From babdaa79cafd163121749653d8fe7bee40b46a88 Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Fri, 19 Sep 2025 08:25:39 -0300 Subject: [PATCH] Add generic LLM service implementation --- src/main.rs | 9 ++- src/services.rs | 1 + src/services/llm_generic.rs | 120 +++++++++++++++++++++++++++++++++++ src/services/llm_provider.rs | 2 +- 4 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 src/services/llm_generic.rs diff --git a/src/main.rs b/src/main.rs index 866e148..c8dfe0c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,8 +13,10 @@ use crate::services::email::{ get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email, }; use crate::services::llm::{chat, chat_stream}; -use crate::services::llm_local::ensure_llama_servers_running; -use crate::services::llm_local::{chat_completions_local, embeddings_local}; +use crate::services::llm_generic::generic_chat_completions; +use crate::services::llm_local::{ + chat_completions_local, embeddings_local, ensure_llama_servers_running, +}; use crate::services::web_automation::{initialize_browser_pool, BrowserPool}; mod models; @@ -85,8 +87,9 @@ async fn main() -> std::io::Result<()> { .service(list_emails) .service(send_email) .service(chat_stream) - .service(chat_completions_local) + .service(generic_chat_completions) .service(chat) + .service(chat_completions_local) .service(save_draft) .service(embeddings_local) .service(get_latest_email_from) diff --git a/src/services.rs b/src/services.rs index 567bcc4..e56db6a 100644 --- a/src/services.rs +++ b/src/services.rs @@ -4,6 +4,7 @@ pub mod email; pub mod file; pub mod keywords; pub mod llm; +pub mod llm_generic; pub mod llm_local; pub mod llm_provider; pub mod script; diff --git a/src/services/llm_generic.rs b/src/services/llm_generic.rs new file mode 100644 index 0000000..f20384d --- /dev/null +++ b/src/services/llm_generic.rs @@ -0,0 +1,120 @@ +use log::info; + +use actix_web::{post, web, HttpRequest, HttpResponse, Result}; +use dotenv::dotenv; +use regex::Regex; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::env; + +// OpenAI-compatible request/response structures +#[derive(Debug, Serialize, Deserialize)] +struct ChatMessage { + role: String, + content: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ChatCompletionRequest { + model: String, + messages: Vec, + stream: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ChatCompletionResponse { + id: String, + object: String, + created: u64, + model: String, + choices: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Choice { + message: ChatMessage, + finish_reason: String, +} + +fn clean_request_body(body: &str) -> String { + // Remove problematic parameters that might not be supported by all providers + let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap(); + re.replace_all(body, "").to_string() +} + +#[post("/v1/chat/completions")] +pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result { + // Log raw POST data + let body_str = std::str::from_utf8(&body).unwrap_or_default(); + info!("Original POST Data: {}", body_str); + + dotenv().ok(); + + // Get environment variables + let api_key = env::var("AI_KEY") + .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; + let model = env::var("AI_LLM_MODEL") + .map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?; + let endpoint = env::var("AI_ENDPOINT") + .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; + + // Clean the request body (remove unsupported parameters) + let cleaned_body_str = clean_request_body(body_str); + info!("Cleaned POST Data: {}", cleaned_body_str); + + // Set up headers + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) + .map_err(|_| actix_web::error::ErrorInternalServerError("Invalid API key format"))?, + ); + headers.insert( + "Content-Type", + reqwest::header::HeaderValue::from_static("application/json"), + ); + + // After cleaning the request body, add the unused parameter + let mut json_value: serde_json::Value = serde_json::from_str(&cleaned_body_str) + .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?; + + // Add the unused parameter + json_value + .as_object_mut() + .unwrap() + .insert("model".to_string(), serde_json::Value::String(model)); + + // Serialize the modified JSON + let modified_body_str = serde_json::to_string(&json_value) + .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to serialize JSON"))?; + + // Send request to the OpenAI-compatible provider + let client = Client::new(); + let response = client + .post(&endpoint) + .headers(headers) + .body(modified_body_str) + .send() + .await + .map_err(actix_web::error::ErrorInternalServerError)?; + + // Handle response + let status = response.status(); + let raw_response = response + .text() + .await + .map_err(actix_web::error::ErrorInternalServerError)?; + + info!("Provider response status: {}", status); + info!("Provider response body: {}", raw_response); + + // Return the response with appropriate status code + if status.is_success() { + Ok(HttpResponse::Ok().body(raw_response)) + } else { + let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + Ok(HttpResponse::build(actix_status).body(raw_response)) + } +} diff --git a/src/services/llm_provider.rs b/src/services/llm_provider.rs index 45539a1..4fc02bc 100644 --- a/src/services/llm_provider.rs +++ b/src/services/llm_provider.rs @@ -36,7 +36,7 @@ struct Choice { finish_reason: String, } -#[post("/v1/chat/completions")] +#[post("/azure/v1/chat/completions")] async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result { // Always log raw POST data if let Ok(body_str) = std::str::from_utf8(&body) {