use axum::{ extract::{ ws::{Message}, State, WebSocketUpgrade, }, response::Html, routing::{get, post}, Router, Json, }; use futures::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use std::{ net::SocketAddr, str::FromStr, sync::Arc, }; use tower_http::cors::{Any, CorsLayer}; use tracing::info; use crate::docx_tools::DocxToolsProvider; /// Application state shared across HTTP handlers pub struct AppState { pub provider: DocxToolsProvider, } /// Request to call a tool #[derive(Debug, Deserialize)] pub struct ToolCallRequest { pub name: String, pub arguments: serde_json::Value, } /// Response from a tool call #[derive(Debug, Serialize)] pub struct ToolCallResponse { pub success: bool, pub content: serde_json::Value, pub error: Option, } /// Response with list of tools #[derive(Debug, Serialize)] pub struct ListToolsResponse { pub success: bool, pub tools: Vec, } /// Start the HTTP server pub async fn start_http_server(addr: &str, provider: DocxToolsProvider) -> anyhow::Result<()> { let state = Arc::new(AppState { provider }); let app = Router::new() .with_state(state.clone()) // Serve HTML interface .route("/", get(index_handler)) .route("/api/tools", get(list_tools_handler)) .route("/api/call", post(call_tool_handler)) .route("/ws", get(ws_handler)) // CORS policy - allow all origins on LAN .layer( CorsLayer::new() .allow_origin(Any) .allow_methods([axum::http::Method::GET, axum::http::Method::POST]) ); let addr = SocketAddr::from_str(addr).unwrap_or_else(|_| { info!("Invalid address format, using default 0.0.0.0:3000"); "0.0.0.0:3000".parse().unwrap() }); info!("Starting HTTP server on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app).await?; Ok(()) } /// Serve the HTML interface async fn index_handler() -> Html { Html(include_str!("../assets/html_interface.html").to_string()) } /// List available tools async fn list_tools_handler(State(state): State>) -> Json { let tools = state.provider.list_tools().await; let tool_list: Vec = tools.iter().map(|t| { serde_json::json!({ "name": t.name, "description": t.description, "input_schema": t.input_schema }) }).collect(); Json(ListToolsResponse { success: true, tools: tool_list, }) } /// Call a tool via HTTP POST async fn call_tool_handler( State(state): State>, Json(request): Json, ) -> Json { let response = state.provider.call_tool(&request.name, request.arguments).await; // Convert response to JSON let content = if let Some(content) = response.content.first() { match content { mcp_core::types::ToolResponseContent::Text(text) => { serde_json::from_str(&text.text).unwrap_or_else(|_| { serde_json::json!({"text": text.text.clone()}) }) }, mcp_core::types::ToolResponseContent::Image(image) => { serde_json::json!({ "data": image.data, "mimeType": image.mime_type }) }, _ => serde_json::json!({}), } } else { serde_json::json!({}) }; Json(ToolCallResponse { success: response.is_error.unwrap_or(false) == false, content, error: response.is_error.unwrap_or(false).then(|| "Tool call failed".to_string()), }) } /// WebSocket handler for real-time communication async fn ws_handler( ws: WebSocketUpgrade, State(state): State> ) -> axum::response::Response { ws.on_upgrade(move |socket| async move { let provider = state.provider.clone(); let mut ws = socket; // Handle WebSocket messages while let Some(msg) = ws.recv().await { let msg = match msg { Ok(msg) => msg, Err(_) => continue, }; let text = match msg { Message::Text(text) => text.to_string(), _ => continue, }; // Parse request let request: ToolCallRequest = match serde_json::from_str(&text) { Ok(req) => req, Err(e) => { let error_response = ToolCallResponse { success: false, content: serde_json::json!({}), error: Some(format!("Parse error: {}", e)), }; let _ = ws.send(Message::Text( serde_json::to_string(&error_response).unwrap_or("{}".to_string()) )).await; continue; } }; // Call tool let response = provider.call_tool(&request.name, request.arguments).await; // Convert response to JSON let content = if let Some(content) = response.content.first() { match content { mcp_core::types::ToolResponseContent::Text(text) => { serde_json::from_str(&text.text).unwrap_or_else(|_| { serde_json::json!({"text": text.text.clone()}) }) }, mcp_core::types::ToolResponseContent::Image(image) => { serde_json::json!({ "data": image.data, "mimeType": image.mime_type }) }, } } else { serde_json::json!({}) }; let ws_response = ToolCallResponse { success: response.is_error.unwrap_or(false) == false, content, error: response.is_error.unwrap_or(false).then(|| "Tool call failed".to_string()), }; let _ = ws.send(Message::Text( serde_json::to_string(&ws_response).unwrap_or("{}".to_string()) )).await; } }) }