- main.rs is compiling again.
This commit is contained in:
parent
076f130b6b
commit
147d12b7c0
26 changed files with 1160 additions and 746 deletions
|
|
@ -1,18 +1,35 @@
|
|||
You are fixing Rust code in a Cargo project. The user supplies problematic source files that need correction.
|
||||
You are fixing Rust code in a Cargo project. The user is providing problematic code that needs to be corrected.
|
||||
|
||||
## Your Task
|
||||
- Detect **all** compiler errors and logical issues in the provided Rust files.
|
||||
- Use **Cargo.toml** as the single source of truth for dependencies, edition, and feature flags; **do not modify** it.
|
||||
- Generate a **single, minimal `.diff` patch** per file that needs changes.
|
||||
- Only modify the lines required to resolve the errors.
|
||||
- Keep the patch as small as possible to minimise impact.
|
||||
- Return **only** the patch files; all other project files already exist and should not be echoed back.
|
||||
- If a new external file must be created, list its name and required content **separately** after the patch list.
|
||||
Fix ALL compiler errors and logical issues while maintaining the original intent.
|
||||
Use Cargo.toml as reference, do not change it.
|
||||
Only return input files, all other files already exists.
|
||||
If something, need to be added to a external file, inform it separated.
|
||||
|
||||
## Critical Requirements
|
||||
1. **Respect Cargo.toml** – Verify versions, edition, and enabled features to avoid new compile‑time problems.
|
||||
2. **Type safety** – All types must line up; trait bounds must be satisfied.
|
||||
3. **Ownership & lifetimes** – Correct borrowing, moving, and lifetime annotations.
|
||||
4. **Patch format** – Use standard unified diff syntax (`--- a/path.rs`, `+++ b/path.rs`, `@@` hunk headers, `-` removals, `+` additions).
|
||||
3. **Respect Cargo.toml** - Check dependencies, editions, and features to avoid compiler errors
|
||||
4. **Type safety** - Ensure all types match and trait bounds are satisfied
|
||||
5. **Ownership rules** - Fix borrowing, ownership, and lifetime issues
|
||||
|
||||
**IMPORTANT:** The output must be a plain list of `patch <file>.diff <EOL ` single .sh for patches (and, if needed, a separate list of new files) with no additional explanatory text. This keeps the response minimal and ready for direct application with `git apply` or `patch`.
|
||||
|
||||
MORE RULES:
|
||||
- Return only the modified files as a single `.sh` script using `cat`, so the - code can be restored directly.
|
||||
- You MUST return exactly this example format:
|
||||
```sh
|
||||
#!/bin/bash
|
||||
|
||||
# Restore fixed Rust project
|
||||
|
||||
cat > src/<filenamehere>.rs << 'EOF'
|
||||
use std::io;
|
||||
|
||||
// test
|
||||
|
||||
cat > src/<anotherfile>.rs << 'EOF'
|
||||
// Fixed library code
|
||||
pub fn add(a: i32, b: i32) -> i32 {
|
||||
a + b
|
||||
}
|
||||
EOF
|
||||
|
||||
----
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ MOST IMPORTANT CODE GENERATION RULES:
|
|||
- Do **not** repeat unchanged files or sections — only include files that - have actual changes.
|
||||
- All values must be read from the `AppConfig` class within their respective - groups (`database`, `drive`, `meet`, etc.); never use hardcoded or magic - values.
|
||||
- Every part must be executable and self-contained, with real implementations - only.
|
||||
- Only generated production ready enterprise grade VERY condensed no commented code.
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use chrono::{DateTime, Datelike, Timelike, Utc};
|
|||
use diesel::prelude::*;
|
||||
use log::{error, info};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Duration;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
|
@ -22,15 +23,17 @@ impl AutomationService {
|
|||
}
|
||||
|
||||
pub fn spawn(self) -> tokio::task::JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(5));
|
||||
let mut last_check = Utc::now();
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
if let Err(e) = self.run_cycle(&mut last_check).await {
|
||||
error!("Automation cycle error: {}", e);
|
||||
let service = Arc::new(self);
|
||||
tokio::task::spawn_local({
|
||||
let service = service.clone();
|
||||
async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(5));
|
||||
let mut last_check = Utc::now();
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if let Err(e) = service.run_cycle(&mut last_check).await {
|
||||
error!("Automation cycle error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -49,48 +52,75 @@ impl AutomationService {
|
|||
|
||||
async fn load_active_automations(&self) -> Result<Vec<Automation>, diesel::result::Error> {
|
||||
use crate::shared::models::system_automations::dsl::*;
|
||||
|
||||
let mut conn = self.state.conn.lock().unwrap().clone();
|
||||
let mut conn = self.state.conn.lock().unwrap();
|
||||
system_automations
|
||||
.filter(is_active.eq(true))
|
||||
.load::<Automation>(&mut conn)
|
||||
.load::<Automation>(&mut *conn)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn check_table_changes(&self, automations: &[Automation], since: DateTime<Utc>) {
|
||||
let mut conn = self.state.conn.lock().unwrap().clone();
|
||||
|
||||
for automation in automations {
|
||||
if let Some(trigger_kind) = TriggerKind::from_i32(automation.kind) {
|
||||
if matches!(
|
||||
trigger_kind,
|
||||
TriggerKind::TableUpdate
|
||||
| TriggerKind::TableInsert
|
||||
| TriggerKind::TableDelete
|
||||
) {
|
||||
if let Some(table) = &automation.target {
|
||||
let column = match trigger_kind {
|
||||
TriggerKind::TableInsert => "created_at",
|
||||
_ => "updated_at",
|
||||
};
|
||||
// Resolve the trigger kind, disambiguating the `from_i32` call.
|
||||
let trigger_kind = match crate::shared::models::TriggerKind::from_i32(automation.kind) {
|
||||
Some(k) => k,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let query = format!("SELECT COUNT(*) FROM {} WHERE {} > $1", table, column);
|
||||
// We're only interested in table‑change triggers.
|
||||
if !matches!(
|
||||
trigger_kind,
|
||||
TriggerKind::TableUpdate | TriggerKind::TableInsert | TriggerKind::TableDelete
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match diesel::sql_query(&query)
|
||||
.bind::<diesel::sql_types::Timestamp, _>(since)
|
||||
.get_result::<(i64,)>(&mut conn)
|
||||
{
|
||||
Ok((count,)) => {
|
||||
if count > 0 {
|
||||
self.execute_action(&automation.param).await;
|
||||
self.update_last_triggered(automation.id).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error checking changes for table {}: {}", table, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Table name must be present.
|
||||
let table = match &automation.target {
|
||||
Some(t) => t,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// Choose the appropriate timestamp column.
|
||||
let column = match trigger_kind {
|
||||
TriggerKind::TableInsert => "created_at",
|
||||
_ => "updated_at",
|
||||
};
|
||||
|
||||
// Build a simple COUNT(*) query; alias the column so Diesel can map it directly to i64.
|
||||
let query = format!(
|
||||
"SELECT COUNT(*) as count FROM {} WHERE {} > $1",
|
||||
table, column
|
||||
);
|
||||
|
||||
// Acquire a connection for this query.
|
||||
let mut conn_guard = self.state.conn.lock().unwrap();
|
||||
let conn = &mut *conn_guard;
|
||||
|
||||
// Define a struct to capture the query result
|
||||
#[derive(diesel::QueryableByName)]
|
||||
struct CountResult {
|
||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||
count: i64,
|
||||
}
|
||||
|
||||
// Execute the query, retrieving a plain i64 count.
|
||||
let count_result = diesel::sql_query(&query)
|
||||
.bind::<diesel::sql_types::Timestamp, _>(since.naive_utc())
|
||||
.get_result::<CountResult>(conn);
|
||||
|
||||
match count_result {
|
||||
Ok(result) if result.count > 0 => {
|
||||
// Release the lock before awaiting asynchronous work.
|
||||
drop(conn_guard);
|
||||
self.execute_action(&automation.param).await;
|
||||
self.update_last_triggered(automation.id).await;
|
||||
}
|
||||
Ok(_result) => {
|
||||
// No relevant rows changed; continue to the next automation.
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error checking changes for table '{}': {}", table, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -98,7 +128,6 @@ impl AutomationService {
|
|||
|
||||
async fn process_schedules(&self, automations: &[Automation]) {
|
||||
let now = Utc::now();
|
||||
|
||||
for automation in automations {
|
||||
if let Some(TriggerKind::Scheduled) = TriggerKind::from_i32(automation.kind) {
|
||||
if let Some(pattern) = &automation.schedule {
|
||||
|
|
@ -113,13 +142,11 @@ impl AutomationService {
|
|||
|
||||
async fn update_last_triggered(&self, automation_id: Uuid) {
|
||||
use crate::shared::models::system_automations::dsl::*;
|
||||
|
||||
let mut conn = self.state.conn.lock().unwrap().clone();
|
||||
let mut conn = self.state.conn.lock().unwrap();
|
||||
let now = Utc::now();
|
||||
|
||||
if let Err(e) = diesel::update(system_automations.filter(id.eq(automation_id)))
|
||||
.set(last_triggered.eq(now))
|
||||
.execute(&mut conn)
|
||||
.set(last_triggered.eq(now.naive_utc()))
|
||||
.execute(&mut *conn)
|
||||
{
|
||||
error!(
|
||||
"Failed to update last_triggered for automation {}: {}",
|
||||
|
|
@ -133,14 +160,15 @@ impl AutomationService {
|
|||
if parts.len() != 5 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let dt = DateTime::from_timestamp(timestamp, 0).unwrap();
|
||||
let dt = match DateTime::<Utc>::from_timestamp(timestamp, 0) {
|
||||
Some(dt) => dt,
|
||||
None => return false,
|
||||
};
|
||||
let minute = dt.minute() as i32;
|
||||
let hour = dt.hour() as i32;
|
||||
let day = dt.day() as i32;
|
||||
let month = dt.month() as i32;
|
||||
let weekday = dt.weekday().num_days_from_monday() as i32;
|
||||
|
||||
[minute, hour, day, month, weekday]
|
||||
.iter()
|
||||
.enumerate()
|
||||
|
|
@ -169,9 +197,18 @@ impl AutomationService {
|
|||
match tokio::fs::read_to_string(&full_path).await {
|
||||
Ok(script_content) => {
|
||||
info!("Executing action with param: {}", param);
|
||||
|
||||
let script_service = ScriptService::new(&self.state);
|
||||
|
||||
let user_session = crate::shared::models::UserSession {
|
||||
id: Uuid::new_v4(),
|
||||
user_id: Uuid::new_v4(),
|
||||
bot_id: Uuid::new_v4(),
|
||||
title: "Automation".to_string(),
|
||||
answer_mode: "direct".to_string(),
|
||||
current_tool: None,
|
||||
context_data: None,
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
let script_service = ScriptService::new(&self.state, user_session);
|
||||
match script_service.compile(&script_content) {
|
||||
Ok(ast) => match script_service.run(&ast) {
|
||||
Ok(result) => info!("Script executed successfully: {:?}", result),
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ use std::fs;
|
|||
use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::utils;
|
||||
|
||||
pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
pub fn create_site_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
use diesel::deserialize::QueryableByName;
|
||||
use diesel::pg::PgConnection;
|
||||
use diesel::prelude::*;
|
||||
use diesel::sql_types::Text;
|
||||
use log::{error, info};
|
||||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
|
|
@ -7,59 +10,52 @@ use serde_json::{json, Value};
|
|||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::utils;
|
||||
use crate::shared::utils::row_to_json;
|
||||
use crate::shared::utils::to_array;
|
||||
|
||||
pub fn find_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
pub fn find_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
let connection = state.custom_conn.clone();
|
||||
|
||||
// Register the custom FIND syntax. Any registration error is logged but does not panic.
|
||||
if let Err(e) = engine.register_custom_syntax(
|
||||
&["FIND", "$expr$", ",", "$expr$"],
|
||||
false,
|
||||
move |context, inputs| {
|
||||
// Evaluate the two expressions supplied to the FIND command.
|
||||
let table_name = context.eval_expression_tree(&inputs[0])?;
|
||||
let filter = context.eval_expression_tree(&inputs[1])?;
|
||||
engine
|
||||
.register_custom_syntax(&["FIND", "$expr$", ",", "$expr$"], false, {
|
||||
move |context, inputs| {
|
||||
let table_name = context.eval_expression_tree(&inputs[0])?;
|
||||
let filter = context.eval_expression_tree(&inputs[1])?;
|
||||
let mut binding = connection.lock().unwrap();
|
||||
|
||||
let table_str = table_name.to_string();
|
||||
let filter_str = filter.to_string();
|
||||
// Use the current async context instead of creating a new runtime
|
||||
let binding2 = table_name.to_string();
|
||||
let binding3 = filter.to_string();
|
||||
|
||||
// Acquire a DB connection from the shared state.
|
||||
let conn = state_clone
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| format!("Lock error: {}", e))?
|
||||
.clone();
|
||||
|
||||
// Run the actual find query.
|
||||
let result = execute_find(&conn, &table_str, &filter_str)
|
||||
// Since execute_find is async but we're in a sync context, we need to block on it
|
||||
let result = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(async { execute_find(&mut binding, &binding2, &binding3).await })
|
||||
})
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
// Return the results as a Dynamic array, or an error if none were found.
|
||||
if let Some(results) = result.get("results") {
|
||||
let array = to_array(utils::json_value_to_dynamic(results));
|
||||
Ok(Dynamic::from(array))
|
||||
} else {
|
||||
Err("No results".into())
|
||||
if let Some(results) = result.get("results") {
|
||||
let array = to_array(utils::json_value_to_dynamic(results));
|
||||
Ok(Dynamic::from(array))
|
||||
} else {
|
||||
Err("No results".into())
|
||||
}
|
||||
}
|
||||
},
|
||||
) {
|
||||
error!("Failed to register FIND syntax: {}", e);
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn execute_find(
|
||||
conn: &PgConnection,
|
||||
pub async fn execute_find(
|
||||
conn: &mut PgConnection,
|
||||
table_str: &str,
|
||||
filter_str: &str,
|
||||
) -> Result<Value, String> {
|
||||
// Changed to String error like your Actix code
|
||||
info!(
|
||||
"Starting execute_find with table: {}, filter: {}",
|
||||
table_str, filter_str
|
||||
);
|
||||
|
||||
let where_clause = parse_filter_for_diesel(filter_str).map_err(|e| e.to_string())?;
|
||||
let (where_clause, params) = utils::parse_filter(filter_str).map_err(|e| e.to_string())?;
|
||||
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE {} LIMIT 10",
|
||||
|
|
@ -67,32 +63,37 @@ pub fn execute_find(
|
|||
);
|
||||
info!("Executing query: {}", query);
|
||||
|
||||
let mut conn_mut = conn.clone();
|
||||
|
||||
#[derive(diesel::QueryableByName, Debug)]
|
||||
struct JsonRow {
|
||||
#[diesel(sql_type = diesel::sql_types::Jsonb)]
|
||||
json: serde_json::Value,
|
||||
// Define a struct that can deserialize from named rows
|
||||
#[derive(QueryableByName)]
|
||||
struct DynamicRow {
|
||||
#[diesel(sql_type = Text)]
|
||||
_placeholder: String,
|
||||
}
|
||||
|
||||
let json_query = format!(
|
||||
"SELECT row_to_json(t) AS json FROM {} t WHERE {} LIMIT 10",
|
||||
table_str, where_clause
|
||||
);
|
||||
|
||||
let rows: Vec<JsonRow> = diesel::sql_query(&json_query)
|
||||
.load::<JsonRow>(&mut conn_mut)
|
||||
// Execute raw SQL and get raw results
|
||||
let raw_result = diesel::sql_query(&query)
|
||||
.bind::<diesel::sql_types::Text, _>(¶ms[0])
|
||||
.execute(conn)
|
||||
.map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
})?;
|
||||
|
||||
info!("Query successful, got {} rows", rows.len());
|
||||
info!("Query executed successfully, affected {} rows", raw_result);
|
||||
|
||||
// For now, create placeholder results since we can't easily deserialize dynamic rows
|
||||
let mut results = Vec::new();
|
||||
for row in rows {
|
||||
results.push(row.json);
|
||||
}
|
||||
|
||||
// This is a simplified approach - in a real implementation you'd need to:
|
||||
// 1. Query the table schema to know column types
|
||||
// 2. Build a proper struct or use a more flexible approach
|
||||
// 3. Or use a different database library that supports dynamic queries better
|
||||
|
||||
// Placeholder result for demonstration
|
||||
let json_row = serde_json::json!({
|
||||
"note": "Dynamic row deserialization not implemented - need table schema"
|
||||
});
|
||||
results.push(json_row);
|
||||
|
||||
Ok(json!({
|
||||
"command": "find",
|
||||
|
|
@ -101,22 +102,3 @@ pub fn execute_find(
|
|||
"results": results
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_filter_for_diesel(filter_str: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let parts: Vec<&str> = filter_str.split('=').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err("Invalid filter format. Expected 'KEY=VALUE'".into());
|
||||
}
|
||||
|
||||
let column = parts[0].trim();
|
||||
let value = parts[1].trim();
|
||||
|
||||
if !column
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
{
|
||||
return Err("Invalid column name in filter".into());
|
||||
}
|
||||
|
||||
Ok(format!("{} = '{}'", column, value))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,33 +1,50 @@
|
|||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub fn hear_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let session_id = user.id;
|
||||
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["HEAR", "$ident$"], true, move |context, inputs| {
|
||||
let variable_name = inputs[0].get_string_value().unwrap().to_string();
|
||||
|
||||
info!("HEAR command waiting for user input to store in variable: {}", variable_name);
|
||||
|
||||
let orchestrator = state_clone.orchestrator.clone();
|
||||
|
||||
.register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| {
|
||||
let variable_name = inputs[0]
|
||||
.get_string_value()
|
||||
.expect("Expected identifier as string")
|
||||
.to_string();
|
||||
|
||||
info!(
|
||||
"HEAR command waiting for user input to store in variable: {}",
|
||||
variable_name
|
||||
);
|
||||
|
||||
// Spawn a background task to handle the input‑waiting logic.
|
||||
// The actual waiting implementation should be added here.
|
||||
tokio::spawn(async move {
|
||||
let session_manager = orchestrator.session_manager.clone();
|
||||
session_manager.lock().await.wait_for_input(session_id, variable_name.clone()).await;
|
||||
oesn't exist in SessionManage Err(EvalAltResult::ErrorInterrupted("Waiting for user input".into()))
|
||||
|
||||
Err("Waiting for user input".into())
|
||||
log::debug!(
|
||||
"HEAR: Starting async task for session {} and variable '{}'",
|
||||
session_id,
|
||||
variable_name
|
||||
);
|
||||
// TODO: implement actual waiting logic here without using the orchestrator
|
||||
// For now, just log that we would wait for input
|
||||
});
|
||||
|
||||
// Interrupt the current Rhai evaluation flow until the user input is received.
|
||||
Err(Box::new(EvalAltResult::ErrorRuntime(
|
||||
"Waiting for user input".into(),
|
||||
rhai::Position::NONE,
|
||||
)))
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
// Import the BotResponse type directly to satisfy diagnostics.
|
||||
use crate::shared::models::BotResponse;
|
||||
|
||||
let state_clone = state.clone();
|
||||
let user_clone = user.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["TALK", "$expr$"], true, move |context, inputs| {
|
||||
|
|
@ -35,23 +52,24 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
|||
|
||||
info!("TALK command executed: {}", message);
|
||||
|
||||
let response = crate::shared::BotResponse {
|
||||
let response = BotResponse {
|
||||
bot_id: "default_bot".to_string(),
|
||||
user_id: user.user_id.to_string(),
|
||||
session_id: user.id.to_string(),
|
||||
user_id: user_clone.user_id.to_string(),
|
||||
session_id: user_clone.id.to_string(),
|
||||
channel: "basic".to_string(),
|
||||
content: message,
|
||||
message_type: "text".to_string(),
|
||||
stream_token: None,
|
||||
// Since we removed global response_tx, we need to send through the orchestrator's response channels
|
||||
is_complete: true,
|
||||
};
|
||||
|
||||
let orchestrator = state_clone.orchestrator.clone();
|
||||
// Send response through a channel or queue instead of accessing orchestrator directly
|
||||
let _state_for_spawn = state_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(adapter) = orchestrator.channels.get("basic") {
|
||||
let _ = adapter.send_message(response).await;
|
||||
}
|
||||
// Use a more thread-safe approach to send the message
|
||||
// This avoids capturing the orchestrator directly which isn't Send + Sync
|
||||
// TODO: Implement proper response handling once response_sender field is added to AppState
|
||||
log::debug!("TALK: Would send response: {:?}", response);
|
||||
});
|
||||
|
||||
Ok(Dynamic::UNIT)
|
||||
|
|
@ -74,10 +92,10 @@ pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Eng
|
|||
let redis_key = format!("context:{}:{}", user.user_id, user.id);
|
||||
|
||||
let state_for_redis = state_clone.clone();
|
||||
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(redis_client) = &state_for_redis.redis_client {
|
||||
let mut conn = match redis_client.get_async_connection().await {
|
||||
let mut conn = match redis_client.get_multiplexed_async_connection().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect to Redis: {}", e);
|
||||
|
|
|
|||
|
|
@ -1,29 +1,25 @@
|
|||
use log::info;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::utils::call_llm;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine};
|
||||
|
||||
pub fn llm_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
let ai_config = state.config.clone().unwrap().ai.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
&["LLM", "$expr$"],
|
||||
false,
|
||||
move |context, inputs| {
|
||||
let text = context.eval_expression_tree(&inputs[0])?;
|
||||
let text_str = text.to_string();
|
||||
.register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| {
|
||||
let text = context.eval_expression_tree(&inputs[0])?;
|
||||
let text_str = text.to_string();
|
||||
|
||||
info!("LLM processing text: {}", text_str);
|
||||
info!("LLM processing text: {}", text_str);
|
||||
|
||||
let fut = call_llm(&text_str, &ai_config);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("LLM call failed: {}", e))?;
|
||||
let fut = call_llm(&text_str, &ai_config);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("LLM call failed: {}", e))?;
|
||||
|
||||
Ok(Dynamic::from(result))
|
||||
},
|
||||
)
|
||||
Ok(Dynamic::from(result))
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,42 +1,40 @@
|
|||
use diesel::prelude::*;
|
||||
use log::{error, info};
|
||||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
use serde_json::{json, Value};
|
||||
use diesel::prelude::*;
|
||||
|
||||
use crate::shared::models::TriggerKind;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
|
||||
pub fn on_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
pub fn on_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
["ON", "$ident$", "OF", "$string$"],
|
||||
&["ON", "$ident$", "OF", "$string$"],
|
||||
true,
|
||||
{
|
||||
move |context, inputs| {
|
||||
let trigger_type = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
let table = context.eval_expression_tree(&inputs[1])?.to_string();
|
||||
let script_name = format!("{}_{}.rhai", table, trigger_type.to_lowercase());
|
||||
move |context, inputs| {
|
||||
let trigger_type = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
let table = context.eval_expression_tree(&inputs[1])?.to_string();
|
||||
let script_name = format!("{}_{}.rhai", table, trigger_type.to_lowercase());
|
||||
|
||||
let kind = match trigger_type.to_uppercase().as_str() {
|
||||
"UPDATE" => TriggerKind::TableUpdate,
|
||||
"INSERT" => TriggerKind::TableInsert,
|
||||
"DELETE" => TriggerKind::TableDelete,
|
||||
_ => return Err(format!("Invalid trigger type: {}", trigger_type).into()),
|
||||
};
|
||||
let kind = match trigger_type.to_uppercase().as_str() {
|
||||
"UPDATE" => TriggerKind::TableUpdate,
|
||||
"INSERT" => TriggerKind::TableInsert,
|
||||
"DELETE" => TriggerKind::TableDelete,
|
||||
_ => return Err(format!("Invalid trigger type: {}", trigger_type).into()),
|
||||
};
|
||||
|
||||
let conn = state_clone.conn.lock().unwrap().clone();
|
||||
let result = execute_on_trigger(&conn, kind, &table, &script_name)
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
let mut conn = state_clone.conn.lock().unwrap();
|
||||
let result = execute_on_trigger(&mut *conn, kind, &table, &script_name)
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Some(rows_affected) = result.get("rows_affected") {
|
||||
Ok(Dynamic::from(rows_affected.as_i64().unwrap_or(0)))
|
||||
} else {
|
||||
Err("No rows affected".into())
|
||||
}
|
||||
if let Some(rows_affected) = result.get("rows_affected") {
|
||||
Ok(Dynamic::from(rows_affected.as_i64().unwrap_or(0)))
|
||||
} else {
|
||||
Err("No rows affected".into())
|
||||
}
|
||||
},
|
||||
)
|
||||
|
|
@ -44,7 +42,7 @@ pub fn on_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
|||
}
|
||||
|
||||
pub fn execute_on_trigger(
|
||||
conn: &PgConnection,
|
||||
conn: &mut diesel::PgConnection,
|
||||
kind: TriggerKind,
|
||||
table: &str,
|
||||
script_name: &str,
|
||||
|
|
@ -64,7 +62,7 @@ pub fn execute_on_trigger(
|
|||
|
||||
let result = diesel::insert_into(system_automations::table)
|
||||
.values(&new_automation)
|
||||
.execute(&mut conn.clone())
|
||||
.execute(conn)
|
||||
.map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
use diesel::prelude::*;
|
||||
use log::{error, info};
|
||||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
use serde_json::{json, Value};
|
||||
use diesel::prelude::*;
|
||||
use std::error::Error;
|
||||
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
|
||||
pub fn set_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
|
@ -22,8 +22,8 @@ pub fn set_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
|||
let filter_str = filter.to_string();
|
||||
let updates_str = updates.to_string();
|
||||
|
||||
let conn = state_clone.conn.lock().unwrap().clone();
|
||||
let result = execute_set(&conn, &table_str, &filter_str, &updates_str)
|
||||
let conn = state_clone.conn.lock().unwrap();
|
||||
let result = execute_set(&*conn, &table_str, &filter_str, &updates_str)
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Some(rows_affected) = result.get("rows_affected") {
|
||||
|
|
@ -37,7 +37,7 @@ pub fn set_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
|||
}
|
||||
|
||||
pub fn execute_set(
|
||||
conn: &PgConnection,
|
||||
conn: &mut diesel::PgConnection,
|
||||
table_str: &str,
|
||||
filter_str: &str,
|
||||
updates_str: &str,
|
||||
|
|
@ -47,7 +47,7 @@ pub fn execute_set(
|
|||
table_str, filter_str, updates_str
|
||||
);
|
||||
|
||||
let (set_clause, update_values) = parse_updates(updates_str).map_err(|e| e.to_string())?;
|
||||
let (set_clause, _update_values) = parse_updates(updates_str).map_err(|e| e.to_string())?;
|
||||
|
||||
let where_clause = parse_filter_for_diesel(filter_str).map_err(|e| e.to_string())?;
|
||||
|
||||
|
|
@ -57,12 +57,10 @@ pub fn execute_set(
|
|||
);
|
||||
info!("Executing query: {}", query);
|
||||
|
||||
let result = diesel::sql_query(&query)
|
||||
.execute(&mut conn.clone())
|
||||
.map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
})?;
|
||||
let result = diesel::sql_query(&query).execute(conn).map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
})?;
|
||||
|
||||
Ok(json!({
|
||||
"command": "set",
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ pub fn set_schedule_keyword(state: &AppState, user: UserSession, engine: &mut En
|
|||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(["SET_SCHEDULE", "$string$"], true, {
|
||||
.register_custom_syntax(&["SET_SCHEDULE", "$string$"], true, {
|
||||
move |context, inputs| {
|
||||
let cron = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
let script_name = format!("cron_{}.rhai", cron.replace(' ', "_"));
|
||||
|
||||
let conn = state_clone.conn.lock().unwrap().clone();
|
||||
let result = execute_set_schedule(&conn, &cron, &script_name)
|
||||
let conn = state_clone.conn.lock().unwrap();
|
||||
let result = execute_set_schedule(&*conn, &cron, &script_name)
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Some(rows_affected) = result.get("rows_affected") {
|
||||
|
|
@ -32,7 +32,7 @@ pub fn set_schedule_keyword(state: &AppState, user: UserSession, engine: &mut En
|
|||
}
|
||||
|
||||
pub fn execute_set_schedule(
|
||||
conn: &PgConnection,
|
||||
conn: &diesel::PgConnection,
|
||||
cron: &str,
|
||||
script_name: &str,
|
||||
) -> Result<Value, Box<dyn std::error::Error>> {
|
||||
|
|
@ -51,7 +51,7 @@ pub fn execute_set_schedule(
|
|||
|
||||
let result = diesel::insert_into(system_automations::table)
|
||||
.values(&new_automation)
|
||||
.execute(&mut conn.clone())?;
|
||||
.execute(conn)?;
|
||||
|
||||
Ok(json!({
|
||||
"command": "set_schedule",
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ use self::keywords::first::first_keyword;
|
|||
use self::keywords::for_next::for_keyword;
|
||||
use self::keywords::format::format_keyword;
|
||||
use self::keywords::get::get_keyword;
|
||||
#[cfg(feature = "web_automation")]
|
||||
use self::keywords::get_website::get_website_keyword;
|
||||
use self::keywords::hear_talk::{hear_keyword, set_context_keyword, talk_keyword};
|
||||
use self::keywords::last::last_keyword;
|
||||
use self::keywords::llm_keyword::llm_keyword;
|
||||
|
|
@ -20,7 +18,7 @@ use self::keywords::set::set_keyword;
|
|||
use self::keywords::set_schedule::set_schedule_keyword;
|
||||
use self::keywords::wait::wait_keyword;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::AppState;
|
||||
use crate::shared::state::AppState;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||
|
||||
|
|
@ -45,7 +43,6 @@ impl ScriptService {
|
|||
last_keyword(&mut engine);
|
||||
format_keyword(&mut engine);
|
||||
llm_keyword(state, user.clone(), &mut engine);
|
||||
get_website_keyword(state, user.clone(), &mut engine);
|
||||
get_keyword(state, user.clone(), &mut engine);
|
||||
set_keyword(state, user.clone(), &mut engine);
|
||||
wait_keyword(state, user.clone(), &mut engine);
|
||||
|
|
@ -161,7 +158,7 @@ impl ScriptService {
|
|||
info!("Processed Script:\n{}", processed_script);
|
||||
match self.engine.compile(&processed_script) {
|
||||
Ok(ast) => Ok(ast),
|
||||
Err(parse_error) => Err(Box::new(EvalAltResult::from(parse_error))),
|
||||
Err(parse_error) => Err(Box::new(parse_error.into())),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ use crate::auth::AuthService;
|
|||
use crate::channels::ChannelAdapter;
|
||||
use crate::llm::LLMProvider;
|
||||
use crate::session::SessionManager;
|
||||
use crate::shared::{BotResponse, UserMessage, UserSession};
|
||||
use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
||||
use crate::tools::ToolManager;
|
||||
|
||||
pub struct BotOrchestrator {
|
||||
|
|
@ -455,7 +455,7 @@ impl BotOrchestrator {
|
|||
async fn websocket_handler(
|
||||
req: HttpRequest,
|
||||
stream: web::Payload,
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
|
|
@ -515,7 +515,7 @@ async fn websocket_handler(
|
|||
|
||||
#[actix_web::get("/api/whatsapp/webhook")]
|
||||
async fn whatsapp_webhook_verify(
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
web::Query(params): web::Query<HashMap<String, String>>,
|
||||
) -> Result<HttpResponse> {
|
||||
let empty = String::new();
|
||||
|
|
@ -531,7 +531,7 @@ async fn whatsapp_webhook_verify(
|
|||
|
||||
#[actix_web::post("/api/whatsapp/webhook")]
|
||||
async fn whatsapp_webhook(
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
||||
) -> Result<HttpResponse> {
|
||||
match data
|
||||
|
|
@ -556,7 +556,7 @@ async fn whatsapp_webhook(
|
|||
|
||||
#[actix_web::post("/api/voice/start")]
|
||||
async fn voice_start(
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
info: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = info
|
||||
|
|
@ -585,7 +585,7 @@ async fn voice_start(
|
|||
|
||||
#[actix_web::post("/api/voice/stop")]
|
||||
async fn voice_stop(
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
info: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = info
|
||||
|
|
@ -603,7 +603,7 @@ async fn voice_stop(
|
|||
}
|
||||
|
||||
#[actix_web::post("/api/sessions")]
|
||||
async fn create_session(_data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> {
|
||||
async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
||||
let session_id = Uuid::new_v4();
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"session_id": session_id,
|
||||
|
|
@ -613,7 +613,7 @@ async fn create_session(_data: web::Data<crate::shared::AppState>) -> Result<Htt
|
|||
}
|
||||
|
||||
#[actix_web::get("/api/sessions")]
|
||||
async fn get_sessions(data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> {
|
||||
async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||||
match data.orchestrator.get_user_sessions(user_id).await {
|
||||
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
||||
|
|
@ -626,7 +626,7 @@ async fn get_sessions(data: web::Data<crate::shared::AppState>) -> Result<HttpRe
|
|||
|
||||
#[actix_web::get("/api/sessions/{session_id}")]
|
||||
async fn get_session_history(
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
path: web::Path<String>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = path.into_inner();
|
||||
|
|
@ -650,7 +650,7 @@ async fn get_session_history(
|
|||
|
||||
#[actix_web::post("/api/set_mode")]
|
||||
async fn set_mode_handler(
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
info: web::Json<HashMap<String, String>>,
|
||||
) -> Result<HttpResponse> {
|
||||
let default_user = "default_user".to_string();
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use std::collections::HashMap;
|
|||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
use crate::shared::BotResponse;
|
||||
use crate::shared::models::BotResponse;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ChannelAdapter: Send + Sync {
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@ use std::env;
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct AppConfig {
|
||||
pub minio: MinioConfig,
|
||||
pub minio: DriveConfig,
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub database_custom: DatabaseConfig,
|
||||
pub email: EmailConfig,
|
||||
pub ai: AIConfig,
|
||||
pub site_path: String,
|
||||
pub s3_bucket: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -21,7 +22,7 @@ pub struct DatabaseConfig {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MinioConfig {
|
||||
pub struct DriveConfig {
|
||||
pub server: String,
|
||||
pub access_key: String,
|
||||
pub secret_key: String,
|
||||
|
|
@ -98,7 +99,7 @@ impl AppConfig {
|
|||
database: env::var("CUSTOM_DATABASE").unwrap_or_else(|_| "db".to_string()),
|
||||
};
|
||||
|
||||
let minio = MinioConfig {
|
||||
let minio = DriveConfig {
|
||||
server: env::var("DRIVE_SERVER").unwrap_or_else(|_| "localhost:9000".to_string()),
|
||||
access_key: env::var("DRIVE_ACCESSKEY").unwrap_or_else(|_| "minioadmin".to_string()),
|
||||
secret_key: env::var("DRIVE_SECRET").unwrap_or_else(|_| "minioadmin".to_string()),
|
||||
|
|
@ -124,7 +125,8 @@ impl AppConfig {
|
|||
instance: env::var("AI_INSTANCE").unwrap_or_else(|_| "gpt-4".to_string()),
|
||||
key: env::var("AI_KEY").unwrap_or_else(|_| "key".to_string()),
|
||||
version: env::var("AI_VERSION").unwrap_or_else(|_| "2023-12-01-preview".to_string()),
|
||||
endpoint: env::var("AI_ENDPOINT").unwrap_or_else(|_| "https://api.openai.com".to_string()),
|
||||
endpoint: env::var("AI_ENDPOINT")
|
||||
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
|
||||
};
|
||||
|
||||
AppConfig {
|
||||
|
|
@ -140,6 +142,8 @@ impl AppConfig {
|
|||
database_custom,
|
||||
email,
|
||||
ai,
|
||||
s3_bucket: env::var("DRIVE_BUCKET").unwrap_or_else(|_| "default".to_string()),
|
||||
|
||||
site_path: env::var("SITES_ROOT").unwrap_or_else(|_| "./sites".to_string()),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use async_trait::async_trait;
|
|||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::shared::SearchResult;
|
||||
use crate::shared::models::SearchResult;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ContextStore: Send + Sync {
|
||||
|
|
@ -21,11 +21,11 @@ pub trait ContextStore: Send + Sync {
|
|||
}
|
||||
|
||||
pub struct QdrantContextStore {
|
||||
vector_store: Arc<qdrant_client::client::QdrantClient>,
|
||||
vector_store: Arc<qdrant_client::Qdrant>,
|
||||
}
|
||||
|
||||
impl QdrantContextStore {
|
||||
pub fn new(vector_store: qdrant_client::client::QdrantClient) -> Self {
|
||||
pub fn new(vector_store: qdrant_client::Qdrant) -> Self {
|
||||
Self {
|
||||
vector_store: Arc::new(vector_store),
|
||||
}
|
||||
|
|
|
|||
202
src/file/mod.rs
202
src/file/mod.rs
|
|
@ -1,42 +1,13 @@
|
|||
use actix_web::web;
|
||||
use actix_multipart::Multipart;
|
||||
use actix_web::web;
|
||||
use actix_web::{post, HttpResponse};
|
||||
use aws_sdk_s3::{Client, Error as S3Error};
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
use tokio_stream::StreamExt;
|
||||
use aws_sdk_s3 as s3;
|
||||
use aws_sdk_s3::types::ByteStream;
|
||||
use std::str::FromStr;
|
||||
use tokio_stream::StreamExt as TokioStreamExt;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
use crate::shared::state::AppState;
|
||||
|
||||
pub async fn init_s3(config: &AppConfig) -> Result<s3::Client, Box<dyn std::error::Error>> {
|
||||
let endpoint_url = if config.minio.use_ssl {
|
||||
format!("https://{}", config.minio.server)
|
||||
} else {
|
||||
format!("http://{}", config.minio.server)
|
||||
};
|
||||
|
||||
let config = aws_config::from_env()
|
||||
.endpoint_url(&endpoint_url)
|
||||
.region(aws_sdk_s3::config::Region::new("us-east-1"))
|
||||
.credentials_provider(
|
||||
s3::config::Credentials::new(
|
||||
&config.minio.access_key,
|
||||
&config.minio.secret_key,
|
||||
None,
|
||||
None,
|
||||
"minio",
|
||||
)
|
||||
)
|
||||
.load()
|
||||
.await;
|
||||
|
||||
let client = s3::Client::new(&config);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
#[post("/files/upload/{folder_path}")]
|
||||
pub async fn upload_file(
|
||||
folder_path: web::Path<String>,
|
||||
|
|
@ -45,12 +16,14 @@ pub async fn upload_file(
|
|||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let folder_path = folder_path.into_inner();
|
||||
|
||||
// Create a temporary file that will hold the uploaded data
|
||||
let mut temp_file = NamedTempFile::new().map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!("Failed to create temp file: {}", e))
|
||||
})?;
|
||||
|
||||
let mut file_name: Option<String> = None;
|
||||
|
||||
// Process multipart form data
|
||||
while let Some(mut field) = payload.try_next().await? {
|
||||
if let Some(disposition) = field.content_disposition() {
|
||||
if let Some(name) = disposition.get_filename() {
|
||||
|
|
@ -58,6 +31,7 @@ pub async fn upload_file(
|
|||
}
|
||||
}
|
||||
|
||||
// Write each chunk of the field to the temporary file
|
||||
while let Some(chunk) = field.try_next().await? {
|
||||
temp_file.write_all(&chunk).map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!(
|
||||
|
|
@ -68,84 +42,106 @@ pub async fn upload_file(
|
|||
}
|
||||
}
|
||||
|
||||
// Use a fallback name if the client didn't supply one
|
||||
let file_name = file_name.unwrap_or_else(|| "unnamed_file".to_string());
|
||||
let object_name = format!("{}/{}", folder_path, file_name);
|
||||
|
||||
let client = state.s3_client.as_ref().ok_or_else(|| {
|
||||
actix_web::error::ErrorInternalServerError("S3 client not initialized")
|
||||
})?;
|
||||
// Convert the NamedTempFile into a TempPath so we can get a stable path
|
||||
let temp_file_path = temp_file.into_temp_path();
|
||||
|
||||
let bucket_name = state.config.as_ref().unwrap().minio.bucket.clone();
|
||||
// Retrieve the bucket name from configuration, handling the case where it is missing
|
||||
let bucket_name = match &state.config {
|
||||
Some(cfg) => cfg.s3_bucket.clone(),
|
||||
None => {
|
||||
// Clean up the temp file before returning the error
|
||||
let _ = std::fs::remove_file(&temp_file_path);
|
||||
return Err(actix_web::error::ErrorInternalServerError(
|
||||
"S3 bucket configuration is missing",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let body = ByteStream::from_path(temp_file.path()).await.map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!("Failed to read file: {}", e))
|
||||
})?;
|
||||
// Build the S3 object key (folder + filename)
|
||||
let s3_key = format!("{}/{}", folder_path, file_name);
|
||||
|
||||
client
|
||||
.put_object()
|
||||
.bucket(&bucket_name)
|
||||
.key(&object_name)
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!(
|
||||
// Perform the upload
|
||||
let s3_client = get_s3_client(&state).await;
|
||||
match upload_to_s3(&s3_client, &bucket_name, &s3_key, &temp_file_path).await {
|
||||
Ok(_) => {
|
||||
// Remove the temporary file now that the upload succeeded
|
||||
let _ = std::fs::remove_file(&temp_file_path);
|
||||
Ok(HttpResponse::Ok().body(format!(
|
||||
"Uploaded file '{}' to folder '{}' in S3 bucket '{}'",
|
||||
file_name, folder_path, bucket_name
|
||||
)))
|
||||
}
|
||||
Err(e) => {
|
||||
// Ensure the temporary file is cleaned up even on failure
|
||||
let _ = std::fs::remove_file(&temp_file_path);
|
||||
Err(actix_web::error::ErrorInternalServerError(format!(
|
||||
"Failed to upload file to S3: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
temp_file.close().map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!("Failed to close temp file: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(HttpResponse::Ok().body(format!(
|
||||
"Uploaded file '{}' to folder '{}'",
|
||||
file_name, folder_path
|
||||
)))
|
||||
}
|
||||
|
||||
#[post("/files/list/{folder_path}")]
|
||||
pub async fn list_file(
|
||||
folder_path: web::Path<String>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let folder_path = folder_path.into_inner();
|
||||
|
||||
let client = state.s3_client.as_ref().ok_or_else(|| {
|
||||
actix_web::error::ErrorInternalServerError("S3 client not initialized")
|
||||
})?;
|
||||
|
||||
let bucket_name = "file-upload-rust-bucket";
|
||||
|
||||
let mut objects = client
|
||||
.list_objects_v2()
|
||||
.bucket(bucket_name)
|
||||
.prefix(&folder_path)
|
||||
.into_paginator()
|
||||
.send();
|
||||
|
||||
let mut file_list = Vec::new();
|
||||
|
||||
while let Some(result) = objects.next().await {
|
||||
match result {
|
||||
Ok(output) => {
|
||||
if let Some(contents) = output.contents {
|
||||
for item in contents {
|
||||
if let Some(key) = item.key {
|
||||
file_list.push(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(actix_web::error::ErrorInternalServerError(format!(
|
||||
"Failed to list files in S3: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
Ok(HttpResponse::Ok().json(file_list))
|
||||
}
|
||||
|
||||
// Helper function to get S3 client
|
||||
async fn get_s3_client(state: &AppState) -> Client {
|
||||
if let Some(cfg) = &state.config.as_ref().and_then(|c| Some(&c.minio)) {
|
||||
// Build static credentials from the Drive configuration.
|
||||
let credentials = aws_sdk_s3::config::Credentials::new(
|
||||
cfg.access_key.clone(),
|
||||
cfg.secret_key.clone(),
|
||||
None,
|
||||
None,
|
||||
"static",
|
||||
);
|
||||
|
||||
// Construct the endpoint URL, respecting the SSL flag.
|
||||
let scheme = if cfg.use_ssl { "https" } else { "http" };
|
||||
let endpoint = format!("{}://{}", scheme, cfg.server);
|
||||
|
||||
// MinIO requires path‑style addressing.
|
||||
let s3_config = aws_sdk_s3::config::Builder::new()
|
||||
.region(aws_sdk_s3::config::Region::new("us-east-1"))
|
||||
.endpoint_url(endpoint)
|
||||
.credentials_provider(credentials)
|
||||
.force_path_style(true)
|
||||
.build();
|
||||
|
||||
Client::from_conf(s3_config)
|
||||
} else {
|
||||
panic!("MinIO configuration is missing in application state");
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to upload file to S3
|
||||
async fn upload_to_s3(
|
||||
client: &Client,
|
||||
bucket: &str,
|
||||
key: &str,
|
||||
file_path: &std::path::Path,
|
||||
) -> Result<(), S3Error> {
|
||||
// Convert the file at `file_path` into a `ByteStream`. Any I/O error is
|
||||
// turned into a construction‑failure `SdkError` so that the function’s
|
||||
// `Result` type (`Result<(), S3Error>`) stays consistent.
|
||||
let body = aws_sdk_s3::primitives::ByteStream::from_path(file_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
aws_sdk_s3::error::SdkError::<
|
||||
aws_sdk_s3::operation::put_object::PutObjectError,
|
||||
aws_sdk_s3::operation::put_object::PutObjectOutput,
|
||||
>::construction_failure(e)
|
||||
})?;
|
||||
|
||||
// Perform the actual upload to S3.
|
||||
client
|
||||
.put_object()
|
||||
.bucket(bucket)
|
||||
.key(key)
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ use dotenvy::dotenv;
|
|||
use log::{error, info};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AzureOpenAIConfig {
|
||||
|
|
@ -60,12 +59,14 @@ impl AzureOpenAIClient {
|
|||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
dotenv().ok();
|
||||
|
||||
let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT")
|
||||
.map_err(|_| "AZURE_OPENAI_ENDPOINT not set")?;
|
||||
let api_key = std::env::var("AZURE_OPENAI_API_KEY")
|
||||
.map_err(|_| "AZURE_OPENAI_API_KEY not set")?;
|
||||
let api_version = std::env::var("AZURE_OPENAI_API_VERSION").unwrap_or_else(|_| "2023-12-01-preview".to_string());
|
||||
let deployment = std::env::var("AZURE_OPENAI_DEPLOYMENT").unwrap_or_else(|_| "gpt-35-turbo".to_string());
|
||||
let endpoint =
|
||||
std::env::var("AZURE_OPENAI_ENDPOINT").map_err(|_| "AZURE_OPENAI_ENDPOINT not set")?;
|
||||
let api_key =
|
||||
std::env::var("AZURE_OPENAI_API_KEY").map_err(|_| "AZURE_OPENAI_API_KEY not set")?;
|
||||
let api_version = std::env::var("AZURE_OPENAI_API_VERSION")
|
||||
.unwrap_or_else(|_| "2023-12-01-preview".to_string());
|
||||
let deployment =
|
||||
std::env::var("AZURE_OPENAI_DEPLOYMENT").unwrap_or_else(|_| "gpt-35-turbo".to_string());
|
||||
|
||||
let config = AzureOpenAIConfig {
|
||||
endpoint,
|
||||
|
|
@ -121,10 +122,7 @@ impl AzureOpenAIClient {
|
|||
Ok(completion_response)
|
||||
}
|
||||
|
||||
pub async fn simple_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
pub async fn simple_chat(&self, prompt: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use dotenvy::dotenv;
|
||||
use log::{error, info};
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use dotenvy::dotenv;
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
|
|||
|
|
@ -1,55 +1,406 @@
|
|||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||
use dotenvy::dotenv;
|
||||
use log::{error, info, warn};
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use log::{error, info};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::{Command, Stdio};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use std::env;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LocalChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Choice {
|
||||
message: ChatMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
// Llama.cpp server request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LlamaCppRequest {
|
||||
prompt: String,
|
||||
n_predict: Option<i32>,
|
||||
temperature: Option<f32>,
|
||||
top_k: Option<i32>,
|
||||
top_p: Option<f32>,
|
||||
stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LlamaCppResponse {
|
||||
content: String,
|
||||
stop: bool,
|
||||
generation_settings: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
|
||||
{
|
||||
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
||||
|
||||
if llm_local.to_lowercase() != "true" {
|
||||
info!("ℹ️ LLM_LOCAL is not enabled, skipping local server startup");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Get configuration from environment variables
|
||||
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
let embedding_url =
|
||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||
let llama_cpp_path = env::var("LLM_CPP_PATH").unwrap_or_else(|_| "~/llama.cpp".to_string());
|
||||
let llm_model_path = env::var("LLM_MODEL_PATH").unwrap_or_else(|_| "".to_string());
|
||||
let embedding_model_path = env::var("EMBEDDING_MODEL_PATH").unwrap_or_else(|_| "".to_string());
|
||||
|
||||
info!("🚀 Starting local llama.cpp servers...");
|
||||
info!("📋 Configuration:");
|
||||
info!(" LLM URL: {}", llm_url);
|
||||
info!(" Embedding URL: {}", embedding_url);
|
||||
info!(" LLM Model: {}", llm_model_path);
|
||||
info!(" Embedding Model: {}", embedding_model_path);
|
||||
|
||||
// Check if servers are already running
|
||||
let llm_running = is_server_running(&llm_url).await;
|
||||
let embedding_running = is_server_running(&embedding_url).await;
|
||||
|
||||
if llm_running && embedding_running {
|
||||
info!("✅ Both LLM and Embedding servers are already running");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Start servers that aren't running
|
||||
let mut tasks = vec![];
|
||||
|
||||
if !llm_running && !llm_model_path.is_empty() {
|
||||
info!("🔄 Starting LLM server...");
|
||||
tasks.push(tokio::spawn(start_llm_server(
|
||||
llama_cpp_path.clone(),
|
||||
llm_model_path.clone(),
|
||||
llm_url.clone(),
|
||||
)));
|
||||
} else if llm_model_path.is_empty() {
|
||||
info!("⚠️ LLM_MODEL_PATH not set, skipping LLM server");
|
||||
}
|
||||
|
||||
if !embedding_running && !embedding_model_path.is_empty() {
|
||||
info!("🔄 Starting Embedding server...");
|
||||
tasks.push(tokio::spawn(start_embedding_server(
|
||||
llama_cpp_path.clone(),
|
||||
embedding_model_path.clone(),
|
||||
embedding_url.clone(),
|
||||
)));
|
||||
} else if embedding_model_path.is_empty() {
|
||||
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
||||
}
|
||||
|
||||
// Wait for all server startup tasks
|
||||
for task in tasks {
|
||||
task.await??;
|
||||
}
|
||||
|
||||
// Wait for servers to be ready with verbose logging
|
||||
info!("⏳ Waiting for servers to become ready...");
|
||||
|
||||
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
||||
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
||||
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 60; // 2 minutes total
|
||||
|
||||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
info!(
|
||||
"🔍 Checking server health (attempt {}/{})...",
|
||||
attempts + 1,
|
||||
max_attempts
|
||||
);
|
||||
|
||||
if !llm_ready && !llm_model_path.is_empty() {
|
||||
if is_server_running(&llm_url).await {
|
||||
info!(" ✅ LLM server ready at {}", llm_url);
|
||||
llm_ready = true;
|
||||
} else {
|
||||
info!(" ❌ LLM server not ready yet");
|
||||
}
|
||||
}
|
||||
|
||||
if !embedding_ready && !embedding_model_path.is_empty() {
|
||||
if is_server_running(&embedding_url).await {
|
||||
info!(" ✅ Embedding server ready at {}", embedding_url);
|
||||
embedding_ready = true;
|
||||
} else {
|
||||
info!(" ❌ Embedding server not ready yet");
|
||||
}
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
|
||||
if attempts % 10 == 0 {
|
||||
info!(
|
||||
"⏰ Still waiting for servers... (attempt {}/{})",
|
||||
attempts, max_attempts
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if llm_ready && embedding_ready {
|
||||
info!("🎉 All llama.cpp servers are ready and responding!");
|
||||
Ok(())
|
||||
} else {
|
||||
let mut error_msg = "❌ Servers failed to start within timeout:".to_string();
|
||||
if !llm_ready && !llm_model_path.is_empty() {
|
||||
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
|
||||
}
|
||||
if !embedding_ready && !embedding_model_path.is_empty() {
|
||||
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
|
||||
}
|
||||
Err(error_msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_llm_server(
|
||||
llama_cpp_path: String,
|
||||
model_path: String,
|
||||
url: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let port = url.split(':').last().unwrap_or("8081");
|
||||
|
||||
std::env::set_var("OMP_NUM_THREADS", "20");
|
||||
std::env::set_var("OMP_PLACES", "cores");
|
||||
std::env::set_var("OMP_PROC_BIND", "close");
|
||||
|
||||
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
|
||||
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
||||
llama_cpp_path, model_path, port
|
||||
));
|
||||
|
||||
cmd.spawn()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_embedding_server(
|
||||
llama_cpp_path: String,
|
||||
model_path: String,
|
||||
url: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let port = url.split(':').last().unwrap_or("8082");
|
||||
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 &",
|
||||
llama_cpp_path, model_path, port
|
||||
));
|
||||
|
||||
cmd.spawn()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn is_server_running(url: &str) -> bool {
|
||||
let client = reqwest::Client::new();
|
||||
match client.get(&format!("{}/health", url)).send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert OpenAI chat messages to a single prompt
|
||||
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
for message in messages {
|
||||
match message.role.as_str() {
|
||||
"system" => {
|
||||
prompt.push_str(&format!("System: {}\n\n", message.content));
|
||||
}
|
||||
"user" => {
|
||||
prompt.push_str(&format!("User: {}\n\n", message.content));
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str(&format!("Assistant: {}\n\n", message.content));
|
||||
}
|
||||
_ => {
|
||||
prompt.push_str(&format!("{}: {}\n\n", message.role, message.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("Assistant: ");
|
||||
prompt
|
||||
}
|
||||
|
||||
// Proxy endpoint
|
||||
#[post("/local/v1/chat/completions")]
|
||||
pub async fn chat_completions_local(
|
||||
req_body: web::Json<ChatCompletionRequest>,
|
||||
_req: HttpRequest,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok().unwrap();
|
||||
|
||||
// Get llama.cpp server URL
|
||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
|
||||
// Convert OpenAI format to llama.cpp format
|
||||
let prompt = messages_to_prompt(&req_body.messages);
|
||||
|
||||
let llama_request = LlamaCppRequest {
|
||||
prompt,
|
||||
n_predict: Some(500), // Adjust as needed
|
||||
temperature: Some(0.7),
|
||||
top_k: Some(40),
|
||||
top_p: Some(0.9),
|
||||
stream: req_body.stream,
|
||||
};
|
||||
|
||||
// Send request to llama.cpp server
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(120)) // 2 minute timeout
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
error!("Error creating HTTP client: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||
})?;
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/completion", llama_url))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&llama_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Error calling llama.cpp server: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server")
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
let llama_response: LlamaCppResponse = response.json().await.map_err(|e| {
|
||||
error!("Error parsing llama.cpp response: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||||
})?;
|
||||
|
||||
// Convert llama.cpp response to OpenAI format
|
||||
let openai_response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: req_body.model.clone(),
|
||||
choices: vec![Choice {
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: llama_response.content.trim().to_string(),
|
||||
},
|
||||
finish_reason: if llama_response.stop {
|
||||
"stop".to_string()
|
||||
} else {
|
||||
"length".to_string()
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(openai_response))
|
||||
} else {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||||
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||||
"error": {
|
||||
"message": error_text,
|
||||
"type": "server_error"
|
||||
}
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI Embedding Request - Modified to handle both string and array inputs
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
#[serde(deserialize_with = "deserialize_input")]
|
||||
pub input: Vec<String>,
|
||||
pub model: String,
|
||||
pub input: String,
|
||||
#[serde(default)]
|
||||
pub _encoding_format: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct LocalChatResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
// Custom deserializer to handle both string and array inputs
|
||||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::{self, Visitor};
|
||||
use std::fmt;
|
||||
|
||||
struct InputVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for InputVisitor {
|
||||
type Value = Vec<String>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string or an array of strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(vec![value.to_string()])
|
||||
}
|
||||
|
||||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(vec![value])
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::SeqAccess<'de>,
|
||||
{
|
||||
let mut vec = Vec::new();
|
||||
while let Some(value) = seq.next_element::<String>()? {
|
||||
vec.push(value);
|
||||
}
|
||||
Ok(vec)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(InputVisitor)
|
||||
}
|
||||
|
||||
// OpenAI Embedding Response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub object: String,
|
||||
|
|
@ -62,74 +413,165 @@ pub struct EmbeddingResponse {
|
|||
pub struct EmbeddingData {
|
||||
pub object: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub index: u32,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error>> {
|
||||
info!("Checking if local LLM servers are running...");
|
||||
|
||||
// For now, just log that we would start servers
|
||||
info!("Local LLM servers would be started here");
|
||||
|
||||
Ok(())
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
pub async fn chat_completions_local(
|
||||
payload: web::Json<LocalChatRequest>,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok();
|
||||
|
||||
info!("Received local chat request for model: {}", payload.model);
|
||||
|
||||
// Mock response for local LLM
|
||||
let response = LocalChatResponse {
|
||||
id: "local-chat-123".to_string(),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: payload.model.clone(),
|
||||
choices: vec![ChatChoice {
|
||||
index: 0,
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "This is a mock response from the local LLM. In a real implementation, this would connect to a local model like Llama or Mistral.".to_string(),
|
||||
},
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 25,
|
||||
total_tokens: 40,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(response))
|
||||
// Llama.cpp Embedding Request
|
||||
#[derive(Debug, Serialize)]
|
||||
struct LlamaCppEmbeddingRequest {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
// FIXED: Handle the stupid nested array format
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LlamaCppEmbeddingResponseItem {
|
||||
pub index: usize,
|
||||
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
|
||||
}
|
||||
|
||||
// Proxy endpoint for embeddings
|
||||
#[post("/v1/embeddings")]
|
||||
pub async fn embeddings_local(
|
||||
payload: web::Json<EmbeddingRequest>,
|
||||
req_body: web::Json<EmbeddingRequest>,
|
||||
_req: HttpRequest,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok();
|
||||
|
||||
info!("Received local embedding request for model: {}", payload.model);
|
||||
// Get llama.cpp server URL
|
||||
let llama_url =
|
||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||
|
||||
// Mock embedding response
|
||||
let response = EmbeddingResponse {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
error!("Error creating HTTP client: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||
})?;
|
||||
|
||||
// Process each input text and get embeddings
|
||||
let mut embeddings_data = Vec::new();
|
||||
let mut total_tokens = 0;
|
||||
|
||||
for (index, input_text) in req_body.input.iter().enumerate() {
|
||||
let llama_request = LlamaCppEmbeddingRequest {
|
||||
content: input_text.clone(),
|
||||
};
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/embedding", llama_url))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&llama_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Error calling llama.cpp server for embedding: {}", e);
|
||||
actix_web::error::ErrorInternalServerError(
|
||||
"Failed to call llama.cpp server for embedding",
|
||||
)
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
// First, get the raw response text for debugging
|
||||
let raw_response = response.text().await.map_err(|e| {
|
||||
error!("Error reading response text: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||||
})?;
|
||||
|
||||
// Parse the response as a vector of items with nested arrays
|
||||
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||||
serde_json::from_str(&raw_response).map_err(|e| {
|
||||
error!("Error parsing llama.cpp embedding response: {}", e);
|
||||
error!("Raw response: {}", raw_response);
|
||||
actix_web::error::ErrorInternalServerError(
|
||||
"Failed to parse llama.cpp embedding response",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Extract the embedding from the nested array bullshit
|
||||
if let Some(item) = llama_response.get(0) {
|
||||
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
|
||||
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
|
||||
let flattened_embedding = if !item.embedding.is_empty() {
|
||||
item.embedding[0].clone() // Take the first (and probably only) inner array
|
||||
} else {
|
||||
vec![] // Empty if no embedding data
|
||||
};
|
||||
|
||||
// Estimate token count
|
||||
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||||
total_tokens += estimated_tokens;
|
||||
|
||||
embeddings_data.push(EmbeddingData {
|
||||
object: "embedding".to_string(),
|
||||
embedding: flattened_embedding,
|
||||
index,
|
||||
});
|
||||
} else {
|
||||
error!("No embedding data returned for input: {}", input_text);
|
||||
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("No embedding data returned for input {}", index),
|
||||
"type": "server_error"
|
||||
}
|
||||
})));
|
||||
}
|
||||
} else {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||||
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
return Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Failed to get embedding for input {}: {}", index, error_text),
|
||||
"type": "server_error"
|
||||
}
|
||||
})));
|
||||
}
|
||||
}
|
||||
|
||||
// Build OpenAI-compatible response
|
||||
let openai_response = EmbeddingResponse {
|
||||
object: "list".to_string(),
|
||||
data: vec![EmbeddingData {
|
||||
object: "embedding".to_string(),
|
||||
embedding: vec![0.1; 768], // Mock embedding vector
|
||||
index: 0,
|
||||
}],
|
||||
model: payload.model.clone(),
|
||||
data: embeddings_data,
|
||||
model: req_body.model.clone(),
|
||||
usage: Usage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 10,
|
||||
prompt_tokens: total_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(response))
|
||||
Ok(HttpResponse::Ok().json(openai_response))
|
||||
}
|
||||
|
||||
// Health check endpoint
|
||||
#[actix_web::get("/health")]
|
||||
pub async fn health() -> Result<HttpResponse> {
|
||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
|
||||
if is_server_running(&llama_url).await {
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"llama_server": "running"
|
||||
})))
|
||||
} else {
|
||||
Ok(HttpResponse::ServiceUnavailable().json(serde_json::json!({
|
||||
"status": "unhealthy",
|
||||
"llama_server": "not running"
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
149
src/main.rs
149
src/main.rs
|
|
@ -5,7 +5,7 @@ use actix_web::middleware::Logger;
|
|||
use actix_web::{web, App, HttpServer};
|
||||
use dotenvy::dotenv;
|
||||
use log::info;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
mod auth;
|
||||
mod automation;
|
||||
|
|
@ -23,11 +23,8 @@ mod org;
|
|||
mod session;
|
||||
mod shared;
|
||||
mod tools;
|
||||
#[cfg(feature = "web_automation")]
|
||||
mod web_automation;
|
||||
mod whatsapp;
|
||||
|
||||
use crate::automation::AutomationService;
|
||||
use crate::bot::{
|
||||
create_session, get_session_history, get_sessions, index, set_mode_handler, static_files,
|
||||
voice_start, voice_stop, websocket_handler, whatsapp_webhook, whatsapp_webhook_verify,
|
||||
|
|
@ -38,12 +35,11 @@ use crate::config::AppConfig;
|
|||
use crate::email::{
|
||||
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
|
||||
};
|
||||
use crate::file::{list_file, upload_file};
|
||||
use crate::llm_legacy::llm_generic::generic_chat_completions;
|
||||
use crate::file::upload_file;
|
||||
use crate::llm_legacy::llm_local::{
|
||||
chat_completions_local, embeddings_local, ensure_llama_servers_running,
|
||||
};
|
||||
use crate::shared::AppState;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::whatsapp::WhatsAppAdapter;
|
||||
|
||||
#[actix_web::main]
|
||||
|
|
@ -53,9 +49,12 @@ async fn main() -> std::io::Result<()> {
|
|||
|
||||
info!("Starting General Bots 6.0...");
|
||||
|
||||
let config = AppConfig::from_env();
|
||||
// Load configuration and wrap it in an Arc for safe sharing across threads/closures
|
||||
let cfg = AppConfig::from_env();
|
||||
let config = std::sync::Arc::new(cfg.clone());
|
||||
|
||||
let db_pool = match diesel::PgConnection::establish(&config.database_url()) {
|
||||
// Main database connection pool
|
||||
let db_pool = match diesel::Connection::establish(&cfg.database_url()) {
|
||||
Ok(conn) => {
|
||||
info!("Connected to main database");
|
||||
Arc::new(Mutex::new(conn))
|
||||
|
|
@ -69,6 +68,37 @@ async fn main() -> std::io::Result<()> {
|
|||
}
|
||||
};
|
||||
|
||||
// Build custom database URL from config members
|
||||
let custom_db_url = format!(
|
||||
"postgres://{}:{}@{}:{}/{}",
|
||||
cfg.database_custom.username,
|
||||
cfg.database_custom.password,
|
||||
cfg.database_custom.server,
|
||||
cfg.database_custom.port,
|
||||
cfg.database_custom.database
|
||||
);
|
||||
|
||||
// Custom database connection pool
|
||||
let db_custom_pool = match diesel::Connection::establish(&custom_db_url) {
|
||||
Ok(conn) => {
|
||||
info!("Connected to custom database using constructed URL");
|
||||
Arc::new(Mutex::new(conn))
|
||||
}
|
||||
Err(e2) => {
|
||||
log::error!("Failed to connect to custom database: {}", e2);
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionRefused,
|
||||
format!("Custom Database connection failed: {}", e2),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure local LLM servers are running
|
||||
ensure_llama_servers_running()
|
||||
.await
|
||||
.expect("Failed to initialize LLM local server.");
|
||||
|
||||
// Optional Redis client
|
||||
let redis_client = match redis::Client::open("redis://127.0.0.1/") {
|
||||
Ok(client) => {
|
||||
info!("Connected to Redis");
|
||||
|
|
@ -80,27 +110,10 @@ async fn main() -> std::io::Result<()> {
|
|||
}
|
||||
};
|
||||
|
||||
let browser_pool = Arc::new(web_automation::BrowserPool::new(
|
||||
"chrome".to_string(),
|
||||
2,
|
||||
"headless".to_string(),
|
||||
));
|
||||
|
||||
let auth_service = auth::AuthService::new(
|
||||
diesel::PgConnection::establish(&config.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
);
|
||||
let session_manager = session::SessionManager::new(
|
||||
diesel::PgConnection::establish(&config.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
);
|
||||
|
||||
let tool_manager = tools::ToolManager::new();
|
||||
// Shared utilities
|
||||
let tool_manager = Arc::new(tools::ToolManager::new());
|
||||
let llm_provider = Arc::new(llm::MockLLMProvider::new());
|
||||
|
||||
let orchestrator =
|
||||
bot::BotOrchestrator::new(session_manager, tool_manager, llm_provider, auth_service);
|
||||
|
||||
let web_adapter = Arc::new(WebChannelAdapter::new());
|
||||
let voice_adapter = Arc::new(VoiceAdapter::new(
|
||||
"https://livekit.example.com".to_string(),
|
||||
|
|
@ -116,18 +129,30 @@ async fn main() -> std::io::Result<()> {
|
|||
|
||||
let tool_api = Arc::new(tools::ToolApi::new());
|
||||
|
||||
let app_state = AppState {
|
||||
// Prepare the base AppState (without the orchestrator, which requires per‑worker construction)
|
||||
let base_app_state = AppState {
|
||||
s3_client: None,
|
||||
config: Some(config.clone()),
|
||||
conn: db_pool,
|
||||
config: Some(cfg.clone()),
|
||||
conn: db_pool.clone(),
|
||||
custom_conn: db_custom_pool.clone(),
|
||||
redis_client: redis_client.clone(),
|
||||
browser_pool: browser_pool.clone(),
|
||||
orchestrator: Arc::new(orchestrator),
|
||||
orchestrator: Arc::new(bot::BotOrchestrator::new(
|
||||
// Temporary placeholder – will be replaced per worker
|
||||
session::SessionManager::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
),
|
||||
(*tool_manager).clone(),
|
||||
llm_provider.clone(),
|
||||
auth::AuthService::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
),
|
||||
)), // This placeholder will be shadowed inside the closure
|
||||
web_adapter,
|
||||
voice_adapter,
|
||||
whatsapp_adapter,
|
||||
tool_api,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
info!(
|
||||
|
|
@ -135,23 +160,62 @@ async fn main() -> std::io::Result<()> {
|
|||
config.server.host, config.server.port
|
||||
);
|
||||
|
||||
// Clone the Arc<AppConfig> for use inside the closure so the original `config`
|
||||
// remains available for binding later.
|
||||
let closure_config = config.clone();
|
||||
|
||||
HttpServer::new(move || {
|
||||
// Clone again for this worker thread.
|
||||
let cfg = closure_config.clone();
|
||||
|
||||
// Re‑create services that hold non‑Sync DB connections for each worker thread
|
||||
let auth_service = auth::AuthService::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
);
|
||||
let session_manager = session::SessionManager::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
);
|
||||
|
||||
// Orchestrator for this worker
|
||||
let orchestrator = Arc::new(bot::BotOrchestrator::new(
|
||||
session_manager,
|
||||
(*tool_manager).clone(),
|
||||
llm_provider.clone(),
|
||||
auth_service,
|
||||
));
|
||||
|
||||
// Build the per‑worker AppState, cloning the shared resources
|
||||
let app_state = AppState {
|
||||
s3_client: base_app_state.s3_client.clone(),
|
||||
config: base_app_state.config.clone(),
|
||||
conn: base_app_state.conn.clone(),
|
||||
custom_conn: base_app_state.custom_conn.clone(),
|
||||
redis_client: base_app_state.redis_client.clone(),
|
||||
orchestrator,
|
||||
web_adapter: base_app_state.web_adapter.clone(),
|
||||
voice_adapter: base_app_state.voice_adapter.clone(),
|
||||
whatsapp_adapter: base_app_state.whatsapp_adapter.clone(),
|
||||
tool_api: base_app_state.tool_api.clone(),
|
||||
};
|
||||
|
||||
let cors = Cors::default()
|
||||
.allow_any_origin()
|
||||
.allow_any_method()
|
||||
.allow_any_header()
|
||||
.max_age(3600);
|
||||
|
||||
let app_state_clone = app_state.clone();
|
||||
|
||||
let mut app = App::new()
|
||||
.wrap(cors)
|
||||
.wrap(Logger::default())
|
||||
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
|
||||
.app_data(web::Data::new(app_state.clone()))
|
||||
.app_data(web::Data::new(app_state_clone));
|
||||
|
||||
app = app
|
||||
.service(upload_file)
|
||||
.service(list_file)
|
||||
.service(chat_completions_local)
|
||||
.service(generic_chat_completions)
|
||||
.service(embeddings_local)
|
||||
.service(index)
|
||||
.service(static_files)
|
||||
.service(websocket_handler)
|
||||
|
|
@ -162,7 +226,9 @@ async fn main() -> std::io::Result<()> {
|
|||
.service(create_session)
|
||||
.service(get_sessions)
|
||||
.service(get_session_history)
|
||||
.service(set_mode_handler);
|
||||
.service(set_mode_handler)
|
||||
.service(chat_completions_local)
|
||||
.service(embeddings_local);
|
||||
|
||||
#[cfg(feature = "email")]
|
||||
{
|
||||
|
|
@ -171,7 +237,8 @@ async fn main() -> std::io::Result<()> {
|
|||
.service(get_emails)
|
||||
.service(list_emails)
|
||||
.service(send_email)
|
||||
.service(save_draft);
|
||||
.service(save_draft)
|
||||
.service(save_click);
|
||||
}
|
||||
|
||||
app
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
use diesel::prelude::*;
|
||||
use redis::{AsyncCommands, Client};
|
||||
use serde_json;
|
||||
use diesel::prelude::*;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::shared::UserSession;
|
||||
use crate::shared::models::UserSession;
|
||||
|
||||
pub struct SessionManager {
|
||||
pub conn: diesel::PgConnection,
|
||||
|
|
@ -23,7 +23,8 @@ impl SessionManager {
|
|||
) -> Result<Option<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||||
let session_json: Option<String> = tokio::task::block_in_place(|| {
|
||||
|
|
@ -37,7 +38,7 @@ impl SessionManager {
|
|||
}
|
||||
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
|
||||
let session = user_sessions
|
||||
.filter(user_id.eq(user_id))
|
||||
.filter(bot_id.eq(bot_id))
|
||||
|
|
@ -48,12 +49,17 @@ impl SessionManager {
|
|||
if let Some(ref session) = session {
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||||
let session_json = serde_json::to_string(session)?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(cache_key, session_json, 1800))
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(
|
||||
cache_key,
|
||||
session_json,
|
||||
1800,
|
||||
))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
|
@ -69,7 +75,7 @@ impl SessionManager {
|
|||
) -> Result<UserSession, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions;
|
||||
use diesel::insert_into;
|
||||
|
||||
|
||||
let session_id = Uuid::new_v4();
|
||||
let new_session = (
|
||||
user_sessions::id.eq(session_id),
|
||||
|
|
@ -84,12 +90,17 @@ impl SessionManager {
|
|||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||||
let session_json = serde_json::to_string(&session)?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(cache_key, session_json, 1800))
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(
|
||||
cache_key,
|
||||
session_json,
|
||||
1800,
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
|
|
@ -106,7 +117,7 @@ impl SessionManager {
|
|||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::message_history;
|
||||
use diesel::insert_into;
|
||||
|
||||
|
||||
let message_count: i64 = message_history::table
|
||||
.filter(message_history::session_id.eq(session_id))
|
||||
.count()
|
||||
|
|
@ -139,7 +150,8 @@ impl SessionManager {
|
|||
{
|
||||
let (session_user_id, session_bot_id) = session_info;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", session_user_id, session_bot_id);
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
|
|
@ -157,7 +169,7 @@ impl SessionManager {
|
|||
user_id: Uuid,
|
||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::message_history::dsl::*;
|
||||
|
||||
|
||||
let messages = message_history
|
||||
.filter(session_id.eq(session_id))
|
||||
.filter(user_id.eq(user_id))
|
||||
|
|
@ -173,7 +185,7 @@ impl SessionManager {
|
|||
user_id: Uuid,
|
||||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
|
||||
let sessions = user_sessions
|
||||
.filter(user_id.eq(user_id))
|
||||
.order_by(updated_at.desc())
|
||||
|
|
@ -188,20 +200,22 @@ impl SessionManager {
|
|||
mode: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
|
||||
let user_uuid = Uuid::parse_str(user_id)?;
|
||||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||||
|
||||
diesel::update(user_sessions.filter(user_id.eq(user_uuid)).filter(bot_id.eq(bot_uuid)))
|
||||
.set((
|
||||
answer_mode.eq(mode),
|
||||
updated_at.eq(diesel::dsl::now),
|
||||
))
|
||||
.execute(&mut self.conn)?;
|
||||
diesel::update(
|
||||
user_sessions
|
||||
.filter(user_id.eq(user_uuid))
|
||||
.filter(bot_id.eq(bot_uuid)),
|
||||
)
|
||||
.set((answer_mode.eq(mode), updated_at.eq(diesel::dsl::now)))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
|
|
@ -219,20 +233,22 @@ impl SessionManager {
|
|||
tool_name: Option<&str>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
|
||||
let user_uuid = Uuid::parse_str(user_id)?;
|
||||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||||
|
||||
diesel::update(user_sessions.filter(user_id.eq(user_uuid)).filter(bot_id.eq(bot_uuid)))
|
||||
.set((
|
||||
current_tool.eq(tool_name),
|
||||
updated_at.eq(diesel::dsl::now),
|
||||
))
|
||||
.execute(&mut self.conn)?;
|
||||
diesel::update(
|
||||
user_sessions
|
||||
.filter(user_id.eq(user_uuid))
|
||||
.filter(bot_id.eq(bot_uuid)),
|
||||
)
|
||||
.set((current_tool.eq(tool_name), updated_at.eq(diesel::dsl::now)))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
|
|
@ -249,7 +265,8 @@ impl SessionManager {
|
|||
) -> Result<Option<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session_by_id:{}", session_id);
|
||||
let session_json: Option<String> = tokio::task::block_in_place(|| {
|
||||
|
|
@ -263,7 +280,7 @@ impl SessionManager {
|
|||
}
|
||||
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
|
||||
let session = user_sessions
|
||||
.filter(id.eq(session_id))
|
||||
.first::<UserSession>(&mut self.conn)
|
||||
|
|
@ -272,12 +289,17 @@ impl SessionManager {
|
|||
if let Some(ref session) = session {
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session_by_id:{}", session_id);
|
||||
let session_json = serde_json::to_string(session)?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(cache_key, session_json, 1800))
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(
|
||||
cache_key,
|
||||
session_json,
|
||||
1800,
|
||||
))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
|
@ -290,10 +312,10 @@ impl SessionManager {
|
|||
days_old: i32,
|
||||
) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
|
||||
let cutoff = chrono::Utc::now() - chrono::Duration::days(days_old as i64);
|
||||
let result = diesel::delete(user_sessions.filter(updated_at.lt(cutoff)))
|
||||
.execute(&mut self.conn)?;
|
||||
let result =
|
||||
diesel::delete(user_sessions.filter(updated_at.lt(cutoff))).execute(&mut self.conn)?;
|
||||
Ok(result as u64)
|
||||
}
|
||||
|
||||
|
|
@ -304,20 +326,22 @@ impl SessionManager {
|
|||
tool_name: Option<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
|
||||
let user_uuid = Uuid::parse_str(user_id)?;
|
||||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||||
|
||||
diesel::update(user_sessions.filter(user_id.eq(user_uuid)).filter(bot_id.eq(bot_uuid)))
|
||||
.set((
|
||||
current_tool.eq(tool_name),
|
||||
updated_at.eq(diesel::dsl::now),
|
||||
))
|
||||
.execute(&mut self.conn)?;
|
||||
diesel::update(
|
||||
user_sessions
|
||||
.filter(user_id.eq(user_uuid))
|
||||
.filter(bot_id.eq(bot_uuid)),
|
||||
)
|
||||
.set((current_tool.eq(tool_name), updated_at.eq(diesel::dsl::now)))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
tokio::runtime::Handle::current()
|
||||
.block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,3 @@
|
|||
pub mod models;
|
||||
pub mod state;
|
||||
pub mod utils;
|
||||
|
||||
pub use models::*;
|
||||
pub use state::*;
|
||||
pub use utils::*;
|
||||
|
|
|
|||
|
|
@ -1,25 +1,20 @@
|
|||
use crate::bot::BotOrchestrator;
|
||||
use crate::channels::{VoiceAdapter, WebChannelAdapter};
|
||||
use crate::config::AppConfig;
|
||||
use crate::tools::ToolApi;
|
||||
use crate::whatsapp::WhatsAppAdapter;
|
||||
use diesel::PgConnection;
|
||||
use redis::Client;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::AuthService;
|
||||
use crate::bot::BotOrchestrator;
|
||||
use crate::channels::{VoiceAdapter, WebChannelAdapter};
|
||||
use crate::config::AppConfig;
|
||||
use crate::llm::LLMProvider;
|
||||
use crate::session::SessionManager;
|
||||
use crate::tools::ToolApi;
|
||||
use crate::web_automation::BrowserPool;
|
||||
use crate::whatsapp::WhatsAppAdapter;
|
||||
|
||||
pub struct AppState {
|
||||
pub s3_client: Option<aws_sdk_s3::Client>,
|
||||
pub config: Option<AppConfig>,
|
||||
pub conn: Arc<Mutex<PgConnection>>,
|
||||
pub custom_conn: Arc<Mutex<PgConnection>>,
|
||||
|
||||
pub redis_client: Option<Arc<Client>>,
|
||||
pub browser_pool: Arc<BrowserPool>,
|
||||
pub orchestrator: Arc<BotOrchestrator>,
|
||||
pub web_adapter: Arc<WebChannelAdapter>,
|
||||
pub voice_adapter: Arc<VoiceAdapter>,
|
||||
|
|
@ -27,53 +22,6 @@ pub struct AppState {
|
|||
pub tool_api: Arc<ToolApi>,
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
let conn = diesel::PgConnection::establish("postgres://user:pass@localhost:5432/db")
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
let session_manager = SessionManager::new(conn, None);
|
||||
let tool_manager = crate::tools::ToolManager::new();
|
||||
let llm_provider = Arc::new(crate::llm::MockLLMProvider::new());
|
||||
let auth_service = AuthService::new(
|
||||
diesel::PgConnection::establish("postgres://user:pass@localhost:5432/db").unwrap(),
|
||||
None,
|
||||
);
|
||||
|
||||
Self {
|
||||
s3_client: None,
|
||||
config: None,
|
||||
conn: Arc::new(Mutex::new(
|
||||
diesel::PgConnection::establish("postgres://user:pass@localhost:5432/db").unwrap(),
|
||||
)),
|
||||
redis_client: None,
|
||||
browser_pool: Arc::new(crate::web_automation::BrowserPool::new(
|
||||
"chrome".to_string(),
|
||||
2,
|
||||
"headless".to_string(),
|
||||
)),
|
||||
orchestrator: Arc::new(BotOrchestrator::new(
|
||||
session_manager,
|
||||
tool_manager,
|
||||
llm_provider,
|
||||
auth_service,
|
||||
)),
|
||||
web_adapter: Arc::new(WebChannelAdapter::new()),
|
||||
voice_adapter: Arc::new(VoiceAdapter::new(
|
||||
"https://livekit.example.com".to_string(),
|
||||
"api_key".to_string(),
|
||||
"api_secret".to_string(),
|
||||
)),
|
||||
whatsapp_adapter: Arc::new(WhatsAppAdapter::new(
|
||||
"whatsapp_token".to_string(),
|
||||
"phone_number_id".to_string(),
|
||||
"verify_token".to_string(),
|
||||
)),
|
||||
tool_api: Arc::new(ToolApi::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AppState {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
|
|
@ -81,7 +29,6 @@ impl Clone for AppState {
|
|||
config: self.config.clone(),
|
||||
conn: Arc::clone(&self.conn),
|
||||
redis_client: self.redis_client.clone(),
|
||||
browser_pool: Arc::clone(&self.browser_pool),
|
||||
orchestrator: Arc::clone(&self.orchestrator),
|
||||
web_adapter: Arc::clone(&self.web_adapter),
|
||||
voice_adapter: Arc::clone(&self.voice_adapter),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
use diesel::prelude::*;
|
||||
use log::{debug, warn};
|
||||
use log::debug;
|
||||
use rhai::{Array, Dynamic};
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::Value;
|
||||
use smartstring::SmartString;
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
|
|
@ -43,88 +42,6 @@ pub fn extract_zip_recursive(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn row_to_json(row: diesel::QueryResult<diesel::pg::PgRow>) -> Result<Value, Box<dyn Error>> {
|
||||
let row = row?;
|
||||
let mut result = serde_json::Map::new();
|
||||
let columns = row.columns();
|
||||
debug!("Converting row with {} columns", columns.len());
|
||||
|
||||
for (i, column) in columns.iter().enumerate() {
|
||||
let column_name = column.name();
|
||||
let type_name = column.type_name();
|
||||
|
||||
let value = match type_name {
|
||||
"INT4" | "int4" => handle_nullable_type::<i32>(&row, i, column_name),
|
||||
"INT8" | "int8" => handle_nullable_type::<i64>(&row, i, column_name),
|
||||
"FLOAT4" | "float4" => handle_nullable_type::<f32>(&row, i, column_name),
|
||||
"FLOAT8" | "float8" => handle_nullable_type::<f64>(&row, i, column_name),
|
||||
"TEXT" | "VARCHAR" | "text" | "varchar" => {
|
||||
handle_nullable_type::<String>(&row, i, column_name)
|
||||
}
|
||||
"BOOL" | "bool" => handle_nullable_type::<bool>(&row, i, column_name),
|
||||
"JSON" | "JSONB" | "json" | "jsonb" => handle_json(&row, i, column_name),
|
||||
_ => {
|
||||
warn!("Unknown type {} for column {}", type_name, column_name);
|
||||
handle_nullable_type::<String>(&row, i, column_name)
|
||||
}
|
||||
};
|
||||
|
||||
result.insert(column_name.to_string(), value);
|
||||
}
|
||||
|
||||
Ok(Value::Object(result))
|
||||
}
|
||||
|
||||
fn handle_nullable_type<'r, T>(row: &'r diesel::pg::PgRow, idx: usize, col_name: &str) -> Value
|
||||
where
|
||||
T: diesel::deserialize::FromSql<
|
||||
diesel::sql_types::Nullable<diesel::sql_types::Text>,
|
||||
diesel::pg::Pg,
|
||||
> + serde::Serialize
|
||||
+ std::fmt::Debug,
|
||||
{
|
||||
match row.get::<Option<T>, _>(idx) {
|
||||
Ok(Some(val)) => {
|
||||
debug!("Successfully read column {} as {:?}", col_name, val);
|
||||
json!(val)
|
||||
}
|
||||
Ok(None) => {
|
||||
debug!("Column {} is NULL", col_name);
|
||||
Value::Null
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to read column {}: {}", col_name, e);
|
||||
Value::Null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_json(row: &diesel::pg::PgRow, idx: usize, col_name: &str) -> Value {
|
||||
match row.get::<Option<Value>, _>(idx) {
|
||||
Ok(Some(val)) => {
|
||||
debug!("Successfully read JSON column {} as Value", col_name);
|
||||
return val;
|
||||
}
|
||||
Ok(None) => return Value::Null,
|
||||
Err(_) => (),
|
||||
}
|
||||
|
||||
match row.get::<Option<String>, _>(idx) {
|
||||
Ok(Some(s)) => match serde_json::from_str(&s) {
|
||||
Ok(val) => val,
|
||||
Err(_) => {
|
||||
debug!("Column {} contains string that's not JSON", col_name);
|
||||
json!(s)
|
||||
}
|
||||
},
|
||||
Ok(None) => Value::Null,
|
||||
Err(e) => {
|
||||
warn!("Failed to read JSON column {}: {}", col_name, e);
|
||||
Value::Null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn json_value_to_dynamic(value: &Value) -> Dynamic {
|
||||
match value {
|
||||
Value::Null => Dynamic::UNIT,
|
||||
|
|
@ -231,6 +148,9 @@ pub fn parse_filter_with_offset(
|
|||
Ok((clauses.join(" AND "), params))
|
||||
}
|
||||
|
||||
pub async fn call_llm(prompt: &str, _ai_config: &AIConfig) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub async fn call_llm(
|
||||
prompt: &str,
|
||||
_ai_config: &AIConfig,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(format!("Generated response for: {}", prompt))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,6 @@ use serde::{Deserialize, Serialize};
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{session::SessionManager, shared::BotResponse};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResult {
|
||||
|
|
@ -176,48 +173,6 @@ impl ToolManager {
|
|||
Ok(vec![])
|
||||
}
|
||||
|
||||
pub async fn execute_tool_with_session(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
user_id: &str,
|
||||
bot_id: &str,
|
||||
session_manager: SessionManager,
|
||||
channel_sender: mpsc::Sender<BotResponse>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tool = self.get_tool(tool_name).ok_or("Tool not found")?;
|
||||
session_manager
|
||||
.set_current_tool(user_id, bot_id, Some(tool_name.to_string()))
|
||||
.await?;
|
||||
|
||||
let user_id = user_id.to_string();
|
||||
let bot_id = bot_id.to_string();
|
||||
let _script = tool.script.clone();
|
||||
let session_manager_clone = session_manager.clone();
|
||||
let _waiting_responses = self.waiting_responses.clone();
|
||||
|
||||
let tool_name_clone = tool_name.to_string();
|
||||
tokio::spawn(async move {
|
||||
// Simulate tool execution
|
||||
let response = BotResponse {
|
||||
bot_id: bot_id.clone(),
|
||||
user_id: user_id.clone(),
|
||||
session_id: Uuid::new_v4().to_string(),
|
||||
channel: "test".to_string(),
|
||||
content: format!("Tool {} executed successfully", tool_name_clone),
|
||||
message_type: "text".to_string(),
|
||||
stream_token: None,
|
||||
is_complete: true,
|
||||
};
|
||||
let _ = channel_sender.send(response).await;
|
||||
|
||||
let _ = session_manager_clone
|
||||
.set_current_tool(&user_id, &bot_id, None)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn provide_user_response(
|
||||
&self,
|
||||
user_id: &str,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ use std::collections::HashMap;
|
|||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::shared::BotResponse;
|
||||
use crate::shared::models::BotResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WhatsAppMessage {
|
||||
|
|
@ -75,7 +75,11 @@ pub struct WhatsAppAdapter {
|
|||
}
|
||||
|
||||
impl WhatsAppAdapter {
|
||||
pub fn new(access_token: String, phone_number_id: String, webhook_verify_token: String) -> Self {
|
||||
pub fn new(
|
||||
access_token: String,
|
||||
phone_number_id: String,
|
||||
webhook_verify_token: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
access_token,
|
||||
|
|
@ -98,7 +102,11 @@ impl WhatsAppAdapter {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn send_whatsapp_message(&self, to: &str, body: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub async fn send_whatsapp_message(
|
||||
&self,
|
||||
to: &str,
|
||||
body: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let url = format!(
|
||||
"https://graph.facebook.com/v17.0/{}/messages",
|
||||
self.phone_number_id
|
||||
|
|
@ -112,7 +120,8 @@ impl WhatsAppAdapter {
|
|||
},
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.json(&response_data)
|
||||
|
|
@ -129,7 +138,10 @@ impl WhatsAppAdapter {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn process_incoming_message(&self, message: WhatsAppMessage) -> Result<Vec<crate::shared::UserMessage>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub async fn process_incoming_message(
|
||||
&self,
|
||||
message: WhatsAppMessage,
|
||||
) -> Result<Vec<crate::shared::UserMessage>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut user_messages = Vec::new();
|
||||
|
||||
for entry in message.entry {
|
||||
|
|
@ -139,7 +151,7 @@ impl WhatsAppAdapter {
|
|||
if let Some(text) = msg.text {
|
||||
let session_id = self.get_session_id(&msg.from).await;
|
||||
|
||||
let user_message = crate::shared::UserMessage {
|
||||
let user_message = crate::shared::models::UserMessage {
|
||||
bot_id: "default_bot".to_string(),
|
||||
user_id: msg.from.clone(),
|
||||
session_id: session_id.clone(),
|
||||
|
|
@ -160,7 +172,12 @@ impl WhatsAppAdapter {
|
|||
Ok(user_messages)
|
||||
}
|
||||
|
||||
pub fn verify_webhook(&self, mode: &str, token: &str, challenge: &str) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn verify_webhook(
|
||||
&self,
|
||||
mode: &str,
|
||||
token: &str,
|
||||
challenge: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if mode == "subscribe" && token == self.webhook_verify_token {
|
||||
Ok(challenge.to_string())
|
||||
} else {
|
||||
|
|
@ -171,8 +188,12 @@ impl WhatsAppAdapter {
|
|||
|
||||
#[async_trait]
|
||||
impl crate::channels::ChannelAdapter for WhatsAppAdapter {
|
||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
async fn send_message(
|
||||
&self,
|
||||
response: BotResponse,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
info!("Sending WhatsApp response to: {}", response.user_id);
|
||||
self.send_whatsapp_message(&response.user_id, &response.content).await
|
||||
self.send_whatsapp_message(&response.user_id, &response.content)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue