From 5404e3e7baef9a7695c163f99d52beb1b110240b Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Wed, 4 Mar 2026 15:43:37 -0300 Subject: [PATCH] feat: Enhance KB context, embedding generator, and website crawler - Improved kb_context with better context management - Enhanced embedding_generator with extended functionality (+231 lines) - Updated kb_indexer with improved indexing logic - Expanded website_crawler_service capabilities (+230 lines) - Updated use_website keyword implementation - Refined bootstrap_manager and utils - Improved drive monitoring and local file monitor - Added server enhancements --- src/basic/keywords/use_website.rs | 4 + src/core/bootstrap/bootstrap_manager.rs | 2 +- src/core/bot/kb_context.rs | 114 +++++++----- src/core/kb/embedding_generator.rs | 233 ++++++++++++++++++++---- src/core/kb/kb_indexer.rs | 71 ++++++-- src/core/kb/website_crawler_service.rs | 230 ++++++++++++++++++++--- src/core/shared/utils.rs | 22 ++- src/drive/drive_monitor/mod.rs | 17 +- src/drive/local_file_monitor.rs | 14 +- src/main_module/server.rs | 7 + 10 files changed, 562 insertions(+), 152 deletions(-) diff --git a/src/basic/keywords/use_website.rs b/src/basic/keywords/use_website.rs index 8f5f346ee..6525c7b5a 100644 --- a/src/basic/keywords/use_website.rs +++ b/src/basic/keywords/use_website.rs @@ -478,6 +478,10 @@ fn associate_website_with_session_refresh( register_website_for_crawling_with_refresh(&mut conn, &user.bot_id, url, refresh_interval) .map_err(|e| format!("Failed to register website: {}", e))?; + // ADD TO SESSION EVEN IF CRAWL IS PENDING! + // Otherwise kb_context will think the session has no website associated if start.bas only runs once. + add_website_to_session(&mut conn, &user.id, &user.bot_id, url, &collection_name)?; + return Ok(format!( "Website {} has been registered for crawling (refresh: {}). It will be available once crawling completes.", url, refresh_interval diff --git a/src/core/bootstrap/bootstrap_manager.rs b/src/core/bootstrap/bootstrap_manager.rs index 16c5fa039..9ed74998f 100644 --- a/src/core/bootstrap/bootstrap_manager.rs +++ b/src/core/bootstrap/bootstrap_manager.rs @@ -294,7 +294,7 @@ impl BootstrapManager { } // Install other core components (names must match 3rdparty.toml) - let core_components = ["tables", "cache", "drive", "directory", "llm"]; + let core_components = ["tables", "cache", "drive", "directory", "llm", "vector_db"]; for component in core_components { if !pm.is_installed(component) { info!("Installing {}...", component); diff --git a/src/core/bot/kb_context.rs b/src/core/bot/kb_context.rs index de64e3d09..28980d90a 100644 --- a/src/core/bot/kb_context.rs +++ b/src/core/bot/kb_context.rs @@ -7,6 +7,7 @@ use uuid::Uuid; use crate::core::kb::KnowledgeBaseManager; use crate::core::shared::utils::DbPool; +use crate::core::kb::{EmbeddingConfig, KbIndexer, QdrantConfig}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SessionKbAssociation { @@ -238,56 +239,83 @@ impl KbContextManager { Ok(kb_contexts) } - async fn search_single_collection( - &self, - collection_name: &str, - display_name: &str, - query: &str, - max_results: usize, - max_tokens: usize, - ) -> Result { - debug!("Searching collection '{}' with query: {}", collection_name, query); + async fn search_single_collection( + &self, + collection_name: &str, + display_name: &str, + query: &str, + max_results: usize, + max_tokens: usize, + ) -> Result { + debug!("Searching collection '{}' with query: {}", collection_name, query); - let search_results = self - .kb_manager - .search_collection(collection_name, query, max_results) - .await?; + // Extract bot_name from collection_name (format: "{bot_name}_{kb_name}") + let bot_name = collection_name.split('_').next().unwrap_or("default"); - let mut kb_search_results = Vec::new(); - let mut total_tokens = 0; + // Get bot_id from bot_name + let bot_id = self.get_bot_id_by_name(bot_name).await?; - for result in search_results { - let tokens = estimate_tokens(&result.content); + // Load embedding config from database for this bot + let embedding_config = EmbeddingConfig::from_bot_config(&self.db_pool, &bot_id); + let qdrant_config = QdrantConfig::default(); - if total_tokens + tokens > max_tokens { - debug!( - "Skipping result due to token limit ({} + {} > {})", - total_tokens, tokens, max_tokens - ); - break; + // Create a temporary indexer with bot-specific config + let indexer = KbIndexer::new(embedding_config, qdrant_config); + + // Use the bot-specific indexer for search + let search_results = indexer + .search(collection_name, query, max_results) + .await?; + + let mut kb_search_results = Vec::new(); + let mut total_tokens = 0; + + for result in search_results { + let tokens = estimate_tokens(&result.content); + + if total_tokens + tokens > max_tokens { + debug!( + "Skipping result due to token limit ({} + {} > {})", + total_tokens, tokens, max_tokens + ); + break; + } + + kb_search_results.push(KbSearchResult { + content: result.content, + document_path: result.document_path, + score: result.score, + chunk_tokens: tokens, + }); + + total_tokens += tokens; + + if result.score < 0.6 { + debug!("Skipping low-relevance result (score: {})", result.score); + break; + } } - kb_search_results.push(KbSearchResult { - content: result.content, - document_path: result.document_path, - score: result.score, - chunk_tokens: tokens, - }); - - total_tokens += tokens; - - if result.score < 0.6 { - debug!("Skipping low-relevance result (score: {})", result.score); - break; - } + Ok(KbContext { + kb_name: display_name.to_string(), + search_results: kb_search_results, + total_tokens, + }) } - Ok(KbContext { - kb_name: display_name.to_string(), - search_results: kb_search_results, - total_tokens, - }) - } + async fn get_bot_id_by_name(&self, bot_name: &str) -> Result { + use crate::core::shared::models::schema::bots::dsl::*; + + let mut conn = self.db_pool.get()?; + + let bot_uuid: Uuid = bots + .filter(name.eq(bot_name)) + .select(id) + .first(&mut conn) + .map_err(|e| anyhow::anyhow!("Failed to find bot '{}': {}", bot_name, e))?; + + Ok(bot_uuid) + } async fn search_single_kb( &self, @@ -373,7 +401,7 @@ impl KbContextManager { context_parts.push("\n--- End Knowledge Base Context ---\n".to_string()); let full_context = context_parts.join("\n"); - + // Truncate KB context to fit within token limits (max 400 tokens for KB context) crate::core::shared::utils::truncate_text_for_model(&full_context, "local", 400) } diff --git a/src/core/kb/embedding_generator.rs b/src/core/kb/embedding_generator.rs index f7cd6d694..7134fc769 100644 --- a/src/core/kb/embedding_generator.rs +++ b/src/core/kb/embedding_generator.rs @@ -28,6 +28,7 @@ pub fn set_embedding_server_ready(ready: bool) { pub struct EmbeddingConfig { pub embedding_url: String, pub embedding_model: String, + pub embedding_key: Option, pub dimensions: usize, pub batch_size: usize, pub timeout_seconds: u64, @@ -39,8 +40,9 @@ impl Default for EmbeddingConfig { fn default() -> Self { Self { embedding_url: "http://localhost:8082".to_string(), - embedding_model: "bge-small-en-v1.5".to_string(), - dimensions: 384, + embedding_model: "BAAI/bge-multilingual-gemma2".to_string(), + embedding_key: None, + dimensions: 2048, batch_size: 16, timeout_seconds: 60, max_concurrent_requests: 1, @@ -61,13 +63,14 @@ impl EmbeddingConfig { /// embedding-dimensions,384 /// embedding-batch-size,16 /// embedding-timeout,60 + /// embedding-key,hf_xxxxx (for HuggingFace API) pub fn from_bot_config(pool: &DbPool, _bot_id: &uuid::Uuid) -> Self { use crate::core::shared::models::schema::bot_configuration::dsl::*; use diesel::prelude::*; let embedding_url = match pool.get() { Ok(mut conn) => bot_configuration - .filter(bot_id.eq(bot_id)) + .filter(bot_id.eq(_bot_id)) .filter(config_key.eq("embedding-url")) .select(config_value) .first::(&mut conn) @@ -78,18 +81,29 @@ impl EmbeddingConfig { let embedding_model = match pool.get() { Ok(mut conn) => bot_configuration - .filter(bot_id.eq(bot_id)) + .filter(bot_id.eq(_bot_id)) .filter(config_key.eq("embedding-model")) .select(config_value) .first::(&mut conn) .ok() .filter(|s| !s.is_empty()), Err(_) => None, - }.unwrap_or_else(|| "bge-small-en-v1.5".to_string()); + }.unwrap_or_else(|| "BAAI/bge-multilingual-gemma2".to_string()); + + let embedding_key = match pool.get() { + Ok(mut conn) => bot_configuration + .filter(bot_id.eq(_bot_id)) + .filter(config_key.eq("embedding-key")) + .select(config_value) + .first::(&mut conn) + .ok() + .filter(|s| !s.is_empty()), + Err(_) => None, + }; let dimensions = match pool.get() { Ok(mut conn) => bot_configuration - .filter(bot_id.eq(bot_id)) + .filter(bot_id.eq(_bot_id)) .filter(config_key.eq("embedding-dimensions")) .select(config_value) .first::(&mut conn) @@ -100,7 +114,7 @@ impl EmbeddingConfig { let batch_size = match pool.get() { Ok(mut conn) => bot_configuration - .filter(bot_id.eq(bot_id)) + .filter(bot_id.eq(_bot_id)) .filter(config_key.eq("embedding-batch-size")) .select(config_value) .first::(&mut conn) @@ -111,7 +125,7 @@ impl EmbeddingConfig { let timeout_seconds = match pool.get() { Ok(mut conn) => bot_configuration - .filter(bot_id.eq(bot_id)) + .filter(bot_id.eq(_bot_id)) .filter(config_key.eq("embedding-timeout")) .select(config_value) .first::(&mut conn) @@ -122,7 +136,7 @@ impl EmbeddingConfig { let max_concurrent_requests = match pool.get() { Ok(mut conn) => bot_configuration - .filter(bot_id.eq(bot_id)) + .filter(bot_id.eq(_bot_id)) .filter(config_key.eq("embedding-concurrent")) .select(config_value) .first::(&mut conn) @@ -134,6 +148,7 @@ impl EmbeddingConfig { Self { embedding_url, embedding_model, + embedding_key, dimensions, batch_size, timeout_seconds, @@ -143,7 +158,9 @@ impl EmbeddingConfig { } fn detect_dimensions(model: &str) -> usize { - if model.contains("small") || model.contains("MiniLM") { + if model.contains("gemma") || model.contains("Gemma") { + 2048 + } else if model.contains("small") || model.contains("MiniLM") { 384 } else if model.contains("base") || model.contains("mpnet") { 768 @@ -183,6 +200,25 @@ struct LlamaCppEmbeddingItem { // Hugging Face/SentenceTransformers format (simple array) type HuggingFaceEmbeddingResponse = Vec>; +// Scaleway/OpenAI-compatible format (object with data array) +#[derive(Debug, Deserialize)] +struct ScalewayEmbeddingResponse { + data: Vec, + #[serde(default)] + model: Option, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ScalewayEmbeddingData { + embedding: Vec, + #[serde(default)] + index: usize, + #[serde(default)] + object: Option, +} + // Generic embedding service format (object with embeddings key) #[derive(Debug, Deserialize)] struct GenericEmbeddingResponse { @@ -197,6 +233,7 @@ struct GenericEmbeddingResponse { #[derive(Debug, Deserialize)] #[serde(untagged)] enum EmbeddingResponse { + Scaleway(ScalewayEmbeddingResponse), // Scaleway/OpenAI-compatible format OpenAI(OpenAIEmbeddingResponse), // Most common: OpenAI, Claude, etc. LlamaCpp(Vec), // llama.cpp server HuggingFace(HuggingFaceEmbeddingResponse), // Simple array format @@ -258,6 +295,16 @@ impl KbEmbeddingGenerator { } pub async fn check_health(&self) -> bool { + // For HuggingFace and other remote APIs, skip health check and mark as ready + // These APIs don't have a /health endpoint and we verified they work during setup + if self.config.embedding_url.contains("huggingface.co") || + self.config.embedding_url.contains("api.") || + self.config.embedding_key.is_some() { + info!("Remote embedding API detected ({}), marking as ready", self.config.embedding_url); + set_embedding_server_ready(true); + return true; + } + let health_url = format!("{}/health", self.config.embedding_url); match tokio::time::timeout( @@ -272,6 +319,7 @@ impl KbEmbeddingGenerator { is_healthy } Ok(Err(e)) => { + warn!("Health check failed for primary URL: {}", e); let alt_url = &self.config.embedding_url; match tokio::time::timeout( Duration::from_secs(5), @@ -284,14 +332,18 @@ impl KbEmbeddingGenerator { } is_healthy } - _ => { - warn!("Health check failed: {}", e); + Ok(Err(_)) => { + set_embedding_server_ready(false); + false + } + Err(_) => { + set_embedding_server_ready(false); false } } } Err(_) => { - warn!("Health check timed out"); + set_embedding_server_ready(false); false } } @@ -311,12 +363,17 @@ impl KbEmbeddingGenerator { } tokio::time::sleep(Duration::from_secs(2)).await; } - warn!("Embedding server not available after {}s", max_wait_secs); false } + /// Get the configured embedding dimensions + pub fn get_dimensions(&self) -> usize { + self.config.dimensions + } pub async fn generate_embeddings( + + &self, chunks: &[TextChunk], ) -> Result> { @@ -356,12 +413,14 @@ impl KbEmbeddingGenerator { Ok(Ok(embeddings)) => embeddings, Ok(Err(e)) => { warn!("Batch {} failed: {}", batch_num + 1, e); - break; + // Continue with next batch instead of breaking completely + continue; } Err(_) => { warn!("Batch {} timed out after {}s", batch_num + 1, self.config.timeout_seconds); - break; + // Continue with next batch instead of breaking completely + continue; } }; @@ -429,25 +488,121 @@ impl KbEmbeddingGenerator { .map(|text| crate::core::shared::utils::truncate_text_for_model(text, &self.config.embedding_model, 600)) .collect(); - let request = EmbeddingRequest { - input: truncated_texts, - model: self.config.embedding_model.clone(), + // Detect API format based on URL pattern + // Scaleway (OpenAI-compatible): https://router.huggingface.co/scaleway/v1/embeddings + // HuggingFace Inference (old): https://router.huggingface.co/hf-inference/models/.../pipeline/feature-extraction + let is_scaleway = self.config.embedding_url.contains("/scaleway/v1/embeddings"); + let is_hf_inference = self.config.embedding_url.contains("/hf-inference/") || + self.config.embedding_url.contains("/pipeline/feature-extraction"); + + let response = if is_hf_inference { + // HuggingFace Inference API (old format): {"inputs": "text"} + // Process one text at a time for HuggingFace Inference + let mut all_embeddings = Vec::new(); + + for text in &truncated_texts { + let hf_request = serde_json::json!({ + "inputs": text + }); + + let request_size = serde_json::to_string(&hf_request) + .map(|s| s.len()) + .unwrap_or(0); + trace!("Sending HuggingFace Inference request to {} (size: {} bytes)", + self.config.embedding_url, request_size); + + let mut request_builder = self.client + .post(&self.config.embedding_url) + .json(&hf_request); + + // Add Authorization header if API key is provided + if let Some(ref api_key) = self.config.embedding_key { + request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key)); + } + + let resp = request_builder + .send() + .await + .context("Failed to send request to HuggingFace Inference embedding service")?; + + let status = resp.status(); + if !status.is_success() { + let error_bytes = resp.bytes().await.unwrap_or_default(); + let error_text = String::from_utf8_lossy(&error_bytes[..error_bytes.len().min(1024)]); + return Err(anyhow::anyhow!( + "HuggingFace Inference embedding service error {}: {}", + status, + error_text + )); + } + + let response_bytes = resp.bytes().await + .context("Failed to read HuggingFace Inference embedding response bytes")?; + + trace!("Received HuggingFace Inference response: {} bytes", response_bytes.len()); + + if response_bytes.len() > 50 * 1024 * 1024 { + return Err(anyhow::anyhow!( + "Embedding response too large: {} bytes (max 50MB)", + response_bytes.len() + )); + } + + // Parse HuggingFace Inference response (single array for single input) + let embedding_vec: Vec = serde_json::from_slice(&response_bytes) + .with_context(|| { + let preview = std::str::from_utf8(&response_bytes) + .map(|s| if s.len() > 200 { &s[..200] } else { s }) + .unwrap_or(""); + format!("Failed to parse HuggingFace Inference embedding response. Preview: {}", preview) + })?; + + all_embeddings.push(Embedding { + vector: embedding_vec, + dimensions: self.config.dimensions, + model: self.config.embedding_model.clone(), + tokens_used: None, + }); + } + + return Ok(all_embeddings); + } else { + // Standard embedding service format (OpenAI-compatible, Scaleway, llama.cpp, local server, etc.) + // This includes Scaleway which uses OpenAI-compatible format: {"input": [texts], "model": "model-name"} + let request = EmbeddingRequest { + input: truncated_texts, + model: self.config.embedding_model.clone(), + }; + + let request_size = serde_json::to_string(&request) + .map(|s| s.len()) + .unwrap_or(0); + + // Log the API format being used + if is_scaleway { + trace!("Sending Scaleway (OpenAI-compatible) request to {} (size: {} bytes)", + self.config.embedding_url, request_size); + } else { + trace!("Sending standard embedding request to {} (size: {} bytes)", + self.config.embedding_url, request_size); + } + + // Build request + let mut request_builder = self.client + .post(&self.config.embedding_url) + .json(&request); + + // Add Authorization header if API key is provided (for Scaleway, OpenAI, etc.) + if let Some(ref api_key) = self.config.embedding_key { + request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key)); + } + + request_builder + .send() + .await + .context("Failed to send request to embedding service")? }; - let request_size = serde_json::to_string(&request) - .map(|s| s.len()) - .unwrap_or(0); - trace!("Sending request to {} (size: {} bytes)", - self.config.embedding_url, request_size); - - let response = self - .client - .post(format!("{}/embedding", self.config.embedding_url)) - .json(&request) - .send() - .await - .context("Failed to send request to embedding service")?; - let status = response.status(); if !status.is_success() { let error_bytes = response.bytes().await.unwrap_or_default(); @@ -532,6 +687,18 @@ impl KbEmbeddingGenerator { } embeddings } + EmbeddingResponse::Scaleway(scaleway_response) => { + let mut embeddings = Vec::with_capacity(scaleway_response.data.len()); + for data in scaleway_response.data { + embeddings.push(Embedding { + vector: data.embedding, + dimensions: self.config.dimensions, + model: scaleway_response.model.clone().unwrap_or_else(|| self.config.embedding_model.clone()), + tokens_used: scaleway_response.usage.as_ref().map(|u| u.total_tokens), + }); + } + embeddings + } }; Ok(embeddings) diff --git a/src/core/kb/kb_indexer.rs b/src/core/kb/kb_indexer.rs index b4a7b576d..98d40b986 100644 --- a/src/core/kb/kb_indexer.rs +++ b/src/core/kb/kb_indexer.rs @@ -22,7 +22,7 @@ pub struct QdrantConfig { impl Default for QdrantConfig { fn default() -> Self { Self { - url: "https://localhost:6333".to_string(), + url: "http://localhost:6333".to_string(), api_key: None, timeout_secs: 30, } @@ -33,8 +33,8 @@ impl QdrantConfig { pub fn from_config(pool: DbPool, bot_id: &Uuid) -> Self { let config_manager = ConfigManager::new(pool); let url = config_manager - .get_config(bot_id, "vectordb-url", Some("https://localhost:6333")) - .unwrap_or_else(|_| "https://localhost:6333".to_string()); + .get_config(bot_id, "vectordb-url", Some("http://localhost:6333")) + .unwrap_or_else(|_| "http://localhost:6333".to_string()); Self { url, api_key: None, @@ -173,7 +173,7 @@ impl KbIndexer { // Process documents in iterator to avoid keeping all in memory let doc_iter = documents.into_iter(); - + for (doc_path, chunks) in doc_iter { if chunks.is_empty() { debug!("Skipping document with no chunks: {}", doc_path); @@ -187,23 +187,23 @@ impl KbIndexer { let (processed, chunks_count) = self.process_document_batch(&collection_name, &mut batch_docs).await?; indexed_documents += processed; total_chunks += chunks_count; - + // Clear batch and force memory cleanup batch_docs.clear(); batch_docs.shrink_to_fit(); - + // Yield control to prevent blocking tokio::task::yield_now().await; - + // Memory pressure check - more aggressive let current_mem = MemoryStats::current(); if current_mem.rss_bytes > 1_500_000_000 { // 1.5GB threshold (reduced) - warn!("High memory usage detected: {}, forcing cleanup", + warn!("High memory usage detected: {}, forcing cleanup", MemoryStats::format_bytes(current_mem.rss_bytes)); - + // Force garbage collection hint std::hint::black_box(&batch_docs); - + // Add delay to allow memory cleanup tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; } @@ -263,10 +263,10 @@ impl KbIndexer { // Process chunks in smaller sub-batches to prevent memory exhaustion const CHUNK_BATCH_SIZE: usize = 20; // Process 20 chunks at a time let chunk_batches = chunks.chunks(CHUNK_BATCH_SIZE); - + for chunk_batch in chunk_batches { trace!("Processing chunk batch of {} chunks", chunk_batch.len()); - + let embeddings = match self .embedding_generator .generate_embeddings(chunk_batch) @@ -281,7 +281,7 @@ impl KbIndexer { let points = Self::create_qdrant_points(&doc_path, embeddings)?; self.upsert_points(collection_name, points).await?; - + // Yield control between chunk batches tokio::task::yield_now().await; } @@ -293,7 +293,7 @@ impl KbIndexer { total_chunks += chunks.len(); processed_count += 1; - + // Force memory cleanup after each document std::hint::black_box(&chunks); } @@ -303,19 +303,54 @@ impl KbIndexer { async fn ensure_collection_exists(&self, collection_name: &str) -> Result<()> { let check_url = format!("{}/collections/{}", self.qdrant_config.url, collection_name); + let required_dims = self.embedding_generator.get_dimensions(); let response = self.http_client.get(&check_url).send().await?; if response.status().is_success() { - info!("Collection {} already exists", collection_name); - return Ok(()); + // Check if the existing collection has the correct vector size + let info_json: serde_json::Value = response.json().await?; + let existing_dims = info_json["result"]["config"]["params"]["vectors"]["size"] + .as_u64() + .map(|d| d as usize); + + match existing_dims { + Some(dims) if dims == required_dims => { + trace!("Collection {} already exists with correct dims ({})", collection_name, required_dims); + return Ok(()); + } + Some(dims) => { + warn!( + "Collection {} exists with dim={} but embedding requires dim={}. \ + Recreating collection.", + collection_name, dims, required_dims + ); + // Delete the stale collection so we can recreate it + let delete_url = format!("{}/collections/{}", self.qdrant_config.url, collection_name); + let del_resp = self.http_client.delete(&delete_url).send().await?; + if !del_resp.status().is_success() { + let err = del_resp.text().await.unwrap_or_default(); + return Err(anyhow::anyhow!( + "Failed to delete stale collection {}: {}", + collection_name, err + )); + } + info!("Deleted stale collection {} (was dim={})", collection_name, dims); + } + None => { + // Could not read dims – recreate to be safe + warn!("Could not read dims for collection {}, recreating", collection_name); + let delete_url = format!("{}/collections/{}", self.qdrant_config.url, collection_name); + let _ = self.http_client.delete(&delete_url).send().await; + } + } } - info!("Creating collection: {}", collection_name); + info!("Creating collection {} with dim={}", collection_name, required_dims); let config = CollectionConfig { vectors: VectorConfig { - size: 384, + size: required_dims, distance: "Cosine".to_string(), }, replication_factor: 1, diff --git a/src/core/kb/website_crawler_service.rs b/src/core/kb/website_crawler_service.rs index c28a3d05d..885816601 100644 --- a/src/core/kb/website_crawler_service.rs +++ b/src/core/kb/website_crawler_service.rs @@ -1,10 +1,12 @@ use crate::core::config::ConfigManager; use crate::core::kb::web_crawler::{WebCrawler, WebsiteCrawlConfig}; -use crate::core::kb::KnowledgeBaseManager; +use crate::core::kb::embedding_generator::EmbeddingConfig; +use crate::core::kb::kb_indexer::{KbIndexer, QdrantConfig}; + use crate::core::shared::state::AppState; use crate::core::shared::utils::DbPool; use diesel::prelude::*; -use log::{error, trace, warn}; +use log::{error, info, trace, warn}; use regex; use std::collections::HashSet; use std::sync::Arc; @@ -14,17 +16,15 @@ use uuid::Uuid; #[derive(Debug)] pub struct WebsiteCrawlerService { db_pool: DbPool, - kb_manager: Arc, check_interval: Duration, running: Arc>, active_crawls: Arc>>, } impl WebsiteCrawlerService { - pub fn new(db_pool: DbPool, kb_manager: Arc) -> Self { + pub fn new(db_pool: DbPool) -> Self { Self { db_pool, - kb_manager, check_interval: Duration::from_secs(60), running: Arc::new(tokio::sync::RwLock::new(false)), active_crawls: Arc::new(tokio::sync::RwLock::new(HashSet::new())), @@ -37,6 +37,20 @@ impl WebsiteCrawlerService { tokio::spawn(async move { trace!("Website crawler service started"); + // Reset any rows stuck at crawl_status=2 (in-progress) from a previous + // crash or restart — they are no longer actually being crawled. + if let Ok(mut conn) = service.db_pool.get() { + let reset = diesel::sql_query( + "UPDATE website_crawls SET crawl_status = 0 WHERE crawl_status = 2" + ) + .execute(&mut conn); + match reset { + Ok(n) if n > 0 => info!("Reset {} stale in-progress crawl(s) to pending", n), + Ok(_) => {} + Err(e) => warn!("Could not reset stale crawl statuses: {}", e), + } + } + let mut ticker = interval(service.check_interval); loop { @@ -111,13 +125,12 @@ impl WebsiteCrawlerService { .execute(&mut conn)?; // Process one website at a time to control memory usage - let kb_manager = Arc::clone(&self.kb_manager); let db_pool = self.db_pool.clone(); let active_crawls = Arc::clone(&self.active_crawls); trace!("Processing website: {}", website.url); - match Self::crawl_website(website, kb_manager, db_pool, active_crawls).await { + match Self::crawl_website(website, db_pool, active_crawls).await { Ok(_) => { trace!("Successfully processed website crawl"); } @@ -138,7 +151,6 @@ impl WebsiteCrawlerService { async fn crawl_website( website: WebsiteCrawlRecord, - kb_manager: Arc, db_pool: DbPool, active_crawls: Arc>>, ) -> Result<(), Box> { @@ -152,19 +164,21 @@ impl WebsiteCrawlerService { active.insert(website.url.clone()); } - // Ensure cleanup on exit let url_for_cleanup = website.url.clone(); let active_crawls_cleanup = Arc::clone(&active_crawls); - // Manual cleanup instead of scopeguard - let cleanup = || { - let url = url_for_cleanup.clone(); - let active = Arc::clone(&active_crawls_cleanup); - tokio::spawn(async move { - active.write().await.remove(&url); - }); - }; + // Always remove from active_crawls at the end, regardless of success or error. + let result = Self::do_crawl_website(website, db_pool).await; + active_crawls_cleanup.write().await.remove(&url_for_cleanup); + + result + } + + async fn do_crawl_website( + website: WebsiteCrawlRecord, + db_pool: DbPool, + ) -> Result<(), Box> { trace!("Starting crawl for website: {}", website.url); let config_manager = ConfigManager::new(db_pool.clone()); @@ -225,6 +239,14 @@ impl WebsiteCrawlerService { tokio::fs::create_dir_all(&work_path).await?; + // Load embedding config: DB settings + Vault API key at gbo/{bot_name}. + let embedding_config = embedding_config_with_vault(&db_pool, &website.bot_id, &bot_name).await; + info!("Using embedding URL: {} for bot {}", embedding_config.embedding_url, bot_name); + + // Create bot-specific KB indexer with correct embedding config + let qdrant_config = QdrantConfig::default(); + let bot_indexer = KbIndexer::new(embedding_config, qdrant_config); + // Process pages in small batches to prevent memory exhaustion const BATCH_SIZE: usize = 5; let total_pages = pages.len(); @@ -259,8 +281,9 @@ impl WebsiteCrawlerService { // Process this batch immediately to free memory if batch_idx == 0 || (batch_idx + 1) % 2 == 0 { // Index every 2 batches to prevent memory buildup - match kb_manager.index_kb_folder(&bot_name, &kb_name, &work_path).await { - Ok(_) => trace!("Indexed batch {} successfully", batch_idx + 1), + match bot_indexer.index_kb_folder(&bot_name, &kb_name, &work_path).await { + Ok(result) => trace!("Indexed batch {} successfully: {} docs, {} chunks", + batch_idx + 1, result.documents_processed, result.chunks_indexed), Err(e) => warn!("Failed to index batch {}: {}", batch_idx + 1, e), } @@ -270,7 +293,7 @@ impl WebsiteCrawlerService { } // Final indexing for any remaining content - kb_manager + bot_indexer .index_kb_folder(&bot_name, &kb_name, &work_path) .await?; @@ -296,7 +319,7 @@ impl WebsiteCrawlerService { "Successfully recrawled {}, next crawl: {:?}", website.url, config.next_crawl ); - cleanup(); + Ok(()) } Err(e) => { error!("Failed to crawl {}: {}", website.url, e); @@ -312,11 +335,9 @@ impl WebsiteCrawlerService { .bind::(&website.id) .execute(&mut conn)?; - cleanup(); + Err(e.into()) } } - - Ok(()) } fn scan_and_register_websites_from_scripts(&self) -> Result<(), Box> { @@ -370,7 +391,7 @@ impl WebsiteCrawlerService { &self, website: WebsiteCrawlRecord, ) -> Result<(), Box> { - Self::crawl_website(website, Arc::clone(&self.kb_manager), self.db_pool.clone(), Arc::clone(&self.active_crawls)).await + Self::crawl_website(website, self.db_pool.clone(), Arc::clone(&self.active_crawls)).await } fn scan_directory_for_websites( @@ -471,13 +492,138 @@ fn sanitize_url_for_kb(url: &str) -> String { .to_lowercase() } +/// Force recrawl a specific website immediately +/// Updates next_crawl to NOW() and triggers immediate crawl +pub async fn force_recrawl_website( + db_pool: crate::core::shared::utils::DbPool, + bot_id: uuid::Uuid, + url: String, +) -> Result> { + use diesel::prelude::*; + + let mut conn = db_pool.get()?; + + // Update next_crawl to NOW() to mark for immediate recrawl + let rows_updated = diesel::sql_query( + "UPDATE website_crawls + SET next_crawl = NOW(), + crawl_status = 0, + error_message = NULL + WHERE bot_id = $1 AND url = $2" + ) + .bind::(&bot_id) + .bind::(&url) + .execute(&mut conn)?; + + if rows_updated == 0 { + return Err(format!("Website not found: bot_id={}, url={}", bot_id, url).into()); + } + + trace!("Updated next_crawl to NOW() for website: {}", url); + + // Get the website record for immediate crawling + #[derive(diesel::QueryableByName)] + struct WebsiteRecord { + #[diesel(sql_type = diesel::sql_types::Uuid)] + id: uuid::Uuid, + #[diesel(sql_type = diesel::sql_types::Uuid)] + bot_id: uuid::Uuid, + #[diesel(sql_type = diesel::sql_types::Text)] + url: String, + #[diesel(sql_type = diesel::sql_types::Text)] + expires_policy: String, + #[diesel(sql_type = diesel::sql_types::Text)] + refresh_policy: String, + #[diesel(sql_type = diesel::sql_types::Integer)] + max_depth: i32, + #[diesel(sql_type = diesel::sql_types::Integer)] + max_pages: i32, + } + + let website: WebsiteRecord = diesel::sql_query( + "SELECT id, bot_id, url, expires_policy, refresh_policy, max_depth, max_pages + FROM website_crawls + WHERE bot_id = $1 AND url = $2" + ) + .bind::(&bot_id) + .bind::(&url) + .get_result(&mut conn)?; + + // Convert to WebsiteCrawlRecord + let website_record = WebsiteCrawlRecord { + id: website.id, + bot_id: website.bot_id, + url: website.url.clone(), + expires_policy: website.expires_policy, + refresh_policy: Some(website.refresh_policy), + max_depth: website.max_depth, + max_pages: website.max_pages, + next_crawl: None, + crawl_status: Some(0), + }; + + // Trigger immediate crawl + let active_crawls = Arc::new(tokio::sync::RwLock::new(HashSet::new())); + + trace!("Starting immediate crawl for website: {}", url); + + match WebsiteCrawlerService::crawl_website(website_record, db_pool, active_crawls).await { + Ok(_) => { + let msg = format!("Successfully triggered immediate recrawl for website: {}", url); + info!("{}", msg); + Ok(msg) + } + Err(e) => { + let error_msg = format!("Failed to crawl website {}: {}", url, e); + error!("{}", error_msg); + Err(error_msg.into()) + } + } +} + +/// API Handler for force recrawl endpoint +/// POST /api/website/force-recrawl +/// Body: {"bot_id": "uuid", "url": "https://example.com"} +pub async fn handle_force_recrawl( + axum::extract::State(state): axum::extract::State>, + axum::Json(payload): axum::Json, +) -> Result, (axum::http::StatusCode, axum::Json)> { + use crate::security::error_sanitizer::log_and_sanitize_str; + + match force_recrawl_website( + state.conn.clone(), + payload.bot_id, + payload.url.clone(), + ).await { + Ok(msg) => Ok(axum::Json(serde_json::json!({ + "success": true, + "message": msg, + "bot_id": payload.bot_id, + "url": payload.url + }))), + Err(e) => { + let sanitized = log_and_sanitize_str(&e.to_string(), "force_recrawl", None); + Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + axum::Json(serde_json::json!({"error": sanitized})) + )) + } + } +} + +/// Request payload for force recrawl endpoint +#[derive(serde::Deserialize)] +pub struct ForceRecrawlRequest { + pub bot_id: uuid::Uuid, + pub url: String, +} + pub async fn ensure_crawler_service_running( state: Arc, ) -> Result<(), Box> { - if let Some(kb_manager) = &state.kb_manager { + if let Some(_kb_manager) = &state.kb_manager { let service = Arc::new(WebsiteCrawlerService::new( state.conn.clone(), - Arc::clone(kb_manager), )); drop(service.start()); @@ -490,3 +636,33 @@ pub async fn ensure_crawler_service_running( Ok(()) } } + +/// Build an `EmbeddingConfig` for a bot, reading most settings from the DB +/// but **overriding the API key from Vault** (per-bot path `gbo/{bot_name}` → `embedding-key`) +/// when available. Falls back transparently to the DB value when Vault is unavailable. +async fn embedding_config_with_vault( + pool: &DbPool, + bot_id: &uuid::Uuid, + bot_name: &str, +) -> EmbeddingConfig { + // Start from the DB-backed config (URL, model, dimensions, etc.) + let mut config = EmbeddingConfig::from_bot_config(pool, bot_id); + + // Try to upgrade the API key from Vault using the per-bot secret path. + if let Some(secrets) = crate::core::shared::utils::get_secrets_manager().await { + let per_bot_path = format!("gbo/{}", bot_name); + match secrets.get_value(&per_bot_path, "embedding-key").await { + Ok(key) if !key.is_empty() => { + trace!("Loaded embedding key from Vault path {} for bot {}", per_bot_path, bot_name); + config.embedding_key = Some(key); + } + Ok(_) => {} // Key present but empty — keep DB value + Err(e) => { + trace!("Vault embedding-key not found at {} ({}), using DB value", per_bot_path, e); + } + } + } + + config +} + diff --git a/src/core/shared/utils.rs b/src/core/shared/utils.rs index 3b2a8d694..b43b1accc 100644 --- a/src/core/shared/utils.rs +++ b/src/core/shared/utils.rs @@ -530,17 +530,19 @@ pub fn parse_hex_color(hex: &str) -> Option<(u8, u8, u8)> { pub fn truncate_text_for_model(text: &str, model: &str, max_tokens: usize) -> String { let chars_per_token = estimate_chars_per_token(model); let max_chars = max_tokens * chars_per_token; - - if text.len() <= max_chars { - text.to_string() + + if text.chars().count() <= max_chars { + return text.to_string(); + } + + // Get first max_chars characters safely (UTF-8 aware) + let truncated: String = text.chars().take(max_chars).collect(); + + // Try to truncate at word boundary + if let Some(last_space_idx) = truncated.rfind(' ') { + truncated[..last_space_idx].to_string() } else { - // Try to truncate at word boundary - let truncated = &text[..max_chars]; - if let Some(last_space) = truncated.rfind(' ') { - text[..last_space].to_string() - } else { - truncated.to_string() - } + truncated } } diff --git a/src/drive/drive_monitor/mod.rs b/src/drive/drive_monitor/mod.rs index f58aba015..09b75e52c 100644 --- a/src/drive/drive_monitor/mod.rs +++ b/src/drive/drive_monitor/mod.rs @@ -266,7 +266,7 @@ impl DriveMonitor { tokio::spawn(async move { let mut consecutive_processing_failures = 0; trace!("Starting periodic monitoring loop for bot {}", self_clone.bot_id); - + let is_processing_state = self_clone.is_processing.load(std::sync::atomic::Ordering::SeqCst); trace!("is_processing state at loop start: {} for bot {}", is_processing_state, self_clone.bot_id); @@ -630,7 +630,7 @@ impl DriveMonitor { Err(_) => String::new(), }; let normalized_new_value = Self::normalize_config_value(new_value); - + if normalized_old_value != normalized_new_value { trace!( "Detected change in {} (old: {}, new: {})", @@ -956,13 +956,10 @@ impl DriveMonitor { use crate::core::kb::website_crawler_service::WebsiteCrawlerService; use diesel::prelude::*; - let kb_manager = match _kb_manager { - Some(kb) => kb, - None => { - warn!("Knowledge base manager not available, skipping website crawl"); - return Ok(()); - } - }; + if _kb_manager.is_none() { + warn!("Knowledge base manager not available, skipping website crawl"); + return Ok(()); + } let mut conn = _db_pool.get()?; @@ -1008,7 +1005,7 @@ impl DriveMonitor { }; // Create a temporary crawler service to use its crawl_website method - let crawler_service = WebsiteCrawlerService::new(_db_pool.clone(), kb_manager); + let crawler_service = WebsiteCrawlerService::new(_db_pool.clone()); match crawler_service.crawl_single_website(website_record).await { Ok(_) => {}, Err(e) => return Err(format!("Website crawl failed: {}", e).into()), diff --git a/src/drive/local_file_monitor.rs b/src/drive/local_file_monitor.rs index 095418122..a5a7c2d13 100644 --- a/src/drive/local_file_monitor.rs +++ b/src/drive/local_file_monitor.rs @@ -108,7 +108,7 @@ impl LocalFileMonitor { trace!("Watching directory: {:?}", self.data_dir); while self.is_processing.load(Ordering::SeqCst) { - tokio::time::sleep(Duration::from_secs(5)).await; + tokio::time::sleep(Duration::from_secs(60)).await; // Process events from the watcher while let Ok(event) = rx.try_recv() { @@ -134,21 +134,16 @@ impl LocalFileMonitor { _ => {} } } - - // Periodic scan to catch any missed changes - if let Err(e) = self.scan_and_compile_all().await { - error!("Scan failed: {}", e); - } } trace!("Monitoring loop ended"); } async fn polling_loop(&self) { - trace!("Using polling fallback (checking every 10s)"); + trace!("Using polling fallback (checking every 60s)"); while self.is_processing.load(Ordering::SeqCst) { - tokio::time::sleep(Duration::from_secs(10)).await; + tokio::time::sleep(Duration::from_secs(60)).await; if let Err(e) = self.scan_and_compile_all().await { error!("Scan failed: {}", e); @@ -203,8 +198,6 @@ impl LocalFileMonitor { } async fn compile_gbdialog(&self, bot_name: &str, gbdialog_path: &Path) -> Result<(), Box> { - info!("Compiling bot: {}", bot_name); - let entries = tokio::fs::read_dir(gbdialog_path).await?; let mut entries = entries; @@ -231,6 +224,7 @@ impl LocalFileMonitor { }; if should_compile { + info!("Compiling bot: {}", bot_name); debug!("Recompiling {:?} - modification detected", path); if let Err(e) = self.compile_local_file(&path).await { error!("Failed to compile {:?}: {}", path, e); diff --git a/src/main_module/server.rs b/src/main_module/server.rs index 5685aaecc..6e98274c7 100644 --- a/src/main_module/server.rs +++ b/src/main_module/server.rs @@ -247,6 +247,13 @@ pub async fn run_axum_server( api_router = api_router.merge(crate::research::configure_research_routes()); api_router = api_router.merge(crate::research::ui::configure_research_ui_routes()); } + #[cfg(any(feature = "research", feature = "llm"))] + { + api_router = api_router.route( + "/api/website/force-recrawl", + post(crate::core::kb::website_crawler_service::handle_force_recrawl) + ); + } #[cfg(feature = "sources")] { api_router = api_router.merge(crate::sources::configure_sources_routes());