Add generic LLM service implementation
This commit is contained in:
parent
e5dfcf2659
commit
babdaa79ca
4 changed files with 128 additions and 4 deletions
|
|
@ -13,8 +13,10 @@ use crate::services::email::{
|
|||
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
|
||||
};
|
||||
use crate::services::llm::{chat, chat_stream};
|
||||
use crate::services::llm_local::ensure_llama_servers_running;
|
||||
use crate::services::llm_local::{chat_completions_local, embeddings_local};
|
||||
use crate::services::llm_generic::generic_chat_completions;
|
||||
use crate::services::llm_local::{
|
||||
chat_completions_local, embeddings_local, ensure_llama_servers_running,
|
||||
};
|
||||
use crate::services::web_automation::{initialize_browser_pool, BrowserPool};
|
||||
|
||||
mod models;
|
||||
|
|
@ -85,8 +87,9 @@ async fn main() -> std::io::Result<()> {
|
|||
.service(list_emails)
|
||||
.service(send_email)
|
||||
.service(chat_stream)
|
||||
.service(chat_completions_local)
|
||||
.service(generic_chat_completions)
|
||||
.service(chat)
|
||||
.service(chat_completions_local)
|
||||
.service(save_draft)
|
||||
.service(embeddings_local)
|
||||
.service(get_latest_email_from)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ pub mod email;
|
|||
pub mod file;
|
||||
pub mod keywords;
|
||||
pub mod llm;
|
||||
pub mod llm_generic;
|
||||
pub mod llm_local;
|
||||
pub mod llm_provider;
|
||||
pub mod script;
|
||||
|
|
|
|||
120
src/services/llm_generic.rs
Normal file
120
src/services/llm_generic.rs
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
use log::info;
|
||||
|
||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||
use dotenv::dotenv;
|
||||
use regex::Regex;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Choice {
|
||||
message: ChatMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
fn clean_request_body(body: &str) -> String {
|
||||
// Remove problematic parameters that might not be supported by all providers
|
||||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||||
re.replace_all(body, "").to_string()
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||
// Log raw POST data
|
||||
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
||||
info!("Original POST Data: {}", body_str);
|
||||
|
||||
dotenv().ok();
|
||||
|
||||
// Get environment variables
|
||||
let api_key = env::var("AI_KEY")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||
let model = env::var("AI_LLM_MODEL")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
||||
let endpoint = env::var("AI_ENDPOINT")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||
|
||||
// Clean the request body (remove unsupported parameters)
|
||||
let cleaned_body_str = clean_request_body(body_str);
|
||||
info!("Cleaned POST Data: {}", cleaned_body_str);
|
||||
|
||||
// Set up headers
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid API key format"))?,
|
||||
);
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// After cleaning the request body, add the unused parameter
|
||||
let mut json_value: serde_json::Value = serde_json::from_str(&cleaned_body_str)
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
||||
|
||||
// Add the unused parameter
|
||||
json_value
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("model".to_string(), serde_json::Value::String(model));
|
||||
|
||||
// Serialize the modified JSON
|
||||
let modified_body_str = serde_json::to_string(&json_value)
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to serialize JSON"))?;
|
||||
|
||||
// Send request to the OpenAI-compatible provider
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.post(&endpoint)
|
||||
.headers(headers)
|
||||
.body(modified_body_str)
|
||||
.send()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
|
||||
// Handle response
|
||||
let status = response.status();
|
||||
let raw_response = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
|
||||
info!("Provider response status: {}", status);
|
||||
info!("Provider response body: {}", raw_response);
|
||||
|
||||
// Return the response with appropriate status code
|
||||
if status.is_success() {
|
||||
Ok(HttpResponse::Ok().body(raw_response))
|
||||
} else {
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
Ok(HttpResponse::build(actix_status).body(raw_response))
|
||||
}
|
||||
}
|
||||
|
|
@ -36,7 +36,7 @@ struct Choice {
|
|||
finish_reason: String,
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
#[post("/azure/v1/chat/completions")]
|
||||
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||
// Always log raw POST data
|
||||
if let Ok(body_str) = std::str::from_utf8(&body) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue