From e389178a360b7786d429370941e570bdc3b03cda Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Wed, 4 Mar 2026 16:54:25 -0300 Subject: [PATCH] feat: Add HuggingFace embedding API support with authentication - Added api_key field to LocalEmbeddingService for authentication - Added embedding-key config parameter support in bootstrap - Smart URL handling: doesn't append /embedding for full URLs (HuggingFace, OpenAI, etc.) - Detects HuggingFace URLs and uses correct request format (inputs instead of input/model) - Handles multiple response formats: - HuggingFace: direct array [0.1, 0.2, ...] - Standard/OpenAI: {"data": [{"embedding": [...]}]} - Added Authorization header with Bearer token when api_key is provided - Improved error messages with full response details Fixes embedding errors when using HuggingFace Inference API --- src/llm/cache.rs | 69 ++++++++++++++++++++++++++++-------- src/main_module/bootstrap.rs | 5 +++ 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/src/llm/cache.rs b/src/llm/cache.rs index cb753c8ab..e914d2dc3 100644 --- a/src/llm/cache.rs +++ b/src/llm/cache.rs @@ -605,13 +605,15 @@ impl LLMProvider for CachedLLMProvider { pub struct LocalEmbeddingService { embedding_url: String, model: String, + api_key: Option, } impl LocalEmbeddingService { - pub fn new(embedding_url: String, model: String) -> Self { + pub fn new(embedding_url: String, model: String, api_key: Option) -> Self { Self { embedding_url, model, + api_key, } } } @@ -623,12 +625,37 @@ impl EmbeddingService for LocalEmbeddingService { text: &str, ) -> Result, Box> { let client = reqwest::Client::new(); - let response = client - .post(format!("{}/embedding", self.embedding_url)) - .json(&serde_json::json!({ + + // Determine if URL already includes endpoint path + let url = if self.embedding_url.contains("/pipeline/") || + self.embedding_url.contains("/v1/") || + self.embedding_url.ends_with("/embeddings") { + self.embedding_url.clone() + } else { + format!("{}/embedding", self.embedding_url) + }; + + let mut request = client.post(&url); + + // Add authorization header if API key is provided + if let Some(ref api_key) = self.api_key { + request = request.header("Authorization", format!("Bearer {}", api_key)); + } + + // Determine request body format based on URL + let request_body = if self.embedding_url.contains("huggingface.co") { + serde_json::json!({ + "inputs": text, + }) + } else { + serde_json::json!({ "input": text, "model": self.model, - })) + }) + }; + + let response = request + .json(&request_body) .send() .await?; @@ -659,15 +686,29 @@ impl EmbeddingService for LocalEmbeddingService { return Err(format!("Embedding service error: {}", error).into()); } - let embedding = result["data"][0]["embedding"] - .as_array() - .ok_or_else(|| { - debug!("Invalid embedding response format. Expected data[0].embedding array. Got: {}", response_text); - format!("Invalid embedding response format - Expected data[0].embedding array, got: {}", response_text) - })? - .iter() - .filter_map(|v| v.as_f64().map(|f| f as f32)) - .collect(); + // Try multiple response formats + let embedding = if let Some(arr) = result.as_array() { + // HuggingFace format: direct array [0.1, 0.2, ...] + arr.iter() + .filter_map(|v| v.as_f64().map(|f| f as f32)) + .collect() + } else if let Some(data) = result.get("data") { + // OpenAI/Standard format: {"data": [{"embedding": [...]}]} + data[0]["embedding"] + .as_array() + .ok_or_else(|| { + debug!("Invalid embedding response format. Expected data[0].embedding array. Got: {}", response_text); + format!("Invalid embedding response format - Expected data[0].embedding array, got: {}", response_text) + })? + .iter() + .filter_map(|v| v.as_f64().map(|f| f as f32)) + .collect() + } else { + return Err(format!( + "Invalid embedding response format - Expected array or data[0].embedding, got: {}", + response_text + ).into()); + }; Ok(embedding) } diff --git a/src/main_module/bootstrap.rs b/src/main_module/bootstrap.rs index 85a984d4c..c19737033 100644 --- a/src/main_module/bootstrap.rs +++ b/src/main_module/bootstrap.rs @@ -748,6 +748,9 @@ fn init_llm_provider( let embedding_model = config_manager .get_config(&bot_id, "embedding-model", Some("all-MiniLM-L6-v2")) .unwrap_or_else(|_| "all-MiniLM-L6-v2".to_string()); + let embedding_key = config_manager + .get_config(&bot_id, "embedding-key", None) + .ok(); let semantic_cache_enabled = config_manager .get_config(&bot_id, "semantic-cache-enabled", Some("true")) .unwrap_or_else(|_| "true".to_string()) @@ -755,12 +758,14 @@ fn init_llm_provider( info!("Embedding URL: {}", embedding_url); info!("Embedding Model: {}", embedding_model); + info!("Embedding Key: {}", if embedding_key.is_some() { "configured" } else { "not set" }); info!("Semantic Cache Enabled: {}", semantic_cache_enabled); let embedding_service = if semantic_cache_enabled { Some(Arc::new(LocalEmbeddingService::new( embedding_url, embedding_model, + embedding_key, )) as Arc) } else { None