Refactor: Match time tools pattern exactly for MCP compatibility

This commit is contained in:
2026-06-14 05:00:05 +00:00
parent b0f19810d4
commit a98903c048
+113 -170
View File
@@ -25,36 +25,12 @@ import json
import os import os
import sys import sys
import re import re
import logging
from http.server import HTTPServer, BaseHTTPRequestHandler from http.server import HTTPServer, BaseHTTPRequestHandler
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
import requests import requests
from requests.exceptions import RequestException
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("mcp-summary")
# MCP Server Configuration
API_KEY = os.environ.get("API_KEY", "").strip() API_KEY = os.environ.get("API_KEY", "").strip()
PORT = int(os.environ.get("PORT", "8080"))
# LLM Configuration
OPENAPI_URL = os.environ.get("OPENAPI_URL", "http://localhost:8080/v1")
OPENAPI_API_KEY = os.environ.get("OPENAPI_API_KEY", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
# Summarization Configuration
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "4000"))
OVERLAP = int(os.environ.get("OVERLAP", "200"))
TARGET_INTERMEDIATE_SUMMARY_LENGTH = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150"))
MAX_DIRECT_SUMMARY_LENGTH = int(os.environ.get("MAX_DIRECT_SUMMARY_LENGTH", "100"))
MAX_DIRECT_TEXT_LENGTH = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000"))
# LLM call timeout in seconds - increase for large documents
LLM_TIMEOUT = int(os.environ.get("LLM_TIMEOUT", "120"))
# Tool definitions
TOOLS_LIST: Dict[str, Any] = { TOOLS_LIST: Dict[str, Any] = {
"tools": [ "tools": [
{ {
@@ -69,7 +45,7 @@ TOOLS_LIST: Dict[str, Any] = {
}, },
"max_length": { "max_length": {
"type": "integer", "type": "integer",
"description": f"Maximum length of summary in words (default: {MAX_DIRECT_SUMMARY_LENGTH})" "description": "Maximum length of summary in words (default: 100)"
} }
}, },
"required": ["text"] "required": ["text"]
@@ -79,53 +55,73 @@ TOOLS_LIST: Dict[str, Any] = {
} }
def call_llm(messages: List[Dict], temperature: float = 0.3) -> str: def get_bearer_token(headers: Any) -> Optional[str]:
"""Make an OpenAPI-compatible LLM call with error handling.""" auth = (headers.get("Authorization") or "").strip()
url = f"{OPENAPI_URL}/chat/completions" if auth.startswith("Bearer "):
return auth[len("Bearer "):].strip()
return None
def require_auth(headers: Any) -> bool:
# If API_KEY is not set, allow unauthenticated access
if not API_KEY:
return True
token = get_bearer_token(headers)
if not token or token != API_KEY:
raise PermissionError("Missing or invalid API key")
return True
def call_llm(text: str, system_prompt: str, max_tokens: int = 2000) -> str:
"""Make an OpenAPI-compatible LLM call."""
openapi_url = os.environ.get("OPENAPI_URL", "http://localhost:8080/v1")
openapi_api_key = os.environ.get("OPENAPI_API_KEY", "")
model_name = os.environ.get("MODEL_NAME", "gpt-4o")
timeout = int(os.environ.get("LLM_TIMEOUT", "120"))
url = f"{openapi_url}/chat/completions"
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {OPENAPI_API_KEY}" "Authorization": f"Bearer {openapi_api_key}"
} }
payload = { payload = {
"model": MODEL_NAME, "model": model_name,
"messages": messages, "messages": [
"temperature": temperature, {"role": "system", "content": system_prompt},
"max_tokens": 2000, {"role": "user", "content": text}
],
"temperature": 0.3,
"max_tokens": max_tokens,
"top_p": 0.9 "top_p": 0.9
} }
try: response = requests.post(url, headers=headers, json=payload, timeout=timeout)
logger.info(f"Calling LLM at {OPENAPI_URL} with model {MODEL_NAME}")
response = requests.post(url, headers=headers, json=payload, timeout=LLM_TIMEOUT)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
return data["choices"][0]["message"]["content"] return data["choices"][0]["message"]["content"]
except RequestException as e:
logger.error(f"LLM request failed: {e}")
raise RuntimeError(f"Failed to connect to LLM at {OPENAPI_URL}: {str(e)}")
except Exception as e:
logger.error(f"LLM call failed: {e}")
raise RuntimeError(f"LLM call failed: {str(e)}")
def chunk_text(text: str) -> list:
def chunk_text(text: str) -> List[str]:
"""Split text into chunks with overlap for summarization.""" """Split text into chunks with overlap for summarization."""
if len(text) <= CHUNK_SIZE: chunk_size = int(os.environ.get("CHUNK_SIZE", "4000"))
overlap = int(os.environ.get("OVERLAP", "200"))
if len(text) <= chunk_size:
return [text] return [text]
chunks = [] chunks = []
start = 0 start = 0
while start < len(text): while start < len(text):
end = min(start + CHUNK_SIZE, len(text)) end = min(start + chunk_size, len(text))
# Try to break at sentence/paragraph boundary # Try to break at sentence/paragraph boundary
break_point = end break_point = end
for marker in ["\n\n", "\n", ". ", "! ", "? "]: for marker in ["\n\n", "\n", ". ", "! ", "? "]:
pos = text.rfind(marker, start + CHUNK_SIZE // 2, end) pos = text.rfind(marker, start + chunk_size // 2, end)
if pos > start: if pos > start:
break_point = pos break_point = pos
break break
@@ -134,46 +130,84 @@ def chunk_text(text: str) -> List[str]:
if chunk.strip(): if chunk.strip():
chunks.append(chunk) chunks.append(chunk)
start = break_point - OVERLAP if break_point < len(text) else len(text) start = break_point - overlap if break_point < len(text) else len(text)
if start >= len(text): if start >= len(text):
break break
logger.info(f"Split text into {len(chunks)} chunks")
return chunks return chunks
def summarize_chunk(chunk: str, chunk_num: int, total_chunks: int) -> str: def summarize_document(text: str, max_length: int = 100) -> dict:
"""Summarize a single chunk of text.""" """
Main summarization function.
- If text is short, summarize directly
- If text is long, chunk and summarize each chunk, then synthesize
"""
original_length = len(text)
text = text.strip()
if not text:
raise ValueError("Empty text provided")
max_direct_length = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000"))
intermediate_length = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150"))
# Direct summarization for shorter texts
if len(text) <= max_direct_length:
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries. system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
You are processing chunk {chunk_num} of {total_chunks} from a larger document. Create a summary that:
- Is approximately {max_length} words
- Captures key points and important details
- Uses clear, professional language
- Preserves names, dates, and specific facts
Format as plain text without bullet points."""
user_prompt = f"""Summarize the following document:
{text}
Summary:"""
summary = call_llm(user_prompt, system_prompt)
return {
"summary": summary,
"original_length": original_length,
"method": "direct",
"chunks": 1
}
# Chunked summarization for longer texts
chunks = chunk_text(text)
chunk_summaries = []
for i, chunk in enumerate(chunks, 1):
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
You are processing chunk {i} of {len(chunks)} from a larger document.
Create a focused summary that: Create a focused summary that:
- Captures key points and important details - Captures key points and important details
- Is approximately {TARGET_INTERMEDIATE_SUMMARY_LENGTH} words - Is approximately {intermediate_length} words
- Can be combined with other chunk summaries - Can be combined with other chunk summaries
- Uses clear, professional language - Uses clear, professional language
- Preserves names, dates, and specific facts - Preserves names, dates, and specific facts
Respond as plain text without bullet points.""" Respond as plain text without bullet points."""
user_prompt = f"""Summarize this text (chunk {chunk_num} of {total_chunks}): user_prompt = f"""Summarize this text (chunk {i} of {len(chunks)}):
{text} {chunk}
Summary:""" Summary:"""
messages = [ chunk_summary = call_llm(user_prompt, system_prompt)
{"role": "system", "content": system_prompt}, chunk_summaries.append(chunk_summary)
{"role": "user", "content": user_prompt}
]
logger.info(f"Summarizing chunk {chunk_num}/{total_chunks}") # Synthesize into final summary
return call_llm(messages)
def synthesize_summaries(chunk_summaries: List[str]) -> str:
"""Synthesize multiple chunk summaries into a single final summary."""
combined = "\n\n".join(chunk_summaries) combined = "\n\n".join(chunk_summaries)
system_prompt = """You are a precise legal assistant creating executive-level summaries. system_prompt = """You are a precise legal assistant creating executive-level summaries.
@@ -194,71 +228,7 @@ Format as a single paragraph of plain text."""
Final summary:""" Final summary:"""
messages = [ final_summary = call_llm(user_prompt, system_prompt)
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
logger.info(f"Synthesizing {len(chunk_summaries)} chunk summaries")
return call_llm(messages)
def summarize_document(text: str, max_length: int = MAX_DIRECT_SUMMARY_LENGTH) -> Dict[str, Any]:
"""
Main summarization function.
- If text is short, summarize directly
- If text is long, chunk and summarize each chunk, then synthesize
"""
original_length = len(text)
text = text.strip()
if not text:
raise ValueError("Empty text provided")
logger.info(f"Summarizing text of {original_length} characters")
# Direct summarization for shorter texts
if len(text) <= MAX_DIRECT_TEXT_LENGTH:
system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries.
Create a summary that:
- Is approximately {max_length} words
- Captures key points and important details
- Uses clear, professional language
- Preserves names, dates, and specific facts
Format as plain text without bullet points."""
user_prompt = f"""Summarize the following document:
{text}
Summary:"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
summary = call_llm(messages)
return {
"summary": summary,
"original_length": original_length,
"method": "direct",
"chunks": 1
}
# Chunked summarization for longer texts
chunks = chunk_text(text)
chunk_summaries = []
for i, chunk in enumerate(chunks, 1):
chunk_summary = summarize_chunk(chunk, i, len(chunks))
chunk_summaries.append(chunk_summary)
final_summary = synthesize_summaries(chunk_summaries)
return { return {
"summary": final_summary, "summary": final_summary,
@@ -269,13 +239,11 @@ Summary:"""
class MCPSummaryHandler(BaseHTTPRequestHandler): class MCPSummaryHandler(BaseHTTPRequestHandler):
"""HTTP handler for MCP summary server."""
def log_message(self, format, *args): def log_message(self, format, *args):
logger.info(format % args) # Quiet logs by default
pass
def _send_json(self, status: int, payload: Any): def _send_json(self, status: int, payload: Any):
"""Send JSON response."""
body = json.dumps(payload, ensure_ascii=False).encode("utf-8") body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
self.send_response(status) self.send_response(status)
self.send_header("Content-Type", "application/json") self.send_header("Content-Type", "application/json")
@@ -283,31 +251,19 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(body) self.wfile.write(body)
def _auth_or_401(self) -> bool: def _auth_or_401(self):
"""Check authentication if API key is configured.""" try:
if not API_KEY: return require_auth(self.headers)
return True except PermissionError:
auth_header = self.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
self._send_json(401, {"error": "Missing or invalid API key"}) self._send_json(401, {"error": "Missing or invalid API key"})
return False return False
token = auth_header[len("Bearer "):].strip()
if token != API_KEY:
self._send_json(401, {"error": "Invalid API key"})
return False
return True
def do_GET(self): def do_GET(self):
"""Handle GET requests (health check).""" # Basic info endpoint (not required by MCP, but useful)
if self.path == "/": if self.path == "/":
self._send_json(200, { self._send_json(200, {
"service": "mcp-summary", "service": "mcp-summary",
"transport": "streamable-http", "transport": "streamable-http",
"model": MODEL_NAME,
"status": "running",
"docs": "Use POST / with MCP JSON-RPC (initialize, tools/list, tools/call)." "docs": "Use POST / with MCP JSON-RPC (initialize, tools/list, tools/call)."
}) })
return return
@@ -315,7 +271,7 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
self.send_error(404, "Not Found") self.send_error(404, "Not Found")
def do_POST(self): def do_POST(self):
"""Handle MCP JSON-RPC requests.""" # Streamable HTTP MCP endpoint
if self.path not in ("/", "/mcp"): if self.path not in ("/", "/mcp"):
self.send_error(404, "Not Found") self.send_error(404, "Not Found")
return return
@@ -323,7 +279,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
if not self._auth_or_401(): if not self._auth_or_401():
return return
# Parse request
length = int(self.headers.get("Content-Length", 0)) length = int(self.headers.get("Content-Length", 0))
if length == 0: if length == 0:
self._send_json(400, {"error": "Empty body"}) self._send_json(400, {"error": "Empty body"})
@@ -340,8 +295,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
params = req.get("params") or {} params = req.get("params") or {}
req_id = req.get("id") req_id = req.get("id")
logger.info(f"MCP request: method={method}, id={req_id}")
# MCP: initialize # MCP: initialize
if method == "initialize": if method == "initialize":
self._send_json(200, { self._send_json(200, {
@@ -373,7 +326,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
if method == "tools/call": if method == "tools/call":
tool_name = params.get("name") tool_name = params.get("name")
tool_args = params.get("arguments") or {} tool_args = params.get("arguments") or {}
try: try:
result = self._call_tool(tool_name, tool_args) result = self._call_tool(tool_name, tool_args)
self._send_json(200, { self._send_json(200, {
@@ -386,7 +338,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
} }
}) })
except Exception as e: except Exception as e:
logger.error(f"Tool call failed: {e}")
self._send_json(200, { self._send_json(200, {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": req_id, "id": req_id,
@@ -401,30 +352,22 @@ class MCPSummaryHandler(BaseHTTPRequestHandler):
self._send_json(400, {"error": "Unknown method: " + str(method)}) self._send_json(400, {"error": "Unknown method: " + str(method)})
def _call_tool(self, name: str, args: Dict[str, Any]) -> Any: def _call_tool(self, name: str, args: Dict[str, Any]) -> Any:
"""Execute a tool call."""
if name == "summarize_document": if name == "summarize_document":
text = args.get("text") text = args.get("text")
if not text: if not text:
raise ValueError("Text parameter is required") raise ValueError("Text parameter is required")
max_length = args.get("max_length", MAX_DIRECT_SUMMARY_LENGTH) max_length = args.get("max_length", 100)
return summarize_document(text, max_length) return summarize_document(text, max_length)
raise ValueError(f"Unknown tool: {name}") raise ValueError(f"Unknown tool: {name}")
def main(): def main():
"""Start the MCP summary server.""" port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("PORT", "8080"))
server = HTTPServer(("0.0.0.0", PORT), MCPSummaryHandler) server = HTTPServer(("0.0.0.0", port), MCPSummaryHandler)
mode = "auth enabled (Bearer)" if API_KEY else "no auth (API_KEY not set)" mode = "auth enabled (Bearer)" if API_KEY else "no auth (API_KEY not set)"
print(f"MCP Summary Server listening on 0.0.0.0:{port} [{mode}]")
print(f"MCP Summary Server listening on 0.0.0.0:{PORT} [{mode}]")
print(f" - Model: {MODEL_NAME}")
print(f" - LLM URL: {OPENAPI_URL}")
print(f" - Chunk size: {CHUNK_SIZE} characters")
print(f" - Max direct text: {MAX_DIRECT_TEXT_LENGTH} characters")
print(f" - LLM timeout: {LLM_TIMEOUT} seconds")
try: try:
server.serve_forever() server.serve_forever()
except KeyboardInterrupt: except KeyboardInterrupt: