- Warning removal and restore of old code.
This commit is contained in:
parent
76dc51e980
commit
8a9cd104d6
11 changed files with 522 additions and 339 deletions
|
|
@ -1,15 +1,17 @@
|
|||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use log::info;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
use crate::shared::{BotResponse, UserMessage};
|
||||
use crate::shared::BotResponse;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ChannelAdapter: Send + Sync {
|
||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
async fn send_message(
|
||||
&self,
|
||||
response: BotResponse,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
|
||||
pub struct WebChannelAdapter {
|
||||
|
|
@ -34,7 +36,10 @@ impl WebChannelAdapter {
|
|||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for WebChannelAdapter {
|
||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
async fn send_message(
|
||||
&self,
|
||||
response: BotResponse,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let connections = self.connections.lock().await;
|
||||
if let Some(tx) = connections.get(&response.session_id) {
|
||||
tx.send(response).await?;
|
||||
|
|
@ -67,11 +72,17 @@ impl VoiceAdapter {
|
|||
session_id: &str,
|
||||
user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
info!("Starting voice session for user: {} with session: {}", user_id, session_id);
|
||||
|
||||
info!(
|
||||
"Starting voice session for user: {} with session: {}",
|
||||
user_id, session_id
|
||||
);
|
||||
|
||||
let token = format!("mock_token_{}_{}", session_id, user_id);
|
||||
self.rooms.lock().await.insert(session_id.to_string(), token.clone());
|
||||
|
||||
self.rooms
|
||||
.lock()
|
||||
.await
|
||||
.insert(session_id.to_string(), token.clone());
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
|
|
@ -99,7 +110,10 @@ impl VoiceAdapter {
|
|||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for VoiceAdapter {
|
||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
async fn send_message(
|
||||
&self,
|
||||
response: BotResponse,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
info!("Sending voice response to: {}", response.user_id);
|
||||
self.send_voice_response(&response.session_id, &response.content)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -1,257 +0,0 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use langchain_rust::{
|
||||
language_models::llm::LLM,
|
||||
llm::{claude::Claude, openai::OpenAI},
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::tools::ToolManager;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
available_tools: &[String],
|
||||
tool_manager: Arc<ToolManager>,
|
||||
session_id: &str,
|
||||
user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
|
||||
pub struct OpenAIClient {
|
||||
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for OpenAIClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let result = self
|
||||
.client
|
||||
.invoke(prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
||||
let mut stream = self
|
||||
.client
|
||||
.stream(&messages)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let content = chunk.content;
|
||||
if !content.is_empty() {
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Stream error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_info = if available_tools.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||
};
|
||||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.invoke(&enhanced_prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnthropicClient {
|
||||
client: Claude,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
let client = Claude::default().with_api_key(api_key);
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for AnthropicClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let result = self
|
||||
.client
|
||||
.invoke(prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
||||
let mut stream = self
|
||||
.client
|
||||
.stream(&messages)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let content = chunk.content;
|
||||
if !content.is_empty() {
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Stream error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_info = if available_tools.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||
};
|
||||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.invoke(&enhanced_prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockLLMProvider;
|
||||
|
||||
impl MockLLMProvider {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for MockLLMProvider {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(format!("Mock response to: {}", prompt))
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = format!("Mock stream response to: {}", prompt);
|
||||
for word in response.split_whitespace() {
|
||||
let _ = tx.send(format!("{} ", word)).await;
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_list = if available_tools.is_empty() {
|
||||
"no tools available".to_string()
|
||||
} else {
|
||||
available_tools.join(", ")
|
||||
};
|
||||
Ok(format!(
|
||||
"Mock response with tools [{}] to: {}",
|
||||
tools_list, prompt
|
||||
))
|
||||
}
|
||||
}
|
||||
258
src/llm/mod.rs
258
src/llm/mod.rs
|
|
@ -1,3 +1,257 @@
|
|||
pub mod llm_provider;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use langchain_rust::{
|
||||
language_models::llm::LLM,
|
||||
llm::{claude::Claude, openai::OpenAI},
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub use llm_provider::*;
|
||||
use crate::tools::ToolManager;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
available_tools: &[String],
|
||||
tool_manager: Arc<ToolManager>,
|
||||
session_id: &str,
|
||||
user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
|
||||
pub struct OpenAIClient {
|
||||
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for OpenAIClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let result = self
|
||||
.client
|
||||
.invoke(prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
||||
let mut stream = self
|
||||
.client
|
||||
.stream(&messages)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let content = chunk.content;
|
||||
if !content.is_empty() {
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Stream error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_info = if available_tools.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||
};
|
||||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.invoke(&enhanced_prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnthropicClient {
|
||||
client: Claude,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
let client = Claude::default().with_api_key(api_key);
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for AnthropicClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let result = self
|
||||
.client
|
||||
.invoke(prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
||||
let mut stream = self
|
||||
.client
|
||||
.stream(&messages)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let content = chunk.content;
|
||||
if !content.is_empty() {
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Stream error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_info = if available_tools.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||
};
|
||||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.invoke(&enhanced_prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockLLMProvider;
|
||||
|
||||
impl MockLLMProvider {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for MockLLMProvider {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(format!("Mock response to: {}", prompt))
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = format!("Mock stream response to: {}", prompt);
|
||||
for word in response.split_whitespace() {
|
||||
let _ = tx.send(format!("{} ", word)).await;
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_list = if available_tools.is_empty() {
|
||||
"no tools available".to_string()
|
||||
} else {
|
||||
available_tools.join(", ")
|
||||
};
|
||||
Ok(format!(
|
||||
"Mock response with tools [{}] to: {}",
|
||||
tools_list, prompt
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
use actix_web::{web, HttpResponse, Result};
|
||||
use serde_json::json;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::utils::azure_from_config;
|
||||
|
||||
pub async fn health() -> Result<HttpResponse> {
|
||||
Ok(HttpResponse::Ok().json(json!({"status": "healthy"})))
|
||||
}
|
||||
|
||||
pub async fn chat_completions_local(
|
||||
_data: web::Data<AppState>,
|
||||
_payload: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
Ok(HttpResponse::NotImplemented().json(json!({"error": "Local LLM not implemented"})))
|
||||
}
|
||||
|
||||
pub async fn embeddings_local(
|
||||
_data: web::Data<AppState>,
|
||||
_payload: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
Ok(HttpResponse::NotImplemented().json(json!({"error": "Local embeddings not implemented"})))
|
||||
}
|
||||
|
||||
pub async fn generic_chat_completions(
|
||||
_data: web::Data<AppState>,
|
||||
_payload: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
Ok(HttpResponse::NotImplemented().json(json!({"error": "Generic chat not implemented"})))
|
||||
}
|
||||
116
src/llm_legacy/llm_azure.rs
Normal file
116
src/llm_legacy/llm_azure.rs
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
use log::info;
|
||||
|
||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||
use dotenv::dotenv;
|
||||
use regex::Regex;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Choice {
|
||||
message: ChatMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
#[post("/azure/v1/chat/completions")]
|
||||
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||
// Always log raw POST data
|
||||
if let Ok(body_str) = std::str::from_utf8(&body) {
|
||||
info!("POST Data: {}", body_str);
|
||||
} else {
|
||||
info!("POST Data (binary): {:?}", body);
|
||||
}
|
||||
|
||||
dotenv().ok();
|
||||
|
||||
// Environment variables
|
||||
let azure_endpoint = env::var("AI_ENDPOINT")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||
let azure_key = env::var("AI_KEY")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||
let deployment_name = env::var("AI_LLM_MODEL")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
||||
|
||||
// Construct Azure OpenAI URL
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
|
||||
azure_endpoint, deployment_name
|
||||
);
|
||||
|
||||
// Forward headers
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"api-key",
|
||||
reqwest::header::HeaderValue::from_str(&azure_key)
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
|
||||
);
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
let body_str = std::str::from_utf8(&body).unwrap_or("");
|
||||
info!("Original POST Data: {}", body_str);
|
||||
|
||||
// Remove the problematic params
|
||||
let re =
|
||||
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
|
||||
let cleaned = re.replace_all(body_str, "");
|
||||
let cleaned_body = web::Bytes::from(cleaned.to_string());
|
||||
|
||||
info!("Cleaned POST Data: {}", cleaned);
|
||||
|
||||
// Send request to Azure
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.post(&url)
|
||||
.headers(headers)
|
||||
.body(cleaned_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
|
||||
// Handle response based on status
|
||||
let status = response.status();
|
||||
let raw_response = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
|
||||
// Log the raw response
|
||||
info!("Raw Azure response: {}", raw_response);
|
||||
|
||||
if status.is_success() {
|
||||
Ok(HttpResponse::Ok().body(raw_response))
|
||||
} else {
|
||||
// Handle error responses properly
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
Ok(HttpResponse::build(actix_status).body(raw_response))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
use log::{error, info};
|
||||
|
||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||
use dotenv::dotenv;
|
||||
use log::{error, info};
|
||||
use regex::Regex;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
|
|
@ -35,17 +37,20 @@ struct Choice {
|
|||
}
|
||||
|
||||
fn clean_request_body(body: &str) -> String {
|
||||
// Remove problematic parameters that might not be supported by all providers
|
||||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||||
re.replace_all(body, "").to_string()
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||
// Log raw POST data
|
||||
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
||||
info!("Original POST Data: {}", body_str);
|
||||
|
||||
dotenv().ok();
|
||||
|
||||
// Get environment variables
|
||||
let api_key = env::var("AI_KEY")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||
let model = env::var("AI_LLM_MODEL")
|
||||
|
|
@ -53,9 +58,11 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
|||
let endpoint = env::var("AI_ENDPOINT")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||
|
||||
// Parse and modify the request body
|
||||
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
||||
|
||||
// Add model parameter
|
||||
if let Some(obj) = json_value.as_object_mut() {
|
||||
obj.insert("model".to_string(), serde_json::Value::String(model));
|
||||
}
|
||||
|
|
@ -65,6 +72,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
|||
|
||||
info!("Modified POST Data: {}", modified_body_str);
|
||||
|
||||
// Set up headers
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
|
|
@ -76,6 +84,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
|||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// Send request to the AI provider
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.post(&endpoint)
|
||||
|
|
@ -85,6 +94,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
|||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
|
||||
// Handle response
|
||||
let status = response.status();
|
||||
let raw_response = response
|
||||
.text()
|
||||
|
|
@ -94,6 +104,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
|||
info!("Provider response status: {}", status);
|
||||
info!("Provider response body: {}", raw_response);
|
||||
|
||||
// Convert response to OpenAI format if successful
|
||||
if status.is_success() {
|
||||
match convert_to_openai_format(&raw_response) {
|
||||
Ok(openai_response) => Ok(HttpResponse::Ok()
|
||||
|
|
@ -101,12 +112,14 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
|||
.body(openai_response)),
|
||||
Err(e) => {
|
||||
error!("Failed to convert response format: {}", e);
|
||||
// Return the original response if conversion fails
|
||||
Ok(HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(raw_response))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Return error as-is
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
|
|
@ -116,6 +129,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
|||
}
|
||||
}
|
||||
|
||||
/// Converts provider response to OpenAI-compatible format
|
||||
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ProviderChoice {
|
||||
|
|
@ -177,8 +191,10 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
|||
total_tokens: u32,
|
||||
}
|
||||
|
||||
// Parse the provider response
|
||||
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||||
|
||||
// Extract content from the first choice
|
||||
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||||
let content = first_choice.message.content.clone();
|
||||
let role = first_choice
|
||||
|
|
@ -187,6 +203,7 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
|||
.clone()
|
||||
.unwrap_or_else(|| "assistant".to_string());
|
||||
|
||||
// Calculate token usage
|
||||
let usage = provider.usage.unwrap_or_default();
|
||||
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||||
let completion_tokens = usage
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
|
|||
use std::env;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
|
|
@ -34,6 +35,7 @@ struct Choice {
|
|||
finish_reason: String,
|
||||
}
|
||||
|
||||
// Llama.cpp server request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LlamaCppRequest {
|
||||
prompt: String,
|
||||
|
|
@ -51,7 +53,8 @@ struct LlamaCppResponse {
|
|||
generation_settings: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
|
||||
{
|
||||
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
||||
|
||||
if llm_local.to_lowercase() != "true" {
|
||||
|
|
@ -59,6 +62,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
// Get configuration from environment variables
|
||||
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
let embedding_url =
|
||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||
|
|
@ -73,6 +77,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
|||
info!(" LLM Model: {}", llm_model_path);
|
||||
info!(" Embedding Model: {}", embedding_model_path);
|
||||
|
||||
// Check if servers are already running
|
||||
let llm_running = is_server_running(&llm_url).await;
|
||||
let embedding_running = is_server_running(&embedding_url).await;
|
||||
|
||||
|
|
@ -81,6 +86,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
// Start servers that aren't running
|
||||
let mut tasks = vec![];
|
||||
|
||||
if !llm_running && !llm_model_path.is_empty() {
|
||||
|
|
@ -105,17 +111,19 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
|||
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
||||
}
|
||||
|
||||
// Wait for all server startup tasks
|
||||
for task in tasks {
|
||||
task.await??;
|
||||
}
|
||||
|
||||
// Wait for servers to be ready with verbose logging
|
||||
info!("⏳ Waiting for servers to become ready...");
|
||||
|
||||
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
||||
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
||||
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 60;
|
||||
let max_attempts = 60; // 2 minutes total
|
||||
|
||||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
|
@ -180,6 +188,8 @@ async fn start_llm_server(
|
|||
std::env::set_var("OMP_PLACES", "cores");
|
||||
std::env::set_var("OMP_PROC_BIND", "close");
|
||||
|
||||
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
|
||||
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
||||
|
|
@ -215,6 +225,7 @@ async fn is_server_running(url: &str) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
// Convert OpenAI chat messages to a single prompt
|
||||
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
|
|
@ -239,28 +250,32 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
|||
prompt
|
||||
}
|
||||
|
||||
// Proxy endpoint
|
||||
#[post("/local/v1/chat/completions")]
|
||||
pub async fn chat_completions_local(
|
||||
req_body: web::Json<ChatCompletionRequest>,
|
||||
_req: HttpRequest,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok();
|
||||
dotenv().ok().unwrap();
|
||||
|
||||
// Get llama.cpp server URL
|
||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
|
||||
// Convert OpenAI format to llama.cpp format
|
||||
let prompt = messages_to_prompt(&req_body.messages);
|
||||
|
||||
let llama_request = LlamaCppRequest {
|
||||
prompt,
|
||||
n_predict: Some(500),
|
||||
n_predict: Some(500), // Adjust as needed
|
||||
temperature: Some(0.7),
|
||||
top_k: Some(40),
|
||||
top_p: Some(0.9),
|
||||
stream: req_body.stream,
|
||||
};
|
||||
|
||||
// Send request to llama.cpp server
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.timeout(Duration::from_secs(120)) // 2 minute timeout
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
error!("Error creating HTTP client: {}", e);
|
||||
|
|
@ -286,6 +301,7 @@ pub async fn chat_completions_local(
|
|||
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||||
})?;
|
||||
|
||||
// Convert llama.cpp response to OpenAI format
|
||||
let openai_response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||
object: "chat.completion".to_string(),
|
||||
|
|
@ -328,6 +344,7 @@ pub async fn chat_completions_local(
|
|||
}
|
||||
}
|
||||
|
||||
// OpenAI Embedding Request - Modified to handle both string and array inputs
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
#[serde(deserialize_with = "deserialize_input")]
|
||||
|
|
@ -337,6 +354,7 @@ pub struct EmbeddingRequest {
|
|||
pub _encoding_format: Option<String>,
|
||||
}
|
||||
|
||||
// Custom deserializer to handle both string and array inputs
|
||||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
|
|
@ -382,6 +400,7 @@ where
|
|||
deserializer.deserialize_any(InputVisitor)
|
||||
}
|
||||
|
||||
// OpenAI Embedding Response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub object: String,
|
||||
|
|
@ -403,17 +422,20 @@ pub struct Usage {
|
|||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// Llama.cpp Embedding Request
|
||||
#[derive(Debug, Serialize)]
|
||||
struct LlamaCppEmbeddingRequest {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
// FIXED: Handle the stupid nested array format
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LlamaCppEmbeddingResponseItem {
|
||||
pub index: usize,
|
||||
pub embedding: Vec<Vec<f32>>,
|
||||
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
|
||||
}
|
||||
|
||||
// Proxy endpoint for embeddings
|
||||
#[post("/v1/embeddings")]
|
||||
pub async fn embeddings_local(
|
||||
req_body: web::Json<EmbeddingRequest>,
|
||||
|
|
@ -421,6 +443,7 @@ pub async fn embeddings_local(
|
|||
) -> Result<HttpResponse> {
|
||||
dotenv().ok();
|
||||
|
||||
// Get llama.cpp server URL
|
||||
let llama_url =
|
||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||
|
||||
|
|
@ -432,6 +455,7 @@ pub async fn embeddings_local(
|
|||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||
})?;
|
||||
|
||||
// Process each input text and get embeddings
|
||||
let mut embeddings_data = Vec::new();
|
||||
let mut total_tokens = 0;
|
||||
|
||||
|
|
@ -456,11 +480,13 @@ pub async fn embeddings_local(
|
|||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
// First, get the raw response text for debugging
|
||||
let raw_response = response.text().await.map_err(|e| {
|
||||
error!("Error reading response text: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||||
})?;
|
||||
|
||||
// Parse the response as a vector of items with nested arrays
|
||||
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||||
serde_json::from_str(&raw_response).map_err(|e| {
|
||||
error!("Error parsing llama.cpp embedding response: {}", e);
|
||||
|
|
@ -470,13 +496,17 @@ pub async fn embeddings_local(
|
|||
)
|
||||
})?;
|
||||
|
||||
// Extract the embedding from the nested array bullshit
|
||||
if let Some(item) = llama_response.get(0) {
|
||||
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
|
||||
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
|
||||
let flattened_embedding = if !item.embedding.is_empty() {
|
||||
item.embedding[0].clone()
|
||||
item.embedding[0].clone() // Take the first (and probably only) inner array
|
||||
} else {
|
||||
vec![]
|
||||
vec![] // Empty if no embedding data
|
||||
};
|
||||
|
||||
// Estimate token count
|
||||
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||||
total_tokens += estimated_tokens;
|
||||
|
||||
|
|
@ -514,6 +544,7 @@ pub async fn embeddings_local(
|
|||
}
|
||||
}
|
||||
|
||||
// Build OpenAI-compatible response
|
||||
let openai_response = EmbeddingResponse {
|
||||
object: "list".to_string(),
|
||||
data: embeddings_data,
|
||||
|
|
@ -527,6 +558,7 @@ pub async fn embeddings_local(
|
|||
Ok(HttpResponse::Ok().json(openai_response))
|
||||
}
|
||||
|
||||
// Health check endpoint
|
||||
#[actix_web::get("/health")]
|
||||
pub async fn health() -> Result<HttpResponse> {
|
||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
pub mod llm;
|
||||
pub mod llm_azure;
|
||||
pub mod llm_generic;
|
||||
pub mod llm_local;
|
||||
|
|
|
|||
97
src/main.rs
97
src/main.rs
|
|
@ -1,4 +1,5 @@
|
|||
use actix_cors::Cors;
|
||||
use actix_web::middleware::Logger;
|
||||
use actix_web::{web, App, HttpServer};
|
||||
use dotenv::dotenv;
|
||||
use log::info;
|
||||
|
|
@ -23,23 +24,30 @@ mod tools;
|
|||
mod web_automation;
|
||||
mod whatsapp;
|
||||
|
||||
use crate::automation::AutomationService;
|
||||
use crate::bot::{
|
||||
create_session, get_session_history, get_sessions, index, set_mode_handler, static_files,
|
||||
voice_start, voice_stop, websocket_handler, whatsapp_webhook, whatsapp_webhook_verify,
|
||||
};
|
||||
use crate::channels::{VoiceAdapter, WebChannelAdapter};
|
||||
use crate::config::AppConfig;
|
||||
use crate::email::send_email;
|
||||
use crate::email::{
|
||||
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
|
||||
};
|
||||
use crate::file::{list_file, upload_file};
|
||||
use crate::llm_legacy::llm_generic::generic_chat_completions;
|
||||
use crate::llm_legacy::llm_local::{
|
||||
chat_completions_local, embeddings_local, ensure_llama_servers_running,
|
||||
};
|
||||
use crate::shared::AppState;
|
||||
use crate::whatsapp::WhatsAppAdapter;
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
dotenv().ok();
|
||||
env_logger::init();
|
||||
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
|
||||
|
||||
info!("Starting BotServer...");
|
||||
info!("Starting General Bots 6.0...");
|
||||
|
||||
let config = AppConfig::from_env();
|
||||
|
||||
|
|
@ -59,17 +67,17 @@ async fn main() -> std::io::Result<()> {
|
|||
};
|
||||
|
||||
// Optional custom database pool
|
||||
let db_custom_pool: Option<sqlx::Pool<sqlx::Postgres>> =
|
||||
match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await {
|
||||
Ok(pool) => {
|
||||
info!("Connected to custom database");
|
||||
Some(pool)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to connect to custom database: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
let db_custom_pool = match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await
|
||||
{
|
||||
Ok(pool) => {
|
||||
info!("Connected to custom database");
|
||||
Some(pool)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to connect to custom database: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Optional Redis client
|
||||
let redis_client = match redis::Client::open("redis://127.0.0.1/") {
|
||||
|
|
@ -83,9 +91,28 @@ async fn main() -> std::io::Result<()> {
|
|||
}
|
||||
};
|
||||
|
||||
// Placeholder for MinIO (not yet implemented)
|
||||
let minio_client = None;
|
||||
// Initialize MinIO client
|
||||
let minio_client = file::init_minio(&config)
|
||||
.await
|
||||
.expect("Failed to initialize Minio");
|
||||
|
||||
// Initialize browser pool
|
||||
let browser_pool = Arc::new(web_automation::BrowserPool::new(
|
||||
"chrome".to_string(),
|
||||
2,
|
||||
"headless".to_string(),
|
||||
));
|
||||
|
||||
// Initialize LLM servers
|
||||
ensure_llama_servers_running()
|
||||
.await
|
||||
.expect("Failed to initialize LLM local server.");
|
||||
|
||||
web_automation::initialize_browser_pool()
|
||||
.await
|
||||
.expect("Failed to initialize browser pool");
|
||||
|
||||
// Initialize services from new architecture
|
||||
let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone());
|
||||
let session_manager = session::SessionManager::new(db_pool.clone(), redis_client.clone());
|
||||
|
||||
|
|
@ -110,20 +137,13 @@ async fn main() -> std::io::Result<()> {
|
|||
|
||||
let tool_api = Arc::new(tools::ToolApi::new());
|
||||
|
||||
// Browser pool – constructed with the required three arguments.
|
||||
// Adjust the strings as needed for your environment.
|
||||
let browser_pool = Arc::new(web_automation::BrowserPool::new(
|
||||
"chrome".to_string(),
|
||||
2,
|
||||
"headless".to_string(),
|
||||
));
|
||||
|
||||
// Create unified app state
|
||||
let app_state = AppState {
|
||||
minio_client,
|
||||
minio_client: Some(minio_client),
|
||||
config: Some(config.clone()),
|
||||
db: Some(db_pool.clone()),
|
||||
db_custom: db_custom_pool,
|
||||
browser_pool,
|
||||
db_custom: db_custom_pool.clone(),
|
||||
browser_pool: browser_pool.clone(),
|
||||
orchestrator: Arc::new(orchestrator),
|
||||
web_adapter,
|
||||
voice_adapter,
|
||||
|
|
@ -131,6 +151,11 @@ async fn main() -> std::io::Result<()> {
|
|||
tool_api,
|
||||
};
|
||||
|
||||
// Start automation service in background
|
||||
let automation_state = app_state.clone();
|
||||
let automation = AutomationService::new(automation_state, "src/prompts");
|
||||
let _automation_handle = automation.spawn();
|
||||
|
||||
info!(
|
||||
"Starting server on {}:{}",
|
||||
config.server.host, config.server.port
|
||||
|
|
@ -145,7 +170,22 @@ async fn main() -> std::io::Result<()> {
|
|||
|
||||
App::new()
|
||||
.wrap(cors)
|
||||
.wrap(Logger::default())
|
||||
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
|
||||
.app_data(web::Data::new(app_state.clone()))
|
||||
// Legacy services
|
||||
.service(upload_file)
|
||||
.service(list_file)
|
||||
.service(save_click)
|
||||
.service(get_emails)
|
||||
.service(list_emails)
|
||||
.service(send_email)
|
||||
.service(chat_completions_local)
|
||||
.service(save_draft)
|
||||
.service(generic_chat_completions)
|
||||
.service(embeddings_local)
|
||||
.service(get_latest_email_from)
|
||||
// New bot services
|
||||
.service(index)
|
||||
.service(static_files)
|
||||
.service(websocket_handler)
|
||||
|
|
@ -157,9 +197,6 @@ async fn main() -> std::io::Result<()> {
|
|||
.service(get_sessions)
|
||||
.service(get_session_history)
|
||||
.service(set_mode_handler)
|
||||
.service(send_email)
|
||||
.service(upload_file)
|
||||
.service(list_file)
|
||||
})
|
||||
.bind((config.server.host.clone(), config.server.port))?
|
||||
.run()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
use chrono::Utc;
|
||||
use langchain_rust::llm::AzureConfig;
|
||||
use log::{debug, warn};
|
||||
use rhai::{Array, Dynamic};
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>General Bots - ChatGPT Clone</title>
|
||||
<title>General Bots</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue