diff --git a/src/bot/mod.rs b/src/bot/mod.rs index e40998fe..1ff35d08 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -455,28 +455,6 @@ impl BotOrchestrator { return Ok(()); } - - - // Handle context change messages (type 4) immediately - // before any other processing - if message.message_type == 4 { - if let Some(context_name) = &message.context_name { - self - .handle_context_change( - &message.user_id, - &message.bot_id, - &message.session_id, - &message.channel, - context_name, - ) - .await?; - - } - } - - - let response_content = self.direct_mode_handler(&message, &session).await?; - { let mut session_manager = self.state.session_manager.lock().await; session_manager.save_message( @@ -486,10 +464,31 @@ impl BotOrchestrator { &message.content, message.message_type, )?; + } + + let response_content = self.direct_mode_handler(&message, &session).await?; + + { + let mut session_manager = self.state.session_manager.lock().await; session_manager.save_message(session.id, user_id, 2, &response_content, 1)?; } - // Create regular response for non-context-change messages + // Handle context change messages (type 4) first + if message.message_type == 4 { + if let Some(context_name) = &message.context_name { + return self + .handle_context_change( + &message.user_id, + &message.bot_id, + &message.session_id, + &message.channel, + context_name, + ) + .await; + } + } + + // Create regular response let channel = message.channel.clone(); let config_manager = ConfigManager::new(Arc::clone(&self.state.conn)); let max_context_size = config_manager @@ -566,30 +565,94 @@ impl BotOrchestrator { prompt.push_str(&format!("User: {}\nAssistant:", message.content)); + let use_langcache = std::env::var("LLM_CACHE") + .unwrap_or_else(|_| "false".to_string()) + .eq_ignore_ascii_case("true"); - let user_message = UserMessage { - bot_id: "default".to_string(), - user_id: session.user_id.to_string(), - session_id: session.id.to_string(), - channel: "web".to_string(), - content: message.content.clone(), - message_type: 1, - media_url: None, - timestamp: Utc::now(), - context_name: None, - }; + if use_langcache { + ensure_collection_exists(&self.state, "semantic_cache").await?; + let langcache_client = get_langcache_client()?; + let isolated_question = message.content.trim().to_string(); + let question_embeddings = generate_embeddings(vec![isolated_question.clone()]).await?; + let question_embedding = question_embeddings + .get(0) + .ok_or_else(|| "Failed to generate embedding for question")? + .clone(); - let (response_tx, mut response_rx) = mpsc::channel::(100); - if let Err(e) = self.stream_response(user_message, response_tx).await { - error!("Failed to stream response in direct_mode_handler: {}", e); + let search_results = langcache_client + .search("semantic_cache", question_embedding.clone(), 1) + .await?; + + if let Some(result) = search_results.first() { + let payload = &result.payload; + if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) { + return Ok(resp.to_string()); + } + } + + let response = self + .state + .llm_provider + .generate(&prompt, &serde_json::Value::Null) + .await?; + + let point = QdrantPoint { + id: uuid::Uuid::new_v4().to_string(), + vector: question_embedding, + payload: serde_json::json!({ + "question": isolated_question, + "prompt": prompt, + "response": response + }), + }; + + langcache_client + .upsert_points("semantic_cache", vec![point]) + .await?; + + Ok(response) + } else { + ensure_collection_exists(&self.state, "semantic_cache").await?; + let qdrant_client = get_qdrant_client(&self.state)?; + let embeddings = generate_embeddings(vec![prompt.clone()]).await?; + let embedding = embeddings + .get(0) + .ok_or_else(|| "Failed to generate embedding")? + .clone(); + + let search_results = qdrant_client + .search("semantic_cache", embedding.clone(), 1) + .await?; + + if let Some(result) = search_results.first() { + if let Some(payload) = &result.payload { + if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) { + return Ok(resp.to_string()); + } + } + } + + let response = self + .state + .llm_provider + .generate(&prompt, &serde_json::Value::Null) + .await?; + + let point = QdrantPoint { + id: uuid::Uuid::new_v4().to_string(), + vector: embedding, + payload: serde_json::json!({ + "prompt": prompt, + "response": response + }), + }; + + qdrant_client + .upsert_points("semantic_cache", vec![point]) + .await?; + + Ok(response) } - - let mut full_response = String::new(); - while let Some(response) = response_rx.recv().await { - full_response.push_str(&response.content); - } - - Ok(full_response) } pub async fn stream_response( diff --git a/templates/announcements.gbai/announcements.gbdialog/update-summary.bas b/templates/announcements.gbai/announcements.gbdialog/update-summary.bas index fe80fb2b..de494afe 100644 --- a/templates/announcements.gbai/announcements.gbdialog/update-summary.bas +++ b/templates/announcements.gbai/announcements.gbdialog/update-summary.bas @@ -1,4 +1,3 @@ -SET_SCHEDULE "*/10 * * * *" let text = GET "announcements.gbkb/news/news.pdf"