feat: Add HuggingFace embedding API support with authentication
All checks were successful
BotServer CI / build (push) Successful in 11m52s

- Added api_key field to LocalEmbeddingService for authentication
- Added embedding-key config parameter support in bootstrap
- Smart URL handling: doesn't append /embedding for full URLs (HuggingFace, OpenAI, etc.)
- Detects HuggingFace URLs and uses correct request format (inputs instead of input/model)
- Handles multiple response formats:
  - HuggingFace: direct array [0.1, 0.2, ...]
  - Standard/OpenAI: {"data": [{"embedding": [...]}]}
- Added Authorization header with Bearer token when api_key is provided
- Improved error messages with full response details

Fixes embedding errors when using HuggingFace Inference API
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-03-04 16:54:25 -03:00
parent d5b877f8e8
commit e389178a36
2 changed files with 60 additions and 14 deletions

View file

@ -605,13 +605,15 @@ impl LLMProvider for CachedLLMProvider {
pub struct LocalEmbeddingService { pub struct LocalEmbeddingService {
embedding_url: String, embedding_url: String,
model: String, model: String,
api_key: Option<String>,
} }
impl LocalEmbeddingService { impl LocalEmbeddingService {
pub fn new(embedding_url: String, model: String) -> Self { pub fn new(embedding_url: String, model: String, api_key: Option<String>) -> Self {
Self { Self {
embedding_url, embedding_url,
model, model,
api_key,
} }
} }
} }
@ -623,12 +625,37 @@ impl EmbeddingService for LocalEmbeddingService {
text: &str, text: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let response = client
.post(format!("{}/embedding", self.embedding_url)) // Determine if URL already includes endpoint path
.json(&serde_json::json!({ let url = if self.embedding_url.contains("/pipeline/") ||
self.embedding_url.contains("/v1/") ||
self.embedding_url.ends_with("/embeddings") {
self.embedding_url.clone()
} else {
format!("{}/embedding", self.embedding_url)
};
let mut request = client.post(&url);
// Add authorization header if API key is provided
if let Some(ref api_key) = self.api_key {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
// Determine request body format based on URL
let request_body = if self.embedding_url.contains("huggingface.co") {
serde_json::json!({
"inputs": text,
})
} else {
serde_json::json!({
"input": text, "input": text,
"model": self.model, "model": self.model,
})) })
};
let response = request
.json(&request_body)
.send() .send()
.await?; .await?;
@ -659,15 +686,29 @@ impl EmbeddingService for LocalEmbeddingService {
return Err(format!("Embedding service error: {}", error).into()); return Err(format!("Embedding service error: {}", error).into());
} }
let embedding = result["data"][0]["embedding"] // Try multiple response formats
.as_array() let embedding = if let Some(arr) = result.as_array() {
.ok_or_else(|| { // HuggingFace format: direct array [0.1, 0.2, ...]
debug!("Invalid embedding response format. Expected data[0].embedding array. Got: {}", response_text); arr.iter()
format!("Invalid embedding response format - Expected data[0].embedding array, got: {}", response_text) .filter_map(|v| v.as_f64().map(|f| f as f32))
})? .collect()
.iter() } else if let Some(data) = result.get("data") {
.filter_map(|v| v.as_f64().map(|f| f as f32)) // OpenAI/Standard format: {"data": [{"embedding": [...]}]}
.collect(); data[0]["embedding"]
.as_array()
.ok_or_else(|| {
debug!("Invalid embedding response format. Expected data[0].embedding array. Got: {}", response_text);
format!("Invalid embedding response format - Expected data[0].embedding array, got: {}", response_text)
})?
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect()
} else {
return Err(format!(
"Invalid embedding response format - Expected array or data[0].embedding, got: {}",
response_text
).into());
};
Ok(embedding) Ok(embedding)
} }

View file

@ -748,6 +748,9 @@ fn init_llm_provider(
let embedding_model = config_manager let embedding_model = config_manager
.get_config(&bot_id, "embedding-model", Some("all-MiniLM-L6-v2")) .get_config(&bot_id, "embedding-model", Some("all-MiniLM-L6-v2"))
.unwrap_or_else(|_| "all-MiniLM-L6-v2".to_string()); .unwrap_or_else(|_| "all-MiniLM-L6-v2".to_string());
let embedding_key = config_manager
.get_config(&bot_id, "embedding-key", None)
.ok();
let semantic_cache_enabled = config_manager let semantic_cache_enabled = config_manager
.get_config(&bot_id, "semantic-cache-enabled", Some("true")) .get_config(&bot_id, "semantic-cache-enabled", Some("true"))
.unwrap_or_else(|_| "true".to_string()) .unwrap_or_else(|_| "true".to_string())
@ -755,12 +758,14 @@ fn init_llm_provider(
info!("Embedding URL: {}", embedding_url); info!("Embedding URL: {}", embedding_url);
info!("Embedding Model: {}", embedding_model); info!("Embedding Model: {}", embedding_model);
info!("Embedding Key: {}", if embedding_key.is_some() { "configured" } else { "not set" });
info!("Semantic Cache Enabled: {}", semantic_cache_enabled); info!("Semantic Cache Enabled: {}", semantic_cache_enabled);
let embedding_service = if semantic_cache_enabled { let embedding_service = if semantic_cache_enabled {
Some(Arc::new(LocalEmbeddingService::new( Some(Arc::new(LocalEmbeddingService::new(
embedding_url, embedding_url,
embedding_model, embedding_model,
embedding_key,
)) as Arc<dyn EmbeddingService>) )) as Arc<dyn EmbeddingService>)
} else { } else {
None None