botserver/src/vector-db/hybrid_search.rs

882 lines
27 KiB
Rust
Raw Normal View History

//! Hybrid Search Module for RAG 2.0
//!
//! Implements hybrid search combining sparse (BM25) and dense (embedding) retrieval
//! with Reciprocal Rank Fusion (RRF) for optimal results.
//!
//! Config.csv properties:
//! ```csv
//! rag-hybrid-enabled,true
//! rag-dense-weight,0.7
//! rag-sparse-weight,0.3
//! rag-reranker-enabled,true
//! rag-reranker-model,cross-encoder/ms-marco-MiniLM-L-6-v2
//! ```
use log::{debug, error, info, trace, warn};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
use crate::shared::state::AppState;
/// Configuration for hybrid search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridSearchConfig {
/// Weight for dense (embedding) search results (0.0 - 1.0)
pub dense_weight: f32,
/// Weight for sparse (BM25) search results (0.0 - 1.0)
pub sparse_weight: f32,
/// Whether to use reranker for final results
pub reranker_enabled: bool,
/// Reranker model name/path
pub reranker_model: String,
/// Maximum number of results to return
pub max_results: usize,
/// Minimum score threshold (0.0 - 1.0)
pub min_score: f32,
/// K parameter for RRF (typically 60)
pub rrf_k: u32,
}
impl Default for HybridSearchConfig {
fn default() -> Self {
Self {
dense_weight: 0.7,
sparse_weight: 0.3,
reranker_enabled: false,
reranker_model: "cross-encoder/ms-marco-MiniLM-L-6-v2".to_string(),
max_results: 10,
min_score: 0.0,
rrf_k: 60,
}
}
}
impl HybridSearchConfig {
/// Load config from bot configuration
pub fn from_bot_config(state: &AppState, bot_id: Uuid) -> Self {
use diesel::prelude::*;
let mut config = Self::default();
if let Ok(mut conn) = state.conn.get() {
#[derive(QueryableByName)]
struct ConfigRow {
#[diesel(sql_type = diesel::sql_types::Text)]
config_key: String,
#[diesel(sql_type = diesel::sql_types::Text)]
config_value: String,
}
let configs: Vec<ConfigRow> = diesel::sql_query(
"SELECT config_key, config_value FROM bot_configuration \
WHERE bot_id = $1 AND config_key LIKE 'rag-%'",
)
.bind::<diesel::sql_types::Uuid, _>(bot_id)
.load(&mut conn)
.unwrap_or_default();
for row in configs {
match row.config_key.as_str() {
"rag-dense-weight" => {
config.dense_weight = row.config_value.parse().unwrap_or(0.7);
}
"rag-sparse-weight" => {
config.sparse_weight = row.config_value.parse().unwrap_or(0.3);
}
"rag-reranker-enabled" => {
config.reranker_enabled = row.config_value.to_lowercase() == "true";
}
"rag-reranker-model" => {
config.reranker_model = row.config_value;
}
"rag-max-results" => {
config.max_results = row.config_value.parse().unwrap_or(10);
}
"rag-min-score" => {
config.min_score = row.config_value.parse().unwrap_or(0.0);
}
"rag-rrf-k" => {
config.rrf_k = row.config_value.parse().unwrap_or(60);
}
_ => {}
}
}
}
// Normalize weights
let total = config.dense_weight + config.sparse_weight;
if total > 0.0 {
config.dense_weight /= total;
config.sparse_weight /= total;
}
config
}
}
/// Search result from any retrieval method
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
/// Unique document identifier
pub doc_id: String,
/// Document content
pub content: String,
/// Source file/email/etc path
pub source: String,
/// Relevance score (0.0 - 1.0)
pub score: f32,
/// Additional metadata
pub metadata: HashMap<String, String>,
/// Search method that produced this result
pub search_method: SearchMethod,
}
/// Search method used to retrieve a result
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum SearchMethod {
Dense,
Sparse,
Hybrid,
Reranked,
}
/// BM25 search index for sparse retrieval
pub struct BM25Index {
/// Document frequency for each term
doc_freq: HashMap<String, usize>,
/// Total number of documents
doc_count: usize,
/// Average document length
avg_doc_len: f32,
/// Document lengths
doc_lengths: HashMap<String, usize>,
/// Term frequencies per document
term_freqs: HashMap<String, HashMap<String, usize>>,
/// BM25 parameters
k1: f32,
b: f32,
}
impl BM25Index {
pub fn new() -> Self {
Self {
doc_freq: HashMap::new(),
doc_count: 0,
avg_doc_len: 0.0,
doc_lengths: HashMap::new(),
term_freqs: HashMap::new(),
k1: 1.2,
b: 0.75,
}
}
/// Add a document to the index
pub fn add_document(&mut self, doc_id: &str, content: &str) {
let terms = self.tokenize(content);
let doc_len = terms.len();
// Update document length
self.doc_lengths.insert(doc_id.to_string(), doc_len);
// Calculate term frequencies
let mut term_freq: HashMap<String, usize> = HashMap::new();
let mut seen_terms: std::collections::HashSet<String> = std::collections::HashSet::new();
for term in &terms {
*term_freq.entry(term.clone()).or_insert(0) += 1;
// Update document frequency (only once per document per term)
if !seen_terms.contains(term) {
*self.doc_freq.entry(term.clone()).or_insert(0) += 1;
seen_terms.insert(term.clone());
}
}
self.term_freqs.insert(doc_id.to_string(), term_freq);
self.doc_count += 1;
// Update average document length
let total_len: usize = self.doc_lengths.values().sum();
self.avg_doc_len = total_len as f32 / self.doc_count as f32;
}
/// Remove a document from the index
pub fn remove_document(&mut self, doc_id: &str) {
if let Some(term_freq) = self.term_freqs.remove(doc_id) {
// Update document frequencies
for term in term_freq.keys() {
if let Some(freq) = self.doc_freq.get_mut(term) {
*freq = freq.saturating_sub(1);
if *freq == 0 {
self.doc_freq.remove(term);
}
}
}
}
self.doc_lengths.remove(doc_id);
self.doc_count = self.doc_count.saturating_sub(1);
// Update average document length
if self.doc_count > 0 {
let total_len: usize = self.doc_lengths.values().sum();
self.avg_doc_len = total_len as f32 / self.doc_count as f32;
} else {
self.avg_doc_len = 0.0;
}
}
/// Search the index with BM25 scoring
pub fn search(&self, query: &str, max_results: usize) -> Vec<(String, f32)> {
let query_terms = self.tokenize(query);
let mut scores: HashMap<String, f32> = HashMap::new();
for term in &query_terms {
let df = *self.doc_freq.get(term).unwrap_or(&0);
if df == 0 {
continue;
}
// IDF calculation
let idf = ((self.doc_count as f32 - df as f32 + 0.5) / (df as f32 + 0.5) + 1.0).ln();
for (doc_id, term_freqs) in &self.term_freqs {
if let Some(&tf) = term_freqs.get(term) {
let doc_len = *self.doc_lengths.get(doc_id).unwrap_or(&1) as f32;
let tf_normalized = (tf as f32 * (self.k1 + 1.0))
/ (tf as f32
+ self.k1 * (1.0 - self.b + self.b * (doc_len / self.avg_doc_len)));
*scores.entry(doc_id.clone()).or_insert(0.0) += idf * tf_normalized;
}
}
}
// Sort by score and return top results
let mut results: Vec<(String, f32)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(max_results);
results
}
/// Tokenize text into terms
fn tokenize(&self, text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| s.len() > 2) // Filter out very short tokens
.map(|s| s.to_string())
.collect()
}
/// Get index statistics
pub fn stats(&self) -> BM25Stats {
BM25Stats {
doc_count: self.doc_count,
unique_terms: self.doc_freq.len(),
avg_doc_len: self.avg_doc_len,
}
}
}
impl Default for BM25Index {
fn default() -> Self {
Self::new()
}
}
/// BM25 index statistics
#[derive(Debug, Clone)]
pub struct BM25Stats {
pub doc_count: usize,
pub unique_terms: usize,
pub avg_doc_len: f32,
}
/// Hybrid search engine combining dense and sparse retrieval
pub struct HybridSearchEngine {
/// BM25 sparse index
bm25_index: BM25Index,
/// Document store for content retrieval
documents: HashMap<String, DocumentEntry>,
/// Configuration
config: HybridSearchConfig,
/// Qdrant URL for dense search
qdrant_url: String,
/// Collection name
collection_name: String,
}
/// Document entry in the store
#[derive(Debug, Clone)]
struct DocumentEntry {
pub content: String,
pub source: String,
pub metadata: HashMap<String, String>,
}
impl HybridSearchEngine {
pub fn new(config: HybridSearchConfig, qdrant_url: &str, collection_name: &str) -> Self {
Self {
bm25_index: BM25Index::new(),
documents: HashMap::new(),
config,
qdrant_url: qdrant_url.to_string(),
collection_name: collection_name.to_string(),
}
}
/// Index a document for both dense and sparse search
pub async fn index_document(
&mut self,
doc_id: &str,
content: &str,
source: &str,
metadata: HashMap<String, String>,
embedding: Option<Vec<f32>>,
) -> Result<(), String> {
// Add to BM25 index
self.bm25_index.add_document(doc_id, content);
// Store document
self.documents.insert(
doc_id.to_string(),
DocumentEntry {
content: content.to_string(),
source: source.to_string(),
metadata,
},
);
// If embedding provided, add to Qdrant
if let Some(emb) = embedding {
self.upsert_to_qdrant(doc_id, &emb).await?;
}
Ok(())
}
/// Remove a document from all indexes
pub async fn remove_document(&mut self, doc_id: &str) -> Result<(), String> {
self.bm25_index.remove_document(doc_id);
self.documents.remove(doc_id);
self.delete_from_qdrant(doc_id).await?;
Ok(())
}
/// Perform hybrid search
pub async fn search(
&self,
query: &str,
query_embedding: Option<Vec<f32>>,
) -> Result<Vec<SearchResult>, String> {
let fetch_count = self.config.max_results * 3; // Fetch more for fusion
// Sparse search (BM25)
let sparse_results = self.bm25_index.search(query, fetch_count);
trace!(
"BM25 search returned {} results for query: {}",
sparse_results.len(),
query
);
// Dense search (Qdrant)
let dense_results = if let Some(embedding) = query_embedding {
self.search_qdrant(&embedding, fetch_count).await?
} else {
Vec::new()
};
trace!(
"Dense search returned {} results for query: {}",
dense_results.len(),
query
);
// Reciprocal Rank Fusion
let fused_results = self.reciprocal_rank_fusion(&sparse_results, &dense_results);
trace!("RRF produced {} fused results", fused_results.len());
// Convert to SearchResult
let mut results: Vec<SearchResult> = fused_results
.into_iter()
.filter_map(|(doc_id, score)| {
self.documents.get(&doc_id).map(|doc| SearchResult {
doc_id,
content: doc.content.clone(),
source: doc.source.clone(),
score,
metadata: doc.metadata.clone(),
search_method: SearchMethod::Hybrid,
})
})
.filter(|r| r.score >= self.config.min_score)
.take(self.config.max_results)
.collect();
// Optional reranking
if self.config.reranker_enabled && !results.is_empty() {
results = self.rerank(query, results).await?;
}
Ok(results)
}
/// Perform only sparse (BM25) search
pub fn sparse_search(&self, query: &str) -> Vec<SearchResult> {
let results = self.bm25_index.search(query, self.config.max_results);
results
.into_iter()
.filter_map(|(doc_id, score)| {
self.documents.get(&doc_id).map(|doc| SearchResult {
doc_id,
content: doc.content.clone(),
source: doc.source.clone(),
score,
metadata: doc.metadata.clone(),
search_method: SearchMethod::Sparse,
})
})
.collect()
}
/// Perform only dense (embedding) search
pub async fn dense_search(
&self,
query_embedding: Vec<f32>,
) -> Result<Vec<SearchResult>, String> {
let results = self
.search_qdrant(&query_embedding, self.config.max_results)
.await?;
let search_results: Vec<SearchResult> = results
.into_iter()
.filter_map(|(doc_id, score)| {
self.documents.get(&doc_id).map(|doc| SearchResult {
doc_id,
content: doc.content.clone(),
source: doc.source.clone(),
score,
metadata: doc.metadata.clone(),
search_method: SearchMethod::Dense,
})
})
.collect();
Ok(search_results)
}
/// Reciprocal Rank Fusion algorithm
fn reciprocal_rank_fusion(
&self,
sparse: &[(String, f32)],
dense: &[(String, f32)],
) -> Vec<(String, f32)> {
let k = self.config.rrf_k as f32;
let mut scores: HashMap<String, f32> = HashMap::new();
// Score from sparse results
for (rank, (doc_id, _)) in sparse.iter().enumerate() {
let rrf_score = self.config.sparse_weight / (k + rank as f32 + 1.0);
*scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
}
// Score from dense results
for (rank, (doc_id, _)) in dense.iter().enumerate() {
let rrf_score = self.config.dense_weight / (k + rank as f32 + 1.0);
*scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
}
// Sort by combined score
let mut results: Vec<(String, f32)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
// Normalize scores to 0-1 range
if let Some((_, max_score)) = results.first() {
if *max_score > 0.0 {
for (_, score) in &mut results {
*score /= max_score;
}
}
}
results
}
/// Rerank results using cross-encoder model
async fn rerank(
&self,
query: &str,
results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, String> {
// In a full implementation, this would call a cross-encoder model
// For now, we'll use a simple relevance heuristic
let mut reranked = results;
for result in &mut reranked {
// Simple reranking based on query term overlap
let query_terms: std::collections::HashSet<&str> =
query.to_lowercase().split_whitespace().collect();
let content_lower = result.content.to_lowercase();
let mut overlap_score = 0.0;
for term in &query_terms {
if content_lower.contains(term) {
overlap_score += 1.0;
}
}
// Combine original score with overlap
let overlap_normalized = overlap_score / query_terms.len().max(1) as f32;
result.score = result.score * 0.7 + overlap_normalized * 0.3;
result.search_method = SearchMethod::Reranked;
}
// Re-sort by new scores
reranked.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
Ok(reranked)
}
/// Search Qdrant for similar vectors
async fn search_qdrant(
&self,
embedding: &[f32],
limit: usize,
) -> Result<Vec<(String, f32)>, String> {
let client = reqwest::Client::new();
let search_request = serde_json::json!({
"vector": embedding,
"limit": limit,
"with_payload": false
});
let response = client
.post(&format!(
"{}/collections/{}/points/search",
self.qdrant_url, self.collection_name
))
.json(&search_request)
.send()
.await
.map_err(|e| format!("Qdrant search failed: {}", e))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(format!("Qdrant search error: {}", error_text));
}
let result: serde_json::Value = response
.json()
.await
.map_err(|e| format!("Failed to parse Qdrant response: {}", e))?;
let points = result["result"]
.as_array()
.ok_or("Invalid Qdrant response format")?;
let results: Vec<(String, f32)> = points
.iter()
.filter_map(|p| {
let id = p["id"].as_str().map(|s| s.to_string())?;
let score = p["score"].as_f64()? as f32;
Some((id, score))
})
.collect();
Ok(results)
}
/// Upsert vector to Qdrant
async fn upsert_to_qdrant(&self, doc_id: &str, embedding: &[f32]) -> Result<(), String> {
let client = reqwest::Client::new();
let upsert_request = serde_json::json!({
"points": [{
"id": doc_id,
"vector": embedding
}]
});
let response = client
.put(&format!(
"{}/collections/{}/points",
self.qdrant_url, self.collection_name
))
.json(&upsert_request)
.send()
.await
.map_err(|e| format!("Qdrant upsert failed: {}", e))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(format!("Qdrant upsert error: {}", error_text));
}
Ok(())
}
/// Delete vector from Qdrant
async fn delete_from_qdrant(&self, doc_id: &str) -> Result<(), String> {
let client = reqwest::Client::new();
let delete_request = serde_json::json!({
"points": [doc_id]
});
let response = client
.post(&format!(
"{}/collections/{}/points/delete",
self.qdrant_url, self.collection_name
))
.json(&delete_request)
.send()
.await
.map_err(|e| format!("Qdrant delete failed: {}", e))?;
if !response.status().is_success() {
warn!(
"Qdrant delete may have failed for {}: {}",
doc_id,
response.status()
);
}
Ok(())
}
/// Get engine statistics
pub fn stats(&self) -> HybridSearchStats {
let bm25_stats = self.bm25_index.stats();
HybridSearchStats {
total_documents: self.documents.len(),
bm25_doc_count: bm25_stats.doc_count,
unique_terms: bm25_stats.unique_terms,
avg_doc_len: bm25_stats.avg_doc_len,
config: self.config.clone(),
}
}
}
/// Hybrid search engine statistics
#[derive(Debug, Clone)]
pub struct HybridSearchStats {
pub total_documents: usize,
pub bm25_doc_count: usize,
pub unique_terms: usize,
pub avg_doc_len: f32,
pub config: HybridSearchConfig,
}
/// Query decomposition for complex questions
pub struct QueryDecomposer {
llm_endpoint: String,
api_key: String,
}
impl QueryDecomposer {
pub fn new(llm_endpoint: &str, api_key: &str) -> Self {
Self {
llm_endpoint: llm_endpoint.to_string(),
api_key: api_key.to_string(),
}
}
/// Decompose a complex query into simpler sub-queries
pub async fn decompose(&self, query: &str) -> Result<Vec<String>, String> {
// Simple heuristic decomposition for common patterns
// A full implementation would use an LLM
let mut sub_queries = Vec::new();
// Check for conjunctions
let conjunctions = ["and", "also", "as well as", "in addition to"];
let mut parts: Vec<&str> = vec![query];
for conj in &conjunctions {
parts = parts
.iter()
.flat_map(|p| p.split(conj))
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
}
if parts.len() > 1 {
for part in parts {
sub_queries.push(part.to_string());
}
} else {
// Try question word splitting
let question_words = ["what", "how", "why", "when", "where", "who"];
let lower = query.to_lowercase();
let mut has_multiple_questions = false;
for qw in &question_words {
if lower.matches(qw).count() > 1 {
has_multiple_questions = true;
break;
}
}
if has_multiple_questions {
// Split on question marks or question words
for part in query.split('?') {
let trimmed = part.trim();
if !trimmed.is_empty() {
sub_queries.push(format!("{}?", trimmed));
}
}
}
}
// If no decomposition happened, return original query
if sub_queries.is_empty() {
sub_queries.push(query.to_string());
}
Ok(sub_queries)
}
/// Synthesize answers from multiple sub-query results
pub fn synthesize(&self, query: &str, sub_answers: &[String]) -> String {
if sub_answers.len() == 1 {
return sub_answers[0].clone();
}
// Simple concatenation with context
let mut synthesis = format!("Based on your question about \"{}\", here's what I found:\n\n", query);
for (i, answer) in sub_answers.iter().enumerate() {
synthesis.push_str(&format!("{}. {}\n\n", i + 1, answer));
}
synthesis
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_index_basic() {
let mut index = BM25Index::new();
index.add_document("doc1", "The quick brown fox jumps over the lazy dog");
index.add_document("doc2", "A quick brown dog runs in the park");
index.add_document("doc3", "The lazy cat sleeps all day");
let stats = index.stats();
assert_eq!(stats.doc_count, 3);
assert!(stats.avg_doc_len > 0.0);
}
#[test]
fn test_bm25_search() {
let mut index = BM25Index::new();
index.add_document("doc1", "machine learning artificial intelligence");
index.add_document("doc2", "natural language processing NLP");
index.add_document("doc3", "computer vision image recognition");
let results = index.search("machine learning", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, "doc1"); // doc1 should be first
}
#[test]
fn test_bm25_remove_document() {
let mut index = BM25Index::new();
index.add_document("doc1", "test document one");
index.add_document("doc2", "test document two");
assert_eq!(index.stats().doc_count, 2);
index.remove_document("doc1");
assert_eq!(index.stats().doc_count, 1);
let results = index.search("one", 10);
assert!(results.is_empty() || results[0].0 != "doc1");
}
#[test]
fn test_hybrid_config_default() {
let config = HybridSearchConfig::default();
assert_eq!(config.dense_weight, 0.7);
assert_eq!(config.sparse_weight, 0.3);
assert!(!config.reranker_enabled);
assert_eq!(config.max_results, 10);
}
#[test]
fn test_reciprocal_rank_fusion() {
let config = HybridSearchConfig::default();
let engine = HybridSearchEngine::new(config, "http://localhost:6333", "test");
let sparse = vec![
("doc1".to_string(), 0.9),
("doc2".to_string(), 0.7),
("doc3".to_string(), 0.5),
];
let dense = vec![
("doc2".to_string(), 0.95),
("doc1".to_string(), 0.8),
("doc4".to_string(), 0.6),
];
let fused = engine.reciprocal_rank_fusion(&sparse, &dense);
// doc1 and doc2 should be in top results as they appear in both
assert!(!fused.is_empty());
let top_ids: Vec<&str> = fused.iter().take(2).map(|(id, _)| id.as_str()).collect();
assert!(top_ids.contains(&"doc1") || top_ids.contains(&"doc2"));
}
#[test]
fn test_query_decomposer_simple() {
let decomposer = QueryDecomposer::new("http://localhost:8081", "none");
// Use tokio runtime for async test
let rt = tokio::runtime::Runtime::new().unwrap();
let result = rt.block_on(async {
decomposer
.decompose("What is machine learning and how does it work?")
.await
});
assert!(result.is_ok());
let queries = result.unwrap();
assert!(!queries.is_empty());
}
#[test]
fn test_search_result_serialization() {
let result = SearchResult {
doc_id: "test123".to_string(),
content: "Test content".to_string(),
source: "/path/to/file".to_string(),
score: 0.85,
metadata: HashMap::new(),
search_method: SearchMethod::Hybrid,
};
let json = serde_json::to_string(&result);
assert!(json.is_ok());
let parsed: Result<SearchResult, _> = serde_json::from_str(&json.unwrap());
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap().doc_id, "test123");
}
}