diff --git a/mcp_summary_server.py b/mcp_summary_server.py index 29b1b9e..e422fd6 100644 --- a/mcp_summary_server.py +++ b/mcp_summary_server.py @@ -25,36 +25,12 @@ import json import os import sys import re -import logging from http.server import HTTPServer, BaseHTTPRequestHandler -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import requests -from requests.exceptions import RequestException -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger("mcp-summary") - -# MCP Server Configuration API_KEY = os.environ.get("API_KEY", "").strip() -PORT = int(os.environ.get("PORT", "8080")) -# LLM Configuration -OPENAPI_URL = os.environ.get("OPENAPI_URL", "http://localhost:8080/v1") -OPENAPI_API_KEY = os.environ.get("OPENAPI_API_KEY", "") -MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o") - -# Summarization Configuration -CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "4000")) -OVERLAP = int(os.environ.get("OVERLAP", "200")) -TARGET_INTERMEDIATE_SUMMARY_LENGTH = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150")) -MAX_DIRECT_SUMMARY_LENGTH = int(os.environ.get("MAX_DIRECT_SUMMARY_LENGTH", "100")) -MAX_DIRECT_TEXT_LENGTH = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000")) - -# LLM call timeout in seconds - increase for large documents -LLM_TIMEOUT = int(os.environ.get("LLM_TIMEOUT", "120")) - -# Tool definitions TOOLS_LIST: Dict[str, Any] = { "tools": [ { @@ -69,7 +45,7 @@ TOOLS_LIST: Dict[str, Any] = { }, "max_length": { "type": "integer", - "description": f"Maximum length of summary in words (default: {MAX_DIRECT_SUMMARY_LENGTH})" + "description": "Maximum length of summary in words (default: 100)" } }, "required": ["text"] @@ -79,53 +55,73 @@ TOOLS_LIST: Dict[str, Any] = { } -def call_llm(messages: List[Dict], temperature: float = 0.3) -> str: - """Make an OpenAPI-compatible LLM call with error handling.""" - url = f"{OPENAPI_URL}/chat/completions" +def get_bearer_token(headers: Any) -> Optional[str]: + auth = (headers.get("Authorization") or "").strip() + if auth.startswith("Bearer "): + return auth[len("Bearer "):].strip() + return None + + +def require_auth(headers: Any) -> bool: + # If API_KEY is not set, allow unauthenticated access + if not API_KEY: + return True + + token = get_bearer_token(headers) + if not token or token != API_KEY: + raise PermissionError("Missing or invalid API key") + return True + + +def call_llm(text: str, system_prompt: str, max_tokens: int = 2000) -> str: + """Make an OpenAPI-compatible LLM call.""" + openapi_url = os.environ.get("OPENAPI_URL", "http://localhost:8080/v1") + openapi_api_key = os.environ.get("OPENAPI_API_KEY", "") + model_name = os.environ.get("MODEL_NAME", "gpt-4o") + timeout = int(os.environ.get("LLM_TIMEOUT", "120")) + + url = f"{openapi_url}/chat/completions" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {OPENAPI_API_KEY}" + "Authorization": f"Bearer {openapi_api_key}" } payload = { - "model": MODEL_NAME, - "messages": messages, - "temperature": temperature, - "max_tokens": 2000, + "model": model_name, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": text} + ], + "temperature": 0.3, + "max_tokens": max_tokens, "top_p": 0.9 } - try: - logger.info(f"Calling LLM at {OPENAPI_URL} with model {MODEL_NAME}") - response = requests.post(url, headers=headers, json=payload, timeout=LLM_TIMEOUT) - response.raise_for_status() - - data = response.json() - return data["choices"][0]["message"]["content"] + response = requests.post(url, headers=headers, json=payload, timeout=timeout) + response.raise_for_status() - except RequestException as e: - logger.error(f"LLM request failed: {e}") - raise RuntimeError(f"Failed to connect to LLM at {OPENAPI_URL}: {str(e)}") - except Exception as e: - logger.error(f"LLM call failed: {e}") - raise RuntimeError(f"LLM call failed: {str(e)}") + data = response.json() + return data["choices"][0]["message"]["content"] -def chunk_text(text: str) -> List[str]: +def chunk_text(text: str) -> list: """Split text into chunks with overlap for summarization.""" - if len(text) <= CHUNK_SIZE: + chunk_size = int(os.environ.get("CHUNK_SIZE", "4000")) + overlap = int(os.environ.get("OVERLAP", "200")) + + if len(text) <= chunk_size: return [text] chunks = [] start = 0 while start < len(text): - end = min(start + CHUNK_SIZE, len(text)) + end = min(start + chunk_size, len(text)) # Try to break at sentence/paragraph boundary break_point = end for marker in ["\n\n", "\n", ". ", "! ", "? "]: - pos = text.rfind(marker, start + CHUNK_SIZE // 2, end) + pos = text.rfind(marker, start + chunk_size // 2, end) if pos > start: break_point = pos break @@ -134,46 +130,84 @@ def chunk_text(text: str) -> List[str]: if chunk.strip(): chunks.append(chunk) - start = break_point - OVERLAP if break_point < len(text) else len(text) + start = break_point - overlap if break_point < len(text) else len(text) if start >= len(text): break - logger.info(f"Split text into {len(chunks)} chunks") return chunks -def summarize_chunk(chunk: str, chunk_num: int, total_chunks: int) -> str: - """Summarize a single chunk of text.""" - system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries. +def summarize_document(text: str, max_length: int = 100) -> dict: + """ + Main summarization function. + + - If text is short, summarize directly + - If text is long, chunk and summarize each chunk, then synthesize + """ + original_length = len(text) + + text = text.strip() + if not text: + raise ValueError("Empty text provided") + + max_direct_length = int(os.environ.get("MAX_DIRECT_TEXT_LENGTH", "8000")) + intermediate_length = int(os.environ.get("TARGET_INTERMEDIATE_SUMMARY_LENGTH", "150")) + + # Direct summarization for shorter texts + if len(text) <= max_direct_length: + system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries. -You are processing chunk {chunk_num} of {total_chunks} from a larger document. +Create a summary that: +- Is approximately {max_length} words +- Captures key points and important details +- Uses clear, professional language +- Preserves names, dates, and specific facts + +Format as plain text without bullet points.""" + + user_prompt = f"""Summarize the following document: + +{text} + +Summary:""" + + summary = call_llm(user_prompt, system_prompt) + + return { + "summary": summary, + "original_length": original_length, + "method": "direct", + "chunks": 1 + } + + # Chunked summarization for longer texts + chunks = chunk_text(text) + + chunk_summaries = [] + for i, chunk in enumerate(chunks, 1): + system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries. + +You are processing chunk {i} of {len(chunks)} from a larger document. Create a focused summary that: - Captures key points and important details -- Is approximately {TARGET_INTERMEDIATE_SUMMARY_LENGTH} words +- Is approximately {intermediate_length} words - Can be combined with other chunk summaries - Uses clear, professional language - Preserves names, dates, and specific facts Respond as plain text without bullet points.""" - - user_prompt = f"""Summarize this text (chunk {chunk_num} of {total_chunks}): + + user_prompt = f"""Summarize this text (chunk {i} of {len(chunks)}): -{text} +{chunk} Summary:""" + + chunk_summary = call_llm(user_prompt, system_prompt) + chunk_summaries.append(chunk_summary) - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - - logger.info(f"Summarizing chunk {chunk_num}/{total_chunks}") - return call_llm(messages) - - -def synthesize_summaries(chunk_summaries: List[str]) -> str: - """Synthesize multiple chunk summaries into a single final summary.""" + # Synthesize into final summary combined = "\n\n".join(chunk_summaries) system_prompt = """You are a precise legal assistant creating executive-level summaries. @@ -194,71 +228,7 @@ Format as a single paragraph of plain text.""" Final summary:""" - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - - logger.info(f"Synthesizing {len(chunk_summaries)} chunk summaries") - return call_llm(messages) - - -def summarize_document(text: str, max_length: int = MAX_DIRECT_SUMMARY_LENGTH) -> Dict[str, Any]: - """ - Main summarization function. - - - If text is short, summarize directly - - If text is long, chunk and summarize each chunk, then synthesize - """ - original_length = len(text) - - text = text.strip() - if not text: - raise ValueError("Empty text provided") - - logger.info(f"Summarizing text of {original_length} characters") - - # Direct summarization for shorter texts - if len(text) <= MAX_DIRECT_TEXT_LENGTH: - system_prompt = f"""You are a precise legal assistant creating concise, accurate summaries. - -Create a summary that: -- Is approximately {max_length} words -- Captures key points and important details -- Uses clear, professional language -- Preserves names, dates, and specific facts - -Format as plain text without bullet points.""" - - user_prompt = f"""Summarize the following document: - -{text} - -Summary:""" - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - - summary = call_llm(messages) - - return { - "summary": summary, - "original_length": original_length, - "method": "direct", - "chunks": 1 - } - - # Chunked summarization for longer texts - chunks = chunk_text(text) - - chunk_summaries = [] - for i, chunk in enumerate(chunks, 1): - chunk_summary = summarize_chunk(chunk, i, len(chunks)) - chunk_summaries.append(chunk_summary) - - final_summary = synthesize_summaries(chunk_summaries) + final_summary = call_llm(user_prompt, system_prompt) return { "summary": final_summary, @@ -269,79 +239,62 @@ Summary:""" class MCPSummaryHandler(BaseHTTPRequestHandler): - """HTTP handler for MCP summary server.""" - def log_message(self, format, *args): - logger.info(format % args) - + # Quiet logs by default + pass + def _send_json(self, status: int, payload: Any): - """Send JSON response.""" body = json.dumps(payload, ensure_ascii=False).encode("utf-8") self.send_response(status) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", str(len(body))) self.end_headers() self.wfile.write(body) - - def _auth_or_401(self) -> bool: - """Check authentication if API key is configured.""" - if not API_KEY: - return True - - auth_header = self.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): + + def _auth_or_401(self): + try: + return require_auth(self.headers) + except PermissionError: self._send_json(401, {"error": "Missing or invalid API key"}) return False - - token = auth_header[len("Bearer "):].strip() - if token != API_KEY: - self._send_json(401, {"error": "Invalid API key"}) - return False - - return True - + def do_GET(self): - """Handle GET requests (health check).""" + # Basic info endpoint (not required by MCP, but useful) if self.path == "/": self._send_json(200, { "service": "mcp-summary", "transport": "streamable-http", - "model": MODEL_NAME, - "status": "running", "docs": "Use POST / with MCP JSON-RPC (initialize, tools/list, tools/call)." }) return - + self.send_error(404, "Not Found") - + def do_POST(self): - """Handle MCP JSON-RPC requests.""" + # Streamable HTTP MCP endpoint if self.path not in ("/", "/mcp"): self.send_error(404, "Not Found") return - + if not self._auth_or_401(): return - - # Parse request + length = int(self.headers.get("Content-Length", 0)) if length == 0: self._send_json(400, {"error": "Empty body"}) return - + raw = self.rfile.read(length) try: req = json.loads(raw) except json.JSONDecodeError: self._send_json(400, {"error": "Invalid JSON"}) return - + method = req.get("method") params = req.get("params") or {} req_id = req.get("id") - - logger.info(f"MCP request: method={method}, id={req_id}") - + # MCP: initialize if method == "initialize": self._send_json(200, { @@ -359,7 +312,7 @@ class MCPSummaryHandler(BaseHTTPRequestHandler): } }) return - + # MCP: tools/list if method == "tools/list": self._send_json(200, { @@ -368,12 +321,11 @@ class MCPSummaryHandler(BaseHTTPRequestHandler): "result": TOOLS_LIST }) return - + # MCP: tools/call if method == "tools/call": tool_name = params.get("name") tool_args = params.get("arguments") or {} - try: result = self._call_tool(tool_name, tool_args) self._send_json(200, { @@ -386,7 +338,6 @@ class MCPSummaryHandler(BaseHTTPRequestHandler): } }) except Exception as e: - logger.error(f"Tool call failed: {e}") self._send_json(200, { "jsonrpc": "2.0", "id": req_id, @@ -396,35 +347,27 @@ class MCPSummaryHandler(BaseHTTPRequestHandler): } }) return - + # Unknown method self._send_json(400, {"error": "Unknown method: " + str(method)}) - + def _call_tool(self, name: str, args: Dict[str, Any]) -> Any: - """Execute a tool call.""" if name == "summarize_document": text = args.get("text") if not text: raise ValueError("Text parameter is required") - max_length = args.get("max_length", MAX_DIRECT_SUMMARY_LENGTH) + max_length = args.get("max_length", 100) return summarize_document(text, max_length) - + raise ValueError(f"Unknown tool: {name}") def main(): - """Start the MCP summary server.""" - server = HTTPServer(("0.0.0.0", PORT), MCPSummaryHandler) + port = int(sys.argv[1]) if len(sys.argv) > 1 else int(os.environ.get("PORT", "8080")) + server = HTTPServer(("0.0.0.0", port), MCPSummaryHandler) mode = "auth enabled (Bearer)" if API_KEY else "no auth (API_KEY not set)" - - print(f"MCP Summary Server listening on 0.0.0.0:{PORT} [{mode}]") - print(f" - Model: {MODEL_NAME}") - print(f" - LLM URL: {OPENAPI_URL}") - print(f" - Chunk size: {CHUNK_SIZE} characters") - print(f" - Max direct text: {MAX_DIRECT_TEXT_LENGTH} characters") - print(f" - LLM timeout: {LLM_TIMEOUT} seconds") - + print(f"MCP Summary Server listening on 0.0.0.0:{port} [{mode}]") try: server.serve_forever() except KeyboardInterrupt: