use anyhow::{Context, Result}; use log::{debug, info, warn}; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::Semaphore; use super::document_processor::TextChunk; /// Embedding model configuration #[derive(Debug, Clone)] pub struct EmbeddingConfig { /// URL for the embedding service (e.g., http://localhost:8082) pub embedding_url: String, /// Model name/path for embeddings (e.g., bge-small-en-v1.5) pub embedding_model: String, /// Dimension of embeddings (e.g., 384, 768, 1536) pub dimensions: usize, /// Maximum batch size for embedding generation pub batch_size: usize, /// Request timeout in seconds pub timeout_seconds: u64, } impl Default for EmbeddingConfig { fn default() -> Self { Self { embedding_url: "http://localhost:8082".to_string(), embedding_model: "bge-small-en-v1.5".to_string(), dimensions: 384, // Default for bge-small batch_size: 32, timeout_seconds: 30, } } } impl EmbeddingConfig { /// Create config from environment or config.csv values pub fn from_env() -> Self { let embedding_url = std::env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); let embedding_model = std::env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "bge-small-en-v1.5".to_string()); // Detect dimensions based on model name let dimensions = Self::detect_dimensions(&embedding_model); Self { embedding_url, embedding_model, dimensions, batch_size: 32, timeout_seconds: 30, } } /// Detect embedding dimensions based on model name fn detect_dimensions(model: &str) -> usize { if model.contains("small") || model.contains("MiniLM") { 384 } else if model.contains("base") || model.contains("mpnet") { 768 } else if model.contains("large") || model.contains("ada") { 1536 } else { 384 // Default } } } /// Request payload for embedding generation #[derive(Debug, Serialize)] struct EmbeddingRequest { input: Vec, model: String, } /// Response from embedding service #[derive(Debug, Deserialize)] struct EmbeddingResponse { data: Vec, model: String, usage: Option, } #[derive(Debug, Deserialize)] struct EmbeddingData { embedding: Vec, index: usize, } #[derive(Debug, Deserialize)] struct EmbeddingUsage { prompt_tokens: usize, total_tokens: usize, } /// Generated embedding with metadata #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Embedding { pub vector: Vec, pub dimensions: usize, pub model: String, pub tokens_used: Option, } /// Knowledge base embedding generator pub struct KbEmbeddingGenerator { config: EmbeddingConfig, client: Client, semaphore: Arc, } impl std::fmt::Debug for KbEmbeddingGenerator { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("KbEmbeddingGenerator") .field("config", &self.config) .field("client", &"Client") .field("semaphore", &"Semaphore") .finish() } } impl KbEmbeddingGenerator { pub fn new(config: EmbeddingConfig) -> Self { let client = Client::builder() .timeout(std::time::Duration::from_secs(config.timeout_seconds)) .build() .expect("Failed to create HTTP client"); // Limit concurrent requests let semaphore = Arc::new(Semaphore::new(4)); Self { config, client, semaphore, } } /// Generate embeddings for text chunks pub async fn generate_embeddings( &self, chunks: &[TextChunk], ) -> Result> { if chunks.is_empty() { return Ok(Vec::new()); } info!("Generating embeddings for {} chunks", chunks.len()); let mut results = Vec::new(); // Process in batches for batch in chunks.chunks(self.config.batch_size) { let batch_embeddings = self.generate_batch_embeddings(batch).await?; // Pair chunks with their embeddings for (chunk, embedding) in batch.iter().zip(batch_embeddings.iter()) { results.push((chunk.clone(), embedding.clone())); } } info!("Generated {} embeddings", results.len()); Ok(results) } /// Generate embeddings for a batch of chunks async fn generate_batch_embeddings(&self, chunks: &[TextChunk]) -> Result> { let _permit = self.semaphore.acquire().await?; let texts: Vec = chunks.iter().map(|c| c.content.clone()).collect(); debug!("Generating embeddings for batch of {} texts", texts.len()); // Try local embedding service first match self.generate_local_embeddings(&texts).await { Ok(embeddings) => Ok(embeddings), Err(e) => { warn!("Local embedding service failed: {}, trying OpenAI API", e); self.generate_openai_embeddings(&texts).await } } } /// Generate embeddings using local service async fn generate_local_embeddings(&self, texts: &[String]) -> Result> { let request = EmbeddingRequest { input: texts.to_vec(), model: self.config.embedding_model.clone(), }; let response = self .client .post(&format!("{}/embeddings", self.config.embedding_url)) .json(&request) .send() .await .context("Failed to send request to embedding service")?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); return Err(anyhow::anyhow!( "Embedding service error {}: {}", status, error_text )); } let embedding_response: EmbeddingResponse = response .json() .await .context("Failed to parse embedding response")?; let mut embeddings = Vec::new(); for data in embedding_response.data { embeddings.push(Embedding { vector: data.embedding, dimensions: self.config.dimensions, model: embedding_response.model.clone(), tokens_used: embedding_response.usage.as_ref().map(|u| u.total_tokens), }); } Ok(embeddings) } /// Generate embeddings using OpenAI API (fallback) async fn generate_openai_embeddings(&self, texts: &[String]) -> Result> { let api_key = std::env::var("OPENAI_API_KEY") .context("OPENAI_API_KEY not set for fallback embedding generation")?; let request = serde_json::json!({ "input": texts, "model": "text-embedding-ada-002" }); let response = self .client .post("https://api.openai.com/v1/embeddings") .header("Authorization", format!("Bearer {}", api_key)) .json(&request) .send() .await .context("Failed to send request to OpenAI")?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); return Err(anyhow::anyhow!( "OpenAI API error {}: {}", status, error_text )); } let response_json: serde_json::Value = response .json() .await .context("Failed to parse OpenAI response")?; let mut embeddings = Vec::new(); if let Some(data) = response_json["data"].as_array() { for item in data { if let Some(embedding) = item["embedding"].as_array() { let vector: Vec = embedding .iter() .filter_map(|v| v.as_f64().map(|f| f as f32)) .collect(); embeddings.push(Embedding { vector, dimensions: 1536, // OpenAI ada-002 dimensions model: "text-embedding-ada-002".to_string(), tokens_used: response_json["usage"]["total_tokens"] .as_u64() .map(|t| t as usize), }); } } } Ok(embeddings) } /// Generate embedding for a single text pub async fn generate_single_embedding(&self, text: &str) -> Result { let embeddings = self .generate_batch_embeddings(&[TextChunk { content: text.to_string(), metadata: super::document_processor::ChunkMetadata { document_path: "query".to_string(), document_title: None, chunk_index: 0, total_chunks: 1, start_char: 0, end_char: text.len(), page_number: None, }, }]) .await?; embeddings .into_iter() .next() .ok_or_else(|| anyhow::anyhow!("No embedding generated")) } } /// Generic embedding generator for other uses (email, etc.) pub struct EmbeddingGenerator { kb_generator: KbEmbeddingGenerator, } impl std::fmt::Debug for EmbeddingGenerator { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EmbeddingGenerator") .field("kb_generator", &self.kb_generator) .finish() } } impl EmbeddingGenerator { pub fn new(llm_endpoint: String) -> Self { let config = EmbeddingConfig { embedding_url: llm_endpoint, ..Default::default() }; Self { kb_generator: KbEmbeddingGenerator::new(config), } } /// Generate embedding for arbitrary text pub async fn generate_text_embedding(&self, text: &str) -> Result> { let embedding = self.kb_generator.generate_single_embedding(text).await?; Ok(embedding.vector) } } /// Email-specific embedding generator (for compatibility) pub struct EmailEmbeddingGenerator { generator: EmbeddingGenerator, } impl std::fmt::Debug for EmailEmbeddingGenerator { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EmailEmbeddingGenerator") .field("generator", &self.generator) .finish() } } impl EmailEmbeddingGenerator { pub fn new(llm_endpoint: String) -> Self { Self { generator: EmbeddingGenerator::new(llm_endpoint), } } /// Generate embedding for email content pub async fn generate_embedding(&self, email: &impl EmailLike) -> Result> { let text = format!( "Subject: {}\nFrom: {}\nTo: {}\n\n{}", email.subject(), email.from(), email.to(), email.body() ); self.generator.generate_text_embedding(&text).await } /// Generate embedding for text pub async fn generate_text_embedding(&self, text: &str) -> Result> { self.generator.generate_text_embedding(text).await } } /// Trait for email-like objects pub trait EmailLike { fn subject(&self) -> &str; fn from(&self) -> &str; fn to(&self) -> &str; fn body(&self) -> &str; } /// Simple email struct for testing #[derive(Debug)] pub struct SimpleEmail { pub id: String, pub subject: String, pub from: String, pub to: String, pub body: String, } impl EmailLike for SimpleEmail { fn subject(&self) -> &str { &self.subject } fn from(&self) -> &str { &self.from } fn to(&self) -> &str { &self.to } fn body(&self) -> &str { &self.body } } #[cfg(test)] mod tests { use super::*; #[test] fn test_dimension_detection() { assert_eq!(EmbeddingConfig::detect_dimensions("bge-small-en"), 384); assert_eq!(EmbeddingConfig::detect_dimensions("all-mpnet-base-v2"), 768); assert_eq!( EmbeddingConfig::detect_dimensions("text-embedding-ada-002"), 1536 ); assert_eq!(EmbeddingConfig::detect_dimensions("unknown-model"), 384); } #[tokio::test] async fn test_text_cleaning_for_embedding() { let text = "This is a test\n\nWith multiple lines"; let generator = EmbeddingGenerator::new("http://localhost:8082".to_string()); // This would test actual embedding generation if service is available // For unit tests, we just verify the structure is correct assert!(!text.is_empty()); } }