diff --git a/src/bot/mod.rs b/src/bot/mod.rs index 73bbed0c2..bd56ba6d5 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -3,30 +3,27 @@ use crate::shared::models::{BotResponse, UserMessage, UserSession}; use crate::shared::state::AppState; use actix_web::{web, HttpRequest, HttpResponse, Result}; use actix_ws::Message as WsMessage; -use chrono::Utc; use log::{debug, error, info, warn}; +use chrono::Utc; use serde_json; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::mpsc; -use tokio::sync::Mutex; +use crate::kb::embeddings::generate_embeddings; use uuid::Uuid; -use redis::AsyncCommands; -use reqwest::Client; + +use crate::kb::qdrant_client::{ensure_collection_exists, get_qdrant_client, QdrantPoint}; +use crate::context::langcache::{get_langcache_client}; + pub struct BotOrchestrator { pub state: Arc, - pub cache: Arc>>, } impl BotOrchestrator { -pub fn new(state: Arc) -> Self { - Self { - state, - cache: Arc::new(Mutex::new(std::collections::HashMap::new())), + pub fn new(state: Arc) -> Self { + Self { state } } -} -} pub async fn handle_user_input( &self, @@ -301,7 +298,7 @@ pub fn new(state: Arc) -> Self { session_manager.get_conversation_history(session.id, session.user_id)? }; - // Prompt compactor: keep only last 10 entries to limit size + // Prompt compactor: keep only last 10 entries let recent_history = if history.len() > 10 { &history[history.len() - 10..] } else { @@ -313,27 +310,111 @@ pub fn new(state: Arc) -> Self { } prompt.push_str(&format!("User: {}\nAssistant:", message.content)); - // Check in-memory cache for existing response - { - let cache = self.cache.lock().await; - if let Some(cached) = cache.get(&prompt) { - return Ok(cached.clone()); + // Determine which cache backend to use + let use_langcache = std::env::var("LLM_CACHE") + .unwrap_or_else(|_| "false".to_string()) + .eq_ignore_ascii_case("true"); + + if use_langcache { + // Ensure LangCache collection exists + ensure_collection_exists(&self.state, "semantic_cache").await?; + + // Get LangCache client + let langcache_client = get_langcache_client()?; + + // Isolate the user question (ignore conversation history) + let isolated_question = message.content.trim().to_string(); + + // Generate embedding for the isolated question + let question_embeddings = generate_embeddings(vec![isolated_question.clone()]).await?; + let question_embedding = question_embeddings + .get(0) + .ok_or_else(|| "Failed to generate embedding for question")? + .clone(); + + // Search for similar question in LangCache + let search_results = langcache_client + .search("semantic_cache", question_embedding.clone(), 1) + .await?; + + if let Some(result) = search_results.first() { + let payload = &result.payload; + if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) { + return Ok(resp.to_string()); + } } + + // Generate response via LLM provider using full prompt (including history) + let response = self.state + .llm_provider + .generate(&prompt, &serde_json::Value::Null) + .await?; + + // Store isolated question and response in LangCache + let point = QdrantPoint { + id: uuid::Uuid::new_v4().to_string(), + vector: question_embedding, + payload: serde_json::json!({ + "question": isolated_question, + "prompt": prompt, + "response": response + }), + }; + langcache_client + .upsert_points("semantic_cache", vec![point]) + .await?; + + Ok(response) + } else { + // Ensure semantic cache collection exists + ensure_collection_exists(&self.state, "semantic_cache").await?; + + // Get Qdrant client + let qdrant_client = get_qdrant_client(&self.state)?; + + // Generate embedding for the prompt + let embeddings = generate_embeddings(vec![prompt.clone()]).await?; + let embedding = embeddings + .get(0) + .ok_or_else(|| "Failed to generate embedding")? + .clone(); + + // Search for similar prompt in Qdrant + let search_results = qdrant_client + .search("semantic_cache", embedding.clone(), 1) + .await?; + + if let Some(result) = search_results.first() { + if let Some(payload) = &result.payload { + if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) { + return Ok(resp.to_string()); + } + } + } + + // Generate response via LLM provider + let response = self.state + .llm_provider + .generate(&prompt, &serde_json::Value::Null) + .await?; + + // Store prompt and response in Qdrant + let point = QdrantPoint { + id: uuid::Uuid::new_v4().to_string(), + vector: embedding, + payload: serde_json::json!({ + "prompt": prompt, + "response": response + }), + }; + qdrant_client + .upsert_points("semantic_cache", vec![point]) + .await?; + + Ok(response) } - // Generate response via LLM provider - let response = self.state - .llm_provider - .generate(&prompt, &serde_json::Value::Null) - .await?; - // Store the new response in cache - { - let mut cache = self.cache.lock().await; - cache.insert(prompt.clone(), response.clone()); - } - - Ok(response) } pub async fn stream_response( @@ -727,7 +808,6 @@ impl Default for BotOrchestrator { fn default() -> Self { Self { state: Arc::new(AppState::default()), - cache: Arc::new(Mutex::new(std::collections::HashMap::new())), } } } diff --git a/src/context/langcache.rs b/src/context/langcache.rs new file mode 100644 index 000000000..6e2dc607d --- /dev/null +++ b/src/context/langcache.rs @@ -0,0 +1,67 @@ +use crate::kb::qdrant_client::{ensure_collection_exists, VectorDBClient, QdrantPoint}; +use std::error::Error; + +/// LangCache client – currently a thin wrapper around the existing Qdrant client, +/// allowing future replacement with a dedicated LangCache SDK or API without +/// changing the rest of the codebase. +pub struct LLMCacheClient { + inner: VectorDBClient, +} + +impl LLMCacheClient { + /// Create a new LangCache client. + /// This client uses the internal Qdrant client with the default QDRANT_URL. + /// No external environment variable is required. + pub fn new() -> Result> { + // Use the same URL as the Qdrant client (default or from QDRANT_URL env) + let qdrant_url = std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string()); + Ok(Self { + inner: VectorDBClient::new(qdrant_url), + }) + } + + + /// Ensure a collection exists in LangCache. + pub async fn ensure_collection_exists( + &self, + collection_name: &str, + ) -> Result<(), Box> { + // Reuse the Qdrant helper – LangCache uses the same semantics. + ensure_collection_exists(&crate::shared::state::AppState::default(), collection_name).await + } + + /// Search for similar vectors in a LangCache collection. + pub async fn search( + &self, + collection_name: &str, + query_vector: Vec, + limit: usize, + ) -> Result, Box> { + // Forward to the inner Qdrant client and map results to QdrantPoint. + let results = self.inner.search(collection_name, query_vector, limit).await?; + // Convert SearchResult to QdrantPoint (payload and vector may be None) + let points = results + .into_iter() + .map(|res| QdrantPoint { + id: res.id, + vector: res.vector.unwrap_or_default(), + payload: res.payload.unwrap_or_else(|| serde_json::json!({})), + }) + .collect(); + Ok(points) + } + + /// Upsert points (prompt/response pairs) into a LangCache collection. + pub async fn upsert_points( + &self, + collection_name: &str, + points: Vec, + ) -> Result<(), Box> { + self.inner.upsert_points(collection_name, points).await + } +} + +/// Helper to obtain a LangCache client from the application state. +pub fn get_langcache_client() -> Result> { + LLMCacheClient::new() +} diff --git a/src/context/mod.rs b/src/context/mod.rs index ad33fab5d..a440f278f 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -1,95 +1 @@ -use async_trait::async_trait; -use serde_json::Value; -use std::sync::Arc; - -use crate::shared::models::SearchResult; - -pub mod prompt_processor; - -#[async_trait] -pub trait ContextStore: Send + Sync { - async fn store_embedding( - &self, - text: &str, - embedding: Vec, - metadata: Value, - ) -> Result<(), Box>; - - async fn search_similar( - &self, - embedding: Vec, - limit: u32, - ) -> Result, Box>; -} - -pub struct QdrantContextStore { - vector_store: Arc, -} - -impl QdrantContextStore { - pub fn new(vector_store: qdrant_client::Qdrant) -> Self { - Self { - vector_store: Arc::new(vector_store), - } - } - - pub async fn get_conversation_context( - &self, - session_id: &str, - user_id: &str, - _limit: usize, - ) -> Result, Box> { - let _query = format!("session_id:{} AND user_id:{}", session_id, user_id); - Ok(vec![]) - } -} - -#[async_trait] -impl ContextStore for QdrantContextStore { - async fn store_embedding( - &self, - text: &str, - _embedding: Vec, - _metadata: Value, - ) -> Result<(), Box> { - log::info!("Storing embedding for text: {}", text); - Ok(()) - } - - async fn search_similar( - &self, - _embedding: Vec, - _limit: u32, - ) -> Result, Box> { - Ok(vec![]) - } -} - -pub struct MockContextStore; - -impl MockContextStore { - pub fn new() -> Self { - Self - } -} - -#[async_trait] -impl ContextStore for MockContextStore { - async fn store_embedding( - &self, - text: &str, - _embedding: Vec, - _metadata: Value, - ) -> Result<(), Box> { - log::info!("Mock storing embedding for: {}", text); - Ok(()) - } - - async fn search_similar( - &self, - _embedding: Vec, - _limit: u32, - ) -> Result, Box> { - Ok(vec![]) - } -} +pub mod langcache; diff --git a/src/kb/qdrant_client.rs b/src/kb/qdrant_client.rs index 58923876a..850dd9c8e 100644 --- a/src/kb/qdrant_client.rs +++ b/src/kb/qdrant_client.rs @@ -53,12 +53,12 @@ pub struct CollectionInfo { pub status: String, } -pub struct QdrantClient { +pub struct VectorDBClient { base_url: String, client: Client, } -impl QdrantClient { +impl VectorDBClient { pub fn new(base_url: String) -> Self { Self { base_url, @@ -235,11 +235,11 @@ impl QdrantClient { } /// Get Qdrant client from app state -pub fn get_qdrant_client(_state: &AppState) -> Result> { +pub fn get_qdrant_client(_state: &AppState) -> Result> { let qdrant_url = std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string()); - Ok(QdrantClient::new(qdrant_url)) + Ok(VectorDBClient::new(qdrant_url)) } /// Ensure a collection exists, create if not @@ -280,7 +280,7 @@ mod tests { #[test] fn test_qdrant_client_creation() { - let client = QdrantClient::new("http://localhost:6333".to_string()); + let client = VectorDBClient::new("http://localhost:6333".to_string()); assert_eq!(client.base_url, "http://localhost:6333"); } }