Fix: robust MCP handler with proper logging and error handling

This commit is contained in:
2026-06-14 17:13:10 +00:00
parent 63617550a1
commit 208e4195b0
+139 -104
View File
@@ -24,10 +24,19 @@ Auth:
import json import json
import os import os
import sys import sys
import logging
from http.server import HTTPServer, BaseHTTPRequestHandler from http.server import HTTPServer, BaseHTTPRequestHandler
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import requests import requests
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger("mcp-summary")
API_KEY = os.environ.get("API_KEY", "").strip() API_KEY = os.environ.get("API_KEY", "").strip()
# Tool definitions # Tool definitions
@@ -81,13 +90,13 @@ def call_llm(text: str, system_prompt: str, max_tokens: int = 2000) -> str:
openapi_api_key = os.environ.get("OPENAPI_API_KEY", "") openapi_api_key = os.environ.get("OPENAPI_API_KEY", "")
model_name = os.environ.get("MODEL_NAME", "gpt-4o") model_name = os.environ.get("MODEL_NAME", "gpt-4o")
timeout = int(os.environ.get("LLM_TIMEOUT", "120")) timeout = int(os.environ.get("LLM_TIMEOUT", "120"))
url = f"{openapi_url}/chat/completions" url = f"{openapi_url}/chat/completions"
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {openapi_api_key}" "Authorization": f"Bearer {openapi_api_key}"
} }
payload = { payload = {
"model": model_name, "model": model_name,
"messages": [ "messages": [
@@ -98,10 +107,11 @@ def call_llm(text: str, system_prompt: str, max_tokens: int = 2000) -> str:
"max_tokens": max_tokens, "max_tokens": max_tokens,
"top_p": 0.9 "top_p": 0.9
} }
logger.info(f"Calling LLM: {url} model={model_name}")
response = requests.post(url, headers=headers, json=payload, timeout=timeout) response = requests.post(url, headers=headers, json=payload, timeout=timeout)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
return data["choices"][0]["message"]["content"] return data["choices"][0]["message"]["content"]
@@ -110,16 +120,16 @@ def chunk_text(text: str) -> list:
"""Split text into chunks with overlap for summarization.""" """Split text into chunks with overlap for summarization."""
chunk_size = int(os.environ.get("CHUNK_SIZE", "4000")) chunk_size = int(os.environ.get("CHUNK_SIZE", "4000"))
overlap = int(os.environ.get("OVERLAP", "200")) overlap = int(os.environ.get("OVERLAP", "200"))
if len(text) <= chunk_size: if len(text) <= chunk_size:
return [text] return [text]
chunks = [] chunks = []
start = 0 start = 0
while start < len(text): while start < len(text):
end = min(start + chunk_size, len(text)) end = min(start + chunk_size, len(text))
# Try to break at sentence/paragraph boundary # Try to break at sentence/paragraph boundary
break_point = end break_point = end
for marker in ["\n\n", "\n", ". ", "! ", "? "]: for marker in ["\n\n", "\n", ". ", "! ", "? "]:
@@ -127,34 +137,34 @@ def chunk_text(text: str) -> list:
if pos > start: if pos > start:
break_point = pos break_point = pos
break break
chunk = text[start:break_point] chunk = text[start:break_point]
if chunk.strip(): if chunk.strip():
chunks.append(chunk) chunks.append(chunk)
start = break_point - overlap if break_point < len(text) else len(text) start = break_point - overlap if break_point < len(text) else len(text)
if start >= len(text): if start >= len(text):
break break
return chunks return chunks
def summarize_document(text: str, max_length: int = 100) -> dict: def summarize_document(text: str, max_length: int = 100) -> dict:
""" """
Main summarization function. Main summarization function.
- If text is short, summarize directly - If text is short, summarize directly
- If text is long, chunk and summarize each chunk, then synthesize - If text is long, chunk and summarize each chunk, then synthesize
""" """
original_length = len(text) original_length = len(text)
text = text.strip() text = text.strip()
if not text: if not text:
raise ValueError("Empty text provided") raise ValueError("Empty text provided")
max_direct_length = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000")) max_direct_length = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000"))
intermediate_length = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150")) intermediate_length = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150"))
# Direct summarization for shorter texts # Direct summarization for shorter texts
if len(text) <= max_direct_length: if len(text) <= max_direct_length:
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries. system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
@@ -166,25 +176,25 @@ Create a summary that:
- Preserves names, dates, and specific facts - Preserves names, dates, and specific facts
Format as plain text without bullet points.""" Format as plain text without bullet points."""
user_prompt = f"""Summarize the following document: user_prompt = f"""Summarize the following document:
{text} {text}
Summary:""" Summary:"""
summary = call_llm(user_prompt, system_prompt) summary = call_llm(user_prompt, system_prompt)
return { return {
"summary": summary, "summary": summary,
"original_length": original_length, "original_length": original_length,
"method": "direct", "method": "direct",
"chunks": 1 "chunks": 1
} }
# Chunked summarization for longer texts # Chunked summarization for longer texts
chunks = chunk_text(text) chunks = chunk_text(text)
chunk_summaries = [] chunk_summaries = []
for i, chunk in enumerate(chunks, 1): for i, chunk in enumerate(chunks, 1):
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries. system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
@@ -199,19 +209,19 @@ Create a focused summary that:
- Preserves names, dates, and specific facts - Preserves names, dates, and specific facts
Respond as plain text without bullet points.""" Respond as plain text without bullet points."""
user_prompt = f"""Summarize this text (chunk {i} of {len(chunks)}): user_prompt = f"""Summarize this text (chunk {i} of {len(chunks)}):
{chunk} {chunk}
Summary:""" Summary:"""
chunk_summary = call_llm(user_prompt, system_prompt) chunk_summary = call_llm(user_prompt, system_prompt)
chunk_summaries.append(chunk_summary) chunk_summaries.append(chunk_summary)
# Synthesize into final summary # Synthesize into final summary
combined = "\n\n".join(chunk_summaries) combined = "\n\n".join(chunk_summaries)
system_prompt = """You are a precise legal assistant creating executive-level summaries. system_prompt = """You are a precise legal assistant creating executive-level summaries.
Synthesize the provided partial summaries into a single, cohesive summary that: Synthesize the provided partial summaries into a single, cohesive summary that:
@@ -223,15 +233,15 @@ Synthesize the provided partial summaries into a single, cohesive summary that:
- Preserves all critical information - Preserves all critical information
Format as a single paragraph of plain text.""" Format as a single paragraph of plain text."""
user_prompt = f"""Synthesize these partial summaries into one cohesive summary: user_prompt = f"""Synthesize these partial summaries into one cohesive summary:
{combined} {combined}
Final summary:""" Final summary:"""
final_summary = call_llm(user_prompt, system_prompt) final_summary = call_llm(user_prompt, system_prompt)
return { return {
"summary": final_summary, "summary": final_summary,
"original_length": original_length, "original_length": original_length,
@@ -242,10 +252,10 @@ Final summary:"""
class MCPSummaryHandler(BaseHTTPRequestHandler): class MCPSummaryHandler(BaseHTTPRequestHandler):
"""HTTP handler for MCP summary server.""" """HTTP handler for MCP summary server."""
def log_message(self, format, *args): def log_message(self, format, *args):
# Quiet logs by default # Use our logger instead of default stderr logging
pass logger.info(format % args)
def _send_json(self, status: int, payload: Any): def _send_json(self, status: int, payload: Any):
"""Send JSON response.""" """Send JSON response."""
@@ -266,96 +276,116 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
"""Handle GET requests (health check).""" """Handle GET requests (health check)."""
if self.path == "/": try:
self._send_json(200, { if self.path == "/":
"service": "mcp-summary", self._send_json(200, {
"transport": "streamable-http", "service": "mcp-summary",
"docs": "Use POST / with MCP JSON-RPC (initialize, tools/list, tools/call)." "transport": "streamable-http",
}) "docs": "Use POST / with MCP JSON-RPC (initialize, tools/list, tools/call)."
return })
return
self.send_error(404, "Not Found") self.send_error(404, "Not Found")
except Exception as e:
logger.error(f"GET error: {e}", exc_info=True)
# Ensure we still send something
try:
self.send_error(500, "Internal Server Error")
except Exception:
pass
def do_POST(self): def do_POST(self):
"""Handle MCP JSON-RPC requests.""" """Handle MCP JSON-RPC requests."""
if self.path not in ("/", "/mcp"):
self.send_error(404, "Not Found")
return
if not self._auth_or_401():
return
length = int(self.headers.get("Content-Length", 0))
if length == 0:
self._send_json(400, {"error": "Empty body"})
return
raw = self.rfile.read(length)
try: try:
req = json.loads(raw) if self.path not in ("/", "/mcp"):
except json.JSONDecodeError: self.send_error(404, "Not Found")
self._send_json(400, {"error": "Invalid JSON"}) return
return
method = req.get("method") if not self._auth_or_401():
params = req.get("params") or {} return
req_id = req.get("id")
# MCP: initialize length = int(self.headers.get("Content-Length", 0))
if method == "initialize": if length == 0:
self._send_json(200, { self._send_json(400, {"error": "Empty body"})
"jsonrpc": "2.0", return
"id": req_id,
"result": {
"protocolVersion": "2025-11-25",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "mcp-summary",
"version": "1.0.0"
}
}
})
return
# MCP: tools/list raw = self.rfile.read(length)
if method == "tools/list":
self._send_json(200, {
"jsonrpc": "2.0",
"id": req_id,
"result": TOOLS_LIST
})
return
# MCP: tools/call
if method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments") or {}
try: try:
result = self._call_tool(tool_name, tool_args) req = json.loads(raw)
except json.JSONDecodeError:
self._send_json(400, {"error": "Invalid JSON"})
return
method = req.get("method")
params = req.get("params") or {}
req_id = req.get("id")
logger.info(f"MCP request: method={method}, id={req_id}")
# MCP: initialize
if method == "initialize":
self._send_json(200, { self._send_json(200, {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": req_id, "id": req_id,
"result": { "result": {
"content": [ "protocolVersion": "2025-11-25",
{"type": "text", "text": json.dumps(result, ensure_ascii=False)} "capabilities": {
] "tools": {}
},
"serverInfo": {
"name": "mcp-summary",
"version": "1.0.0"
}
} }
}) })
except Exception as e: return
# MCP: tools/list
if method == "tools/list":
self._send_json(200, { self._send_json(200, {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": req_id, "id": req_id,
"error": { "result": TOOLS_LIST
"code": -32000,
"message": str(e)
}
}) })
return return
# Unknown method # MCP: tools/call
self._send_json(400, {"error": "Unknown method: " + str(method)}) if method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments") or {}
try:
result = self._call_tool(tool_name, tool_args)
self._send_json(200, {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"content": [
{"type": "text", "text": json.dumps(result, ensure_ascii=False)}
]
}
})
except Exception as e:
logger.error(f"Tool call error: {e}", exc_info=True)
self._send_json(200, {
"jsonrpc": "2.0",
"id": req_id,
"error": {
"code": -32000,
"message": str(e)
}
})
return
# Unknown method
self._send_json(400, {"error": "Unknown method: " + str(method)})
except Exception as e:
logger.error(f"POST error: {e}", exc_info=True)
# Fallback response to avoid silent drop
try:
self.send_error(500, "Internal Server Error")
except Exception:
pass
def _call_tool(self, name: str, args: Dict[str, Any]) -> Any: def _call_tool(self, name: str, args: Dict[str, Any]) -> Any:
"""Execute a tool call.""" """Execute a tool call."""
@@ -363,7 +393,7 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
text = args.get("text") text = args.get("text")
if not text: if not text:
raise ValueError("Text parameter is required") raise ValueError("Text parameter is required")
max_length = args.get("max_length", 100) max_length = args.get("max_length", 100)
return summarize_document(text, max_length) return summarize_document(text, max_length)
@@ -373,13 +403,18 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
def main(): def main():
"""Start the MCP summary server.""" """Start the MCP summary server."""
port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("PORT", "8080")) port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("PORT", "8080"))
logger.info(f"Starting MCP Summary Server on 0.0.0.0:{port}")
logger.info(f"Auth mode: {'Bearer (API_KEY set)' if API_KEY else 'none (API_KEY not set)'}")
logger.info(f"LLM URL: {os.environ.get('OPENAPI_URL', 'http://localhost:8080/v1')}")
logger.info(f"Model: {os.environ.get('MODEL_NAME', 'gpt-4o')}")
server = HTTPServer(("0.0.0.0", port), MCPSummaryHandler) server = HTTPServer(("0.0.0.0", port), MCPSummaryHandler)
mode = "auth enabled (Bearer)" if API_KEY else "no auth (API_KEY not set)"
print(f"MCP Summary Server listening on 0.0.0.0:{port} [{mode}]")
try: try:
logger.info(f"MCP Summary Server listening on 0.0.0.0:{port}")
server.serve_forever() server.serve_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nShutting down...") logger.info("Shutting down...")
server.server_close() server.server_close()