feat(llm): add cancel_job support and integrate session cleanup
Introduce a new `cancel_job` method in the `LLMProvider` trait to allow cancellation of ongoing LLM tasks. Implement no-op versions for OpenAI, Anthropic, and Mock providers. Update the WebSocket handler to invoke job cancellation when a session closes, ensuring better resource management and preventing orphaned tasks. Also, fix unused variable warning in `add_suggestion.rs`.
This commit is contained in:
parent
87fc13b08a
commit
e7fa2f7bdb
4 changed files with 160 additions and 1 deletions
|
|
@ -9,7 +9,7 @@ pub fn clear_suggestions_keyword(state: Arc<AppState>, user_session: UserSession
|
||||||
let cache = state.cache.clone();
|
let cache = state.cache.clone();
|
||||||
|
|
||||||
engine
|
engine
|
||||||
.register_custom_syntax(&["CLEAR_SUGGESTIONS"], true, move |context, _inputs| {
|
.register_custom_syntax(&["CLEAR_SUGGESTIONS"], true, move |_context, _inputs| {
|
||||||
info!("CLEAR_SUGGESTIONS command executed");
|
info!("CLEAR_SUGGESTIONS command executed");
|
||||||
|
|
||||||
if let Some(cache_client) = &cache {
|
if let Some(cache_client) = &cache {
|
||||||
|
|
|
||||||
|
|
@ -1367,6 +1367,11 @@ async fn websocket_handler(
|
||||||
.unregister_response_channel(&session_id_clone2)
|
.unregister_response_channel(&session_id_clone2)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
// Cancel any ongoing LLM jobs for this session
|
||||||
|
if let Err(e) = data.llm_provider.cancel_job(&session_id_clone2).await {
|
||||||
|
warn!("Failed to cancel LLM job for session {}: {}", session_id_clone2, e);
|
||||||
|
}
|
||||||
|
|
||||||
info!("WebSocket fully closed for session {}", session_id_clone2);
|
info!("WebSocket fully closed for session {}", session_id_clone2);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
123
src/llm/llama.rs
Normal file
123
src/llm/llama.rs
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use crate::tools::ToolManager;
|
||||||
|
use super::LLMProvider;
|
||||||
|
|
||||||
|
pub struct LlamaClient {
|
||||||
|
client: Client,
|
||||||
|
base_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaClient {
|
||||||
|
pub fn new(base_url: String) -> Self {
|
||||||
|
Self {
|
||||||
|
client: Client::new(),
|
||||||
|
base_url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for LlamaClient {
|
||||||
|
async fn generate(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let response = self.client
|
||||||
|
.post(&format!("{}/completion", self.base_url))
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"prompt": prompt,
|
||||||
|
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
|
||||||
|
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let result: Value = response.json().await?;
|
||||||
|
Ok(result["content"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let response = self.client
|
||||||
|
.post(&format!("{}/completion", self.base_url))
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"prompt": prompt,
|
||||||
|
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
|
||||||
|
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
|
||||||
|
"stream": true
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut buffer = String::new();
|
||||||
|
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
let chunk = chunk?;
|
||||||
|
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||||
|
|
||||||
|
for line in chunk_str.lines() {
|
||||||
|
if let Ok(data) = serde_json::from_str::<Value>(&line) {
|
||||||
|
if let Some(content) = data["content"].as_str() {
|
||||||
|
buffer.push_str(content);
|
||||||
|
let _ = tx.send(content.to_string()).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
_tool_manager: Arc<ToolManager>,
|
||||||
|
_session_id: &str,
|
||||||
|
_user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let tools_info = if available_tools.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||||
|
};
|
||||||
|
|
||||||
|
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||||
|
self.generate(&enhanced_prompt, config).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_job(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
// llama.cpp cancellation endpoint
|
||||||
|
let response = self.client
|
||||||
|
.post(&format!("{}/cancel", self.base_url))
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"session_id": session_id
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if response.status().is_success() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Failed to cancel job for session {}", session_id).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -6,6 +6,8 @@ use tokio::sync::mpsc;
|
||||||
|
|
||||||
use crate::tools::ToolManager;
|
use crate::tools::ToolManager;
|
||||||
|
|
||||||
|
pub mod llama;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait LLMProvider: Send + Sync {
|
pub trait LLMProvider: Send + Sync {
|
||||||
async fn generate(
|
async fn generate(
|
||||||
|
|
@ -30,6 +32,11 @@ pub trait LLMProvider: Send + Sync {
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
|
async fn cancel_job(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAIClient {
|
pub struct OpenAIClient {
|
||||||
|
|
@ -144,6 +151,14 @@ impl LLMProvider for OpenAIClient {
|
||||||
|
|
||||||
self.generate(&enhanced_prompt, &Value::Null).await
|
self.generate(&enhanced_prompt, &Value::Null).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn cancel_job(
|
||||||
|
&self,
|
||||||
|
_session_id: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
// OpenAI doesn't support job cancellation
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AnthropicClient {
|
pub struct AnthropicClient {
|
||||||
|
|
@ -254,6 +269,14 @@ impl LLMProvider for AnthropicClient {
|
||||||
|
|
||||||
self.generate(&enhanced_prompt, &Value::Null).await
|
self.generate(&enhanced_prompt, &Value::Null).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn cancel_job(
|
||||||
|
&self,
|
||||||
|
_session_id: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
// Anthropic doesn't support job cancellation
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MockLLMProvider;
|
pub struct MockLLMProvider;
|
||||||
|
|
@ -307,4 +330,12 @@ impl LLMProvider for MockLLMProvider {
|
||||||
tools_list, prompt
|
tools_list, prompt
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn cancel_job(
|
||||||
|
&self,
|
||||||
|
_session_id: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
// Mock implementation just logs the cancellation
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue