diff --git a/src/core/bot/mod.rs b/src/core/bot/mod.rs index 3063639c..c1356128 100644 --- a/src/core/bot/mod.rs +++ b/src/core/bot/mod.rs @@ -825,18 +825,29 @@ impl BotOrchestrator { // #[cfg(feature = "drive")] // set_llm_streaming(true); - let stream_tx_clone = stream_tx.clone(); - tokio::spawn(async move { - if let Err(e) = llm - .generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref()) - .await - { - error!("LLM streaming error: {}", e); - } - // REMOVED: LLM streaming lock was causing deadlocks - // #[cfg(feature = "drive")] - // set_llm_streaming(false); - }); + let stream_tx_clone = stream_tx.clone(); + + // Create cancellation channel for this streaming session + let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); + let session_id_str = session.id.to_string(); + + // Register this streaming session for potential cancellation + { + let mut active_streams = self.state.active_streams.lock().await; + active_streams.insert(session_id_str.clone(), cancel_tx); + } + + tokio::spawn(async move { + if let Err(e) = llm + .generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref()) + .await + { + error!("LLM streaming error: {}", e); + } + // REMOVED: LLM streaming lock was causing deadlocks + // #[cfg(feature = "drive")] + // set_llm_streaming(false); + }); let mut full_response = String::new(); let mut analysis_buffer = String::new(); @@ -872,11 +883,17 @@ impl BotOrchestrator { } } - while let Some(chunk) = stream_rx.recv().await { - chunk_count += 1; - if chunk_count <= 3 || chunk_count % 50 == 0 { - info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len()); - } + while let Some(chunk) = stream_rx.recv().await { + // Check if cancellation was requested (user sent new message) + if cancel_rx.try_recv().is_ok() { + info!("Streaming cancelled for session {} - user sent new message", session.id); + break; + } + + chunk_count += 1; + if chunk_count <= 3 || chunk_count % 50 == 0 { + info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len()); + } // ===== GENERIC TOOL EXECUTION ===== // Add chunk to tool_call_buffer and try to parse @@ -1718,25 +1735,37 @@ let mut send_task = tokio::spawn(async move { }; if let Some(tx_clone) = tx_opt { - let corrected_msg = UserMessage { - bot_id: bot_id.to_string(), - user_id: session.user_id.to_string(), - session_id: session.id.to_string(), - ..user_msg - }; - info!("Calling orchestrator for session {}", session_id); + // CANCEL any existing streaming for this session first + let session_id_str = session_id.to_string(); + { + let mut active_streams = state_clone.active_streams.lock().await; + if let Some(cancel_tx) = active_streams.remove(&session_id_str) { + info!("Cancelling existing streaming for session {}", session_id); + let _ = cancel_tx.send(()).await; + // Give a moment for the streaming to stop + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + } + } + + let corrected_msg = UserMessage { + bot_id: bot_id.to_string(), + user_id: session.user_id.to_string(), + session_id: session.id.to_string(), + ..user_msg + }; + info!("Calling orchestrator for session {}", session_id); - // Spawn LLM in its own task so recv_task stays free to handle - // new messages — prevents one hung LLM from locking the session. - let orch = BotOrchestrator::new(state_clone.clone()); - tokio::spawn(async move { - if let Err(e) = orch - .stream_response(corrected_msg, tx_clone) - .await - { - error!("Failed to stream response: {}", e); - } - }); + // Spawn LLM in its own task so recv_task stays free to handle + // new messages — prevents one hung LLM from locking the session. + let orch = BotOrchestrator::new(state_clone.clone()); + tokio::spawn(async move { + if let Err(e) = orch + .stream_response(corrected_msg, tx_clone) + .await + { + error!("Failed to stream response: {}", e); + } + }); } else { warn!("Response channel NOT found for session: {}", session_id); } diff --git a/src/core/shared/state.rs b/src/core/shared/state.rs index 6aa95df9..ffa5f890 100644 --- a/src/core/shared/state.rs +++ b/src/core/shared/state.rs @@ -398,6 +398,8 @@ pub struct AppState { pub auth_service: Arc>, pub channels: Arc>>>, pub response_channels: Arc>>>, + /// Active streaming sessions for cancellation: session_id → cancellation sender + pub active_streams: Arc>>>, /// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver. pub hear_channels: Arc>>>, pub web_adapter: Arc, @@ -450,6 +452,7 @@ impl Clone for AppState { kb_manager: self.kb_manager.clone(), channels: Arc::clone(&self.channels), response_channels: Arc::clone(&self.response_channels), + active_streams: Arc::clone(&self.active_streams), hear_channels: Arc::clone(&self.hear_channels), web_adapter: Arc::clone(&self.web_adapter), voice_adapter: Arc::clone(&self.voice_adapter), @@ -665,6 +668,7 @@ impl Default for AppState { auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())), channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())), hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())), web_adapter: Arc::new(WebChannelAdapter::new()), voice_adapter: Arc::new(VoiceAdapter::new()), diff --git a/src/main_module/bootstrap.rs b/src/main_module/bootstrap.rs index 3c720d4e..18f53f2e 100644 --- a/src/main_module/bootstrap.rs +++ b/src/main_module/bootstrap.rs @@ -606,16 +606,17 @@ pub async fn create_app_state( dynamic_llm_provider: Some(dynamic_llm_provider.clone()), #[cfg(feature = "directory")] auth_service: auth_service.clone(), - channels: Arc::new(tokio::sync::Mutex::new({ - let mut map = HashMap::new(); - map.insert( - "web".to_string(), - web_adapter.clone() as Arc, - ); - map - })), - response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), - hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())), + channels: Arc::new(tokio::sync::Mutex::new({ + let mut map = HashMap::new(); + map.insert( + "web".to_string(), + web_adapter.clone() as Arc, + ); + map + })), + response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())), web_adapter: web_adapter.clone(), voice_adapter: voice_adapter.clone(), #[cfg(any(feature = "research", feature = "llm"))]