botserver/src/core/kb/embedding_generator.rs
Rodrigo Rodriguez (Pragmatismo) f40cb6c7b4 Fix typos in bot file extensions and keyword names
Changed incorrect references to .vbs files to .bas and corrected
USE_WEBSITE keyword naming. Also added missing fields to API response
structure and clarified that start.bas is optional for bots.
2025-11-26 22:54:22 -03:00

443 lines
13 KiB
Rust

use anyhow::{Context, Result};
use log::{debug, info, warn};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Semaphore;
use super::document_processor::TextChunk;
/// Embedding model configuration
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
/// URL for the embedding service (e.g., http://localhost:8082)
pub embedding_url: String,
/// Model name/path for embeddings (e.g., bge-small-en-v1.5)
pub embedding_model: String,
/// Dimension of embeddings (e.g., 384, 768, 1536)
pub dimensions: usize,
/// Maximum batch size for embedding generation
pub batch_size: usize,
/// Request timeout in seconds
pub timeout_seconds: u64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
embedding_url: "http://localhost:8082".to_string(),
embedding_model: "bge-small-en-v1.5".to_string(),
dimensions: 384, // Default for bge-small
batch_size: 32,
timeout_seconds: 30,
}
}
}
impl EmbeddingConfig {
/// Create config from environment or config.csv values
pub fn from_env() -> Self {
let embedding_url =
std::env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
let embedding_model =
std::env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "bge-small-en-v1.5".to_string());
// Detect dimensions based on model name
let dimensions = Self::detect_dimensions(&embedding_model);
Self {
embedding_url,
embedding_model,
dimensions,
batch_size: 32,
timeout_seconds: 30,
}
}
/// Detect embedding dimensions based on model name
fn detect_dimensions(model: &str) -> usize {
if model.contains("small") || model.contains("MiniLM") {
384
} else if model.contains("base") || model.contains("mpnet") {
768
} else if model.contains("large") || model.contains("ada") {
1536
} else {
384 // Default
}
}
}
/// Request payload for embedding generation
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: Vec<String>,
model: String,
}
/// Response from embedding service
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
model: String,
usage: Option<EmbeddingUsage>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
index: usize,
}
#[derive(Debug, Deserialize)]
struct EmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
/// Generated embedding with metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
pub vector: Vec<f32>,
pub dimensions: usize,
pub model: String,
pub tokens_used: Option<usize>,
}
/// Knowledge base embedding generator
pub struct KbEmbeddingGenerator {
config: EmbeddingConfig,
client: Client,
semaphore: Arc<Semaphore>,
}
impl std::fmt::Debug for KbEmbeddingGenerator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KbEmbeddingGenerator")
.field("config", &self.config)
.field("client", &"Client")
.field("semaphore", &"Semaphore")
.finish()
}
}
impl KbEmbeddingGenerator {
pub fn new(config: EmbeddingConfig) -> Self {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.timeout_seconds))
.build()
.expect("Failed to create HTTP client");
// Limit concurrent requests
let semaphore = Arc::new(Semaphore::new(4));
Self {
config,
client,
semaphore,
}
}
/// Generate embeddings for text chunks
pub async fn generate_embeddings(
&self,
chunks: &[TextChunk],
) -> Result<Vec<(TextChunk, Embedding)>> {
if chunks.is_empty() {
return Ok(Vec::new());
}
info!("Generating embeddings for {} chunks", chunks.len());
let mut results = Vec::new();
// Process in batches
for batch in chunks.chunks(self.config.batch_size) {
let batch_embeddings = self.generate_batch_embeddings(batch).await?;
// Pair chunks with their embeddings
for (chunk, embedding) in batch.iter().zip(batch_embeddings.iter()) {
results.push((chunk.clone(), embedding.clone()));
}
}
info!("Generated {} embeddings", results.len());
Ok(results)
}
/// Generate embeddings for a batch of chunks
async fn generate_batch_embeddings(&self, chunks: &[TextChunk]) -> Result<Vec<Embedding>> {
let _permit = self.semaphore.acquire().await?;
let texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
debug!("Generating embeddings for batch of {} texts", texts.len());
// Try local embedding service first
match self.generate_local_embeddings(&texts).await {
Ok(embeddings) => Ok(embeddings),
Err(e) => {
warn!("Local embedding service failed: {}, trying OpenAI API", e);
self.generate_openai_embeddings(&texts).await
}
}
}
/// Generate embeddings using local service
async fn generate_local_embeddings(&self, texts: &[String]) -> Result<Vec<Embedding>> {
let request = EmbeddingRequest {
input: texts.to_vec(),
model: self.config.embedding_model.clone(),
};
let response = self
.client
.post(&format!("{}/embeddings", self.config.embedding_url))
.json(&request)
.send()
.await
.context("Failed to send request to embedding service")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"Embedding service error {}: {}",
status,
error_text
));
}
let embedding_response: EmbeddingResponse = response
.json()
.await
.context("Failed to parse embedding response")?;
let mut embeddings = Vec::new();
for data in embedding_response.data {
embeddings.push(Embedding {
vector: data.embedding,
dimensions: self.config.dimensions,
model: embedding_response.model.clone(),
tokens_used: embedding_response.usage.as_ref().map(|u| u.total_tokens),
});
}
Ok(embeddings)
}
/// Generate embeddings using OpenAI API (fallback)
async fn generate_openai_embeddings(&self, texts: &[String]) -> Result<Vec<Embedding>> {
let api_key = std::env::var("OPENAI_API_KEY")
.context("OPENAI_API_KEY not set for fallback embedding generation")?;
let request = serde_json::json!({
"input": texts,
"model": "text-embedding-ada-002"
});
let response = self
.client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", api_key))
.json(&request)
.send()
.await
.context("Failed to send request to OpenAI")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"OpenAI API error {}: {}",
status,
error_text
));
}
let response_json: serde_json::Value = response
.json()
.await
.context("Failed to parse OpenAI response")?;
let mut embeddings = Vec::new();
if let Some(data) = response_json["data"].as_array() {
for item in data {
if let Some(embedding) = item["embedding"].as_array() {
let vector: Vec<f32> = embedding
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
embeddings.push(Embedding {
vector,
dimensions: 1536, // OpenAI ada-002 dimensions
model: "text-embedding-ada-002".to_string(),
tokens_used: response_json["usage"]["total_tokens"]
.as_u64()
.map(|t| t as usize),
});
}
}
}
Ok(embeddings)
}
/// Generate embedding for a single text
pub async fn generate_single_embedding(&self, text: &str) -> Result<Embedding> {
let embeddings = self
.generate_batch_embeddings(&[TextChunk {
content: text.to_string(),
metadata: super::document_processor::ChunkMetadata {
document_path: "query".to_string(),
document_title: None,
chunk_index: 0,
total_chunks: 1,
start_char: 0,
end_char: text.len(),
page_number: None,
},
}])
.await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No embedding generated"))
}
}
/// Generic embedding generator for other uses (email, etc.)
pub struct EmbeddingGenerator {
kb_generator: KbEmbeddingGenerator,
}
impl std::fmt::Debug for EmbeddingGenerator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingGenerator")
.field("kb_generator", &self.kb_generator)
.finish()
}
}
impl EmbeddingGenerator {
pub fn new(llm_endpoint: String) -> Self {
let config = EmbeddingConfig {
embedding_url: llm_endpoint,
..Default::default()
};
Self {
kb_generator: KbEmbeddingGenerator::new(config),
}
}
/// Generate embedding for arbitrary text
pub async fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
let embedding = self.kb_generator.generate_single_embedding(text).await?;
Ok(embedding.vector)
}
}
/// Email-specific embedding generator (for compatibility)
pub struct EmailEmbeddingGenerator {
generator: EmbeddingGenerator,
}
impl std::fmt::Debug for EmailEmbeddingGenerator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmailEmbeddingGenerator")
.field("generator", &self.generator)
.finish()
}
}
impl EmailEmbeddingGenerator {
pub fn new(llm_endpoint: String) -> Self {
Self {
generator: EmbeddingGenerator::new(llm_endpoint),
}
}
/// Generate embedding for email content
pub async fn generate_embedding(&self, email: &impl EmailLike) -> Result<Vec<f32>> {
let text = format!(
"Subject: {}\nFrom: {}\nTo: {}\n\n{}",
email.subject(),
email.from(),
email.to(),
email.body()
);
self.generator.generate_text_embedding(&text).await
}
/// Generate embedding for text
pub async fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
self.generator.generate_text_embedding(text).await
}
}
/// Trait for email-like objects
pub trait EmailLike {
fn subject(&self) -> &str;
fn from(&self) -> &str;
fn to(&self) -> &str;
fn body(&self) -> &str;
}
/// Simple email struct for testing
#[derive(Debug)]
pub struct SimpleEmail {
pub id: String,
pub subject: String,
pub from: String,
pub to: String,
pub body: String,
}
impl EmailLike for SimpleEmail {
fn subject(&self) -> &str {
&self.subject
}
fn from(&self) -> &str {
&self.from
}
fn to(&self) -> &str {
&self.to
}
fn body(&self) -> &str {
&self.body
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dimension_detection() {
assert_eq!(EmbeddingConfig::detect_dimensions("bge-small-en"), 384);
assert_eq!(EmbeddingConfig::detect_dimensions("all-mpnet-base-v2"), 768);
assert_eq!(
EmbeddingConfig::detect_dimensions("text-embedding-ada-002"),
1536
);
assert_eq!(EmbeddingConfig::detect_dimensions("unknown-model"), 384);
}
#[tokio::test]
async fn test_text_cleaning_for_embedding() {
let text = "This is a test\n\nWith multiple lines";
let generator = EmbeddingGenerator::new("http://localhost:8082".to_string());
// This would test actual embedding generation if service is available
// For unit tests, we just verify the structure is correct
assert!(!text.is_empty());
}
}