This commit is contained in:
parent
8349a044b5
commit
660d7d436f
2 changed files with 71 additions and 27 deletions
|
|
@ -36,11 +36,12 @@ struct Choice {
|
|||
finish_reason: String,
|
||||
}
|
||||
|
||||
fn _clean_request_body(body: &str) -> String {
|
||||
fn clean_request_body(body: &str) -> String {
|
||||
// Remove problematic parameters that might not be supported by all providers
|
||||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||||
re.replace_all(body, "").to_string()
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||
// Log raw POST data
|
||||
|
|
@ -130,13 +131,34 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
|||
|
||||
/// Converts provider response to OpenAI-compatible format
|
||||
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ProviderChoice {
|
||||
message: ProviderMessage,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ProviderMessage {
|
||||
role: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ProviderResponse {
|
||||
text: String,
|
||||
#[serde(default)]
|
||||
generated_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
input_tokens: Option<u32>,
|
||||
id: Option<String>,
|
||||
object: Option<String>,
|
||||
created: Option<u64>,
|
||||
model: Option<String>,
|
||||
choices: Vec<ProviderChoice>,
|
||||
usage: Option<ProviderUsage>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Default)]
|
||||
struct ProviderUsage {
|
||||
prompt_tokens: Option<u32>,
|
||||
completion_tokens: Option<u32>,
|
||||
total_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
|
|
@ -172,28 +194,46 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
|||
// Parse the provider response
|
||||
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||||
|
||||
let completion_tokens = provider
|
||||
.generated_tokens
|
||||
.unwrap_or_else(|| provider.text.split_whitespace().count() as u32);
|
||||
// Extract content from the first choice
|
||||
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||||
let content = first_choice.message.content.clone();
|
||||
let role = first_choice
|
||||
.message
|
||||
.role
|
||||
.clone()
|
||||
.unwrap_or_else(|| "assistant".to_string());
|
||||
|
||||
let prompt_tokens = provider.input_tokens.unwrap_or(0);
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
// Calculate token usage
|
||||
let usage = provider.usage.unwrap_or_default();
|
||||
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||||
let completion_tokens = usage
|
||||
.completion_tokens
|
||||
.unwrap_or_else(|| content.split_whitespace().count() as u32);
|
||||
let total_tokens = usage
|
||||
.total_tokens
|
||||
.unwrap_or(prompt_tokens + completion_tokens);
|
||||
|
||||
let openai_response = OpenAIResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().simple()),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
id: provider
|
||||
.id
|
||||
.unwrap_or_else(|| format!("chatcmpl-{}", uuid::Uuid::new_v4().simple())),
|
||||
object: provider
|
||||
.object
|
||||
.unwrap_or_else(|| "chat.completion".to_string()),
|
||||
created: provider.created.unwrap_or_else(|| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: "llama".to_string(),
|
||||
.as_secs()
|
||||
}),
|
||||
model: provider.model.unwrap_or_else(|| "llama".to_string()),
|
||||
choices: vec![OpenAIChoice {
|
||||
index: 0,
|
||||
message: OpenAIMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: provider.text,
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
message: OpenAIMessage { role, content },
|
||||
finish_reason: first_choice
|
||||
.finish_reason
|
||||
.clone()
|
||||
.unwrap_or_else(|| "stop".to_string()),
|
||||
}],
|
||||
usage: OpenAIUsage {
|
||||
prompt_tokens,
|
||||
|
|
@ -204,3 +244,5 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
|||
|
||||
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
// Default implementation for ProviderUsage
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ use minio::s3::Client;
|
|||
|
||||
use crate::services::{config::AppConfig, web_automation::BrowserPool};
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub minio_client: Option<Client>,
|
||||
|
|
@ -13,4 +12,7 @@ pub struct AppState {
|
|||
pub db_custom: Option<sqlx::PgPool>,
|
||||
pub browser_pool: Arc<BrowserPool>,
|
||||
}
|
||||
|
||||
pub struct _BotState {
|
||||
pub language: String,
|
||||
pub work_folder: String,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue