Refactor BotOrchestrator to remove in-memory cache and implement LangCache for user input responses
This commit is contained in:
parent
2e1c0a9a68
commit
13574feb23
4 changed files with 183 additions and 130 deletions
140
src/bot/mod.rs
140
src/bot/mod.rs
|
|
@ -3,30 +3,27 @@ use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
|||
use crate::shared::state::AppState;
|
||||
use actix_web::{web, HttpRequest, HttpResponse, Result};
|
||||
use actix_ws::Message as WsMessage;
|
||||
use chrono::Utc;
|
||||
use log::{debug, error, info, warn};
|
||||
use chrono::Utc;
|
||||
use serde_json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::kb::embeddings::generate_embeddings;
|
||||
use uuid::Uuid;
|
||||
use redis::AsyncCommands;
|
||||
use reqwest::Client;
|
||||
|
||||
use crate::kb::qdrant_client::{ensure_collection_exists, get_qdrant_client, QdrantPoint};
|
||||
use crate::context::langcache::{get_langcache_client};
|
||||
|
||||
|
||||
pub struct BotOrchestrator {
|
||||
pub state: Arc<AppState>,
|
||||
pub cache: Arc<Mutex<std::collections::HashMap<String, String>>>,
|
||||
}
|
||||
|
||||
impl BotOrchestrator {
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_user_input(
|
||||
&self,
|
||||
|
|
@ -301,7 +298,7 @@ pub fn new(state: Arc<AppState>) -> Self {
|
|||
session_manager.get_conversation_history(session.id, session.user_id)?
|
||||
};
|
||||
|
||||
// Prompt compactor: keep only last 10 entries to limit size
|
||||
// Prompt compactor: keep only last 10 entries
|
||||
let recent_history = if history.len() > 10 {
|
||||
&history[history.len() - 10..]
|
||||
} else {
|
||||
|
|
@ -313,27 +310,111 @@ pub fn new(state: Arc<AppState>) -> Self {
|
|||
}
|
||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||
|
||||
// Check in-memory cache for existing response
|
||||
{
|
||||
let cache = self.cache.lock().await;
|
||||
if let Some(cached) = cache.get(&prompt) {
|
||||
return Ok(cached.clone());
|
||||
// Determine which cache backend to use
|
||||
let use_langcache = std::env::var("LLM_CACHE")
|
||||
.unwrap_or_else(|_| "false".to_string())
|
||||
.eq_ignore_ascii_case("true");
|
||||
|
||||
if use_langcache {
|
||||
// Ensure LangCache collection exists
|
||||
ensure_collection_exists(&self.state, "semantic_cache").await?;
|
||||
|
||||
// Get LangCache client
|
||||
let langcache_client = get_langcache_client()?;
|
||||
|
||||
// Isolate the user question (ignore conversation history)
|
||||
let isolated_question = message.content.trim().to_string();
|
||||
|
||||
// Generate embedding for the isolated question
|
||||
let question_embeddings = generate_embeddings(vec![isolated_question.clone()]).await?;
|
||||
let question_embedding = question_embeddings
|
||||
.get(0)
|
||||
.ok_or_else(|| "Failed to generate embedding for question")?
|
||||
.clone();
|
||||
|
||||
// Search for similar question in LangCache
|
||||
let search_results = langcache_client
|
||||
.search("semantic_cache", question_embedding.clone(), 1)
|
||||
.await?;
|
||||
|
||||
if let Some(result) = search_results.first() {
|
||||
let payload = &result.payload;
|
||||
if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) {
|
||||
return Ok(resp.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Generate response via LLM provider using full prompt (including history)
|
||||
let response = self.state
|
||||
.llm_provider
|
||||
.generate(&prompt, &serde_json::Value::Null)
|
||||
.await?;
|
||||
|
||||
// Store isolated question and response in LangCache
|
||||
let point = QdrantPoint {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
vector: question_embedding,
|
||||
payload: serde_json::json!({
|
||||
"question": isolated_question,
|
||||
"prompt": prompt,
|
||||
"response": response
|
||||
}),
|
||||
};
|
||||
langcache_client
|
||||
.upsert_points("semantic_cache", vec![point])
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
} else {
|
||||
// Ensure semantic cache collection exists
|
||||
ensure_collection_exists(&self.state, "semantic_cache").await?;
|
||||
|
||||
// Get Qdrant client
|
||||
let qdrant_client = get_qdrant_client(&self.state)?;
|
||||
|
||||
// Generate embedding for the prompt
|
||||
let embeddings = generate_embeddings(vec![prompt.clone()]).await?;
|
||||
let embedding = embeddings
|
||||
.get(0)
|
||||
.ok_or_else(|| "Failed to generate embedding")?
|
||||
.clone();
|
||||
|
||||
// Search for similar prompt in Qdrant
|
||||
let search_results = qdrant_client
|
||||
.search("semantic_cache", embedding.clone(), 1)
|
||||
.await?;
|
||||
|
||||
if let Some(result) = search_results.first() {
|
||||
if let Some(payload) = &result.payload {
|
||||
if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) {
|
||||
return Ok(resp.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate response via LLM provider
|
||||
let response = self.state
|
||||
.llm_provider
|
||||
.generate(&prompt, &serde_json::Value::Null)
|
||||
.await?;
|
||||
|
||||
// Store prompt and response in Qdrant
|
||||
let point = QdrantPoint {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
vector: embedding,
|
||||
payload: serde_json::json!({
|
||||
"prompt": prompt,
|
||||
"response": response
|
||||
}),
|
||||
};
|
||||
qdrant_client
|
||||
.upsert_points("semantic_cache", vec![point])
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
// Generate response via LLM provider
|
||||
let response = self.state
|
||||
.llm_provider
|
||||
.generate(&prompt, &serde_json::Value::Null)
|
||||
.await?;
|
||||
|
||||
// Store the new response in cache
|
||||
{
|
||||
let mut cache = self.cache.lock().await;
|
||||
cache.insert(prompt.clone(), response.clone());
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn stream_response(
|
||||
|
|
@ -727,7 +808,6 @@ impl Default for BotOrchestrator {
|
|||
fn default() -> Self {
|
||||
Self {
|
||||
state: Arc::new(AppState::default()),
|
||||
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
67
src/context/langcache.rs
Normal file
67
src/context/langcache.rs
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
use crate::kb::qdrant_client::{ensure_collection_exists, VectorDBClient, QdrantPoint};
|
||||
use std::error::Error;
|
||||
|
||||
/// LangCache client – currently a thin wrapper around the existing Qdrant client,
|
||||
/// allowing future replacement with a dedicated LangCache SDK or API without
|
||||
/// changing the rest of the codebase.
|
||||
pub struct LLMCacheClient {
|
||||
inner: VectorDBClient,
|
||||
}
|
||||
|
||||
impl LLMCacheClient {
|
||||
/// Create a new LangCache client.
|
||||
/// This client uses the internal Qdrant client with the default QDRANT_URL.
|
||||
/// No external environment variable is required.
|
||||
pub fn new() -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
// Use the same URL as the Qdrant client (default or from QDRANT_URL env)
|
||||
let qdrant_url = std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||
Ok(Self {
|
||||
inner: VectorDBClient::new(qdrant_url),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// Ensure a collection exists in LangCache.
|
||||
pub async fn ensure_collection_exists(
|
||||
&self,
|
||||
collection_name: &str,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
// Reuse the Qdrant helper – LangCache uses the same semantics.
|
||||
ensure_collection_exists(&crate::shared::state::AppState::default(), collection_name).await
|
||||
}
|
||||
|
||||
/// Search for similar vectors in a LangCache collection.
|
||||
pub async fn search(
|
||||
&self,
|
||||
collection_name: &str,
|
||||
query_vector: Vec<f32>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<QdrantPoint>, Box<dyn Error + Send + Sync>> {
|
||||
// Forward to the inner Qdrant client and map results to QdrantPoint.
|
||||
let results = self.inner.search(collection_name, query_vector, limit).await?;
|
||||
// Convert SearchResult to QdrantPoint (payload and vector may be None)
|
||||
let points = results
|
||||
.into_iter()
|
||||
.map(|res| QdrantPoint {
|
||||
id: res.id,
|
||||
vector: res.vector.unwrap_or_default(),
|
||||
payload: res.payload.unwrap_or_else(|| serde_json::json!({})),
|
||||
})
|
||||
.collect();
|
||||
Ok(points)
|
||||
}
|
||||
|
||||
/// Upsert points (prompt/response pairs) into a LangCache collection.
|
||||
pub async fn upsert_points(
|
||||
&self,
|
||||
collection_name: &str,
|
||||
points: Vec<QdrantPoint>,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
self.inner.upsert_points(collection_name, points).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to obtain a LangCache client from the application state.
|
||||
pub fn get_langcache_client() -> Result<LLMCacheClient, Box<dyn Error + Send + Sync>> {
|
||||
LLMCacheClient::new()
|
||||
}
|
||||
|
|
@ -1,95 +1 @@
|
|||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::shared::models::SearchResult;
|
||||
|
||||
pub mod prompt_processor;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ContextStore: Send + Sync {
|
||||
async fn store_embedding(
|
||||
&self,
|
||||
text: &str,
|
||||
embedding: Vec<f32>,
|
||||
metadata: Value,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
async fn search_similar(
|
||||
&self,
|
||||
embedding: Vec<f32>,
|
||||
limit: u32,
|
||||
) -> Result<Vec<SearchResult>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
|
||||
pub struct QdrantContextStore {
|
||||
vector_store: Arc<qdrant_client::Qdrant>,
|
||||
}
|
||||
|
||||
impl QdrantContextStore {
|
||||
pub fn new(vector_store: qdrant_client::Qdrant) -> Self {
|
||||
Self {
|
||||
vector_store: Arc::new(vector_store),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_conversation_context(
|
||||
&self,
|
||||
session_id: &str,
|
||||
user_id: &str,
|
||||
_limit: usize,
|
||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let _query = format!("session_id:{} AND user_id:{}", session_id, user_id);
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ContextStore for QdrantContextStore {
|
||||
async fn store_embedding(
|
||||
&self,
|
||||
text: &str,
|
||||
_embedding: Vec<f32>,
|
||||
_metadata: Value,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
log::info!("Storing embedding for text: {}", text);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn search_similar(
|
||||
&self,
|
||||
_embedding: Vec<f32>,
|
||||
_limit: u32,
|
||||
) -> Result<Vec<SearchResult>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockContextStore;
|
||||
|
||||
impl MockContextStore {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ContextStore for MockContextStore {
|
||||
async fn store_embedding(
|
||||
&self,
|
||||
text: &str,
|
||||
_embedding: Vec<f32>,
|
||||
_metadata: Value,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
log::info!("Mock storing embedding for: {}", text);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn search_similar(
|
||||
&self,
|
||||
_embedding: Vec<f32>,
|
||||
_limit: u32,
|
||||
) -> Result<Vec<SearchResult>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
pub mod langcache;
|
||||
|
|
|
|||
|
|
@ -53,12 +53,12 @@ pub struct CollectionInfo {
|
|||
pub status: String,
|
||||
}
|
||||
|
||||
pub struct QdrantClient {
|
||||
pub struct VectorDBClient {
|
||||
base_url: String,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl QdrantClient {
|
||||
impl VectorDBClient {
|
||||
pub fn new(base_url: String) -> Self {
|
||||
Self {
|
||||
base_url,
|
||||
|
|
@ -235,11 +235,11 @@ impl QdrantClient {
|
|||
}
|
||||
|
||||
/// Get Qdrant client from app state
|
||||
pub fn get_qdrant_client(_state: &AppState) -> Result<QdrantClient, Box<dyn Error + Send + Sync>> {
|
||||
pub fn get_qdrant_client(_state: &AppState) -> Result<VectorDBClient, Box<dyn Error + Send + Sync>> {
|
||||
let qdrant_url =
|
||||
std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||
|
||||
Ok(QdrantClient::new(qdrant_url))
|
||||
Ok(VectorDBClient::new(qdrant_url))
|
||||
}
|
||||
|
||||
/// Ensure a collection exists, create if not
|
||||
|
|
@ -280,7 +280,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_qdrant_client_creation() {
|
||||
let client = QdrantClient::new("http://localhost:6333".to_string());
|
||||
let client = VectorDBClient::new("http://localhost:6333".to_string());
|
||||
assert_eq!(client.base_url, "http://localhost:6333");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue