Fix: robust MCP handler with proper logging and error handling
This commit is contained in:
+139
-104
@@ -24,10 +24,19 @@ Auth:
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
stream=sys.stdout,
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("mcp-summary")
|
||||||
|
|
||||||
API_KEY = os.environ.get("API_KEY", "").strip()
|
API_KEY = os.environ.get("API_KEY", "").strip()
|
||||||
|
|
||||||
# Tool definitions
|
# Tool definitions
|
||||||
@@ -81,13 +90,13 @@ def call_llm(text: str, system_prompt: str, max_tokens: int = 2000) -> str:
|
|||||||
openapi_api_key = os.environ.get("OPENAPI_API_KEY", "")
|
openapi_api_key = os.environ.get("OPENAPI_API_KEY", "")
|
||||||
model_name = os.environ.get("MODEL_NAME", "gpt-4o")
|
model_name = os.environ.get("MODEL_NAME", "gpt-4o")
|
||||||
timeout = int(os.environ.get("LLM_TIMEOUT", "120"))
|
timeout = int(os.environ.get("LLM_TIMEOUT", "120"))
|
||||||
|
|
||||||
url = f"{openapi_url}/chat/completions"
|
url = f"{openapi_url}/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {openapi_api_key}"
|
"Authorization": f"Bearer {openapi_api_key}"
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -98,10 +107,11 @@ def call_llm(text: str, system_prompt: str, max_tokens: int = 2000) -> str:
|
|||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"top_p": 0.9
|
"top_p": 0.9
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(f"Calling LLM: {url} model={model_name}")
|
||||||
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return data["choices"][0]["message"]["content"]
|
return data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
@@ -110,16 +120,16 @@ def chunk_text(text: str) -> list:
|
|||||||
"""Split text into chunks with overlap for summarization."""
|
"""Split text into chunks with overlap for summarization."""
|
||||||
chunk_size = int(os.environ.get("CHUNK_SIZE", "4000"))
|
chunk_size = int(os.environ.get("CHUNK_SIZE", "4000"))
|
||||||
overlap = int(os.environ.get("OVERLAP", "200"))
|
overlap = int(os.environ.get("OVERLAP", "200"))
|
||||||
|
|
||||||
if len(text) <= chunk_size:
|
if len(text) <= chunk_size:
|
||||||
return [text]
|
return [text]
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
start = 0
|
start = 0
|
||||||
|
|
||||||
while start < len(text):
|
while start < len(text):
|
||||||
end = min(start + chunk_size, len(text))
|
end = min(start + chunk_size, len(text))
|
||||||
|
|
||||||
# Try to break at sentence/paragraph boundary
|
# Try to break at sentence/paragraph boundary
|
||||||
break_point = end
|
break_point = end
|
||||||
for marker in ["\n\n", "\n", ". ", "! ", "? "]:
|
for marker in ["\n\n", "\n", ". ", "! ", "? "]:
|
||||||
@@ -127,34 +137,34 @@ def chunk_text(text: str) -> list:
|
|||||||
if pos > start:
|
if pos > start:
|
||||||
break_point = pos
|
break_point = pos
|
||||||
break
|
break
|
||||||
|
|
||||||
chunk = text[start:break_point]
|
chunk = text[start:break_point]
|
||||||
if chunk.strip():
|
if chunk.strip():
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
start = break_point - overlap if break_point < len(text) else len(text)
|
start = break_point - overlap if break_point < len(text) else len(text)
|
||||||
if start >= len(text):
|
if start >= len(text):
|
||||||
break
|
break
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def summarize_document(text: str, max_length: int = 100) -> dict:
|
def summarize_document(text: str, max_length: int = 100) -> dict:
|
||||||
"""
|
"""
|
||||||
Main summarization function.
|
Main summarization function.
|
||||||
|
|
||||||
- If text is short, summarize directly
|
- If text is short, summarize directly
|
||||||
- If text is long, chunk and summarize each chunk, then synthesize
|
- If text is long, chunk and summarize each chunk, then synthesize
|
||||||
"""
|
"""
|
||||||
original_length = len(text)
|
original_length = len(text)
|
||||||
|
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
if not text:
|
if not text:
|
||||||
raise ValueError("Empty text provided")
|
raise ValueError("Empty text provided")
|
||||||
|
|
||||||
max_direct_length = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000"))
|
max_direct_length = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000"))
|
||||||
intermediate_length = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150"))
|
intermediate_length = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150"))
|
||||||
|
|
||||||
# Direct summarization for shorter texts
|
# Direct summarization for shorter texts
|
||||||
if len(text) <= max_direct_length:
|
if len(text) <= max_direct_length:
|
||||||
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
|
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
|
||||||
@@ -166,25 +176,25 @@ Create a summary that:
|
|||||||
- Preserves names, dates, and specific facts
|
- Preserves names, dates, and specific facts
|
||||||
|
|
||||||
Format as plain text without bullet points."""
|
Format as plain text without bullet points."""
|
||||||
|
|
||||||
user_prompt = f"""Summarize the following document:
|
user_prompt = f"""Summarize the following document:
|
||||||
|
|
||||||
{text}
|
{text}
|
||||||
|
|
||||||
Summary:"""
|
Summary:"""
|
||||||
|
|
||||||
summary = call_llm(user_prompt, system_prompt)
|
summary = call_llm(user_prompt, system_prompt)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"summary": summary,
|
"summary": summary,
|
||||||
"original_length": original_length,
|
"original_length": original_length,
|
||||||
"method": "direct",
|
"method": "direct",
|
||||||
"chunks": 1
|
"chunks": 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# Chunked summarization for longer texts
|
# Chunked summarization for longer texts
|
||||||
chunks = chunk_text(text)
|
chunks = chunk_text(text)
|
||||||
|
|
||||||
chunk_summaries = []
|
chunk_summaries = []
|
||||||
for i, chunk in enumerate(chunks, 1):
|
for i, chunk in enumerate(chunks, 1):
|
||||||
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
|
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
|
||||||
@@ -199,19 +209,19 @@ Create a focused summary that:
|
|||||||
- Preserves names, dates, and specific facts
|
- Preserves names, dates, and specific facts
|
||||||
|
|
||||||
Respond as plain text without bullet points."""
|
Respond as plain text without bullet points."""
|
||||||
|
|
||||||
user_prompt = f"""Summarize this text (chunk {i} of {len(chunks)}):
|
user_prompt = f"""Summarize this text (chunk {i} of {len(chunks)}):
|
||||||
|
|
||||||
{chunk}
|
{chunk}
|
||||||
|
|
||||||
Summary:"""
|
Summary:"""
|
||||||
|
|
||||||
chunk_summary = call_llm(user_prompt, system_prompt)
|
chunk_summary = call_llm(user_prompt, system_prompt)
|
||||||
chunk_summaries.append(chunk_summary)
|
chunk_summaries.append(chunk_summary)
|
||||||
|
|
||||||
# Synthesize into final summary
|
# Synthesize into final summary
|
||||||
combined = "\n\n".join(chunk_summaries)
|
combined = "\n\n".join(chunk_summaries)
|
||||||
|
|
||||||
system_prompt = """You are a precise legal assistant creating executive-level summaries.
|
system_prompt = """You are a precise legal assistant creating executive-level summaries.
|
||||||
|
|
||||||
Synthesize the provided partial summaries into a single, cohesive summary that:
|
Synthesize the provided partial summaries into a single, cohesive summary that:
|
||||||
@@ -223,15 +233,15 @@ Synthesize the provided partial summaries into a single, cohesive summary that:
|
|||||||
- Preserves all critical information
|
- Preserves all critical information
|
||||||
|
|
||||||
Format as a single paragraph of plain text."""
|
Format as a single paragraph of plain text."""
|
||||||
|
|
||||||
user_prompt = f"""Synthesize these partial summaries into one cohesive summary:
|
user_prompt = f"""Synthesize these partial summaries into one cohesive summary:
|
||||||
|
|
||||||
{combined}
|
{combined}
|
||||||
|
|
||||||
Final summary:"""
|
Final summary:"""
|
||||||
|
|
||||||
final_summary = call_llm(user_prompt, system_prompt)
|
final_summary = call_llm(user_prompt, system_prompt)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"summary": final_summary,
|
"summary": final_summary,
|
||||||
"original_length": original_length,
|
"original_length": original_length,
|
||||||
@@ -242,10 +252,10 @@ Final summary:"""
|
|||||||
|
|
||||||
class MCPSummaryHandler(BaseHTTPRequestHandler):
|
class MCPSummaryHandler(BaseHTTPRequestHandler):
|
||||||
"""HTTP handler for MCP summary server."""
|
"""HTTP handler for MCP summary server."""
|
||||||
|
|
||||||
def log_message(self, format, *args):
|
def log_message(self, format, *args):
|
||||||
# Quiet logs by default
|
# Use our logger instead of default stderr logging
|
||||||
pass
|
logger.info(format % args)
|
||||||
|
|
||||||
def _send_json(self, status: int, payload: Any):
|
def _send_json(self, status: int, payload: Any):
|
||||||
"""Send JSON response."""
|
"""Send JSON response."""
|
||||||
@@ -266,96 +276,116 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
"""Handle GET requests (health check)."""
|
"""Handle GET requests (health check)."""
|
||||||
if self.path == "/":
|
try:
|
||||||
self._send_json(200, {
|
if self.path == "/":
|
||||||
"service": "mcp-summary",
|
self._send_json(200, {
|
||||||
"transport": "streamable-http",
|
"service": "mcp-summary",
|
||||||
"docs": "Use POST / with MCP JSON-RPC (initialize, tools/list, tools/call)."
|
"transport": "streamable-http",
|
||||||
})
|
"docs": "Use POST / with MCP JSON-RPC (initialize, tools/list, tools/call)."
|
||||||
return
|
})
|
||||||
|
return
|
||||||
|
|
||||||
self.send_error(404, "Not Found")
|
self.send_error(404, "Not Found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"GET error: {e}", exc_info=True)
|
||||||
|
# Ensure we still send something
|
||||||
|
try:
|
||||||
|
self.send_error(500, "Internal Server Error")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
"""Handle MCP JSON-RPC requests."""
|
"""Handle MCP JSON-RPC requests."""
|
||||||
if self.path not in ("/", "/mcp"):
|
|
||||||
self.send_error(404, "Not Found")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self._auth_or_401():
|
|
||||||
return
|
|
||||||
|
|
||||||
length = int(self.headers.get("Content-Length", 0))
|
|
||||||
if length == 0:
|
|
||||||
self._send_json(400, {"error": "Empty body"})
|
|
||||||
return
|
|
||||||
|
|
||||||
raw = self.rfile.read(length)
|
|
||||||
try:
|
try:
|
||||||
req = json.loads(raw)
|
if self.path not in ("/", "/mcp"):
|
||||||
except json.JSONDecodeError:
|
self.send_error(404, "Not Found")
|
||||||
self._send_json(400, {"error": "Invalid JSON"})
|
return
|
||||||
return
|
|
||||||
|
|
||||||
method = req.get("method")
|
if not self._auth_or_401():
|
||||||
params = req.get("params") or {}
|
return
|
||||||
req_id = req.get("id")
|
|
||||||
|
|
||||||
# MCP: initialize
|
length = int(self.headers.get("Content-Length", 0))
|
||||||
if method == "initialize":
|
if length == 0:
|
||||||
self._send_json(200, {
|
self._send_json(400, {"error": "Empty body"})
|
||||||
"jsonrpc": "2.0",
|
return
|
||||||
"id": req_id,
|
|
||||||
"result": {
|
|
||||||
"protocolVersion": "2025-11-25",
|
|
||||||
"capabilities": {
|
|
||||||
"tools": {}
|
|
||||||
},
|
|
||||||
"serverInfo": {
|
|
||||||
"name": "mcp-summary",
|
|
||||||
"version": "1.0.0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return
|
|
||||||
|
|
||||||
# MCP: tools/list
|
raw = self.rfile.read(length)
|
||||||
if method == "tools/list":
|
|
||||||
self._send_json(200, {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": req_id,
|
|
||||||
"result": TOOLS_LIST
|
|
||||||
})
|
|
||||||
return
|
|
||||||
|
|
||||||
# MCP: tools/call
|
|
||||||
if method == "tools/call":
|
|
||||||
tool_name = params.get("name")
|
|
||||||
tool_args = params.get("arguments") or {}
|
|
||||||
try:
|
try:
|
||||||
result = self._call_tool(tool_name, tool_args)
|
req = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self._send_json(400, {"error": "Invalid JSON"})
|
||||||
|
return
|
||||||
|
|
||||||
|
method = req.get("method")
|
||||||
|
params = req.get("params") or {}
|
||||||
|
req_id = req.get("id")
|
||||||
|
|
||||||
|
logger.info(f"MCP request: method={method}, id={req_id}")
|
||||||
|
|
||||||
|
# MCP: initialize
|
||||||
|
if method == "initialize":
|
||||||
self._send_json(200, {
|
self._send_json(200, {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": req_id,
|
"id": req_id,
|
||||||
"result": {
|
"result": {
|
||||||
"content": [
|
"protocolVersion": "2025-11-25",
|
||||||
{"type": "text", "text": json.dumps(result, ensure_ascii=False)}
|
"capabilities": {
|
||||||
]
|
"tools": {}
|
||||||
|
},
|
||||||
|
"serverInfo": {
|
||||||
|
"name": "mcp-summary",
|
||||||
|
"version": "1.0.0"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
except Exception as e:
|
return
|
||||||
|
|
||||||
|
# MCP: tools/list
|
||||||
|
if method == "tools/list":
|
||||||
self._send_json(200, {
|
self._send_json(200, {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": req_id,
|
"id": req_id,
|
||||||
"error": {
|
"result": TOOLS_LIST
|
||||||
"code": -32000,
|
|
||||||
"message": str(e)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
||||||
# Unknown method
|
# MCP: tools/call
|
||||||
self._send_json(400, {"error": "Unknown method: " + str(method)})
|
if method == "tools/call":
|
||||||
|
tool_name = params.get("name")
|
||||||
|
tool_args = params.get("arguments") or {}
|
||||||
|
try:
|
||||||
|
result = self._call_tool(tool_name, tool_args)
|
||||||
|
self._send_json(200, {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": json.dumps(result, ensure_ascii=False)}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool call error: {e}", exc_info=True)
|
||||||
|
self._send_json(200, {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"error": {
|
||||||
|
"code": -32000,
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Unknown method
|
||||||
|
self._send_json(400, {"error": "Unknown method: " + str(method)})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"POST error: {e}", exc_info=True)
|
||||||
|
# Fallback response to avoid silent drop
|
||||||
|
try:
|
||||||
|
self.send_error(500, "Internal Server Error")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _call_tool(self, name: str, args: Dict[str, Any]) -> Any:
|
def _call_tool(self, name: str, args: Dict[str, Any]) -> Any:
|
||||||
"""Execute a tool call."""
|
"""Execute a tool call."""
|
||||||
@@ -363,7 +393,7 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
|
|||||||
text = args.get("text")
|
text = args.get("text")
|
||||||
if not text:
|
if not text:
|
||||||
raise ValueError("Text parameter is required")
|
raise ValueError("Text parameter is required")
|
||||||
|
|
||||||
max_length = args.get("max_length", 100)
|
max_length = args.get("max_length", 100)
|
||||||
return summarize_document(text, max_length)
|
return summarize_document(text, max_length)
|
||||||
|
|
||||||
@@ -373,13 +403,18 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
|
|||||||
def main():
|
def main():
|
||||||
"""Start the MCP summary server."""
|
"""Start the MCP summary server."""
|
||||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("PORT", "8080"))
|
port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("PORT", "8080"))
|
||||||
|
|
||||||
|
logger.info(f"Starting MCP Summary Server on 0.0.0.0:{port}")
|
||||||
|
logger.info(f"Auth mode: {'Bearer (API_KEY set)' if API_KEY else 'none (API_KEY not set)'}")
|
||||||
|
logger.info(f"LLM URL: {os.environ.get('OPENAPI_URL', 'http://localhost:8080/v1')}")
|
||||||
|
logger.info(f"Model: {os.environ.get('MODEL_NAME', 'gpt-4o')}")
|
||||||
|
|
||||||
server = HTTPServer(("0.0.0.0", port), MCPSummaryHandler)
|
server = HTTPServer(("0.0.0.0", port), MCPSummaryHandler)
|
||||||
mode = "auth enabled (Bearer)" if API_KEY else "no auth (API_KEY not set)"
|
|
||||||
print(f"MCP Summary Server listening on 0.0.0.0:{port} [{mode}]")
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"MCP Summary Server listening on 0.0.0.0:{port}")
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nShutting down...")
|
logger.info("Shutting down...")
|
||||||
server.server_close()
|
server.server_close()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user