- Warning removal and restore of old code.
This commit is contained in:
parent
76dc51e980
commit
8a9cd104d6
11 changed files with 522 additions and 339 deletions
|
|
@ -1,15 +1,17 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Utc;
|
|
||||||
use log::info;
|
use log::info;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
|
|
||||||
use crate::shared::{BotResponse, UserMessage};
|
use crate::shared::BotResponse;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait ChannelAdapter: Send + Sync {
|
pub trait ChannelAdapter: Send + Sync {
|
||||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
async fn send_message(
|
||||||
|
&self,
|
||||||
|
response: BotResponse,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct WebChannelAdapter {
|
pub struct WebChannelAdapter {
|
||||||
|
|
@ -34,7 +36,10 @@ impl WebChannelAdapter {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ChannelAdapter for WebChannelAdapter {
|
impl ChannelAdapter for WebChannelAdapter {
|
||||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
async fn send_message(
|
||||||
|
&self,
|
||||||
|
response: BotResponse,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let connections = self.connections.lock().await;
|
let connections = self.connections.lock().await;
|
||||||
if let Some(tx) = connections.get(&response.session_id) {
|
if let Some(tx) = connections.get(&response.session_id) {
|
||||||
tx.send(response).await?;
|
tx.send(response).await?;
|
||||||
|
|
@ -67,11 +72,17 @@ impl VoiceAdapter {
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
info!("Starting voice session for user: {} with session: {}", user_id, session_id);
|
info!(
|
||||||
|
"Starting voice session for user: {} with session: {}",
|
||||||
|
user_id, session_id
|
||||||
|
);
|
||||||
|
|
||||||
let token = format!("mock_token_{}_{}", session_id, user_id);
|
let token = format!("mock_token_{}_{}", session_id, user_id);
|
||||||
self.rooms.lock().await.insert(session_id.to_string(), token.clone());
|
self.rooms
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(session_id.to_string(), token.clone());
|
||||||
|
|
||||||
Ok(token)
|
Ok(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -99,7 +110,10 @@ impl VoiceAdapter {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ChannelAdapter for VoiceAdapter {
|
impl ChannelAdapter for VoiceAdapter {
|
||||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
async fn send_message(
|
||||||
|
&self,
|
||||||
|
response: BotResponse,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
info!("Sending voice response to: {}", response.user_id);
|
info!("Sending voice response to: {}", response.user_id);
|
||||||
self.send_voice_response(&response.session_id, &response.content)
|
self.send_voice_response(&response.session_id, &response.content)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -1,257 +0,0 @@
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::StreamExt;
|
|
||||||
use langchain_rust::{
|
|
||||||
language_models::llm::LLM,
|
|
||||||
llm::{claude::Claude, openai::OpenAI},
|
|
||||||
};
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
|
|
||||||
use crate::tools::ToolManager;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait LLMProvider: Send + Sync {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
tx: mpsc::Sender<String>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
|
|
||||||
async fn generate_with_tools(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
available_tools: &[String],
|
|
||||||
tool_manager: Arc<ToolManager>,
|
|
||||||
session_id: &str,
|
|
||||||
user_id: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct OpenAIClient {
|
|
||||||
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAIClient {
|
|
||||||
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
|
|
||||||
Self { client }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for OpenAIClient {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let result = self
|
|
||||||
.client
|
|
||||||
.invoke(prompt)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
tx: mpsc::Sender<String>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
|
||||||
let mut stream = self
|
|
||||||
.client
|
|
||||||
.stream(&messages)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
while let Some(result) = stream.next().await {
|
|
||||||
match result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
let content = chunk.content;
|
|
||||||
if !content.is_empty() {
|
|
||||||
let _ = tx.send(content.to_string()).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Stream error: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_with_tools(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
available_tools: &[String],
|
|
||||||
_tool_manager: Arc<ToolManager>,
|
|
||||||
_session_id: &str,
|
|
||||||
_user_id: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let tools_info = if available_tools.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
|
||||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
|
||||||
};
|
|
||||||
|
|
||||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
|
||||||
|
|
||||||
let result = self
|
|
||||||
.client
|
|
||||||
.invoke(&enhanced_prompt)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct AnthropicClient {
|
|
||||||
client: Claude,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AnthropicClient {
|
|
||||||
pub fn new(api_key: String) -> Self {
|
|
||||||
let client = Claude::default().with_api_key(api_key);
|
|
||||||
Self { client }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for AnthropicClient {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let result = self
|
|
||||||
.client
|
|
||||||
.invoke(prompt)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
tx: mpsc::Sender<String>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
|
||||||
let mut stream = self
|
|
||||||
.client
|
|
||||||
.stream(&messages)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
while let Some(result) = stream.next().await {
|
|
||||||
match result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
let content = chunk.content;
|
|
||||||
if !content.is_empty() {
|
|
||||||
let _ = tx.send(content.to_string()).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Stream error: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_with_tools(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
available_tools: &[String],
|
|
||||||
_tool_manager: Arc<ToolManager>,
|
|
||||||
_session_id: &str,
|
|
||||||
_user_id: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let tools_info = if available_tools.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
|
||||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
|
||||||
};
|
|
||||||
|
|
||||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
|
||||||
|
|
||||||
let result = self
|
|
||||||
.client
|
|
||||||
.invoke(&enhanced_prompt)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct MockLLMProvider;
|
|
||||||
|
|
||||||
impl MockLLMProvider {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for MockLLMProvider {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
Ok(format!("Mock response to: {}", prompt))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
tx: mpsc::Sender<String>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let response = format!("Mock stream response to: {}", prompt);
|
|
||||||
for word in response.split_whitespace() {
|
|
||||||
let _ = tx.send(format!("{} ", word)).await;
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_with_tools(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
available_tools: &[String],
|
|
||||||
_tool_manager: Arc<ToolManager>,
|
|
||||||
_session_id: &str,
|
|
||||||
_user_id: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let tools_list = if available_tools.is_empty() {
|
|
||||||
"no tools available".to_string()
|
|
||||||
} else {
|
|
||||||
available_tools.join(", ")
|
|
||||||
};
|
|
||||||
Ok(format!(
|
|
||||||
"Mock response with tools [{}] to: {}",
|
|
||||||
tools_list, prompt
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
258
src/llm/mod.rs
258
src/llm/mod.rs
|
|
@ -1,3 +1,257 @@
|
||||||
pub mod llm_provider;
|
use async_trait::async_trait;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use langchain_rust::{
|
||||||
|
language_models::llm::LLM,
|
||||||
|
llm::{claude::Claude, openai::OpenAI},
|
||||||
|
};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
pub use llm_provider::*;
|
use crate::tools::ToolManager;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait LLMProvider: Send + Sync {
|
||||||
|
async fn generate(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
tool_manager: Arc<ToolManager>,
|
||||||
|
session_id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OpenAIClient {
|
||||||
|
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIClient {
|
||||||
|
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for OpenAIClient {
|
||||||
|
async fn generate(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let result = self
|
||||||
|
.client
|
||||||
|
.invoke(prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
||||||
|
let mut stream = self
|
||||||
|
.client
|
||||||
|
.stream(&messages)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
let content = chunk.content;
|
||||||
|
if !content.is_empty() {
|
||||||
|
let _ = tx.send(content.to_string()).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Stream error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
_tool_manager: Arc<ToolManager>,
|
||||||
|
_session_id: &str,
|
||||||
|
_user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let tools_info = if available_tools.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||||
|
};
|
||||||
|
|
||||||
|
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.client
|
||||||
|
.invoke(&enhanced_prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AnthropicClient {
|
||||||
|
client: Claude,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnthropicClient {
|
||||||
|
pub fn new(api_key: String) -> Self {
|
||||||
|
let client = Claude::default().with_api_key(api_key);
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for AnthropicClient {
|
||||||
|
async fn generate(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let result = self
|
||||||
|
.client
|
||||||
|
.invoke(prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
||||||
|
let mut stream = self
|
||||||
|
.client
|
||||||
|
.stream(&messages)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
let content = chunk.content;
|
||||||
|
if !content.is_empty() {
|
||||||
|
let _ = tx.send(content.to_string()).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Stream error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
_tool_manager: Arc<ToolManager>,
|
||||||
|
_session_id: &str,
|
||||||
|
_user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let tools_info = if available_tools.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||||
|
};
|
||||||
|
|
||||||
|
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.client
|
||||||
|
.invoke(&enhanced_prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MockLLMProvider;
|
||||||
|
|
||||||
|
impl MockLLMProvider {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for MockLLMProvider {
|
||||||
|
async fn generate(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
Ok(format!("Mock response to: {}", prompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let response = format!("Mock stream response to: {}", prompt);
|
||||||
|
for word in response.split_whitespace() {
|
||||||
|
let _ = tx.send(format!("{} ", word)).await;
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
_tool_manager: Arc<ToolManager>,
|
||||||
|
_session_id: &str,
|
||||||
|
_user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let tools_list = if available_tools.is_empty() {
|
||||||
|
"no tools available".to_string()
|
||||||
|
} else {
|
||||||
|
available_tools.join(", ")
|
||||||
|
};
|
||||||
|
Ok(format!(
|
||||||
|
"Mock response with tools [{}] to: {}",
|
||||||
|
tools_list, prompt
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,29 +0,0 @@
|
||||||
use actix_web::{web, HttpResponse, Result};
|
|
||||||
use serde_json::json;
|
|
||||||
use crate::shared::state::AppState;
|
|
||||||
use crate::shared::utils::azure_from_config;
|
|
||||||
|
|
||||||
pub async fn health() -> Result<HttpResponse> {
|
|
||||||
Ok(HttpResponse::Ok().json(json!({"status": "healthy"})))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn chat_completions_local(
|
|
||||||
_data: web::Data<AppState>,
|
|
||||||
_payload: web::Json<serde_json::Value>,
|
|
||||||
) -> Result<HttpResponse> {
|
|
||||||
Ok(HttpResponse::NotImplemented().json(json!({"error": "Local LLM not implemented"})))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn embeddings_local(
|
|
||||||
_data: web::Data<AppState>,
|
|
||||||
_payload: web::Json<serde_json::Value>,
|
|
||||||
) -> Result<HttpResponse> {
|
|
||||||
Ok(HttpResponse::NotImplemented().json(json!({"error": "Local embeddings not implemented"})))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn generic_chat_completions(
|
|
||||||
_data: web::Data<AppState>,
|
|
||||||
_payload: web::Json<serde_json::Value>,
|
|
||||||
) -> Result<HttpResponse> {
|
|
||||||
Ok(HttpResponse::NotImplemented().json(json!({"error": "Generic chat not implemented"})))
|
|
||||||
}
|
|
||||||
116
src/llm_legacy/llm_azure.rs
Normal file
116
src/llm_legacy/llm_azure.rs
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
use log::info;
|
||||||
|
|
||||||
|
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use regex::Regex;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
// OpenAI-compatible request/response structures
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct ChatMessage {
|
||||||
|
role: String,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct ChatCompletionRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct ChatCompletionResponse {
|
||||||
|
id: String,
|
||||||
|
object: String,
|
||||||
|
created: u64,
|
||||||
|
model: String,
|
||||||
|
choices: Vec<Choice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct Choice {
|
||||||
|
message: ChatMessage,
|
||||||
|
finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/azure/v1/chat/completions")]
|
||||||
|
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||||
|
// Always log raw POST data
|
||||||
|
if let Ok(body_str) = std::str::from_utf8(&body) {
|
||||||
|
info!("POST Data: {}", body_str);
|
||||||
|
} else {
|
||||||
|
info!("POST Data (binary): {:?}", body);
|
||||||
|
}
|
||||||
|
|
||||||
|
dotenv().ok();
|
||||||
|
|
||||||
|
// Environment variables
|
||||||
|
let azure_endpoint = env::var("AI_ENDPOINT")
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||||
|
let azure_key = env::var("AI_KEY")
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||||
|
let deployment_name = env::var("AI_LLM_MODEL")
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
||||||
|
|
||||||
|
// Construct Azure OpenAI URL
|
||||||
|
let url = format!(
|
||||||
|
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
|
||||||
|
azure_endpoint, deployment_name
|
||||||
|
);
|
||||||
|
|
||||||
|
// Forward headers
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"api-key",
|
||||||
|
reqwest::header::HeaderValue::from_str(&azure_key)
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
|
||||||
|
);
|
||||||
|
headers.insert(
|
||||||
|
"Content-Type",
|
||||||
|
reqwest::header::HeaderValue::from_static("application/json"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let body_str = std::str::from_utf8(&body).unwrap_or("");
|
||||||
|
info!("Original POST Data: {}", body_str);
|
||||||
|
|
||||||
|
// Remove the problematic params
|
||||||
|
let re =
|
||||||
|
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
|
||||||
|
let cleaned = re.replace_all(body_str, "");
|
||||||
|
let cleaned_body = web::Bytes::from(cleaned.to_string());
|
||||||
|
|
||||||
|
info!("Cleaned POST Data: {}", cleaned);
|
||||||
|
|
||||||
|
// Send request to Azure
|
||||||
|
let client = Client::new();
|
||||||
|
let response = client
|
||||||
|
.post(&url)
|
||||||
|
.headers(headers)
|
||||||
|
.body(cleaned_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||||
|
|
||||||
|
// Handle response based on status
|
||||||
|
let status = response.status();
|
||||||
|
let raw_response = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||||
|
|
||||||
|
// Log the raw response
|
||||||
|
info!("Raw Azure response: {}", raw_response);
|
||||||
|
|
||||||
|
if status.is_success() {
|
||||||
|
Ok(HttpResponse::Ok().body(raw_response))
|
||||||
|
} else {
|
||||||
|
// Handle error responses properly
|
||||||
|
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||||
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
Ok(HttpResponse::build(actix_status).body(raw_response))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
|
use log::{error, info};
|
||||||
|
|
||||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
use log::{error, info};
|
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
|
// OpenAI-compatible request/response structures
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct ChatMessage {
|
struct ChatMessage {
|
||||||
role: String,
|
role: String,
|
||||||
|
|
@ -35,17 +37,20 @@ struct Choice {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clean_request_body(body: &str) -> String {
|
fn clean_request_body(body: &str) -> String {
|
||||||
|
// Remove problematic parameters that might not be supported by all providers
|
||||||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||||||
re.replace_all(body, "").to_string()
|
re.replace_all(body, "").to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/chat/completions")]
|
#[post("/v1/chat/completions")]
|
||||||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||||
|
// Log raw POST data
|
||||||
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
||||||
info!("Original POST Data: {}", body_str);
|
info!("Original POST Data: {}", body_str);
|
||||||
|
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
|
|
||||||
|
// Get environment variables
|
||||||
let api_key = env::var("AI_KEY")
|
let api_key = env::var("AI_KEY")
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||||
let model = env::var("AI_LLM_MODEL")
|
let model = env::var("AI_LLM_MODEL")
|
||||||
|
|
@ -53,9 +58,11 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
let endpoint = env::var("AI_ENDPOINT")
|
let endpoint = env::var("AI_ENDPOINT")
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||||
|
|
||||||
|
// Parse and modify the request body
|
||||||
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
||||||
|
|
||||||
|
// Add model parameter
|
||||||
if let Some(obj) = json_value.as_object_mut() {
|
if let Some(obj) = json_value.as_object_mut() {
|
||||||
obj.insert("model".to_string(), serde_json::Value::String(model));
|
obj.insert("model".to_string(), serde_json::Value::String(model));
|
||||||
}
|
}
|
||||||
|
|
@ -65,6 +72,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
|
|
||||||
info!("Modified POST Data: {}", modified_body_str);
|
info!("Modified POST Data: {}", modified_body_str);
|
||||||
|
|
||||||
|
// Set up headers
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"Authorization",
|
"Authorization",
|
||||||
|
|
@ -76,6 +84,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
reqwest::header::HeaderValue::from_static("application/json"),
|
reqwest::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Send request to the AI provider
|
||||||
let client = Client::new();
|
let client = Client::new();
|
||||||
let response = client
|
let response = client
|
||||||
.post(&endpoint)
|
.post(&endpoint)
|
||||||
|
|
@ -85,6 +94,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
.await
|
.await
|
||||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||||
|
|
||||||
|
// Handle response
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let raw_response = response
|
let raw_response = response
|
||||||
.text()
|
.text()
|
||||||
|
|
@ -94,6 +104,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
info!("Provider response status: {}", status);
|
info!("Provider response status: {}", status);
|
||||||
info!("Provider response body: {}", raw_response);
|
info!("Provider response body: {}", raw_response);
|
||||||
|
|
||||||
|
// Convert response to OpenAI format if successful
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
match convert_to_openai_format(&raw_response) {
|
match convert_to_openai_format(&raw_response) {
|
||||||
Ok(openai_response) => Ok(HttpResponse::Ok()
|
Ok(openai_response) => Ok(HttpResponse::Ok()
|
||||||
|
|
@ -101,12 +112,14 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
.body(openai_response)),
|
.body(openai_response)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to convert response format: {}", e);
|
error!("Failed to convert response format: {}", e);
|
||||||
|
// Return the original response if conversion fails
|
||||||
Ok(HttpResponse::Ok()
|
Ok(HttpResponse::Ok()
|
||||||
.content_type("application/json")
|
.content_type("application/json")
|
||||||
.body(raw_response))
|
.body(raw_response))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// Return error as-is
|
||||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
|
@ -116,6 +129,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts provider response to OpenAI-compatible format
|
||||||
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct ProviderChoice {
|
struct ProviderChoice {
|
||||||
|
|
@ -177,8 +191,10 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
total_tokens: u32,
|
total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse the provider response
|
||||||
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||||||
|
|
||||||
|
// Extract content from the first choice
|
||||||
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||||||
let content = first_choice.message.content.clone();
|
let content = first_choice.message.content.clone();
|
||||||
let role = first_choice
|
let role = first_choice
|
||||||
|
|
@ -187,6 +203,7 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| "assistant".to_string());
|
.unwrap_or_else(|| "assistant".to_string());
|
||||||
|
|
||||||
|
// Calculate token usage
|
||||||
let usage = provider.usage.unwrap_or_default();
|
let usage = provider.usage.unwrap_or_default();
|
||||||
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||||||
let completion_tokens = usage
|
let completion_tokens = usage
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::env;
|
||||||
use tokio::time::{sleep, Duration};
|
use tokio::time::{sleep, Duration};
|
||||||
|
|
||||||
|
// OpenAI-compatible request/response structures
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct ChatMessage {
|
struct ChatMessage {
|
||||||
role: String,
|
role: String,
|
||||||
|
|
@ -34,6 +35,7 @@ struct Choice {
|
||||||
finish_reason: String,
|
finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Llama.cpp server request/response structures
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct LlamaCppRequest {
|
struct LlamaCppRequest {
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
@ -51,7 +53,8 @@ struct LlamaCppResponse {
|
||||||
generation_settings: Option<serde_json::Value>,
|
generation_settings: Option<serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
|
||||||
|
{
|
||||||
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
||||||
|
|
||||||
if llm_local.to_lowercase() != "true" {
|
if llm_local.to_lowercase() != "true" {
|
||||||
|
|
@ -59,6 +62,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get configuration from environment variables
|
||||||
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
let embedding_url =
|
let embedding_url =
|
||||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
|
|
@ -73,6 +77,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
info!(" LLM Model: {}", llm_model_path);
|
info!(" LLM Model: {}", llm_model_path);
|
||||||
info!(" Embedding Model: {}", embedding_model_path);
|
info!(" Embedding Model: {}", embedding_model_path);
|
||||||
|
|
||||||
|
// Check if servers are already running
|
||||||
let llm_running = is_server_running(&llm_url).await;
|
let llm_running = is_server_running(&llm_url).await;
|
||||||
let embedding_running = is_server_running(&embedding_url).await;
|
let embedding_running = is_server_running(&embedding_url).await;
|
||||||
|
|
||||||
|
|
@ -81,6 +86,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start servers that aren't running
|
||||||
let mut tasks = vec![];
|
let mut tasks = vec![];
|
||||||
|
|
||||||
if !llm_running && !llm_model_path.is_empty() {
|
if !llm_running && !llm_model_path.is_empty() {
|
||||||
|
|
@ -105,17 +111,19 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait for all server startup tasks
|
||||||
for task in tasks {
|
for task in tasks {
|
||||||
task.await??;
|
task.await??;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait for servers to be ready with verbose logging
|
||||||
info!("⏳ Waiting for servers to become ready...");
|
info!("⏳ Waiting for servers to become ready...");
|
||||||
|
|
||||||
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
||||||
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
||||||
|
|
||||||
let mut attempts = 0;
|
let mut attempts = 0;
|
||||||
let max_attempts = 60;
|
let max_attempts = 60; // 2 minutes total
|
||||||
|
|
||||||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||||
sleep(Duration::from_secs(2)).await;
|
sleep(Duration::from_secs(2)).await;
|
||||||
|
|
@ -180,6 +188,8 @@ async fn start_llm_server(
|
||||||
std::env::set_var("OMP_PLACES", "cores");
|
std::env::set_var("OMP_PLACES", "cores");
|
||||||
std::env::set_var("OMP_PROC_BIND", "close");
|
std::env::set_var("OMP_PROC_BIND", "close");
|
||||||
|
|
||||||
|
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
|
||||||
|
|
||||||
let mut cmd = tokio::process::Command::new("sh");
|
let mut cmd = tokio::process::Command::new("sh");
|
||||||
cmd.arg("-c").arg(format!(
|
cmd.arg("-c").arg(format!(
|
||||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
||||||
|
|
@ -215,6 +225,7 @@ async fn is_server_running(url: &str) -> bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert OpenAI chat messages to a single prompt
|
||||||
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
|
@ -239,28 +250,32 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Proxy endpoint
|
||||||
#[post("/local/v1/chat/completions")]
|
#[post("/local/v1/chat/completions")]
|
||||||
pub async fn chat_completions_local(
|
pub async fn chat_completions_local(
|
||||||
req_body: web::Json<ChatCompletionRequest>,
|
req_body: web::Json<ChatCompletionRequest>,
|
||||||
_req: HttpRequest,
|
_req: HttpRequest,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
dotenv().ok();
|
dotenv().ok().unwrap();
|
||||||
|
|
||||||
|
// Get llama.cpp server URL
|
||||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
|
||||||
|
// Convert OpenAI format to llama.cpp format
|
||||||
let prompt = messages_to_prompt(&req_body.messages);
|
let prompt = messages_to_prompt(&req_body.messages);
|
||||||
|
|
||||||
let llama_request = LlamaCppRequest {
|
let llama_request = LlamaCppRequest {
|
||||||
prompt,
|
prompt,
|
||||||
n_predict: Some(500),
|
n_predict: Some(500), // Adjust as needed
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
top_k: Some(40),
|
top_k: Some(40),
|
||||||
top_p: Some(0.9),
|
top_p: Some(0.9),
|
||||||
stream: req_body.stream,
|
stream: req_body.stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Send request to llama.cpp server
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.timeout(Duration::from_secs(120))
|
.timeout(Duration::from_secs(120)) // 2 minute timeout
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("Error creating HTTP client: {}", e);
|
error!("Error creating HTTP client: {}", e);
|
||||||
|
|
@ -286,6 +301,7 @@ pub async fn chat_completions_local(
|
||||||
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// Convert llama.cpp response to OpenAI format
|
||||||
let openai_response = ChatCompletionResponse {
|
let openai_response = ChatCompletionResponse {
|
||||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||||
object: "chat.completion".to_string(),
|
object: "chat.completion".to_string(),
|
||||||
|
|
@ -328,6 +344,7 @@ pub async fn chat_completions_local(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI Embedding Request - Modified to handle both string and array inputs
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct EmbeddingRequest {
|
pub struct EmbeddingRequest {
|
||||||
#[serde(deserialize_with = "deserialize_input")]
|
#[serde(deserialize_with = "deserialize_input")]
|
||||||
|
|
@ -337,6 +354,7 @@ pub struct EmbeddingRequest {
|
||||||
pub _encoding_format: Option<String>,
|
pub _encoding_format: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Custom deserializer to handle both string and array inputs
|
||||||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||||
where
|
where
|
||||||
D: serde::Deserializer<'de>,
|
D: serde::Deserializer<'de>,
|
||||||
|
|
@ -382,6 +400,7 @@ where
|
||||||
deserializer.deserialize_any(InputVisitor)
|
deserializer.deserialize_any(InputVisitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI Embedding Response
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct EmbeddingResponse {
|
pub struct EmbeddingResponse {
|
||||||
pub object: String,
|
pub object: String,
|
||||||
|
|
@ -403,17 +422,20 @@ pub struct Usage {
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Llama.cpp Embedding Request
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct LlamaCppEmbeddingRequest {
|
struct LlamaCppEmbeddingRequest {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXED: Handle the stupid nested array format
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct LlamaCppEmbeddingResponseItem {
|
struct LlamaCppEmbeddingResponseItem {
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
pub embedding: Vec<Vec<f32>>,
|
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Proxy endpoint for embeddings
|
||||||
#[post("/v1/embeddings")]
|
#[post("/v1/embeddings")]
|
||||||
pub async fn embeddings_local(
|
pub async fn embeddings_local(
|
||||||
req_body: web::Json<EmbeddingRequest>,
|
req_body: web::Json<EmbeddingRequest>,
|
||||||
|
|
@ -421,6 +443,7 @@ pub async fn embeddings_local(
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
|
|
||||||
|
// Get llama.cpp server URL
|
||||||
let llama_url =
|
let llama_url =
|
||||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
|
|
||||||
|
|
@ -432,6 +455,7 @@ pub async fn embeddings_local(
|
||||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// Process each input text and get embeddings
|
||||||
let mut embeddings_data = Vec::new();
|
let mut embeddings_data = Vec::new();
|
||||||
let mut total_tokens = 0;
|
let mut total_tokens = 0;
|
||||||
|
|
||||||
|
|
@ -456,11 +480,13 @@ pub async fn embeddings_local(
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
|
// First, get the raw response text for debugging
|
||||||
let raw_response = response.text().await.map_err(|e| {
|
let raw_response = response.text().await.map_err(|e| {
|
||||||
error!("Error reading response text: {}", e);
|
error!("Error reading response text: {}", e);
|
||||||
actix_web::error::ErrorInternalServerError("Failed to read response")
|
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// Parse the response as a vector of items with nested arrays
|
||||||
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||||||
serde_json::from_str(&raw_response).map_err(|e| {
|
serde_json::from_str(&raw_response).map_err(|e| {
|
||||||
error!("Error parsing llama.cpp embedding response: {}", e);
|
error!("Error parsing llama.cpp embedding response: {}", e);
|
||||||
|
|
@ -470,13 +496,17 @@ pub async fn embeddings_local(
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// Extract the embedding from the nested array bullshit
|
||||||
if let Some(item) = llama_response.get(0) {
|
if let Some(item) = llama_response.get(0) {
|
||||||
|
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
|
||||||
|
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
|
||||||
let flattened_embedding = if !item.embedding.is_empty() {
|
let flattened_embedding = if !item.embedding.is_empty() {
|
||||||
item.embedding[0].clone()
|
item.embedding[0].clone() // Take the first (and probably only) inner array
|
||||||
} else {
|
} else {
|
||||||
vec![]
|
vec![] // Empty if no embedding data
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Estimate token count
|
||||||
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||||||
total_tokens += estimated_tokens;
|
total_tokens += estimated_tokens;
|
||||||
|
|
||||||
|
|
@ -514,6 +544,7 @@ pub async fn embeddings_local(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build OpenAI-compatible response
|
||||||
let openai_response = EmbeddingResponse {
|
let openai_response = EmbeddingResponse {
|
||||||
object: "list".to_string(),
|
object: "list".to_string(),
|
||||||
data: embeddings_data,
|
data: embeddings_data,
|
||||||
|
|
@ -527,6 +558,7 @@ pub async fn embeddings_local(
|
||||||
Ok(HttpResponse::Ok().json(openai_response))
|
Ok(HttpResponse::Ok().json(openai_response))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Health check endpoint
|
||||||
#[actix_web::get("/health")]
|
#[actix_web::get("/health")]
|
||||||
pub async fn health() -> Result<HttpResponse> {
|
pub async fn health() -> Result<HttpResponse> {
|
||||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
pub mod llm;
|
pub mod llm_azure;
|
||||||
pub mod llm_generic;
|
pub mod llm_generic;
|
||||||
pub mod llm_local;
|
pub mod llm_local;
|
||||||
|
|
|
||||||
97
src/main.rs
97
src/main.rs
|
|
@ -1,4 +1,5 @@
|
||||||
use actix_cors::Cors;
|
use actix_cors::Cors;
|
||||||
|
use actix_web::middleware::Logger;
|
||||||
use actix_web::{web, App, HttpServer};
|
use actix_web::{web, App, HttpServer};
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
use log::info;
|
use log::info;
|
||||||
|
|
@ -23,23 +24,30 @@ mod tools;
|
||||||
mod web_automation;
|
mod web_automation;
|
||||||
mod whatsapp;
|
mod whatsapp;
|
||||||
|
|
||||||
|
use crate::automation::AutomationService;
|
||||||
use crate::bot::{
|
use crate::bot::{
|
||||||
create_session, get_session_history, get_sessions, index, set_mode_handler, static_files,
|
create_session, get_session_history, get_sessions, index, set_mode_handler, static_files,
|
||||||
voice_start, voice_stop, websocket_handler, whatsapp_webhook, whatsapp_webhook_verify,
|
voice_start, voice_stop, websocket_handler, whatsapp_webhook, whatsapp_webhook_verify,
|
||||||
};
|
};
|
||||||
use crate::channels::{VoiceAdapter, WebChannelAdapter};
|
use crate::channels::{VoiceAdapter, WebChannelAdapter};
|
||||||
use crate::config::AppConfig;
|
use crate::config::AppConfig;
|
||||||
use crate::email::send_email;
|
use crate::email::{
|
||||||
|
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
|
||||||
|
};
|
||||||
use crate::file::{list_file, upload_file};
|
use crate::file::{list_file, upload_file};
|
||||||
|
use crate::llm_legacy::llm_generic::generic_chat_completions;
|
||||||
|
use crate::llm_legacy::llm_local::{
|
||||||
|
chat_completions_local, embeddings_local, ensure_llama_servers_running,
|
||||||
|
};
|
||||||
use crate::shared::AppState;
|
use crate::shared::AppState;
|
||||||
use crate::whatsapp::WhatsAppAdapter;
|
use crate::whatsapp::WhatsAppAdapter;
|
||||||
|
|
||||||
#[actix_web::main]
|
#[actix_web::main]
|
||||||
async fn main() -> std::io::Result<()> {
|
async fn main() -> std::io::Result<()> {
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
env_logger::init();
|
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
|
||||||
|
|
||||||
info!("Starting BotServer...");
|
info!("Starting General Bots 6.0...");
|
||||||
|
|
||||||
let config = AppConfig::from_env();
|
let config = AppConfig::from_env();
|
||||||
|
|
||||||
|
|
@ -59,17 +67,17 @@ async fn main() -> std::io::Result<()> {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Optional custom database pool
|
// Optional custom database pool
|
||||||
let db_custom_pool: Option<sqlx::Pool<sqlx::Postgres>> =
|
let db_custom_pool = match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await
|
||||||
match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await {
|
{
|
||||||
Ok(pool) => {
|
Ok(pool) => {
|
||||||
info!("Connected to custom database");
|
info!("Connected to custom database");
|
||||||
Some(pool)
|
Some(pool)
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::warn!("Failed to connect to custom database: {}", e);
|
log::warn!("Failed to connect to custom database: {}", e);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Optional Redis client
|
// Optional Redis client
|
||||||
let redis_client = match redis::Client::open("redis://127.0.0.1/") {
|
let redis_client = match redis::Client::open("redis://127.0.0.1/") {
|
||||||
|
|
@ -83,9 +91,28 @@ async fn main() -> std::io::Result<()> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Placeholder for MinIO (not yet implemented)
|
// Initialize MinIO client
|
||||||
let minio_client = None;
|
let minio_client = file::init_minio(&config)
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize Minio");
|
||||||
|
|
||||||
|
// Initialize browser pool
|
||||||
|
let browser_pool = Arc::new(web_automation::BrowserPool::new(
|
||||||
|
"chrome".to_string(),
|
||||||
|
2,
|
||||||
|
"headless".to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Initialize LLM servers
|
||||||
|
ensure_llama_servers_running()
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize LLM local server.");
|
||||||
|
|
||||||
|
web_automation::initialize_browser_pool()
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize browser pool");
|
||||||
|
|
||||||
|
// Initialize services from new architecture
|
||||||
let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone());
|
let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone());
|
||||||
let session_manager = session::SessionManager::new(db_pool.clone(), redis_client.clone());
|
let session_manager = session::SessionManager::new(db_pool.clone(), redis_client.clone());
|
||||||
|
|
||||||
|
|
@ -110,20 +137,13 @@ async fn main() -> std::io::Result<()> {
|
||||||
|
|
||||||
let tool_api = Arc::new(tools::ToolApi::new());
|
let tool_api = Arc::new(tools::ToolApi::new());
|
||||||
|
|
||||||
// Browser pool – constructed with the required three arguments.
|
// Create unified app state
|
||||||
// Adjust the strings as needed for your environment.
|
|
||||||
let browser_pool = Arc::new(web_automation::BrowserPool::new(
|
|
||||||
"chrome".to_string(),
|
|
||||||
2,
|
|
||||||
"headless".to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
let app_state = AppState {
|
let app_state = AppState {
|
||||||
minio_client,
|
minio_client: Some(minio_client),
|
||||||
config: Some(config.clone()),
|
config: Some(config.clone()),
|
||||||
db: Some(db_pool.clone()),
|
db: Some(db_pool.clone()),
|
||||||
db_custom: db_custom_pool,
|
db_custom: db_custom_pool.clone(),
|
||||||
browser_pool,
|
browser_pool: browser_pool.clone(),
|
||||||
orchestrator: Arc::new(orchestrator),
|
orchestrator: Arc::new(orchestrator),
|
||||||
web_adapter,
|
web_adapter,
|
||||||
voice_adapter,
|
voice_adapter,
|
||||||
|
|
@ -131,6 +151,11 @@ async fn main() -> std::io::Result<()> {
|
||||||
tool_api,
|
tool_api,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Start automation service in background
|
||||||
|
let automation_state = app_state.clone();
|
||||||
|
let automation = AutomationService::new(automation_state, "src/prompts");
|
||||||
|
let _automation_handle = automation.spawn();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Starting server on {}:{}",
|
"Starting server on {}:{}",
|
||||||
config.server.host, config.server.port
|
config.server.host, config.server.port
|
||||||
|
|
@ -145,7 +170,22 @@ async fn main() -> std::io::Result<()> {
|
||||||
|
|
||||||
App::new()
|
App::new()
|
||||||
.wrap(cors)
|
.wrap(cors)
|
||||||
|
.wrap(Logger::default())
|
||||||
|
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
|
||||||
.app_data(web::Data::new(app_state.clone()))
|
.app_data(web::Data::new(app_state.clone()))
|
||||||
|
// Legacy services
|
||||||
|
.service(upload_file)
|
||||||
|
.service(list_file)
|
||||||
|
.service(save_click)
|
||||||
|
.service(get_emails)
|
||||||
|
.service(list_emails)
|
||||||
|
.service(send_email)
|
||||||
|
.service(chat_completions_local)
|
||||||
|
.service(save_draft)
|
||||||
|
.service(generic_chat_completions)
|
||||||
|
.service(embeddings_local)
|
||||||
|
.service(get_latest_email_from)
|
||||||
|
// New bot services
|
||||||
.service(index)
|
.service(index)
|
||||||
.service(static_files)
|
.service(static_files)
|
||||||
.service(websocket_handler)
|
.service(websocket_handler)
|
||||||
|
|
@ -157,9 +197,6 @@ async fn main() -> std::io::Result<()> {
|
||||||
.service(get_sessions)
|
.service(get_sessions)
|
||||||
.service(get_session_history)
|
.service(get_session_history)
|
||||||
.service(set_mode_handler)
|
.service(set_mode_handler)
|
||||||
.service(send_email)
|
|
||||||
.service(upload_file)
|
|
||||||
.service(list_file)
|
|
||||||
})
|
})
|
||||||
.bind((config.server.host.clone(), config.server.port))?
|
.bind((config.server.host.clone(), config.server.port))?
|
||||||
.run()
|
.run()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
use chrono::Utc;
|
|
||||||
use langchain_rust::llm::AzureConfig;
|
use langchain_rust::llm::AzureConfig;
|
||||||
use log::{debug, warn};
|
use log::{debug, warn};
|
||||||
use rhai::{Array, Dynamic};
|
use rhai::{Array, Dynamic};
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
<!doctype html>
|
<!doctype html>
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<title>General Bots - ChatGPT Clone</title>
|
<title>General Bots</title>
|
||||||
<style>
|
<style>
|
||||||
* {
|
* {
|
||||||
margin: 0;
|
margin: 0;
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue