feat(llm): add cancel_job support and integrate session cleanup

Introduce a new `cancel_job` method in the `LLMProvider` trait to allow cancellation of ongoing LLM tasks. Implement no-op versions for OpenAI, Anthropic, and Mock providers. Update the WebSocket handler to invoke job cancellation when a session closes, ensuring better resource management and preventing orphaned tasks. Also, fix unused variable warning in `add_suggestion.rs`.
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-11-02 15:13:47 -03:00
parent 87fc13b08a
commit e7fa2f7bdb
4 changed files with 160 additions and 1 deletions

View file

@ -9,7 +9,7 @@ pub fn clear_suggestions_keyword(state: Arc<AppState>, user_session: UserSession
let cache = state.cache.clone();
engine
.register_custom_syntax(&["CLEAR_SUGGESTIONS"], true, move |context, _inputs| {
.register_custom_syntax(&["CLEAR_SUGGESTIONS"], true, move |_context, _inputs| {
info!("CLEAR_SUGGESTIONS command executed");
if let Some(cache_client) = &cache {

View file

@ -1367,6 +1367,11 @@ async fn websocket_handler(
.unregister_response_channel(&session_id_clone2)
.await;
// Cancel any ongoing LLM jobs for this session
if let Err(e) = data.llm_provider.cancel_job(&session_id_clone2).await {
warn!("Failed to cancel LLM job for session {}: {}", session_id_clone2, e);
}
info!("WebSocket fully closed for session {}", session_id_clone2);
break;
}

123
src/llm/llama.rs Normal file
View file

@ -0,0 +1,123 @@
use async_trait::async_trait;
use reqwest::Client;
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
use futures_util::StreamExt;
use crate::tools::ToolManager;
use super::LLMProvider;
pub struct LlamaClient {
client: Client,
base_url: String,
}
impl LlamaClient {
pub fn new(base_url: String) -> Self {
Self {
client: Client::new(),
base_url,
}
}
}
#[async_trait]
impl LLMProvider for LlamaClient {
async fn generate(
&self,
prompt: &str,
config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let response = self.client
.post(&format!("{}/completion", self.base_url))
.json(&serde_json::json!({
"prompt": prompt,
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
}))
.send()
.await?;
let result: Value = response.json().await?;
Ok(result["content"]
.as_str()
.unwrap_or("")
.to_string())
}
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = self.client
.post(&format!("{}/completion", self.base_url))
.json(&serde_json::json!({
"prompt": prompt,
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
"stream": true
}))
.send()
.await?;
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk_str = String::from_utf8_lossy(&chunk);
for line in chunk_str.lines() {
if let Ok(data) = serde_json::from_str::<Value>(&line) {
if let Some(content) = data["content"].as_str() {
buffer.push_str(content);
let _ = tx.send(content.to_string()).await;
}
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
self.generate(&enhanced_prompt, config).await
}
async fn cancel_job(
&self,
session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// llama.cpp cancellation endpoint
let response = self.client
.post(&format!("{}/cancel", self.base_url))
.json(&serde_json::json!({
"session_id": session_id
}))
.send()
.await?;
if response.status().is_success() {
Ok(())
} else {
Err(format!("Failed to cancel job for session {}", session_id).into())
}
}
}

View file

@ -6,6 +6,8 @@ use tokio::sync::mpsc;
use crate::tools::ToolManager;
pub mod llama;
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn generate(
@ -30,6 +32,11 @@ pub trait LLMProvider: Send + Sync {
session_id: &str,
user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn cancel_job(
&self,
session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
}
pub struct OpenAIClient {
@ -144,6 +151,14 @@ impl LLMProvider for OpenAIClient {
self.generate(&enhanced_prompt, &Value::Null).await
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// OpenAI doesn't support job cancellation
Ok(())
}
}
pub struct AnthropicClient {
@ -254,6 +269,14 @@ impl LLMProvider for AnthropicClient {
self.generate(&enhanced_prompt, &Value::Null).await
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Anthropic doesn't support job cancellation
Ok(())
}
}
pub struct MockLLMProvider;
@ -307,4 +330,12 @@ impl LLMProvider for MockLLMProvider {
tools_list, prompt
))
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Mock implementation just logs the cancellation
Ok(())
}
}