diff --git a/src/main.rs b/src/main.rs index 7da0941..05271bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -82,7 +82,6 @@ async fn main() -> std::io::Result<()> { .service(list_emails) .service(send_email) .service(chat_stream) - .service(chat_completions) .service(chat_completions_local) .service(chat) .service(embeddings_local) diff --git a/src/services/llm_local.rs b/src/services/llm_local.rs index c018708..2216d6d 100644 --- a/src/services/llm_local.rs +++ b/src/services/llm_local.rs @@ -339,15 +339,62 @@ pub async fn chat_completions_local( } } -// OpenAI Embedding Request +// OpenAI Embedding Request - Modified to handle both string and array inputs #[derive(Debug, Deserialize)] pub struct EmbeddingRequest { + #[serde(deserialize_with = "deserialize_input")] pub input: Vec, pub model: String, #[serde(default)] pub encoding_format: Option, } +// Custom deserializer to handle both string and array inputs +fn deserialize_input<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + use serde::de::{self, Visitor}; + use std::fmt; + + struct InputVisitor; + + impl<'de> Visitor<'de> for InputVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string or an array of strings") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(vec![value.to_string()]) + } + + fn visit_string(self, value: String) -> Result + where + E: de::Error, + { + Ok(vec![value]) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut vec = Vec::new(); + while let Some(value) = seq.next_element::()? { + vec.push(value); + } + Ok(vec) + } + } + + deserializer.deserialize_any(InputVisitor) +} + // OpenAI Embedding Response #[derive(Debug, Serialize)] pub struct EmbeddingResponse {