feat: Cancel streaming LLM when user sends new message
All checks were successful
BotServer CI/CD / build (push) Successful in 6m4s

- Add active_streams HashMap to AppState to track streaming sessions
- Create cancellation channel for each streaming session
- Cancel existing streaming when new message arrives
- Prevents overlapping responses and improves UX
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-04-15 07:37:07 -03:00
parent 01d4f47a93
commit 9db784fd5c
3 changed files with 79 additions and 45 deletions

View file

@ -825,18 +825,29 @@ impl BotOrchestrator {
// #[cfg(feature = "drive")]
// set_llm_streaming(true);
let stream_tx_clone = stream_tx.clone();
tokio::spawn(async move {
if let Err(e) = llm
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
.await
{
error!("LLM streaming error: {}", e);
}
// REMOVED: LLM streaming lock was causing deadlocks
// #[cfg(feature = "drive")]
// set_llm_streaming(false);
});
let stream_tx_clone = stream_tx.clone();
// Create cancellation channel for this streaming session
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
let session_id_str = session.id.to_string();
// Register this streaming session for potential cancellation
{
let mut active_streams = self.state.active_streams.lock().await;
active_streams.insert(session_id_str.clone(), cancel_tx);
}
tokio::spawn(async move {
if let Err(e) = llm
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
.await
{
error!("LLM streaming error: {}", e);
}
// REMOVED: LLM streaming lock was causing deadlocks
// #[cfg(feature = "drive")]
// set_llm_streaming(false);
});
let mut full_response = String::new();
let mut analysis_buffer = String::new();
@ -872,11 +883,17 @@ impl BotOrchestrator {
}
}
while let Some(chunk) = stream_rx.recv().await {
chunk_count += 1;
if chunk_count <= 3 || chunk_count % 50 == 0 {
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
}
while let Some(chunk) = stream_rx.recv().await {
// Check if cancellation was requested (user sent new message)
if cancel_rx.try_recv().is_ok() {
info!("Streaming cancelled for session {} - user sent new message", session.id);
break;
}
chunk_count += 1;
if chunk_count <= 3 || chunk_count % 50 == 0 {
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
}
// ===== GENERIC TOOL EXECUTION =====
// Add chunk to tool_call_buffer and try to parse
@ -1718,25 +1735,37 @@ let mut send_task = tokio::spawn(async move {
};
if let Some(tx_clone) = tx_opt {
let corrected_msg = UserMessage {
bot_id: bot_id.to_string(),
user_id: session.user_id.to_string(),
session_id: session.id.to_string(),
..user_msg
};
info!("Calling orchestrator for session {}", session_id);
// CANCEL any existing streaming for this session first
let session_id_str = session_id.to_string();
{
let mut active_streams = state_clone.active_streams.lock().await;
if let Some(cancel_tx) = active_streams.remove(&session_id_str) {
info!("Cancelling existing streaming for session {}", session_id);
let _ = cancel_tx.send(()).await;
// Give a moment for the streaming to stop
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
}
let corrected_msg = UserMessage {
bot_id: bot_id.to_string(),
user_id: session.user_id.to_string(),
session_id: session.id.to_string(),
..user_msg
};
info!("Calling orchestrator for session {}", session_id);
// Spawn LLM in its own task so recv_task stays free to handle
// new messages — prevents one hung LLM from locking the session.
let orch = BotOrchestrator::new(state_clone.clone());
tokio::spawn(async move {
if let Err(e) = orch
.stream_response(corrected_msg, tx_clone)
.await
{
error!("Failed to stream response: {}", e);
}
});
// Spawn LLM in its own task so recv_task stays free to handle
// new messages — prevents one hung LLM from locking the session.
let orch = BotOrchestrator::new(state_clone.clone());
tokio::spawn(async move {
if let Err(e) = orch
.stream_response(corrected_msg, tx_clone)
.await
{
error!("Failed to stream response: {}", e);
}
});
} else {
warn!("Response channel NOT found for session: {}", session_id);
}

View file

@ -398,6 +398,8 @@ pub struct AppState {
pub auth_service: Arc<tokio::sync::Mutex<AuthService>>,
pub channels: Arc<tokio::sync::Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
/// Active streaming sessions for cancellation: session_id → cancellation sender
pub active_streams: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<()>>>>,
/// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver.
pub hear_channels: Arc<std::sync::Mutex<HashMap<uuid::Uuid, std::sync::mpsc::SyncSender<String>>>>,
pub web_adapter: Arc<WebChannelAdapter>,
@ -450,6 +452,7 @@ impl Clone for AppState {
kb_manager: self.kb_manager.clone(),
channels: Arc::clone(&self.channels),
response_channels: Arc::clone(&self.response_channels),
active_streams: Arc::clone(&self.active_streams),
hear_channels: Arc::clone(&self.hear_channels),
web_adapter: Arc::clone(&self.web_adapter),
voice_adapter: Arc::clone(&self.voice_adapter),
@ -665,6 +668,7 @@ impl Default for AppState {
auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())),
channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
web_adapter: Arc::new(WebChannelAdapter::new()),
voice_adapter: Arc::new(VoiceAdapter::new()),

View file

@ -606,16 +606,17 @@ pub async fn create_app_state(
dynamic_llm_provider: Some(dynamic_llm_provider.clone()),
#[cfg(feature = "directory")]
auth_service: auth_service.clone(),
channels: Arc::new(tokio::sync::Mutex::new({
let mut map = HashMap::new();
map.insert(
"web".to_string(),
web_adapter.clone() as Arc<dyn crate::core::bot::channels::ChannelAdapter>,
);
map
})),
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
channels: Arc::new(tokio::sync::Mutex::new({
let mut map = HashMap::new();
map.insert(
"web".to_string(),
web_adapter.clone() as Arc<dyn crate::core::bot::channels::ChannelAdapter>,
);
map
})),
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
web_adapter: web_adapter.clone(),
voice_adapter: voice_adapter.clone(),
#[cfg(any(feature = "research", feature = "llm"))]