feat(llm): add cancel_job support and integrate session cleanup
Introduce a new `cancel_job` method in the `LLMProvider` trait to allow cancellation of ongoing LLM tasks. Implement no-op versions for OpenAI, Anthropic, and Mock providers. Update the WebSocket handler to invoke job cancellation when a session closes, ensuring better resource management and preventing orphaned tasks. Also, fix unused variable warning in `add_suggestion.rs`.
This commit is contained in:
parent
87fc13b08a
commit
e7fa2f7bdb
4 changed files with 160 additions and 1 deletions
|
|
@ -9,7 +9,7 @@ pub fn clear_suggestions_keyword(state: Arc<AppState>, user_session: UserSession
|
|||
let cache = state.cache.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["CLEAR_SUGGESTIONS"], true, move |context, _inputs| {
|
||||
.register_custom_syntax(&["CLEAR_SUGGESTIONS"], true, move |_context, _inputs| {
|
||||
info!("CLEAR_SUGGESTIONS command executed");
|
||||
|
||||
if let Some(cache_client) = &cache {
|
||||
|
|
|
|||
|
|
@ -1367,6 +1367,11 @@ async fn websocket_handler(
|
|||
.unregister_response_channel(&session_id_clone2)
|
||||
.await;
|
||||
|
||||
// Cancel any ongoing LLM jobs for this session
|
||||
if let Err(e) = data.llm_provider.cancel_job(&session_id_clone2).await {
|
||||
warn!("Failed to cancel LLM job for session {}: {}", session_id_clone2, e);
|
||||
}
|
||||
|
||||
info!("WebSocket fully closed for session {}", session_id_clone2);
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
123
src/llm/llama.rs
Normal file
123
src/llm/llama.rs
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use futures_util::StreamExt;
|
||||
use crate::tools::ToolManager;
|
||||
use super::LLMProvider;
|
||||
|
||||
pub struct LlamaClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl LlamaClient {
|
||||
pub fn new(base_url: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for LlamaClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = self.client
|
||||
.post(&format!("{}/completion", self.base_url))
|
||||
.json(&serde_json::json!({
|
||||
"prompt": prompt,
|
||||
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
|
||||
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let result: Value = response.json().await?;
|
||||
Ok(result["content"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string())
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = self.client
|
||||
.post(&format!("{}/completion", self.base_url))
|
||||
.json(&serde_json::json!({
|
||||
"prompt": prompt,
|
||||
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
|
||||
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
|
||||
for line in chunk_str.lines() {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line) {
|
||||
if let Some(content) = data["content"].as_str() {
|
||||
buffer.push_str(content);
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_info = if available_tools.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||
};
|
||||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
self.generate(&enhanced_prompt, config).await
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// llama.cpp cancellation endpoint
|
||||
let response = self.client
|
||||
.post(&format!("{}/cancel", self.base_url))
|
||||
.json(&serde_json::json!({
|
||||
"session_id": session_id
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Failed to cancel job for session {}", session_id).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -6,6 +6,8 @@ use tokio::sync::mpsc;
|
|||
|
||||
use crate::tools::ToolManager;
|
||||
|
||||
pub mod llama;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
async fn generate(
|
||||
|
|
@ -30,6 +32,11 @@ pub trait LLMProvider: Send + Sync {
|
|||
session_id: &str,
|
||||
user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
|
||||
pub struct OpenAIClient {
|
||||
|
|
@ -144,6 +151,14 @@ impl LLMProvider for OpenAIClient {
|
|||
|
||||
self.generate(&enhanced_prompt, &Value::Null).await
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// OpenAI doesn't support job cancellation
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnthropicClient {
|
||||
|
|
@ -254,6 +269,14 @@ impl LLMProvider for AnthropicClient {
|
|||
|
||||
self.generate(&enhanced_prompt, &Value::Null).await
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Anthropic doesn't support job cancellation
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockLLMProvider;
|
||||
|
|
@ -307,4 +330,12 @@ impl LLMProvider for MockLLMProvider {
|
|||
tools_list, prompt
|
||||
))
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Mock implementation just logs the cancellation
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue