//! Server-Sent Events (SSE) streaming handlers for chat responses //! //! This module provides real-time streaming of LLM responses using SSE, //! enabling token-by-token delivery to the client for a responsive chat experience. use axum::{ extract::{Query, State}, response::{ sse::{Event, KeepAlive, Sse}, IntoResponse, }, Json, }; use futures::stream::Stream; use log::{error, info, trace}; use serde::{Deserialize, Serialize}; use std::{convert::Infallible, sync::Arc, time::Duration}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use uuid::Uuid; use crate::llm::{LLMProvider, OpenAIClient}; use crate::shared::state::AppState; /// Request payload for streaming chat #[derive(Debug, Deserialize)] pub struct StreamChatRequest { /// Session ID pub session_id: String, /// User message content pub message: String, /// Optional system prompt override pub system_prompt: Option, /// Optional model name override pub model: Option, /// Optional bot ID pub bot_id: Option, } /// Query parameters for SSE connection #[derive(Debug, Deserialize)] pub struct StreamQuery { pub session_id: String, } /// SSE event types #[derive(Debug, Clone, Serialize)] #[serde(tag = "type", content = "data")] pub enum StreamEvent { /// Token chunk Token { content: String }, /// Thinking/reasoning content (for models that support it) Thinking { content: String }, /// Tool call request ToolCall { name: String, arguments: String }, /// Error occurred Error { message: String }, /// Stream completed Done { total_tokens: Option }, /// Stream started Start { session_id: String, model: String }, /// Metadata update Meta { key: String, value: String }, } impl StreamEvent { pub fn to_sse_event(&self) -> Result { let event_type = match self { StreamEvent::Token { .. } => "token", StreamEvent::Thinking { .. } => "thinking", StreamEvent::ToolCall { .. } => "tool_call", StreamEvent::Error { .. } => "error", StreamEvent::Done { .. } => "done", StreamEvent::Start { .. } => "start", StreamEvent::Meta { .. } => "meta", }; let data = serde_json::to_string(self)?; Ok(Event::default().event(event_type).data(data)) } } /// Stream a chat response using SSE pub async fn stream_chat_response( State(state): State, Json(payload): Json, ) -> Sse>> { let (tx, rx) = mpsc::channel::>(100); // Clone state for the spawned task let state_clone = state.clone(); let session_id = payload.session_id.clone(); let message = payload.message.clone(); let model = payload.model.clone(); let system_prompt = payload.system_prompt.clone(); let bot_id = payload .bot_id .as_ref() .and_then(|id| Uuid::parse_str(id).ok()); // Spawn the streaming task tokio::spawn(async move { if let Err(e) = handle_stream_generation( state_clone, tx.clone(), session_id, message, model, system_prompt, bot_id, ) .await { error!("Stream generation error: {}", e); let error_event = StreamEvent::Error { message: e.to_string(), }; if let Ok(event) = error_event.to_sse_event() { let _ = tx.send(Ok(event)).await; } } // Send done event let done_event = StreamEvent::Done { total_tokens: None }; if let Ok(event) = done_event.to_sse_event() { let _ = tx.send(Ok(event)).await; } }); Sse::new(ReceiverStream::new(rx)).keep_alive( KeepAlive::new() .interval(Duration::from_secs(15)) .text("keep-alive"), ) } /// Handle the actual stream generation async fn handle_stream_generation( state: AppState, tx: mpsc::Sender>, session_id: String, message: String, model_override: Option, system_prompt_override: Option, bot_id: Option, ) -> Result<(), Box> { // Get LLM configuration let (llm_url, llm_model, llm_key) = get_llm_config(&state, bot_id).await?; let model = model_override.unwrap_or(llm_model); // Send start event let start_event = StreamEvent::Start { session_id: session_id.clone(), model: model.clone(), }; if let Ok(event) = start_event.to_sse_event() { tx.send(Ok(event)).await?; } info!( "Starting SSE stream for session: {}, model: {}", session_id, model ); // Build messages let system_prompt = system_prompt_override.unwrap_or_else(|| { "You are a helpful AI assistant powered by General Bots.".to_string() }); let messages = OpenAIClient::build_messages(&system_prompt, "", &[("user".to_string(), message)]); // Create LLM client let client = OpenAIClient::new(llm_key.clone(), Some(llm_url.clone())); // Create channel for token streaming let (token_tx, mut token_rx) = mpsc::channel::(100); // Spawn LLM streaming task let client_clone = client; let messages_clone = messages.clone(); let model_clone = model.clone(); let key_clone = llm_key.clone(); tokio::spawn(async move { if let Err(e) = client_clone .generate_stream(&"", &messages_clone, token_tx, &model_clone, &key_clone) .await { error!("LLM stream error: {}", e); } }); // Forward tokens as SSE events while let Some(token) = token_rx.recv().await { trace!("Streaming token: {}", token); let token_event = StreamEvent::Token { content: token.clone(), }; if let Ok(event) = token_event.to_sse_event() { if tx.send(Ok(event)).await.is_err() { // Client disconnected info!("Client disconnected from SSE stream"); break; } } } Ok(()) } /// Get LLM configuration for a bot async fn get_llm_config( state: &AppState, bot_id: Option, ) -> Result<(String, String, String), Box> { use diesel::prelude::*; let mut conn = state .conn .get() .map_err(|e| format!("Failed to acquire database connection: {}", e))?; let target_bot_id = bot_id.unwrap_or(Uuid::nil()); #[derive(QueryableByName)] struct ConfigRow { #[diesel(sql_type = diesel::sql_types::Text)] config_key: String, #[diesel(sql_type = diesel::sql_types::Text)] config_value: String, } let configs: Vec = diesel::sql_query( "SELECT config_key, config_value FROM bot_configuration \ WHERE bot_id = $1 AND config_key IN ('llm-url', 'llm-model', 'llm-key')", ) .bind::(target_bot_id) .load(&mut conn) .unwrap_or_default(); let mut llm_url = "http://localhost:8081".to_string(); let mut llm_model = "default".to_string(); let mut llm_key = "none".to_string(); for config in configs { match config.config_key.as_str() { "llm-url" => llm_url = config.config_value, "llm-model" => llm_model = config.config_value, "llm-key" => llm_key = config.config_value, _ => {} } } Ok((llm_url, llm_model, llm_key)) } /// Create routes for streaming endpoints pub fn routes() -> axum::Router { use axum::routing::post; axum::Router::new() .route("/api/chat/stream", post(stream_chat_response)) .route("/api/v1/stream", post(stream_chat_response)) } /// Streaming chat with conversation history #[derive(Debug, Deserialize)] pub struct StreamChatWithHistoryRequest { pub session_id: String, pub message: String, pub history: Option>, pub system_prompt: Option, pub model: Option, pub bot_id: Option, pub temperature: Option, pub max_tokens: Option, } #[derive(Debug, Deserialize, Serialize, Clone)] pub struct HistoryMessage { pub role: String, pub content: String, } /// Stream chat with full conversation history pub async fn stream_chat_with_history( State(state): State, Json(payload): Json, ) -> Sse>> { let (tx, rx) = mpsc::channel::>(100); let state_clone = state.clone(); tokio::spawn(async move { if let Err(e) = handle_stream_with_history(state_clone, tx.clone(), payload).await { error!("Stream with history error: {}", e); let error_event = StreamEvent::Error { message: e.to_string(), }; if let Ok(event) = error_event.to_sse_event() { let _ = tx.send(Ok(event)).await; } } let done_event = StreamEvent::Done { total_tokens: None }; if let Ok(event) = done_event.to_sse_event() { let _ = tx.send(Ok(event)).await; } }); Sse::new(ReceiverStream::new(rx)).keep_alive( KeepAlive::new() .interval(Duration::from_secs(15)) .text("keep-alive"), ) } async fn handle_stream_with_history( state: AppState, tx: mpsc::Sender>, payload: StreamChatWithHistoryRequest, ) -> Result<(), Box> { let bot_id = payload .bot_id .as_ref() .and_then(|id| Uuid::parse_str(id).ok()); let (llm_url, llm_model, llm_key) = get_llm_config(&state, bot_id).await?; let model = payload.model.unwrap_or(llm_model); // Send start event let start_event = StreamEvent::Start { session_id: payload.session_id.clone(), model: model.clone(), }; if let Ok(event) = start_event.to_sse_event() { tx.send(Ok(event)).await?; } // Build history let history: Vec<(String, String)> = payload .history .unwrap_or_default() .into_iter() .map(|h| (h.role, h.content)) .chain(std::iter::once(("user".to_string(), payload.message))) .collect(); let system_prompt = payload.system_prompt.unwrap_or_else(|| { "You are a helpful AI assistant powered by General Bots.".to_string() }); let messages = OpenAIClient::build_messages(&system_prompt, "", &history); let client = OpenAIClient::new(llm_key.clone(), Some(llm_url.clone())); let (token_tx, mut token_rx) = mpsc::channel::(100); let client_clone = client; let messages_clone = messages.clone(); let model_clone = model.clone(); let key_clone = llm_key.clone(); tokio::spawn(async move { if let Err(e) = client_clone .generate_stream(&"", &messages_clone, token_tx, &model_clone, &key_clone) .await { error!("LLM stream error: {}", e); } }); while let Some(token) = token_rx.recv().await { let token_event = StreamEvent::Token { content: token.clone(), }; if let Ok(event) = token_event.to_sse_event() { if tx.send(Ok(event)).await.is_err() { break; } } } Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_stream_event_to_sse() { let event = StreamEvent::Token { content: "Hello".to_string(), }; let sse = event.to_sse_event(); assert!(sse.is_ok()); } #[test] fn test_stream_event_done() { let event = StreamEvent::Done { total_tokens: Some(100), }; let sse = event.to_sse_event(); assert!(sse.is_ok()); } #[test] fn test_stream_event_error() { let event = StreamEvent::Error { message: "Test error".to_string(), }; let sse = event.to_sse_event(); assert!(sse.is_ok()); } #[test] fn test_history_message_serialization() { let msg = HistoryMessage { role: "user".to_string(), content: "Hello".to_string(), }; let json = serde_json::to_string(&msg); assert!(json.is_ok()); } }