This commit is contained in:
parent
8349a044b5
commit
660d7d436f
2 changed files with 71 additions and 27 deletions
|
|
@ -36,11 +36,12 @@ struct Choice {
|
||||||
finish_reason: String,
|
finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _clean_request_body(body: &str) -> String {
|
fn clean_request_body(body: &str) -> String {
|
||||||
// Remove problematic parameters that might not be supported by all providers
|
// Remove problematic parameters that might not be supported by all providers
|
||||||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||||||
re.replace_all(body, "").to_string()
|
re.replace_all(body, "").to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/chat/completions")]
|
#[post("/v1/chat/completions")]
|
||||||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||||
// Log raw POST data
|
// Log raw POST data
|
||||||
|
|
@ -130,13 +131,34 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
|
|
||||||
/// Converts provider response to OpenAI-compatible format
|
/// Converts provider response to OpenAI-compatible format
|
||||||
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct ProviderChoice {
|
||||||
|
message: ProviderMessage,
|
||||||
|
#[serde(default)]
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct ProviderMessage {
|
||||||
|
role: Option<String>,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct ProviderResponse {
|
struct ProviderResponse {
|
||||||
text: String,
|
id: Option<String>,
|
||||||
#[serde(default)]
|
object: Option<String>,
|
||||||
generated_tokens: Option<u32>,
|
created: Option<u64>,
|
||||||
#[serde(default)]
|
model: Option<String>,
|
||||||
input_tokens: Option<u32>,
|
choices: Vec<ProviderChoice>,
|
||||||
|
usage: Option<ProviderUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Default)]
|
||||||
|
struct ProviderUsage {
|
||||||
|
prompt_tokens: Option<u32>,
|
||||||
|
completion_tokens: Option<u32>,
|
||||||
|
total_tokens: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize)]
|
#[derive(serde::Serialize)]
|
||||||
|
|
@ -172,28 +194,46 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
// Parse the provider response
|
// Parse the provider response
|
||||||
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||||||
|
|
||||||
let completion_tokens = provider
|
// Extract content from the first choice
|
||||||
.generated_tokens
|
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||||||
.unwrap_or_else(|| provider.text.split_whitespace().count() as u32);
|
let content = first_choice.message.content.clone();
|
||||||
|
let role = first_choice
|
||||||
|
.message
|
||||||
|
.role
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "assistant".to_string());
|
||||||
|
|
||||||
let prompt_tokens = provider.input_tokens.unwrap_or(0);
|
// Calculate token usage
|
||||||
let total_tokens = prompt_tokens + completion_tokens;
|
let usage = provider.usage.unwrap_or_default();
|
||||||
|
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||||||
|
let completion_tokens = usage
|
||||||
|
.completion_tokens
|
||||||
|
.unwrap_or_else(|| content.split_whitespace().count() as u32);
|
||||||
|
let total_tokens = usage
|
||||||
|
.total_tokens
|
||||||
|
.unwrap_or(prompt_tokens + completion_tokens);
|
||||||
|
|
||||||
let openai_response = OpenAIResponse {
|
let openai_response = OpenAIResponse {
|
||||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().simple()),
|
id: provider
|
||||||
object: "chat.completion".to_string(),
|
.id
|
||||||
created: std::time::SystemTime::now()
|
.unwrap_or_else(|| format!("chatcmpl-{}", uuid::Uuid::new_v4().simple())),
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
object: provider
|
||||||
.unwrap()
|
.object
|
||||||
.as_secs(),
|
.unwrap_or_else(|| "chat.completion".to_string()),
|
||||||
model: "llama".to_string(),
|
created: provider.created.unwrap_or_else(|| {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs()
|
||||||
|
}),
|
||||||
|
model: provider.model.unwrap_or_else(|| "llama".to_string()),
|
||||||
choices: vec![OpenAIChoice {
|
choices: vec![OpenAIChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
message: OpenAIMessage {
|
message: OpenAIMessage { role, content },
|
||||||
role: "assistant".to_string(),
|
finish_reason: first_choice
|
||||||
content: provider.text,
|
.finish_reason
|
||||||
},
|
.clone()
|
||||||
finish_reason: "stop".to_string(),
|
.unwrap_or_else(|| "stop".to_string()),
|
||||||
}],
|
}],
|
||||||
usage: OpenAIUsage {
|
usage: OpenAIUsage {
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
|
|
@ -204,3 +244,5 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
|
|
||||||
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Default implementation for ProviderUsage
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,15 @@ use minio::s3::Client;
|
||||||
|
|
||||||
use crate::services::{config::AppConfig, web_automation::BrowserPool};
|
use crate::services::{config::AppConfig, web_automation::BrowserPool};
|
||||||
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub minio_client: Option<Client>,
|
pub minio_client: Option<Client>,
|
||||||
pub config: Option<AppConfig>,
|
pub config: Option<AppConfig>,
|
||||||
pub db: Option<sqlx::PgPool>,
|
pub db: Option<sqlx::PgPool>,
|
||||||
pub db_custom: Option<sqlx::PgPool>,
|
pub db_custom: Option<sqlx::PgPool>,
|
||||||
pub browser_pool: Arc<BrowserPool>,
|
pub browser_pool: Arc<BrowserPool>,
|
||||||
}
|
}
|
||||||
|
pub struct _BotState {
|
||||||
|
pub language: String,
|
||||||
|
pub work_folder: String,
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue