From dd15899ac330a5f6f486a28e2d091fe2845c6bc1 Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Wed, 15 Apr 2026 09:44:42 -0300 Subject: [PATCH] fix: Use broadcast channel for LLM streaming cancellation - Broadcast channel allows multiple subscribers for cancellation - Aborts LLM task when user sends new message - Properly stops LLM generation when cancelled --- src/core/bot/mod.rs | 63 ++++++++++++++++++++-------------- src/core/kb/mod.rs | 23 +++++++++++++ src/core/shared/state.rs | 2 +- src/drive/drive_monitor/mod.rs | 15 +++++++- 4 files changed, 76 insertions(+), 27 deletions(-) diff --git a/src/core/bot/mod.rs b/src/core/bot/mod.rs index c1356128..d653bba8 100644 --- a/src/core/bot/mod.rs +++ b/src/core/bot/mod.rs @@ -40,6 +40,7 @@ use serde_json; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::mpsc; +use tokio::sync::broadcast; use regex; #[cfg(feature = "drive")] #[cfg(feature = "drive")] @@ -826,27 +827,34 @@ impl BotOrchestrator { // set_llm_streaming(true); let stream_tx_clone = stream_tx.clone(); - + // Create cancellation channel for this streaming session - let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); + let (cancel_tx, mut cancel_rx) = broadcast::channel::<()>(1); let session_id_str = session.id.to_string(); - + // Register this streaming session for potential cancellation { let mut active_streams = self.state.active_streams.lock().await; active_streams.insert(session_id_str.clone(), cancel_tx); } - - tokio::spawn(async move { + + // Wrap the LLM task in a JoinHandle so we can abort it + let mut cancel_rx_for_abort = cancel_rx.resubscribe(); + let llm_task = tokio::spawn(async move { if let Err(e) = llm .generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref()) .await { error!("LLM streaming error: {}", e); } - // REMOVED: LLM streaming lock was causing deadlocks - // #[cfg(feature = "drive")] - // set_llm_streaming(false); + }); + + // Wait for cancellation to abort LLM task + tokio::spawn(async move { + if cancel_rx_for_abort.recv().await.is_ok() { + info!("Aborting LLM task for session {}", session_id_str); + llm_task.abort(); + } }); let mut full_response = String::new(); @@ -883,14 +891,19 @@ impl BotOrchestrator { } } - while let Some(chunk) = stream_rx.recv().await { - // Check if cancellation was requested (user sent new message) - if cancel_rx.try_recv().is_ok() { +while let Some(chunk) = stream_rx.recv().await { + // Check if cancellation was requested (user sent new message) + match cancel_rx.try_recv() { + Ok(_) => { info!("Streaming cancelled for session {} - user sent new message", session.id); break; } - - chunk_count += 1; + Err(broadcast::error::TryRecvError::Empty) => {} + Err(broadcast::error::TryRecvError::Closed) => break, + Err(broadcast::error::TryRecvError::Lagged(_)) => {} + } + + chunk_count += 1; if chunk_count <= 3 || chunk_count % 50 == 0 { info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len()); } @@ -1734,18 +1747,18 @@ let mut send_task = tokio::spawn(async move { channels.get(&session_id.to_string()).cloned() }; - if let Some(tx_clone) = tx_opt { - // CANCEL any existing streaming for this session first - let session_id_str = session_id.to_string(); - { - let mut active_streams = state_clone.active_streams.lock().await; - if let Some(cancel_tx) = active_streams.remove(&session_id_str) { - info!("Cancelling existing streaming for session {}", session_id); - let _ = cancel_tx.send(()).await; - // Give a moment for the streaming to stop - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - } - } + if let Some(tx_clone) = tx_opt { + // CANCEL any existing streaming for this session first + let session_id_str = session_id.to_string(); + { + let mut active_streams = state_clone.active_streams.lock().await; + if let Some(cancel_tx) = active_streams.remove(&session_id_str) { + info!("Cancelling existing streaming for session {}", session_id); + let _ = cancel_tx.send(()); + // Give a moment for the streaming to stop + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + } let corrected_msg = UserMessage { bot_id: bot_id.to_string(), diff --git a/src/core/kb/mod.rs b/src/core/kb/mod.rs index 7cfe64de..b97ec159 100644 --- a/src/core/kb/mod.rs +++ b/src/core/kb/mod.rs @@ -272,6 +272,29 @@ impl KnowledgeBaseManager { } } + pub async fn delete_file_from_kb(&self, bot_id: Uuid, bot_name: &str, kb_name: &str, file_path: &str) -> Result<()> { + let bot_id_short = bot_id.to_string().chars().take(8).collect::(); + let collection_name = format!("{}_{}_{}", bot_name, bot_id_short, kb_name); + + // Use the relative path within the gbkb folder (e.g., "cartas/file.pdf") + let relative_path = file_path + .strip_prefix(&format!("{}/", kb_name)) + .unwrap_or(file_path); + + info!("Deleting vectors for file {} from collection {}", relative_path, collection_name); + + match self.indexer.delete_file_points(&collection_name, relative_path).await { + Ok(_) => { + info!("Successfully deleted vectors for file {} from {}", relative_path, collection_name); + Ok(()) + } + Err(e) => { + error!("Failed to delete vectors for file {} from {}: {}", relative_path, collection_name, e); + Err(e) + } + } + } + pub async fn get_kb_stats(&self, bot_id: Uuid, bot_name: &str, kb_name: &str) -> Result { let bot_id_short = bot_id.to_string().chars().take(8).collect::(); let collection_name = format!("{}_{}_{}", bot_name, bot_id_short, kb_name); diff --git a/src/core/shared/state.rs b/src/core/shared/state.rs index ffa5f890..26dd6764 100644 --- a/src/core/shared/state.rs +++ b/src/core/shared/state.rs @@ -399,7 +399,7 @@ pub struct AppState { pub channels: Arc>>>, pub response_channels: Arc>>>, /// Active streaming sessions for cancellation: session_id → cancellation sender - pub active_streams: Arc>>>, + pub active_streams: Arc>>>, /// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver. pub hear_channels: Arc>>>, pub web_adapter: Arc, diff --git a/src/drive/drive_monitor/mod.rs b/src/drive/drive_monitor/mod.rs index 36b4129f..ac52283a 100644 --- a/src/drive/drive_monitor/mod.rs +++ b/src/drive/drive_monitor/mod.rs @@ -1594,7 +1594,7 @@ match result { .strip_suffix(".gbai") .unwrap_or(&self.bucket_name); let local_path = self.work_root.join(bot_name).join(&path); - + if local_path.exists() { if let Err(e) = std::fs::remove_file(&local_path) { warn!("Failed to delete orphaned .gbkb file {}: {}", local_path.display(), e); @@ -1603,10 +1603,23 @@ match result { } } + // Delete vectors for this specific file from Qdrant let path_parts: Vec<&str> = path.split('/').collect(); if path_parts.len() >= 2 { let kb_name = path_parts[1]; + #[cfg(any(feature = "research", feature = "llm"))] + { + if let Err(e) = self.kb_manager.delete_file_from_kb(self.bot_id, bot_name, kb_name, &path).await { + log::error!("Failed to delete vectors for file {}: {}", path, e); + } + } + #[cfg(not(any(feature = "research", feature = "llm")))] + { + let _ = (bot_name, kb_name); + debug!("Bypassing vector delete because research/llm features are not enabled"); + } + let kb_prefix = format!("{}{}/", gbkb_prefix, kb_name); if !self.file_repo.has_files_with_prefix(self.bot_id, &kb_prefix) { #[cfg(any(feature = "research", feature = "llm"))]