Refactor: Match time tools pattern exactly for MCP compatibility

This commit is contained in:
2026-06-14 05:00:05 +00:00
parent b0f19810d4
commit a98903c048
+117 -174
View File
@@ -25,36 +25,12 @@ import json
import os
import sys
import re
import logging
from http.server import HTTPServer, BaseHTTPRequestHandler
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import requests
from requests.exceptions import RequestException
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("mcp-summary")
# MCP Server Configuration
API_KEY = os.environ.get("API_KEY", "").strip()
PORT = int(os.environ.get("PORT", "8080"))
# LLM Configuration
OPENAPI_URL = os.environ.get("OPENAPI_URL", "http://localhost:8080/v1")
OPENAPI_API_KEY = os.environ.get("OPENAPI_API_KEY", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
# Summarization Configuration
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "4000"))
OVERLAP = int(os.environ.get("OVERLAP", "200"))
TARGET_INTERMEDIATE_SUMMARY_LENGTH = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150"))
MAX_DIRECT_SUMMARY_LENGTH = int(os.environ.get("MAX_DIRECT_SUMMARY_LENGTH", "100"))
MAX_DIRECT_TEXT_LENGTH = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000"))
# LLM call timeout in seconds - increase for large documents
LLM_TIMEOUT = int(os.environ.get("LLM_TIMEOUT", "120"))
# Tool definitions
TOOLS_LIST: Dict[str, Any] = {
"tools": [
{
@@ -69,7 +45,7 @@ TOOLS_LIST: Dict[str, Any] = {
},
"max_length": {
"type": "integer",
"description": f"Maximum length of summary in words (default: {MAX_DIRECT_SUMMARY_LENGTH})"
"description": "Maximum length of summary in words (default: 100)"
}
},
"required": ["text"]
@@ -79,53 +55,73 @@ TOOLS_LIST: Dict[str, Any] = {
}
def call_llm(messages: List[Dict], temperature: float = 0.3) -> str:
"""Make an OpenAPI-compatible LLM call with error handling."""
url = f"{OPENAPI_URL}/chat/completions"
def get_bearer_token(headers: Any) -> Optional[str]:
auth = (headers.get("Authorization") or "").strip()
if auth.startswith("Bearer "):
return auth[len("Bearer "):].strip()
return None
def require_auth(headers: Any) -> bool:
# If API_KEY is not set, allow unauthenticated access
if not API_KEY:
return True
token = get_bearer_token(headers)
if not token or token != API_KEY:
raise PermissionError("Missing or invalid API key")
return True
def call_llm(text: str, system_prompt: str, max_tokens: int = 2000) -> str:
"""Make an OpenAPI-compatible LLM call."""
openapi_url = os.environ.get("OPENAPI_URL", "http://localhost:8080/v1")
openapi_api_key = os.environ.get("OPENAPI_API_KEY", "")
model_name = os.environ.get("MODEL_NAME", "gpt-4o")
timeout = int(os.environ.get("LLM_TIMEOUT", "120"))
url = f"{openapi_url}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAPI_API_KEY}"
"Authorization": f"Bearer {openapi_api_key}"
}
payload = {
"model": MODEL_NAME,
"messages": messages,
"temperature": temperature,
"max_tokens": 2000,
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text}
],
"temperature": 0.3,
"max_tokens": max_tokens,
"top_p": 0.9
}
try:
logger.info(f"Calling LLM at {OPENAPI_URL} with model {MODEL_NAME}")
response = requests.post(url, headers=headers, json=payload, timeout=LLM_TIMEOUT)
response.raise_for_status()
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
except RequestException as e:
logger.error(f"LLM request failed: {e}")
raise RuntimeError(f"Failed to connect to LLM at {OPENAPI_URL}: {str(e)}")
except Exception as e:
logger.error(f"LLM call failed: {e}")
raise RuntimeError(f"LLM call failed: {str(e)}")
data = response.json()
return data["choices"][0]["message"]["content"]
def chunk_text(text: str) -> List[str]:
def chunk_text(text: str) -> list:
"""Split text into chunks with overlap for summarization."""
if len(text) <= CHUNK_SIZE:
chunk_size = int(os.environ.get("CHUNK_SIZE", "4000"))
overlap = int(os.environ.get("OVERLAP", "200"))
if len(text) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = min(start + CHUNK_SIZE, len(text))
end = min(start + chunk_size, len(text))
# Try to break at sentence/paragraph boundary
break_point = end
for marker in ["\n\n", "\n", ". ", "! ", "? "]:
pos = text.rfind(marker, start + CHUNK_SIZE // 2, end)
pos = text.rfind(marker, start + chunk_size // 2, end)
if pos > start:
break_point = pos
break
@@ -134,46 +130,84 @@ def chunk_text(text: str) -> List[str]:
if chunk.strip():
chunks.append(chunk)
start = break_point - OVERLAP if break_point < len(text) else len(text)
start = break_point - overlap if break_point < len(text) else len(text)
if start >= len(text):
break
logger.info(f"Split text into {len(chunks)} chunks")
return chunks
def summarize_chunk(chunk: str, chunk_num: int, total_chunks: int) -> str:
"""Summarize a single chunk of text."""
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
def summarize_document(text: str, max_length: int = 100) -> dict:
"""
Main summarization function.
You are processing chunk {chunk_num} of {total_chunks} from a larger document.
- If text is short, summarize directly
- If text is long, chunk and summarize each chunk, then synthesize
"""
original_length = len(text)
text = text.strip()
if not text:
raise ValueError("Empty text provided")
max_direct_length = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000"))
intermediate_length = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150"))
# Direct summarization for shorter texts
if len(text) <= max_direct_length:
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
Create a summary that:
- Is approximately {max_length} words
- Captures key points and important details
- Uses clear, professional language
- Preserves names, dates, and specific facts
Format as plain text without bullet points."""
user_prompt = f"""Summarize the following document:
{text}
Summary:"""
summary = call_llm(user_prompt, system_prompt)
return {
"summary": summary,
"original_length": original_length,
"method": "direct",
"chunks": 1
}
# Chunked summarization for longer texts
chunks = chunk_text(text)
chunk_summaries = []
for i, chunk in enumerate(chunks, 1):
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
You are processing chunk {i} of {len(chunks)} from a larger document.
Create a focused summary that:
- Captures key points and important details
- Is approximately {TARGET_INTERMEDIATE_SUMMARY_LENGTH} words
- Is approximately {intermediate_length} words
- Can be combined with other chunk summaries
- Uses clear, professional language
- Preserves names, dates, and specific facts
Respond as plain text without bullet points."""
user_prompt = f"""Summarize this text (chunk {chunk_num} of {total_chunks}):
user_prompt = f"""Summarize this text (chunk {i} of {len(chunks)}):
{text}
{chunk}
Summary:"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
chunk_summary = call_llm(user_prompt, system_prompt)
chunk_summaries.append(chunk_summary)
logger.info(f"Summarizing chunk {chunk_num}/{total_chunks}")
return call_llm(messages)
def synthesize_summaries(chunk_summaries: List[str]) -> str:
"""Synthesize multiple chunk summaries into a single final summary."""
# Synthesize into final summary
combined = "\n\n".join(chunk_summaries)
system_prompt = """You are a precise legal assistant creating executive-level summaries.
@@ -194,71 +228,7 @@ Format as a single paragraph of plain text."""
Final summary:"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
logger.info(f"Synthesizing {len(chunk_summaries)} chunk summaries")
return call_llm(messages)
def summarize_document(text: str, max_length: int = MAX_DIRECT_SUMMARY_LENGTH) -> Dict[str, Any]:
"""
Main summarization function.
- If text is short, summarize directly
- If text is long, chunk and summarize each chunk, then synthesize
"""
original_length = len(text)
text = text.strip()
if not text:
raise ValueError("Empty text provided")
logger.info(f"Summarizing text of {original_length} characters")
# Direct summarization for shorter texts
if len(text) <= MAX_DIRECT_TEXT_LENGTH:
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
Create a summary that:
- Is approximately {max_length} words
- Captures key points and important details
- Uses clear, professional language
- Preserves names, dates, and specific facts
Format as plain text without bullet points."""
user_prompt = f"""Summarize the following document:
{text}
Summary:"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
summary = call_llm(messages)
return {
"summary": summary,
"original_length": original_length,
"method": "direct",
"chunks": 1
}
# Chunked summarization for longer texts
chunks = chunk_text(text)
chunk_summaries = []
for i, chunk in enumerate(chunks, 1):
chunk_summary = summarize_chunk(chunk, i, len(chunks))
chunk_summaries.append(chunk_summary)
final_summary = synthesize_summaries(chunk_summaries)
final_summary = call_llm(user_prompt, system_prompt)
return {
"summary": final_summary,
@@ -269,13 +239,11 @@ Summary:"""
class MCPSummaryHandler(BaseHTTPRequestHandler):
"""HTTP handler for MCP summary server."""
def log_message(self, format, *args):
logger.info(format % args)
# Quiet logs by default
pass
def _send_json(self, status: int, payload: Any):
"""Send JSON response."""
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
@@ -283,31 +251,19 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(body)
def _auth_or_401(self) -> bool:
"""Check authentication if API key is configured."""
if not API_KEY:
return True
auth_header = self.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
def _auth_or_401(self):
try:
return require_auth(self.headers)
except PermissionError:
self._send_json(401, {"error": "Missing or invalid API key"})
return False
token = auth_header[len("Bearer "):].strip()
if token != API_KEY:
self._send_json(401, {"error": "Invalid API key"})
return False
return True
def do_GET(self):
"""Handle GET requests (health check)."""
# Basic info endpoint (not required by MCP, but useful)
if self.path == "/":
self._send_json(200, {
"service": "mcp-summary",
"transport": "streamable-http",
"model": MODEL_NAME,
"status": "running",
"docs": "Use POST / with MCP JSON-RPC (initialize, tools/list, tools/call)."
})
return
@@ -315,7 +271,7 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
self.send_error(404, "Not Found")
def do_POST(self):
"""Handle MCP JSON-RPC requests."""
# Streamable HTTP MCP endpoint
if self.path not in ("/", "/mcp"):
self.send_error(404, "Not Found")
return
@@ -323,7 +279,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
if not self._auth_or_401():
return
# Parse request
length = int(self.headers.get("Content-Length", 0))
if length == 0:
self._send_json(400, {"error": "Empty body"})
@@ -340,8 +295,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
params = req.get("params") or {}
req_id = req.get("id")
logger.info(f"MCP request: method={method}, id={req_id}")
# MCP: initialize
if method == "initialize":
self._send_json(200, {
@@ -373,7 +326,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
if method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments") or {}
try:
result = self._call_tool(tool_name, tool_args)
self._send_json(200, {
@@ -386,7 +338,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
}
})
except Exception as e:
logger.error(f"Tool call failed: {e}")
self._send_json(200, {
"jsonrpc": "2.0",
"id": req_id,
@@ -401,30 +352,22 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
self._send_json(400, {"error": "Unknown method: " + str(method)})
def _call_tool(self, name: str, args: Dict[str, Any]) -> Any:
"""Execute a tool call."""
if name == "summarize_document":
text = args.get("text")
if not text:
raise ValueError("Text parameter is required")
max_length = args.get("max_length", MAX_DIRECT_SUMMARY_LENGTH)
max_length = args.get("max_length", 100)
return summarize_document(text, max_length)
raise ValueError(f"Unknown tool: {name}")
def main():
"""Start the MCP summary server."""
server = HTTPServer(("0.0.0.0", PORT), MCPSummaryHandler)
port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("PORT", "8080"))
server = HTTPServer(("0.0.0.0", port), MCPSummaryHandler)
mode = "auth enabled (Bearer)" if API_KEY else "no auth (API_KEY not set)"
print(f"MCP Summary Server listening on 0.0.0.0:{PORT} [{mode}]")
print(f" - Model: {MODEL_NAME}")
print(f" - LLM URL: {OPENAPI_URL}")
print(f" - Chunk size: {CHUNK_SIZE} characters")
print(f" - Max direct text: {MAX_DIRECT_TEXT_LENGTH} characters")
print(f" - LLM timeout: {LLM_TIMEOUT} seconds")
print(f"MCP Summary Server listening on 0.0.0.0:{port} [{mode}]")
try:
server.serve_forever()
except KeyboardInterrupt: