Refactor BotOrchestrator to remove in-memory cache and implement LangCache for user input responses
This commit is contained in:
parent
2e1c0a9a68
commit
13574feb23
4 changed files with 183 additions and 130 deletions
140
src/bot/mod.rs
140
src/bot/mod.rs
|
|
@ -3,30 +3,27 @@ use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
||||||
use crate::shared::state::AppState;
|
use crate::shared::state::AppState;
|
||||||
use actix_web::{web, HttpRequest, HttpResponse, Result};
|
use actix_web::{web, HttpRequest, HttpResponse, Result};
|
||||||
use actix_ws::Message as WsMessage;
|
use actix_ws::Message as WsMessage;
|
||||||
use chrono::Utc;
|
|
||||||
use log::{debug, error, info, warn};
|
use log::{debug, error, info, warn};
|
||||||
|
use chrono::Utc;
|
||||||
use serde_json;
|
use serde_json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::Mutex;
|
use crate::kb::embeddings::generate_embeddings;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use redis::AsyncCommands;
|
|
||||||
use reqwest::Client;
|
use crate::kb::qdrant_client::{ensure_collection_exists, get_qdrant_client, QdrantPoint};
|
||||||
|
use crate::context::langcache::{get_langcache_client};
|
||||||
|
|
||||||
|
|
||||||
pub struct BotOrchestrator {
|
pub struct BotOrchestrator {
|
||||||
pub state: Arc<AppState>,
|
pub state: Arc<AppState>,
|
||||||
pub cache: Arc<Mutex<std::collections::HashMap<String, String>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BotOrchestrator {
|
impl BotOrchestrator {
|
||||||
pub fn new(state: Arc<AppState>) -> Self {
|
pub fn new(state: Arc<AppState>) -> Self {
|
||||||
Self {
|
Self { state }
|
||||||
state,
|
|
||||||
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn handle_user_input(
|
pub async fn handle_user_input(
|
||||||
&self,
|
&self,
|
||||||
|
|
@ -301,7 +298,7 @@ pub fn new(state: Arc<AppState>) -> Self {
|
||||||
session_manager.get_conversation_history(session.id, session.user_id)?
|
session_manager.get_conversation_history(session.id, session.user_id)?
|
||||||
};
|
};
|
||||||
|
|
||||||
// Prompt compactor: keep only last 10 entries to limit size
|
// Prompt compactor: keep only last 10 entries
|
||||||
let recent_history = if history.len() > 10 {
|
let recent_history = if history.len() > 10 {
|
||||||
&history[history.len() - 10..]
|
&history[history.len() - 10..]
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -313,27 +310,111 @@ pub fn new(state: Arc<AppState>) -> Self {
|
||||||
}
|
}
|
||||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||||
|
|
||||||
// Check in-memory cache for existing response
|
// Determine which cache backend to use
|
||||||
{
|
let use_langcache = std::env::var("LLM_CACHE")
|
||||||
let cache = self.cache.lock().await;
|
.unwrap_or_else(|_| "false".to_string())
|
||||||
if let Some(cached) = cache.get(&prompt) {
|
.eq_ignore_ascii_case("true");
|
||||||
return Ok(cached.clone());
|
|
||||||
|
if use_langcache {
|
||||||
|
// Ensure LangCache collection exists
|
||||||
|
ensure_collection_exists(&self.state, "semantic_cache").await?;
|
||||||
|
|
||||||
|
// Get LangCache client
|
||||||
|
let langcache_client = get_langcache_client()?;
|
||||||
|
|
||||||
|
// Isolate the user question (ignore conversation history)
|
||||||
|
let isolated_question = message.content.trim().to_string();
|
||||||
|
|
||||||
|
// Generate embedding for the isolated question
|
||||||
|
let question_embeddings = generate_embeddings(vec![isolated_question.clone()]).await?;
|
||||||
|
let question_embedding = question_embeddings
|
||||||
|
.get(0)
|
||||||
|
.ok_or_else(|| "Failed to generate embedding for question")?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
// Search for similar question in LangCache
|
||||||
|
let search_results = langcache_client
|
||||||
|
.search("semantic_cache", question_embedding.clone(), 1)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(result) = search_results.first() {
|
||||||
|
let payload = &result.payload;
|
||||||
|
if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) {
|
||||||
|
return Ok(resp.to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate response via LLM provider using full prompt (including history)
|
||||||
|
let response = self.state
|
||||||
|
.llm_provider
|
||||||
|
.generate(&prompt, &serde_json::Value::Null)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Store isolated question and response in LangCache
|
||||||
|
let point = QdrantPoint {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
vector: question_embedding,
|
||||||
|
payload: serde_json::json!({
|
||||||
|
"question": isolated_question,
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": response
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
langcache_client
|
||||||
|
.upsert_points("semantic_cache", vec![point])
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
} else {
|
||||||
|
// Ensure semantic cache collection exists
|
||||||
|
ensure_collection_exists(&self.state, "semantic_cache").await?;
|
||||||
|
|
||||||
|
// Get Qdrant client
|
||||||
|
let qdrant_client = get_qdrant_client(&self.state)?;
|
||||||
|
|
||||||
|
// Generate embedding for the prompt
|
||||||
|
let embeddings = generate_embeddings(vec![prompt.clone()]).await?;
|
||||||
|
let embedding = embeddings
|
||||||
|
.get(0)
|
||||||
|
.ok_or_else(|| "Failed to generate embedding")?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
// Search for similar prompt in Qdrant
|
||||||
|
let search_results = qdrant_client
|
||||||
|
.search("semantic_cache", embedding.clone(), 1)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(result) = search_results.first() {
|
||||||
|
if let Some(payload) = &result.payload {
|
||||||
|
if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) {
|
||||||
|
return Ok(resp.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate response via LLM provider
|
||||||
|
let response = self.state
|
||||||
|
.llm_provider
|
||||||
|
.generate(&prompt, &serde_json::Value::Null)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Store prompt and response in Qdrant
|
||||||
|
let point = QdrantPoint {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
vector: embedding,
|
||||||
|
payload: serde_json::json!({
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": response
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
qdrant_client
|
||||||
|
.upsert_points("semantic_cache", vec![point])
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate response via LLM provider
|
|
||||||
let response = self.state
|
|
||||||
.llm_provider
|
|
||||||
.generate(&prompt, &serde_json::Value::Null)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Store the new response in cache
|
|
||||||
{
|
|
||||||
let mut cache = self.cache.lock().await;
|
|
||||||
cache.insert(prompt.clone(), response.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_response(
|
pub async fn stream_response(
|
||||||
|
|
@ -727,7 +808,6 @@ impl Default for BotOrchestrator {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
state: Arc::new(AppState::default()),
|
state: Arc::new(AppState::default()),
|
||||||
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
67
src/context/langcache.rs
Normal file
67
src/context/langcache.rs
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
use crate::kb::qdrant_client::{ensure_collection_exists, VectorDBClient, QdrantPoint};
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
/// LangCache client – currently a thin wrapper around the existing Qdrant client,
|
||||||
|
/// allowing future replacement with a dedicated LangCache SDK or API without
|
||||||
|
/// changing the rest of the codebase.
|
||||||
|
pub struct LLMCacheClient {
|
||||||
|
inner: VectorDBClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LLMCacheClient {
|
||||||
|
/// Create a new LangCache client.
|
||||||
|
/// This client uses the internal Qdrant client with the default QDRANT_URL.
|
||||||
|
/// No external environment variable is required.
|
||||||
|
pub fn new() -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||||
|
// Use the same URL as the Qdrant client (default or from QDRANT_URL env)
|
||||||
|
let qdrant_url = std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||||
|
Ok(Self {
|
||||||
|
inner: VectorDBClient::new(qdrant_url),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Ensure a collection exists in LangCache.
|
||||||
|
pub async fn ensure_collection_exists(
|
||||||
|
&self,
|
||||||
|
collection_name: &str,
|
||||||
|
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
|
// Reuse the Qdrant helper – LangCache uses the same semantics.
|
||||||
|
ensure_collection_exists(&crate::shared::state::AppState::default(), collection_name).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search for similar vectors in a LangCache collection.
|
||||||
|
pub async fn search(
|
||||||
|
&self,
|
||||||
|
collection_name: &str,
|
||||||
|
query_vector: Vec<f32>,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<QdrantPoint>, Box<dyn Error + Send + Sync>> {
|
||||||
|
// Forward to the inner Qdrant client and map results to QdrantPoint.
|
||||||
|
let results = self.inner.search(collection_name, query_vector, limit).await?;
|
||||||
|
// Convert SearchResult to QdrantPoint (payload and vector may be None)
|
||||||
|
let points = results
|
||||||
|
.into_iter()
|
||||||
|
.map(|res| QdrantPoint {
|
||||||
|
id: res.id,
|
||||||
|
vector: res.vector.unwrap_or_default(),
|
||||||
|
payload: res.payload.unwrap_or_else(|| serde_json::json!({})),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(points)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upsert points (prompt/response pairs) into a LangCache collection.
|
||||||
|
pub async fn upsert_points(
|
||||||
|
&self,
|
||||||
|
collection_name: &str,
|
||||||
|
points: Vec<QdrantPoint>,
|
||||||
|
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
|
self.inner.upsert_points(collection_name, points).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to obtain a LangCache client from the application state.
|
||||||
|
pub fn get_langcache_client() -> Result<LLMCacheClient, Box<dyn Error + Send + Sync>> {
|
||||||
|
LLMCacheClient::new()
|
||||||
|
}
|
||||||
|
|
@ -1,95 +1 @@
|
||||||
use async_trait::async_trait;
|
pub mod langcache;
|
||||||
use serde_json::Value;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::shared::models::SearchResult;
|
|
||||||
|
|
||||||
pub mod prompt_processor;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait ContextStore: Send + Sync {
|
|
||||||
async fn store_embedding(
|
|
||||||
&self,
|
|
||||||
text: &str,
|
|
||||||
embedding: Vec<f32>,
|
|
||||||
metadata: Value,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
|
|
||||||
async fn search_similar(
|
|
||||||
&self,
|
|
||||||
embedding: Vec<f32>,
|
|
||||||
limit: u32,
|
|
||||||
) -> Result<Vec<SearchResult>, Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct QdrantContextStore {
|
|
||||||
vector_store: Arc<qdrant_client::Qdrant>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QdrantContextStore {
|
|
||||||
pub fn new(vector_store: qdrant_client::Qdrant) -> Self {
|
|
||||||
Self {
|
|
||||||
vector_store: Arc::new(vector_store),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_conversation_context(
|
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
user_id: &str,
|
|
||||||
_limit: usize,
|
|
||||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let _query = format!("session_id:{} AND user_id:{}", session_id, user_id);
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ContextStore for QdrantContextStore {
|
|
||||||
async fn store_embedding(
|
|
||||||
&self,
|
|
||||||
text: &str,
|
|
||||||
_embedding: Vec<f32>,
|
|
||||||
_metadata: Value,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
log::info!("Storing embedding for text: {}", text);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn search_similar(
|
|
||||||
&self,
|
|
||||||
_embedding: Vec<f32>,
|
|
||||||
_limit: u32,
|
|
||||||
) -> Result<Vec<SearchResult>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct MockContextStore;
|
|
||||||
|
|
||||||
impl MockContextStore {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ContextStore for MockContextStore {
|
|
||||||
async fn store_embedding(
|
|
||||||
&self,
|
|
||||||
text: &str,
|
|
||||||
_embedding: Vec<f32>,
|
|
||||||
_metadata: Value,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
log::info!("Mock storing embedding for: {}", text);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn search_similar(
|
|
||||||
&self,
|
|
||||||
_embedding: Vec<f32>,
|
|
||||||
_limit: u32,
|
|
||||||
) -> Result<Vec<SearchResult>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -53,12 +53,12 @@ pub struct CollectionInfo {
|
||||||
pub status: String,
|
pub status: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct QdrantClient {
|
pub struct VectorDBClient {
|
||||||
base_url: String,
|
base_url: String,
|
||||||
client: Client,
|
client: Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QdrantClient {
|
impl VectorDBClient {
|
||||||
pub fn new(base_url: String) -> Self {
|
pub fn new(base_url: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
base_url,
|
base_url,
|
||||||
|
|
@ -235,11 +235,11 @@ impl QdrantClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get Qdrant client from app state
|
/// Get Qdrant client from app state
|
||||||
pub fn get_qdrant_client(_state: &AppState) -> Result<QdrantClient, Box<dyn Error + Send + Sync>> {
|
pub fn get_qdrant_client(_state: &AppState) -> Result<VectorDBClient, Box<dyn Error + Send + Sync>> {
|
||||||
let qdrant_url =
|
let qdrant_url =
|
||||||
std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
|
std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||||
|
|
||||||
Ok(QdrantClient::new(qdrant_url))
|
Ok(VectorDBClient::new(qdrant_url))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Ensure a collection exists, create if not
|
/// Ensure a collection exists, create if not
|
||||||
|
|
@ -280,7 +280,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_qdrant_client_creation() {
|
fn test_qdrant_client_creation() {
|
||||||
let client = QdrantClient::new("http://localhost:6333".to_string());
|
let client = VectorDBClient::new("http://localhost:6333".to_string());
|
||||||
assert_eq!(client.base_url, "http://localhost:6333");
|
assert_eq!(client.base_url, "http://localhost:6333");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue