feat: Add HuggingFace embedding API support with authentication
All checks were successful
BotServer CI / build (push) Successful in 11m52s
All checks were successful
BotServer CI / build (push) Successful in 11m52s
- Added api_key field to LocalEmbeddingService for authentication
- Added embedding-key config parameter support in bootstrap
- Smart URL handling: doesn't append /embedding for full URLs (HuggingFace, OpenAI, etc.)
- Detects HuggingFace URLs and uses correct request format (inputs instead of input/model)
- Handles multiple response formats:
- HuggingFace: direct array [0.1, 0.2, ...]
- Standard/OpenAI: {"data": [{"embedding": [...]}]}
- Added Authorization header with Bearer token when api_key is provided
- Improved error messages with full response details
Fixes embedding errors when using HuggingFace Inference API
This commit is contained in:
parent
d5b877f8e8
commit
e389178a36
2 changed files with 60 additions and 14 deletions
|
|
@ -605,13 +605,15 @@ impl LLMProvider for CachedLLMProvider {
|
||||||
pub struct LocalEmbeddingService {
|
pub struct LocalEmbeddingService {
|
||||||
embedding_url: String,
|
embedding_url: String,
|
||||||
model: String,
|
model: String,
|
||||||
|
api_key: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LocalEmbeddingService {
|
impl LocalEmbeddingService {
|
||||||
pub fn new(embedding_url: String, model: String) -> Self {
|
pub fn new(embedding_url: String, model: String, api_key: Option<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
embedding_url,
|
embedding_url,
|
||||||
model,
|
model,
|
||||||
|
api_key,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -623,12 +625,37 @@ impl EmbeddingService for LocalEmbeddingService {
|
||||||
text: &str,
|
text: &str,
|
||||||
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let response = client
|
|
||||||
.post(format!("{}/embedding", self.embedding_url))
|
// Determine if URL already includes endpoint path
|
||||||
.json(&serde_json::json!({
|
let url = if self.embedding_url.contains("/pipeline/") ||
|
||||||
|
self.embedding_url.contains("/v1/") ||
|
||||||
|
self.embedding_url.ends_with("/embeddings") {
|
||||||
|
self.embedding_url.clone()
|
||||||
|
} else {
|
||||||
|
format!("{}/embedding", self.embedding_url)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut request = client.post(&url);
|
||||||
|
|
||||||
|
// Add authorization header if API key is provided
|
||||||
|
if let Some(ref api_key) = self.api_key {
|
||||||
|
request = request.header("Authorization", format!("Bearer {}", api_key));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine request body format based on URL
|
||||||
|
let request_body = if self.embedding_url.contains("huggingface.co") {
|
||||||
|
serde_json::json!({
|
||||||
|
"inputs": text,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
serde_json::json!({
|
||||||
"input": text,
|
"input": text,
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
}))
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = request
|
||||||
|
.json(&request_body)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
@ -659,7 +686,15 @@ impl EmbeddingService for LocalEmbeddingService {
|
||||||
return Err(format!("Embedding service error: {}", error).into());
|
return Err(format!("Embedding service error: {}", error).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let embedding = result["data"][0]["embedding"]
|
// Try multiple response formats
|
||||||
|
let embedding = if let Some(arr) = result.as_array() {
|
||||||
|
// HuggingFace format: direct array [0.1, 0.2, ...]
|
||||||
|
arr.iter()
|
||||||
|
.filter_map(|v| v.as_f64().map(|f| f as f32))
|
||||||
|
.collect()
|
||||||
|
} else if let Some(data) = result.get("data") {
|
||||||
|
// OpenAI/Standard format: {"data": [{"embedding": [...]}]}
|
||||||
|
data[0]["embedding"]
|
||||||
.as_array()
|
.as_array()
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
debug!("Invalid embedding response format. Expected data[0].embedding array. Got: {}", response_text);
|
debug!("Invalid embedding response format. Expected data[0].embedding array. Got: {}", response_text);
|
||||||
|
|
@ -667,7 +702,13 @@ impl EmbeddingService for LocalEmbeddingService {
|
||||||
})?
|
})?
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|v| v.as_f64().map(|f| f as f32))
|
.filter_map(|v| v.as_f64().map(|f| f as f32))
|
||||||
.collect();
|
.collect()
|
||||||
|
} else {
|
||||||
|
return Err(format!(
|
||||||
|
"Invalid embedding response format - Expected array or data[0].embedding, got: {}",
|
||||||
|
response_text
|
||||||
|
).into());
|
||||||
|
};
|
||||||
|
|
||||||
Ok(embedding)
|
Ok(embedding)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -748,6 +748,9 @@ fn init_llm_provider(
|
||||||
let embedding_model = config_manager
|
let embedding_model = config_manager
|
||||||
.get_config(&bot_id, "embedding-model", Some("all-MiniLM-L6-v2"))
|
.get_config(&bot_id, "embedding-model", Some("all-MiniLM-L6-v2"))
|
||||||
.unwrap_or_else(|_| "all-MiniLM-L6-v2".to_string());
|
.unwrap_or_else(|_| "all-MiniLM-L6-v2".to_string());
|
||||||
|
let embedding_key = config_manager
|
||||||
|
.get_config(&bot_id, "embedding-key", None)
|
||||||
|
.ok();
|
||||||
let semantic_cache_enabled = config_manager
|
let semantic_cache_enabled = config_manager
|
||||||
.get_config(&bot_id, "semantic-cache-enabled", Some("true"))
|
.get_config(&bot_id, "semantic-cache-enabled", Some("true"))
|
||||||
.unwrap_or_else(|_| "true".to_string())
|
.unwrap_or_else(|_| "true".to_string())
|
||||||
|
|
@ -755,12 +758,14 @@ fn init_llm_provider(
|
||||||
|
|
||||||
info!("Embedding URL: {}", embedding_url);
|
info!("Embedding URL: {}", embedding_url);
|
||||||
info!("Embedding Model: {}", embedding_model);
|
info!("Embedding Model: {}", embedding_model);
|
||||||
|
info!("Embedding Key: {}", if embedding_key.is_some() { "configured" } else { "not set" });
|
||||||
info!("Semantic Cache Enabled: {}", semantic_cache_enabled);
|
info!("Semantic Cache Enabled: {}", semantic_cache_enabled);
|
||||||
|
|
||||||
let embedding_service = if semantic_cache_enabled {
|
let embedding_service = if semantic_cache_enabled {
|
||||||
Some(Arc::new(LocalEmbeddingService::new(
|
Some(Arc::new(LocalEmbeddingService::new(
|
||||||
embedding_url,
|
embedding_url,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
embedding_key,
|
||||||
)) as Arc<dyn EmbeddingService>)
|
)) as Arc<dyn EmbeddingService>)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue