This commit is contained in:
parent
c54904b18b
commit
f0cdf2047a
2 changed files with 48 additions and 2 deletions
|
@ -82,7 +82,6 @@ async fn main() -> std::io::Result<()> {
|
|||
.service(list_emails)
|
||||
.service(send_email)
|
||||
.service(chat_stream)
|
||||
.service(chat_completions)
|
||||
.service(chat_completions_local)
|
||||
.service(chat)
|
||||
.service(embeddings_local)
|
||||
|
|
|
@ -339,15 +339,62 @@ pub async fn chat_completions_local(
|
|||
}
|
||||
}
|
||||
|
||||
// OpenAI Embedding Request
|
||||
// OpenAI Embedding Request - Modified to handle both string and array inputs
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
#[serde(deserialize_with = "deserialize_input")]
|
||||
pub input: Vec<String>,
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub encoding_format: Option<String>,
|
||||
}
|
||||
|
||||
// Custom deserializer to handle both string and array inputs
|
||||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::{self, Visitor};
|
||||
use std::fmt;
|
||||
|
||||
struct InputVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for InputVisitor {
|
||||
type Value = Vec<String>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string or an array of strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(vec![value.to_string()])
|
||||
}
|
||||
|
||||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(vec![value])
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::SeqAccess<'de>,
|
||||
{
|
||||
let mut vec = Vec::new();
|
||||
while let Some(value) = seq.next_element::<String>()? {
|
||||
vec.push(value);
|
||||
}
|
||||
Ok(vec)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(InputVisitor)
|
||||
}
|
||||
|
||||
// OpenAI Embedding Response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
|
|
Loading…
Add table
Reference in a new issue