From f0cdf2047a2c0e4d76877b2f8990391d6932be8f Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Wed, 10 Sep 2025 11:01:39 -0300 Subject: [PATCH] - Local LLM Embeddings. --- src/main.rs | 1 - src/services/llm_local.rs | 49 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index 7da0941..05271bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -82,7 +82,6 @@ async fn main() -> std::io::Result<()> { .service(list_emails) .service(send_email) .service(chat_stream) - .service(chat_completions) .service(chat_completions_local) .service(chat) .service(embeddings_local) diff --git a/src/services/llm_local.rs b/src/services/llm_local.rs index c018708..2216d6d 100644 --- a/src/services/llm_local.rs +++ b/src/services/llm_local.rs @@ -339,15 +339,62 @@ pub async fn chat_completions_local( } } -// OpenAI Embedding Request +// OpenAI Embedding Request - Modified to handle both string and array inputs #[derive(Debug, Deserialize)] pub struct EmbeddingRequest { + #[serde(deserialize_with = "deserialize_input")] pub input: Vec, pub model: String, #[serde(default)] pub encoding_format: Option, } +// Custom deserializer to handle both string and array inputs +fn deserialize_input<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + use serde::de::{self, Visitor}; + use std::fmt; + + struct InputVisitor; + + impl<'de> Visitor<'de> for InputVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string or an array of strings") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(vec![value.to_string()]) + } + + fn visit_string(self, value: String) -> Result + where + E: de::Error, + { + Ok(vec![value]) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut vec = Vec::new(); + while let Some(value) = seq.next_element::()? { + vec.push(value); + } + Ok(vec) + } + } + + deserializer.deserialize_any(InputVisitor) +} + // OpenAI Embedding Response #[derive(Debug, Serialize)] pub struct EmbeddingResponse {