fix: Use broadcast channel for LLM streaming cancellation
All checks were successful
BotServer CI/CD / build (push) Successful in 5m48s
All checks were successful
BotServer CI/CD / build (push) Successful in 5m48s
- Broadcast channel allows multiple subscribers for cancellation - Aborts LLM task when user sends new message - Properly stops LLM generation when cancelled
This commit is contained in:
parent
9db784fd5c
commit
dd15899ac3
4 changed files with 76 additions and 27 deletions
|
|
@ -40,6 +40,7 @@ use serde_json;
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::broadcast;
|
||||
use regex;
|
||||
#[cfg(feature = "drive")]
|
||||
#[cfg(feature = "drive")]
|
||||
|
|
@ -826,27 +827,34 @@ impl BotOrchestrator {
|
|||
// set_llm_streaming(true);
|
||||
|
||||
let stream_tx_clone = stream_tx.clone();
|
||||
|
||||
|
||||
// Create cancellation channel for this streaming session
|
||||
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
|
||||
let (cancel_tx, mut cancel_rx) = broadcast::channel::<()>(1);
|
||||
let session_id_str = session.id.to_string();
|
||||
|
||||
|
||||
// Register this streaming session for potential cancellation
|
||||
{
|
||||
let mut active_streams = self.state.active_streams.lock().await;
|
||||
active_streams.insert(session_id_str.clone(), cancel_tx);
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
|
||||
// Wrap the LLM task in a JoinHandle so we can abort it
|
||||
let mut cancel_rx_for_abort = cancel_rx.resubscribe();
|
||||
let llm_task = tokio::spawn(async move {
|
||||
if let Err(e) = llm
|
||||
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
|
||||
.await
|
||||
{
|
||||
error!("LLM streaming error: {}", e);
|
||||
}
|
||||
// REMOVED: LLM streaming lock was causing deadlocks
|
||||
// #[cfg(feature = "drive")]
|
||||
// set_llm_streaming(false);
|
||||
});
|
||||
|
||||
// Wait for cancellation to abort LLM task
|
||||
tokio::spawn(async move {
|
||||
if cancel_rx_for_abort.recv().await.is_ok() {
|
||||
info!("Aborting LLM task for session {}", session_id_str);
|
||||
llm_task.abort();
|
||||
}
|
||||
});
|
||||
|
||||
let mut full_response = String::new();
|
||||
|
|
@ -883,14 +891,19 @@ impl BotOrchestrator {
|
|||
}
|
||||
}
|
||||
|
||||
while let Some(chunk) = stream_rx.recv().await {
|
||||
// Check if cancellation was requested (user sent new message)
|
||||
if cancel_rx.try_recv().is_ok() {
|
||||
while let Some(chunk) = stream_rx.recv().await {
|
||||
// Check if cancellation was requested (user sent new message)
|
||||
match cancel_rx.try_recv() {
|
||||
Ok(_) => {
|
||||
info!("Streaming cancelled for session {} - user sent new message", session.id);
|
||||
break;
|
||||
}
|
||||
|
||||
chunk_count += 1;
|
||||
Err(broadcast::error::TryRecvError::Empty) => {}
|
||||
Err(broadcast::error::TryRecvError::Closed) => break,
|
||||
Err(broadcast::error::TryRecvError::Lagged(_)) => {}
|
||||
}
|
||||
|
||||
chunk_count += 1;
|
||||
if chunk_count <= 3 || chunk_count % 50 == 0 {
|
||||
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
|
||||
}
|
||||
|
|
@ -1734,18 +1747,18 @@ let mut send_task = tokio::spawn(async move {
|
|||
channels.get(&session_id.to_string()).cloned()
|
||||
};
|
||||
|
||||
if let Some(tx_clone) = tx_opt {
|
||||
// CANCEL any existing streaming for this session first
|
||||
let session_id_str = session_id.to_string();
|
||||
{
|
||||
let mut active_streams = state_clone.active_streams.lock().await;
|
||||
if let Some(cancel_tx) = active_streams.remove(&session_id_str) {
|
||||
info!("Cancelling existing streaming for session {}", session_id);
|
||||
let _ = cancel_tx.send(()).await;
|
||||
// Give a moment for the streaming to stop
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
if let Some(tx_clone) = tx_opt {
|
||||
// CANCEL any existing streaming for this session first
|
||||
let session_id_str = session_id.to_string();
|
||||
{
|
||||
let mut active_streams = state_clone.active_streams.lock().await;
|
||||
if let Some(cancel_tx) = active_streams.remove(&session_id_str) {
|
||||
info!("Cancelling existing streaming for session {}", session_id);
|
||||
let _ = cancel_tx.send(());
|
||||
// Give a moment for the streaming to stop
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
|
||||
let corrected_msg = UserMessage {
|
||||
bot_id: bot_id.to_string(),
|
||||
|
|
|
|||
|
|
@ -272,6 +272,29 @@ impl KnowledgeBaseManager {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn delete_file_from_kb(&self, bot_id: Uuid, bot_name: &str, kb_name: &str, file_path: &str) -> Result<()> {
|
||||
let bot_id_short = bot_id.to_string().chars().take(8).collect::<String>();
|
||||
let collection_name = format!("{}_{}_{}", bot_name, bot_id_short, kb_name);
|
||||
|
||||
// Use the relative path within the gbkb folder (e.g., "cartas/file.pdf")
|
||||
let relative_path = file_path
|
||||
.strip_prefix(&format!("{}/", kb_name))
|
||||
.unwrap_or(file_path);
|
||||
|
||||
info!("Deleting vectors for file {} from collection {}", relative_path, collection_name);
|
||||
|
||||
match self.indexer.delete_file_points(&collection_name, relative_path).await {
|
||||
Ok(_) => {
|
||||
info!("Successfully deleted vectors for file {} from {}", relative_path, collection_name);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to delete vectors for file {} from {}: {}", relative_path, collection_name, e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_kb_stats(&self, bot_id: Uuid, bot_name: &str, kb_name: &str) -> Result<KbStatistics> {
|
||||
let bot_id_short = bot_id.to_string().chars().take(8).collect::<String>();
|
||||
let collection_name = format!("{}_{}_{}", bot_name, bot_id_short, kb_name);
|
||||
|
|
|
|||
|
|
@ -399,7 +399,7 @@ pub struct AppState {
|
|||
pub channels: Arc<tokio::sync::Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
|
||||
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||
/// Active streaming sessions for cancellation: session_id → cancellation sender
|
||||
pub active_streams: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<()>>>>,
|
||||
pub active_streams: Arc<tokio::sync::Mutex<HashMap<String, broadcast::Sender<()>>>>,
|
||||
/// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver.
|
||||
pub hear_channels: Arc<std::sync::Mutex<HashMap<uuid::Uuid, std::sync::mpsc::SyncSender<String>>>>,
|
||||
pub web_adapter: Arc<WebChannelAdapter>,
|
||||
|
|
|
|||
|
|
@ -1594,7 +1594,7 @@ match result {
|
|||
.strip_suffix(".gbai")
|
||||
.unwrap_or(&self.bucket_name);
|
||||
let local_path = self.work_root.join(bot_name).join(&path);
|
||||
|
||||
|
||||
if local_path.exists() {
|
||||
if let Err(e) = std::fs::remove_file(&local_path) {
|
||||
warn!("Failed to delete orphaned .gbkb file {}: {}", local_path.display(), e);
|
||||
|
|
@ -1603,10 +1603,23 @@ match result {
|
|||
}
|
||||
}
|
||||
|
||||
// Delete vectors for this specific file from Qdrant
|
||||
let path_parts: Vec<&str> = path.split('/').collect();
|
||||
if path_parts.len() >= 2 {
|
||||
let kb_name = path_parts[1];
|
||||
|
||||
#[cfg(any(feature = "research", feature = "llm"))]
|
||||
{
|
||||
if let Err(e) = self.kb_manager.delete_file_from_kb(self.bot_id, bot_name, kb_name, &path).await {
|
||||
log::error!("Failed to delete vectors for file {}: {}", path, e);
|
||||
}
|
||||
}
|
||||
#[cfg(not(any(feature = "research", feature = "llm")))]
|
||||
{
|
||||
let _ = (bot_name, kb_name);
|
||||
debug!("Bypassing vector delete because research/llm features are not enabled");
|
||||
}
|
||||
|
||||
let kb_prefix = format!("{}{}/", gbkb_prefix, kb_name);
|
||||
if !self.file_repo.has_files_with_prefix(self.bot_id, &kb_prefix) {
|
||||
#[cfg(any(feature = "research", feature = "llm"))]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue