fix: Use broadcast channel for LLM streaming cancellation
All checks were successful
BotServer CI/CD / build (push) Successful in 5m48s

- Broadcast channel allows multiple subscribers for cancellation
- Aborts LLM task when user sends new message
- Properly stops LLM generation when cancelled
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-04-15 09:44:42 -03:00
parent 9db784fd5c
commit dd15899ac3
4 changed files with 76 additions and 27 deletions

View file

@ -40,6 +40,7 @@ use serde_json;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::broadcast;
use regex;
#[cfg(feature = "drive")]
#[cfg(feature = "drive")]
@ -826,27 +827,34 @@ impl BotOrchestrator {
// set_llm_streaming(true);
let stream_tx_clone = stream_tx.clone();
// Create cancellation channel for this streaming session
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
let (cancel_tx, mut cancel_rx) = broadcast::channel::<()>(1);
let session_id_str = session.id.to_string();
// Register this streaming session for potential cancellation
{
let mut active_streams = self.state.active_streams.lock().await;
active_streams.insert(session_id_str.clone(), cancel_tx);
}
tokio::spawn(async move {
// Wrap the LLM task in a JoinHandle so we can abort it
let mut cancel_rx_for_abort = cancel_rx.resubscribe();
let llm_task = tokio::spawn(async move {
if let Err(e) = llm
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
.await
{
error!("LLM streaming error: {}", e);
}
// REMOVED: LLM streaming lock was causing deadlocks
// #[cfg(feature = "drive")]
// set_llm_streaming(false);
});
// Wait for cancellation to abort LLM task
tokio::spawn(async move {
if cancel_rx_for_abort.recv().await.is_ok() {
info!("Aborting LLM task for session {}", session_id_str);
llm_task.abort();
}
});
let mut full_response = String::new();
@ -883,14 +891,19 @@ impl BotOrchestrator {
}
}
while let Some(chunk) = stream_rx.recv().await {
// Check if cancellation was requested (user sent new message)
if cancel_rx.try_recv().is_ok() {
while let Some(chunk) = stream_rx.recv().await {
// Check if cancellation was requested (user sent new message)
match cancel_rx.try_recv() {
Ok(_) => {
info!("Streaming cancelled for session {} - user sent new message", session.id);
break;
}
chunk_count += 1;
Err(broadcast::error::TryRecvError::Empty) => {}
Err(broadcast::error::TryRecvError::Closed) => break,
Err(broadcast::error::TryRecvError::Lagged(_)) => {}
}
chunk_count += 1;
if chunk_count <= 3 || chunk_count % 50 == 0 {
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
}
@ -1734,18 +1747,18 @@ let mut send_task = tokio::spawn(async move {
channels.get(&session_id.to_string()).cloned()
};
if let Some(tx_clone) = tx_opt {
// CANCEL any existing streaming for this session first
let session_id_str = session_id.to_string();
{
let mut active_streams = state_clone.active_streams.lock().await;
if let Some(cancel_tx) = active_streams.remove(&session_id_str) {
info!("Cancelling existing streaming for session {}", session_id);
let _ = cancel_tx.send(()).await;
// Give a moment for the streaming to stop
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
}
if let Some(tx_clone) = tx_opt {
// CANCEL any existing streaming for this session first
let session_id_str = session_id.to_string();
{
let mut active_streams = state_clone.active_streams.lock().await;
if let Some(cancel_tx) = active_streams.remove(&session_id_str) {
info!("Cancelling existing streaming for session {}", session_id);
let _ = cancel_tx.send(());
// Give a moment for the streaming to stop
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
let corrected_msg = UserMessage {
bot_id: bot_id.to_string(),

View file

@ -272,6 +272,29 @@ impl KnowledgeBaseManager {
}
}
pub async fn delete_file_from_kb(&self, bot_id: Uuid, bot_name: &str, kb_name: &str, file_path: &str) -> Result<()> {
let bot_id_short = bot_id.to_string().chars().take(8).collect::<String>();
let collection_name = format!("{}_{}_{}", bot_name, bot_id_short, kb_name);
// Use the relative path within the gbkb folder (e.g., "cartas/file.pdf")
let relative_path = file_path
.strip_prefix(&format!("{}/", kb_name))
.unwrap_or(file_path);
info!("Deleting vectors for file {} from collection {}", relative_path, collection_name);
match self.indexer.delete_file_points(&collection_name, relative_path).await {
Ok(_) => {
info!("Successfully deleted vectors for file {} from {}", relative_path, collection_name);
Ok(())
}
Err(e) => {
error!("Failed to delete vectors for file {} from {}: {}", relative_path, collection_name, e);
Err(e)
}
}
}
pub async fn get_kb_stats(&self, bot_id: Uuid, bot_name: &str, kb_name: &str) -> Result<KbStatistics> {
let bot_id_short = bot_id.to_string().chars().take(8).collect::<String>();
let collection_name = format!("{}_{}_{}", bot_name, bot_id_short, kb_name);

View file

@ -399,7 +399,7 @@ pub struct AppState {
pub channels: Arc<tokio::sync::Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
/// Active streaming sessions for cancellation: session_id → cancellation sender
pub active_streams: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<()>>>>,
pub active_streams: Arc<tokio::sync::Mutex<HashMap<String, broadcast::Sender<()>>>>,
/// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver.
pub hear_channels: Arc<std::sync::Mutex<HashMap<uuid::Uuid, std::sync::mpsc::SyncSender<String>>>>,
pub web_adapter: Arc<WebChannelAdapter>,

View file

@ -1594,7 +1594,7 @@ match result {
.strip_suffix(".gbai")
.unwrap_or(&self.bucket_name);
let local_path = self.work_root.join(bot_name).join(&path);
if local_path.exists() {
if let Err(e) = std::fs::remove_file(&local_path) {
warn!("Failed to delete orphaned .gbkb file {}: {}", local_path.display(), e);
@ -1603,10 +1603,23 @@ match result {
}
}
// Delete vectors for this specific file from Qdrant
let path_parts: Vec<&str> = path.split('/').collect();
if path_parts.len() >= 2 {
let kb_name = path_parts[1];
#[cfg(any(feature = "research", feature = "llm"))]
{
if let Err(e) = self.kb_manager.delete_file_from_kb(self.bot_id, bot_name, kb_name, &path).await {
log::error!("Failed to delete vectors for file {}: {}", path, e);
}
}
#[cfg(not(any(feature = "research", feature = "llm")))]
{
let _ = (bot_name, kb_name);
debug!("Bypassing vector delete because research/llm features are not enabled");
}
let kb_prefix = format!("{}{}/", gbkb_prefix, kb_name);
if !self.file_repo.has_files_with_prefix(self.bot_id, &kb_prefix) {
#[cfg(any(feature = "research", feature = "llm"))]