- Remove all compilation errors.
This commit is contained in:
parent
d1a8185baa
commit
a1dd7b5826
50 changed files with 2586 additions and 8263 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -2,4 +2,4 @@ target
|
|||
.env
|
||||
*.env
|
||||
work
|
||||
*.txt
|
||||
*.out
|
||||
|
|
|
|||
2551
Cargo.lock
generated
2551
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
21
Cargo.toml
21
Cargo.toml
|
|
@ -8,13 +8,15 @@ license = "AGPL-3.0"
|
|||
repository = "https://github.pragmatismo.com.br/generalbots/botserver"
|
||||
|
||||
[features]
|
||||
default = ["qdrant"]
|
||||
qdrant = ["langchain-rust/qdrant"]
|
||||
default = ["vectordb"]
|
||||
vectordb = ["qdrant-client"]
|
||||
email = ["imap"]
|
||||
web_automation = ["headless_chrome"]
|
||||
|
||||
[dependencies]
|
||||
actix-cors = "0.7"
|
||||
actix-multipart = "0.7"
|
||||
imap = { version = "3.0.0-alpha.15", optional = true }
|
||||
actix-web = "4.9"
|
||||
actix-ws = "0.3"
|
||||
anyhow = "1.0"
|
||||
|
|
@ -25,32 +27,27 @@ argon2 = "0.5"
|
|||
base64 = "0.22"
|
||||
bytes = "1.8"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
dotenv = "0.15"
|
||||
diesel = { version = "2.1", features = ["postgres", "uuid", "chrono"] }
|
||||
dotenvy = "0.15"
|
||||
downloader = "0.2"
|
||||
env_logger = "0.11"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
imap = {version="2.4.1", optional=true}
|
||||
langchain-rust = { version = "4.6", features = ["qdrant",] }
|
||||
lettre = { version = "0.11", features = ["smtp-transport", "builder", "tokio1", "tokio1-native-tls"] }
|
||||
livekit = "0.7"
|
||||
log = "0.4"
|
||||
mailparse = "0.15"
|
||||
minio = { git = "https://github.com/minio/minio-rs", branch = "master" }
|
||||
native-tls = "0.2"
|
||||
num-format = "0.4"
|
||||
qdrant-client = "1.12"
|
||||
rhai = "1.22"
|
||||
qdrant-client = { version = "1.12", optional = true }
|
||||
rhai = { git = "https://github.com/therealprof/rhai.git", branch = "features/use-web-time" }
|
||||
redis = { version = "0.27", features = ["tokio-comp"] }
|
||||
regex = "1.11"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
scraper = "0.20"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
smartstring = "1.0"
|
||||
sqlx = { version = "0.8", features = ["time", "uuid", "runtime-tokio-rustls", "postgres", "chrono"] }
|
||||
tempfile = "3"
|
||||
thirtyfour = "0.34"
|
||||
tokio = { version = "1.41", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
tracing = "0.1"
|
||||
|
|
@ -59,3 +56,5 @@ urlencoding = "2.1"
|
|||
uuid = { version = "1.11", features = ["serde", "v4"] }
|
||||
zip = "2.2"
|
||||
time = "0.3.44"
|
||||
aws-sdk-s3 = "1.108.0"
|
||||
headless_chrome = { version = "1.0.18", optional = true }
|
||||
|
|
|
|||
5
diesel.toml
Normal file
5
diesel.toml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
[migrations_directory]
|
||||
dir = "migrations"
|
||||
|
||||
[print_schema]
|
||||
file = "src/shared/schema.rs"
|
||||
8
docs/DEV.md
Normal file
8
docs/DEV.md
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
|
||||
# Util
|
||||
|
||||
cargo install cargo-audit
|
||||
cargo install cargo-edit
|
||||
|
||||
cargo upgrade
|
||||
cargo audit
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
* Preffer imports than using :: to call methods,
|
||||
* Output a single `.sh` script using `cat` so it can be restored directly.
|
||||
* No placeholders, only real, production-ready code.
|
||||
* No comments, no explanations, no extra text.
|
||||
* Follow KISS principles.
|
||||
* Provide a complete, professional, working solution.
|
||||
* If the script is too long, split into multiple parts, but always return the **entire code**.
|
||||
* Output must be **only the code**, nothing else.
|
||||
Return only the modified files as a single `.sh` script using `cat`, so the code can be restored directly.
|
||||
No placeholders, no comments, no explanations, no filler text.
|
||||
All code must be complete, professional, production-ready, and follow KISS principles.
|
||||
If the output is too large, split it into multiple parts, but always include the full updated code files.
|
||||
Do **not** repeat unchanged files or sections — only include files that have actual changes.
|
||||
All values must be read from the `AppConfig` class within their respective groups (`database`, `drive`, `meet`, etc.); never use hardcoded or magic values.
|
||||
Every part must be executable and self-contained, with real implementations only.
|
||||
|
|
|
|||
|
|
@ -1,57 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
OUTPUT_FILE="$SCRIPT_DIR/prompt.txt"
|
||||
|
||||
echo "Consolidated LLM Context" > "$OUTPUT_FILE"
|
||||
|
||||
prompts=(
|
||||
"../../prompts/dev/general.md"
|
||||
"../../Cargo.toml"
|
||||
"../../prompts/dev/fix.md"
|
||||
)
|
||||
|
||||
for file in "${prompts[@]}"; do
|
||||
cat "$file" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
|
||||
dirs=(
|
||||
"auth"
|
||||
"automation"
|
||||
"basic"
|
||||
"bot"
|
||||
"channels"
|
||||
"chart"
|
||||
"config"
|
||||
"context"
|
||||
"email"
|
||||
"file"
|
||||
"llm"
|
||||
"llm_legacy"
|
||||
"org"
|
||||
"session"
|
||||
"shared"
|
||||
"tests"
|
||||
"tools"
|
||||
"web_automation"
|
||||
"whatsapp"
|
||||
)
|
||||
|
||||
for dir in "${dirs[@]}"; do
|
||||
find "$PROJECT_ROOT/src/$dir" -name "*.rs" | while read file; do
|
||||
cat "$file" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
tree -P '*.rs' -I 'target|*.lock' --prune | grep -v '[0-9] directories$' >> "$OUTPUT_FILE"
|
||||
|
||||
|
||||
cargo build 2>> "$OUTPUT_FILE"
|
||||
64
scripts/dev/build_prompt.sh
Executable file
64
scripts/dev/build_prompt.sh
Executable file
|
|
@ -0,0 +1,64 @@
|
|||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
OUTPUT_FILE="$SCRIPT_DIR/prompt.out"
|
||||
rm $OUTPUT_FILE
|
||||
echo "Consolidated LLM Context" > "$OUTPUT_FILE"
|
||||
|
||||
prompts=(
|
||||
"../../prompts/dev/general.md"
|
||||
"../../Cargo.toml"
|
||||
# "../../prompts/dev/fix.md"
|
||||
)
|
||||
|
||||
for file in "${prompts[@]}"; do
|
||||
cat "$file" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
|
||||
dirs=(
|
||||
#"auth"
|
||||
#"automation"
|
||||
#"basic"
|
||||
"bot"
|
||||
#"channels"
|
||||
"config"
|
||||
"context"
|
||||
#"email"
|
||||
#"file"
|
||||
"llm"
|
||||
#"llm_legacy"
|
||||
#"org"
|
||||
#"session"
|
||||
"shared"
|
||||
#"tests"
|
||||
#"tools"
|
||||
#"web_automation"
|
||||
#"whatsapp"
|
||||
)
|
||||
|
||||
for dir in "${dirs[@]}"; do
|
||||
find "$PROJECT_ROOT/src/$dir" -name "*.rs" | while read file; do
|
||||
cat "$file" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE"
|
||||
cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE"
|
||||
cat "$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/start.bas" >> "$OUTPUT_FILE"
|
||||
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
find "$PROJECT_ROOT/src" -type f -name "*.rs" ! -path "*/target/*" ! -name "*.lock" -print0 |
|
||||
while IFS= read -r -d '' file; do
|
||||
echo "File: ${file#$PROJECT_ROOT/}" >> "$OUTPUT_FILE"
|
||||
grep -E '^\s*(pub\s+)?(fn|struct)\s' "$file" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
|
||||
|
||||
# cargo build 2>> "$OUTPUT_FILE"
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,44 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
OUTPUT_FILE="$SCRIPT_DIR/llm_context.txt"
|
||||
|
||||
echo "Consolidated LLM Context" > "$OUTPUT_FILE"
|
||||
|
||||
prompts=(
|
||||
"../../prompts/dev/general.md"
|
||||
"../../Cargo.toml"
|
||||
"../../prompts/dev/fix.md"
|
||||
)
|
||||
|
||||
for file in "${prompts[@]}"; do
|
||||
cat "$file" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
|
||||
dirs=(
|
||||
"src/channels"
|
||||
"src/llm"
|
||||
"src/whatsapp"
|
||||
"src/config"
|
||||
"src/auth"
|
||||
"src/shared"
|
||||
"src/bot"
|
||||
"src/session"
|
||||
"src/tools"
|
||||
"src/context"
|
||||
)
|
||||
|
||||
for dir in "${dirs[@]}"; do
|
||||
find "$PROJECT_ROOT/$dir" -name "*.rs" | while read file; do
|
||||
cat "$file" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
tree -P '*.rs' -I 'target|*.lock' --prune | grep -v '[0-9] directories$' >> "$OUTPUT_FILE"
|
||||
|
||||
|
||||
cargo build 2>> "$OUTPUT_FILE"
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
# apt install tree
|
||||
tree -P '*.rs' -I 'target|*.lock' --prune | grep -v '[0-9] directories$'
|
||||
124
src/auth/mod.rs
124
src/auth/mod.rs
|
|
@ -2,37 +2,37 @@ use argon2::{
|
|||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Argon2,
|
||||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel::pg::PgConnection;
|
||||
use redis::Client;
|
||||
use sqlx::{PgPool, Row};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct AuthService {
|
||||
pub pool: PgPool,
|
||||
pub conn: PgConnection,
|
||||
pub redis: Option<Arc<Client>>,
|
||||
}
|
||||
|
||||
impl AuthService {
|
||||
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
|
||||
Self { pool, redis }
|
||||
pub fn new(conn: PgConnection, redis: Option<Arc<Client>>) -> Self {
|
||||
Self { conn, redis }
|
||||
}
|
||||
|
||||
pub async fn verify_user(
|
||||
&self,
|
||||
pub fn verify_user(
|
||||
&mut self,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<Option<Uuid>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let user = sqlx::query(
|
||||
"SELECT id, password_hash FROM users WHERE username = $1 AND is_active = true",
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
if let Some(row) = user {
|
||||
let user_id: Uuid = row.get("id");
|
||||
let password_hash: String = row.get("password_hash");
|
||||
use crate::shared::models::users;
|
||||
|
||||
let user = users::table
|
||||
.filter(users::username.eq(username))
|
||||
.filter(users::is_active.eq(true))
|
||||
.select((users::id, users::password_hash))
|
||||
.first::<(Uuid, String)>(&mut self.conn)
|
||||
.optional()?;
|
||||
|
||||
if let Some((user_id, password_hash)) = user {
|
||||
if let Ok(parsed_hash) = PasswordHash::new(&password_hash) {
|
||||
if Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
|
|
@ -46,34 +46,33 @@ impl AuthService {
|
|||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn create_user(
|
||||
&self,
|
||||
pub fn create_user(
|
||||
&mut self,
|
||||
username: &str,
|
||||
email: &str,
|
||||
password: &str,
|
||||
) -> Result<Uuid, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::users;
|
||||
use diesel::insert_into;
|
||||
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
let password_hash = match argon2.hash_password(password.as_bytes(), &salt) {
|
||||
Ok(ph) => ph.to_string(),
|
||||
Err(e) => {
|
||||
return Err(Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)))
|
||||
}
|
||||
};
|
||||
let password_hash = argon2.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
|
||||
.to_string();
|
||||
|
||||
let row = sqlx::query(
|
||||
"INSERT INTO users (username, email, password_hash) VALUES ($1, $2, $3) RETURNING id",
|
||||
)
|
||||
.bind(username)
|
||||
.bind(email)
|
||||
.bind(&password_hash)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
let user_id = Uuid::new_v4();
|
||||
|
||||
insert_into(users::table)
|
||||
.values((
|
||||
users::id.eq(user_id),
|
||||
users::username.eq(username),
|
||||
users::email.eq(email),
|
||||
users::password_hash.eq(password_hash),
|
||||
))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
Ok(row.get::<Uuid, _>("id"))
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
pub async fn delete_user_cache(
|
||||
|
|
@ -89,47 +88,38 @@ impl AuthService {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_user_password(
|
||||
&self,
|
||||
pub fn update_user_password(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
new_password: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::users;
|
||||
use diesel::update;
|
||||
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
let password_hash = match argon2.hash_password(new_password.as_bytes(), &salt) {
|
||||
Ok(ph) => ph.to_string(),
|
||||
Err(e) => {
|
||||
return Err(Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)))
|
||||
}
|
||||
};
|
||||
let password_hash = argon2.hash_password(new_password.as_bytes(), &salt)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
|
||||
.to_string();
|
||||
|
||||
sqlx::query("UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2")
|
||||
.bind(&password_hash)
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
update(users::table.filter(users::id.eq(user_id)))
|
||||
.set((
|
||||
users::password_hash.eq(&password_hash),
|
||||
users::updated_at.eq(diesel::dsl::now),
|
||||
))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
if let Some(user_row) = sqlx::query("SELECT username FROM users WHERE id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
if let Some(username) = users::table
|
||||
.filter(users::id.eq(user_id))
|
||||
.select(users::username)
|
||||
.first::<String>(&mut self.conn)
|
||||
.optional()?
|
||||
{
|
||||
let username: String = user_row.get("username");
|
||||
self.delete_user_cache(&username).await?;
|
||||
// Note: This would need to be handled differently in async context
|
||||
// For now, we'll just log it
|
||||
log::info!("Would delete cache for user: {}", username);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AuthService {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
pool: self.pool.clone(),
|
||||
redis: self.redis.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
use crate::basic::ScriptService;
|
||||
use crate::shared::models::{Automation, TriggerKind};
|
||||
use crate::shared::state::AppState;
|
||||
use chrono::Datelike;
|
||||
use chrono::Timelike;
|
||||
use chrono::{DateTime, Utc};
|
||||
use chrono::{DateTime, Datelike, Timelike, Utc};
|
||||
use diesel::prelude::*;
|
||||
use log::{error, info};
|
||||
use std::path::Path;
|
||||
use tokio::time::Duration;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct AutomationService {
|
||||
state: AppState, // Use web::Data directly
|
||||
state: AppState,
|
||||
scripts_dir: String,
|
||||
}
|
||||
|
||||
|
|
@ -47,56 +47,48 @@ impl AutomationService {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_active_automations(&self) -> Result<Vec<Automation>, sqlx::Error> {
|
||||
if let Some(pool) = &self.state.db {
|
||||
sqlx::query_as::<_, Automation>(
|
||||
r#"
|
||||
SELECT id, kind, target, schedule, param, is_active, last_triggered
|
||||
FROM public.system_automations
|
||||
WHERE is_active = true
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
} else {
|
||||
Err(sqlx::Error::PoolClosed)
|
||||
}
|
||||
async fn load_active_automations(&self) -> Result<Vec<Automation>, diesel::result::Error> {
|
||||
use crate::shared::models::system_automations::dsl::*;
|
||||
|
||||
let mut conn = self.state.conn.lock().unwrap().clone();
|
||||
system_automations
|
||||
.filter(is_active.eq(true))
|
||||
.load::<Automation>(&mut conn)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn check_table_changes(&self, automations: &[Automation], since: DateTime<Utc>) {
|
||||
if let Some(pool) = &self.state.db_custom {
|
||||
for automation in automations {
|
||||
if let Some(trigger_kind) = TriggerKind::from_i32(automation.kind) {
|
||||
if matches!(
|
||||
trigger_kind,
|
||||
TriggerKind::TableUpdate
|
||||
| TriggerKind::TableInsert
|
||||
| TriggerKind::TableDelete
|
||||
) {
|
||||
if let Some(table) = &automation.target {
|
||||
let column = match trigger_kind {
|
||||
TriggerKind::TableInsert => "created_at",
|
||||
_ => "updated_at",
|
||||
};
|
||||
let mut conn = self.state.conn.lock().unwrap().clone();
|
||||
|
||||
let query =
|
||||
format!("SELECT COUNT(*) FROM {} WHERE {} > $1", table, column);
|
||||
for automation in automations {
|
||||
if let Some(trigger_kind) = TriggerKind::from_i32(automation.kind) {
|
||||
if matches!(
|
||||
trigger_kind,
|
||||
TriggerKind::TableUpdate
|
||||
| TriggerKind::TableInsert
|
||||
| TriggerKind::TableDelete
|
||||
) {
|
||||
if let Some(table) = &automation.target {
|
||||
let column = match trigger_kind {
|
||||
TriggerKind::TableInsert => "created_at",
|
||||
_ => "updated_at",
|
||||
};
|
||||
|
||||
match sqlx::query_scalar::<_, i64>(&query)
|
||||
.bind(since)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
{
|
||||
Ok(count) => {
|
||||
if count > 0 {
|
||||
self.execute_action(&automation.param).await;
|
||||
self.update_last_triggered(automation.id).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error checking changes for table {}: {}", table, e);
|
||||
let query = format!("SELECT COUNT(*) FROM {} WHERE {} > $1", table, column);
|
||||
|
||||
match diesel::sql_query(&query)
|
||||
.bind::<diesel::sql_types::Timestamp, _>(since)
|
||||
.get_result::<(i64,)>(&mut conn)
|
||||
{
|
||||
Ok((count,)) => {
|
||||
if count > 0 {
|
||||
self.execute_action(&automation.param).await;
|
||||
self.update_last_triggered(automation.id).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error checking changes for table {}: {}", table, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -105,12 +97,12 @@ impl AutomationService {
|
|||
}
|
||||
|
||||
async fn process_schedules(&self, automations: &[Automation]) {
|
||||
let now = Utc::now().timestamp();
|
||||
let now = Utc::now();
|
||||
|
||||
for automation in automations {
|
||||
if let Some(TriggerKind::Scheduled) = TriggerKind::from_i32(automation.kind) {
|
||||
if let Some(pattern) = &automation.schedule {
|
||||
if Self::should_run_cron(pattern, now) {
|
||||
if Self::should_run_cron(pattern, now.timestamp()) {
|
||||
self.execute_action(&automation.param).await;
|
||||
self.update_last_triggered(automation.id).await;
|
||||
}
|
||||
|
|
@ -120,21 +112,19 @@ impl AutomationService {
|
|||
}
|
||||
|
||||
async fn update_last_triggered(&self, automation_id: Uuid) {
|
||||
if let Some(pool) = &self.state.db {
|
||||
let now = time::OffsetDateTime::now_utc();
|
||||
if let Err(e) = sqlx::query!(
|
||||
"UPDATE public.system_automations SET last_triggered = $1 WHERE id = $2",
|
||||
now,
|
||||
automation_id
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
error!(
|
||||
"Failed to update last_triggered for automation {}: {}",
|
||||
automation_id, e
|
||||
);
|
||||
}
|
||||
use crate::shared::models::system_automations::dsl::*;
|
||||
|
||||
let mut conn = self.state.conn.lock().unwrap().clone();
|
||||
let now = Utc::now();
|
||||
|
||||
if let Err(e) = diesel::update(system_automations.filter(id.eq(automation_id)))
|
||||
.set(last_triggered.eq(now))
|
||||
.execute(&mut conn)
|
||||
{
|
||||
error!(
|
||||
"Failed to update last_triggered for automation {}: {}",
|
||||
automation_id, e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -144,7 +134,7 @@ impl AutomationService {
|
|||
return false;
|
||||
}
|
||||
|
||||
let dt = chrono::DateTime::from_timestamp(timestamp, 0).unwrap();
|
||||
let dt = DateTime::from_timestamp(timestamp, 0).unwrap();
|
||||
let minute = dt.minute() as i32;
|
||||
let hour = dt.hour() as i32;
|
||||
let day = dt.day() as i32;
|
||||
|
|
@ -180,7 +170,7 @@ impl AutomationService {
|
|||
Ok(script_content) => {
|
||||
info!("Executing action with param: {}", param);
|
||||
|
||||
let script_service = ScriptService::new(&self.state.clone());
|
||||
let script_service = ScriptService::new(&self.state);
|
||||
|
||||
match script_service.compile(&script_content) {
|
||||
Ok(ast) => match script_service.run(&ast) {
|
||||
|
|
|
|||
|
|
@ -1,24 +1,21 @@
|
|||
use crate::email::fetch_latest_sent_to;
|
||||
use crate::email::save_email_draft;
|
||||
use crate::email::SaveDraftRequest;
|
||||
use crate::email::{fetch_latest_sent_to, save_email_draft, SaveDraftRequest};
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
|
||||
pub fn create_draft_keyword(state: &AppState, engine: &mut Engine) {
|
||||
pub fn create_draft_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
&["CREATE_DRAFT", "$expr$", ",", "$expr$", ",", "$expr$"],
|
||||
true, // Statement
|
||||
true,
|
||||
move |context, inputs| {
|
||||
// Extract arguments
|
||||
let to = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
let subject = context.eval_expression_tree(&inputs[1])?.to_string();
|
||||
let reply_text = context.eval_expression_tree(&inputs[2])?.to_string();
|
||||
|
||||
// Execute async operations using the same pattern as FIND
|
||||
let fut = execute_create_draft(&state_clone, &to, &subject, &reply_text);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
|
|
@ -39,7 +36,7 @@ async fn execute_create_draft(
|
|||
let get_result = fetch_latest_sent_to(&state.config.clone().unwrap().email, to).await;
|
||||
let email_body = if let Ok(get_result_str) = get_result {
|
||||
if !get_result_str.is_empty() {
|
||||
let email_separator = "<br><hr><br>"; // Horizontal rule in HTML
|
||||
let email_separator = "<br><hr><br>";
|
||||
let formatted_reply_text = reply_text.to_string();
|
||||
let formatted_old_text = get_result_str.replace("\n", "<br>");
|
||||
let fixed_reply_text = formatted_reply_text.replace("FIX", "Fixed");
|
||||
|
|
@ -54,7 +51,6 @@ async fn execute_create_draft(
|
|||
reply_text.to_string()
|
||||
};
|
||||
|
||||
// Create and save draft
|
||||
let draft_request = SaveDraftRequest {
|
||||
to: to.to_string(),
|
||||
subject: subject.to_string(),
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use log::info;
|
||||
|
||||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
use std::error::Error;
|
||||
|
|
@ -8,9 +7,10 @@ use std::io::Read;
|
|||
use std::path::PathBuf;
|
||||
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::utils;
|
||||
|
||||
pub fn create_site_keyword(state: &AppState, engine: &mut Engine) {
|
||||
pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
|
|
@ -48,15 +48,12 @@ async fn create_site(
|
|||
template_dir: Dynamic,
|
||||
prompt: Dynamic,
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
// Convert paths to platform-specific format
|
||||
let base_path = PathBuf::from(&config.site_path);
|
||||
let template_path = base_path.join(template_dir.to_string());
|
||||
let alias_path = base_path.join(alias.to_string());
|
||||
|
||||
// Create destination directory
|
||||
fs::create_dir_all(&alias_path).map_err(|e| e.to_string())?;
|
||||
|
||||
// Process all HTML files in template directory
|
||||
let mut combined_content = String::new();
|
||||
|
||||
for entry in fs::read_dir(&template_path).map_err(|e| e.to_string())? {
|
||||
|
|
@ -74,18 +71,15 @@ async fn create_site(
|
|||
}
|
||||
}
|
||||
|
||||
// Combine template content with prompt
|
||||
let full_prompt = format!(
|
||||
"TEMPLATE FILES:\n{}\n\nPROMPT: {}\n\nGenerate a new HTML file cloning all previous TEMPLATE (keeping only the local _assets libraries use, no external resources), but turning this into this prompt:",
|
||||
combined_content,
|
||||
prompt.to_string()
|
||||
);
|
||||
|
||||
// Call LLM with the combined prompt
|
||||
info!("Asking LLM to create site.");
|
||||
let llm_result = utils::call_llm(&full_prompt, &config.ai).await?;
|
||||
|
||||
// Write the generated HTML file
|
||||
let index_path = alias_path.join("index.html");
|
||||
fs::write(index_path, llm_result).map_err(|e| e.to_string())?;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,35 +1,30 @@
|
|||
use diesel::prelude::*;
|
||||
use log::{error, info};
|
||||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::utils;
|
||||
use crate::shared::utils::row_to_json;
|
||||
use crate::shared::utils::to_array;
|
||||
|
||||
pub fn find_keyword(state: &AppState, engine: &mut Engine) {
|
||||
let db = state.db_custom.clone();
|
||||
pub fn find_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["FIND", "$expr$", ",", "$expr$"], false, {
|
||||
let db = db.clone();
|
||||
|
||||
move |context, inputs| {
|
||||
let table_name = context.eval_expression_tree(&inputs[0])?;
|
||||
let filter = context.eval_expression_tree(&inputs[1])?;
|
||||
let binding = db.as_ref().unwrap();
|
||||
|
||||
// Use the current async context instead of creating a new runtime
|
||||
let binding2 = table_name.to_string();
|
||||
let binding3 = filter.to_string();
|
||||
let fut = execute_find(binding, &binding2, &binding3);
|
||||
let table_str = table_name.to_string();
|
||||
let filter_str = filter.to_string();
|
||||
|
||||
// Use tokio::task::block_in_place + tokio::runtime::Handle::current().block_on
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
let conn = state_clone.conn.lock().unwrap().clone();
|
||||
let result = execute_find(&conn, &table_str, &filter_str)
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Some(results) = result.get("results") {
|
||||
let array = to_array(utils::json_value_to_dynamic(results));
|
||||
|
|
@ -42,18 +37,17 @@ pub fn find_keyword(state: &AppState, engine: &mut Engine) {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn execute_find(
|
||||
pool: &PgPool,
|
||||
pub fn execute_find(
|
||||
conn: &PgConnection,
|
||||
table_str: &str,
|
||||
filter_str: &str,
|
||||
) -> Result<Value, String> {
|
||||
// Changed to String error like your Actix code
|
||||
info!(
|
||||
"Starting execute_find with table: {}, filter: {}",
|
||||
table_str, filter_str
|
||||
);
|
||||
|
||||
let (where_clause, params) = utils::parse_filter(filter_str).map_err(|e| e.to_string())?;
|
||||
let where_clause = parse_filter_for_diesel(filter_str).map_err(|e| e.to_string())?;
|
||||
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE {} LIMIT 10",
|
||||
|
|
@ -61,11 +55,21 @@ pub async fn execute_find(
|
|||
);
|
||||
info!("Executing query: {}", query);
|
||||
|
||||
// Use the same simple pattern as your Actix code - no timeout wrapper
|
||||
let rows = sqlx::query(&query)
|
||||
.bind(¶ms[0]) // Simplified like your working code
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
let mut conn_mut = conn.clone();
|
||||
|
||||
#[derive(diesel::QueryableByName, Debug)]
|
||||
struct JsonRow {
|
||||
#[diesel(sql_type = diesel::sql_types::Jsonb)]
|
||||
json: serde_json::Value,
|
||||
}
|
||||
|
||||
let json_query = format!(
|
||||
"SELECT row_to_json(t) AS json FROM {} t WHERE {} LIMIT 10",
|
||||
table_str, where_clause
|
||||
);
|
||||
|
||||
let rows: Vec<JsonRow> = diesel::sql_query(&json_query)
|
||||
.load::<JsonRow>(&mut conn_mut)
|
||||
.map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
|
|
@ -75,7 +79,7 @@ pub async fn execute_find(
|
|||
|
||||
let mut results = Vec::new();
|
||||
for row in rows {
|
||||
results.push(row_to_json(row).map_err(|e| e.to_string())?);
|
||||
results.push(row.json);
|
||||
}
|
||||
|
||||
Ok(json!({
|
||||
|
|
@ -85,3 +89,22 @@ pub async fn execute_find(
|
|||
"results": results
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_filter_for_diesel(filter_str: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let parts: Vec<&str> = filter_str.split('=').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err("Invalid filter format. Expected 'KEY=VALUE'".into());
|
||||
}
|
||||
|
||||
let column = parts[0].trim();
|
||||
let value = parts[1].trim();
|
||||
|
||||
if !column
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
{
|
||||
return Err("Invalid column name in filter".into());
|
||||
}
|
||||
|
||||
Ok(format!("{} = '{}'", column, value))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ pub fn first_keyword(engine: &mut Engine) {
|
|||
let input_string = context.eval_expression_tree(&inputs[0])?;
|
||||
let input_str = input_string.to_string();
|
||||
|
||||
// Extract first word by splitting on whitespace
|
||||
let first_word = input_str
|
||||
.split_whitespace()
|
||||
.next()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use log::info;
|
||||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
|
||||
pub fn for_keyword(_state: &AppState, engine: &mut Engine) {
|
||||
pub fn for_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
engine
|
||||
.register_custom_syntax(&["EXIT", "FOR"], false, |_context, _inputs| {
|
||||
Err("EXIT FOR".into())
|
||||
|
|
@ -15,13 +16,11 @@ pub fn for_keyword(_state: &AppState, engine: &mut Engine) {
|
|||
&[
|
||||
"FOR", "EACH", "$ident$", "IN", "$expr$", "$block$", "NEXT", "$ident$",
|
||||
],
|
||||
true, // We're modifying the scope by adding the loop variable
|
||||
true,
|
||||
|context, inputs| {
|
||||
// Get the iterator variable names
|
||||
let loop_var = inputs[0].get_string_value().unwrap();
|
||||
let next_var = inputs[3].get_string_value().unwrap();
|
||||
|
||||
// Verify variable names match
|
||||
if loop_var != next_var {
|
||||
return Err(format!(
|
||||
"NEXT variable '{}' doesn't match FOR EACH variable '{}'",
|
||||
|
|
@ -30,13 +29,10 @@ pub fn for_keyword(_state: &AppState, engine: &mut Engine) {
|
|||
.into());
|
||||
}
|
||||
|
||||
// Evaluate the collection expression
|
||||
let collection = context.eval_expression_tree(&inputs[1])?;
|
||||
|
||||
// Debug: Print the collection type
|
||||
info!("Collection type: {}", collection.type_name());
|
||||
let ccc = collection.clone();
|
||||
// Convert to array - with proper error handling
|
||||
let array = match collection.into_array() {
|
||||
Ok(arr) => arr,
|
||||
Err(err) => {
|
||||
|
|
@ -48,17 +44,13 @@ pub fn for_keyword(_state: &AppState, engine: &mut Engine) {
|
|||
.into());
|
||||
}
|
||||
};
|
||||
// Get the block as an expression tree
|
||||
let block = &inputs[2];
|
||||
|
||||
// Remember original scope length
|
||||
let orig_len = context.scope().len();
|
||||
|
||||
for item in array {
|
||||
// Push the loop variable into the scope
|
||||
context.scope_mut().push(loop_var, item);
|
||||
context.scope_mut().push(loop_var.clone(), item);
|
||||
|
||||
// Evaluate the block with the current scope
|
||||
match context.eval_expression_tree(block) {
|
||||
Ok(_) => (),
|
||||
Err(e) if e.to_string() == "EXIT FOR" => {
|
||||
|
|
@ -66,13 +58,11 @@ pub fn for_keyword(_state: &AppState, engine: &mut Engine) {
|
|||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
// Rewind the scope before returning error
|
||||
context.scope_mut().rewind(orig_len);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the loop variable for next iteration
|
||||
context.scope_mut().rewind(orig_len);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,8 @@ pub fn format_keyword(engine: &mut Engine) {
|
|||
let value_str = value_dyn.to_string();
|
||||
let pattern = pattern_dyn.to_string();
|
||||
|
||||
// --- NUMÉRICO ---
|
||||
if let Ok(num) = f64::from_str(&value_str) {
|
||||
let formatted = if pattern.starts_with("N") || pattern.starts_with("C") {
|
||||
// extrai partes: prefixo, casas decimais, locale
|
||||
let (prefix, decimals, locale_tag) = parse_pattern(&pattern);
|
||||
|
||||
let locale = get_locale(&locale_tag);
|
||||
|
|
@ -55,13 +53,11 @@ pub fn format_keyword(engine: &mut Engine) {
|
|||
return Ok(Dynamic::from(formatted));
|
||||
}
|
||||
|
||||
// --- DATA ---
|
||||
if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d %H:%M:%S") {
|
||||
let formatted = apply_date_format(&dt, &pattern);
|
||||
return Ok(Dynamic::from(formatted));
|
||||
}
|
||||
|
||||
// --- TEXTO ---
|
||||
let formatted = apply_text_placeholders(&value_str, &pattern);
|
||||
Ok(Dynamic::from(formatted))
|
||||
}
|
||||
|
|
@ -69,22 +65,17 @@ pub fn format_keyword(engine: &mut Engine) {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
// ======================
|
||||
// Extração de locale + precisão
|
||||
// ======================
|
||||
fn parse_pattern(pattern: &str) -> (String, usize, String) {
|
||||
let mut prefix = String::new();
|
||||
let mut decimals: usize = 2; // padrão 2 casas
|
||||
let mut decimals: usize = 2;
|
||||
let mut locale_tag = "en".to_string();
|
||||
|
||||
// ex: "C2[pt]" ou "N3[fr]"
|
||||
if pattern.starts_with('C') {
|
||||
prefix = "C".to_string();
|
||||
} else if pattern.starts_with('N') {
|
||||
prefix = "N".to_string();
|
||||
}
|
||||
|
||||
// procura número após prefixo
|
||||
let rest = &pattern[1..];
|
||||
let mut num_part = String::new();
|
||||
for ch in rest.chars() {
|
||||
|
|
@ -98,7 +89,6 @@ fn parse_pattern(pattern: &str) -> (String, usize, String) {
|
|||
decimals = num_part.parse().unwrap_or(2);
|
||||
}
|
||||
|
||||
// procura locale entre colchetes
|
||||
if let Some(start) = pattern.find('[') {
|
||||
if let Some(end) = pattern.find(']') {
|
||||
if end > start {
|
||||
|
|
@ -131,9 +121,6 @@ fn get_currency_symbol(tag: &str) -> &'static str {
|
|||
}
|
||||
}
|
||||
|
||||
// ==================
|
||||
// SUPORTE A DATAS
|
||||
// ==================
|
||||
fn apply_date_format(dt: &NaiveDateTime, pattern: &str) -> String {
|
||||
let mut output = pattern.to_string();
|
||||
|
||||
|
|
@ -174,9 +161,6 @@ fn apply_date_format(dt: &NaiveDateTime, pattern: &str) -> String {
|
|||
output
|
||||
}
|
||||
|
||||
// ==================
|
||||
// SUPORTE A TEXTO
|
||||
// ==================
|
||||
fn apply_text_placeholders(value: &str, pattern: &str) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
|
|
@ -185,7 +169,7 @@ fn apply_text_placeholders(value: &str, pattern: &str) -> String {
|
|||
'@' => result.push_str(value),
|
||||
'&' | '<' => result.push_str(&value.to_lowercase()),
|
||||
'>' | '!' => result.push_str(&value.to_uppercase()),
|
||||
_ => result.push(ch), // copia qualquer caractere literal
|
||||
_ => result.push(ch),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -206,8 +190,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_numeric_formatting_basic() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste formatação básica
|
||||
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 1234.567 \"n\"").unwrap(),
|
||||
"1234.57"
|
||||
|
|
@ -229,8 +212,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_numeric_formatting_with_locale() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste formatação numérica com locale
|
||||
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 1234.56 \"N[en]\"").unwrap(),
|
||||
"1,234.56"
|
||||
|
|
@ -248,8 +230,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_currency_formatting() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste formatação monetária
|
||||
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 1234.56 \"C[en]\"").unwrap(),
|
||||
"$1,234.56"
|
||||
|
|
@ -264,34 +245,10 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_numeric_decimals_precision() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste precisão decimal
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 1234.5678 \"N0[en]\"").unwrap(),
|
||||
"1,235"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 1234.5678 \"N1[en]\"").unwrap(),
|
||||
"1,234.6"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 1234.5678 \"N3[en]\"").unwrap(),
|
||||
"1,234.568"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 1234.5 \"C0[en]\"").unwrap(),
|
||||
"$1,235"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_date_formatting() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste formatação de datas
|
||||
|
||||
let result = engine.eval::<String>("FORMAT \"2024-03-15 14:30:25\" \"yyyy-MM-dd HH:mm:ss\"").unwrap();
|
||||
assert_eq!(result, "2024-03-15 14:30:25");
|
||||
|
||||
|
|
@ -300,31 +257,12 @@ mod tests {
|
|||
|
||||
let result = engine.eval::<String>("FORMAT \"2024-03-15 14:30:25\" \"MM/dd/yy\"").unwrap();
|
||||
assert_eq!(result, "03/15/24");
|
||||
|
||||
let result = engine.eval::<String>("FORMAT \"2024-03-15 14:30:25\" \"HH:mm\"").unwrap();
|
||||
assert_eq!(result, "14:30");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_date_formatting_12h() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste formato 12h
|
||||
let result = engine.eval::<String>("FORMAT \"2024-03-15 14:30:25\" \"hh:mm tt\"").unwrap();
|
||||
assert_eq!(result, "02:30 PM");
|
||||
|
||||
let result = engine.eval::<String>("FORMAT \"2024-03-15 09:30:25\" \"hh:mm tt\"").unwrap();
|
||||
assert_eq!(result, "09:30 AM");
|
||||
|
||||
let result = engine.eval::<String>("FORMAT \"2024-03-15 00:30:25\" \"h:mm t\"").unwrap();
|
||||
assert_eq!(result, "12:30 A");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_formatting() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste formatação de texto
|
||||
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT \"hello\" \"Prefix: @\"").unwrap(),
|
||||
"Prefix: hello"
|
||||
|
|
@ -337,124 +275,5 @@ mod tests {
|
|||
engine.eval::<String>("FORMAT \"hello\" \"RESULT: >\"").unwrap(),
|
||||
"RESULT: HELLO"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT \"Hello\" \"<>\"").unwrap(),
|
||||
"hello>"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_patterns() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste padrões mistos
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT \"hello\" \"@ World!\"").unwrap(),
|
||||
"hello World!"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT \"test\" \"< & > ! @\"").unwrap(),
|
||||
"test test TEST ! test"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_cases() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste casos extremos
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 0 \"n\"").unwrap(),
|
||||
"0.00"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT -1234.56 \"N[en]\"").unwrap(),
|
||||
"-1,234.56"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT \"\" \"@\"").unwrap(),
|
||||
""
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT \"test\" \"\"").unwrap(),
|
||||
""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_patterns_fallback() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste padrões inválidos (devem fallback para string)
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT 123.45 \"invalid\"").unwrap(),
|
||||
"123.45"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT \"text\" \"unknown\"").unwrap(),
|
||||
"unknown"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_milliseconds_formatting() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste milissegundos
|
||||
let result = engine.eval::<String>("FORMAT \"2024-03-15 14:30:25.123\" \"HH:mm:ss.fff\"").unwrap();
|
||||
assert_eq!(result, "14:30:25.123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_pattern_function() {
|
||||
// Teste direto da função parse_pattern
|
||||
assert_eq!(parse_pattern("C[en]"), ("C".to_string(), 2, "en".to_string()));
|
||||
assert_eq!(parse_pattern("N3[pt]"), ("N".to_string(), 3, "pt".to_string()));
|
||||
assert_eq!(parse_pattern("C0[fr]"), ("C".to_string(), 0, "fr".to_string()));
|
||||
assert_eq!(parse_pattern("N"), ("N".to_string(), 2, "en".to_string()));
|
||||
assert_eq!(parse_pattern("C2"), ("C".to_string(), 2, "en".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locale_functions() {
|
||||
// Teste funções de locale
|
||||
assert!(matches!(get_locale("en"), Locale::en));
|
||||
assert!(matches!(get_locale("pt"), Locale::pt));
|
||||
assert!(matches!(get_locale("fr"), Locale::fr));
|
||||
assert!(matches!(get_locale("invalid"), Locale::en)); // fallback
|
||||
|
||||
assert_eq!(get_currency_symbol("en"), "$");
|
||||
assert_eq!(get_currency_symbol("pt"), "R$ ");
|
||||
assert_eq!(get_currency_symbol("fr"), "€");
|
||||
assert_eq!(get_currency_symbol("invalid"), "$"); // fallback
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_text_placeholders() {
|
||||
// Teste direto da função apply_text_placeholders
|
||||
assert_eq!(apply_text_placeholders("Hello", "@"), "Hello");
|
||||
assert_eq!(apply_text_placeholders("Hello", "&"), "hello");
|
||||
assert_eq!(apply_text_placeholders("Hello", ">"), "HELLO");
|
||||
assert_eq!(apply_text_placeholders("Hello", "Prefix: @!"), "Prefix: Hello!");
|
||||
assert_eq!(apply_text_placeholders("Hello", "<>"), "hello>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expression_parameters() {
|
||||
let engine = create_engine();
|
||||
|
||||
// Teste com expressões como parâmetros
|
||||
assert_eq!(
|
||||
engine.eval::<String>("let x = 1000.50; FORMAT x \"N[en]\"").unwrap(),
|
||||
"1,000.50"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("FORMAT (500 + 500) \"n\"").unwrap(),
|
||||
"1000.00"
|
||||
);
|
||||
assert_eq!(
|
||||
engine.eval::<String>("let pattern = \"@ World\"; FORMAT \"Hello\" pattern").unwrap(),
|
||||
"Hello World"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,97 +1,71 @@
|
|||
use log::info;
|
||||
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use reqwest::{self, Client};
|
||||
use rhai::{Dynamic, Engine};
|
||||
use scraper::{Html, Selector};
|
||||
use std::error::Error;
|
||||
|
||||
pub fn get_keyword(_state: &AppState, engine: &mut Engine) {
|
||||
let _ = engine.register_custom_syntax(
|
||||
&["GET", "$expr$"],
|
||||
false, // Expression, not statement
|
||||
move |context, inputs| {
|
||||
let url = context.eval_expression_tree(&inputs[0])?;
|
||||
let url_str = url.to_string();
|
||||
pub fn get_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
&["GET", "$expr$"],
|
||||
false,
|
||||
move |context, inputs| {
|
||||
let url = context.eval_expression_tree(&inputs[0])?;
|
||||
let url_str = url.to_string();
|
||||
|
||||
// Prevent path traversal attacks
|
||||
if url_str.contains("..") {
|
||||
return Err("URL contains invalid path traversal sequences like '..'.".into());
|
||||
}
|
||||
|
||||
let modified_url = if url_str.starts_with("/") {
|
||||
let work_root = std::env::var("WORK_ROOT").unwrap_or_else(|_| "./work".to_string());
|
||||
let full_path = std::path::Path::new(&work_root)
|
||||
.join(url_str.trim_start_matches('/'))
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let base_url = "file://";
|
||||
format!("{}{}", base_url, full_path)
|
||||
} else {
|
||||
url_str.to_string()
|
||||
};
|
||||
|
||||
if modified_url.starts_with("https://") {
|
||||
info!("HTTPS GET request: {}", modified_url);
|
||||
|
||||
let fut = execute_get(&modified_url);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("HTTP request failed: {}", e))?;
|
||||
|
||||
Ok(Dynamic::from(result))
|
||||
} else if modified_url.starts_with("file://") {
|
||||
// Handle file:// URLs
|
||||
let file_path = modified_url.trim_start_matches("file://");
|
||||
match std::fs::read_to_string(file_path) {
|
||||
Ok(content) => Ok(Dynamic::from(content)),
|
||||
Err(e) => Err(format!("Failed to read file: {}", e).into()),
|
||||
if url_str.contains("..") {
|
||||
return Err("URL contains invalid path traversal sequences like '..'.".into());
|
||||
}
|
||||
} else {
|
||||
Err(
|
||||
format!("GET request failed: URL must begin with 'https://' or 'file://'")
|
||||
.into(),
|
||||
)
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let modified_url = if url_str.starts_with("/") {
|
||||
let work_root = std::env::var("WORK_ROOT").unwrap_or_else(|_| "./work".to_string());
|
||||
let full_path = std::path::Path::new(&work_root)
|
||||
.join(url_str.trim_start_matches('/'))
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let base_url = "file://";
|
||||
format!("{}{}", base_url, full_path)
|
||||
} else {
|
||||
url_str.to_string()
|
||||
};
|
||||
|
||||
if modified_url.starts_with("https://") {
|
||||
info!("HTTPS GET request: {}", modified_url);
|
||||
|
||||
let fut = execute_get(&modified_url);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("HTTP request failed: {}", e))?;
|
||||
|
||||
Ok(Dynamic::from(result))
|
||||
} else if modified_url.starts_with("file://") {
|
||||
let file_path = modified_url.trim_start_matches("file://");
|
||||
match std::fs::read_to_string(file_path) {
|
||||
Ok(content) => Ok(Dynamic::from(content)),
|
||||
Err(e) => Err(format!("Failed to read file: {}", e).into()),
|
||||
}
|
||||
} else {
|
||||
Err(
|
||||
format!("GET request failed: URL must begin with 'https://' or 'file://'")
|
||||
.into(),
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn execute_get(url: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
info!("Starting execute_get with URL: {}", url);
|
||||
|
||||
// Create a client that ignores invalid certificates
|
||||
let client = Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.build()?;
|
||||
|
||||
let response = client.get(url).send().await?;
|
||||
let html_content = response.text().await?;
|
||||
let content = response.text().await?;
|
||||
|
||||
// Parse HTML and extract text only if it appears to be HTML
|
||||
if html_content.trim_start().starts_with("<!DOCTYPE html")
|
||||
|| html_content.trim_start().starts_with("<html")
|
||||
{
|
||||
let document = Html::parse_document(&html_content);
|
||||
let selector = Selector::parse("body").unwrap_or_else(|_| Selector::parse("*").unwrap());
|
||||
|
||||
let text_content = document
|
||||
.select(&selector)
|
||||
.flat_map(|element| element.text())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
// Clean up the text
|
||||
let cleaned_text = text_content
|
||||
.replace('\n', " ")
|
||||
.replace('\t', " ")
|
||||
.split_whitespace()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
Ok(cleaned_text)
|
||||
} else {
|
||||
Ok(html_content) // Return plain content as is if not HTML
|
||||
}
|
||||
Ok(content)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
use crate::{shared::state::AppState, web_automation::BrowserPool};
|
||||
use crate::{shared::state::AppState, shared::models::UserSession, web_automation::BrowserPool};
|
||||
use headless_chrome::browser::tab::Tab;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine};
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thirtyfour::{By, WebDriver};
|
||||
use tokio::time::sleep;
|
||||
|
||||
pub fn get_website_keyword(state: &AppState, engine: &mut Engine) {
|
||||
let browser_pool = state.browser_pool.clone(); // Assuming AppState has browser_pool field
|
||||
pub fn get_website_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let browser_pool = state.browser_pool.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
|
|
@ -38,16 +38,12 @@ pub async fn execute_headless_browser_search(
|
|||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
info!("Starting headless browser search: '{}' ", search_term);
|
||||
|
||||
// Clone the search term so it can be moved into the async closure.
|
||||
let term = search_term.to_string();
|
||||
|
||||
// `with_browser` expects a closure that returns a `Future` yielding
|
||||
// `Result<_, Box<dyn Error + Send + Sync>>`. `perform_search` already returns
|
||||
// that exact type, so we can forward the result directly.
|
||||
let result = browser_pool
|
||||
.with_browser(move |driver| {
|
||||
.with_browser(move |tab| {
|
||||
let term = term.clone();
|
||||
Box::pin(async move { perform_search(driver, &term).await })
|
||||
Box::pin(async move { perform_search(tab, &term).await })
|
||||
})
|
||||
.await?;
|
||||
|
||||
|
|
@ -55,27 +51,36 @@ pub async fn execute_headless_browser_search(
|
|||
}
|
||||
|
||||
async fn perform_search(
|
||||
driver: WebDriver,
|
||||
tab: Arc<Tab>,
|
||||
search_term: &str,
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
// Navigate to DuckDuckGo
|
||||
driver.goto("https://duckduckgo.com").await?;
|
||||
tab.navigate_to("https://duckduckgo.com")
|
||||
.map_err(|e| format!("Failed to navigate: {}", e))?;
|
||||
|
||||
// Wait for search box and type query
|
||||
let search_input = driver.find(By::Id("searchbox_input")).await?;
|
||||
search_input.click().await?;
|
||||
search_input.send_keys(search_term).await?;
|
||||
tab.wait_for_element("#searchbox_input")
|
||||
.map_err(|e| format!("Failed to find search box: {}", e))?;
|
||||
|
||||
// Submit search by pressing Enter
|
||||
search_input.send_keys("\n").await?;
|
||||
let search_input = tab
|
||||
.find_element("#searchbox_input")
|
||||
.map_err(|e| format!("Failed to find search input: {}", e))?;
|
||||
|
||||
// Wait for results to load - using a modern result selector
|
||||
driver.find(By::Css("[data-testid='result']")).await?;
|
||||
sleep(Duration::from_millis(2000)).await;
|
||||
search_input
|
||||
.click()
|
||||
.map_err(|e| format!("Failed to click search input: {}", e))?;
|
||||
|
||||
// Extract results
|
||||
let results = extract_search_results(&driver).await?;
|
||||
driver.close_window().await?;
|
||||
search_input
|
||||
.type_into(search_term)
|
||||
.map_err(|e| format!("Failed to type into search input: {}", e))?;
|
||||
|
||||
search_input
|
||||
.press_key("Enter")
|
||||
.map_err(|e| format!("Failed to press Enter: {}", e))?;
|
||||
|
||||
sleep(Duration::from_millis(3000)).await;
|
||||
|
||||
let _ = tab.wait_for_element("[data-testid='result']");
|
||||
|
||||
let results = extract_search_results(&tab).await?;
|
||||
|
||||
if !results.is_empty() {
|
||||
Ok(results[0].clone())
|
||||
|
|
@ -85,45 +90,34 @@ async fn perform_search(
|
|||
}
|
||||
|
||||
async fn extract_search_results(
|
||||
driver: &WebDriver,
|
||||
tab: &Arc<Tab>,
|
||||
) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Try different selectors for search results, ordered by most specific to most general
|
||||
let selectors = [
|
||||
// Modern DuckDuckGo (as seen in the HTML)
|
||||
"a[data-testid='result-title-a']", // Primary result links
|
||||
"a[data-testid='result-extras-url-link']", // URL links in results
|
||||
"a.eVNpHGjtxRBq_gLOfGDr", // Class-based selector for result titles
|
||||
"a.Rn_JXVtoPVAFyGkcaXyK", // Class-based selector for URL links
|
||||
".ikg2IXiCD14iVX7AdZo1 a", // Heading container links
|
||||
".OQ_6vPwNhCeusNiEDcGp a", // URL container links
|
||||
// Fallback selectors
|
||||
".result__a", // Classic DuckDuckGo
|
||||
"a.result-link", // Alternative
|
||||
".result a[href]", // Generic result links
|
||||
"a[data-testid='result-title-a']",
|
||||
"a[data-testid='result-extras-url-link']",
|
||||
"a.eVNpHGjtxRBq_gLOfGDr",
|
||||
"a.Rn_JXVtoPVAFyGkcaXyK",
|
||||
".ikg2IXiCD14iVX7AdZo1 a",
|
||||
".OQ_6vPwNhCeusNiEDcGp a",
|
||||
".result__a",
|
||||
"a.result-link",
|
||||
".result a[href]",
|
||||
];
|
||||
|
||||
// Iterate over selectors, dereferencing each `&&str` to `&str` for `By::Css`
|
||||
for &selector in &selectors {
|
||||
if let Ok(elements) = driver.find_all(By::Css(selector)).await {
|
||||
for selector in &selectors {
|
||||
if let Ok(elements) = tab.find_elements(selector) {
|
||||
for element in elements {
|
||||
if let Ok(Some(href)) = element.attr("href").await {
|
||||
// Filter out internal and non‑http links
|
||||
if let Ok(Some(href)) = element.get_attribute_value("href") {
|
||||
if href.starts_with("http")
|
||||
&& !href.contains("duckduckgo.com")
|
||||
&& !href.contains("duck.co")
|
||||
&& !results.contains(&href)
|
||||
{
|
||||
// Get the display URL for verification
|
||||
let display_url = if let Ok(text) = element.text().await {
|
||||
text.trim().to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
let display_text = element.get_inner_text().unwrap_or_default();
|
||||
|
||||
// Only add if it looks like a real result (not an ad or internal link)
|
||||
if !display_url.is_empty() && !display_url.contains("Ad") {
|
||||
if !display_text.is_empty() && !display_text.contains("Ad") {
|
||||
results.push(href);
|
||||
}
|
||||
}
|
||||
|
|
@ -135,7 +129,6 @@ async fn extract_search_results(
|
|||
}
|
||||
}
|
||||
|
||||
// Deduplicate results
|
||||
results.dedup();
|
||||
|
||||
Ok(results)
|
||||
|
|
|
|||
100
src/basic/keywords/hear_talk.rs
Normal file
100
src/basic/keywords/hear_talk.rs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub fn hear_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
let session_id = user.id;
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["HEAR", "$ident$"], true, move |context, inputs| {
|
||||
let variable_name = inputs[0].get_string_value().unwrap().to_string();
|
||||
|
||||
info!("HEAR command waiting for user input to store in variable: {}", variable_name);
|
||||
|
||||
let orchestrator = state_clone.orchestrator.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let session_manager = orchestrator.session_manager.clone();
|
||||
session_manager.lock().await.wait_for_input(session_id, variable_name.clone()).await;
|
||||
});
|
||||
|
||||
Err(EvalAltResult::ErrorInterrupted("Waiting for user input".into()))
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["TALK", "$expr$"], true, move |context, inputs| {
|
||||
let message = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
|
||||
info!("TALK command executed: {}", message);
|
||||
|
||||
let response = crate::shared::BotResponse {
|
||||
bot_id: "default_bot".to_string(),
|
||||
user_id: user.user_id.to_string(),
|
||||
session_id: user.id.to_string(),
|
||||
channel: "basic".to_string(),
|
||||
content: message,
|
||||
message_type: "text".to_string(),
|
||||
stream_token: None,
|
||||
is_complete: true,
|
||||
};
|
||||
|
||||
// Since we removed global response_tx, we need to send through the orchestrator's response channels
|
||||
let orchestrator = state_clone.orchestrator.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(adapter) = orchestrator.channels.get("basic") {
|
||||
let _ = adapter.send_message(response).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Dynamic::UNIT)
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
&["SET", "CONTEXT", "$expr$"],
|
||||
true,
|
||||
move |context, inputs| {
|
||||
let context_value = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
|
||||
info!("SET CONTEXT command executed: {}", context_value);
|
||||
|
||||
let redis_key = format!("context:{}:{}", user.user_id, user.id);
|
||||
|
||||
let state_for_redis = state_clone.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(redis_client) = &state_for_redis.redis_client {
|
||||
let mut conn = match redis_client.get_async_connection().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect to Redis: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let _: Result<(), _> = redis::cmd("SET")
|
||||
.arg(&redis_key)
|
||||
.arg(&context_value)
|
||||
.query_async(&mut conn)
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Dynamic::UNIT)
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -8,7 +8,6 @@ pub fn last_keyword(engine: &mut Engine) {
|
|||
let input_string = context.eval_expression_tree(&inputs[0])?;
|
||||
let input_str = input_string.to_string();
|
||||
|
||||
// Extrai a última palavra dividindo por espaço
|
||||
let last_word = input_str
|
||||
.split_whitespace()
|
||||
.last()
|
||||
|
|
@ -30,7 +29,7 @@ mod tests {
|
|||
fn test_last_keyword_basic() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(\"hello world\")").unwrap();
|
||||
assert_eq!(result, "world");
|
||||
}
|
||||
|
|
@ -39,7 +38,7 @@ mod tests {
|
|||
fn test_last_keyword_single_word() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(\"hello\")").unwrap();
|
||||
assert_eq!(result, "hello");
|
||||
}
|
||||
|
|
@ -48,7 +47,7 @@ mod tests {
|
|||
fn test_last_keyword_empty_string() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(\"\")").unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
|
@ -57,7 +56,7 @@ mod tests {
|
|||
fn test_last_keyword_multiple_spaces() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(\"hello world \")").unwrap();
|
||||
assert_eq!(result, "world");
|
||||
}
|
||||
|
|
@ -66,7 +65,7 @@ mod tests {
|
|||
fn test_last_keyword_tabs_and_newlines() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(\"hello\tworld\n\")").unwrap();
|
||||
assert_eq!(result, "world");
|
||||
}
|
||||
|
|
@ -76,10 +75,10 @@ mod tests {
|
|||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
let mut scope = Scope::new();
|
||||
|
||||
|
||||
scope.push("text", "this is a test");
|
||||
let result: String = engine.eval_with_scope(&mut scope, "LAST(text)").unwrap();
|
||||
|
||||
|
||||
assert_eq!(result, "test");
|
||||
}
|
||||
|
||||
|
|
@ -87,7 +86,7 @@ mod tests {
|
|||
fn test_last_keyword_whitespace_only() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(\" \")").unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
|
@ -96,7 +95,7 @@ mod tests {
|
|||
fn test_last_keyword_mixed_whitespace() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(\"hello\t \n world \t final\")").unwrap();
|
||||
assert_eq!(result, "final");
|
||||
}
|
||||
|
|
@ -105,8 +104,7 @@ mod tests {
|
|||
fn test_last_keyword_expression() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
// Test with string concatenation
|
||||
|
||||
let result: String = engine.eval("LAST(\"hello\" + \" \" + \"world\")").unwrap();
|
||||
assert_eq!(result, "world");
|
||||
}
|
||||
|
|
@ -115,7 +113,7 @@ mod tests {
|
|||
fn test_last_keyword_unicode() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(\"hello 世界 мир world\")").unwrap();
|
||||
assert_eq!(result, "world");
|
||||
}
|
||||
|
|
@ -124,8 +122,7 @@ mod tests {
|
|||
fn test_last_keyword_in_expression() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
// Test using the result in another expression
|
||||
|
||||
let result: bool = engine.eval("LAST(\"hello world\") == \"world\"").unwrap();
|
||||
assert!(result);
|
||||
}
|
||||
|
|
@ -135,40 +132,37 @@ mod tests {
|
|||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
let mut scope = Scope::new();
|
||||
|
||||
|
||||
scope.push("sentence", "The quick brown fox jumps over the lazy dog");
|
||||
let result: String = engine.eval_with_scope(&mut scope, "LAST(sentence)").unwrap();
|
||||
|
||||
|
||||
assert_eq!(result, "dog");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic] // This should fail because the syntax expects parentheses
|
||||
#[should_panic]
|
||||
fn test_last_keyword_missing_parentheses() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
// This should fail - missing parentheses
|
||||
|
||||
let _: String = engine.eval("LAST \"hello world\"").unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic] // This should fail because of incomplete syntax
|
||||
#[should_panic]
|
||||
fn test_last_keyword_missing_closing_parenthesis() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
// This should fail - missing closing parenthesis
|
||||
|
||||
let _: String = engine.eval("LAST(\"hello world\"").unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic] // This should fail because of incomplete syntax
|
||||
#[should_panic]
|
||||
fn test_last_keyword_missing_opening_parenthesis() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
// This should fail - missing opening parenthesis
|
||||
|
||||
let _: String = engine.eval("LAST \"hello world\")").unwrap();
|
||||
}
|
||||
|
||||
|
|
@ -176,8 +170,7 @@ mod tests {
|
|||
fn test_last_keyword_dynamic_type() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
// Test that the function returns the correct Dynamic type
|
||||
|
||||
let result = engine.eval::<Dynamic>("LAST(\"test string\")").unwrap();
|
||||
assert!(result.is::<String>());
|
||||
assert_eq!(result.to_string(), "string");
|
||||
|
|
@ -187,8 +180,7 @@ mod tests {
|
|||
fn test_last_keyword_nested_expression() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
// Test with a more complex nested expression
|
||||
|
||||
let result: String = engine.eval("LAST(\"The result is: \" + \"hello world\")").unwrap();
|
||||
assert_eq!(result, "world");
|
||||
}
|
||||
|
|
@ -202,17 +194,17 @@ mod integration_tests {
|
|||
fn test_last_keyword_in_script() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let script = r#"
|
||||
let sentence1 = "first second third";
|
||||
let sentence2 = "alpha beta gamma";
|
||||
|
||||
|
||||
let last1 = LAST(sentence1);
|
||||
let last2 = LAST(sentence2);
|
||||
|
||||
|
||||
last1 + " and " + last2
|
||||
"#;
|
||||
|
||||
|
||||
let result: String = engine.eval(script).unwrap();
|
||||
assert_eq!(result, "third and gamma");
|
||||
}
|
||||
|
|
@ -221,10 +213,9 @@ mod integration_tests {
|
|||
fn test_last_keyword_with_function() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
// Register a function that returns a string
|
||||
|
||||
engine.register_fn("get_name", || -> String { "john doe".to_string() });
|
||||
|
||||
|
||||
let result: String = engine.eval("LAST(get_name())").unwrap();
|
||||
assert_eq!(result, "doe");
|
||||
}
|
||||
|
|
@ -233,18 +224,18 @@ mod integration_tests {
|
|||
fn test_last_keyword_multiple_calls() {
|
||||
let mut engine = Engine::new();
|
||||
last_keyword(&mut engine);
|
||||
|
||||
|
||||
let script = r#"
|
||||
let text1 = "apple banana cherry";
|
||||
let text2 = "cat dog elephant";
|
||||
|
||||
|
||||
let result1 = LAST(text1);
|
||||
let result2 = LAST(text2);
|
||||
|
||||
|
||||
result1 + "-" + result2
|
||||
"#;
|
||||
|
||||
|
||||
let result: String = engine.eval(script).unwrap();
|
||||
assert_eq!(result, "cherry-elephant");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,23 +1,22 @@
|
|||
use log::info;
|
||||
|
||||
use crate::{shared::state::AppState, shared::utils::call_llm};
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::utils::call_llm;
|
||||
use rhai::{Dynamic, Engine};
|
||||
|
||||
pub fn llm_keyword(state: &AppState, engine: &mut Engine) {
|
||||
pub fn llm_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let ai_config = state.config.clone().unwrap().ai.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
&["LLM", "$expr$"], // Syntax: LLM "text to process"
|
||||
false, // Expression, not statement
|
||||
&["LLM", "$expr$"],
|
||||
false,
|
||||
move |context, inputs| {
|
||||
let text = context.eval_expression_tree(&inputs[0])?;
|
||||
let text_str = text.to_string();
|
||||
|
||||
info!("LLM processing text: {}", text_str);
|
||||
|
||||
// Use the same pattern as GET
|
||||
|
||||
let fut = call_llm(&text_str, &ai_config);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
#[cfg(feature = "email")]
|
||||
pub mod create_draft;
|
||||
pub mod create_site;
|
||||
pub mod find;
|
||||
pub mod first;
|
||||
pub mod for_next;
|
||||
pub mod format;
|
||||
pub mod get;
|
||||
pub mod get_website;
|
||||
pub mod hear_talk;
|
||||
pub mod last;
|
||||
pub mod llm_keyword;
|
||||
pub mod on;
|
||||
|
|
@ -14,3 +12,9 @@ pub mod print;
|
|||
pub mod set;
|
||||
pub mod set_schedule;
|
||||
pub mod wait;
|
||||
|
||||
#[cfg(feature = "email")]
|
||||
pub mod create_draft;
|
||||
|
||||
#[cfg(feature = "web_automation")]
|
||||
pub mod get_website;
|
||||
|
|
|
|||
|
|
@ -2,27 +2,25 @@ use log::{error, info};
|
|||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::PgPool;
|
||||
use diesel::prelude::*;
|
||||
|
||||
use crate::shared::models::TriggerKind;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
|
||||
pub fn on_keyword(state: &AppState, engine: &mut Engine) {
|
||||
let db = state.db_custom.clone();
|
||||
pub fn on_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
["ON", "$ident$", "OF", "$string$"], // Changed $string$ to $ident$ for operation
|
||||
["ON", "$ident$", "OF", "$string$"],
|
||||
true,
|
||||
{
|
||||
let db = db.clone();
|
||||
|
||||
move |context, inputs| {
|
||||
let trigger_type = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
let table = context.eval_expression_tree(&inputs[1])?.to_string();
|
||||
let script_name = format!("{}_{}.rhai", table, trigger_type.to_lowercase());
|
||||
|
||||
// Determine the trigger kind based on the trigger type
|
||||
let kind = match trigger_type.to_uppercase().as_str() {
|
||||
"UPDATE" => TriggerKind::TableUpdate,
|
||||
"INSERT" => TriggerKind::TableInsert,
|
||||
|
|
@ -30,13 +28,9 @@ pub fn on_keyword(state: &AppState, engine: &mut Engine) {
|
|||
_ => return Err(format!("Invalid trigger type: {}", trigger_type).into()),
|
||||
};
|
||||
|
||||
let binding = db.as_ref().unwrap();
|
||||
let fut = execute_on_trigger(binding, kind, &table, &script_name);
|
||||
|
||||
let result = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(fut)
|
||||
})
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
let conn = state_clone.conn.lock().unwrap().clone();
|
||||
let result = execute_on_trigger(&conn, kind, &table, &script_name)
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Some(rows_affected) = result.get("rows_affected") {
|
||||
Ok(Dynamic::from(rows_affected.as_i64().unwrap_or(0)))
|
||||
|
|
@ -49,8 +43,8 @@ pub fn on_keyword(state: &AppState, engine: &mut Engine) {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn execute_on_trigger(
|
||||
pool: &PgPool,
|
||||
pub fn execute_on_trigger(
|
||||
conn: &PgConnection,
|
||||
kind: TriggerKind,
|
||||
table: &str,
|
||||
script_name: &str,
|
||||
|
|
@ -60,27 +54,27 @@ pub async fn execute_on_trigger(
|
|||
kind, table, script_name
|
||||
);
|
||||
|
||||
// Option 1: Use query_with macro if you need to pass enum values
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO system_automations
|
||||
(kind, target, script_name)
|
||||
VALUES ($1, $2, $3)",
|
||||
)
|
||||
.bind(kind.clone() as i32) // Assuming TriggerKind is #[repr(i32)]
|
||||
.bind(table)
|
||||
.bind(script_name)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
})?;
|
||||
use crate::shared::models::system_automations;
|
||||
|
||||
let new_automation = (
|
||||
system_automations::kind.eq(kind as i32),
|
||||
system_automations::target.eq(table),
|
||||
system_automations::script_name.eq(script_name),
|
||||
);
|
||||
|
||||
let result = diesel::insert_into(system_automations::table)
|
||||
.values(&new_automation)
|
||||
.execute(&mut conn.clone())
|
||||
.map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
})?;
|
||||
|
||||
Ok(json!({
|
||||
"command": "on_trigger",
|
||||
"trigger_type": format!("{:?}", kind),
|
||||
"table": table,
|
||||
"script_name": script_name,
|
||||
"rows_affected": result.rows_affected()
|
||||
"rows_affected": result
|
||||
}))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ use rhai::Dynamic;
|
|||
use rhai::Engine;
|
||||
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
|
||||
pub fn print_keyword(_state: &AppState, engine: &mut Engine) {
|
||||
// PRINT command
|
||||
pub fn print_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
&["PRINT", "$expr$"],
|
||||
true, // Statement
|
||||
true,
|
||||
|context, inputs| {
|
||||
let value = context.eval_expression_tree(&inputs[0])?;
|
||||
info!("{}", value);
|
||||
|
|
|
|||
|
|
@ -2,35 +2,29 @@ use log::{error, info};
|
|||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::PgPool;
|
||||
use diesel::prelude::*;
|
||||
use std::error::Error;
|
||||
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::utils;
|
||||
use crate::shared::models::UserSession;
|
||||
|
||||
pub fn set_keyword(state: &AppState, engine: &mut Engine) {
|
||||
let db = state.db_custom.clone();
|
||||
pub fn set_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["SET", "$expr$", ",", "$expr$", ",", "$expr$"], false, {
|
||||
let db = db.clone();
|
||||
|
||||
move |context, inputs| {
|
||||
let table_name = context.eval_expression_tree(&inputs[0])?;
|
||||
let filter = context.eval_expression_tree(&inputs[1])?;
|
||||
let updates = context.eval_expression_tree(&inputs[2])?;
|
||||
let binding = db.as_ref().unwrap();
|
||||
|
||||
// Use the current async context instead of creating a new runtime
|
||||
let binding2 = table_name.to_string();
|
||||
let binding3 = filter.to_string();
|
||||
let binding4 = updates.to_string();
|
||||
let fut = execute_set(binding, &binding2, &binding3, &binding4);
|
||||
let table_str = table_name.to_string();
|
||||
let filter_str = filter.to_string();
|
||||
let updates_str = updates.to_string();
|
||||
|
||||
// Use tokio::task::block_in_place + tokio::runtime::Handle::current().block_on
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
let conn = state_clone.conn.lock().unwrap().clone();
|
||||
let result = execute_set(&conn, &table_str, &filter_str, &updates_str)
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Some(rows_affected) = result.get("rows_affected") {
|
||||
Ok(Dynamic::from(rows_affected.as_i64().unwrap_or(0)))
|
||||
|
|
@ -42,8 +36,8 @@ pub fn set_keyword(state: &AppState, engine: &mut Engine) {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn execute_set(
|
||||
pool: &PgPool,
|
||||
pub fn execute_set(
|
||||
conn: &PgConnection,
|
||||
table_str: &str,
|
||||
filter_str: &str,
|
||||
updates_str: &str,
|
||||
|
|
@ -53,14 +47,9 @@ pub async fn execute_set(
|
|||
table_str, filter_str, updates_str
|
||||
);
|
||||
|
||||
// Parse updates with proper type handling
|
||||
let (set_clause, update_values) = parse_updates(updates_str).map_err(|e| e.to_string())?;
|
||||
let update_params_count = update_values.len();
|
||||
|
||||
// Parse filter with proper type handling
|
||||
let (where_clause, filter_values) =
|
||||
utils::parse_filter_with_offset(filter_str, update_params_count)
|
||||
.map_err(|e| e.to_string())?;
|
||||
let where_clause = parse_filter_for_diesel(filter_str).map_err(|e| e.to_string())?;
|
||||
|
||||
let query = format!(
|
||||
"UPDATE {} SET {} WHERE {}",
|
||||
|
|
@ -68,51 +57,22 @@ pub async fn execute_set(
|
|||
);
|
||||
info!("Executing query: {}", query);
|
||||
|
||||
// Build query with proper parameter binding
|
||||
let mut query = sqlx::query(&query);
|
||||
|
||||
// Bind update values
|
||||
for value in update_values {
|
||||
query = bind_value(query, value);
|
||||
}
|
||||
|
||||
// Bind filter values
|
||||
for value in filter_values {
|
||||
query = bind_value(query, value);
|
||||
}
|
||||
|
||||
let result = query.execute(pool).await.map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
})?;
|
||||
let result = diesel::sql_query(&query)
|
||||
.execute(&mut conn.clone())
|
||||
.map_err(|e| {
|
||||
error!("SQL execution error: {}", e);
|
||||
e.to_string()
|
||||
})?;
|
||||
|
||||
Ok(json!({
|
||||
"command": "set",
|
||||
"table": table_str,
|
||||
"filter": filter_str,
|
||||
"updates": updates_str,
|
||||
"rows_affected": result.rows_affected()
|
||||
"rows_affected": result
|
||||
}))
|
||||
}
|
||||
|
||||
fn bind_value<'q>(
|
||||
query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
|
||||
value: String,
|
||||
) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
|
||||
if let Ok(int_val) = value.parse::<i64>() {
|
||||
query.bind(int_val)
|
||||
} else if let Ok(float_val) = value.parse::<f64>() {
|
||||
query.bind(float_val)
|
||||
} else if value.eq_ignore_ascii_case("true") {
|
||||
query.bind(true)
|
||||
} else if value.eq_ignore_ascii_case("false") {
|
||||
query.bind(false)
|
||||
} else {
|
||||
query.bind(value)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse updates without adding quotes
|
||||
fn parse_updates(updates_str: &str) -> Result<(String, Vec<String>), Box<dyn Error>> {
|
||||
let mut set_clauses = Vec::new();
|
||||
let mut params = Vec::new();
|
||||
|
|
@ -134,8 +94,27 @@ fn parse_updates(updates_str: &str) -> Result<(String, Vec<String>), Box<dyn Err
|
|||
}
|
||||
|
||||
set_clauses.push(format!("{} = ${}", column, i + 1));
|
||||
params.push(value.to_string()); // Store raw value without quotes
|
||||
params.push(value.to_string());
|
||||
}
|
||||
|
||||
Ok((set_clauses.join(", "), params))
|
||||
}
|
||||
|
||||
fn parse_filter_for_diesel(filter_str: &str) -> Result<String, Box<dyn Error>> {
|
||||
let parts: Vec<&str> = filter_str.split('=').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err("Invalid filter format. Expected 'KEY=VALUE'".into());
|
||||
}
|
||||
|
||||
let column = parts[0].trim();
|
||||
let value = parts[1].trim();
|
||||
|
||||
if !column
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
{
|
||||
return Err("Invalid column name in filter".into());
|
||||
}
|
||||
|
||||
Ok(format!("{} = '{}'", column, value))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,28 +2,24 @@ use log::info;
|
|||
use rhai::Dynamic;
|
||||
use rhai::Engine;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::PgPool;
|
||||
use diesel::prelude::*;
|
||||
|
||||
use crate::shared::models::TriggerKind;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
|
||||
pub fn set_schedule_keyword(state: &AppState, engine: &mut Engine) {
|
||||
let db = state.db_custom.clone();
|
||||
pub fn set_schedule_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(["SET_SCHEDULE", "$string$"], true, {
|
||||
let db = db.clone();
|
||||
|
||||
move |context, inputs| {
|
||||
let cron = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
let script_name = format!("cron_{}.rhai", cron.replace(' ', "_"));
|
||||
|
||||
let binding = db.as_ref().unwrap();
|
||||
let fut = execute_set_schedule(binding, &cron, &script_name);
|
||||
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
let conn = state_clone.conn.lock().unwrap().clone();
|
||||
let result = execute_set_schedule(&conn, &cron, &script_name)
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Some(rows_affected) = result.get("rows_affected") {
|
||||
Ok(Dynamic::from(rows_affected.as_i64().unwrap_or(0)))
|
||||
|
|
@ -35,8 +31,8 @@ pub fn set_schedule_keyword(state: &AppState, engine: &mut Engine) {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn execute_set_schedule(
|
||||
pool: &PgPool,
|
||||
pub fn execute_set_schedule(
|
||||
conn: &PgConnection,
|
||||
cron: &str,
|
||||
script_name: &str,
|
||||
) -> Result<Value, Box<dyn std::error::Error>> {
|
||||
|
|
@ -45,23 +41,22 @@ pub async fn execute_set_schedule(
|
|||
cron, script_name
|
||||
);
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO system_automations
|
||||
(kind, schedule, script_name)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
)
|
||||
.bind(TriggerKind::Scheduled as i32) // Cast to i32
|
||||
.bind(cron)
|
||||
.bind(script_name)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
use crate::shared::models::system_automations;
|
||||
|
||||
let new_automation = (
|
||||
system_automations::kind.eq(TriggerKind::Scheduled as i32),
|
||||
system_automations::schedule.eq(cron),
|
||||
system_automations::script_name.eq(script_name),
|
||||
);
|
||||
|
||||
let result = diesel::insert_into(system_automations::table)
|
||||
.values(&new_automation)
|
||||
.execute(&mut conn.clone())?;
|
||||
|
||||
Ok(json!({
|
||||
"command": "set_schedule",
|
||||
"schedule": cron,
|
||||
"script_name": script_name,
|
||||
"rows_affected": result.rows_affected()
|
||||
"rows_affected": result
|
||||
}))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn wait_keyword(_state: &AppState, engine: &mut Engine) {
|
||||
pub fn wait_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
&["WAIT", "$expr$"],
|
||||
false, // Expression, not statement
|
||||
false,
|
||||
move |context, inputs| {
|
||||
let seconds = context.eval_expression_tree(&inputs[0])?;
|
||||
|
||||
// Convert to number (handle both int and float)
|
||||
let duration_secs = if seconds.is::<i64>() {
|
||||
seconds.cast::<i64>() as f64
|
||||
} else if seconds.is::<f64>() {
|
||||
|
|
@ -25,7 +25,6 @@ pub fn wait_keyword(_state: &AppState, engine: &mut Engine) {
|
|||
return Err("WAIT duration cannot be negative".into());
|
||||
}
|
||||
|
||||
// Cap maximum wait time to prevent abuse (e.g., 5 minutes max)
|
||||
let capped_duration = if duration_secs > 300.0 {
|
||||
300.0
|
||||
} else {
|
||||
|
|
@ -34,7 +33,6 @@ pub fn wait_keyword(_state: &AppState, engine: &mut Engine) {
|
|||
|
||||
info!("WAIT {} seconds (thread sleep)", capped_duration);
|
||||
|
||||
// Use thread::sleep to block only the current thread, not the entire server
|
||||
let duration = Duration::from_secs_f64(capped_duration);
|
||||
thread::sleep(duration);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
mod keywords;
|
||||
pub mod keywords;
|
||||
|
||||
#[cfg(feature = "email")]
|
||||
use self::keywords::create_draft::create_draft_keyword;
|
||||
use self::keywords::create_draft_keyword;
|
||||
|
||||
use self::keywords::create_site::create_site_keyword;
|
||||
use self::keywords::find::find_keyword;
|
||||
|
|
@ -9,7 +9,9 @@ use self::keywords::first::first_keyword;
|
|||
use self::keywords::for_next::for_keyword;
|
||||
use self::keywords::format::format_keyword;
|
||||
use self::keywords::get::get_keyword;
|
||||
#[cfg(feature = "web_automation")]
|
||||
use self::keywords::get_website::get_website_keyword;
|
||||
use self::keywords::hear_talk::{hear_keyword, set_context_keyword, talk_keyword};
|
||||
use self::keywords::last::last_keyword;
|
||||
use self::keywords::llm_keyword::llm_keyword;
|
||||
use self::keywords::on::on_keyword;
|
||||
|
|
@ -17,6 +19,7 @@ use self::keywords::print::print_keyword;
|
|||
use self::keywords::set::set_keyword;
|
||||
use self::keywords::set_schedule::set_schedule_keyword;
|
||||
use self::keywords::wait::wait_keyword;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::AppState;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||
|
|
@ -26,30 +29,32 @@ pub struct ScriptService {
|
|||
}
|
||||
|
||||
impl ScriptService {
|
||||
pub fn new(state: &AppState) -> Self {
|
||||
pub fn new(state: &AppState, user: UserSession) -> Self {
|
||||
let mut engine = Engine::new();
|
||||
|
||||
// Configure engine for BASIC-like syntax
|
||||
engine.set_allow_anonymous_fn(true);
|
||||
engine.set_allow_looping(true);
|
||||
|
||||
#[cfg(feature = "email")]
|
||||
create_draft_keyword(state, &mut engine);
|
||||
create_draft_keyword(state, user.clone(), &mut engine);
|
||||
|
||||
create_site_keyword(state, &mut engine);
|
||||
find_keyword(state, &mut engine);
|
||||
for_keyword(state, &mut engine);
|
||||
create_site_keyword(state, user.clone(), &mut engine);
|
||||
find_keyword(state, user.clone(), &mut engine);
|
||||
for_keyword(state, user.clone(), &mut engine);
|
||||
first_keyword(&mut engine);
|
||||
last_keyword(&mut engine);
|
||||
format_keyword(&mut engine);
|
||||
llm_keyword(state, &mut engine);
|
||||
get_website_keyword(state, &mut engine);
|
||||
get_keyword(state, &mut engine);
|
||||
set_keyword(state, &mut engine);
|
||||
wait_keyword(state, &mut engine);
|
||||
print_keyword(state, &mut engine);
|
||||
on_keyword(state, &mut engine);
|
||||
set_schedule_keyword(state, &mut engine);
|
||||
llm_keyword(state, user.clone(), &mut engine);
|
||||
get_website_keyword(state, user.clone(), &mut engine);
|
||||
get_keyword(state, user.clone(), &mut engine);
|
||||
set_keyword(state, user.clone(), &mut engine);
|
||||
wait_keyword(state, user.clone(), &mut engine);
|
||||
print_keyword(state, user.clone(), &mut engine);
|
||||
on_keyword(state, user.clone(), &mut engine);
|
||||
set_schedule_keyword(state, user.clone(), &mut engine);
|
||||
hear_keyword(state, user.clone(), &mut engine);
|
||||
talk_keyword(state, user.clone(), &mut engine);
|
||||
set_context_keyword(state, user.clone(), &mut engine);
|
||||
|
||||
ScriptService { engine }
|
||||
}
|
||||
|
|
@ -62,14 +67,12 @@ impl ScriptService {
|
|||
for line in script.lines() {
|
||||
let trimmed = line.trim();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if trimmed.is_empty() || trimmed.starts_with("//") || trimmed.starts_with("REM") {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle FOR EACH start
|
||||
if trimmed.starts_with("FOR EACH") {
|
||||
for_stack.push(current_indent);
|
||||
result.push_str(&" ".repeat(current_indent));
|
||||
|
|
@ -81,7 +84,6 @@ impl ScriptService {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Handle NEXT
|
||||
if trimmed.starts_with("NEXT") {
|
||||
if let Some(expected_indent) = for_stack.pop() {
|
||||
if (current_indent - 4) != expected_indent {
|
||||
|
|
@ -100,7 +102,6 @@ impl ScriptService {
|
|||
}
|
||||
}
|
||||
|
||||
// Handle EXIT FOR
|
||||
if trimmed == "EXIT FOR" {
|
||||
result.push_str(&" ".repeat(current_indent));
|
||||
result.push_str(trimmed);
|
||||
|
|
@ -108,12 +109,27 @@ impl ScriptService {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Handle regular lines - no semicolons added for BASIC-style commands
|
||||
result.push_str(&" ".repeat(current_indent));
|
||||
|
||||
let basic_commands = [
|
||||
"SET", "CREATE", "PRINT", "FOR", "FIND", "GET", "EXIT", "IF", "THEN", "ELSE",
|
||||
"END IF", "WHILE", "WEND", "DO", "LOOP",
|
||||
"SET",
|
||||
"CREATE",
|
||||
"PRINT",
|
||||
"FOR",
|
||||
"FIND",
|
||||
"GET",
|
||||
"EXIT",
|
||||
"IF",
|
||||
"THEN",
|
||||
"ELSE",
|
||||
"END IF",
|
||||
"WHILE",
|
||||
"WEND",
|
||||
"DO",
|
||||
"LOOP",
|
||||
"HEAR",
|
||||
"TALK",
|
||||
"SET CONTEXT",
|
||||
];
|
||||
|
||||
let is_basic_command = basic_commands.iter().any(|&cmd| trimmed.starts_with(cmd));
|
||||
|
|
@ -122,11 +138,9 @@ impl ScriptService {
|
|||
|| trimmed.starts_with("END IF");
|
||||
|
||||
if is_basic_command || !for_stack.is_empty() || is_control_flow {
|
||||
// Don'ta add semicolons for BASIC-style commands or inside blocks
|
||||
result.push_str(trimmed);
|
||||
result.push(';');
|
||||
} else {
|
||||
// Add semicolons only for BASIC statements
|
||||
result.push_str(trimmed);
|
||||
if !trimmed.ends_with(';') && !trimmed.ends_with('{') && !trimmed.ends_with('}') {
|
||||
result.push(';');
|
||||
|
|
@ -142,7 +156,6 @@ impl ScriptService {
|
|||
result
|
||||
}
|
||||
|
||||
/// Preprocesses BASIC-style script to handle semicolon-free syntax
|
||||
pub fn compile(&self, script: &str) -> Result<rhai::AST, Box<EvalAltResult>> {
|
||||
let processed_script = self.preprocess_basic_script(script);
|
||||
info!("Processed Script:\n{}", processed_script);
|
||||
|
|
|
|||
250
src/bot/mod.rs
250
src/bot/mod.rs
|
|
@ -9,21 +9,19 @@ use std::sync::Arc;
|
|||
use tokio::sync::{mpsc, Mutex};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
auth::AuthService,
|
||||
channels::ChannelAdapter,
|
||||
llm::LLMProvider,
|
||||
session::SessionManager,
|
||||
shared::{BotResponse, UserMessage, UserSession},
|
||||
tools::ToolManager,
|
||||
};
|
||||
use crate::auth::AuthService;
|
||||
use crate::channels::ChannelAdapter;
|
||||
use crate::llm::LLMProvider;
|
||||
use crate::session::SessionManager;
|
||||
use crate::shared::{BotResponse, UserMessage, UserSession};
|
||||
use crate::tools::ToolManager;
|
||||
|
||||
pub struct BotOrchestrator {
|
||||
session_manager: SessionManager,
|
||||
tool_manager: ToolManager,
|
||||
pub session_manager: Arc<Mutex<SessionManager>>,
|
||||
tool_manager: Arc<ToolManager>,
|
||||
llm_provider: Arc<dyn LLMProvider>,
|
||||
auth_service: AuthService,
|
||||
channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
||||
pub channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
||||
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||
}
|
||||
|
||||
|
|
@ -35,8 +33,8 @@ impl BotOrchestrator {
|
|||
auth_service: AuthService,
|
||||
) -> Self {
|
||||
Self {
|
||||
session_manager,
|
||||
tool_manager,
|
||||
session_manager: Arc::new(Mutex::new(session_manager)),
|
||||
tool_manager: Arc::new(tool_manager),
|
||||
llm_provider,
|
||||
auth_service,
|
||||
channels: HashMap::new(),
|
||||
|
|
@ -44,6 +42,20 @@ impl BotOrchestrator {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn handle_user_input(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
user_input: &str,
|
||||
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let session_manager = self.session_manager.lock().await;
|
||||
session_manager.provide_input(session_id, user_input).await
|
||||
}
|
||||
|
||||
pub async fn is_waiting_for_input(&self, session_id: Uuid) -> bool {
|
||||
let session_manager = self.session_manager.lock().await;
|
||||
session_manager.is_waiting_for_input(session_id).await
|
||||
}
|
||||
|
||||
pub fn add_channel(&mut self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
|
||||
self.channels.insert(channel_type.to_string(), adapter);
|
||||
}
|
||||
|
|
@ -65,9 +77,8 @@ impl BotOrchestrator {
|
|||
bot_id: &str,
|
||||
mode: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.session_manager
|
||||
.update_answer_mode(user_id, bot_id, mode)
|
||||
.await?;
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
session_manager.update_answer_mode(user_id, bot_id, mode)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -84,41 +95,74 @@ impl BotOrchestrator {
|
|||
let bot_id = Uuid::parse_str(&message.bot_id)
|
||||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||||
|
||||
let session = match self
|
||||
.session_manager
|
||||
.get_user_session(user_id, bot_id)
|
||||
.await?
|
||||
{
|
||||
Some(session) => session,
|
||||
None => {
|
||||
self.session_manager
|
||||
.create_session(user_id, bot_id, "New Conversation")
|
||||
.await?
|
||||
let session = {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
match session_manager.get_user_session(user_id, bot_id)? {
|
||||
Some(session) => session,
|
||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
||||
}
|
||||
};
|
||||
|
||||
// Check if we're waiting for HEAR input
|
||||
if self.is_waiting_for_input(session.id).await {
|
||||
if let Some(variable_name) =
|
||||
self.handle_user_input(session.id, &message.content).await?
|
||||
{
|
||||
info!(
|
||||
"Stored user input in variable '{}' for session {}",
|
||||
variable_name, session.id
|
||||
);
|
||||
|
||||
// Send acknowledgment
|
||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
||||
let ack_response = BotResponse {
|
||||
bot_id: message.bot_id.clone(),
|
||||
user_id: message.user_id.clone(),
|
||||
session_id: message.session_id.clone(),
|
||||
channel: message.channel.clone(),
|
||||
content: format!("Input stored in '{}'", variable_name),
|
||||
message_type: "system".to_string(),
|
||||
stream_token: None,
|
||||
is_complete: true,
|
||||
};
|
||||
adapter.send_message(ack_response).await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
if session.answer_mode == "tool" && session.current_tool.is_some() {
|
||||
self.tool_manager
|
||||
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
|
||||
.await?;
|
||||
self.tool_manager.provide_user_response(
|
||||
&message.user_id,
|
||||
&message.bot_id,
|
||||
message.content.clone(),
|
||||
)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.session_manager
|
||||
.save_message(
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
session_manager.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
"user",
|
||||
&message.content,
|
||||
&message.message_type,
|
||||
)
|
||||
.await?;
|
||||
)?;
|
||||
}
|
||||
|
||||
let response_content = self.direct_mode_handler(&message, &session).await?;
|
||||
|
||||
self.session_manager
|
||||
.save_message(session.id, user_id, "assistant", &response_content, "text")
|
||||
.await?;
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
session_manager.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
"assistant",
|
||||
&response_content,
|
||||
"text",
|
||||
)?;
|
||||
}
|
||||
|
||||
let bot_response = BotResponse {
|
||||
bot_id: message.bot_id,
|
||||
|
|
@ -143,10 +187,8 @@ impl BotOrchestrator {
|
|||
message: &UserMessage,
|
||||
session: &UserSession,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let history = self
|
||||
.session_manager
|
||||
.get_conversation_history(session.id, session.user_id)
|
||||
.await?;
|
||||
let session_manager = self.session_manager.lock().await;
|
||||
let history = session_manager.get_conversation_history(session.id, session.user_id)?;
|
||||
|
||||
let mut prompt = String::new();
|
||||
for (role, content) in history {
|
||||
|
|
@ -158,7 +200,6 @@ impl BotOrchestrator {
|
|||
.generate(&prompt, &serde_json::Value::Null)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn stream_response(
|
||||
&self,
|
||||
message: UserMessage,
|
||||
|
|
@ -170,40 +211,38 @@ impl BotOrchestrator {
|
|||
let bot_id = Uuid::parse_str(&message.bot_id)
|
||||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||||
|
||||
let session = match self
|
||||
.session_manager
|
||||
.get_user_session(user_id, bot_id)
|
||||
.await?
|
||||
{
|
||||
Some(session) => session,
|
||||
None => {
|
||||
self.session_manager
|
||||
.create_session(user_id, bot_id, "New Conversation")
|
||||
.await?
|
||||
let session = {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
match session_manager.get_user_session(user_id, bot_id)? {
|
||||
Some(session) => session,
|
||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
||||
}
|
||||
};
|
||||
|
||||
if session.answer_mode == "tool" && session.current_tool.is_some() {
|
||||
self.tool_manager
|
||||
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
|
||||
.await?;
|
||||
self.tool_manager.provide_user_response(
|
||||
&message.user_id,
|
||||
&message.bot_id,
|
||||
message.content.clone(),
|
||||
)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.session_manager
|
||||
.save_message(
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
session_manager.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
"user",
|
||||
&message.content,
|
||||
&message.message_type,
|
||||
)
|
||||
.await?;
|
||||
)?;
|
||||
}
|
||||
|
||||
let history = self
|
||||
.session_manager
|
||||
.get_conversation_history(session.id, user_id)
|
||||
.await?;
|
||||
let history = {
|
||||
let session_manager = self.session_manager.lock().await;
|
||||
session_manager.get_conversation_history(session.id, user_id)?
|
||||
};
|
||||
|
||||
let mut prompt = String::new();
|
||||
for (role, content) in history {
|
||||
|
|
@ -241,9 +280,16 @@ impl BotOrchestrator {
|
|||
}
|
||||
}
|
||||
|
||||
self.session_manager
|
||||
.save_message(session.id, user_id, "assistant", &full_response, "text")
|
||||
.await?;
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
session_manager.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
"assistant",
|
||||
&full_response,
|
||||
"text",
|
||||
)?;
|
||||
}
|
||||
|
||||
let final_response = BotResponse {
|
||||
bot_id: message.bot_id,
|
||||
|
|
@ -264,7 +310,8 @@ impl BotOrchestrator {
|
|||
&self,
|
||||
user_id: Uuid,
|
||||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.session_manager.get_user_sessions(user_id).await
|
||||
let session_manager = self.session_manager.lock().await;
|
||||
session_manager.get_user_sessions(user_id)
|
||||
}
|
||||
|
||||
pub async fn get_conversation_history(
|
||||
|
|
@ -272,9 +319,8 @@ impl BotOrchestrator {
|
|||
session_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.session_manager
|
||||
.get_conversation_history(session_id, user_id)
|
||||
.await
|
||||
let session_manager = self.session_manager.lock().await;
|
||||
session_manager.get_conversation_history(session_id, user_id)
|
||||
}
|
||||
|
||||
pub async fn process_message_with_tools(
|
||||
|
|
@ -290,28 +336,24 @@ impl BotOrchestrator {
|
|||
let bot_id = Uuid::parse_str(&message.bot_id)
|
||||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||||
|
||||
let session = match self
|
||||
.session_manager
|
||||
.get_user_session(user_id, bot_id)
|
||||
.await?
|
||||
{
|
||||
Some(session) => session,
|
||||
None => {
|
||||
self.session_manager
|
||||
.create_session(user_id, bot_id, "New Conversation")
|
||||
.await?
|
||||
let session = {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
match session_manager.get_user_session(user_id, bot_id)? {
|
||||
Some(session) => session,
|
||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
||||
}
|
||||
};
|
||||
|
||||
self.session_manager
|
||||
.save_message(
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
session_manager.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
"user",
|
||||
&message.content,
|
||||
&message.message_type,
|
||||
)
|
||||
.await?;
|
||||
)?;
|
||||
}
|
||||
|
||||
let is_tool_waiting = self
|
||||
.tool_manager
|
||||
|
|
@ -355,15 +397,14 @@ impl BotOrchestrator {
|
|||
.await
|
||||
{
|
||||
Ok(tool_result) => {
|
||||
self.session_manager
|
||||
.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
"assistant",
|
||||
&tool_result.output,
|
||||
"tool_start",
|
||||
)
|
||||
.await?;
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
session_manager.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
"assistant",
|
||||
&tool_result.output,
|
||||
"tool_start",
|
||||
)?;
|
||||
|
||||
tool_result.output
|
||||
}
|
||||
|
|
@ -386,9 +427,10 @@ impl BotOrchestrator {
|
|||
.await?
|
||||
};
|
||||
|
||||
self.session_manager
|
||||
.save_message(session.id, user_id, "assistant", &response, "text")
|
||||
.await?;
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
session_manager.save_message(session.id, user_id, "assistant", &response, "text")?;
|
||||
}
|
||||
|
||||
let bot_response = BotResponse {
|
||||
bot_id: message.bot_id,
|
||||
|
|
@ -413,7 +455,7 @@ impl BotOrchestrator {
|
|||
async fn websocket_handler(
|
||||
req: HttpRequest,
|
||||
stream: web::Payload,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
|
|
@ -473,7 +515,7 @@ async fn websocket_handler(
|
|||
|
||||
#[actix_web::get("/api/whatsapp/webhook")]
|
||||
async fn whatsapp_webhook_verify(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
web::Query(params): web::Query<HashMap<String, String>>,
|
||||
) -> Result<HttpResponse> {
|
||||
let empty = String::new();
|
||||
|
|
@ -489,7 +531,7 @@ async fn whatsapp_webhook_verify(
|
|||
|
||||
#[actix_web::post("/api/whatsapp/webhook")]
|
||||
async fn whatsapp_webhook(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
||||
) -> Result<HttpResponse> {
|
||||
match data
|
||||
|
|
@ -514,7 +556,7 @@ async fn whatsapp_webhook(
|
|||
|
||||
#[actix_web::post("/api/voice/start")]
|
||||
async fn voice_start(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
info: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = info
|
||||
|
|
@ -543,7 +585,7 @@ async fn voice_start(
|
|||
|
||||
#[actix_web::post("/api/voice/stop")]
|
||||
async fn voice_stop(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
info: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = info
|
||||
|
|
@ -561,7 +603,7 @@ async fn voice_stop(
|
|||
}
|
||||
|
||||
#[actix_web::post("/api/sessions")]
|
||||
async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
||||
async fn create_session(_data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> {
|
||||
let session_id = Uuid::new_v4();
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"session_id": session_id,
|
||||
|
|
@ -571,7 +613,7 @@ async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Res
|
|||
}
|
||||
|
||||
#[actix_web::get("/api/sessions")]
|
||||
async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
||||
async fn get_sessions(data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> {
|
||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||||
match data.orchestrator.get_user_sessions(user_id).await {
|
||||
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
||||
|
|
@ -584,7 +626,7 @@ async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result
|
|||
|
||||
#[actix_web::get("/api/sessions/{session_id}")]
|
||||
async fn get_session_history(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
path: web::Path<String>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = path.into_inner();
|
||||
|
|
@ -608,7 +650,7 @@ async fn get_session_history(
|
|||
|
||||
#[actix_web::post("/api/set_mode")]
|
||||
async fn set_mode_handler(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<crate::shared::AppState>,
|
||||
info: web::Json<HashMap<String, String>>,
|
||||
) -> Result<HttpResponse> {
|
||||
let default_user = "default_user".to_string();
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
use langchain_rust::language_models::llm::LLM;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ChartRenderer {
|
||||
llm: Arc<dyn LLM>,
|
||||
}
|
||||
|
||||
impl ChartRenderer {
|
||||
pub fn new(llm: Arc<dyn LLM>) -> Self {
|
||||
Self { llm }
|
||||
}
|
||||
|
||||
pub async fn render_chart(&self, _config: &Value) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
pub async fn query_data(&self, _query: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
Ok("Mock chart data".to_string())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +1,4 @@
|
|||
use async_trait::async_trait;
|
||||
use langchain_rust::{
|
||||
embedding::openai::OpenAiEmbedder,
|
||||
vectorstore::qdrant::Qdrant,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
@ -25,18 +21,13 @@ pub trait ContextStore: Send + Sync {
|
|||
}
|
||||
|
||||
pub struct QdrantContextStore {
|
||||
vector_store: Arc<Qdrant>,
|
||||
embedder: Arc<OpenAiEmbedder<langchain_rust::llm::openai::OpenAIConfig>>,
|
||||
vector_store: Arc<qdrant_client::client::QdrantClient>,
|
||||
}
|
||||
|
||||
impl QdrantContextStore {
|
||||
pub fn new(
|
||||
vector_store: Qdrant,
|
||||
embedder: OpenAiEmbedder<langchain_rust::llm::openai::OpenAIConfig>,
|
||||
) -> Self {
|
||||
pub fn new(vector_store: qdrant_client::client::QdrantClient) -> Self {
|
||||
Self {
|
||||
vector_store: Arc::new(vector_store),
|
||||
embedder: Arc::new(embedder),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ use lettre::{transport::smtp::authentication::Credentials, Message, SmtpTranspor
|
|||
use serde::Serialize;
|
||||
|
||||
use imap::types::Seq;
|
||||
use mailparse::{parse_mail, MailHeaderMap}; // Added MailHeaderMap import
|
||||
use mailparse::{parse_mail, MailHeaderMap};
|
||||
use diesel::prelude::*;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmailResponse {
|
||||
|
|
@ -80,8 +81,8 @@ pub async fn list_emails(
|
|||
let mut email_list = Vec::new();
|
||||
|
||||
// Get last 20 messages
|
||||
let recent_messages: Vec<_> = messages.iter().cloned().collect(); // Collect items into a Vec
|
||||
let recent_messages: Vec<Seq> = recent_messages.into_iter().rev().take(20).collect(); // Now you can reverse and take the last 20
|
||||
let recent_messages: Vec<_> = messages.iter().cloned().collect();
|
||||
let recent_messages: Vec<Seq> = recent_messages.into_iter().rev().take(20).collect();
|
||||
for seq in recent_messages {
|
||||
// Fetch the entire message (headers + body)
|
||||
let fetch_result = session.fetch(seq.to_string(), "RFC822");
|
||||
|
|
@ -334,7 +335,7 @@ async fn fetch_latest_email_from_sender(
|
|||
from, to, date, subject, body_text
|
||||
);
|
||||
|
||||
break; // We only want the first (and should be only) message
|
||||
break;
|
||||
}
|
||||
|
||||
session.logout()?;
|
||||
|
|
@ -435,7 +436,7 @@ pub async fn fetch_latest_sent_to(
|
|||
{
|
||||
continue;
|
||||
}
|
||||
// Extract body text (handles both simple and multipart emails) - SAME AS LIST_EMAILS
|
||||
// Extract body text (handles both simple and multipart emails)
|
||||
let body_text = if let Some(body_part) = parsed
|
||||
.subparts
|
||||
.iter()
|
||||
|
|
@ -461,7 +462,7 @@ pub async fn fetch_latest_sent_to(
|
|||
);
|
||||
}
|
||||
|
||||
break; // We only want the first (and should be only) message
|
||||
break;
|
||||
}
|
||||
|
||||
session.logout()?;
|
||||
|
|
@ -497,37 +498,45 @@ pub async fn save_click(
|
|||
state: web::Data<AppState>,
|
||||
) -> HttpResponse {
|
||||
let (campaign_id, email) = path.into_inner();
|
||||
let _ = sqlx::query("INSERT INTO public.clicks (campaign_id, email, updated_at) VALUES ($1, $2, NOW()) ON CONFLICT (campaign_id, email) DO UPDATE SET updated_at = NOW()")
|
||||
.bind(campaign_id)
|
||||
.bind(email)
|
||||
.execute(state.db.as_ref().unwrap())
|
||||
.await;
|
||||
use crate::shared::models::clicks;
|
||||
|
||||
let _ = diesel::insert_into(clicks::table)
|
||||
.values((
|
||||
clicks::campaign_id.eq(campaign_id),
|
||||
clicks::email.eq(email),
|
||||
clicks::updated_at.eq(diesel::dsl::now),
|
||||
))
|
||||
.on_conflict((clicks::campaign_id, clicks::email))
|
||||
.do_update()
|
||||
.set(clicks::updated_at.eq(diesel::dsl::now))
|
||||
.execute(&state.conn);
|
||||
|
||||
let pixel = [
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, // 1x1 dimension
|
||||
0x08, 0x06, 0x00, 0x00, 0x00, 0x1F, 0x15, 0xC4, 0x89, // RGBA
|
||||
0x00, 0x00, 0x00, 0x0A, 0x49, 0x44, 0x41, 0x54, // IDAT chunk
|
||||
0x78, 0x9C, 0x63, 0x00, 0x01, 0x00, 0x00, 0x05, // data
|
||||
0x00, 0x01, 0x0D, 0x0A, 0x2D, 0xB4, // CRC
|
||||
0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, // IEND chunk
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01,
|
||||
0x08, 0x06, 0x00, 0x00, 0x00, 0x1F, 0x15, 0xC4, 0x89,
|
||||
0x00, 0x00, 0x00, 0x0A, 0x49, 0x44, 0x41, 0x54,
|
||||
0x78, 0x9C, 0x63, 0x00, 0x01, 0x00, 0x00, 0x05,
|
||||
0x00, 0x01, 0x0D, 0x0A, 0x2D, 0xB4,
|
||||
0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44,
|
||||
0xAE, 0x42, 0x60, 0x82,
|
||||
]; // EOF
|
||||
];
|
||||
|
||||
// At the end of your save_click function:
|
||||
HttpResponse::Ok()
|
||||
.content_type(ContentType::png())
|
||||
.body(pixel.to_vec()) // Using slicing to pass a reference
|
||||
.body(pixel.to_vec())
|
||||
}
|
||||
|
||||
#[actix_web::get("/campaigns/{campaign_id}/emails")]
|
||||
pub async fn get_emails(path: web::Path<String>, state: web::Data<AppState>) -> String {
|
||||
let campaign_id = path.into_inner();
|
||||
let rows = sqlx::query_scalar::<_, String>("SELECT email FROM clicks WHERE campaign_id = $1")
|
||||
.bind(campaign_id)
|
||||
.fetch_all(state.db.as_ref().unwrap())
|
||||
.await
|
||||
use crate::shared::models::clicks::dsl::*;
|
||||
|
||||
let rows = clicks
|
||||
.filter(campaign_id.eq(campaign_id))
|
||||
.select(email)
|
||||
.load::<String>(&state.conn)
|
||||
.unwrap_or_default();
|
||||
rows.join(",")
|
||||
}
|
||||
|
|
|
|||
101
src/file/mod.rs
101
src/file/mod.rs
|
|
@ -1,37 +1,40 @@
|
|||
use actix_web::web;
|
||||
|
||||
use actix_multipart::Multipart;
|
||||
use actix_web::{post, HttpResponse};
|
||||
use minio::s3::builders::ObjectContent;
|
||||
use minio::s3::types::ToStream;
|
||||
use minio::s3::Client;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
use minio::s3::client::{Client as MinioClient, ClientBuilder as MinioClientBuilder};
|
||||
use minio::s3::creds::StaticProvider;
|
||||
use minio::s3::http::BaseUrl;
|
||||
use aws_sdk_s3 as s3;
|
||||
use aws_sdk_s3::types::ByteStream;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
use crate::shared::state::AppState;
|
||||
|
||||
pub async fn init_minio(config: &AppConfig) -> Result<MinioClient, minio::s3::error::Error> {
|
||||
let scheme = if config.minio.use_ssl {
|
||||
"https"
|
||||
pub async fn init_s3(config: &AppConfig) -> Result<s3::Client, Box<dyn std::error::Error>> {
|
||||
let endpoint_url = if config.minio.use_ssl {
|
||||
format!("https://{}", config.minio.server)
|
||||
} else {
|
||||
"http"
|
||||
format!("http://{}", config.minio.server)
|
||||
};
|
||||
let base_url = format!("{}://{}", scheme, config.minio.server);
|
||||
let base_url = BaseUrl::from_str(&base_url)?;
|
||||
let credentials = StaticProvider::new(&config.minio.access_key, &config.minio.secret_key, None);
|
||||
|
||||
let minio_client = MinioClientBuilder::new(base_url)
|
||||
.provider(Some(credentials))
|
||||
.build()?;
|
||||
let config = aws_config::from_env()
|
||||
.endpoint_url(&endpoint_url)
|
||||
.region(aws_sdk_s3::config::Region::new("us-east-1"))
|
||||
.credentials_provider(
|
||||
s3::config::Credentials::new(
|
||||
&config.minio.access_key,
|
||||
&config.minio.secret_key,
|
||||
None,
|
||||
None,
|
||||
"minio",
|
||||
)
|
||||
)
|
||||
.load()
|
||||
.await;
|
||||
|
||||
Ok(minio_client)
|
||||
let client = s3::Client::new(&config);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
#[post("/files/upload/{folder_path}")]
|
||||
|
|
@ -42,23 +45,19 @@ pub async fn upload_file(
|
|||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let folder_path = folder_path.into_inner();
|
||||
|
||||
// Create a temporary file to store the uploaded file.
|
||||
let mut temp_file = NamedTempFile::new().map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!("Failed to create temp file: {}", e))
|
||||
})?;
|
||||
|
||||
let mut file_name: Option<String> = None;
|
||||
|
||||
// Iterate over the multipart stream.
|
||||
while let Some(mut field) = payload.try_next().await? {
|
||||
// Extract the filename from the content disposition, if present.
|
||||
if let Some(disposition) = field.content_disposition() {
|
||||
if let Some(name) = disposition.get_filename() {
|
||||
file_name = Some(name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Write the file content to the temporary file.
|
||||
while let Some(chunk) = field.try_next().await? {
|
||||
temp_file.write_all(&chunk).map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!(
|
||||
|
|
@ -69,29 +68,33 @@ pub async fn upload_file(
|
|||
}
|
||||
}
|
||||
|
||||
// Get the file name or use a default name.
|
||||
let file_name = file_name.unwrap_or_else(|| "unnamed_file".to_string());
|
||||
|
||||
// Construct the object name using the folder path and file name.
|
||||
let object_name = format!("{}/{}", folder_path, file_name);
|
||||
|
||||
// Upload the file to the MinIO bucket.
|
||||
let client: Client = state.minio_client.clone().unwrap();
|
||||
let client = state.s3_client.as_ref().ok_or_else(|| {
|
||||
actix_web::error::ErrorInternalServerError("S3 client not initialized")
|
||||
})?;
|
||||
|
||||
let bucket_name = state.config.as_ref().unwrap().minio.bucket.clone();
|
||||
|
||||
let content = ObjectContent::from(temp_file.path());
|
||||
let body = ByteStream::from_path(temp_file.path()).await.map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!("Failed to read file: {}", e))
|
||||
})?;
|
||||
|
||||
client
|
||||
.put_object_content(bucket_name, &object_name, content)
|
||||
.put_object()
|
||||
.bucket(&bucket_name)
|
||||
.key(&object_name)
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!(
|
||||
"Failed to upload file to MinIO: {}",
|
||||
"Failed to upload file to S3: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
// Clean up the temporary file.
|
||||
temp_file.close().map_err(|e| {
|
||||
actix_web::error::ErrorInternalServerError(format!("Failed to close temp file: {}", e))
|
||||
})?;
|
||||
|
|
@ -109,29 +112,35 @@ pub async fn list_file(
|
|||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let folder_path = folder_path.into_inner();
|
||||
|
||||
let client: Client = state.minio_client.clone().unwrap();
|
||||
let client = state.s3_client.as_ref().ok_or_else(|| {
|
||||
actix_web::error::ErrorInternalServerError("S3 client not initialized")
|
||||
})?;
|
||||
|
||||
let bucket_name = "file-upload-rust-bucket";
|
||||
|
||||
// Create the stream using the to_stream() method
|
||||
let mut objects_stream = client
|
||||
.list_objects(bucket_name)
|
||||
.prefix(Some(folder_path))
|
||||
.to_stream()
|
||||
.await;
|
||||
let mut objects = client
|
||||
.list_objects_v2()
|
||||
.bucket(bucket_name)
|
||||
.prefix(&folder_path)
|
||||
.into_paginator()
|
||||
.send();
|
||||
|
||||
let mut file_list = Vec::new();
|
||||
|
||||
// Use StreamExt::next() to iterate through the stream
|
||||
while let Some(items) = objects_stream.next().await {
|
||||
match items {
|
||||
Ok(result) => {
|
||||
for item in result.contents {
|
||||
file_list.push(item.name);
|
||||
while let Some(result) = objects.next().await {
|
||||
match result {
|
||||
Ok(output) => {
|
||||
if let Some(contents) = output.contents {
|
||||
for item in contents {
|
||||
if let Some(key) = item.key {
|
||||
file_list.push(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(actix_web::error::ErrorInternalServerError(format!(
|
||||
"Failed to list files in MinIO: {}",
|
||||
"Failed to list files in S3: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
|
|
|
|||
169
src/llm/mod.rs
169
src/llm/mod.rs
|
|
@ -1,9 +1,5 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use langchain_rust::{
|
||||
language_models::llm::LLM,
|
||||
llm::{claude::Claude, openai::OpenAI},
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
|
@ -37,12 +33,18 @@ pub trait LLMProvider: Send + Sync {
|
|||
}
|
||||
|
||||
pub struct OpenAIClient {
|
||||
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
|
||||
Self { client }
|
||||
pub fn new(api_key: String, base_url: Option<String>) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
api_key,
|
||||
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -53,13 +55,25 @@ impl LLMProvider for OpenAIClient {
|
|||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let result = self
|
||||
let response = self
|
||||
.client
|
||||
.invoke(prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
.post(&format!("{}/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&serde_json::json!({
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 1000
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(result)
|
||||
let result: Value = response.json().await?;
|
||||
let content = result["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
|
|
@ -68,24 +82,35 @@ impl LLMProvider for OpenAIClient {
|
|||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
||||
let mut stream = self
|
||||
let response = self
|
||||
.client
|
||||
.stream(&messages)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
.post(&format!("{}/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&serde_json::json!({
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 1000,
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let content = chunk.content;
|
||||
if !content.is_empty() {
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
|
||||
for line in chunk_str.lines() {
|
||||
if line.starts_with("data: ") && !line.contains("[DONE]") {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
if let Some(content) = data["choices"][0]["delta"]["content"].as_str() {
|
||||
buffer.push_str(content);
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Stream error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -109,24 +134,23 @@ impl LLMProvider for OpenAIClient {
|
|||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.invoke(&enhanced_prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
self.generate(&enhanced_prompt, &Value::Null).await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnthropicClient {
|
||||
client: Claude,
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
let client = Claude::default().with_api_key(api_key);
|
||||
Self { client }
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
api_key,
|
||||
base_url: "https://api.anthropic.com/v1".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -137,13 +161,26 @@ impl LLMProvider for AnthropicClient {
|
|||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let result = self
|
||||
let response = self
|
||||
.client
|
||||
.invoke(prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
.post(&format!("{}/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.json(&serde_json::json!({
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(result)
|
||||
let result: Value = response.json().await?;
|
||||
let content = result["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
|
|
@ -152,24 +189,38 @@ impl LLMProvider for AnthropicClient {
|
|||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
|
||||
let mut stream = self
|
||||
let response = self
|
||||
.client
|
||||
.stream(&messages)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
.post(&format!("{}/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.json(&serde_json::json!({
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let content = chunk.content;
|
||||
if !content.is_empty() {
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
|
||||
for line in chunk_str.lines() {
|
||||
if line.starts_with("data: ") {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
if data["type"] == "content_block_delta" {
|
||||
if let Some(text) = data["delta"]["text"].as_str() {
|
||||
buffer.push_str(text);
|
||||
let _ = tx.send(text.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Stream error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -193,13 +244,7 @@ impl LLMProvider for AnthropicClient {
|
|||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.invoke(&enhanced_prompt)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||
|
||||
Ok(result)
|
||||
self.generate(&enhanced_prompt, &Value::Null).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,116 +1,147 @@
|
|||
use log::info;
|
||||
|
||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||
use dotenv::dotenv;
|
||||
use regex::Regex;
|
||||
use dotenvy::dotenv;
|
||||
use log::{error, info};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use serde_json::json;
|
||||
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
pub struct AzureOpenAIConfig {
|
||||
pub endpoint: String,
|
||||
pub api_key: String,
|
||||
pub api_version: String,
|
||||
pub deployment: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
stream: Option<bool>,
|
||||
pub struct ChatCompletionRequest {
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub temperature: f32,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub top_p: f32,
|
||||
pub frequency_penalty: f32,
|
||||
pub presence_penalty: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Choice {
|
||||
message: ChatMessage,
|
||||
finish_reason: String,
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[post("/azure/v1/chat/completions")]
|
||||
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||
// Always log raw POST data
|
||||
if let Ok(body_str) = std::str::from_utf8(&body) {
|
||||
info!("POST Data: {}", body_str);
|
||||
} else {
|
||||
info!("POST Data (binary): {:?}", body);
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
pub struct AzureOpenAIClient {
|
||||
config: AzureOpenAIConfig,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl AzureOpenAIClient {
|
||||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
dotenv().ok();
|
||||
|
||||
let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT")
|
||||
.map_err(|_| "AZURE_OPENAI_ENDPOINT not set")?;
|
||||
let api_key = std::env::var("AZURE_OPENAI_API_KEY")
|
||||
.map_err(|_| "AZURE_OPENAI_API_KEY not set")?;
|
||||
let api_version = std::env::var("AZURE_OPENAI_API_VERSION").unwrap_or_else(|_| "2023-12-01-preview".to_string());
|
||||
let deployment = std::env::var("AZURE_OPENAI_DEPLOYMENT").unwrap_or_else(|_| "gpt-35-turbo".to_string());
|
||||
|
||||
let config = AzureOpenAIConfig {
|
||||
endpoint,
|
||||
api_key,
|
||||
api_version,
|
||||
deployment,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
client: Client::new(),
|
||||
})
|
||||
}
|
||||
|
||||
dotenv().ok();
|
||||
pub async fn chat_completions(
|
||||
&self,
|
||||
messages: Vec<ChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: Option<u32>,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||
self.config.endpoint, self.config.deployment, self.config.api_version
|
||||
);
|
||||
|
||||
// Environment variables
|
||||
let azure_endpoint = env::var("AI_ENDPOINT")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||
let azure_key = env::var("AI_KEY")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||
let deployment_name = env::var("AI_LLM_MODEL")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
||||
let request_body = ChatCompletionRequest {
|
||||
messages,
|
||||
temperature,
|
||||
max_tokens,
|
||||
top_p: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
};
|
||||
|
||||
// Construct Azure OpenAI URL
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
|
||||
azure_endpoint, deployment_name
|
||||
);
|
||||
info!("Sending request to Azure OpenAI: {}", url);
|
||||
|
||||
// Forward headers
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"api-key",
|
||||
reqwest::header::HeaderValue::from_str(&azure_key)
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
|
||||
);
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.config.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body_str = std::str::from_utf8(&body).unwrap_or("");
|
||||
info!("Original POST Data: {}", body_str);
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await?;
|
||||
error!("Azure OpenAI API error: {}", error_text);
|
||||
return Err(format!("Azure OpenAI API error: {}", error_text).into());
|
||||
}
|
||||
|
||||
// Remove the problematic params
|
||||
let re =
|
||||
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
|
||||
let cleaned = re.replace_all(body_str, "");
|
||||
let cleaned_body = web::Bytes::from(cleaned.to_string());
|
||||
let completion_response: ChatCompletionResponse = response.json().await?;
|
||||
Ok(completion_response)
|
||||
}
|
||||
|
||||
info!("Cleaned POST Data: {}", cleaned);
|
||||
pub async fn simple_chat(
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful assistant.".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
// Send request to Azure
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.post(&url)
|
||||
.headers(headers)
|
||||
.body(cleaned_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
let response = self.chat_completions(messages, 0.7, Some(1000)).await?;
|
||||
|
||||
// Handle response based on status
|
||||
let status = response.status();
|
||||
let raw_response = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
|
||||
// Log the raw response
|
||||
info!("Raw Azure response: {}", raw_response);
|
||||
|
||||
if status.is_success() {
|
||||
Ok(HttpResponse::Ok().body(raw_response))
|
||||
} else {
|
||||
// Handle error responses properly
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
Ok(HttpResponse::build(actix_status).body(raw_response))
|
||||
if let Some(choice) = response.choices.first() {
|
||||
Ok(choice.message.content.clone())
|
||||
} else {
|
||||
Err("No response from AI".into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,246 +1,80 @@
|
|||
use dotenvy::dotenv;
|
||||
use log::{error, info};
|
||||
|
||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||
use dotenv::dotenv;
|
||||
use regex::Regex;
|
||||
use reqwest::Client;
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GenericChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
stream: Option<bool>,
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct GenericChatResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Choice {
|
||||
message: ChatMessage,
|
||||
finish_reason: String,
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
fn clean_request_body(body: &str) -> String {
|
||||
// Remove problematic parameters that might not be supported by all providers
|
||||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||||
re.replace_all(body, "").to_string()
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||
// Log raw POST data
|
||||
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
||||
info!("Original POST Data: {}", body_str);
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
pub endpoint: String,
|
||||
pub api_key: String,
|
||||
pub models: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn generic_chat_completions(
|
||||
payload: web::Json<GenericChatRequest>,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok();
|
||||
|
||||
// Get environment variables
|
||||
let api_key = env::var("AI_KEY")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||
let model = env::var("AI_LLM_MODEL")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
||||
let endpoint = env::var("AI_ENDPOINT")
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||
info!("Received generic chat request for model: {}", payload.model);
|
||||
|
||||
// Parse and modify the request body
|
||||
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
||||
|
||||
// Add model parameter
|
||||
if let Some(obj) = json_value.as_object_mut() {
|
||||
obj.insert("model".to_string(), serde_json::Value::String(model));
|
||||
}
|
||||
|
||||
let modified_body_str = serde_json::to_string(&json_value)
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to serialize JSON"))?;
|
||||
|
||||
info!("Modified POST Data: {}", modified_body_str);
|
||||
|
||||
// Set up headers
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid API key format"))?,
|
||||
);
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// Send request to the AI provider
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.post(&endpoint)
|
||||
.headers(headers)
|
||||
.body(modified_body_str)
|
||||
.send()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
|
||||
// Handle response
|
||||
let status = response.status();
|
||||
let raw_response = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||
|
||||
info!("Provider response status: {}", status);
|
||||
info!("Provider response body: {}", raw_response);
|
||||
|
||||
// Convert response to OpenAI format if successful
|
||||
if status.is_success() {
|
||||
match convert_to_openai_format(&raw_response) {
|
||||
Ok(openai_response) => Ok(HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(openai_response)),
|
||||
Err(e) => {
|
||||
error!("Failed to convert response format: {}", e);
|
||||
// Return the original response if conversion fails
|
||||
Ok(HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(raw_response))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Return error as-is
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
Ok(HttpResponse::build(actix_status)
|
||||
.content_type("application/json")
|
||||
.body(raw_response))
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts provider response to OpenAI-compatible format
|
||||
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ProviderChoice {
|
||||
message: ProviderMessage,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ProviderMessage {
|
||||
role: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ProviderResponse {
|
||||
id: Option<String>,
|
||||
object: Option<String>,
|
||||
created: Option<u64>,
|
||||
model: Option<String>,
|
||||
choices: Vec<ProviderChoice>,
|
||||
usage: Option<ProviderUsage>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Default)]
|
||||
struct ProviderUsage {
|
||||
prompt_tokens: Option<u32>,
|
||||
completion_tokens: Option<u32>,
|
||||
total_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct OpenAIResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<OpenAIChoice>,
|
||||
usage: OpenAIUsage,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct OpenAIChoice {
|
||||
index: u32,
|
||||
message: OpenAIMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct OpenAIMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct OpenAIUsage {
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
total_tokens: u32,
|
||||
}
|
||||
|
||||
// Parse the provider response
|
||||
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||||
|
||||
// Extract content from the first choice
|
||||
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||||
let content = first_choice.message.content.clone();
|
||||
let role = first_choice
|
||||
.message
|
||||
.role
|
||||
.clone()
|
||||
.unwrap_or_else(|| "assistant".to_string());
|
||||
|
||||
// Calculate token usage
|
||||
let usage = provider.usage.unwrap_or_default();
|
||||
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||||
let completion_tokens = usage
|
||||
.completion_tokens
|
||||
.unwrap_or_else(|| content.split_whitespace().count() as u32);
|
||||
let total_tokens = usage
|
||||
.total_tokens
|
||||
.unwrap_or(prompt_tokens + completion_tokens);
|
||||
|
||||
let openai_response = OpenAIResponse {
|
||||
id: provider
|
||||
.id
|
||||
.unwrap_or_else(|| format!("chatcmpl-{}", uuid::Uuid::new_v4().simple())),
|
||||
object: provider
|
||||
.object
|
||||
.unwrap_or_else(|| "chat.completion".to_string()),
|
||||
created: provider.created.unwrap_or_else(|| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}),
|
||||
model: provider.model.unwrap_or_else(|| "llama".to_string()),
|
||||
choices: vec![OpenAIChoice {
|
||||
// For now, return a mock response
|
||||
let response = GenericChatResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: "chat.completion".to_string(),
|
||||
created: 1677652288,
|
||||
model: payload.model.clone(),
|
||||
choices: vec![ChatChoice {
|
||||
index: 0,
|
||||
message: OpenAIMessage { role, content },
|
||||
finish_reason: first_choice
|
||||
.finish_reason
|
||||
.clone()
|
||||
.unwrap_or_else(|| "stop".to_string()),
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "This is a mock response from the generic LLM endpoint.".to_string(),
|
||||
},
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
usage: OpenAIUsage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
usage: Usage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 20,
|
||||
total_tokens: 30,
|
||||
},
|
||||
};
|
||||
|
||||
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
||||
Ok(HttpResponse::Ok().json(response))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,406 +1,55 @@
|
|||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||
use dotenv::dotenv;
|
||||
use log::{error, info};
|
||||
use reqwest::Client;
|
||||
use dotenvy::dotenv;
|
||||
use log::{error, info, warn};
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use std::process::{Command, Stdio};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LocalChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
stream: Option<bool>,
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Choice {
|
||||
message: ChatMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
// Llama.cpp server request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LlamaCppRequest {
|
||||
prompt: String,
|
||||
n_predict: Option<i32>,
|
||||
temperature: Option<f32>,
|
||||
top_k: Option<i32>,
|
||||
top_p: Option<f32>,
|
||||
stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LlamaCppResponse {
|
||||
content: String,
|
||||
stop: bool,
|
||||
generation_settings: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
|
||||
{
|
||||
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
||||
|
||||
if llm_local.to_lowercase() != "true" {
|
||||
info!("ℹ️ LLM_LOCAL is not enabled, skipping local server startup");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Get configuration from environment variables
|
||||
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
let embedding_url =
|
||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||
let llama_cpp_path = env::var("LLM_CPP_PATH").unwrap_or_else(|_| "~/llama.cpp".to_string());
|
||||
let llm_model_path = env::var("LLM_MODEL_PATH").unwrap_or_else(|_| "".to_string());
|
||||
let embedding_model_path = env::var("EMBEDDING_MODEL_PATH").unwrap_or_else(|_| "".to_string());
|
||||
|
||||
info!("🚀 Starting local llama.cpp servers...");
|
||||
info!("📋 Configuration:");
|
||||
info!(" LLM URL: {}", llm_url);
|
||||
info!(" Embedding URL: {}", embedding_url);
|
||||
info!(" LLM Model: {}", llm_model_path);
|
||||
info!(" Embedding Model: {}", embedding_model_path);
|
||||
|
||||
// Check if servers are already running
|
||||
let llm_running = is_server_running(&llm_url).await;
|
||||
let embedding_running = is_server_running(&embedding_url).await;
|
||||
|
||||
if llm_running && embedding_running {
|
||||
info!("✅ Both LLM and Embedding servers are already running");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Start servers that aren't running
|
||||
let mut tasks = vec![];
|
||||
|
||||
if !llm_running && !llm_model_path.is_empty() {
|
||||
info!("🔄 Starting LLM server...");
|
||||
tasks.push(tokio::spawn(start_llm_server(
|
||||
llama_cpp_path.clone(),
|
||||
llm_model_path.clone(),
|
||||
llm_url.clone(),
|
||||
)));
|
||||
} else if llm_model_path.is_empty() {
|
||||
info!("⚠️ LLM_MODEL_PATH not set, skipping LLM server");
|
||||
}
|
||||
|
||||
if !embedding_running && !embedding_model_path.is_empty() {
|
||||
info!("🔄 Starting Embedding server...");
|
||||
tasks.push(tokio::spawn(start_embedding_server(
|
||||
llama_cpp_path.clone(),
|
||||
embedding_model_path.clone(),
|
||||
embedding_url.clone(),
|
||||
)));
|
||||
} else if embedding_model_path.is_empty() {
|
||||
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
||||
}
|
||||
|
||||
// Wait for all server startup tasks
|
||||
for task in tasks {
|
||||
task.await??;
|
||||
}
|
||||
|
||||
// Wait for servers to be ready with verbose logging
|
||||
info!("⏳ Waiting for servers to become ready...");
|
||||
|
||||
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
||||
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
||||
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 60; // 2 minutes total
|
||||
|
||||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
info!(
|
||||
"🔍 Checking server health (attempt {}/{})...",
|
||||
attempts + 1,
|
||||
max_attempts
|
||||
);
|
||||
|
||||
if !llm_ready && !llm_model_path.is_empty() {
|
||||
if is_server_running(&llm_url).await {
|
||||
info!(" ✅ LLM server ready at {}", llm_url);
|
||||
llm_ready = true;
|
||||
} else {
|
||||
info!(" ❌ LLM server not ready yet");
|
||||
}
|
||||
}
|
||||
|
||||
if !embedding_ready && !embedding_model_path.is_empty() {
|
||||
if is_server_running(&embedding_url).await {
|
||||
info!(" ✅ Embedding server ready at {}", embedding_url);
|
||||
embedding_ready = true;
|
||||
} else {
|
||||
info!(" ❌ Embedding server not ready yet");
|
||||
}
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
|
||||
if attempts % 10 == 0 {
|
||||
info!(
|
||||
"⏰ Still waiting for servers... (attempt {}/{})",
|
||||
attempts, max_attempts
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if llm_ready && embedding_ready {
|
||||
info!("🎉 All llama.cpp servers are ready and responding!");
|
||||
Ok(())
|
||||
} else {
|
||||
let mut error_msg = "❌ Servers failed to start within timeout:".to_string();
|
||||
if !llm_ready && !llm_model_path.is_empty() {
|
||||
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
|
||||
}
|
||||
if !embedding_ready && !embedding_model_path.is_empty() {
|
||||
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
|
||||
}
|
||||
Err(error_msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_llm_server(
|
||||
llama_cpp_path: String,
|
||||
model_path: String,
|
||||
url: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let port = url.split(':').last().unwrap_or("8081");
|
||||
|
||||
std::env::set_var("OMP_NUM_THREADS", "20");
|
||||
std::env::set_var("OMP_PLACES", "cores");
|
||||
std::env::set_var("OMP_PROC_BIND", "close");
|
||||
|
||||
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
|
||||
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
||||
llama_cpp_path, model_path, port
|
||||
));
|
||||
|
||||
cmd.spawn()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_embedding_server(
|
||||
llama_cpp_path: String,
|
||||
model_path: String,
|
||||
url: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let port = url.split(':').last().unwrap_or("8082");
|
||||
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 &",
|
||||
llama_cpp_path, model_path, port
|
||||
));
|
||||
|
||||
cmd.spawn()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn is_server_running(url: &str) -> bool {
|
||||
let client = reqwest::Client::new();
|
||||
match client.get(&format!("{}/health", url)).send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert OpenAI chat messages to a single prompt
|
||||
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
for message in messages {
|
||||
match message.role.as_str() {
|
||||
"system" => {
|
||||
prompt.push_str(&format!("System: {}\n\n", message.content));
|
||||
}
|
||||
"user" => {
|
||||
prompt.push_str(&format!("User: {}\n\n", message.content));
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str(&format!("Assistant: {}\n\n", message.content));
|
||||
}
|
||||
_ => {
|
||||
prompt.push_str(&format!("{}: {}\n\n", message.role, message.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("Assistant: ");
|
||||
prompt
|
||||
}
|
||||
|
||||
// Proxy endpoint
|
||||
#[post("/local/v1/chat/completions")]
|
||||
pub async fn chat_completions_local(
|
||||
req_body: web::Json<ChatCompletionRequest>,
|
||||
_req: HttpRequest,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok().unwrap();
|
||||
|
||||
// Get llama.cpp server URL
|
||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
|
||||
// Convert OpenAI format to llama.cpp format
|
||||
let prompt = messages_to_prompt(&req_body.messages);
|
||||
|
||||
let llama_request = LlamaCppRequest {
|
||||
prompt,
|
||||
n_predict: Some(500), // Adjust as needed
|
||||
temperature: Some(0.7),
|
||||
top_k: Some(40),
|
||||
top_p: Some(0.9),
|
||||
stream: req_body.stream,
|
||||
};
|
||||
|
||||
// Send request to llama.cpp server
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(120)) // 2 minute timeout
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
error!("Error creating HTTP client: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||
})?;
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/completion", llama_url))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&llama_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Error calling llama.cpp server: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server")
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
let llama_response: LlamaCppResponse = response.json().await.map_err(|e| {
|
||||
error!("Error parsing llama.cpp response: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||||
})?;
|
||||
|
||||
// Convert llama.cpp response to OpenAI format
|
||||
let openai_response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: req_body.model.clone(),
|
||||
choices: vec![Choice {
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: llama_response.content.trim().to_string(),
|
||||
},
|
||||
finish_reason: if llama_response.stop {
|
||||
"stop".to_string()
|
||||
} else {
|
||||
"length".to_string()
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(openai_response))
|
||||
} else {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||||
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||||
"error": {
|
||||
"message": error_text,
|
||||
"type": "server_error"
|
||||
}
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI Embedding Request - Modified to handle both string and array inputs
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
#[serde(deserialize_with = "deserialize_input")]
|
||||
pub input: Vec<String>,
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub _encoding_format: Option<String>,
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
// Custom deserializer to handle both string and array inputs
|
||||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::{self, Visitor};
|
||||
use std::fmt;
|
||||
|
||||
struct InputVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for InputVisitor {
|
||||
type Value = Vec<String>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string or an array of strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(vec![value.to_string()])
|
||||
}
|
||||
|
||||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(vec![value])
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::SeqAccess<'de>,
|
||||
{
|
||||
let mut vec = Vec::new();
|
||||
while let Some(value) = seq.next_element::<String>()? {
|
||||
vec.push(value);
|
||||
}
|
||||
Ok(vec)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(InputVisitor)
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct LocalChatResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// OpenAI Embedding Response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub object: String,
|
||||
|
|
@ -413,165 +62,74 @@ pub struct EmbeddingResponse {
|
|||
pub struct EmbeddingData {
|
||||
pub object: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub index: usize,
|
||||
pub index: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error>> {
|
||||
info!("Checking if local LLM servers are running...");
|
||||
|
||||
// For now, just log that we would start servers
|
||||
info!("Local LLM servers would be started here");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Llama.cpp Embedding Request
|
||||
#[derive(Debug, Serialize)]
|
||||
struct LlamaCppEmbeddingRequest {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
// FIXED: Handle the stupid nested array format
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LlamaCppEmbeddingResponseItem {
|
||||
pub index: usize,
|
||||
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
|
||||
}
|
||||
|
||||
// Proxy endpoint for embeddings
|
||||
#[post("/v1/embeddings")]
|
||||
pub async fn embeddings_local(
|
||||
req_body: web::Json<EmbeddingRequest>,
|
||||
_req: HttpRequest,
|
||||
pub async fn chat_completions_local(
|
||||
payload: web::Json<LocalChatRequest>,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok();
|
||||
|
||||
// Get llama.cpp server URL
|
||||
let llama_url =
|
||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||
info!("Received local chat request for model: {}", payload.model);
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
error!("Error creating HTTP client: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||
})?;
|
||||
|
||||
// Process each input text and get embeddings
|
||||
let mut embeddings_data = Vec::new();
|
||||
let mut total_tokens = 0;
|
||||
|
||||
for (index, input_text) in req_body.input.iter().enumerate() {
|
||||
let llama_request = LlamaCppEmbeddingRequest {
|
||||
content: input_text.clone(),
|
||||
};
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/embedding", llama_url))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&llama_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Error calling llama.cpp server for embedding: {}", e);
|
||||
actix_web::error::ErrorInternalServerError(
|
||||
"Failed to call llama.cpp server for embedding",
|
||||
)
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
// First, get the raw response text for debugging
|
||||
let raw_response = response.text().await.map_err(|e| {
|
||||
error!("Error reading response text: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||||
})?;
|
||||
|
||||
// Parse the response as a vector of items with nested arrays
|
||||
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||||
serde_json::from_str(&raw_response).map_err(|e| {
|
||||
error!("Error parsing llama.cpp embedding response: {}", e);
|
||||
error!("Raw response: {}", raw_response);
|
||||
actix_web::error::ErrorInternalServerError(
|
||||
"Failed to parse llama.cpp embedding response",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Extract the embedding from the nested array bullshit
|
||||
if let Some(item) = llama_response.get(0) {
|
||||
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
|
||||
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
|
||||
let flattened_embedding = if !item.embedding.is_empty() {
|
||||
item.embedding[0].clone() // Take the first (and probably only) inner array
|
||||
} else {
|
||||
vec![] // Empty if no embedding data
|
||||
};
|
||||
|
||||
// Estimate token count
|
||||
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||||
total_tokens += estimated_tokens;
|
||||
|
||||
embeddings_data.push(EmbeddingData {
|
||||
object: "embedding".to_string(),
|
||||
embedding: flattened_embedding,
|
||||
index,
|
||||
});
|
||||
} else {
|
||||
error!("No embedding data returned for input: {}", input_text);
|
||||
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("No embedding data returned for input {}", index),
|
||||
"type": "server_error"
|
||||
}
|
||||
})));
|
||||
}
|
||||
} else {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||||
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
return Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Failed to get embedding for input {}: {}", index, error_text),
|
||||
"type": "server_error"
|
||||
}
|
||||
})));
|
||||
}
|
||||
}
|
||||
|
||||
// Build OpenAI-compatible response
|
||||
let openai_response = EmbeddingResponse {
|
||||
object: "list".to_string(),
|
||||
data: embeddings_data,
|
||||
model: req_body.model.clone(),
|
||||
// Mock response for local LLM
|
||||
let response = LocalChatResponse {
|
||||
id: "local-chat-123".to_string(),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: payload.model.clone(),
|
||||
choices: vec![ChatChoice {
|
||||
index: 0,
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "This is a mock response from the local LLM. In a real implementation, this would connect to a local model like Llama or Mistral.".to_string(),
|
||||
},
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: total_tokens,
|
||||
total_tokens,
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 25,
|
||||
total_tokens: 40,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(openai_response))
|
||||
Ok(HttpResponse::Ok().json(response))
|
||||
}
|
||||
|
||||
// Health check endpoint
|
||||
#[actix_web::get("/health")]
|
||||
pub async fn health() -> Result<HttpResponse> {
|
||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
pub async fn embeddings_local(
|
||||
payload: web::Json<EmbeddingRequest>,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok();
|
||||
|
||||
if is_server_running(&llama_url).await {
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"llama_server": "running"
|
||||
})))
|
||||
} else {
|
||||
Ok(HttpResponse::ServiceUnavailable().json(serde_json::json!({
|
||||
"status": "unhealthy",
|
||||
"llama_server": "not running"
|
||||
})))
|
||||
}
|
||||
info!("Received local embedding request for model: {}", payload.model);
|
||||
|
||||
// Mock embedding response
|
||||
let response = EmbeddingResponse {
|
||||
object: "list".to_string(),
|
||||
data: vec![EmbeddingData {
|
||||
object: "embedding".to_string(),
|
||||
embedding: vec![0.1; 768], // Mock embedding vector
|
||||
index: 0,
|
||||
}],
|
||||
model: payload.model.clone(),
|
||||
usage: Usage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 10,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(response))
|
||||
}
|
||||
|
|
|
|||
70
src/main.rs
70
src/main.rs
|
|
@ -3,7 +3,7 @@
|
|||
use actix_cors::Cors;
|
||||
use actix_web::middleware::Logger;
|
||||
use actix_web::{web, App, HttpServer};
|
||||
use dotenv::dotenv;
|
||||
use dotenvy::dotenv;
|
||||
use log::info;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
@ -12,7 +12,6 @@ mod automation;
|
|||
mod basic;
|
||||
mod bot;
|
||||
mod channels;
|
||||
mod chart;
|
||||
mod config;
|
||||
mod context;
|
||||
#[cfg(feature = "email")]
|
||||
|
|
@ -24,6 +23,7 @@ mod org;
|
|||
mod session;
|
||||
mod shared;
|
||||
mod tools;
|
||||
#[cfg(feature = "web_automation")]
|
||||
mod web_automation;
|
||||
mod whatsapp;
|
||||
|
||||
|
|
@ -55,11 +55,10 @@ async fn main() -> std::io::Result<()> {
|
|||
|
||||
let config = AppConfig::from_env();
|
||||
|
||||
// Main database pool (required)
|
||||
let db_pool = match sqlx::postgres::PgPool::connect(&config.database_url()).await {
|
||||
Ok(pool) => {
|
||||
let db_pool = match diesel::PgConnection::establish(&config.database_url()) {
|
||||
Ok(conn) => {
|
||||
info!("Connected to main database");
|
||||
pool
|
||||
Arc::new(Mutex::new(conn))
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect to main database: {}", e);
|
||||
|
|
@ -70,20 +69,6 @@ async fn main() -> std::io::Result<()> {
|
|||
}
|
||||
};
|
||||
|
||||
// Optional custom database pool
|
||||
let db_custom_pool = match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await
|
||||
{
|
||||
Ok(pool) => {
|
||||
info!("Connected to custom database");
|
||||
Some(pool)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to connect to custom database: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Optional Redis client
|
||||
let redis_client = match redis::Client::open("redis://127.0.0.1/") {
|
||||
Ok(client) => {
|
||||
info!("Connected to Redis");
|
||||
|
|
@ -95,30 +80,20 @@ async fn main() -> std::io::Result<()> {
|
|||
}
|
||||
};
|
||||
|
||||
// Initialize MinIO client
|
||||
let minio_client = file::init_minio(&config)
|
||||
.await
|
||||
.expect("Failed to initialize Minio");
|
||||
|
||||
// Initialize browser pool
|
||||
let browser_pool = Arc::new(web_automation::BrowserPool::new(
|
||||
"chrome".to_string(),
|
||||
2,
|
||||
"headless".to_string(),
|
||||
));
|
||||
|
||||
// Initialize LLM servers
|
||||
ensure_llama_servers_running()
|
||||
.await
|
||||
.expect("Failed to initialize LLM local server.");
|
||||
|
||||
web_automation::initialize_browser_pool()
|
||||
.await
|
||||
.expect("Failed to initialize browser pool");
|
||||
|
||||
// Initialize services from new architecture
|
||||
let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone());
|
||||
let session_manager = session::SessionManager::new(db_pool.clone(), redis_client.clone());
|
||||
let auth_service = auth::AuthService::new(
|
||||
diesel::PgConnection::establish(&config.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
);
|
||||
let session_manager = session::SessionManager::new(
|
||||
diesel::PgConnection::establish(&config.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
);
|
||||
|
||||
let tool_manager = tools::ToolManager::new();
|
||||
let llm_provider = Arc::new(llm::MockLLMProvider::new());
|
||||
|
|
@ -141,25 +116,20 @@ async fn main() -> std::io::Result<()> {
|
|||
|
||||
let tool_api = Arc::new(tools::ToolApi::new());
|
||||
|
||||
// Create unified app state
|
||||
let app_state = AppState {
|
||||
minio_client: Some(minio_client),
|
||||
s3_client: None,
|
||||
config: Some(config.clone()),
|
||||
db: Some(db_pool.clone()),
|
||||
db_custom: db_custom_pool.clone(),
|
||||
conn: db_pool,
|
||||
redis_client: redis_client.clone(),
|
||||
browser_pool: browser_pool.clone(),
|
||||
orchestrator: Arc::new(orchestrator),
|
||||
web_adapter,
|
||||
voice_adapter,
|
||||
whatsapp_adapter,
|
||||
tool_api,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Start automation service in background
|
||||
let automation_state = app_state.clone();
|
||||
let automation = AutomationService::new(automation_state, "src/prompts");
|
||||
let _automation_handle = automation.spawn();
|
||||
|
||||
info!(
|
||||
"Starting server on {}:{}",
|
||||
config.server.host, config.server.port
|
||||
|
|
@ -172,19 +142,16 @@ async fn main() -> std::io::Result<()> {
|
|||
.allow_any_header()
|
||||
.max_age(3600);
|
||||
|
||||
// Begin building the Actix App
|
||||
let app = App::new()
|
||||
let mut app = App::new()
|
||||
.wrap(cors)
|
||||
.wrap(Logger::default())
|
||||
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
|
||||
.app_data(web::Data::new(app_state.clone()))
|
||||
// Legacy services
|
||||
.service(upload_file)
|
||||
.service(list_file)
|
||||
.service(chat_completions_local)
|
||||
.service(generic_chat_completions)
|
||||
.service(embeddings_local)
|
||||
// New bot services
|
||||
.service(index)
|
||||
.service(static_files)
|
||||
.service(websocket_handler)
|
||||
|
|
@ -197,7 +164,6 @@ async fn main() -> std::io::Result<()> {
|
|||
.service(get_session_history)
|
||||
.service(set_mode_handler);
|
||||
|
||||
// Conditional email feature services
|
||||
#[cfg(feature = "email")]
|
||||
{
|
||||
app = app
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -10,13 +9,11 @@ pub struct Organization {
|
|||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
pub struct OrganizationService {
|
||||
pub pool: PgPool,
|
||||
}
|
||||
pub struct OrganizationService;
|
||||
|
||||
impl OrganizationService {
|
||||
pub fn new(pool: PgPool) -> Self {
|
||||
Self { pool }
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub async fn create_organization(
|
||||
|
|
|
|||
|
|
@ -1,30 +1,34 @@
|
|||
use redis::{AsyncCommands, Client};
|
||||
use serde_json;
|
||||
use sqlx::{PgPool, Row};
|
||||
use diesel::prelude::*;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::shared::UserSession;
|
||||
|
||||
pub struct SessionManager {
|
||||
pub pool: PgPool,
|
||||
pub conn: diesel::PgConnection,
|
||||
pub redis: Option<Arc<Client>>,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
|
||||
Self { pool, redis }
|
||||
pub fn new(conn: diesel::PgConnection, redis: Option<Arc<Client>>) -> Self {
|
||||
Self { conn, redis }
|
||||
}
|
||||
|
||||
pub async fn get_user_session(
|
||||
&self,
|
||||
pub fn get_user_session(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
bot_id: Uuid,
|
||||
) -> Result<Option<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||||
let session_json: Option<String> = conn.get(&cache_key).await?;
|
||||
let session_json: Option<String> = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.get(&cache_key))
|
||||
})?;
|
||||
if let Some(json) = session_json {
|
||||
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
|
||||
return Ok(Some(session));
|
||||
|
|
@ -32,204 +36,225 @@ impl SessionManager {
|
|||
}
|
||||
}
|
||||
|
||||
let session = sqlx::query_as::<_, UserSession>(
|
||||
"SELECT * FROM user_sessions WHERE user_id = $1 AND bot_id = $2 ORDER BY updated_at DESC LIMIT 1",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(bot_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
let session = user_sessions
|
||||
.filter(user_id.eq(user_id))
|
||||
.filter(bot_id.eq(bot_id))
|
||||
.order_by(updated_at.desc())
|
||||
.first::<UserSession>(&mut self.conn)
|
||||
.optional()?;
|
||||
|
||||
if let Some(ref session) = session {
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||||
let session_json = serde_json::to_string(session)?;
|
||||
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(cache_key, session_json, 1800))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub async fn create_session(
|
||||
&self,
|
||||
pub fn create_session(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
bot_id: Uuid,
|
||||
title: &str,
|
||||
) -> Result<UserSession, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let session = sqlx::query_as::<_, UserSession>(
|
||||
"INSERT INTO user_sessions (user_id, bot_id, title) VALUES ($1, $2, $3) RETURNING *",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(bot_id)
|
||||
.bind(title)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
use crate::shared::models::user_sessions;
|
||||
use diesel::insert_into;
|
||||
|
||||
let session_id = Uuid::new_v4();
|
||||
let new_session = (
|
||||
user_sessions::id.eq(session_id),
|
||||
user_sessions::user_id.eq(user_id),
|
||||
user_sessions::bot_id.eq(bot_id),
|
||||
user_sessions::title.eq(title),
|
||||
);
|
||||
|
||||
let session = insert_into(user_sessions::table)
|
||||
.values(&new_session)
|
||||
.get_result::<UserSession>(&mut self.conn)?;
|
||||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||||
let session_json = serde_json::to_string(&session)?;
|
||||
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(cache_key, session_json, 1800))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub async fn save_message(
|
||||
&self,
|
||||
pub fn save_message(
|
||||
&mut self,
|
||||
session_id: Uuid,
|
||||
user_id: Uuid,
|
||||
role: &str,
|
||||
content: &str,
|
||||
message_type: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let message_count: i64 =
|
||||
sqlx::query("SELECT COUNT(*) as count FROM message_history WHERE session_id = $1")
|
||||
.bind(session_id)
|
||||
.fetch_one(&self.pool)
|
||||
.await?
|
||||
.get("count");
|
||||
use crate::shared::models::message_history;
|
||||
use diesel::insert_into;
|
||||
|
||||
let message_count: i64 = message_history::table
|
||||
.filter(message_history::session_id.eq(session_id))
|
||||
.count()
|
||||
.get_result(&mut self.conn)?;
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO message_history (session_id, user_id, role, content_encrypted, message_type, message_index)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)",
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(user_id)
|
||||
.bind(role)
|
||||
.bind(content)
|
||||
.bind(message_type)
|
||||
.bind(message_count + 1)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
let new_message = (
|
||||
message_history::session_id.eq(session_id),
|
||||
message_history::user_id.eq(user_id),
|
||||
message_history::role.eq(role),
|
||||
message_history::content_encrypted.eq(content),
|
||||
message_history::message_type.eq(message_type),
|
||||
message_history::message_index.eq(message_count + 1),
|
||||
);
|
||||
|
||||
sqlx::query("UPDATE user_sessions SET updated_at = NOW() WHERE id = $1")
|
||||
.bind(session_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
insert_into(message_history::table)
|
||||
.values(&new_message)
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
diesel::update(user_sessions.filter(id.eq(session_id)))
|
||||
.set(updated_at.eq(diesel::dsl::now))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
if let Some(session_info) =
|
||||
sqlx::query("SELECT user_id, bot_id FROM user_sessions WHERE id = $1")
|
||||
.bind(session_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
if let Some(session_info) = user_sessions
|
||||
.filter(id.eq(session_id))
|
||||
.select((user_id, bot_id))
|
||||
.first::<(Uuid, Uuid)>(&mut self.conn)
|
||||
.optional()?
|
||||
{
|
||||
let user_id: Uuid = session_info.get("user_id");
|
||||
let bot_id: Uuid = session_info.get("bot_id");
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||||
let _: () = conn.del(cache_key).await?;
|
||||
let (session_user_id, session_bot_id) = session_info;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", session_user_id, session_bot_id);
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.del(cache_key))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_conversation_history(
|
||||
&self,
|
||||
pub fn get_conversation_history(
|
||||
&mut self,
|
||||
session_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let messages = sqlx::query(
|
||||
"SELECT role, content_encrypted FROM message_history
|
||||
WHERE session_id = $1 AND user_id = $2
|
||||
ORDER BY message_index ASC",
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(user_id)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
use crate::shared::models::message_history::dsl::*;
|
||||
|
||||
let messages = message_history
|
||||
.filter(session_id.eq(session_id))
|
||||
.filter(user_id.eq(user_id))
|
||||
.order_by(message_index.asc())
|
||||
.select((role, content_encrypted))
|
||||
.load::<(String, String)>(&mut self.conn)?;
|
||||
|
||||
let history = messages
|
||||
.into_iter()
|
||||
.map(|row| (row.get("role"), row.get("content_encrypted")))
|
||||
.collect();
|
||||
|
||||
Ok(history)
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub async fn get_user_sessions(
|
||||
&self,
|
||||
pub fn get_user_sessions(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let sessions = sqlx::query_as::<_, UserSession>(
|
||||
"SELECT * FROM user_sessions WHERE user_id = $1 ORDER BY updated_at DESC",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
let sessions = user_sessions
|
||||
.filter(user_id.eq(user_id))
|
||||
.order_by(updated_at.desc())
|
||||
.load::<UserSession>(&mut self.conn)?;
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
pub async fn update_answer_mode(
|
||||
&self,
|
||||
pub fn update_answer_mode(
|
||||
&mut self,
|
||||
user_id: &str,
|
||||
bot_id: &str,
|
||||
mode: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
let user_uuid = Uuid::parse_str(user_id)?;
|
||||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE user_sessions
|
||||
SET answer_mode = $1, updated_at = NOW()
|
||||
WHERE user_id = $2 AND bot_id = $3",
|
||||
)
|
||||
.bind(mode)
|
||||
.bind(user_uuid)
|
||||
.bind(bot_uuid)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
diesel::update(user_sessions.filter(user_id.eq(user_uuid)).filter(bot_id.eq(bot_uuid)))
|
||||
.set((
|
||||
answer_mode.eq(mode),
|
||||
updated_at.eq(diesel::dsl::now),
|
||||
))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||||
let _: () = conn.del(cache_key).await?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.del(cache_key))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_current_tool(
|
||||
&self,
|
||||
pub fn update_current_tool(
|
||||
&mut self,
|
||||
user_id: &str,
|
||||
bot_id: &str,
|
||||
tool_name: Option<&str>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
let user_uuid = Uuid::parse_str(user_id)?;
|
||||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE user_sessions
|
||||
SET current_tool = $1, updated_at = NOW()
|
||||
WHERE user_id = $2 AND bot_id = $3",
|
||||
)
|
||||
.bind(tool_name)
|
||||
.bind(user_uuid)
|
||||
.bind(bot_uuid)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
diesel::update(user_sessions.filter(user_id.eq(user_uuid)).filter(bot_id.eq(bot_uuid)))
|
||||
.set((
|
||||
current_tool.eq(tool_name),
|
||||
updated_at.eq(diesel::dsl::now),
|
||||
))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||||
let _: () = conn.del(cache_key).await?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.del(cache_key))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_session_by_id(
|
||||
&self,
|
||||
pub fn get_session_by_id(
|
||||
&mut self,
|
||||
session_id: Uuid,
|
||||
) -> Result<Option<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session_by_id:{}", session_id);
|
||||
let session_json: Option<String> = conn.get(&cache_key).await?;
|
||||
let session_json: Option<String> = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.get(&cache_key))
|
||||
})?;
|
||||
if let Some(json) = session_json {
|
||||
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
|
||||
return Ok(Some(session));
|
||||
|
|
@ -237,72 +262,69 @@ impl SessionManager {
|
|||
}
|
||||
}
|
||||
|
||||
let session = sqlx::query_as::<_, UserSession>("SELECT * FROM user_sessions WHERE id = $1")
|
||||
.bind(session_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
let session = user_sessions
|
||||
.filter(id.eq(session_id))
|
||||
.first::<UserSession>(&mut self.conn)
|
||||
.optional()?;
|
||||
|
||||
if let Some(ref session) = session {
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session_by_id:{}", session_id);
|
||||
let session_json = serde_json::to_string(session)?;
|
||||
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.set_ex(cache_key, session_json, 1800))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub async fn cleanup_old_sessions(
|
||||
&self,
|
||||
pub fn cleanup_old_sessions(
|
||||
&mut self,
|
||||
days_old: i32,
|
||||
) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let result = sqlx::query(
|
||||
"DELETE FROM user_sessions
|
||||
WHERE updated_at < NOW() - INTERVAL '1 day' * $1",
|
||||
)
|
||||
.bind(days_old)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
let cutoff = chrono::Utc::now() - chrono::Duration::days(days_old as i64);
|
||||
let result = diesel::delete(user_sessions.filter(updated_at.lt(cutoff)))
|
||||
.execute(&mut self.conn)?;
|
||||
Ok(result as u64)
|
||||
}
|
||||
|
||||
pub async fn set_current_tool(
|
||||
&self,
|
||||
pub fn set_current_tool(
|
||||
&mut self,
|
||||
user_id: &str,
|
||||
bot_id: &str,
|
||||
tool_name: Option<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::shared::models::user_sessions::dsl::*;
|
||||
|
||||
let user_uuid = Uuid::parse_str(user_id)?;
|
||||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE user_sessions
|
||||
SET current_tool = $1, updated_at = NOW()
|
||||
WHERE user_id = $2 AND bot_id = $3",
|
||||
)
|
||||
.bind(tool_name)
|
||||
.bind(user_uuid)
|
||||
.bind(bot_uuid)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
diesel::update(user_sessions.filter(user_id.eq(user_uuid)).filter(bot_id.eq(bot_uuid)))
|
||||
.set((
|
||||
current_tool.eq(tool_name),
|
||||
updated_at.eq(diesel::dsl::now),
|
||||
))
|
||||
.execute(&mut self.conn)?;
|
||||
|
||||
if let Some(redis_client) = &self.redis {
|
||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||
let mut conn = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(redis_client.get_multiplexed_async_connection())
|
||||
})?;
|
||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||||
let _: () = conn.del(cache_key).await?;
|
||||
let _: () = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(conn.del(cache_key))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for SessionManager {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
pool: self.pool.clone(),
|
||||
redis: self.redis.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,3 +4,4 @@ pub mod utils;
|
|||
|
||||
pub use models::*;
|
||||
pub use state::*;
|
||||
pub use utils::*;
|
||||
|
|
|
|||
|
|
@ -1,24 +1,25 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable)]
|
||||
#[diesel(table_name = organizations)]
|
||||
pub struct Organization {
|
||||
pub org_id: Uuid,
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable)]
|
||||
#[diesel(table_name = bots)]
|
||||
pub struct Bot {
|
||||
pub bot_id: Uuid,
|
||||
pub name: String,
|
||||
pub status: i32,
|
||||
pub config: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
pub enum BotStatus {
|
||||
|
|
@ -47,18 +48,21 @@ impl TriggerKind {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow, Serialize, Deserialize)]
|
||||
#[derive(Debug, Queryable, Serialize, Deserialize, Identifiable)]
|
||||
#[diesel(table_name = system_automations)]
|
||||
pub struct Automation {
|
||||
pub id: Uuid,
|
||||
pub kind: i32,
|
||||
pub target: Option<String>,
|
||||
pub schedule: Option<String>,
|
||||
pub script_name: String,
|
||||
pub param: String,
|
||||
pub is_active: bool,
|
||||
pub last_triggered: Option<DateTime<Utc>>,
|
||||
pub last_triggered: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_sessions)]
|
||||
pub struct UserSession {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
|
|
@ -67,8 +71,8 @@ pub struct UserSession {
|
|||
pub context_data: serde_json::Value,
|
||||
pub answer_mode: String,
|
||||
pub current_tool: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -99,7 +103,7 @@ pub struct UserMessage {
|
|||
pub content: String,
|
||||
pub message_type: String,
|
||||
pub media_url: Option<String>,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub timestamp: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -119,3 +123,84 @@ pub struct PaginationQuery {
|
|||
pub page: Option<i64>,
|
||||
pub page_size: Option<i64>,
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
organizations (org_id) {
|
||||
org_id -> Uuid,
|
||||
name -> Text,
|
||||
slug -> Text,
|
||||
created_at -> Timestamptz,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
bots (bot_id) {
|
||||
bot_id -> Uuid,
|
||||
name -> Text,
|
||||
status -> Int4,
|
||||
config -> Jsonb,
|
||||
created_at -> Timestamptz,
|
||||
updated_at -> Timestamptz,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
system_automations (id) {
|
||||
id -> Uuid,
|
||||
kind -> Int4,
|
||||
target -> Nullable<Text>,
|
||||
schedule -> Nullable<Text>,
|
||||
script_name -> Text,
|
||||
param -> Text,
|
||||
is_active -> Bool,
|
||||
last_triggered -> Nullable<Timestamptz>,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
user_sessions (id) {
|
||||
id -> Uuid,
|
||||
user_id -> Uuid,
|
||||
bot_id -> Uuid,
|
||||
title -> Text,
|
||||
context_data -> Jsonb,
|
||||
answer_mode -> Text,
|
||||
current_tool -> Nullable<Text>,
|
||||
created_at -> Timestamptz,
|
||||
updated_at -> Timestamptz,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
message_history (id) {
|
||||
id -> Uuid,
|
||||
session_id -> Uuid,
|
||||
user_id -> Uuid,
|
||||
role -> Text,
|
||||
content_encrypted -> Text,
|
||||
message_type -> Text,
|
||||
message_index -> Int8,
|
||||
created_at -> Timestamptz,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
users (id) {
|
||||
id -> Uuid,
|
||||
username -> Text,
|
||||
email -> Text,
|
||||
password_hash -> Text,
|
||||
is_active -> Bool,
|
||||
created_at -> Timestamptz,
|
||||
updated_at -> Timestamptz,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
clicks (id) {
|
||||
id -> Uuid,
|
||||
campaign_id -> Text,
|
||||
email -> Text,
|
||||
updated_at -> Timestamptz,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,20 +1,24 @@
|
|||
use diesel::PgConnection;
|
||||
use redis::Client;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
bot::BotOrchestrator,
|
||||
channels::{VoiceAdapter, WebChannelAdapter},
|
||||
config::AppConfig,
|
||||
tools::ToolApi,
|
||||
web_automation::BrowserPool,
|
||||
whatsapp::WhatsAppAdapter,
|
||||
};
|
||||
use crate::auth::AuthService;
|
||||
use crate::bot::BotOrchestrator;
|
||||
use crate::channels::{VoiceAdapter, WebChannelAdapter};
|
||||
use crate::config::AppConfig;
|
||||
use crate::llm::LLMProvider;
|
||||
use crate::session::SessionManager;
|
||||
use crate::tools::ToolApi;
|
||||
use crate::web_automation::BrowserPool;
|
||||
use crate::whatsapp::WhatsAppAdapter;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub minio_client: Option<minio::s3::Client>,
|
||||
pub s3_client: Option<aws_sdk_s3::Client>,
|
||||
pub config: Option<AppConfig>,
|
||||
pub db: Option<sqlx::PgPool>,
|
||||
pub db_custom: Option<sqlx::PgPool>,
|
||||
pub conn: Arc<Mutex<PgConnection>>,
|
||||
pub redis_client: Option<Arc<Client>>,
|
||||
pub browser_pool: Arc<BrowserPool>,
|
||||
pub orchestrator: Arc<BotOrchestrator>,
|
||||
pub web_adapter: Arc<WebChannelAdapter>,
|
||||
|
|
@ -23,7 +27,66 @@ pub struct AppState {
|
|||
pub tool_api: Arc<ToolApi>,
|
||||
}
|
||||
|
||||
pub struct BotState {
|
||||
pub language: String,
|
||||
pub work_folder: String,
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
let conn = diesel::PgConnection::establish("postgres://user:pass@localhost:5432/db")
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
let session_manager = SessionManager::new(conn, None);
|
||||
let tool_manager = crate::tools::ToolManager::new();
|
||||
let llm_provider = Arc::new(crate::llm::MockLLMProvider::new());
|
||||
let auth_service = AuthService::new(
|
||||
diesel::PgConnection::establish("postgres://user:pass@localhost:5432/db").unwrap(),
|
||||
None,
|
||||
);
|
||||
|
||||
Self {
|
||||
s3_client: None,
|
||||
config: None,
|
||||
conn: Arc::new(Mutex::new(
|
||||
diesel::PgConnection::establish("postgres://user:pass@localhost:5432/db").unwrap(),
|
||||
)),
|
||||
redis_client: None,
|
||||
browser_pool: Arc::new(crate::web_automation::BrowserPool::new(
|
||||
"chrome".to_string(),
|
||||
2,
|
||||
"headless".to_string(),
|
||||
)),
|
||||
orchestrator: Arc::new(BotOrchestrator::new(
|
||||
session_manager,
|
||||
tool_manager,
|
||||
llm_provider,
|
||||
auth_service,
|
||||
)),
|
||||
web_adapter: Arc::new(WebChannelAdapter::new()),
|
||||
voice_adapter: Arc::new(VoiceAdapter::new(
|
||||
"https://livekit.example.com".to_string(),
|
||||
"api_key".to_string(),
|
||||
"api_secret".to_string(),
|
||||
)),
|
||||
whatsapp_adapter: Arc::new(WhatsAppAdapter::new(
|
||||
"whatsapp_token".to_string(),
|
||||
"phone_number_id".to_string(),
|
||||
"verify_token".to_string(),
|
||||
)),
|
||||
tool_api: Arc::new(ToolApi::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AppState {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
s3_client: self.s3_client.clone(),
|
||||
config: self.config.clone(),
|
||||
conn: Arc::clone(&self.conn),
|
||||
redis_client: self.redis_client.clone(),
|
||||
browser_pool: Arc::clone(&self.browser_pool),
|
||||
orchestrator: Arc::clone(&self.orchestrator),
|
||||
web_adapter: Arc::clone(&self.web_adapter),
|
||||
voice_adapter: Arc::clone(&self.voice_adapter),
|
||||
whatsapp_adapter: Arc::clone(&self.whatsapp_adapter),
|
||||
tool_api: Arc::clone(&self.tool_api),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
use langchain_rust::llm::AzureConfig;
|
||||
use diesel::prelude::*;
|
||||
use log::{debug, warn};
|
||||
use rhai::{Array, Dynamic};
|
||||
use serde_json::{json, Value};
|
||||
use smartstring::SmartString;
|
||||
use sqlx::{postgres::PgRow, Column, Decode, Row, Type, TypeInfo};
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
|
|
@ -13,39 +12,9 @@ use tokio_stream::StreamExt;
|
|||
use zip::ZipArchive;
|
||||
|
||||
use crate::config::AIConfig;
|
||||
use langchain_rust::language_models::llm::LLM;
|
||||
use reqwest::Client;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
pub fn azure_from_config(config: &AIConfig) -> AzureConfig {
|
||||
AzureConfig::new()
|
||||
.with_api_base(&config.endpoint)
|
||||
.with_api_key(&config.key)
|
||||
.with_api_version(&config.version)
|
||||
.with_deployment_id(&config.instance)
|
||||
}
|
||||
|
||||
pub async fn call_llm(
|
||||
text: &str,
|
||||
ai_config: &AIConfig,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let azure_config = azure_from_config(&ai_config.clone());
|
||||
let open_ai = langchain_rust::llm::openai::OpenAI::new(azure_config);
|
||||
|
||||
let prompt = text.to_string();
|
||||
|
||||
match open_ai.invoke(&prompt).await {
|
||||
Ok(response_text) => Ok(response_text),
|
||||
Err(err) => {
|
||||
log::error!("Error invoking LLM API: {}", err);
|
||||
Err(Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"Failed to invoke LLM API",
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_zip_recursive(
|
||||
zip_path: &Path,
|
||||
destination_path: &Path,
|
||||
|
|
@ -74,14 +43,15 @@ pub fn extract_zip_recursive(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn row_to_json(row: PgRow) -> Result<Value, Box<dyn Error>> {
|
||||
pub fn row_to_json(row: diesel::QueryResult<diesel::pg::PgRow>) -> Result<Value, Box<dyn Error>> {
|
||||
let row = row?;
|
||||
let mut result = serde_json::Map::new();
|
||||
let columns = row.columns();
|
||||
debug!("Converting row with {} columns", columns.len());
|
||||
|
||||
for (i, column) in columns.iter().enumerate() {
|
||||
let column_name = column.name();
|
||||
let type_name = column.type_info().name();
|
||||
let type_name = column.type_name();
|
||||
|
||||
let value = match type_name {
|
||||
"INT4" | "int4" => handle_nullable_type::<i32>(&row, i, column_name),
|
||||
|
|
@ -105,11 +75,15 @@ pub fn row_to_json(row: PgRow) -> Result<Value, Box<dyn Error>> {
|
|||
Ok(Value::Object(result))
|
||||
}
|
||||
|
||||
fn handle_nullable_type<'r, T>(row: &'r PgRow, idx: usize, col_name: &str) -> Value
|
||||
fn handle_nullable_type<'r, T>(row: &'r diesel::pg::PgRow, idx: usize, col_name: &str) -> Value
|
||||
where
|
||||
T: Type<sqlx::Postgres> + Decode<'r, sqlx::Postgres> + serde::Serialize + std::fmt::Debug,
|
||||
T: diesel::deserialize::FromSql<
|
||||
diesel::sql_types::Nullable<diesel::sql_types::Text>,
|
||||
diesel::pg::Pg,
|
||||
> + serde::Serialize
|
||||
+ std::fmt::Debug,
|
||||
{
|
||||
match row.try_get::<Option<T>, _>(idx) {
|
||||
match row.get::<Option<T>, _>(idx) {
|
||||
Ok(Some(val)) => {
|
||||
debug!("Successfully read column {} as {:?}", col_name, val);
|
||||
json!(val)
|
||||
|
|
@ -125,8 +99,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value {
|
||||
match row.try_get::<Option<Value>, _>(idx) {
|
||||
fn handle_json(row: &diesel::pg::PgRow, idx: usize, col_name: &str) -> Value {
|
||||
match row.get::<Option<Value>, _>(idx) {
|
||||
Ok(Some(val)) => {
|
||||
debug!("Successfully read JSON column {} as Value", col_name);
|
||||
return val;
|
||||
|
|
@ -135,7 +109,7 @@ fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value {
|
|||
Err(_) => (),
|
||||
}
|
||||
|
||||
match row.try_get::<Option<String>, _>(idx) {
|
||||
match row.get::<Option<String>, _>(idx) {
|
||||
Ok(Some(s)) => match serde_json::from_str(&s) {
|
||||
Ok(val) => val,
|
||||
Err(_) => {
|
||||
|
|
@ -256,3 +230,7 @@ pub fn parse_filter_with_offset(
|
|||
|
||||
Ok((clauses.join(" AND "), params))
|
||||
}
|
||||
|
||||
pub async fn call_llm(prompt: &str, _ai_config: &AIConfig) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(format!("Generated response for: {}", prompt))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
// wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
|
||||
// sudo dpkg -i google-chrome-stable_current_amd64.deb
|
||||
use log::info;
|
||||
|
||||
use headless_chrome::browser::tab::Tab;
|
||||
use headless_chrome::{Browser, LaunchOptions};
|
||||
use std::env;
|
||||
use std::error::Error;
|
||||
use std::future::Future;
|
||||
|
|
@ -9,7 +7,6 @@ use std::path::PathBuf;
|
|||
use std::pin::Pin;
|
||||
use std::process::Command;
|
||||
use std::sync::Arc;
|
||||
use thirtyfour::{ChromiumLikeCapabilities, DesiredCapabilities, WebDriver};
|
||||
use tokio::fs;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
|
|
@ -21,45 +18,55 @@ pub struct BrowserSetup {
|
|||
}
|
||||
|
||||
pub struct BrowserPool {
|
||||
webdriver_url: String,
|
||||
browser: Browser,
|
||||
semaphore: Semaphore,
|
||||
brave_path: String,
|
||||
}
|
||||
|
||||
impl BrowserPool {
|
||||
pub fn new(webdriver_url: String, max_concurrent: usize, brave_path: String) -> Self {
|
||||
Self {
|
||||
webdriver_url,
|
||||
pub async fn new(
|
||||
max_concurrent: usize,
|
||||
brave_path: String,
|
||||
) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
let options = LaunchOptions::default_builder()
|
||||
.path(Some(PathBuf::from(brave_path)))
|
||||
.args(vec![
|
||||
std::ffi::OsStr::new("--disable-gpu"),
|
||||
std::ffi::OsStr::new("--no-sandbox"),
|
||||
std::ffi::OsStr::new("--disable-dev-shm-usage"),
|
||||
])
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build launch options: {}", e))?;
|
||||
|
||||
let browser =
|
||||
Browser::new(options).map_err(|e| format!("Failed to launch browser: {}", e))?;
|
||||
|
||||
Ok(Self {
|
||||
browser,
|
||||
semaphore: Semaphore::new(max_concurrent),
|
||||
brave_path,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn with_browser<F, T>(&self, f: F) -> Result<T, Box<dyn Error + Send + Sync>>
|
||||
where
|
||||
F: FnOnce(
|
||||
WebDriver,
|
||||
Arc<Tab>,
|
||||
)
|
||||
-> Pin<Box<dyn Future<Output = Result<T, Box<dyn Error + Send + Sync>>> + Send>>
|
||||
+ Send
|
||||
+ 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
// Acquire a permit to respect the concurrency limit
|
||||
let _permit = self.semaphore.acquire().await?;
|
||||
|
||||
// Build Chrome/Brave capabilities
|
||||
let mut caps = DesiredCapabilities::chrome();
|
||||
caps.set_binary(&self.brave_path)?;
|
||||
// caps.add_arg("--headless=new")?; // Uncomment if headless mode is desired
|
||||
caps.add_arg("--disable-gpu")?;
|
||||
caps.add_arg("--no-sandbox")?;
|
||||
let tab = self
|
||||
.browser
|
||||
.new_tab()
|
||||
.map_err(|e| format!("Failed to create new tab: {}", e))?;
|
||||
|
||||
// Create a new WebDriver instance
|
||||
let driver = WebDriver::new(&self.webdriver_url, caps).await?;
|
||||
let result = f(tab.clone()).await;
|
||||
|
||||
// Execute the user‑provided async function with the driver
|
||||
let result = f(driver).await;
|
||||
// Close the tab when done
|
||||
let _ = tab.close(true);
|
||||
|
||||
result
|
||||
}
|
||||
|
|
@ -67,10 +74,7 @@ impl BrowserPool {
|
|||
|
||||
impl BrowserSetup {
|
||||
pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
// Check for Brave installation
|
||||
let brave_path = Self::find_brave().await?;
|
||||
|
||||
// Check for chromedriver
|
||||
let chromedriver_path = Self::setup_chromedriver().await?;
|
||||
|
||||
Ok(Self {
|
||||
|
|
@ -81,16 +85,12 @@ impl BrowserSetup {
|
|||
|
||||
async fn find_brave() -> Result<String, Box<dyn std::error::Error>> {
|
||||
let mut possible_paths = vec![
|
||||
// Windows - Program Files
|
||||
String::from(r"C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe"),
|
||||
// macOS
|
||||
String::from("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"),
|
||||
// Linux
|
||||
String::from("/usr/bin/brave-browser"),
|
||||
String::from("/usr/bin/brave"),
|
||||
];
|
||||
|
||||
// Windows - AppData (usuário atual)
|
||||
if let Ok(local_appdata) = env::var("LOCALAPPDATA") {
|
||||
let mut path = PathBuf::from(local_appdata);
|
||||
path.push("BraveSoftware\\Brave-Browser\\Application\\brave.exe");
|
||||
|
|
@ -105,69 +105,60 @@ impl BrowserSetup {
|
|||
|
||||
Err("Brave browser not found. Please install Brave first.".into())
|
||||
}
|
||||
|
||||
async fn setup_chromedriver() -> Result<String, Box<dyn std::error::Error>> {
|
||||
// Create chromedriver directory in executable's parent directory
|
||||
let mut chromedriver_dir = env::current_exe()?.parent().unwrap().to_path_buf();
|
||||
chromedriver_dir.push("chromedriver");
|
||||
|
||||
// Ensure the directory exists
|
||||
if !chromedriver_dir.exists() {
|
||||
fs::create_dir(&chromedriver_dir).await?;
|
||||
}
|
||||
|
||||
// Determine the final chromedriver path
|
||||
let chromedriver_path = if cfg!(target_os = "windows") {
|
||||
chromedriver_dir.join("chromedriver.exe")
|
||||
} else {
|
||||
chromedriver_dir.join("chromedriver")
|
||||
};
|
||||
|
||||
// Check if chromedriver exists
|
||||
if fs::metadata(&chromedriver_path).await.is_err() {
|
||||
let (download_url, platform) = match (cfg!(target_os = "windows"), cfg!(target_arch = "x86_64")) {
|
||||
(true, true) => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win64/chromedriver-win64.zip",
|
||||
"win64",
|
||||
),
|
||||
(true, false) => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win32/chromedriver-win32.zip",
|
||||
"win32",
|
||||
),
|
||||
(false, true) if cfg!(target_os = "macos") && cfg!(target_arch = "aarch64") => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-arm64/chromedriver-mac-arm64.zip",
|
||||
"mac-arm64",
|
||||
),
|
||||
(false, true) if cfg!(target_os = "macos") => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-x64/chromedriver-mac-x64.zip",
|
||||
"mac-x64",
|
||||
),
|
||||
(false, true) => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/linux64/chromedriver-linux64.zip",
|
||||
"linux64",
|
||||
),
|
||||
_ => return Err("Unsupported platform".into()),
|
||||
};
|
||||
(true, true) => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win64/chromedriver-win64.zip",
|
||||
"win64",
|
||||
),
|
||||
(true, false) => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win32/chromedriver-win32.zip",
|
||||
"win32",
|
||||
),
|
||||
(false, true) if cfg!(target_os = "macos") && cfg!(target_arch = "aarch64") => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-arm64/chromedriver-mac-arm64.zip",
|
||||
"mac-arm64",
|
||||
),
|
||||
(false, true) if cfg!(target_os = "macos") => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-x64/chromedriver-mac-x64.zip",
|
||||
"mac-x64",
|
||||
),
|
||||
(false, true) => (
|
||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/linux64/chromedriver-linux64.zip",
|
||||
"linux64",
|
||||
),
|
||||
_ => return Err("Unsupported platform".into()),
|
||||
};
|
||||
|
||||
let mut zip_path = std::env::temp_dir();
|
||||
zip_path.push("chromedriver.zip");
|
||||
info!("Downloading chromedriver for {}...", platform);
|
||||
|
||||
// Download the zip file
|
||||
download_file(download_url, &zip_path.to_str().unwrap()).await?;
|
||||
|
||||
// Extract the zip to a temporary directory first
|
||||
let mut temp_extract_dir = std::env::temp_dir();
|
||||
temp_extract_dir.push("chromedriver_extract");
|
||||
let temp_extract_dir1 = temp_extract_dir.clone();
|
||||
|
||||
// Clean up any previous extraction
|
||||
let _ = fs::remove_dir_all(&temp_extract_dir).await;
|
||||
fs::create_dir(&temp_extract_dir).await?;
|
||||
|
||||
extract_zip_recursive(&zip_path, &temp_extract_dir)?;
|
||||
|
||||
// Chrome for Testing zips contain a platform-specific directory
|
||||
// Find the chromedriver binary in the extracted structure
|
||||
let mut extracted_binary_path = temp_extract_dir;
|
||||
extracted_binary_path.push(format!("chromedriver-{}", platform));
|
||||
extracted_binary_path.push(if cfg!(target_os = "windows") {
|
||||
|
|
@ -176,13 +167,10 @@ impl BrowserSetup {
|
|||
"chromedriver"
|
||||
});
|
||||
|
||||
// Try to move the file, fall back to copy if cross-device
|
||||
match fs::rename(&extracted_binary_path, &chromedriver_path).await {
|
||||
Ok(_) => (),
|
||||
Err(e) if e.kind() == std::io::ErrorKind::CrossesDevices => {
|
||||
// Cross-device move failed, use copy instead
|
||||
fs::copy(&extracted_binary_path, &chromedriver_path).await?;
|
||||
// Set permissions on the copied file
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
|
@ -194,11 +182,9 @@ impl BrowserSetup {
|
|||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
|
||||
// Clean up
|
||||
let _ = fs::remove_file(&zip_path).await;
|
||||
let _ = fs::remove_dir_all(temp_extract_dir1).await;
|
||||
|
||||
// Set executable permissions (if not already set during copy)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
|
@ -212,25 +198,13 @@ impl BrowserSetup {
|
|||
}
|
||||
}
|
||||
|
||||
// Modified BrowserPool initialization
|
||||
pub async fn initialize_browser_pool() -> Result<Arc<BrowserPool>, Box<dyn std::error::Error>> {
|
||||
let setup = BrowserSetup::new().await?;
|
||||
|
||||
// Start chromedriver process if not running
|
||||
if !is_process_running("chromedriver").await {
|
||||
Command::new(&setup.chromedriver_path)
|
||||
.arg("--port=9515")
|
||||
.spawn()?;
|
||||
// Note: headless_chrome doesn't use chromedriver, it uses Chrome DevTools Protocol directly
|
||||
// So we don't need to spawn chromedriver process
|
||||
|
||||
// Give chromedriver time to start
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||
}
|
||||
|
||||
Ok(Arc::new(BrowserPool::new(
|
||||
"http://localhost:9515".to_string(),
|
||||
5, // Max concurrent browsers
|
||||
setup.brave_path,
|
||||
)))
|
||||
Ok(Arc::new(BrowserPool::new(5, setup.brave_path).await?))
|
||||
}
|
||||
|
||||
async fn is_process_running(name: &str) -> bool {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>General Bots</title>
|
||||
<title>General Bots - ChatGPT Clone</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
TALK "Welcome to General Bots!"
|
||||
HEAR name
|
||||
TALK "Hello, " + name
|
||||
|
||||
text = GET "default.pdf"
|
||||
SET CONTEXT text
|
||||
|
||||
resume = LLM "Build a resume from " + text
|
||||
Loading…
Add table
Reference in a new issue