Compare commits
11 commits
c264ad1294
...
de017241f2
| Author | SHA1 | Date | |
|---|---|---|---|
| de017241f2 | |||
| e143968179 | |||
| df9b228a35 | |||
| 98813fbdc8 | |||
| ac5b814536 | |||
| d7211a6c19 | |||
| 3b21ab5ef9 | |||
| b1118f977d | |||
| f7c60362e3 | |||
| 9b86b204f2 | |||
| 848b875698 |
105 changed files with 8055 additions and 1847 deletions
|
|
@ -124,6 +124,7 @@ sha1 = { workspace = true }
|
|||
tokio = { workspace = true, features = ["full", "process"] }
|
||||
tower-http = { workspace = true, features = ["cors", "fs", "trace"] }
|
||||
tracing = { workspace = true }
|
||||
url = { workspace = true }
|
||||
urlencoding = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4", "v5"] }
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ On first run, botserver automatically:
|
|||
- Installs required components (PostgreSQL, S3 storage, Redis cache, LLM)
|
||||
- Sets up database with migrations
|
||||
- Downloads AI models
|
||||
- Starts HTTP server at `http://localhost:8088`
|
||||
- Starts HTTP server at `http://localhost:9000`
|
||||
|
||||
### Command-Line Options
|
||||
|
||||
|
|
@ -318,7 +318,7 @@ When a file grows beyond this limit:
|
|||
| `attendance/llm_assist.rs` | 2053 | → 5 files |
|
||||
| `drive/mod.rs` | 1522 | → 4 files |
|
||||
|
||||
**See `TODO-refactor1.md` for detailed refactoring plans**
|
||||
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -465,7 +465,7 @@ We welcome contributions! Please read our contributing guidelines before submitt
|
|||
|
||||
1. **Replace 955 unwrap()/expect() calls** with proper error handling
|
||||
2. **Optimize 12,973 clone()/to_string() calls** for performance
|
||||
3. **Refactor 5 large files** following TODO-refactor1.md
|
||||
3. **Refactor 5 large files** following refactoring plan
|
||||
4. **Add missing error handling** in critical paths
|
||||
5. **Implement proper logging** instead of panicking
|
||||
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ Type=simple
|
|||
User=pi
|
||||
Environment=DISPLAY=:0
|
||||
ExecStartPre=/bin/sleep 5
|
||||
ExecStart=/usr/bin/chromium-browser --kiosk --noerrdialogs --disable-infobars --disable-session-crashed-bubble --app=http://localhost:8088/embedded/
|
||||
ExecStart=/usr/bin/chromium-browser --kiosk --noerrdialogs --disable-infobars --disable-session-crashed-bubble --app=http://localhost:9000/embedded/
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
|
||||
|
|
@ -498,10 +498,10 @@ echo "View logs:"
|
|||
echo " ssh $TARGET_HOST 'sudo journalctl -u botserver -f'"
|
||||
echo ""
|
||||
if [ "$WITH_UI" = true ]; then
|
||||
echo "Access UI at: http://$TARGET_HOST:8088/embedded/"
|
||||
echo "Access UI at: http://$TARGET_HOST:9000/embedded/"
|
||||
fi
|
||||
if [ "$WITH_LLAMA" = true ]; then
|
||||
echo ""
|
||||
echo "llama.cpp server running at: http://$TARGET_HOST:8080"
|
||||
echo "Test: curl http://$TARGET_HOST:8080/v1/models"
|
||||
echo "llama.cpp server running at: http://$TARGET_HOST:9000"
|
||||
echo "Test: curl http://$TARGET_HOST:9000/v1/models"
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ pub mod goals_ui;
|
|||
pub mod insights;
|
||||
|
||||
use crate::core::urls::ApiUrls;
|
||||
#[cfg(feature = "llm")]
|
||||
use crate::llm::observability::{ObservabilityConfig, ObservabilityManager, QuickStats};
|
||||
use crate::core::shared::state::AppState;
|
||||
use axum::{
|
||||
extract::State,
|
||||
|
|
|
|||
|
|
@ -1,15 +1,4 @@
|
|||
pub mod llm_assist_types;
|
||||
pub mod llm_assist_config;
|
||||
pub mod llm_assist_handlers;
|
||||
pub mod llm_assist_commands;
|
||||
pub mod llm_assist_helpers;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use llm_assist_types::*;
|
||||
|
||||
// Re-export handlers for routing
|
||||
pub use llm_assist_handlers::*;
|
||||
pub use llm_assist_commands::*;
|
||||
use crate::attendance::{llm_assist_types, llm_assist_config, llm_assist_handlers, llm_assist_commands};
|
||||
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
|
|
@ -18,6 +7,10 @@ use axum::{
|
|||
use std::sync::Arc;
|
||||
use crate::core::shared::state::AppState;
|
||||
|
||||
pub use llm_assist_types::*;
|
||||
pub use llm_assist_handlers::*;
|
||||
pub use llm_assist_commands::*;
|
||||
|
||||
pub fn llm_assist_routes() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
.route("/llm-assist/config/:bot_id", get(get_llm_config))
|
||||
|
|
|
|||
143
src/basic/compiler/blocks/mail.rs
Normal file
143
src/basic/compiler/blocks/mail.rs
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
use log::info;
|
||||
|
||||
pub fn convert_mail_line_with_substitution(line: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut chars = line.chars().peekable();
|
||||
let mut in_substitution = false;
|
||||
let mut current_var = String::new();
|
||||
let mut current_literal = String::new();
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
match c {
|
||||
'$' => {
|
||||
if let Some(&'{') = chars.peek() {
|
||||
chars.next();
|
||||
|
||||
if !current_literal.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push('"');
|
||||
result.push_str(¤t_literal.replace('"', "\\\""));
|
||||
result.push('"');
|
||||
} else {
|
||||
result.push_str(" + \"");
|
||||
result.push_str(¤t_literal.replace('"', "\\\""));
|
||||
result.push('"');
|
||||
}
|
||||
current_literal.clear();
|
||||
}
|
||||
in_substitution = true;
|
||||
current_var.clear();
|
||||
} else {
|
||||
current_literal.push(c);
|
||||
}
|
||||
}
|
||||
'}' if in_substitution => {
|
||||
in_substitution = false;
|
||||
if !current_var.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push_str(¤t_var);
|
||||
} else {
|
||||
result.push_str(" + ");
|
||||
result.push_str(¤t_var);
|
||||
}
|
||||
}
|
||||
current_var.clear();
|
||||
}
|
||||
_ if in_substitution => {
|
||||
if c.is_alphanumeric() || c == '_' || c == '(' || c == ')' || c == ',' || c == ' ' || c == '\"' {
|
||||
current_var.push(c);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if !in_substitution {
|
||||
current_literal.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !current_literal.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push('"');
|
||||
result.push_str(¤t_literal.replace('"', "\\\""));
|
||||
result.push('"');
|
||||
} else {
|
||||
result.push_str(" + \"");
|
||||
result.push_str(¤t_literal.replace('"', "\\\""));
|
||||
result.push('"');
|
||||
}
|
||||
}
|
||||
|
||||
info!("[TOOL] Converted mail line: '{}' → '{}'", line, result);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn convert_mail_block(recipient: &str, lines: &[String]) -> String {
|
||||
let mut subject = String::new();
|
||||
let mut body_lines: Vec<String> = Vec::new();
|
||||
// let mut in_subject = true; // Removed unused variable
|
||||
let mut skip_blank = true;
|
||||
|
||||
for line in lines.iter() {
|
||||
if line.to_uppercase().starts_with("SUBJECT:") {
|
||||
subject = line[8..].trim().to_string();
|
||||
// in_subject = false; // Removed unused assignment
|
||||
skip_blank = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if skip_blank && line.trim().is_empty() {
|
||||
skip_blank = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
skip_blank = false;
|
||||
let converted = convert_mail_line_with_substitution(line);
|
||||
body_lines.push(converted);
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
let chunk_size = 5;
|
||||
let mut all_vars: Vec<String> = Vec::new();
|
||||
|
||||
for (var_count, chunk) in body_lines.chunks(chunk_size).enumerate() {
|
||||
let var_name = format!("__mail_body_{}__", var_count);
|
||||
all_vars.push(var_name.clone());
|
||||
|
||||
if chunk.len() == 1 {
|
||||
result.push_str(&format!("let {} = {};\n", var_name, chunk[0]));
|
||||
} else {
|
||||
let mut chunk_expr = chunk[0].clone();
|
||||
for line in &chunk[1..] {
|
||||
chunk_expr.push_str(" + \"\\n\" + ");
|
||||
chunk_expr.push_str(line);
|
||||
}
|
||||
result.push_str(&format!("let {} = {};\n", var_name, chunk_expr));
|
||||
}
|
||||
}
|
||||
|
||||
let body_expr = if all_vars.is_empty() {
|
||||
"\"\"".to_string()
|
||||
} else if all_vars.len() == 1 {
|
||||
all_vars[0].clone()
|
||||
} else {
|
||||
let mut expr = all_vars[0].clone();
|
||||
for var in &all_vars[1..] {
|
||||
expr.push_str(" + \"\\n\" + ");
|
||||
expr.push_str(var);
|
||||
}
|
||||
expr
|
||||
};
|
||||
|
||||
let recipient_expr = if recipient.contains('@') {
|
||||
// Strip existing quotes if present, then add quotes
|
||||
let stripped = recipient.trim_matches('"');
|
||||
format!("\"{}\"", stripped)
|
||||
} else {
|
||||
recipient.to_string()
|
||||
};
|
||||
result.push_str(&format!("send_mail({}, \"{}\", {}, []);\n", recipient_expr, subject, body_expr));
|
||||
|
||||
info!("[TOOL] Converted MAIL block → {}", result);
|
||||
result
|
||||
}
|
||||
76
src/basic/compiler/blocks/mod.rs
Normal file
76
src/basic/compiler/blocks/mod.rs
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
pub mod mail;
|
||||
pub mod talk;
|
||||
|
||||
pub use mail::convert_mail_block;
|
||||
pub use talk::convert_talk_block;
|
||||
|
||||
use log::info;
|
||||
|
||||
pub fn convert_begin_blocks(script: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut in_talk_block = false;
|
||||
let mut talk_block_lines: Vec<String> = Vec::new();
|
||||
let mut in_mail_block = false;
|
||||
let mut mail_recipient = String::new();
|
||||
let mut mail_block_lines: Vec<String> = Vec::new();
|
||||
|
||||
for line in script.lines() {
|
||||
let trimmed = line.trim();
|
||||
let upper = trimmed.to_uppercase();
|
||||
|
||||
if trimmed.is_empty() || trimmed.starts_with('\'') || trimmed.starts_with("//") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if upper == "BEGIN TALK" {
|
||||
info!("[TOOL] Converting BEGIN TALK statement");
|
||||
in_talk_block = true;
|
||||
talk_block_lines.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if upper == "END TALK" {
|
||||
info!("[TOOL] Converting END TALK statement, processing {} lines", talk_block_lines.len());
|
||||
in_talk_block = false;
|
||||
let converted = convert_talk_block(&talk_block_lines);
|
||||
result.push_str(&converted);
|
||||
talk_block_lines.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_talk_block {
|
||||
talk_block_lines.push(trimmed.to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
if upper.starts_with("BEGIN MAIL ") {
|
||||
let recipient = &trimmed[11..].trim();
|
||||
info!("[TOOL] Converting BEGIN MAIL statement: recipient='{}'", recipient);
|
||||
mail_recipient = recipient.to_string();
|
||||
in_mail_block = true;
|
||||
mail_block_lines.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if upper == "END MAIL" {
|
||||
info!("[TOOL] Converting END MAIL statement, processing {} lines", mail_block_lines.len());
|
||||
in_mail_block = false;
|
||||
let converted = convert_mail_block(&mail_recipient, &mail_block_lines);
|
||||
result.push_str(&converted);
|
||||
result.push('\n');
|
||||
mail_recipient.clear();
|
||||
mail_block_lines.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_mail_block {
|
||||
mail_block_lines.push(trimmed.to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
205
src/basic/compiler/blocks/talk.rs
Normal file
205
src/basic/compiler/blocks/talk.rs
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
use log::info;
|
||||
|
||||
pub fn convert_talk_line_with_substitution(line: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut chars = line.chars().peekable();
|
||||
let mut in_substitution = false;
|
||||
let mut current_var = String::new();
|
||||
let mut current_literal = String::new();
|
||||
let mut paren_depth = 0;
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
match c {
|
||||
'$' => {
|
||||
if let Some(&'{') = chars.peek() {
|
||||
chars.next();
|
||||
|
||||
if !current_literal.is_empty() {
|
||||
// Output the literal with proper quotes
|
||||
if result.is_empty() {
|
||||
result.push_str("TALK \"");
|
||||
} else {
|
||||
result.push_str(" + \"");
|
||||
}
|
||||
let escaped = current_literal.replace('"', "\\\"");
|
||||
result.push_str(&escaped);
|
||||
result.push('"');
|
||||
current_literal.clear();
|
||||
}
|
||||
in_substitution = true;
|
||||
current_var.clear();
|
||||
paren_depth = 0;
|
||||
} else {
|
||||
current_literal.push(c);
|
||||
}
|
||||
}
|
||||
'}' if in_substitution => {
|
||||
if paren_depth == 0 {
|
||||
in_substitution = false;
|
||||
if !current_var.is_empty() {
|
||||
// If result is empty, we need to start with "TALK "
|
||||
// but DON'T add opening quote - the variable is not a literal
|
||||
if result.is_empty() {
|
||||
result.push_str("TALK ");
|
||||
} else {
|
||||
result.push_str(" + ");
|
||||
}
|
||||
result.push_str(¤t_var);
|
||||
}
|
||||
current_var.clear();
|
||||
} else {
|
||||
current_var.push(c);
|
||||
paren_depth -= 1;
|
||||
}
|
||||
}
|
||||
_ if in_substitution => {
|
||||
if c.is_alphanumeric() || c == '_' || c == '.' || c == '[' || c == ']' || c == ',' || c == '"' {
|
||||
current_var.push(c);
|
||||
} else if c == '(' {
|
||||
current_var.push(c);
|
||||
paren_depth += 1;
|
||||
} else if c == ')' && paren_depth > 0 {
|
||||
current_var.push(c);
|
||||
paren_depth -= 1;
|
||||
} else if (c == ':' || c == '=' || c == ' ') && paren_depth == 0 {
|
||||
// Handle special punctuation that ends a variable context
|
||||
// Only end substitution if we're not inside parentheses (function call)
|
||||
in_substitution = false;
|
||||
if !current_var.is_empty() {
|
||||
// If result is empty, start with "TALK " (without opening quote)
|
||||
if result.is_empty() {
|
||||
result.push_str("TALK ");
|
||||
} else {
|
||||
result.push_str(" + ");
|
||||
}
|
||||
result.push_str(¤t_var);
|
||||
}
|
||||
current_var.clear();
|
||||
current_literal.push(c);
|
||||
} else if c == ' ' {
|
||||
// Allow spaces inside function calls
|
||||
current_var.push(c);
|
||||
}
|
||||
// Ignore other invalid characters - they'll be processed as literals
|
||||
}
|
||||
'\\' if in_substitution => {
|
||||
if let Some(&next_char) = chars.peek() {
|
||||
current_var.push(next_char);
|
||||
chars.next();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
current_literal.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !current_literal.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push_str("TALK \"");
|
||||
} else {
|
||||
result.push_str(" + \"");
|
||||
}
|
||||
let escaped = current_literal.replace('"', "\\\"");
|
||||
result.push_str(&escaped);
|
||||
result.push('"');
|
||||
}
|
||||
|
||||
if result.is_empty() {
|
||||
result = "TALK \"\"".to_string();
|
||||
}
|
||||
|
||||
info!("[TOOL] Converted TALK line: '{}' → '{}'", line, result);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn convert_talk_block(lines: &[String]) -> String {
|
||||
// Convert all lines first
|
||||
let converted_lines: Vec<String> = lines.iter()
|
||||
.map(|line| convert_talk_line_with_substitution(line))
|
||||
.collect();
|
||||
|
||||
// Extract content after "TALK " prefix
|
||||
let line_contents: Vec<String> = converted_lines.iter()
|
||||
.map(|line| {
|
||||
if let Some(stripped) = line.strip_prefix("TALK ") {
|
||||
stripped.trim().to_string()
|
||||
} else {
|
||||
line.clone()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Use chunking to reduce expression complexity (max 5 lines per chunk)
|
||||
let chunk_size = 5;
|
||||
let mut result = String::new();
|
||||
|
||||
for (chunk_idx, chunk) in line_contents.chunks(chunk_size).enumerate() {
|
||||
let var_name = format!("__talk_chunk_{}__", chunk_idx);
|
||||
|
||||
if chunk.len() == 1 {
|
||||
result.push_str(&format!("let {} = {};\n", var_name, chunk[0]));
|
||||
} else {
|
||||
let mut chunk_expr = chunk[0].clone();
|
||||
for line in &chunk[1..] {
|
||||
chunk_expr.push_str(" + \"\\n\" + ");
|
||||
chunk_expr.push_str(line);
|
||||
}
|
||||
result.push_str(&format!("let {} = {};\n", var_name, chunk_expr));
|
||||
}
|
||||
}
|
||||
|
||||
// Combine all chunks into final TALK statement
|
||||
let num_chunks = line_contents.len().div_ceil(chunk_size);
|
||||
if line_contents.is_empty() {
|
||||
return "TALK \"\";\n".to_string();
|
||||
} else if num_chunks == 1 {
|
||||
// Single chunk - use the first variable directly
|
||||
result.push_str("TALK __talk_chunk_0__;\n");
|
||||
} else {
|
||||
// Multiple chunks - need hierarchical chunking to avoid complexity
|
||||
// Combine chunks in groups of 5 to create intermediate variables
|
||||
let combine_chunk_size = 5;
|
||||
let mut chunk_vars: Vec<String> = (0..num_chunks)
|
||||
.map(|i| format!("__talk_chunk_{}__", i))
|
||||
.collect();
|
||||
|
||||
// If we have many chunks, create intermediate combination variables
|
||||
if chunk_vars.len() > combine_chunk_size {
|
||||
let mut level = 0;
|
||||
while chunk_vars.len() > combine_chunk_size {
|
||||
let mut new_vars: Vec<String> = Vec::new();
|
||||
for (idx, sub_chunk) in chunk_vars.chunks(combine_chunk_size).enumerate() {
|
||||
let var_name = format!("__talk_combined_{}_{}__", level, idx);
|
||||
if sub_chunk.len() == 1 {
|
||||
new_vars.push(sub_chunk[0].clone());
|
||||
} else {
|
||||
let mut expr = sub_chunk[0].clone();
|
||||
for var in &sub_chunk[1..] {
|
||||
expr.push_str(" + \"\\n\" + ");
|
||||
expr.push_str(var);
|
||||
}
|
||||
result.push_str(&format!("let {} = {};\n", var_name, expr));
|
||||
new_vars.push(var_name);
|
||||
}
|
||||
}
|
||||
chunk_vars = new_vars;
|
||||
level += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Final TALK statement with combined chunks
|
||||
if chunk_vars.len() == 1 {
|
||||
result.push_str(&format!("TALK {};\n", chunk_vars[0]));
|
||||
} else {
|
||||
let mut expr = chunk_vars[0].clone();
|
||||
for var in &chunk_vars[1..] {
|
||||
expr.push_str(" + \"\\n\" + ");
|
||||
expr.push_str(var);
|
||||
}
|
||||
result.push_str(&format!("TALK {};\n", expr));
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
|
@ -4,6 +4,8 @@ use crate::basic::keywords::table_definition::process_table_definitions;
|
|||
use crate::basic::keywords::webhook::execute_webhook_registration;
|
||||
use crate::core::shared::models::TriggerKind;
|
||||
use crate::core::shared::state::AppState;
|
||||
use diesel::QueryableByName;
|
||||
// use diesel::sql_types::Text; // Removed unused import
|
||||
use diesel::ExpressionMethods;
|
||||
use diesel::QueryDsl;
|
||||
use diesel::RunQueryDsl;
|
||||
|
|
@ -11,6 +13,7 @@ use log::{trace, warn};
|
|||
use regex::Regex;
|
||||
|
||||
pub mod goto_transform;
|
||||
pub mod blocks;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
|
|
@ -22,6 +25,7 @@ use std::sync::Arc;
|
|||
pub struct ParamDeclaration {
|
||||
pub name: String,
|
||||
pub param_type: String,
|
||||
pub original_type: String,
|
||||
pub example: Option<String>,
|
||||
pub description: String,
|
||||
pub required: bool,
|
||||
|
|
@ -55,6 +59,8 @@ pub struct MCPProperty {
|
|||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub example: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<String>,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAITool {
|
||||
|
|
@ -84,6 +90,8 @@ pub struct OpenAIProperty {
|
|||
pub example: Option<String>,
|
||||
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<String>,
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct BasicCompiler {
|
||||
|
|
@ -262,6 +270,7 @@ impl BasicCompiler {
|
|||
Ok(Some(ParamDeclaration {
|
||||
name,
|
||||
param_type: Self::normalize_type(¶m_type),
|
||||
original_type: param_type.to_lowercase(),
|
||||
example,
|
||||
description,
|
||||
required: true,
|
||||
|
|
@ -341,12 +350,20 @@ impl BasicCompiler {
|
|||
let mut properties = HashMap::new();
|
||||
let mut required = Vec::new();
|
||||
for param in &tool_def.parameters {
|
||||
// Add format="date" for DATE type parameters to indicate ISO 8601 format
|
||||
let format = if param.original_type == "date" {
|
||||
Some("date".to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
properties.insert(
|
||||
param.name.clone(),
|
||||
MCPProperty {
|
||||
prop_type: param.param_type.clone(),
|
||||
description: param.description.clone(),
|
||||
example: param.example.clone(),
|
||||
format,
|
||||
},
|
||||
);
|
||||
if param.required {
|
||||
|
|
@ -369,6 +386,13 @@ impl BasicCompiler {
|
|||
let mut properties = HashMap::new();
|
||||
let mut required = Vec::new();
|
||||
for param in &tool_def.parameters {
|
||||
// Add format="date" for DATE type parameters to indicate ISO 8601 format
|
||||
let format = if param.original_type == "date" {
|
||||
Some("date".to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
properties.insert(
|
||||
param.name.clone(),
|
||||
OpenAIProperty {
|
||||
|
|
@ -376,6 +400,7 @@ impl BasicCompiler {
|
|||
description: param.description.clone(),
|
||||
example: param.example.clone(),
|
||||
enum_values: param.enum_values.clone(),
|
||||
format,
|
||||
},
|
||||
);
|
||||
if param.required {
|
||||
|
|
@ -434,6 +459,10 @@ impl BasicCompiler {
|
|||
.execute(&mut conn)
|
||||
.ok();
|
||||
}
|
||||
|
||||
let website_regex = Regex::new(r#"(?i)USE\s+WEBSITE\s+"([^"]+)"(?:\s+REFRESH\s+"([^"]+)")?"#)
|
||||
.unwrap_or_else(|_| Regex::new(r"").unwrap());
|
||||
|
||||
for line in source.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty()
|
||||
|
|
@ -505,14 +534,7 @@ impl BasicCompiler {
|
|||
}
|
||||
|
||||
if trimmed.to_uppercase().starts_with("USE WEBSITE") {
|
||||
let re = match Regex::new(r#"(?i)USE\s+WEBSITE\s+"([^"]+)"(?:\s+REFRESH\s+"([^"]+)")?"#) {
|
||||
Ok(re) => re,
|
||||
Err(e) => {
|
||||
log::warn!("Invalid regex pattern: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if let Some(caps) = re.captures(&normalized) {
|
||||
if let Some(caps) = website_regex.captures(&normalized) {
|
||||
if let Some(url_match) = caps.get(1) {
|
||||
let url = url_match.as_str();
|
||||
let refresh = caps.get(2).map(|m| m.as_str()).unwrap_or("1m");
|
||||
|
|
@ -570,8 +592,353 @@ impl BasicCompiler {
|
|||
} else {
|
||||
self.previous_schedules.remove(&script_name);
|
||||
}
|
||||
|
||||
// Convert SAVE statements with field lists to map-based SAVE
|
||||
let result = match self.convert_save_statements(&result, bot_id) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
log::warn!("SAVE conversion failed: {}, using original code", e);
|
||||
result
|
||||
}
|
||||
};
|
||||
// Convert BEGIN TALK and BEGIN MAIL blocks to Rhai code
|
||||
let result = crate::basic::compiler::blocks::convert_begin_blocks(&result);
|
||||
// Convert IF ... THEN / END IF to if ... { }
|
||||
let result = crate::basic::ScriptService::convert_if_then_syntax(&result);
|
||||
// Convert SELECT ... CASE / END SELECT to match expressions
|
||||
let result = crate::basic::ScriptService::convert_select_case_syntax(&result);
|
||||
// Convert BASIC keywords to lowercase (but preserve variable casing)
|
||||
let result = crate::basic::ScriptService::convert_keywords_to_lowercase(&result);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert SAVE statements with field lists to map-based SAVE
|
||||
/// SAVE "table", field1, field2, ... -> let __data__ = #{field1: value1, ...}; SAVE "table", __data__
|
||||
fn convert_save_statements(
|
||||
&self,
|
||||
source: &str,
|
||||
bot_id: uuid::Uuid,
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
let mut result = String::new();
|
||||
let mut save_counter = 0;
|
||||
|
||||
for line in source.lines() {
|
||||
let trimmed = line.trim();
|
||||
|
||||
// Check if this is a SAVE statement with field list
|
||||
if trimmed.to_uppercase().starts_with("SAVE ") {
|
||||
if let Some(converted) = self.convert_save_line(line, bot_id, &mut save_counter)? {
|
||||
result.push_str(&converted);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert a single SAVE statement line if it has a field list
|
||||
fn convert_save_line(
|
||||
&self,
|
||||
line: &str,
|
||||
bot_id: uuid::Uuid,
|
||||
save_counter: &mut usize,
|
||||
) -> Result<Option<String>, Box<dyn Error + Send + Sync>> {
|
||||
let trimmed = line.trim();
|
||||
|
||||
// Parse SAVE statement
|
||||
// Format: SAVE "table", value1, value2, ...
|
||||
let upper = trimmed.to_uppercase();
|
||||
if !upper.starts_with("SAVE ") {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Extract the content after "SAVE"
|
||||
let content = &trimmed[4..].trim();
|
||||
|
||||
// Parse table name and values
|
||||
let parts = self.parse_save_statement(content)?;
|
||||
|
||||
// If only 2 parts (table + data map), leave as-is (structured SAVE)
|
||||
if parts.len() <= 2 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// This is a field list SAVE - convert to map-based SAVE
|
||||
let table_name = &parts[0];
|
||||
|
||||
// Strip quotes from table name if present
|
||||
let table_name = table_name.trim_matches('"');
|
||||
|
||||
// Debug log to see what we're querying
|
||||
log::info!("[SAVE] Converting SAVE for table: '{}' (original: '{}')", table_name, &parts[0]);
|
||||
|
||||
// Get column names from TABLE definition (preserves order from .bas file)
|
||||
let column_names = self.get_table_columns_for_save(table_name, bot_id)?;
|
||||
|
||||
// Build the map by matching variable names to column names (case-insensitive)
|
||||
let values: Vec<&String> = parts.iter().skip(1).collect();
|
||||
let mut map_pairs = Vec::new();
|
||||
|
||||
log::info!("[SAVE] Matching {} variables to {} columns", values.len(), column_names.len());
|
||||
|
||||
for value_var in values.iter() {
|
||||
// Find the column that matches this variable (case-insensitive)
|
||||
let value_lower = value_var.to_lowercase();
|
||||
|
||||
if let Some(column_name) = column_names.iter().find(|col| col.to_lowercase() == value_lower) {
|
||||
map_pairs.push(format!("{}: {}", column_name, value_var));
|
||||
} else {
|
||||
log::warn!("[SAVE] No matching column for variable '{}'", value_var);
|
||||
}
|
||||
}
|
||||
|
||||
let map_expr = format!("#{{{}}}", map_pairs.join(", "));
|
||||
let data_var = format!("__save_data_{}__", save_counter);
|
||||
*save_counter += 1;
|
||||
|
||||
// Generate: let __save_data_N__ = #{...}; SAVE "table", __save_data_N__
|
||||
let converted = format!("let {} = {}; SAVE {}, {}", data_var, map_expr, table_name, data_var);
|
||||
|
||||
Ok(Some(converted))
|
||||
}
|
||||
|
||||
/// Parse SAVE statement into parts
|
||||
fn parse_save_statement(&self, content: &str) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
|
||||
// Simple parsing - split by comma, but respect quoted strings
|
||||
let mut parts = Vec::new();
|
||||
let mut current = String::new();
|
||||
let mut in_quotes = false;
|
||||
let mut chars = content.chars().peekable();
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
match c {
|
||||
'"' if chars.peek() == Some(&'"') => {
|
||||
// Escaped quote
|
||||
current.push('"');
|
||||
chars.next();
|
||||
}
|
||||
'"' => {
|
||||
in_quotes = !in_quotes;
|
||||
current.push('"');
|
||||
}
|
||||
',' if !in_quotes => {
|
||||
parts.push(current.trim().to_string());
|
||||
current = String::new();
|
||||
}
|
||||
_ => {
|
||||
current.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !current.trim().is_empty() {
|
||||
parts.push(current.trim().to_string());
|
||||
}
|
||||
|
||||
Ok(parts)
|
||||
}
|
||||
|
||||
/// Get column names for a table from TABLE definition (preserves field order)
|
||||
fn get_table_columns_for_save(
|
||||
&self,
|
||||
table_name: &str,
|
||||
bot_id: uuid::Uuid,
|
||||
) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
|
||||
// Try to parse TABLE definition from the bot's .bas files to get correct field order
|
||||
if let Ok(columns) = self.get_columns_from_table_definition(table_name, bot_id) {
|
||||
if !columns.is_empty() {
|
||||
log::info!("Using TABLE definition for '{}': {} columns", table_name, columns.len());
|
||||
return Ok(columns);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to database schema query (may have different order)
|
||||
self.get_columns_from_database_schema(table_name, bot_id)
|
||||
}
|
||||
|
||||
/// Parse TABLE definition from .bas files to get field order
|
||||
fn get_columns_from_table_definition(
|
||||
&self,
|
||||
table_name: &str,
|
||||
bot_id: uuid::Uuid,
|
||||
) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
|
||||
// use std::path::Path;
|
||||
|
||||
// Find the tables.bas file in the bot's data directory
|
||||
let bot_name = self.get_bot_name_by_id(bot_id)?;
|
||||
let tables_path = format!("/opt/gbo/data/{}.gbai/{}.gbdialog/tables.bas", bot_name, bot_name);
|
||||
|
||||
let tables_content = fs::read_to_string(&tables_path)?;
|
||||
let columns = self.parse_table_definition_for_fields(&tables_content, table_name)?;
|
||||
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
/// Parse TABLE definition and extract field names in order
|
||||
fn parse_table_definition_for_fields(
|
||||
&self,
|
||||
content: &str,
|
||||
table_name: &str,
|
||||
) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
|
||||
let mut columns = Vec::new();
|
||||
let mut in_target_table = false;
|
||||
|
||||
for line in content.lines() {
|
||||
let trimmed = line.trim();
|
||||
|
||||
if trimmed.starts_with("TABLE ") && trimmed.contains(table_name) {
|
||||
in_target_table = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_target_table {
|
||||
if trimmed.starts_with("END TABLE") {
|
||||
break;
|
||||
}
|
||||
|
||||
if trimmed.starts_with("FIELD ") {
|
||||
// Parse: FIELD fieldName AS TYPE
|
||||
let parts: Vec<&str> = trimmed.split_whitespace().collect();
|
||||
if parts.len() >= 2 {
|
||||
let field_name = parts[1].to_string();
|
||||
columns.push(field_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
/// Get bot name by bot_id
|
||||
fn get_bot_name_by_id(&self, bot_id: uuid::Uuid) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
use crate::core::shared::models::schema::bots::dsl::*;
|
||||
use diesel::QueryDsl;
|
||||
|
||||
let mut conn = self.state.conn.get()
|
||||
.map_err(|e| format!("Failed to get DB connection: {}", e))?;
|
||||
|
||||
let bot_name: String = bots
|
||||
.filter(id.eq(&bot_id))
|
||||
.select(name)
|
||||
.first(&mut conn)
|
||||
.map_err(|e| format!("Failed to get bot name: {}", e))?;
|
||||
|
||||
Ok(bot_name)
|
||||
}
|
||||
|
||||
/// Get column names from database schema (fallback, order may differ)
|
||||
fn get_columns_from_database_schema(
|
||||
&self,
|
||||
table_name: &str,
|
||||
bot_id: uuid::Uuid,
|
||||
) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
|
||||
use diesel::sql_query;
|
||||
use diesel::sql_types::Text;
|
||||
use diesel::RunQueryDsl;
|
||||
|
||||
#[derive(QueryableByName)]
|
||||
struct ColumnRow {
|
||||
#[diesel(sql_type = Text)]
|
||||
column_name: String,
|
||||
}
|
||||
|
||||
// First, try to get columns from the main database's information_schema
|
||||
// This works because tables are created in the bot's database which shares the schema
|
||||
let mut conn = self.state.conn.get()
|
||||
.map_err(|e| format!("Failed to get DB connection: {}", e))?;
|
||||
|
||||
let query = format!(
|
||||
"SELECT column_name FROM information_schema.columns \
|
||||
WHERE table_name = '{}' AND table_schema = 'public' \
|
||||
ORDER BY ordinal_position",
|
||||
table_name
|
||||
);
|
||||
|
||||
let columns: Vec<String> = match sql_query(&query).load(&mut conn) {
|
||||
Ok(cols) => {
|
||||
if cols.is_empty() {
|
||||
log::warn!("Found 0 columns for table '{}' in main database, trying bot database", table_name);
|
||||
// Try bot's database as fallback when main DB returns empty
|
||||
let bot_pool = self.state.bot_database_manager.get_bot_pool(bot_id);
|
||||
if let Ok(pool) = bot_pool {
|
||||
let mut bot_conn = pool.get()
|
||||
.map_err(|e| format!("Bot DB error: {}", e))?;
|
||||
|
||||
let bot_query = format!(
|
||||
"SELECT column_name FROM information_schema.columns \
|
||||
WHERE table_name = '{}' AND table_schema = 'public' \
|
||||
ORDER BY ordinal_position",
|
||||
table_name
|
||||
);
|
||||
|
||||
match sql_query(&bot_query).load(&mut *bot_conn) {
|
||||
Ok(bot_cols) => {
|
||||
log::info!("Found {} columns for table '{}' in bot database", bot_cols.len(), table_name);
|
||||
bot_cols.into_iter()
|
||||
.map(|c: ColumnRow| c.column_name)
|
||||
.collect()
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to get columns from bot DB for '{}': {}", table_name, e);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::error!("No bot database available for bot_id: {}", bot_id);
|
||||
Vec::new()
|
||||
}
|
||||
} else {
|
||||
log::info!("Found {} columns for table '{}' in main database", cols.len(), table_name);
|
||||
cols.into_iter()
|
||||
.map(|c: ColumnRow| c.column_name)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to get columns for table '{}' from main DB: {}", table_name, e);
|
||||
|
||||
// Try bot's database as fallback
|
||||
let bot_pool = self.state.bot_database_manager.get_bot_pool(bot_id);
|
||||
if let Ok(pool) = bot_pool {
|
||||
let mut bot_conn = pool.get()
|
||||
.map_err(|e| format!("Bot DB error: {}", e))?;
|
||||
|
||||
let bot_query = format!(
|
||||
"SELECT column_name FROM information_schema.columns \
|
||||
WHERE table_name = '{}' AND table_schema = 'public' \
|
||||
ORDER BY ordinal_position",
|
||||
table_name
|
||||
);
|
||||
|
||||
match sql_query(&bot_query).load(&mut *bot_conn) {
|
||||
Ok(cols) => {
|
||||
log::info!("Found {} columns for table '{}' in bot database", cols.len(), table_name);
|
||||
cols.into_iter()
|
||||
.filter(|c: &ColumnRow| c.column_name != "id")
|
||||
.map(|c: ColumnRow| c.column_name)
|
||||
.collect()
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to get columns from bot DB for '{}': {}", table_name, e);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::error!("No bot database available for bot_id: {}", bot_id);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(columns)
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct CompilationResult {
|
||||
|
|
|
|||
|
|
@ -30,44 +30,9 @@ async fn execute_create_draft(
|
|||
subject: &str,
|
||||
reply_text: &str,
|
||||
) -> Result<String, String> {
|
||||
#[cfg(feature = "mail")]
|
||||
{
|
||||
use crate::email::{fetch_latest_sent_to, save_email_draft, SaveDraftRequest};
|
||||
|
||||
let config = state.config.as_ref().ok_or("No email config")?;
|
||||
|
||||
let previous_email = fetch_latest_sent_to(&config.email, to)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let email_body = if previous_email.is_empty() {
|
||||
reply_text.to_string()
|
||||
} else {
|
||||
let email_separator = "<br><hr><br>";
|
||||
let formatted_reply = reply_text.replace("FIX", "Fixed");
|
||||
let formatted_old = previous_email.replace('\n', "<br>");
|
||||
format!("{formatted_reply}{email_separator}{formatted_old}")
|
||||
};
|
||||
|
||||
let draft_request = SaveDraftRequest {
|
||||
account_id: String::new(),
|
||||
to: to.to_string(),
|
||||
cc: None,
|
||||
bcc: None,
|
||||
subject: subject.to_string(),
|
||||
body: email_body,
|
||||
};
|
||||
|
||||
save_email_draft(&config.email, &draft_request)
|
||||
.await
|
||||
.map(|()| "Draft saved successfully".to_string())
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "mail"))]
|
||||
{
|
||||
use chrono::Utc;
|
||||
use diesel::prelude::*;
|
||||
use uuid::Uuid;
|
||||
use chrono::Utc;
|
||||
use diesel::prelude::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
let draft_id = Uuid::new_v4();
|
||||
let conn = state.conn.clone();
|
||||
|
|
@ -94,5 +59,4 @@ async fn execute_create_draft(
|
|||
})
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,7 +56,12 @@ pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Eng
|
|||
#[cfg(not(feature = "llm"))]
|
||||
let llm: Option<()> = None;
|
||||
|
||||
let fut = create_site(config, s3, bucket, bot_id, llm, alias, template_dir, prompt);
|
||||
let params = SiteCreationParams {
|
||||
alias,
|
||||
template_dir,
|
||||
prompt,
|
||||
};
|
||||
let fut = create_site(config, s3, bucket, bot_id, llm, params);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("Site creation failed: {}", e))?;
|
||||
|
|
@ -66,6 +71,12 @@ pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Eng
|
|||
.expect("valid syntax registration");
|
||||
}
|
||||
|
||||
struct SiteCreationParams {
|
||||
alias: Dynamic,
|
||||
template_dir: Dynamic,
|
||||
prompt: Dynamic,
|
||||
}
|
||||
|
||||
#[cfg(feature = "llm")]
|
||||
async fn create_site(
|
||||
config: crate::core::config::AppConfig,
|
||||
|
|
@ -73,13 +84,11 @@ async fn create_site(
|
|||
bucket: String,
|
||||
bot_id: String,
|
||||
llm: Option<Arc<dyn LLMProvider>>,
|
||||
alias: Dynamic,
|
||||
template_dir: Dynamic,
|
||||
prompt: Dynamic,
|
||||
params: SiteCreationParams,
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
let alias_str = alias.to_string();
|
||||
let template_dir_str = template_dir.to_string();
|
||||
let prompt_str = prompt.to_string();
|
||||
let alias_str = params.alias.to_string();
|
||||
let template_dir_str = params.template_dir.to_string();
|
||||
let prompt_str = params.prompt.to_string();
|
||||
|
||||
info!(
|
||||
"CREATE SITE: {} from template {}",
|
||||
|
|
@ -114,13 +123,11 @@ async fn create_site(
|
|||
bucket: String,
|
||||
bot_id: String,
|
||||
_llm: Option<()>,
|
||||
alias: Dynamic,
|
||||
template_dir: Dynamic,
|
||||
prompt: Dynamic,
|
||||
params: SiteCreationParams,
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
let alias_str = alias.to_string();
|
||||
let template_dir_str = template_dir.to_string();
|
||||
let prompt_str = prompt.to_string();
|
||||
let alias_str = params.alias.to_string();
|
||||
let template_dir_str = params.template_dir.to_string();
|
||||
let prompt_str = params.prompt.to_string();
|
||||
|
||||
info!(
|
||||
"CREATE SITE: {} from template {}",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use super::table_access::{check_table_access, AccessType, UserRoles};
|
|||
use crate::core::shared::{sanitize_identifier, sanitize_sql_value};
|
||||
use crate::core::shared::models::UserSession;
|
||||
use crate::core::shared::state::AppState;
|
||||
use crate::core::shared::utils::{json_value_to_dynamic, to_array};
|
||||
use crate::core::shared::utils::{convert_date_to_iso_format, json_value_to_dynamic, to_array};
|
||||
use diesel::prelude::*;
|
||||
use diesel::sql_query;
|
||||
use diesel::sql_types::Text;
|
||||
|
|
@ -29,40 +29,127 @@ pub fn register_data_operations(state: Arc<AppState>, user: UserSession, engine:
|
|||
}
|
||||
|
||||
pub fn register_save_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = Arc::clone(&state);
|
||||
let user_roles = UserRoles::from_user_session(&user);
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
["SAVE", "$expr$", ",", "$expr$", ",", "$expr$"],
|
||||
false,
|
||||
move |context, inputs| {
|
||||
let table = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
let id = context.eval_expression_tree(&inputs[1])?;
|
||||
let data = context.eval_expression_tree(&inputs[2])?;
|
||||
// SAVE with variable arguments: SAVE "table", id, field1, field2, ...
|
||||
// Each pattern: table + id + (1 to 64 fields)
|
||||
// Minimum: table + id + 1 field = 4 expressions total
|
||||
register_save_variants(&state, user_roles, engine);
|
||||
}
|
||||
|
||||
trace!("SAVE to table: {}, id: {:?}", table, id);
|
||||
fn register_save_variants(state: &Arc<AppState>, user_roles: UserRoles, engine: &mut Engine) {
|
||||
// Register positional saves FIRST (in descending order), so longer patterns
|
||||
// are tried before shorter ones. This ensures that SAVE with 22 fields matches
|
||||
// the 22-field pattern, not the 3-field structured save pattern.
|
||||
// Pattern: SAVE + table + (field1 + field2 + ... + fieldN)
|
||||
// Total elements = 2 (SAVE + table) + num_fields * 2 (comma + expr)
|
||||
// For 22 fields: 2 + 22*2 = 46 elements
|
||||
|
||||
let mut conn = state_clone
|
||||
.conn
|
||||
.get()
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
// Register in descending order (70 down to 2) so longer patterns override shorter ones
|
||||
for num_fields in (2..=70).rev() {
|
||||
let mut pattern = vec!["SAVE", "$expr$"];
|
||||
for _ in 0..num_fields {
|
||||
pattern.push(",");
|
||||
pattern.push("$expr$");
|
||||
}
|
||||
|
||||
// Check write access
|
||||
if let Err(e) =
|
||||
check_table_access(&mut conn, &table, &user_roles, AccessType::Write)
|
||||
{
|
||||
warn!("SAVE access denied: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
// Log pattern registration for key values
|
||||
if num_fields == 22 || num_fields == 21 || num_fields == 23 {
|
||||
log::info!("Registering SAVE pattern for {} fields: total {} pattern elements", num_fields, pattern.len());
|
||||
}
|
||||
|
||||
let result = execute_save(&mut conn, &table, &id, &data)
|
||||
.map_err(|e| format!("SAVE error: {}", e))?;
|
||||
let state_clone = Arc::clone(state);
|
||||
let user_roles_clone = user_roles.clone();
|
||||
let field_count = num_fields;
|
||||
|
||||
Ok(json_value_to_dynamic(&result))
|
||||
},
|
||||
)
|
||||
.expect("valid syntax registration");
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
pattern,
|
||||
false,
|
||||
move |context, inputs| {
|
||||
// Pattern: ["SAVE", "$expr$", ",", "$expr$", ",", "$expr$", ...]
|
||||
// inputs[0] = table, inputs[2], inputs[4], inputs[6], ... = field values
|
||||
// Commas are at inputs[1], inputs[3], inputs[5], ...
|
||||
let table = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
|
||||
trace!("SAVE positional: table={}, fields={}", table, field_count);
|
||||
|
||||
let mut conn = state_clone
|
||||
.conn
|
||||
.get()
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Err(e) =
|
||||
check_table_access(&mut conn, &table, &user_roles_clone, AccessType::Write)
|
||||
{
|
||||
warn!("SAVE access denied: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
// Get column names from database schema
|
||||
let column_names = crate::basic::keywords::table_access::get_table_columns(&mut conn, &table);
|
||||
|
||||
// Build data map from positional field values
|
||||
let mut data_map: Map = Map::new();
|
||||
|
||||
// Field values are at inputs[2], inputs[4], inputs[6], ... (every other element starting from 2)
|
||||
for i in 0..field_count {
|
||||
if i < column_names.len() {
|
||||
let value_expr = &inputs[i * 2 + 2]; // 2, 4, 6, 8, ...
|
||||
let value = context.eval_expression_tree(value_expr)?;
|
||||
data_map.insert(column_names[i].clone().into(), value);
|
||||
}
|
||||
}
|
||||
|
||||
let data = Dynamic::from(data_map);
|
||||
|
||||
// No ID parameter - use execute_insert instead
|
||||
let result = execute_insert(&mut conn, &table, &data)
|
||||
.map_err(|e| format!("SAVE error: {}", e))?;
|
||||
|
||||
Ok(json_value_to_dynamic(&result))
|
||||
},
|
||||
)
|
||||
.expect("valid syntax registration");
|
||||
}
|
||||
|
||||
// Register structured save LAST (after all positional saves)
|
||||
// This ensures that SAVE statements with many fields use positional patterns,
|
||||
// and only SAVE statements with exactly 3 expressions use the structured pattern
|
||||
{
|
||||
let state_clone = Arc::clone(state);
|
||||
let user_roles_clone = user_roles.clone();
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
["SAVE", "$expr$", ",", "$expr$", ",", "$expr$"],
|
||||
false,
|
||||
move |context, inputs| {
|
||||
let table = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||
let id = context.eval_expression_tree(&inputs[1])?;
|
||||
let data = context.eval_expression_tree(&inputs[2])?;
|
||||
|
||||
trace!("SAVE structured: table={}, id={:?}", table, id);
|
||||
|
||||
let mut conn = state_clone
|
||||
.conn
|
||||
.get()
|
||||
.map_err(|e| format!("DB error: {}", e))?;
|
||||
|
||||
if let Err(e) =
|
||||
check_table_access(&mut conn, &table, &user_roles_clone, AccessType::Write)
|
||||
{
|
||||
warn!("SAVE access denied: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
let result = execute_save(&mut conn, &table, &id, &data)
|
||||
.map_err(|e| format!("SAVE error: {}", e))?;
|
||||
|
||||
Ok(json_value_to_dynamic(&result))
|
||||
},
|
||||
)
|
||||
.expect("valid syntax registration");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_insert_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||||
|
|
@ -470,7 +557,9 @@ fn execute_save(
|
|||
|
||||
for (key, value) in &data_map {
|
||||
let sanitized_key = sanitize_identifier(key);
|
||||
let sanitized_value = format!("'{}'", sanitize_sql_value(&value.to_string()));
|
||||
let value_str = value.to_string();
|
||||
let converted_value = convert_date_to_iso_format(&value_str);
|
||||
let sanitized_value = format!("'{}'", sanitize_sql_value(&converted_value));
|
||||
columns.push(sanitized_key.clone());
|
||||
values.push(sanitized_value.clone());
|
||||
update_sets.push(format!("{} = {}", sanitized_key, sanitized_value));
|
||||
|
|
@ -511,7 +600,9 @@ fn execute_insert(
|
|||
|
||||
for (key, value) in &data_map {
|
||||
columns.push(sanitize_identifier(key));
|
||||
values.push(format!("'{}'", sanitize_sql_value(&value.to_string())));
|
||||
let value_str = value.to_string();
|
||||
let converted_value = convert_date_to_iso_format(&value_str);
|
||||
values.push(format!("'{}'", sanitize_sql_value(&converted_value)));
|
||||
}
|
||||
|
||||
let query = format!(
|
||||
|
|
@ -564,10 +655,12 @@ fn execute_update(
|
|||
|
||||
let mut update_sets: Vec<String> = Vec::new();
|
||||
for (key, value) in &data_map {
|
||||
let value_str = value.to_string();
|
||||
let converted_value = convert_date_to_iso_format(&value_str);
|
||||
update_sets.push(format!(
|
||||
"{} = '{}'",
|
||||
sanitize_identifier(key),
|
||||
sanitize_sql_value(&value.to_string())
|
||||
sanitize_sql_value(&converted_value)
|
||||
));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ use crate::core::shared::sanitize_identifier;
|
|||
use crate::core::urls::ApiUrls;
|
||||
use crate::security::error_sanitizer::log_and_sanitize;
|
||||
use crate::security::sql_guard::{
|
||||
build_safe_count_query, build_safe_select_query, is_table_allowed_with_conn, validate_table_name,
|
||||
build_safe_count_query, build_safe_select_by_id_query, build_safe_select_query,
|
||||
is_table_allowed_with_conn, validate_table_name,
|
||||
};
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
|
|
@ -257,10 +258,21 @@ pub async fn get_record_handler(
|
|||
}
|
||||
};
|
||||
|
||||
let query = format!(
|
||||
"SELECT row_to_json(t.*) as data FROM {} t WHERE id = $1",
|
||||
table_name
|
||||
);
|
||||
let query = match build_safe_select_by_id_query(&table_name) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
warn!("Failed to build safe query for {}: {}", table_name, e);
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(RecordResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
message: Some("Invalid table name".to_string()),
|
||||
}),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let row: Result<Option<JsonRow>, _> = sql_query(&query)
|
||||
.bind::<diesel::sql_types::Uuid, _>(record_id)
|
||||
|
|
@ -700,7 +712,17 @@ pub async fn count_records_handler(
|
|||
return (StatusCode::FORBIDDEN, Json(json!({ "error": e }))).into_response();
|
||||
}
|
||||
|
||||
let query = format!("SELECT COUNT(*) as count FROM {}", table_name);
|
||||
let query = match build_safe_count_query(&table_name) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
warn!("Failed to build count query: {}", e);
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": "Invalid table name" })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let result: Result<CountResult, _> = sql_query(&query).get_result(&mut conn);
|
||||
|
||||
match result {
|
||||
|
|
@ -747,14 +769,18 @@ pub async fn search_records_handler(
|
|||
}
|
||||
};
|
||||
|
||||
let safe_search = search_term.replace('%', "\\%").replace('_', "\\_");
|
||||
|
||||
let query = format!(
|
||||
"SELECT row_to_json(t.*) as data FROM {} t WHERE
|
||||
COALESCE(t.title::text, '') || ' ' || COALESCE(t.name::text, '') || ' ' || COALESCE(t.description::text, '')
|
||||
ILIKE '%{}%' LIMIT {}",
|
||||
table_name, search_term, limit
|
||||
ILIKE '%' || $1 || '%' LIMIT {}",
|
||||
table_name, limit
|
||||
);
|
||||
|
||||
let rows: Result<Vec<JsonRow>, _> = sql_query(&query).get_results(&mut conn);
|
||||
let rows: Result<Vec<JsonRow>, _> = sql_query(&query)
|
||||
.bind::<diesel::sql_types::Text, _>(&safe_search)
|
||||
.get_results(&mut conn);
|
||||
|
||||
match rows {
|
||||
Ok(data) => {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ pub fn register_enhanced_llm_keyword(state: Arc<AppState>, user: UserSession, en
|
|||
|
||||
tokio::spawn(async move {
|
||||
let router = SmartLLMRouter::new(state_for_spawn);
|
||||
let goal = OptimizationGoal::from_str(&optimization);
|
||||
let goal = OptimizationGoal::from_str_name(&optimization);
|
||||
|
||||
match crate::llm::smart_router::enhanced_llm_call(
|
||||
&router, &prompt, goal, None, None,
|
||||
|
|
|
|||
|
|
@ -71,14 +71,12 @@ async fn share_bot_memory(
|
|||
|
||||
let target_bot_uuid = find_bot_by_name(&mut conn, target_bot_name)?;
|
||||
|
||||
let memory_value = match bot_memories::table
|
||||
let memory_value = bot_memories::table
|
||||
.filter(bot_memories::bot_id.eq(source_bot_uuid))
|
||||
.filter(bot_memories::key.eq(memory_key))
|
||||
.select(bot_memories::value)
|
||||
.first(&mut conn) {
|
||||
Ok(value) => value,
|
||||
Err(_) => String::new(),
|
||||
};
|
||||
.first(&mut conn)
|
||||
.unwrap_or_default();
|
||||
|
||||
let shared_memory = BotSharedMemory {
|
||||
id: Uuid::new_v4(),
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ impl FaceApiService {
|
|||
|
||||
Ok(FaceVerificationResult::match_found(
|
||||
result.confidence,
|
||||
options.confidence_threshold as f64,
|
||||
options.confidence_threshold,
|
||||
0,
|
||||
).with_face_ids(face1_id, face2_id))
|
||||
}
|
||||
|
|
@ -783,7 +783,7 @@ impl FaceApiService {
|
|||
// Simulate detection based on image size/content
|
||||
// In production, actual detection algorithms would be used
|
||||
let num_faces = if image_bytes.len() > 100_000 {
|
||||
(image_bytes.len() / 500_000).min(5).max(1)
|
||||
(image_bytes.len() / 500_000).clamp(1, 5)
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
|
@ -821,7 +821,7 @@ impl FaceApiService {
|
|||
attributes: if options.return_attributes.unwrap_or(false) {
|
||||
Some(FaceAttributes {
|
||||
age: Some(25.0 + (face_id.as_u128() % 40) as f32),
|
||||
gender: Some(if face_id.as_u128() % 2 == 0 {
|
||||
gender: Some(if face_id.as_u128().is_multiple_of(2) {
|
||||
Gender::Male
|
||||
} else {
|
||||
Gender::Female
|
||||
|
|
|
|||
|
|
@ -71,6 +71,22 @@ pub fn get_keyword(state: Arc<AppState>, user_session: UserSession, engine: &mut
|
|||
}
|
||||
fn is_safe_path(path: &str) -> bool {
|
||||
if path.starts_with("https://") || path.starts_with("http://") {
|
||||
if let Ok(parsed_url) = url::Url::parse(path) {
|
||||
if let Some(host) = parsed_url.host_str() {
|
||||
let host_lower = host.to_lowercase();
|
||||
if host_lower == "localhost"
|
||||
|| host_lower.contains("169.254")
|
||||
|| host_lower.starts_with("127.")
|
||||
|| host_lower.starts_with("10.")
|
||||
|| host_lower.starts_with("192.168.")
|
||||
|| host_lower.starts_with("172.")
|
||||
|| host_lower == "::1"
|
||||
|| host_lower.contains("0x7f")
|
||||
|| host_lower.contains("metadata.google.internal") {
|
||||
return false; // Prevent obvious SSRF
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if path.contains("..") || path.starts_with('/') {
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ impl Default for McpConnection {
|
|||
fn default() -> Self {
|
||||
Self {
|
||||
connection_type: ConnectionType::Http,
|
||||
url: "http://localhost:8080".to_string(),
|
||||
url: "http://localhost:9000".to_string(),
|
||||
port: None,
|
||||
timeout_seconds: 30,
|
||||
max_retries: 3,
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ pub mod string_functions;
|
|||
pub mod switch_case;
|
||||
pub mod table_access;
|
||||
pub mod table_definition;
|
||||
pub mod table_migration;
|
||||
pub mod universal_messaging;
|
||||
pub mod use_tool;
|
||||
pub mod use_website;
|
||||
|
|
|
|||
|
|
@ -405,6 +405,30 @@ pub fn filter_write_fields(
|
|||
}
|
||||
}
|
||||
|
||||
/// Get column names for a table from the database schema
|
||||
pub fn get_table_columns(conn: &mut PgConnection, table_name: &str) -> Vec<String> {
|
||||
use diesel::prelude::*;
|
||||
use diesel::sql_types::Text;
|
||||
|
||||
// Define a struct for the query result
|
||||
#[derive(diesel::QueryableByName)]
|
||||
struct ColumnName {
|
||||
#[diesel(sql_type = Text)]
|
||||
column_name: String,
|
||||
}
|
||||
|
||||
// Query information_schema to get column names
|
||||
diesel::sql_query(
|
||||
"SELECT column_name FROM information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position"
|
||||
)
|
||||
.bind::<Text, _>(table_name)
|
||||
.load::<ColumnName>(conn)
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|c| c.column_name)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -325,7 +325,7 @@ fn parse_field_definition(
|
|||
})
|
||||
}
|
||||
|
||||
fn map_type_to_sql(field: &FieldDefinition, driver: &str) -> String {
|
||||
pub fn map_type_to_sql(field: &FieldDefinition, driver: &str) -> String {
|
||||
let base_type = match field.field_type.as_str() {
|
||||
"string" => {
|
||||
let len = field.length.unwrap_or(255);
|
||||
|
|
@ -630,6 +630,28 @@ pub fn process_table_definitions(
|
|||
return Ok(tables);
|
||||
}
|
||||
|
||||
// Use schema sync for both debug and release builds (non-destructive)
|
||||
use super::table_migration::sync_bot_tables;
|
||||
|
||||
info!("Running schema migration sync (non-destructive)");
|
||||
|
||||
match sync_bot_tables(&state, bot_id, source) {
|
||||
Ok(result) => {
|
||||
info!("Schema sync completed: {} created, {} altered, {} columns added",
|
||||
result.tables_created, result.tables_altered, result.columns_added);
|
||||
|
||||
// If sync was successful, skip standard table creation
|
||||
if result.tables_created > 0 || result.tables_altered > 0 {
|
||||
return Ok(tables);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Schema sync failed: {}", e);
|
||||
// Fall through to standard table creation
|
||||
}
|
||||
}
|
||||
|
||||
// Standard table creation (for release builds or as fallback)
|
||||
for table in &tables {
|
||||
info!(
|
||||
"Processing TABLE {} ON {}",
|
||||
|
|
|
|||
243
src/basic/keywords/table_migration.rs
Normal file
243
src/basic/keywords/table_migration.rs
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
/*****************************************************************************\
|
||||
| Table Schema Migration Module
|
||||
| Automatically syncs table.bas definitions with database schema
|
||||
\*****************************************************************************/
|
||||
|
||||
use crate::core::shared::sanitize_identifier;
|
||||
use crate::core::shared::state::AppState;
|
||||
use diesel::prelude::*;
|
||||
use diesel::sql_query;
|
||||
use diesel::sql_types::Text;
|
||||
use log::{error, info, warn};
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::table_definition::{FieldDefinition, TableDefinition, map_type_to_sql, parse_table_definition};
|
||||
|
||||
/// Schema migration result
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MigrationResult {
|
||||
pub tables_created: usize,
|
||||
pub tables_altered: usize,
|
||||
pub columns_added: usize,
|
||||
pub errors: Vec<String>,
|
||||
}
|
||||
|
||||
/// Column metadata from database
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DbColumn {
|
||||
pub name: String,
|
||||
pub data_type: String,
|
||||
pub is_nullable: bool,
|
||||
}
|
||||
|
||||
/// Compare and sync table schema with definition
|
||||
pub fn sync_table_schema(
|
||||
table: &TableDefinition,
|
||||
existing_columns: &[DbColumn],
|
||||
create_sql: &str,
|
||||
conn: &mut diesel::PgConnection,
|
||||
) -> Result<MigrationResult, Box<dyn Error + Send + Sync>> {
|
||||
let mut result = MigrationResult::default();
|
||||
|
||||
// If no columns exist, create the table
|
||||
if existing_columns.is_empty() {
|
||||
info!("Creating new table: {}", table.name);
|
||||
sql_query(create_sql).execute(conn)
|
||||
.map_err(|e| format!("Failed to create table {}: {}", table.name, e))?;
|
||||
result.tables_created += 1;
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
// Check for schema drift
|
||||
let existing_col_names: std::collections::HashSet<String> =
|
||||
existing_columns.iter().map(|c| c.name.clone()).collect();
|
||||
|
||||
let mut missing_columns: Vec<&FieldDefinition> = Vec::new();
|
||||
for field in &table.fields {
|
||||
if !existing_col_names.contains(&field.name) {
|
||||
missing_columns.push(field);
|
||||
}
|
||||
}
|
||||
|
||||
// Add missing columns
|
||||
if !missing_columns.is_empty() {
|
||||
info!("Table {} is missing {} columns, adding them", table.name, missing_columns.len());
|
||||
|
||||
for field in &missing_columns {
|
||||
let sql_type = map_type_to_sql(field, "postgres");
|
||||
let column_sql = if field.is_nullable {
|
||||
format!("ALTER TABLE {} ADD COLUMN IF NOT EXISTS {} {}",
|
||||
sanitize_identifier(&table.name),
|
||||
sanitize_identifier(&field.name),
|
||||
sql_type)
|
||||
} else {
|
||||
// For NOT NULL columns, add as nullable first then set default
|
||||
format!("ALTER TABLE {} ADD COLUMN IF NOT EXISTS {} {}",
|
||||
sanitize_identifier(&table.name),
|
||||
sanitize_identifier(&field.name),
|
||||
sql_type)
|
||||
};
|
||||
|
||||
info!("Adding column: {}.{} ({})", table.name, field.name, sql_type);
|
||||
match sql_query(&column_sql).execute(conn) {
|
||||
Ok(_) => {
|
||||
result.columns_added += 1;
|
||||
info!("Successfully added column {}.{}", table.name, field.name);
|
||||
}
|
||||
Err(e) => {
|
||||
// Check if column already exists (ignore error)
|
||||
let err_str = e.to_string();
|
||||
if !err_str.contains("already exists") && !err_str.contains("duplicate column") {
|
||||
let error_msg = format!("Failed to add column {}.{}: {}", table.name, field.name, e);
|
||||
error!("{}", error_msg);
|
||||
result.errors.push(error_msg);
|
||||
} else {
|
||||
info!("Column {}.{} already exists, skipping", table.name, field.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result.tables_altered += 1;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get existing columns from a table
|
||||
pub fn get_table_columns(
|
||||
table_name: &str,
|
||||
conn: &mut diesel::PgConnection,
|
||||
) -> Result<Vec<DbColumn>, Box<dyn Error + Send + Sync>> {
|
||||
let query = format!(
|
||||
"SELECT column_name, data_type, is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = '{}' AND table_schema = 'public'
|
||||
ORDER BY ordinal_position",
|
||||
sanitize_identifier(table_name)
|
||||
);
|
||||
|
||||
#[derive(QueryableByName)]
|
||||
struct ColumnRow {
|
||||
#[diesel(sql_type = Text)]
|
||||
column_name: String,
|
||||
#[diesel(sql_type = Text)]
|
||||
data_type: String,
|
||||
#[diesel(sql_type = Text)]
|
||||
is_nullable: String,
|
||||
}
|
||||
|
||||
let rows: Vec<ColumnRow> = match sql_query(&query).load(conn) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
// Table doesn't exist
|
||||
return Err(format!("Table {} does not exist: {}", table_name, e).into());
|
||||
}
|
||||
};
|
||||
|
||||
Ok(rows.into_iter().map(|row| DbColumn {
|
||||
name: row.column_name,
|
||||
data_type: row.data_type,
|
||||
is_nullable: row.is_nullable == "YES",
|
||||
}).collect())
|
||||
}
|
||||
|
||||
/// Process table definitions with schema sync for a specific bot
|
||||
pub fn sync_bot_tables(
|
||||
state: &Arc<AppState>,
|
||||
bot_id: Uuid,
|
||||
source: &str,
|
||||
) -> Result<MigrationResult, Box<dyn Error + Send + Sync>> {
|
||||
let tables = parse_table_definition(source)?;
|
||||
let mut result = MigrationResult::default();
|
||||
|
||||
info!("Processing {} table definitions with schema sync for bot {}", tables.len(), bot_id);
|
||||
|
||||
// Get bot's database connection
|
||||
let pool = state.bot_database_manager.get_bot_pool(bot_id)?;
|
||||
let mut conn = pool.get()?;
|
||||
|
||||
for table in &tables {
|
||||
if table.connection_name != "default" {
|
||||
continue; // Skip external connections for now
|
||||
}
|
||||
|
||||
info!("Syncing table: {}", table.name);
|
||||
|
||||
// Get existing columns
|
||||
let existing_columns = get_table_columns(&table.name, &mut conn).unwrap_or_default();
|
||||
|
||||
// Generate CREATE TABLE SQL
|
||||
let create_sql = super::table_definition::generate_create_table_sql(table, "postgres");
|
||||
|
||||
// Sync schema
|
||||
match sync_table_schema(table, &existing_columns, &create_sql, &mut conn) {
|
||||
Ok(sync_result) => {
|
||||
result.tables_created += sync_result.tables_created;
|
||||
result.tables_altered += sync_result.tables_altered;
|
||||
result.columns_added += sync_result.columns_added;
|
||||
result.errors.extend(sync_result.errors);
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = format!("Failed to sync table {}: {}", table.name, e);
|
||||
error!("{}", error_msg);
|
||||
result.errors.push(error_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log summary
|
||||
info!("Schema sync summary for bot {}: {} tables created, {} altered, {} columns added, {} errors",
|
||||
bot_id, result.tables_created, result.tables_altered, result.columns_added, result.errors.len());
|
||||
|
||||
if !result.errors.is_empty() {
|
||||
warn!("Schema sync completed with {} errors:", result.errors.len());
|
||||
for error in &result.errors {
|
||||
warn!(" - {}", error);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Validate that all required columns exist
|
||||
pub fn validate_table_schema(
|
||||
table_name: &str,
|
||||
required_fields: &[FieldDefinition],
|
||||
conn: &mut diesel::PgConnection,
|
||||
) -> Result<bool, Box<dyn Error + Send + Sync>> {
|
||||
let existing_columns = get_table_columns(table_name, conn)?;
|
||||
let existing_col_names: std::collections::HashSet<String> =
|
||||
existing_columns.iter().map(|c| c.name.clone()).collect();
|
||||
|
||||
let mut missing = Vec::new();
|
||||
for field in required_fields {
|
||||
if !existing_col_names.contains(&field.name) {
|
||||
missing.push(field.name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if !missing.is_empty() {
|
||||
warn!("Table {} is missing columns: {:?}", table_name, missing);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_db_column_creation() {
|
||||
let col = DbColumn {
|
||||
name: "test_col".to_string(),
|
||||
data_type: "character varying".to_string(),
|
||||
is_nullable: true,
|
||||
};
|
||||
assert_eq!(col.name, "test_col");
|
||||
assert_eq!(col.is_nullable, true);
|
||||
}
|
||||
}
|
||||
|
|
@ -26,7 +26,7 @@ pub fn use_tool_keyword(state: Arc<AppState>, user: UserSession, engine: &mut En
|
|||
tool_path_str.as_str()
|
||||
}
|
||||
.strip_suffix(".bas")
|
||||
.unwrap_or_else(|| tool_path_str.as_str())
|
||||
.unwrap_or(tool_path_str.as_str())
|
||||
.to_string();
|
||||
if tool_name.is_empty() {
|
||||
return Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||||
|
|
|
|||
|
|
@ -826,7 +826,7 @@ mod tests {
|
|||
"docs_example_com_path"
|
||||
);
|
||||
assert_eq!(
|
||||
sanitize_url_for_collection("http://test.site:8080"),
|
||||
sanitize_url_for_collection("http://test.site:9000"),
|
||||
"test_site_8080"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
782
src/basic/mod.rs
782
src/basic/mod.rs
|
|
@ -186,6 +186,9 @@ impl ScriptService {
|
|||
register_string_functions(state.clone(), user.clone(), &mut engine);
|
||||
switch_keyword(&state, user.clone(), &mut engine);
|
||||
register_http_operations(state.clone(), user.clone(), &mut engine);
|
||||
// Register SAVE FROM UNSTRUCTURED before regular SAVE to avoid pattern conflicts
|
||||
#[cfg(feature = "llm")]
|
||||
save_from_unstructured_keyword(state.clone(), user.clone(), &mut engine);
|
||||
register_data_operations(state.clone(), user.clone(), &mut engine);
|
||||
#[cfg(feature = "automation")]
|
||||
webhook_keyword(&state, user.clone(), &mut engine);
|
||||
|
|
@ -223,7 +226,6 @@ impl ScriptService {
|
|||
register_model_routing_keywords(state.clone(), user.clone(), &mut engine);
|
||||
register_multimodal_keywords(state.clone(), user.clone(), &mut engine);
|
||||
remember_keyword(state.clone(), user.clone(), &mut engine);
|
||||
save_from_unstructured_keyword(state.clone(), user.clone(), &mut engine);
|
||||
}
|
||||
|
||||
// Register USE WEBSITE after all other USE keywords to avoid conflicts
|
||||
|
|
@ -579,6 +581,7 @@ impl ScriptService {
|
|||
trimmed.starts_with("DESCRIPTION ") ||
|
||||
trimmed.starts_with("DESCRIPTION\t") ||
|
||||
trimmed.starts_with('\'') || // BASIC comment lines
|
||||
trimmed.starts_with('#') || // Hash comment lines
|
||||
trimmed.is_empty())
|
||||
})
|
||||
.collect::<Vec<&str>>()
|
||||
|
|
@ -589,10 +592,16 @@ impl ScriptService {
|
|||
// Apply minimal preprocessing for tools (skip variable normalization to avoid breaking multi-line strings)
|
||||
let script = preprocess_switch(&executable_script);
|
||||
let script = Self::convert_multiword_keywords(&script);
|
||||
// Convert FORMAT(expr, pattern) to FORMAT expr pattern for Rhai space-separated function syntax
|
||||
// FORMAT syntax conversion disabled - Rhai supports comma-separated args natively
|
||||
// let script = Self::convert_format_syntax(&script);
|
||||
// Skip normalize_variables_to_lowercase for tools - it breaks multi-line strings
|
||||
// Note: FORMAT is registered as a regular function, so FORMAT(expr, pattern) works directly
|
||||
|
||||
info!("[TOOL] Preprocessed tool script for Rhai compilation");
|
||||
// Convert SAVE statements with field lists to map-based SAVE (simplified version for tools)
|
||||
let script = Self::convert_save_for_tools(&script);
|
||||
// Convert BEGIN TALK and BEGIN MAIL blocks to single calls
|
||||
let script = crate::basic::compiler::blocks::convert_begin_blocks(&script);
|
||||
// Convert IF ... THEN / END IF to if ... { }
|
||||
let script = Self::convert_if_then_syntax(&script);
|
||||
// Convert SELECT ... CASE / END SELECT to match expressions
|
||||
|
|
@ -612,6 +621,89 @@ impl ScriptService {
|
|||
self.engine.eval_ast_with_scope(&mut self.scope, ast)
|
||||
}
|
||||
|
||||
/// Convert SAVE statements for tool compilation (simplified, no DB lookup)
|
||||
/// SAVE "table", var1, var2, ... -> let __data__ = #{var1: var1, var2: var2, ...}; SAVE "table", __data__
|
||||
fn convert_save_for_tools(script: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut save_counter = 0;
|
||||
|
||||
for line in script.lines() {
|
||||
let trimmed = line.trim();
|
||||
|
||||
// Check if this is a SAVE statement
|
||||
if trimmed.to_uppercase().starts_with("SAVE ") {
|
||||
// Parse SAVE statement
|
||||
// Format: SAVE "table", value1, value2, ...
|
||||
let content = &trimmed[4..].trim();
|
||||
|
||||
// Simple parse by splitting on commas (outside quotes)
|
||||
let parts = Self::parse_save_parts(content);
|
||||
|
||||
// If more than 2 parts, convert to map-based SAVE
|
||||
if parts.len() > 2 {
|
||||
let table_name = parts[0].trim_matches('"');
|
||||
let values: Vec<&str> = parts.iter().skip(1).map(|s| s.trim()).collect();
|
||||
|
||||
// Build map with variable names as keys
|
||||
let map_pairs: Vec<String> = values.iter().map(|v| format!("{}: {}", v, v)).collect();
|
||||
let map_expr = format!("#{{{}}}", map_pairs.join(", "));
|
||||
let data_var = format!("__save_data_{}__", save_counter);
|
||||
save_counter += 1;
|
||||
|
||||
let converted = format!("let {} = {};\nINSERT \"{}\", {};", data_var, map_expr, table_name, data_var);
|
||||
result.push_str(&converted);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Parse SAVE statement parts (handles quoted strings)
|
||||
fn parse_save_parts(s: &str) -> Vec<String> {
|
||||
let mut parts = Vec::new();
|
||||
let mut current = String::new();
|
||||
let mut in_quotes = false;
|
||||
let mut chars = s.chars().peekable();
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
match c {
|
||||
'"' if !in_quotes => {
|
||||
in_quotes = true;
|
||||
current.push(c);
|
||||
}
|
||||
'"' if in_quotes => {
|
||||
in_quotes = false;
|
||||
current.push(c);
|
||||
}
|
||||
',' if !in_quotes => {
|
||||
parts.push(current.trim().to_string());
|
||||
current = String::new();
|
||||
// Skip whitespace after comma
|
||||
while let Some(&next_c) = chars.peek() {
|
||||
if next_c.is_whitespace() {
|
||||
chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => current.push(c),
|
||||
}
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
parts.push(current.trim().to_string());
|
||||
}
|
||||
|
||||
parts
|
||||
}
|
||||
|
||||
/// Set a variable in the script scope (for tool parameters)
|
||||
pub fn set_variable(&mut self, name: &str, value: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
use rhai::Dynamic;
|
||||
|
|
@ -619,32 +711,380 @@ impl ScriptService {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Convert FORMAT(expr, pattern) to FORMAT expr pattern (custom syntax format)
|
||||
/// Also handles RANDOM and other functions that need space-separated arguments
|
||||
/// This properly handles nested function calls by counting parentheses
|
||||
#[allow(dead_code)]
|
||||
fn convert_format_syntax(script: &str) -> String {
|
||||
use regex::Regex;
|
||||
let mut result = script.to_string();
|
||||
let mut result = String::new();
|
||||
let mut chars = script.chars().peekable();
|
||||
let mut i = 0;
|
||||
let bytes = script.as_bytes();
|
||||
|
||||
// First, process RANDOM to ensure commas are preserved
|
||||
// RANDOM(min, max) stays as RANDOM(min, max) - no conversion needed
|
||||
while i < bytes.len() {
|
||||
// Check if this is the start of FORMAT(
|
||||
if i + 6 <= bytes.len()
|
||||
&& bytes[i..i+6].eq_ignore_ascii_case(b"FORMAT")
|
||||
&& i + 7 < bytes.len()
|
||||
&& bytes[i + 6] == b'('
|
||||
{
|
||||
// Found FORMAT( - now parse the arguments
|
||||
let mut paren_depth = 1;
|
||||
let mut j = i + 7; // Start after FORMAT(
|
||||
let mut comma_pos = None;
|
||||
|
||||
// Convert FORMAT(expr, pattern) → FORMAT expr pattern
|
||||
// Need to handle nested functions carefully
|
||||
// Match FORMAT( ... ) but don't include inner function parentheses
|
||||
// This regex matches FORMAT followed by parentheses containing two comma-separated expressions
|
||||
if let Ok(re) = Regex::new(r"(?i)FORMAT\s*\(([^()]+(?:\([^()]*\)[^()]*)*),([^)]+)\)") {
|
||||
result = re.replace_all(&result, "FORMAT $1$2").to_string();
|
||||
// Find the arguments by tracking parentheses
|
||||
while j < bytes.len() && paren_depth > 0 {
|
||||
match bytes[j] {
|
||||
b'(' => paren_depth += 1,
|
||||
b')' => {
|
||||
paren_depth -= 1;
|
||||
if paren_depth == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
b',' => {
|
||||
if paren_depth == 1 {
|
||||
// This is the comma separating FORMAT's arguments
|
||||
comma_pos = Some(j);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
j += 1;
|
||||
}
|
||||
|
||||
if let Some(comma) = comma_pos {
|
||||
// Extract the two arguments
|
||||
let expr = &script[i + 7..comma].trim();
|
||||
let pattern = &script[comma + 1..j].trim();
|
||||
|
||||
// Convert to Rhai space-separated syntax
|
||||
// Remove quotes from pattern if present, then add them back in the right format
|
||||
let pattern_clean = pattern.trim_matches('"').trim_matches('\'');
|
||||
result.push_str(&format!("FORMAT ({expr}) (\"{pattern_clean}\")"));
|
||||
|
||||
i = j + 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the character as-is
|
||||
if let Some(c) = chars.next() {
|
||||
result.push(c);
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Convert a single TALK line with ${variable} substitution to proper TALK syntax
|
||||
/// Handles: "Hello ${name}" → TALK "Hello " + name
|
||||
/// Also handles: "Plain text" → TALK "Plain text"
|
||||
/// Also handles function calls: "Value: ${FORMAT(x, "n")}" → TALK "Value: " + FORMAT(x, "n")
|
||||
fn convert_talk_line_with_substitution(line: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut chars = line.chars().peekable();
|
||||
let mut in_substitution = false;
|
||||
let mut current_expr = String::new();
|
||||
let mut current_literal = String::new();
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
match c {
|
||||
'$' => {
|
||||
if let Some(&'{') = chars.peek() {
|
||||
// Start of ${...} substitution
|
||||
chars.next(); // consume '{'
|
||||
|
||||
// Add accumulated literal as a string if non-empty
|
||||
if !current_literal.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push_str("TALK \"");
|
||||
} else {
|
||||
result.push_str(" + \"");
|
||||
}
|
||||
// Escape any quotes in the literal
|
||||
let escaped = current_literal.replace('"', "\\\"");
|
||||
result.push_str(&escaped);
|
||||
result.push('"');
|
||||
current_literal.clear();
|
||||
}
|
||||
|
||||
in_substitution = true;
|
||||
current_expr.clear();
|
||||
} else {
|
||||
// Regular $ character, add to literal
|
||||
current_literal.push(c);
|
||||
}
|
||||
}
|
||||
'}' if in_substitution => {
|
||||
// End of ${...} substitution
|
||||
in_substitution = false;
|
||||
|
||||
// Add the expression (variable or function call)
|
||||
if !current_expr.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push_str(¤t_expr);
|
||||
} else {
|
||||
result.push_str(" + ");
|
||||
result.push_str(¤t_expr);
|
||||
}
|
||||
}
|
||||
current_expr.clear();
|
||||
}
|
||||
_ if in_substitution => {
|
||||
// Collect expression content, tracking parentheses and quotes
|
||||
// This handles function calls like FORMAT(x, "pattern")
|
||||
current_expr.push(c);
|
||||
|
||||
// Track nested parentheses and quoted strings
|
||||
let mut paren_depth: i32 = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
for ch in current_expr.chars() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'\\' => {
|
||||
escape_next = true;
|
||||
}
|
||||
'"' if !in_string => {
|
||||
in_string = true;
|
||||
}
|
||||
'"' if in_string => {
|
||||
in_string = false;
|
||||
}
|
||||
'(' if !in_string => {
|
||||
paren_depth += 1;
|
||||
}
|
||||
')' if !in_string => {
|
||||
paren_depth = paren_depth.saturating_sub(1);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Continue collecting expression until we're back at depth 0
|
||||
// The closing '}' will handle the end of substitution
|
||||
}
|
||||
_ => {
|
||||
// Regular character, add to literal
|
||||
current_literal.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add any remaining literal
|
||||
if !current_literal.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push_str("TALK \"");
|
||||
} else {
|
||||
result.push_str(" + \"");
|
||||
}
|
||||
let escaped = current_literal.replace('"', "\\\"");
|
||||
result.push_str(&escaped);
|
||||
result.push('"');
|
||||
}
|
||||
|
||||
// If result is empty (shouldn't happen), just return a TALK with empty string
|
||||
if result.is_empty() {
|
||||
result = "TALK \"\"".to_string();
|
||||
}
|
||||
|
||||
log::debug!("[TOOL] Converted TALK line: '{}' → '{}'", line, result);
|
||||
result
|
||||
}
|
||||
|
||||
/// Convert a BEGIN MAIL ... END MAIL block to SEND EMAIL call
|
||||
/// Handles multi-line emails with ${variable} substitution
|
||||
/// Uses intermediate variables to reduce expression complexity
|
||||
/// Format:
|
||||
/// BEGIN MAIL recipient
|
||||
/// Subject: Email subject here
|
||||
///
|
||||
/// Body line 1 with ${variable}
|
||||
/// Body line 2 with ${anotherVariable}
|
||||
/// END MAIL
|
||||
fn convert_mail_block(recipient: &str, lines: &[String]) -> String {
|
||||
let mut subject = String::new();
|
||||
let mut body_lines: Vec<String> = Vec::new();
|
||||
// let mut in_subject = true; // Removed unused variable
|
||||
let mut skip_blank = true;
|
||||
|
||||
for line in lines.iter() {
|
||||
// Check if this line is a subject line
|
||||
if line.to_uppercase().starts_with("SUBJECT:") {
|
||||
subject = line[8..].trim().to_string();
|
||||
// in_subject = false; // Removed unused assignment
|
||||
skip_blank = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip blank lines after subject
|
||||
if skip_blank && line.trim().is_empty() {
|
||||
skip_blank = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
skip_blank = false;
|
||||
|
||||
// Process body line with ${} substitution
|
||||
let converted = Self::convert_mail_line_with_substitution(line);
|
||||
body_lines.push(converted);
|
||||
}
|
||||
|
||||
// Generate code that builds the email body using intermediate variables
|
||||
// This reduces expression complexity for Rhai parser
|
||||
let mut result = String::new();
|
||||
|
||||
// Create intermediate variables for body chunks (max 5 lines per variable to keep complexity low)
|
||||
let chunk_size = 5;
|
||||
let mut all_vars: Vec<String> = Vec::new();
|
||||
|
||||
for (var_count, chunk) in body_lines.chunks(chunk_size).enumerate() {
|
||||
let var_name = format!("__mail_body_{}__", var_count);
|
||||
all_vars.push(var_name.clone());
|
||||
|
||||
if chunk.len() == 1 {
|
||||
result.push_str(&format!("let {} = {};\n", var_name, chunk[0]));
|
||||
} else {
|
||||
let mut chunk_expr = chunk[0].clone();
|
||||
for line in &chunk[1..] {
|
||||
chunk_expr.push_str(" + \"\\n\" + ");
|
||||
chunk_expr.push_str(line);
|
||||
}
|
||||
result.push_str(&format!("let {} = {};\n", var_name, chunk_expr));
|
||||
}
|
||||
}
|
||||
|
||||
// Combine all chunks into final body
|
||||
let body_expr = if all_vars.is_empty() {
|
||||
"\"\"".to_string()
|
||||
} else if all_vars.len() == 1 {
|
||||
all_vars[0].clone()
|
||||
} else {
|
||||
let mut expr = all_vars[0].clone();
|
||||
for var in &all_vars[1..] {
|
||||
expr.push_str(" + \"\\n\" + ");
|
||||
expr.push_str(var);
|
||||
}
|
||||
expr
|
||||
};
|
||||
|
||||
// Generate the send_mail function call
|
||||
// If recipient contains '@', it's a string literal and needs to be quoted
|
||||
// Otherwise, it's a variable name and should be used as-is
|
||||
let recipient_expr = if recipient.contains('@') {
|
||||
format!("\"{}\"", recipient)
|
||||
} else {
|
||||
recipient.to_string()
|
||||
};
|
||||
result.push_str(&format!("send_mail({}, \"{}\", {}, []);\n", recipient_expr, subject, body_expr));
|
||||
|
||||
log::info!("[TOOL] Converted MAIL block → {}", result);
|
||||
result
|
||||
}
|
||||
|
||||
/// Convert a single mail line with ${variable} substitution to string concatenation
|
||||
/// Similar to TALK substitution but doesn't add "TALK" prefix
|
||||
fn convert_mail_line_with_substitution(line: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut chars = line.chars().peekable();
|
||||
let mut in_substitution = false;
|
||||
let mut current_var = String::new();
|
||||
let mut current_literal = String::new();
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
match c {
|
||||
'$' => {
|
||||
if let Some(&'{') = chars.peek() {
|
||||
// Start of ${...} substitution
|
||||
chars.next(); // consume '{'
|
||||
|
||||
// Add accumulated literal as a string if non-empty
|
||||
if !current_literal.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push('"');
|
||||
result.push_str(¤t_literal.replace('"', "\\\""));
|
||||
result.push('"');
|
||||
} else {
|
||||
result.push_str(" + \"");
|
||||
result.push_str(¤t_literal.replace('"', "\\\""));
|
||||
result.push('"');
|
||||
}
|
||||
current_literal.clear();
|
||||
}
|
||||
|
||||
in_substitution = true;
|
||||
current_var.clear();
|
||||
} else {
|
||||
// Regular $ character, add to literal
|
||||
current_literal.push(c);
|
||||
}
|
||||
}
|
||||
'}' if in_substitution => {
|
||||
// End of ${...} substitution
|
||||
in_substitution = false;
|
||||
|
||||
// Add the variable name
|
||||
if !current_var.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push_str(¤t_var);
|
||||
} else {
|
||||
result.push_str(" + ");
|
||||
result.push_str(¤t_var);
|
||||
}
|
||||
}
|
||||
current_var.clear();
|
||||
}
|
||||
_ if in_substitution => {
|
||||
// Collect variable name (allow alphanumeric, underscore, and function call syntax)
|
||||
if c.is_alphanumeric() || c == '_' || c == '(' || c == ')' || c == ',' || c == ' ' || c == '\"' {
|
||||
current_var.push(c);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Regular character, add to literal
|
||||
if !in_substitution {
|
||||
current_literal.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add any remaining literal
|
||||
if !current_literal.is_empty() {
|
||||
if result.is_empty() {
|
||||
result.push('"');
|
||||
result.push_str(¤t_literal.replace('"', "\\\""));
|
||||
result.push('"');
|
||||
} else {
|
||||
result.push_str(" + \"");
|
||||
result.push_str(¤t_literal.replace('"', "\\\""));
|
||||
result.push('"');
|
||||
}
|
||||
}
|
||||
|
||||
log::debug!("[TOOL] Converted mail line: '{}' → '{}'", line, result);
|
||||
result
|
||||
}
|
||||
|
||||
/// Convert BASIC IF ... THEN / END IF syntax to Rhai's if ... { } syntax
|
||||
fn convert_if_then_syntax(script: &str) -> String {
|
||||
pub fn convert_if_then_syntax(script: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut if_stack: Vec<bool> = Vec::new();
|
||||
let mut in_with_block = false;
|
||||
let mut in_talk_block = false;
|
||||
let mut talk_block_lines: Vec<String> = Vec::new();
|
||||
let mut in_mail_block = false;
|
||||
let mut mail_recipient = String::new();
|
||||
let mut mail_block_lines: Vec<String> = Vec::new();
|
||||
let mut in_line_continuation = false;
|
||||
|
||||
log::info!("[TOOL] Converting IF/THEN syntax, input has {} lines", script.lines().count());
|
||||
|
||||
|
|
@ -653,7 +1093,7 @@ impl ScriptService {
|
|||
let upper = trimmed.to_uppercase();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if trimmed.is_empty() || trimmed.starts_with('\'') || trimmed.starts_with("//") {
|
||||
if trimmed.is_empty() || trimmed.starts_with('\'') || trimmed.starts_with('#') || trimmed.starts_with("//") {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -666,6 +1106,19 @@ impl ScriptService {
|
|||
let condition = &trimmed[3..then_pos].trim();
|
||||
// Convert BASIC "NOT IN" to Rhai "!in"
|
||||
let condition = condition.replace(" NOT IN ", " !in ").replace(" not in ", " !in ");
|
||||
// Convert BASIC "AND" to Rhai "&&" and "OR" to Rhai "||"
|
||||
let condition = condition.replace(" AND ", " && ").replace(" and ", " && ")
|
||||
.replace(" OR ", " || ").replace(" or ", " || ");
|
||||
// Convert BASIC "=" to Rhai "==" for comparisons in IF conditions
|
||||
// Skip if it's already a comparison operator (==, !=, <=, >=) or assignment (+=, -=, etc.)
|
||||
let condition = if !condition.contains("==") && !condition.contains("!=")
|
||||
&& !condition.contains("<=") && !condition.contains(">=")
|
||||
&& !condition.contains("+=") && !condition.contains("-=")
|
||||
&& !condition.contains("*=") && !condition.contains("/=") {
|
||||
condition.replace("=", "==")
|
||||
} else {
|
||||
condition.to_string()
|
||||
};
|
||||
log::info!("[TOOL] Converting IF statement: condition='{}'", condition);
|
||||
result.push_str("if ");
|
||||
result.push_str(&condition);
|
||||
|
|
@ -681,10 +1134,35 @@ impl ScriptService {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Handle ELSEIF ... THEN
|
||||
if upper.starts_with("ELSEIF ") && upper.contains(" THEN") {
|
||||
let then_pos = match upper.find(" THEN") {
|
||||
Some(pos) => pos,
|
||||
None => continue,
|
||||
};
|
||||
let condition = &trimmed[6..then_pos].trim();
|
||||
let condition = condition.replace(" NOT IN ", " !in ").replace(" not in ", " !in ");
|
||||
let condition = condition.replace(" AND ", " && ").replace(" and ", " && ")
|
||||
.replace(" OR ", " || ").replace(" or ", " || ");
|
||||
let condition = if !condition.contains("==") && !condition.contains("!=")
|
||||
&& !condition.contains("<=") && !condition.contains(">=")
|
||||
&& !condition.contains("+=") && !condition.contains("-=")
|
||||
&& !condition.contains("*=") && !condition.contains("/=") {
|
||||
condition.replace("=", "==")
|
||||
} else {
|
||||
condition.to_string()
|
||||
};
|
||||
log::info!("[TOOL] Converting ELSEIF statement: condition='{}'", condition);
|
||||
result.push_str("} else if ");
|
||||
result.push_str(&condition);
|
||||
result.push_str(" {\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle END IF
|
||||
if upper == "END IF" {
|
||||
log::info!("[TOOL] Converting END IF statement");
|
||||
if let Some(_) = if_stack.pop() {
|
||||
if if_stack.pop().is_some() {
|
||||
result.push_str("}\n");
|
||||
}
|
||||
continue;
|
||||
|
|
@ -709,6 +1187,85 @@ impl ScriptService {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Handle BEGIN TALK ... END TALK (multi-line TALK with ${} substitution)
|
||||
if upper == "BEGIN TALK" {
|
||||
log::info!("[TOOL] Converting BEGIN TALK statement");
|
||||
in_talk_block = true;
|
||||
talk_block_lines.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if upper == "END TALK" {
|
||||
log::info!("[TOOL] Converting END TALK statement, processing {} lines", talk_block_lines.len());
|
||||
in_talk_block = false;
|
||||
|
||||
// Split into multiple TALK statements to avoid expression complexity limit
|
||||
// Use chunks of 5 lines per TALK statement
|
||||
let chunk_size = 5;
|
||||
for chunk in talk_block_lines.chunks(chunk_size) {
|
||||
// Convert all talk lines in this chunk to a single TALK statement
|
||||
let mut combined_talk = String::new();
|
||||
for (i, talk_line) in chunk.iter().enumerate() {
|
||||
let converted = Self::convert_talk_line_with_substitution(talk_line);
|
||||
// Remove "TALK " prefix from converted line if present
|
||||
let line_content = if let Some(stripped) = converted.strip_prefix("TALK ") {
|
||||
stripped.trim().to_string()
|
||||
} else {
|
||||
converted
|
||||
};
|
||||
if i > 0 {
|
||||
combined_talk.push_str(" + \"\\n\" + ");
|
||||
}
|
||||
combined_talk.push_str(&line_content);
|
||||
}
|
||||
|
||||
// Generate TALK statement for this chunk
|
||||
result.push_str("TALK ");
|
||||
result.push_str(&combined_talk);
|
||||
result.push_str(";\n");
|
||||
}
|
||||
|
||||
talk_block_lines.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we're in a TALK block, collect lines
|
||||
if in_talk_block {
|
||||
// Skip empty lines but preserve them as blank TALK statements if needed
|
||||
talk_block_lines.push(trimmed.to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle BEGIN MAIL ... END MAIL (multi-line email with ${} substitution)
|
||||
if upper.starts_with("BEGIN MAIL ") {
|
||||
let recipient = &trimmed[11..].trim(); // Skip "BEGIN MAIL "
|
||||
log::info!("[TOOL] Converting BEGIN MAIL statement: recipient='{}'", recipient);
|
||||
mail_recipient = recipient.to_string();
|
||||
in_mail_block = true;
|
||||
mail_block_lines.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if upper == "END MAIL" {
|
||||
log::info!("[TOOL] Converting END MAIL statement, processing {} lines", mail_block_lines.len());
|
||||
in_mail_block = false;
|
||||
|
||||
// Process the mail block and convert to SEND EMAIL
|
||||
let converted = Self::convert_mail_block(&mail_recipient, &mail_block_lines);
|
||||
result.push_str(&converted);
|
||||
result.push('\n');
|
||||
|
||||
mail_recipient.clear();
|
||||
mail_block_lines.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we're in a MAIL block, collect lines
|
||||
if in_mail_block {
|
||||
mail_block_lines.push(trimmed.to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Inside a WITH block - convert property assignments (key = value → key: value)
|
||||
if in_with_block {
|
||||
// Check if this is a property assignment (identifier = value)
|
||||
|
|
@ -728,21 +1285,32 @@ impl ScriptService {
|
|||
result.push_str(" ");
|
||||
}
|
||||
|
||||
// Handle SAVE table, object → INSERT table, object
|
||||
// BASIC SAVE uses 2 parameters but Rhai SAVE needs 3
|
||||
// INSERT uses 2 parameters which matches the BASIC syntax
|
||||
// Handle SAVE table, field1, field2, ... → INSERT "table", #{field1: value1, field2: value2, ...}
|
||||
if upper.starts_with("SAVE") && upper.contains(',') {
|
||||
log::info!("[TOOL] Processing SAVE line: '{}'", trimmed);
|
||||
// Extract table and object name
|
||||
// Extract the part after "SAVE"
|
||||
let after_save = &trimmed[4..].trim(); // Skip "SAVE"
|
||||
let parts: Vec<&str> = after_save.split(',').collect();
|
||||
log::info!("[TOOL] SAVE parts: {:?}", parts);
|
||||
if parts.len() == 2 {
|
||||
|
||||
if parts.len() >= 2 {
|
||||
// First part is the table name (in quotes)
|
||||
let table = parts[0].trim().trim_matches('"');
|
||||
let object_name = parts[1].trim().trim_end_matches(';');
|
||||
// Convert to INSERT table, object
|
||||
let converted = format!("INSERT \"{}\", {};\n", table, object_name);
|
||||
log::info!("[TOOL] Converted SAVE to INSERT: '{}'", converted);
|
||||
|
||||
// For old WITH block syntax (parts.len() == 2), convert to INSERT with object name
|
||||
if parts.len() == 2 {
|
||||
let object_name = parts[1].trim().trim_end_matches(';');
|
||||
let converted = format!("INSERT \"{}\", {};\n", table, object_name);
|
||||
log::info!("[TOOL] Converted SAVE to INSERT (old syntax): '{}'", converted);
|
||||
result.push_str(&converted);
|
||||
continue;
|
||||
}
|
||||
|
||||
// For modern direct field list syntax (parts.len() > 2), just pass values as-is
|
||||
// The runtime SAVE handler will match them to database columns by position
|
||||
let values = parts[1..].join(", ");
|
||||
let converted = format!("SAVE \"{}\", {};\n", table, values);
|
||||
log::info!("[TOOL] Keeping SAVE syntax (modern): '{}'", converted);
|
||||
result.push_str(&converted);
|
||||
continue;
|
||||
}
|
||||
|
|
@ -776,7 +1344,7 @@ impl ScriptService {
|
|||
if !upper.starts_with("IF ") && !upper.starts_with("ELSE") && !upper.starts_with("END IF") {
|
||||
// Check if this is a variable assignment (identifier = expression)
|
||||
// Pattern: starts with letter/underscore, contains = but not ==, !=, <=, >=, +=, -=
|
||||
let is_var_assignment = trimmed.chars().next().map_or(false, |c| c.is_alphabetic() || c == '_')
|
||||
let is_var_assignment = trimmed.chars().next().is_some_and(|c| c.is_alphabetic() || c == '_')
|
||||
&& trimmed.contains('=')
|
||||
&& !trimmed.contains("==")
|
||||
&& !trimmed.contains("!=")
|
||||
|
|
@ -787,16 +1355,42 @@ impl ScriptService {
|
|||
&& !trimmed.contains("*=")
|
||||
&& !trimmed.contains("/=");
|
||||
|
||||
// Check for line continuation (BASIC uses comma at end of line)
|
||||
let ends_with_comma = trimmed.ends_with(',');
|
||||
|
||||
// If we're in a line continuation and this is not a variable assignment or statement,
|
||||
// it's likely a string literal continuation - quote it
|
||||
let line_to_process = if in_line_continuation && !is_var_assignment
|
||||
&& !trimmed.contains('=') && !trimmed.starts_with('"') && !upper.starts_with("IF ") {
|
||||
// This is a string literal continuation - quote it and escape any inner quotes
|
||||
let escaped = trimmed.replace('"', "\\\"");
|
||||
format!("\"{}\\n\"", escaped)
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
};
|
||||
|
||||
if is_var_assignment {
|
||||
// Add 'let' for variable declarations
|
||||
result.push_str("let ");
|
||||
// Add 'let' for variable declarations, but only if line doesn't already start with let/LET
|
||||
let trimmed_lower = trimmed.to_lowercase();
|
||||
if !trimmed_lower.starts_with("let ") {
|
||||
result.push_str("let ");
|
||||
}
|
||||
}
|
||||
result.push_str(trimmed);
|
||||
result.push_str(&line_to_process);
|
||||
// Add semicolon if line doesn't have one and doesn't end with { or }
|
||||
if !trimmed.ends_with(';') && !trimmed.ends_with('{') && !trimmed.ends_with('}') {
|
||||
// Skip adding semicolons to:
|
||||
// - SELECT/CASE/END SELECT statements (they're converted to if-else later)
|
||||
// - Lines ending with comma (BASIC line continuation)
|
||||
// - Lines that are part of a continuation block (in_line_continuation is true)
|
||||
if !trimmed.ends_with(';') && !trimmed.ends_with('{') && !trimmed.ends_with('}')
|
||||
&& !upper.starts_with("SELECT ") && !upper.starts_with("CASE ") && upper != "END SELECT"
|
||||
&& !ends_with_comma && !in_line_continuation {
|
||||
result.push(';');
|
||||
}
|
||||
result.push('\n');
|
||||
|
||||
// Update line continuation state
|
||||
in_line_continuation = ends_with_comma;
|
||||
} else {
|
||||
result.push_str(trimmed);
|
||||
result.push('\n');
|
||||
|
|
@ -804,18 +1398,40 @@ impl ScriptService {
|
|||
}
|
||||
|
||||
log::info!("[TOOL] IF/THEN conversion complete, output has {} lines", result.lines().count());
|
||||
result
|
||||
|
||||
// Convert BASIC <> (not equal) to Rhai != globally
|
||||
|
||||
|
||||
result.replace(" <> ", " != ")
|
||||
}
|
||||
|
||||
/// Convert BASIC SELECT ... CASE / END SELECT to Rhai match expressions
|
||||
/// Convert BASIC SELECT ... CASE / END SELECT to if-else chains
|
||||
/// Transforms: SELECT var ... CASE "value" ... END SELECT
|
||||
/// Into: match var { "value" => { ... } ... }
|
||||
fn convert_select_case_syntax(script: &str) -> String {
|
||||
/// Into: if var == "value" { ... } else if var == "value2" { ... }
|
||||
/// Note: We use if-else instead of match because 'match' is a reserved keyword in Rhai
|
||||
///
|
||||
/// IMPORTANT: This function strips 'let ' keywords from assignment statements inside CASE blocks
|
||||
/// to avoid creating local variables that shadow outer scope variables.
|
||||
pub fn convert_select_case_syntax(script: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut lines: Vec<&str> = script.lines().collect();
|
||||
let lines: Vec<&str> = script.lines().collect();
|
||||
let mut i = 0;
|
||||
|
||||
log::info!("[TOOL] Converting SELECT/CASE syntax");
|
||||
log::info!("[TOOL] Converting SELECT/CASE syntax to if-else chains");
|
||||
|
||||
// Helper function to strip 'let ' from the beginning of a line
|
||||
// This is needed because convert_if_then_syntax adds 'let' to all assignments,
|
||||
// but inside CASE blocks we want to modify outer variables, not create new ones
|
||||
fn strip_let_from_assignment(line: &str) -> String {
|
||||
let trimmed = line.trim();
|
||||
let trimmed_lower = trimmed.to_lowercase();
|
||||
if trimmed_lower.starts_with("let ") && trimmed.contains('=') {
|
||||
// This is a 'let' assignment - strip the 'let ' keyword
|
||||
trimmed[4..].trim().to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
while i < lines.len() {
|
||||
let trimmed = lines[i].trim();
|
||||
|
|
@ -827,43 +1443,83 @@ impl ScriptService {
|
|||
let select_var = trimmed[7..].trim(); // Skip "SELECT "
|
||||
log::info!("[TOOL] Converting SELECT statement for variable: '{}'", select_var);
|
||||
|
||||
// Start match expression
|
||||
result.push_str(&format!("match {} {{\n", select_var));
|
||||
|
||||
// Skip the SELECT line
|
||||
i += 1;
|
||||
|
||||
// Process CASE statements until END SELECT
|
||||
let mut current_case_body: Vec<String> = Vec::new();
|
||||
let mut in_case = false;
|
||||
let mut is_first_case = true;
|
||||
|
||||
while i < lines.len() {
|
||||
let case_trimmed = lines[i].trim();
|
||||
let case_upper = case_trimmed.to_uppercase();
|
||||
|
||||
// Skip empty lines and comment lines within SELECT/CASE blocks
|
||||
if case_trimmed.is_empty() || case_trimmed.starts_with('\'') || case_trimmed.starts_with('#') {
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if case_upper == "END SELECT" {
|
||||
// Close any open case
|
||||
if in_case {
|
||||
for body_line in ¤t_case_body {
|
||||
result.push_str(" ");
|
||||
result.push_str(body_line);
|
||||
// Strip 'let ' from assignments to avoid creating local variables
|
||||
let processed_line = strip_let_from_assignment(body_line);
|
||||
result.push_str(&processed_line);
|
||||
// Add semicolon if line doesn't have one
|
||||
if !processed_line.ends_with(';') && !processed_line.ends_with('{') && !processed_line.ends_with('}') {
|
||||
result.push(';');
|
||||
}
|
||||
result.push('\n');
|
||||
}
|
||||
// Close the last case arm (no else if, so we need the closing brace)
|
||||
result.push_str(" }\n");
|
||||
current_case_body.clear();
|
||||
in_case = false;
|
||||
//in_case = false; // Removed unused assignment
|
||||
}
|
||||
// Close the match expression
|
||||
result.push_str("}\n");
|
||||
// No extra closing brace needed - the last } else if ... { already closed the chain
|
||||
i += 1;
|
||||
break;
|
||||
} else if case_upper.starts_with("CASE ") {
|
||||
// Close previous case if any
|
||||
} else if case_upper.starts_with("SELECT ") {
|
||||
// Encountered another SELECT statement while processing this SELECT block
|
||||
// Close the current if-else chain and break to let the outer loop handle the new SELECT
|
||||
if in_case {
|
||||
for body_line in ¤t_case_body {
|
||||
result.push_str(" ");
|
||||
result.push_str(body_line);
|
||||
// Strip 'let ' from assignments to avoid creating local variables
|
||||
let processed_line = strip_let_from_assignment(body_line);
|
||||
result.push_str(&processed_line);
|
||||
// Add semicolon if line doesn't have one
|
||||
if !processed_line.ends_with(';') && !processed_line.ends_with('{') && !processed_line.ends_with('}') {
|
||||
result.push(';');
|
||||
}
|
||||
result.push('\n');
|
||||
}
|
||||
// Close the current case arm (no else if, so we need the closing brace)
|
||||
result.push_str(" }\n");
|
||||
current_case_body.clear();
|
||||
//in_case = false; // Removed unused assignment
|
||||
}
|
||||
// No extra closing brace needed
|
||||
break;
|
||||
} else if case_upper.starts_with("CASE ") {
|
||||
// Close previous case if any (but NOT if we're about to start else if)
|
||||
if in_case {
|
||||
for body_line in ¤t_case_body {
|
||||
result.push_str(" ");
|
||||
// Strip 'let ' from assignments to avoid creating local variables
|
||||
let processed_line = strip_let_from_assignment(body_line);
|
||||
result.push_str(&processed_line);
|
||||
// Add semicolon if line doesn't have one
|
||||
if !processed_line.ends_with(';') && !processed_line.ends_with('{') && !processed_line.ends_with('}') {
|
||||
result.push(';');
|
||||
}
|
||||
result.push('\n');
|
||||
}
|
||||
// NOTE: Don't close the case arm here - the } else if will close it
|
||||
current_case_body.clear();
|
||||
}
|
||||
|
||||
|
|
@ -876,14 +1532,22 @@ impl ScriptService {
|
|||
format!("\"{}\"", case_trimmed[5..].trim())
|
||||
};
|
||||
|
||||
result.push_str(&format!(" {} => {{\n", case_value));
|
||||
// Start if/else if chain
|
||||
if is_first_case {
|
||||
result.push_str(&format!("if {} == {} {{\n", select_var, case_value));
|
||||
is_first_case = false;
|
||||
} else {
|
||||
result.push_str(&format!("}} else if {} == {} {{\n", select_var, case_value));
|
||||
}
|
||||
in_case = true;
|
||||
i += 1;
|
||||
} else {
|
||||
} else if in_case {
|
||||
// Collect body lines for the current case
|
||||
if in_case {
|
||||
current_case_body.push(lines[i].to_string());
|
||||
}
|
||||
current_case_body.push(lines[i].to_string());
|
||||
i += 1;
|
||||
} else {
|
||||
// We're in the SELECT block but not in a CASE yet
|
||||
// Skip this line and move to the next
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
|
@ -892,9 +1556,11 @@ impl ScriptService {
|
|||
}
|
||||
|
||||
// Not a SELECT statement - just copy the line
|
||||
result.push_str(lines[i]);
|
||||
result.push('\n');
|
||||
i += 1;
|
||||
if i < lines.len() {
|
||||
result.push_str(lines[i]);
|
||||
result.push('\n');
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
|
|
@ -902,7 +1568,7 @@ impl ScriptService {
|
|||
|
||||
/// Convert BASIC keywords to lowercase without touching variables
|
||||
/// This is a simplified version of normalize_variables_to_lowercase for tools
|
||||
fn convert_keywords_to_lowercase(script: &str) -> String {
|
||||
pub fn convert_keywords_to_lowercase(script: &str) -> String {
|
||||
let keywords = [
|
||||
"IF", "THEN", "ELSE", "END IF", "FOR", "NEXT", "WHILE", "WEND",
|
||||
"DO", "LOOP", "RETURN", "EXIT",
|
||||
|
|
@ -1238,7 +1904,7 @@ impl ScriptService {
|
|||
/// - "USE WEBSITE "url" REFRESH "interval"" → "USE_WEBSITE("url", "interval")"
|
||||
/// - "SET BOT MEMORY key AS value" → "SET_BOT_MEMORY(key, value)"
|
||||
/// - "CLEAR SUGGESTIONS" → "CLEAR_SUGGESTIONS()"
|
||||
fn convert_multiword_keywords(script: &str) -> String {
|
||||
pub fn convert_multiword_keywords(script: &str) -> String {
|
||||
use regex::Regex;
|
||||
|
||||
// Known multi-word keywords with their conversion patterns
|
||||
|
|
@ -1363,9 +2029,9 @@ impl ScriptService {
|
|||
let mut current = String::new();
|
||||
let mut in_quotes = false;
|
||||
let mut quote_char = '"';
|
||||
let mut chars = params_str.chars().peekable();
|
||||
let chars = params_str.chars().peekable();
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
for c in chars {
|
||||
match c {
|
||||
'"' | '\'' if !in_quotes => {
|
||||
in_quotes = true;
|
||||
|
|
|
|||
1879
src/basic/mod.rs.backup
Normal file
1879
src/basic/mod.rs.backup
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -41,7 +41,7 @@ impl Platform {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
pub fn from_str_name(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"twitter" | "x" => Some(Self::Twitter),
|
||||
"facebook" | "fb" => Some(Self::Facebook),
|
||||
|
|
|
|||
|
|
@ -49,10 +49,7 @@ impl SocialPlatform {
|
|||
}
|
||||
|
||||
pub fn requires_oauth(&self) -> bool {
|
||||
match self {
|
||||
Self::Bluesky | Self::Telegram | Self::Twilio => false,
|
||||
_ => true,
|
||||
}
|
||||
!matches!(self, Self::Bluesky | Self::Telegram | Self::Twilio)
|
||||
}
|
||||
|
||||
pub fn authorization_url(&self) -> Option<&'static str> {
|
||||
|
|
|
|||
|
|
@ -298,10 +298,10 @@ impl GoogleClient {
|
|||
})).collect::<Vec<_>>())
|
||||
},
|
||||
"organizations": if contact.company.is_some() || contact.job_title.is_some() {
|
||||
Some([{
|
||||
"name": contact.company,
|
||||
"title": contact.job_title
|
||||
}])
|
||||
Some(vec![serde_json::json!({
|
||||
"name": contact.company.unwrap_or_default(),
|
||||
"title": contact.job_title.unwrap_or_default()
|
||||
})])
|
||||
} else { None }
|
||||
});
|
||||
|
||||
|
|
@ -363,10 +363,10 @@ impl GoogleClient {
|
|||
})).collect::<Vec<_>>())
|
||||
},
|
||||
"organizations": if contact.company.is_some() || contact.job_title.is_some() {
|
||||
Some([{
|
||||
"name": contact.company,
|
||||
"title": contact.job_title
|
||||
}])
|
||||
Some(vec![serde_json::json!({
|
||||
"name": contact.company.unwrap_or_default(),
|
||||
"title": contact.job_title.unwrap_or_default()
|
||||
})])
|
||||
} else { None }
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -106,9 +106,9 @@ impl BootstrapManager {
|
|||
}
|
||||
}
|
||||
|
||||
if pm.is_installed("postgres") {
|
||||
if pm.is_installed("tables") {
|
||||
info!("Starting PostgreSQL...");
|
||||
match pm.start("postgres") {
|
||||
match pm.start("tables") {
|
||||
Ok(_child) => {
|
||||
info!("PostgreSQL started");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -436,9 +436,22 @@ pub async fn inject_kb_context(
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
// Sanitize context to remove UTF-16 surrogate characters that can't be encoded in UTF-8
|
||||
let sanitized_context = context_string
|
||||
.chars()
|
||||
.filter(|c| {
|
||||
let cp = *c as u32;
|
||||
!(0xD800..=0xDBFF).contains(&cp) && !(0xDC00..=0xDFFF).contains(&cp)
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
if sanitized_context.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!(
|
||||
"Injecting {} characters of KB/website context into prompt for session {}",
|
||||
context_string.len(),
|
||||
sanitized_context.len(),
|
||||
session_id
|
||||
);
|
||||
|
||||
|
|
@ -447,7 +460,7 @@ pub async fn inject_kb_context(
|
|||
|
||||
if let Some(idx) = system_msg_idx {
|
||||
if let Some(content) = messages_array[idx]["content"].as_str() {
|
||||
let new_content = format!("{}\n{}", content, context_string);
|
||||
let new_content = format!("{}\n{}", content, sanitized_context);
|
||||
messages_array[idx]["content"] = serde_json::Value::String(new_content);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -455,7 +468,7 @@ pub async fn inject_kb_context(
|
|||
0,
|
||||
serde_json::json!({
|
||||
"role": "system",
|
||||
"content": context_string
|
||||
"content": sanitized_context
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,8 +146,24 @@ pub async fn get_bot_config(
|
|||
let mut theme_logo: Option<String> = None;
|
||||
let mut theme_logo_text: Option<String> = None;
|
||||
|
||||
// Query all config values (no prefix filter - will match in code)
|
||||
let target_bot_id = match get_bot_id_by_name(&mut conn, &bot_name) {
|
||||
Ok(found_id) => found_id,
|
||||
Err(e) => {
|
||||
warn!("Failed to find bot ID for name '{}': {}", bot_name, e);
|
||||
return Ok(Json(BotConfigResponse {
|
||||
public: false,
|
||||
theme_color1: None,
|
||||
theme_color2: None,
|
||||
theme_title: None,
|
||||
theme_logo: None,
|
||||
theme_logo_text: None,
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
// Query all config values for this specific bot
|
||||
match bot_configuration
|
||||
.filter(bot_id.eq(target_bot_id))
|
||||
.select((config_key, config_value))
|
||||
.load::<(String, String)>(&mut conn)
|
||||
{
|
||||
|
|
@ -580,11 +596,20 @@ impl BotOrchestrator {
|
|||
}
|
||||
}
|
||||
|
||||
// Sanitize user message to remove any UTF-16 surrogate characters
|
||||
let sanitized_message_content = message_content
|
||||
.chars()
|
||||
.filter(|c| {
|
||||
let cp = *c as u32;
|
||||
!(0xD800..=0xDBFF).contains(&cp) && !(0xDC00..=0xDFFF).contains(&cp)
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
// Add the current user message to the messages array
|
||||
if let Some(msgs_array) = messages.as_array_mut() {
|
||||
msgs_array.push(serde_json::json!({
|
||||
"role": "user",
|
||||
"content": message_content
|
||||
"content": sanitized_message_content
|
||||
}));
|
||||
}
|
||||
|
||||
|
|
@ -644,6 +669,8 @@ impl BotOrchestrator {
|
|||
let mut analysis_buffer = String::new();
|
||||
let mut in_analysis = false;
|
||||
let mut tool_call_buffer = String::new(); // Accumulate potential tool call JSON chunks
|
||||
let mut accumulating_tool_call = false; // Track if we're currently accumulating a tool call
|
||||
let mut tool_was_executed = false; // Track if a tool was executed to avoid duplicate final message
|
||||
let handler = llm_models::get_handler(&model);
|
||||
|
||||
info!("[STREAM_START] Entering stream processing loop for model: {}", model);
|
||||
|
|
@ -679,12 +706,62 @@ impl BotOrchestrator {
|
|||
// ===== GENERIC TOOL EXECUTION =====
|
||||
// Add chunk to tool_call_buffer and try to parse
|
||||
// Tool calls arrive as JSON that can span multiple chunks
|
||||
let looks_like_json = chunk.trim().starts_with('{') || chunk.trim().starts_with('[') ||
|
||||
tool_call_buffer.contains('{') || tool_call_buffer.contains('[');
|
||||
|
||||
let chunk_in_tool_buffer = if looks_like_json {
|
||||
// Check if this chunk contains JSON (either starts with {/[ or contains {/[)
|
||||
let chunk_contains_json = chunk.trim().starts_with('{') || chunk.trim().starts_with('[') ||
|
||||
chunk.contains('{') || chunk.contains('[');
|
||||
|
||||
let chunk_in_tool_buffer = if accumulating_tool_call {
|
||||
// Already accumulating - add entire chunk to buffer
|
||||
tool_call_buffer.push_str(&chunk);
|
||||
true
|
||||
} else if chunk_contains_json {
|
||||
// Check if { appears in the middle of the chunk (mixed text + JSON)
|
||||
let json_start = chunk.find('{').or_else(|| chunk.find('['));
|
||||
|
||||
if let Some(pos) = json_start {
|
||||
if pos > 0 {
|
||||
// Send the part before { as regular content
|
||||
let regular_part = &chunk[..pos];
|
||||
if !regular_part.trim().is_empty() {
|
||||
info!("[STREAM_CONTENT] Sending regular part before JSON: '{}', len: {}", regular_part, regular_part.len());
|
||||
full_response.push_str(regular_part);
|
||||
|
||||
let response = BotResponse {
|
||||
bot_id: message.bot_id.clone(),
|
||||
user_id: message.user_id.clone(),
|
||||
session_id: message.session_id.clone(),
|
||||
channel: message.channel.clone(),
|
||||
content: regular_part.to_string(),
|
||||
message_type: MessageType::BOT_RESPONSE,
|
||||
stream_token: None,
|
||||
is_complete: false,
|
||||
suggestions: Vec::new(),
|
||||
context_name: None,
|
||||
context_length: 0,
|
||||
context_max_length: 0,
|
||||
};
|
||||
|
||||
if response_tx.send(response).await.is_err() {
|
||||
warn!("Response channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Start accumulating from { onwards
|
||||
accumulating_tool_call = true;
|
||||
tool_call_buffer.push_str(&chunk[pos..]);
|
||||
true
|
||||
} else {
|
||||
// Chunk starts with { or [
|
||||
accumulating_tool_call = true;
|
||||
tool_call_buffer.push_str(&chunk);
|
||||
true
|
||||
}
|
||||
} else {
|
||||
// Contains {/[ but find() failed - shouldn't happen, but send as regular content
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
|
@ -774,13 +851,16 @@ impl BotOrchestrator {
|
|||
// Don't add tool_call JSON to full_response or analysis_buffer
|
||||
// Clear the tool_call_buffer since we found and executed a tool call
|
||||
tool_call_buffer.clear();
|
||||
accumulating_tool_call = false; // Reset accumulation flag
|
||||
tool_was_executed = true; // Mark that a tool was executed
|
||||
// Continue to next chunk
|
||||
continue;
|
||||
}
|
||||
|
||||
// Clear tool_call_buffer if it's getting too large and no tool call was found
|
||||
// This prevents memory issues from accumulating JSON fragments
|
||||
if tool_call_buffer.len() > 10000 {
|
||||
// Increased limit to 50000 to handle large tool calls with many parameters
|
||||
if tool_call_buffer.len() > 50000 {
|
||||
// Flush accumulated content to client since it's too large to be a tool call
|
||||
info!("[TOOL_EXEC] Flushing tool_call_buffer (too large, assuming not a tool call)");
|
||||
full_response.push_str(&tool_call_buffer);
|
||||
|
|
@ -801,6 +881,7 @@ impl BotOrchestrator {
|
|||
};
|
||||
|
||||
tool_call_buffer.clear();
|
||||
accumulating_tool_call = false; // Reset accumulation flag after flush
|
||||
|
||||
if response_tx.send(response).await.is_err() {
|
||||
warn!("Response channel closed");
|
||||
|
|
@ -810,7 +891,7 @@ impl BotOrchestrator {
|
|||
|
||||
// If this chunk was added to tool_call_buffer and no tool call was found yet,
|
||||
// skip processing (it's part of an incomplete tool call JSON)
|
||||
if chunk_in_tool_buffer && tool_call_buffer.len() <= 10000 {
|
||||
if chunk_in_tool_buffer {
|
||||
continue;
|
||||
}
|
||||
// ===== END TOOL EXECUTION =====
|
||||
|
|
@ -941,12 +1022,16 @@ impl BotOrchestrator {
|
|||
#[cfg(not(feature = "chat"))]
|
||||
let suggestions: Vec<crate::core::shared::models::Suggestion> = Vec::new();
|
||||
|
||||
// When a tool was executed, the content was already sent as streaming chunks
|
||||
// (pre-tool text + tool result). Sending full_response again would duplicate it.
|
||||
let final_content = if tool_was_executed { String::new() } else { full_response };
|
||||
|
||||
let final_response = BotResponse {
|
||||
bot_id: message.bot_id,
|
||||
user_id: message.user_id,
|
||||
session_id: message.session_id,
|
||||
channel: message.channel,
|
||||
content: full_response,
|
||||
content: final_content,
|
||||
message_type: MessageType::BOT_RESPONSE,
|
||||
stream_token: None,
|
||||
is_complete: true,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
/// Works across all LLM providers (GLM, OpenAI, Claude, etc.)
|
||||
use log::{error, info, warn};
|
||||
use serde_json::Value;
|
||||
// use std::collections::HashMap;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
|
@ -264,6 +265,8 @@ impl ToolExecutor {
|
|||
script_service.load_bot_config_params(state, bot_id);
|
||||
|
||||
// Set tool parameters as variables in the engine scope
|
||||
// Note: DATE parameters are now sent by LLM in ISO 8601 format (YYYY-MM-DD)
|
||||
// The tool schema with format="date" tells the LLM to use this agnostic format
|
||||
if let Some(obj) = arguments.as_object() {
|
||||
for (key, value) in obj {
|
||||
let value_str = match value {
|
||||
|
|
@ -281,7 +284,7 @@ impl ToolExecutor {
|
|||
}
|
||||
|
||||
// Compile tool script (filters PARAM/DESCRIPTION lines and converts BASIC to Rhai)
|
||||
let ast = match script_service.compile_tool_script(&bas_script) {
|
||||
let ast = match script_service.compile_tool_script(bas_script) {
|
||||
Ok(ast) => ast,
|
||||
Err(e) => {
|
||||
let error_msg = format!("Compilation error: {}", e);
|
||||
|
|
|
|||
|
|
@ -291,10 +291,10 @@ impl AppConfig {
|
|||
smtp_server: get_str("EMAIL_SMTP_SERVER", "smtp.gmail.com"),
|
||||
smtp_port: get_u16("EMAIL_SMTP_PORT", 587),
|
||||
};
|
||||
let port = std::env::var("BOTSERVER_PORT")
|
||||
let port = std::env::var("PORT")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u16>().ok())
|
||||
.unwrap_or_else(|| get_u16("server_port", 8080));
|
||||
.unwrap_or_else(|| get_u16("server_port", 9000));
|
||||
|
||||
Ok(Self {
|
||||
drive,
|
||||
|
|
@ -302,7 +302,7 @@ impl AppConfig {
|
|||
server: ServerConfig {
|
||||
host: get_str("server_host", "0.0.0.0"),
|
||||
port,
|
||||
base_url: config_map.get("server_base_url").cloned().unwrap_or_else(|| "http://localhost:8080".to_string()),
|
||||
base_url: config_map.get("server_base_url").cloned().unwrap_or_else(|| "http://localhost:9000".to_string()),
|
||||
},
|
||||
site_path: {
|
||||
ConfigManager::new(pool.clone()).get_config(
|
||||
|
|
@ -329,10 +329,10 @@ impl AppConfig {
|
|||
smtp_server: "smtp.gmail.com".to_string(),
|
||||
smtp_port: 587,
|
||||
};
|
||||
let port = std::env::var("BOTSERVER_PORT")
|
||||
let port = std::env::var("PORT")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u16>().ok())
|
||||
.unwrap_or(8080);
|
||||
.unwrap_or(9000);
|
||||
|
||||
Ok(Self {
|
||||
drive: minio,
|
||||
|
|
@ -340,7 +340,7 @@ impl AppConfig {
|
|||
server: ServerConfig {
|
||||
host: "0.0.0.0".to_string(),
|
||||
port,
|
||||
base_url: "http://localhost:8080".to_string(),
|
||||
base_url: "http://localhost:9000".to_string(),
|
||||
},
|
||||
|
||||
site_path: "./botserver-stack/sites".to_string(),
|
||||
|
|
@ -419,7 +419,7 @@ impl ConfigManager {
|
|||
.first::<String>(&mut conn)
|
||||
.unwrap_or_else(|_| fallback_str.to_string())
|
||||
} else {
|
||||
String::from(v)
|
||||
v
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ pub async fn reload_config(
|
|||
let mut conn = conn_arc
|
||||
.get()
|
||||
.map_err(|e| format!("failed to get db connection: {e}"))?;
|
||||
Ok(crate::core::bot::get_default_bot(&mut *conn))
|
||||
Ok(crate::core::bot::get_default_bot(&mut conn))
|
||||
})
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
|
|
|
|||
|
|
@ -172,9 +172,9 @@ impl KbIndexer {
|
|||
let mut batch_docs = Vec::with_capacity(BATCH_SIZE);
|
||||
|
||||
// Process documents in iterator to avoid keeping all in memory
|
||||
let mut doc_iter = documents.into_iter();
|
||||
let doc_iter = documents.into_iter();
|
||||
|
||||
while let Some((doc_path, chunks)) = doc_iter.next() {
|
||||
for (doc_path, chunks) in doc_iter {
|
||||
if chunks.is_empty() {
|
||||
debug!("[KB_INDEXER] Skipping document with no chunks: {}", doc_path);
|
||||
continue;
|
||||
|
|
@ -262,9 +262,9 @@ impl KbIndexer {
|
|||
|
||||
// Process chunks in smaller sub-batches to prevent memory exhaustion
|
||||
const CHUNK_BATCH_SIZE: usize = 20; // Process 20 chunks at a time
|
||||
let mut chunk_batches = chunks.chunks(CHUNK_BATCH_SIZE);
|
||||
let chunk_batches = chunks.chunks(CHUNK_BATCH_SIZE);
|
||||
|
||||
while let Some(chunk_batch) = chunk_batches.next() {
|
||||
for chunk_batch in chunk_batches {
|
||||
trace!("[KB_INDEXER] Processing chunk batch of {} chunks", chunk_batch.len());
|
||||
|
||||
let embeddings = match self
|
||||
|
|
|
|||
|
|
@ -221,7 +221,7 @@ impl WebCrawler {
|
|||
self.pages.push(page);
|
||||
|
||||
// Aggressive memory cleanup every 10 pages
|
||||
if self.pages.len() % 10 == 0 {
|
||||
if self.pages.len().is_multiple_of(10) {
|
||||
self.pages.shrink_to_fit();
|
||||
self.visited_urls.shrink_to_fit();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ impl WebsiteCrawlerService {
|
|||
let total_pages = pages.len();
|
||||
|
||||
for (batch_idx, batch) in pages.chunks(BATCH_SIZE).enumerate() {
|
||||
info!("Processing batch {} of {} pages", batch_idx + 1, (total_pages + BATCH_SIZE - 1) / BATCH_SIZE);
|
||||
info!("Processing batch {} of {} pages", batch_idx + 1, total_pages.div_ceil(BATCH_SIZE));
|
||||
|
||||
for (idx, page) in batch.iter().enumerate() {
|
||||
let global_idx = batch_idx * BATCH_SIZE + idx;
|
||||
|
|
@ -377,6 +377,8 @@ impl WebsiteCrawlerService {
|
|||
bot_id: uuid::Uuid,
|
||||
conn: &mut diesel::PgConnection,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let website_regex = regex::Regex::new(r#"(?i)(?:USE\s+WEBSITE\s+"([^"]+)"\s+REFRESH\s+"([^"]+)")|(?:USE_WEBSITE\s*\(\s*"([^"]+)"\s*(?:,\s*"([^"]+)"\s*)?\))"#)?;
|
||||
|
||||
for entry in std::fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
|
@ -384,11 +386,7 @@ impl WebsiteCrawlerService {
|
|||
if path.extension().is_some_and(|ext| ext == "bas") {
|
||||
let content = std::fs::read_to_string(&path)?;
|
||||
|
||||
// Regex to find both syntaxes: USE WEBSITE "url" REFRESH "interval" and USE_WEBSITE("url", "refresh")
|
||||
// Case-insensitive to match preprocessed lowercase versions
|
||||
let re = regex::Regex::new(r#"(?i)(?:USE\s+WEBSITE\s+"([^"]+)"\s+REFRESH\s+"([^"]+)")|(?:USE_WEBSITE\s*\(\s*"([^"]+)"\s*(?:,\s*"([^"]+)"\s*)?\))"#)?;
|
||||
|
||||
for cap in re.captures_iter(&content) {
|
||||
for cap in website_regex.captures_iter(&content) {
|
||||
// Extract URL from either capture group 1 (space syntax) or group 3 (function syntax)
|
||||
let url_str = if let Some(url) = cap.get(1) {
|
||||
url.as_str()
|
||||
|
|
|
|||
|
|
@ -495,12 +495,12 @@ pub async fn require_authentication_middleware(
|
|||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
type MiddlewareFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Response>> + Send>>;
|
||||
|
||||
/// Require specific role - returns 403 if role not present
|
||||
pub fn require_role_middleware(
|
||||
required_role: &'static str,
|
||||
) -> impl Fn(Request<Body>, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Response>> + Send>>
|
||||
+ Clone
|
||||
+ Send {
|
||||
) -> impl Fn(Request<Body>, Next) -> MiddlewareFuture + Clone + Send {
|
||||
move |request: Request<Body>, next: Next| {
|
||||
Box::pin(async move {
|
||||
let user = request
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ pub async fn run() -> Result<()> {
|
|||
"rotate-secret" => {
|
||||
if args.len() < 3 {
|
||||
eprintln!("Usage: botserver rotate-secret <component>");
|
||||
eprintln!("Components: tables, drive, cache, email, directory, encryption");
|
||||
eprintln!("Components: tables, drive, cache, email, directory, encryption, jwt");
|
||||
return Ok(());
|
||||
}
|
||||
let component = &args[2];
|
||||
|
|
@ -282,6 +282,7 @@ fn print_usage() {
|
|||
println!(" restart Restart all components");
|
||||
println!(" vault <subcommand> Manage Vault secrets");
|
||||
println!(" rotate-secret <comp> Rotate a component's credentials");
|
||||
println!(" (tables, drive, cache, email, directory, encryption, jwt)");
|
||||
println!(" rotate-secrets --all Rotate ALL credentials (dangerous!)");
|
||||
println!(" version [--all] Show version information");
|
||||
println!(" --version, -v Show version");
|
||||
|
|
@ -788,6 +789,7 @@ async fn rotate_secret(component: &str) -> Result<()> {
|
|||
if input.trim().to_lowercase() == "y" {
|
||||
manager.put_secret(SecretPaths::TABLES, secrets).await?;
|
||||
println!("✓ Credentials saved to Vault");
|
||||
verify_rotation(component).await?;
|
||||
} else {
|
||||
println!("✗ Aborted");
|
||||
}
|
||||
|
|
@ -933,9 +935,81 @@ async fn rotate_secret(component: &str) -> Result<()> {
|
|||
println!("✗ Aborted");
|
||||
}
|
||||
}
|
||||
"jwt" => {
|
||||
let new_secret = generate_password(64);
|
||||
let env_path = std::env::current_dir()?.join(".env");
|
||||
|
||||
println!("⚠️ JWT SECRET ROTATION");
|
||||
println!();
|
||||
println!("Current: JWT_SECRET in .env file");
|
||||
println!("Impact: ALL refresh tokens will become invalid immediately");
|
||||
println!("Access tokens (15 min) will expire naturally");
|
||||
println!();
|
||||
|
||||
// Check if .env exists
|
||||
if !env_path.exists() {
|
||||
println!("✗ .env file not found at: {}", env_path.display());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read current JWT_SECRET for display
|
||||
let env_content = std::fs::read_to_string(&env_path)?;
|
||||
let current_jwt = env_content
|
||||
.lines()
|
||||
.find(|line| line.starts_with("JWT_SECRET="))
|
||||
.unwrap_or("JWT_SECRET=(not set)");
|
||||
|
||||
println!("Current: {}", ¤t_jwt.chars().take(40).collect::<String>());
|
||||
println!("New: {}... (64 chars)", &new_secret.chars().take(8).collect::<String>());
|
||||
println!();
|
||||
|
||||
// Backup .env
|
||||
let backup_path = format!("{}.backup.{}", env_path.display(), chrono::Utc::now().timestamp());
|
||||
std::fs::copy(&env_path, &backup_path)?;
|
||||
println!("✓ Backup created: {}", backup_path);
|
||||
println!();
|
||||
|
||||
print!("Update JWT_SECRET in .env? [y/N]: ");
|
||||
std::io::Write::flush(&mut std::io::stdout())?;
|
||||
let mut input = String::new();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
|
||||
if input.trim().to_lowercase() == "y" {
|
||||
// Read, update, write .env atomically
|
||||
let content = std::fs::read_to_string(&env_path)?;
|
||||
let new_content = content
|
||||
.lines()
|
||||
.map(|line| {
|
||||
if line.starts_with("JWT_SECRET=") {
|
||||
format!("JWT_SECRET={}", new_secret)
|
||||
} else {
|
||||
line.to_string()
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let temp_path = format!("{}.new", env_path.display());
|
||||
std::fs::write(&temp_path, new_content)?;
|
||||
std::fs::rename(&temp_path, &env_path)?;
|
||||
|
||||
println!("✓ JWT_SECRET updated in .env");
|
||||
println!();
|
||||
println!("⚠️ RESTART REQUIRED:");
|
||||
println!(" botserver restart");
|
||||
println!();
|
||||
println!("All users must re-login after restart (refresh tokens invalid)");
|
||||
println!("Access tokens will expire naturally within 15 minutes");
|
||||
|
||||
verify_rotation(component).await?;
|
||||
} else {
|
||||
println!("✗ Aborted");
|
||||
println!("Backup preserved at: {}", backup_path);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
eprintln!("Unknown component: {}", component);
|
||||
eprintln!("Valid components: tables, drive, cache, email, directory, encryption");
|
||||
eprintln!("Valid components: tables, drive, cache, email, directory, encryption, jwt");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1041,6 +1115,96 @@ async fn rotate_all_secrets() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn verify_rotation(component: &str) -> Result<()> {
|
||||
println!();
|
||||
println!("Verifying {}...", component);
|
||||
|
||||
match component {
|
||||
"tables" => {
|
||||
let manager = SecretsManager::from_env()?;
|
||||
let secrets = manager.get_secret(SecretPaths::TABLES).await?;
|
||||
|
||||
let host = secrets.get("host").cloned().unwrap_or_else(|| "localhost".to_string());
|
||||
let port = secrets.get("port").cloned().unwrap_or_else(|| "5432".to_string());
|
||||
let user = secrets.get("username").cloned().unwrap_or_else(|| "postgres".to_string());
|
||||
let pass = secrets.get("password").cloned().unwrap_or_default();
|
||||
let db = secrets.get("database").cloned().unwrap_or_else(|| "postgres".to_string());
|
||||
|
||||
println!(" Testing connection to {}@{}:{}...", user, host, port);
|
||||
|
||||
// Use psql to test connection
|
||||
let result = std::process::Command::new("psql")
|
||||
.args([
|
||||
"-h", &host,
|
||||
"-p", &port,
|
||||
"-U", &user,
|
||||
"-d", &db,
|
||||
"-c", "SELECT 1;",
|
||||
"-t", "-q" // Tuples only, quiet mode
|
||||
])
|
||||
.env("PGPASSWORD", &pass)
|
||||
.output();
|
||||
|
||||
match result {
|
||||
Ok(output) if output.status.success() => {
|
||||
println!("✓ Database connection successful");
|
||||
}
|
||||
Ok(output) => {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
println!("✗ Database connection FAILED");
|
||||
println!(" Error: {}", stderr.trim());
|
||||
println!(" Hint: Run the SQL command provided by rotate-secret");
|
||||
}
|
||||
Err(_e) => {
|
||||
println!("⊘ Verification skipped (psql not available)");
|
||||
println!(" Hint: Manually test with: psql -h {} -U {} -d {} -c 'SELECT 1'", host, user, db);
|
||||
}
|
||||
}
|
||||
}
|
||||
"jwt" => {
|
||||
println!(" Testing health endpoint...");
|
||||
|
||||
// Try to determine the health endpoint
|
||||
let health_urls = vec![
|
||||
"http://localhost:8080/health",
|
||||
"http://localhost:5858/health",
|
||||
"http://localhost:3000/health",
|
||||
];
|
||||
|
||||
let mut success = false;
|
||||
for url in health_urls {
|
||||
match reqwest::get(url).await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
println!("✓ Service healthy at {}", url);
|
||||
success = true;
|
||||
break;
|
||||
}
|
||||
Ok(_resp) => {
|
||||
// Try next URL
|
||||
continue;
|
||||
}
|
||||
Err(_e) => {
|
||||
// Try next URL
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !success {
|
||||
println!("⊘ Health endpoint not reachable");
|
||||
println!(" Hint: Restart botserver with: botserver restart");
|
||||
println!(" Then manually verify service is responding");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
println!("⊘ No automated verification available for {}", component);
|
||||
println!(" Hint: Manually verify the service is working after rotation");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn vault_health() -> Result<()> {
|
||||
let manager = SecretsManager::from_env()?;
|
||||
|
||||
|
|
|
|||
|
|
@ -553,7 +553,7 @@ Store credentials in Vault:
|
|||
r"Email Server (Stalwart):
|
||||
SMTP: {}:25
|
||||
IMAP: {}:143
|
||||
Web: http://{}:8080
|
||||
Web: http://{}:9000
|
||||
|
||||
Store credentials in Vault:
|
||||
botserver vault put gbo/email server={} port=25 username=admin password=<your-password>",
|
||||
|
|
@ -563,11 +563,11 @@ Store credentials in Vault:
|
|||
"directory" => {
|
||||
format!(
|
||||
r"Zitadel Identity Provider:
|
||||
URL: http://{}:8080
|
||||
Console: http://{}:8080/ui/console
|
||||
URL: http://{}:9000
|
||||
Console: http://{}:9000/ui/console
|
||||
|
||||
Store credentials in Vault:
|
||||
botserver vault put gbo/directory url=http://{}:8080 client_id=<client-id> client_secret=<client-secret>",
|
||||
botserver vault put gbo/directory url=http://{}:9000 client_id=<client-id> client_secret=<client-secret>",
|
||||
ip, ip, ip
|
||||
)
|
||||
}
|
||||
|
|
@ -1047,7 +1047,7 @@ Store credentials in Vault:
|
|||
Ok(())
|
||||
}
|
||||
pub fn run_commands(&self, commands: &[String], target: &str, component: &str) -> Result<()> {
|
||||
self.run_commands_with_password(commands, target, component, &String::new())
|
||||
self.run_commands_with_password(commands, target, component, "")
|
||||
}
|
||||
|
||||
pub fn run_commands_with_password(&self, commands: &[String], target: &str, component: &str, db_password_override: &str) -> Result<()> {
|
||||
|
|
@ -1081,7 +1081,7 @@ Store credentials in Vault:
|
|||
match get_database_url_sync() {
|
||||
Ok(url) => {
|
||||
let (_, password, _, _, _) = parse_database_url(&url);
|
||||
String::from(password)
|
||||
password
|
||||
}
|
||||
Err(_) => {
|
||||
trace!("Vault not available for DB_PASSWORD, using empty string");
|
||||
|
|
|
|||
|
|
@ -602,7 +602,7 @@ impl PackageManager {
|
|||
post_install_cmds_windows: vec![],
|
||||
env_vars: HashMap::new(),
|
||||
data_download_list: Vec::new(),
|
||||
exec_cmd: "php -S 0.0.0.0:8080 -t {{DATA_PATH}}/roundcubemail".to_string(),
|
||||
exec_cmd: "php -S 0.0.0.0:9000 -t {{DATA_PATH}}/roundcubemail".to_string(),
|
||||
check_cmd:
|
||||
"curl -f -k --connect-timeout 2 -m 5 https://localhost:8300 >/dev/null 2>&1"
|
||||
.to_string(),
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ impl DirectorySetup {
|
|||
_org_id: &str,
|
||||
) -> Result<(String, String, String)> {
|
||||
let app_name = "BotServer";
|
||||
let redirect_uri = "http://localhost:8080/auth/callback".to_string();
|
||||
let redirect_uri = "http://localhost:9000/auth/callback".to_string();
|
||||
|
||||
let project_response = self
|
||||
.client
|
||||
|
|
@ -362,7 +362,7 @@ impl DirectorySetup {
|
|||
"grantTypes": ["OIDC_GRANT_TYPE_AUTHORIZATION_CODE", "OIDC_GRANT_TYPE_REFRESH_TOKEN", "OIDC_GRANT_TYPE_PASSWORD"],
|
||||
"appType": "OIDC_APP_TYPE_WEB",
|
||||
"authMethodType": "OIDC_AUTH_METHOD_TYPE_POST",
|
||||
"postLogoutRedirectUris": ["http://localhost:8080", "http://localhost:3000", "http://localhost:9000"],
|
||||
"postLogoutRedirectUris": ["http://localhost:9000", "http://localhost:3000", "http://localhost:9000"],
|
||||
"accessTokenType": "OIDC_TOKEN_TYPE_BEARER",
|
||||
"devMode": true,
|
||||
}))
|
||||
|
|
@ -466,10 +466,10 @@ Database:
|
|||
Machine:
|
||||
Identification:
|
||||
Hostname: localhost
|
||||
WebhookAddress: http://localhost:8080
|
||||
WebhookAddress: http://localhost:9000
|
||||
|
||||
ExternalDomain: localhost:8080
|
||||
ExternalPort: 8080
|
||||
ExternalDomain: localhost:9000
|
||||
ExternalPort: 9000
|
||||
ExternalSecure: false
|
||||
|
||||
TLS:
|
||||
|
|
|
|||
|
|
@ -203,7 +203,7 @@ impl EmailSetup {
|
|||
|
||||
let issuer_url = dir_config["base_url"]
|
||||
.as_str()
|
||||
.unwrap_or("http://localhost:8080");
|
||||
.unwrap_or("http://localhost:9000");
|
||||
|
||||
log::info!("Setting up OIDC authentication with Directory...");
|
||||
log::info!("Issuer URL: {}", issuer_url);
|
||||
|
|
@ -289,7 +289,7 @@ protocol = "imap"
|
|||
tls.implicit = true
|
||||
|
||||
[server.listener."http"]
|
||||
bind = ["0.0.0.0:8080"]
|
||||
bind = ["0.0.0.0:9000"]
|
||||
protocol = "http"
|
||||
|
||||
[storage]
|
||||
|
|
@ -315,7 +315,7 @@ store = "sqlite"
|
|||
r#"
|
||||
[directory."oidc"]
|
||||
type = "oidc"
|
||||
issuer = "http://localhost:8080"
|
||||
issuer = "http://localhost:9000"
|
||||
client-id = "{{CLIENT_ID}}"
|
||||
client-secret = "{{CLIENT_SECRET}}"
|
||||
|
||||
|
|
|
|||
|
|
@ -748,7 +748,7 @@ impl<T: Clone + Send + Sync + 'static> BatchProcessor<T> {
|
|||
F: Fn(Vec<T>) -> Fut + Send + Sync + 'static,
|
||||
Fut: std::future::Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let processor_arc: Arc<dyn Fn(Vec<T>) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync> =
|
||||
let processor_arc: BatchProcessorFunc<T> =
|
||||
Arc::new(move |items| Box::pin(processor(items)));
|
||||
|
||||
let batch_processor = Self {
|
||||
|
|
|
|||
|
|
@ -381,7 +381,7 @@ impl SecretsManager {
|
|||
secrets.insert("token".into(), String::new());
|
||||
}
|
||||
SecretPaths::ALM => {
|
||||
secrets.insert("url".into(), "http://localhost:8080".into());
|
||||
secrets.insert("url".into(), "http://localhost:9000".into());
|
||||
secrets.insert("username".into(), String::new());
|
||||
secrets.insert("password".into(), String::new());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,270 +1,50 @@
|
|||
use super::admin_types::*;
|
||||
use crate::core::shared::state::AppState;
|
||||
use crate::core::urls::ApiUrls;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
routing::{get, post},
|
||||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel::sql_types::{Text, Nullable};
|
||||
use log::{error, info};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Get admin dashboard data
|
||||
pub async fn get_admin_dashboard(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(bot_id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
let bot_id = bot_id.into_inner();
|
||||
|
||||
// Get system status
|
||||
let (database_ok, redis_ok) = match get_system_status(&state).await {
|
||||
Ok(status) => (true, status.is_healthy()),
|
||||
Err(e) => {
|
||||
error!("Failed to get system status: {}", e);
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
// Get user count
|
||||
let user_count = get_stats_users(&state).await.unwrap_or(0);
|
||||
let group_count = get_stats_groups(&state).await.unwrap_or(0);
|
||||
let bot_count = get_stats_bots(&state).await.unwrap_or(0);
|
||||
|
||||
// Get storage stats
|
||||
let storage_stats = get_stats_storage(&state).await.unwrap_or_else(|| StorageStat {
|
||||
total_gb: 0,
|
||||
used_gb: 0,
|
||||
percent: 0.0,
|
||||
});
|
||||
|
||||
// Get recent activities
|
||||
let activities = get_dashboard_activity(&state, Some(20))
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Get member/bot/invitation stats
|
||||
let member_count = get_dashboard_members(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let bot_list = get_dashboard_bots(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let invitation_count = get_dashboard_invitations(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let dashboard_data = AdminDashboardData {
|
||||
users: vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Users".to_string(),
|
||||
count: user_count as i64,
|
||||
},
|
||||
GroupStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Groups".to_string(),
|
||||
count: group_count as i64,
|
||||
},
|
||||
BotStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Bots".to_string(),
|
||||
count: bot_count as i64,
|
||||
},
|
||||
],
|
||||
groups,
|
||||
bots: bot_list,
|
||||
storage: storage_stats,
|
||||
activities,
|
||||
invitations: vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Members".to_string(),
|
||||
count: member_count as i64,
|
||||
},
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Invitations".to_string(),
|
||||
count: invitation_count as i64,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(dashboard_data)).into_response()
|
||||
// Helper function to get dashboard members
|
||||
async fn get_dashboard_members(
|
||||
state: &AppState,
|
||||
bot_id: Uuid,
|
||||
limit: i64,
|
||||
) -> Result<i64, diesel::result::Error> {
|
||||
// TODO: Implement actual member fetching logic
|
||||
// For now, return a placeholder count
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Get system health status
|
||||
pub async fn get_system_status(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
let (database_ok, redis_ok) = match get_system_status(&state).await {
|
||||
Ok(status) => (true, status.is_healthy()),
|
||||
Err(e) => {
|
||||
error!("Failed to get system status: {}", e);
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
let response = SystemHealth {
|
||||
database: database_ok,
|
||||
redis: redis_ok,
|
||||
services: vec![],
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
// Helper function to get dashboard invitations
|
||||
async fn get_dashboard_invitations(
|
||||
state: &AppState,
|
||||
bot_id: Uuid,
|
||||
limit: i64,
|
||||
) -> Result<i64, diesel::result::Error> {
|
||||
// TODO: Use organization_invitations table when available in model maps
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Get system metrics
|
||||
pub async fn get_system_metrics(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
// Get CPU usage
|
||||
let cpu_usage = sys_info::get_system_cpu_usage();
|
||||
let cpu_usage_percent = if cpu_usage > 0.0 {
|
||||
(cpu_usage / sys_info::get_system_cpu_count() as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Get memory usage
|
||||
let mem_total = sys_info::get_total_memory_mb();
|
||||
let mem_used = sys_info::get_used_memory_mb();
|
||||
let mem_percent = if mem_total > 0 {
|
||||
((mem_total - mem_used) as f64 / mem_total as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Get disk usage
|
||||
let disk_total = sys_info::get_total_disk_space_gb();
|
||||
let disk_used = sys_info::get_used_disk_space_gb();
|
||||
let disk_percent = if disk_total > 0.0 {
|
||||
((disk_total - disk_used) as f64 / disk_total as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let services = vec![
|
||||
ServiceStatus {
|
||||
name: "database".to_string(),
|
||||
status: if database_ok { "running" } else { "stopped" }.to_string(),
|
||||
uptime_seconds: 0,
|
||||
},
|
||||
ServiceStatus {
|
||||
name: "redis".to_string(),
|
||||
status: if redis_ok { "running" } else { "stopped" }.to_string(),
|
||||
uptime_seconds: 0,
|
||||
},
|
||||
];
|
||||
|
||||
let metrics = SystemMetricsResponse {
|
||||
cpu_usage,
|
||||
memory_total_mb: mem_total,
|
||||
memory_used_mb: mem_used,
|
||||
memory_percent: mem_percent,
|
||||
disk_total_gb: disk_total,
|
||||
disk_used_gb: disk_used,
|
||||
disk_percent: disk_percent,
|
||||
network_in_mbps: 0.0,
|
||||
network_out_mbps: 0.0,
|
||||
active_connections: 0,
|
||||
request_rate_per_minute: 0,
|
||||
error_rate_percent: 0.0,
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(metrics)).into_response()
|
||||
}
|
||||
|
||||
/// Get user statistics
|
||||
pub async fn get_stats_users(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::users;
|
||||
|
||||
let count = users::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get user count: {}", e))?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Users".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get group statistics
|
||||
pub async fn get_stats_groups(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::bot_groups;
|
||||
|
||||
let count = bot_groups::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get group count: {}", e))?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Groups".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get bot statistics
|
||||
pub async fn get_stats_bots(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
// Helper function to get dashboard bots
|
||||
async fn get_dashboard_bots(
|
||||
state: &AppState,
|
||||
bot_id: Uuid,
|
||||
limit: i64,
|
||||
) -> Result<Vec<BotStat>, diesel::result::Error> {
|
||||
use crate::core::shared::models::schema::bots;
|
||||
|
||||
let count = bots::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get bot count: {}", e))?;
|
||||
let bot_list = bots::table
|
||||
.limit(limit)
|
||||
.load::<crate::core::shared::models::Bot>(&state.conn)?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Bots".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
let stats = bot_list.into_iter().map(|b| BotStat {
|
||||
id: b.id,
|
||||
name: b.name,
|
||||
count: 1, // Placeholder
|
||||
}).collect();
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
/// Get storage statistics
|
||||
pub async fn get_stats_storage(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::storage_usage;
|
||||
|
||||
let usage = storage_usage::table
|
||||
.limit(100)
|
||||
.order_by(crate::core::shared::models::schema::storage_usage::timestamp.desc())
|
||||
.load(&state.conn)
|
||||
.map_err(|e| format!("Failed to get storage stats: {}", e))?;
|
||||
|
||||
let total_gb = usage.iter().map(|u| u.total_gb.unwrap_or(0.0)).sum::<f64>();
|
||||
let used_gb = usage.iter().map(|u| u.used_gb.unwrap_or(0.0)).sum::<f64>();
|
||||
let percent = if total_gb > 0.0 { (used_gb / total_gb * 100.0) } else { 0.0 };
|
||||
|
||||
let response = StorageStat {
|
||||
total_gb: total_gb.round(),
|
||||
used_gb: used_gb.round(),
|
||||
percent: (percent * 100.0).round(),
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
// Helper function to get dashboard activity
|
||||
async fn get_dashboard_activity(
|
||||
state: &AppState,
|
||||
limit: Option<i64>,
|
||||
) -> Result<Vec<ActivityLog>, diesel::result::Error> {
|
||||
// Placeholder
|
||||
Ok(vec![])
|
||||
}
|
||||
|
|
|
|||
270
src/core/shared/admin_handlers.rs.bak
Normal file
270
src/core/shared/admin_handlers.rs.bak
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
use super::admin_types::*;
|
||||
use crate::core::shared::state::AppState;
|
||||
use crate::core::urls::ApiUrls;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
routing::{get, post},
|
||||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel::sql_types::{Text, Nullable};
|
||||
use log::{error, info};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Get admin dashboard data
|
||||
pub async fn get_admin_dashboard(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(bot_id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
let bot_id = bot_id.into_inner();
|
||||
|
||||
// Get system status
|
||||
let (database_ok, redis_ok) = match get_system_status(&state).await {
|
||||
Ok(status) => (true, status.is_healthy()),
|
||||
Err(e) => {
|
||||
error!("Failed to get system status: {}", e);
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
// Get user count
|
||||
let user_count = get_stats_users(&state).await.unwrap_or(0);
|
||||
let group_count = get_stats_groups(&state).await.unwrap_or(0);
|
||||
let bot_count = get_stats_bots(&state).await.unwrap_or(0);
|
||||
|
||||
// Get storage stats
|
||||
let storage_stats = get_stats_storage(&state).await.unwrap_or_else(|| StorageStat {
|
||||
total_gb: 0,
|
||||
used_gb: 0,
|
||||
percent: 0.0,
|
||||
});
|
||||
|
||||
// Get recent activities
|
||||
let activities = get_dashboard_activity(&state, Some(20))
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Get member/bot/invitation stats
|
||||
let member_count = get_dashboard_members(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let bot_list = get_dashboard_bots(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let invitation_count = get_dashboard_invitations(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let dashboard_data = AdminDashboardData {
|
||||
users: vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Users".to_string(),
|
||||
count: user_count as i64,
|
||||
},
|
||||
GroupStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Groups".to_string(),
|
||||
count: group_count as i64,
|
||||
},
|
||||
BotStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Bots".to_string(),
|
||||
count: bot_count as i64,
|
||||
},
|
||||
],
|
||||
groups,
|
||||
bots: bot_list,
|
||||
storage: storage_stats,
|
||||
activities,
|
||||
invitations: vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Members".to_string(),
|
||||
count: member_count as i64,
|
||||
},
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Invitations".to_string(),
|
||||
count: invitation_count as i64,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(dashboard_data)).into_response()
|
||||
}
|
||||
|
||||
/// Get system health status
|
||||
pub async fn get_system_status(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
let (database_ok, redis_ok) = match get_system_status(&state).await {
|
||||
Ok(status) => (true, status.is_healthy()),
|
||||
Err(e) => {
|
||||
error!("Failed to get system status: {}", e);
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
let response = SystemHealth {
|
||||
database: database_ok,
|
||||
redis: redis_ok,
|
||||
services: vec![],
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get system metrics
|
||||
pub async fn get_system_metrics(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
// Get CPU usage
|
||||
let cpu_usage = sys_info::get_system_cpu_usage();
|
||||
let cpu_usage_percent = if cpu_usage > 0.0 {
|
||||
(cpu_usage / sys_info::get_system_cpu_count() as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Get memory usage
|
||||
let mem_total = sys_info::get_total_memory_mb();
|
||||
let mem_used = sys_info::get_used_memory_mb();
|
||||
let mem_percent = if mem_total > 0 {
|
||||
((mem_total - mem_used) as f64 / mem_total as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Get disk usage
|
||||
let disk_total = sys_info::get_total_disk_space_gb();
|
||||
let disk_used = sys_info::get_used_disk_space_gb();
|
||||
let disk_percent = if disk_total > 0.0 {
|
||||
((disk_total - disk_used) as f64 / disk_total as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let services = vec![
|
||||
ServiceStatus {
|
||||
name: "database".to_string(),
|
||||
status: if database_ok { "running" } else { "stopped" }.to_string(),
|
||||
uptime_seconds: 0,
|
||||
},
|
||||
ServiceStatus {
|
||||
name: "redis".to_string(),
|
||||
status: if redis_ok { "running" } else { "stopped" }.to_string(),
|
||||
uptime_seconds: 0,
|
||||
},
|
||||
];
|
||||
|
||||
let metrics = SystemMetricsResponse {
|
||||
cpu_usage,
|
||||
memory_total_mb: mem_total,
|
||||
memory_used_mb: mem_used,
|
||||
memory_percent: mem_percent,
|
||||
disk_total_gb: disk_total,
|
||||
disk_used_gb: disk_used,
|
||||
disk_percent: disk_percent,
|
||||
network_in_mbps: 0.0,
|
||||
network_out_mbps: 0.0,
|
||||
active_connections: 0,
|
||||
request_rate_per_minute: 0,
|
||||
error_rate_percent: 0.0,
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(metrics)).into_response()
|
||||
}
|
||||
|
||||
/// Get user statistics
|
||||
pub async fn get_stats_users(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::users;
|
||||
|
||||
let count = users::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get user count: {}", e))?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Users".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get group statistics
|
||||
pub async fn get_stats_groups(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::bot_groups;
|
||||
|
||||
let count = bot_groups::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get group count: {}", e))?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Groups".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get bot statistics
|
||||
pub async fn get_stats_bots(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::bots;
|
||||
|
||||
let count = bots::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get bot count: {}", e))?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Bots".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get storage statistics
|
||||
pub async fn get_stats_storage(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::storage_usage;
|
||||
|
||||
let usage = storage_usage::table
|
||||
.limit(100)
|
||||
.order_by(crate::core::shared::models::schema::storage_usage::timestamp.desc())
|
||||
.load(&state.conn)
|
||||
.map_err(|e| format!("Failed to get storage stats: {}", e))?;
|
||||
|
||||
let total_gb = usage.iter().map(|u| u.total_gb.unwrap_or(0.0)).sum::<f64>();
|
||||
let used_gb = usage.iter().map(|u| u.used_gb.unwrap_or(0.0)).sum::<f64>();
|
||||
let percent = if total_gb > 0.0 { (used_gb / total_gb * 100.0) } else { 0.0 };
|
||||
|
||||
let response = StorageStat {
|
||||
total_gb: total_gb.round(),
|
||||
used_gb: used_gb.round(),
|
||||
percent: (percent * 100.0).round(),
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
321
src/core/shared/admin_handlers.rs.new
Normal file
321
src/core/shared/admin_handlers.rs.new
Normal file
|
|
@ -0,0 +1,321 @@
|
|||
use super::admin_types::*;
|
||||
use crate::core::shared::state::AppState;
|
||||
use crate::core::urls::ApiUrls;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
routing::{get, post},
|
||||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel::sql_types::{Text, Nullable};
|
||||
use log::{error, info};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Get admin dashboard data
|
||||
pub async fn get_admin_dashboard(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(bot_id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
let bot_id = bot_id.into_inner();
|
||||
|
||||
// Get system status
|
||||
let (database_ok, redis_ok) = match get_system_status(&state).await {
|
||||
Ok(status) => (true, status.is_healthy()),
|
||||
Err(e) => {
|
||||
error!("Failed to get system status: {}", e);
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
// Get user count
|
||||
let user_count = get_stats_users(&state).await.unwrap_or(0);
|
||||
let group_count = get_stats_groups(&state).await.unwrap_or(0);
|
||||
let bot_count = get_stats_bots(&state).await.unwrap_or(0);
|
||||
|
||||
// Get storage stats
|
||||
let storage_stats = get_stats_storage(&state).await.unwrap_or_else(|| StorageStat {
|
||||
total_gb: 0,
|
||||
used_gb: 0,
|
||||
percent: 0.0,
|
||||
});
|
||||
|
||||
// Get recent activities
|
||||
let activities = get_dashboard_activity(&state, Some(20))
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Get member/bot/invitation stats
|
||||
let member_count = get_dashboard_members(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let bot_list = get_dashboard_bots(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let invitation_count = get_dashboard_invitations(&state, bot_id, 50)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let dashboard_data = AdminDashboardData {
|
||||
users: vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Users".to_string(),
|
||||
count: user_count as i64,
|
||||
},
|
||||
GroupStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Groups".to_string(),
|
||||
count: group_count as i64,
|
||||
},
|
||||
BotStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Bots".to_string(),
|
||||
count: bot_count as i64,
|
||||
},
|
||||
],
|
||||
groups,
|
||||
bots: bot_list,
|
||||
storage: storage_stats,
|
||||
activities,
|
||||
invitations: vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Members".to_string(),
|
||||
count: member_count as i64,
|
||||
},
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Invitations".to_string(),
|
||||
count: invitation_count as i64,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(dashboard_data)).into_response()
|
||||
}
|
||||
|
||||
/// Get system health status
|
||||
pub async fn get_system_status(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
let (database_ok, redis_ok) = match get_system_status(&state).await {
|
||||
Ok(status) => (true, status.is_healthy()),
|
||||
Err(e) => {
|
||||
error!("Failed to get system status: {}", e);
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
let response = SystemHealth {
|
||||
database: database_ok,
|
||||
redis: redis_ok,
|
||||
services: vec![],
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get system metrics
|
||||
pub async fn get_system_metrics(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
// Get CPU usage
|
||||
let cpu_usage = sys_info::get_system_cpu_usage();
|
||||
let cpu_usage_percent = if cpu_usage > 0.0 {
|
||||
(cpu_usage / sys_info::get_system_cpu_count() as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Get memory usage
|
||||
let mem_total = sys_info::get_total_memory_mb();
|
||||
let mem_used = sys_info::get_used_memory_mb();
|
||||
let mem_percent = if mem_total > 0 {
|
||||
((mem_total - mem_used) as f64 / mem_total as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Get disk usage
|
||||
let disk_total = sys_info::get_total_disk_space_gb();
|
||||
let disk_used = sys_info::get_used_disk_space_gb();
|
||||
let disk_percent = if disk_total > 0.0 {
|
||||
((disk_total - disk_used) as f64 / disk_total as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let services = vec![
|
||||
ServiceStatus {
|
||||
name: "database".to_string(),
|
||||
status: if database_ok { "running" } else { "stopped" }.to_string(),
|
||||
uptime_seconds: 0,
|
||||
},
|
||||
ServiceStatus {
|
||||
name: "redis".to_string(),
|
||||
status: if redis_ok { "running" } else { "stopped" }.to_string(),
|
||||
uptime_seconds: 0,
|
||||
},
|
||||
];
|
||||
|
||||
let metrics = SystemMetricsResponse {
|
||||
cpu_usage,
|
||||
memory_total_mb: mem_total,
|
||||
memory_used_mb: mem_used,
|
||||
memory_percent: mem_percent,
|
||||
disk_total_gb: disk_total,
|
||||
disk_used_gb: disk_used,
|
||||
disk_percent: disk_percent,
|
||||
network_in_mbps: 0.0,
|
||||
network_out_mbps: 0.0,
|
||||
active_connections: 0,
|
||||
request_rate_per_minute: 0,
|
||||
error_rate_percent: 0.0,
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(metrics)).into_response()
|
||||
}
|
||||
|
||||
/// Get user statistics
|
||||
pub async fn get_stats_users(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::users;
|
||||
|
||||
let count = users::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get user count: {}", e))?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Users".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get group statistics
|
||||
pub async fn get_stats_groups(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::bot_groups;
|
||||
|
||||
let count = bot_groups::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get group count: {}", e))?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Groups".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get bot statistics
|
||||
pub async fn get_stats_bots(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::bots;
|
||||
|
||||
let count = bots::table
|
||||
.count()
|
||||
.get_result(&state.conn)
|
||||
.map_err(|e| format!("Failed to get bot count: {}", e))?;
|
||||
|
||||
let response = vec![
|
||||
UserStat {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Total Bots".to_string(),
|
||||
count: count as i64,
|
||||
},
|
||||
];
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
/// Get storage statistics
|
||||
pub async fn get_stats_storage(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::storage_usage;
|
||||
|
||||
let usage = storage_usage::table
|
||||
.limit(100)
|
||||
.order_by(crate::core::shared::models::schema::storage_usage::timestamp.desc())
|
||||
.load(&state.conn)
|
||||
.map_err(|e| format!("Failed to get storage stats: {}", e))?;
|
||||
|
||||
let total_gb = usage.iter().map(|u| u.total_gb.unwrap_or(0.0)).sum::<f64>();
|
||||
let used_gb = usage.iter().map(|u| u.used_gb.unwrap_or(0.0)).sum::<f64>();
|
||||
let percent = if total_gb > 0.0 { (used_gb / total_gb * 100.0) } else { 0.0 };
|
||||
|
||||
let response = StorageStat {
|
||||
total_gb: total_gb.round(),
|
||||
used_gb: used_gb.round(),
|
||||
percent: (percent * 100.0).round(),
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
|
||||
// Helper function to get dashboard members
|
||||
async fn get_dashboard_members(
|
||||
state: &AppState,
|
||||
bot_id: Uuid,
|
||||
limit: i64,
|
||||
) -> Result<i64, diesel::result::Error> {
|
||||
// TODO: Implement actual member fetching logic
|
||||
// For now, return a placeholder count
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
// Helper function to get dashboard invitations
|
||||
async fn get_dashboard_invitations(
|
||||
state: &AppState,
|
||||
bot_id: Uuid,
|
||||
limit: i64,
|
||||
) -> Result<i64, diesel::result::Error> {
|
||||
// TODO: Use organization_invitations table when available in model maps
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
// Helper function to get dashboard bots
|
||||
async fn get_dashboard_bots(
|
||||
state: &AppState,
|
||||
bot_id: Uuid,
|
||||
limit: i64,
|
||||
) -> Result<Vec<BotStat>, diesel::result::Error> {
|
||||
use crate::core::shared::models::schema::bots;
|
||||
|
||||
let bot_list = bots::table
|
||||
.limit(limit)
|
||||
.load::<crate::core::shared::models::Bot>(&state.conn)?;
|
||||
|
||||
let stats = bot_list.into_iter().map(|b| BotStat {
|
||||
id: b.id,
|
||||
name: b.name,
|
||||
count: 1, // Placeholder
|
||||
}).collect();
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
// Helper function to get dashboard activity
|
||||
async fn get_dashboard_activity(
|
||||
state: &AppState,
|
||||
limit: Option<i64>,
|
||||
) -> Result<Vec<ActivityLog>, diesel::result::Error> {
|
||||
// Placeholder
|
||||
Ok(vec![])
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
// Admin invitation management functions
|
||||
use super::admin_types::*;
|
||||
use crate::core::shared::models::core::OrganizationInvitation;
|
||||
use crate::core::shared::state::AppState;
|
||||
use crate::core::urls::ApiUrls;
|
||||
use axum::{
|
||||
|
|
@ -7,113 +7,382 @@ use axum::{
|
|||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use chrono::Utc;
|
||||
use chrono::{Duration, Utc};
|
||||
use diesel::prelude::*;
|
||||
use log::{error, info, warn};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// List all invitations
|
||||
pub async fn list_invitations(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
// TODO: Implement when invitations table is available in schema
|
||||
warn!("list_invitations called - not fully implemented");
|
||||
(StatusCode::OK, Json(BulkInvitationResponse { invitations: vec![] })).into_response()
|
||||
use crate::core::shared::models::schema::organization_invitations::dsl::*;
|
||||
|
||||
let mut conn = match state.pool.get() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get database connection: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Database connection failed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let results = organization_invitations
|
||||
.filter(status.eq("pending"))
|
||||
.filter(expires_at.gt(Utc::now()))
|
||||
.order_by(created_at.desc())
|
||||
.load::<OrganizationInvitation>(&mut conn);
|
||||
|
||||
match results {
|
||||
Ok(invites) => {
|
||||
let responses: Vec<InvitationResponse> = invites
|
||||
.into_iter()
|
||||
.map(|inv| InvitationResponse {
|
||||
id: inv.id,
|
||||
email: inv.email,
|
||||
role: inv.role,
|
||||
message: inv.message,
|
||||
created_at: inv.created_at,
|
||||
token: inv.token,
|
||||
})
|
||||
.collect();
|
||||
|
||||
(StatusCode::OK, Json(BulkInvitationResponse { invitations: responses })).into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to list invitations: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Failed to list invitations"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a single invitation
|
||||
pub async fn create_invitation(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(bot_id): Path<Uuid>,
|
||||
Json(request): Json<CreateInvitationRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let _bot_id = bot_id.into_inner();
|
||||
use crate::core::shared::models::schema::organization_invitations::dsl::*;
|
||||
|
||||
let _bot_id = bot_id;
|
||||
let invitation_id = Uuid::new_v4();
|
||||
let token = invitation_id.to_string();
|
||||
let _accept_url = format!("{}/accept-invitation?token={}", ApiUrls::get_app_url(), token);
|
||||
let token = format!("{}{}", invitation_id, Uuid::new_v4());
|
||||
let expires_at = Utc::now() + Duration::days(7);
|
||||
let accept_url = format!("{}/accept-invitation?token={}", ApiUrls::get_app_url(), token);
|
||||
|
||||
let _body = format!(
|
||||
r#"You have been invited to join our organization as a {}.
|
||||
|
||||
Click on link below to accept the invitation:
|
||||
{}
|
||||
|
||||
This invitation will expire in 7 days."#,
|
||||
request.role, _accept_url
|
||||
let body = format!(
|
||||
"You have been invited to join our organization as a {}.\n\nClick on link below to accept the invitation:\n{}\n\nThis invitation will expire in 7 days.",
|
||||
request.role, accept_url
|
||||
);
|
||||
|
||||
// TODO: Save to database when invitations table is available
|
||||
info!("Creating invitation for {} with role {}", request.email, request.role);
|
||||
let mut conn = match state.pool.get() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get database connection: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Database connection failed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
(StatusCode::OK, Json(InvitationResponse {
|
||||
let new_invitation = OrganizationInvitation {
|
||||
id: invitation_id,
|
||||
org_id: Uuid::new_v4(),
|
||||
email: request.email.clone(),
|
||||
role: request.role.clone(),
|
||||
status: "pending".to_string(),
|
||||
message: request.custom_message.clone(),
|
||||
invited_by: Uuid::new_v4(),
|
||||
token: Some(token.clone()),
|
||||
created_at: Utc::now(),
|
||||
token: Some(token),
|
||||
}).into_response())
|
||||
updated_at: Some(Utc::now()),
|
||||
expires_at: Some(expires_at),
|
||||
accepted_at: None,
|
||||
accepted_by: None,
|
||||
};
|
||||
|
||||
match diesel::insert_into(organization_invitations)
|
||||
.values(&new_invitation)
|
||||
.execute(&mut conn)
|
||||
{
|
||||
Ok(_) => {
|
||||
info!("Created invitation for {} with role {}", request.email, request.role);
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(InvitationResponse {
|
||||
id: invitation_id,
|
||||
email: request.email,
|
||||
role: request.role,
|
||||
message: request.custom_message,
|
||||
created_at: Utc::now(),
|
||||
token: Some(token),
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to create invitation: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Failed to create invitation"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create bulk invitations
|
||||
pub async fn create_bulk_invitations(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(request): Json<BulkInvitationRequest>,
|
||||
) -> impl IntoResponse {
|
||||
use crate::core::shared::models::schema::organization_invitations::dsl::*;
|
||||
|
||||
info!("Creating {} bulk invitations", request.emails.len());
|
||||
|
||||
let mut conn = match state.pool.get() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get database connection: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Database connection failed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut responses = Vec::new();
|
||||
|
||||
for email in &request.emails {
|
||||
let invitation_id = Uuid::new_v4();
|
||||
let token = invitation_id.to_string();
|
||||
let _accept_url = format!("{}/accept-invitation?token={}", ApiUrls::get_app_url(), token);
|
||||
let token = format!("{}{}", invitation_id, Uuid::new_v4());
|
||||
let expires_at = Utc::now() + Duration::days(7);
|
||||
|
||||
// TODO: Save to database when invitations table is available
|
||||
info!("Creating invitation for {} with role {}", email, request.role);
|
||||
|
||||
responses.push(InvitationResponse {
|
||||
let new_invitation = OrganizationInvitation {
|
||||
id: invitation_id,
|
||||
org_id: Uuid::new_v4(),
|
||||
email: email.clone(),
|
||||
role: request.role.clone(),
|
||||
status: "pending".to_string(),
|
||||
message: request.custom_message.clone(),
|
||||
invited_by: Uuid::new_v4(),
|
||||
token: Some(token.clone()),
|
||||
created_at: Utc::now(),
|
||||
token: Some(token),
|
||||
});
|
||||
updated_at: Some(Utc::now()),
|
||||
expires_at: Some(expires_at),
|
||||
accepted_at: None,
|
||||
accepted_by: None,
|
||||
};
|
||||
|
||||
match diesel::insert_into(organization_invitations)
|
||||
.values(&new_invitation)
|
||||
.execute(&mut conn)
|
||||
{
|
||||
Ok(_) => {
|
||||
info!("Created invitation for {} with role {}", email, request.role);
|
||||
responses.push(InvitationResponse {
|
||||
id: invitation_id,
|
||||
email: email.clone(),
|
||||
role: request.role.clone(),
|
||||
message: request.custom_message.clone(),
|
||||
created_at: Utc::now(),
|
||||
token: Some(token),
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to create invitation for {}: {}", email, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(BulkInvitationResponse { invitations: responses })).into_response()
|
||||
}
|
||||
|
||||
/// Get invitation details
|
||||
pub async fn get_invitation(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
// TODO: Implement when invitations table is available
|
||||
warn!("get_invitation called for {} - not fully implemented", id);
|
||||
(StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Invitation not found"})).into_response())
|
||||
use crate::core::shared::models::schema::organization_invitations::dsl::*;
|
||||
|
||||
let mut conn = match state.pool.get() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get database connection: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Database connection failed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match organization_invitations
|
||||
.filter(id.eq(id))
|
||||
.first::<OrganizationInvitation>(&mut conn)
|
||||
{
|
||||
Ok(invitation) => {
|
||||
let response = InvitationResponse {
|
||||
id: invitation.id,
|
||||
email: invitation.email,
|
||||
role: invitation.role,
|
||||
message: invitation.message,
|
||||
created_at: invitation.created_at,
|
||||
token: invitation.token,
|
||||
};
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
Err(diesel::result::Error::NotFound) => {
|
||||
warn!("Invitation not found: {}", id);
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Invitation not found"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to get invitation: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Failed to get invitation"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel invitation
|
||||
pub async fn cancel_invitation(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
let _id = id.into_inner();
|
||||
// TODO: Implement when invitations table is available
|
||||
info!("cancel_invitation called for {} - not fully implemented", id);
|
||||
(StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Invitation not found"}).into_response()))
|
||||
use crate::core::shared::models::schema::organization_invitations::dsl::*;
|
||||
|
||||
let mut conn = match state.pool.get() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get database connection: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Database connection failed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match diesel::update(organization_invitations.filter(id.eq(id)))
|
||||
.set((
|
||||
status.eq("cancelled"),
|
||||
updated_at.eq(Utc::now()),
|
||||
))
|
||||
.execute(&mut conn)
|
||||
{
|
||||
Ok(0) => {
|
||||
warn!("Invitation not found for cancellation: {}", id);
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Invitation not found"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Ok(_) => {
|
||||
info!("Cancelled invitation: {}", id);
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({"success": true, "message": "Invitation cancelled"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to cancel invitation: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Failed to cancel invitation"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resend invitation
|
||||
pub async fn resend_invitation(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
let _id = id.into_inner();
|
||||
// TODO: Implement when invitations table is available
|
||||
info!("resend_invitation called for {} - not fully implemented", id);
|
||||
(StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Invitation not found"}).into_response()))
|
||||
use crate::core::shared::models::schema::organization_invitations::dsl::*;
|
||||
|
||||
let mut conn = match state.pool.get() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get database connection: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Database connection failed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match organization_invitations
|
||||
.filter(id.eq(id))
|
||||
.first::<OrganizationInvitation>(&mut conn)
|
||||
{
|
||||
Ok(invitation) => {
|
||||
if invitation.status != "pending" {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "Invitation is not pending"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let new_expires_at = Utc::now() + Duration::days(7);
|
||||
|
||||
match diesel::update(organization_invitations.filter(id.eq(id)))
|
||||
.set((
|
||||
updated_at.eq(Utc::now()),
|
||||
expires_at.eq(new_expires_at),
|
||||
))
|
||||
.execute(&mut conn)
|
||||
{
|
||||
Ok(_) => {
|
||||
info!("Resent invitation: {}", id);
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({"success": true, "message": "Invitation resent"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to resend invitation: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Failed to resend invitation"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(diesel::result::Error::NotFound) => {
|
||||
warn!("Invitation not found for resending: {}", id);
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Invitation not found"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to get invitation for resending: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": "Failed to get invitation"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ pub struct UserLoginToken {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_preferences)]
|
||||
pub struct UserPreference {
|
||||
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub preference_key: String,
|
||||
|
|
@ -162,10 +163,28 @@ pub struct UserPreference {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = clicks)]
|
||||
#[diesel(table_name = clicks)]
|
||||
pub struct Click {
|
||||
pub id: Uuid,
|
||||
pub campaign_id: String,
|
||||
pub email: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
|
|
|||
216
src/core/shared/models/core.rs.bad
Normal file
216
src/core/shared/models/core.rs.bad
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::shared::models::schema::{
|
||||
bot_configuration, bot_memories, bots, clicks, message_history, organizations,
|
||||
system_automations, user_login_tokens, user_preferences, user_sessions, users,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TriggerKind {
|
||||
Scheduled = 0,
|
||||
TableUpdate = 1,
|
||||
TableInsert = 2,
|
||||
TableDelete = 3,
|
||||
Webhook = 4,
|
||||
EmailReceived = 5,
|
||||
FolderChange = 6,
|
||||
}
|
||||
|
||||
impl TriggerKind {
|
||||
pub fn from_i32(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::Scheduled),
|
||||
1 => Some(Self::TableUpdate),
|
||||
2 => Some(Self::TableInsert),
|
||||
3 => Some(Self::TableDelete),
|
||||
4 => Some(Self::Webhook),
|
||||
5 => Some(Self::EmailReceived),
|
||||
6 => Some(Self::FolderChange),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Queryable, Serialize, Deserialize, Identifiable)]
|
||||
#[diesel(table_name = system_automations)]
|
||||
pub struct Automation {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub kind: i32,
|
||||
pub target: Option<String>,
|
||||
pub schedule: Option<String>,
|
||||
pub param: String,
|
||||
pub is_active: bool,
|
||||
pub last_triggered: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Selectable)]
|
||||
#[diesel(table_name = user_sessions)]
|
||||
pub struct UserSession {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub title: String,
|
||||
pub context_data: serde_json::Value,
|
||||
pub current_tool: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Insertable)]
|
||||
#[diesel(table_name = bot_memories)]
|
||||
pub struct BotMemory {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub is_active: bool,
|
||||
pub is_admin: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bots)]
|
||||
pub struct Bot {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub llm_provider: String,
|
||||
pub llm_config: serde_json::Value,
|
||||
pub context_provider: String,
|
||||
pub context_config: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_active: Option<bool>,
|
||||
pub tenant_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)]
|
||||
#[diesel(primary_key(org_id))]
|
||||
pub struct Organization {
|
||||
pub org_id: Uuid,
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = message_history)]
|
||||
pub struct MessageHistory {
|
||||
pub id: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub role: i32,
|
||||
pub content_encrypted: String,
|
||||
pub message_type: i32,
|
||||
pub message_index: i64,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bot_configuration)]
|
||||
pub struct BotConfiguration {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub config_key: String,
|
||||
pub config_value: String,
|
||||
pub is_encrypted: bool,
|
||||
pub config_type: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_login_tokens)]
|
||||
pub struct UserLoginToken {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub token_hash: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_used: DateTime<Utc>,
|
||||
pub user_agent: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_preferences)]
|
||||
pub struct UserPreference {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub preference_key: String,
|
||||
pub preference_value: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = clicks)]
|
||||
pub struct Click {
|
||||
pub id: Uuid,
|
||||
pub campaign_id: String,
|
||||
pub email: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)] // Correct reference
|
||||
#[diesel(primary_key(id))] // Correct primary key? No, core struct says org_id.
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)] // Wrong table name reference in previous attempt
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
205
src/core/shared/models/core.rs.bad2
Normal file
205
src/core/shared/models/core.rs.bad2
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::shared::models::schema::{
|
||||
bot_configuration, bot_memories, bots, clicks, message_history, organizations,
|
||||
system_automations, user_login_tokens, user_preferences, user_sessions, users,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TriggerKind {
|
||||
Scheduled = 0,
|
||||
TableUpdate = 1,
|
||||
TableInsert = 2,
|
||||
TableDelete = 3,
|
||||
Webhook = 4,
|
||||
EmailReceived = 5,
|
||||
FolderChange = 6,
|
||||
}
|
||||
|
||||
impl TriggerKind {
|
||||
pub fn from_i32(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::Scheduled),
|
||||
1 => Some(Self::TableUpdate),
|
||||
2 => Some(Self::TableInsert),
|
||||
3 => Some(Self::TableDelete),
|
||||
4 => Some(Self::Webhook),
|
||||
5 => Some(Self::EmailReceived),
|
||||
6 => Some(Self::FolderChange),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Queryable, Serialize, Deserialize, Identifiable)]
|
||||
#[diesel(table_name = system_automations)]
|
||||
pub struct Automation {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub kind: i32,
|
||||
pub target: Option<String>,
|
||||
pub schedule: Option<String>,
|
||||
pub param: String,
|
||||
pub is_active: bool,
|
||||
pub last_triggered: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Selectable)]
|
||||
#[diesel(table_name = user_sessions)]
|
||||
pub struct UserSession {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub title: String,
|
||||
pub context_data: serde_json::Value,
|
||||
pub current_tool: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Insertable)]
|
||||
#[diesel(table_name = bot_memories)]
|
||||
pub struct BotMemory {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub is_active: bool,
|
||||
pub is_admin: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bots)]
|
||||
pub struct Bot {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub llm_provider: String,
|
||||
pub llm_config: serde_json::Value,
|
||||
pub context_provider: String,
|
||||
pub context_config: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_active: Option<bool>,
|
||||
pub tenant_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)]
|
||||
#[diesel(primary_key(org_id))]
|
||||
pub struct Organization {
|
||||
pub org_id: Uuid,
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = message_history)]
|
||||
pub struct MessageHistory {
|
||||
pub id: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub role: i32,
|
||||
pub content_encrypted: String,
|
||||
pub message_type: i32,
|
||||
pub message_index: i64,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bot_configuration)]
|
||||
pub struct BotConfiguration {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub config_key: String,
|
||||
pub config_value: String,
|
||||
pub is_encrypted: bool,
|
||||
pub config_type: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_login_tokens)]
|
||||
pub struct UserLoginToken {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub token_hash: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_used: DateTime<Utc>,
|
||||
pub user_agent: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_preferences)]
|
||||
pub struct UserPreference {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub preference_key: String,
|
||||
pub preference_value: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = clicks)]
|
||||
pub struct Click {
|
||||
pub id: Uuid,
|
||||
pub campaign_id: String,
|
||||
pub email: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)] // Correct reference
|
||||
#[diesel(primary_key(id))] // Correct primary key? No, core struct says org_id.
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
176
src/core/shared/models/core.rs.bak
Normal file
176
src/core/shared/models/core.rs.bak
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::shared::models::schema::{
|
||||
bot_configuration, bot_memories, bots, clicks, message_history, organizations,
|
||||
system_automations, user_login_tokens, user_preferences, user_sessions, users,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TriggerKind {
|
||||
Scheduled = 0,
|
||||
TableUpdate = 1,
|
||||
TableInsert = 2,
|
||||
TableDelete = 3,
|
||||
Webhook = 4,
|
||||
EmailReceived = 5,
|
||||
FolderChange = 6,
|
||||
}
|
||||
|
||||
impl TriggerKind {
|
||||
pub fn from_i32(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::Scheduled),
|
||||
1 => Some(Self::TableUpdate),
|
||||
2 => Some(Self::TableInsert),
|
||||
3 => Some(Self::TableDelete),
|
||||
4 => Some(Self::Webhook),
|
||||
5 => Some(Self::EmailReceived),
|
||||
6 => Some(Self::FolderChange),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Queryable, Serialize, Deserialize, Identifiable)]
|
||||
#[diesel(table_name = system_automations)]
|
||||
pub struct Automation {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub kind: i32,
|
||||
pub target: Option<String>,
|
||||
pub schedule: Option<String>,
|
||||
pub param: String,
|
||||
pub is_active: bool,
|
||||
pub last_triggered: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Selectable)]
|
||||
#[diesel(table_name = user_sessions)]
|
||||
pub struct UserSession {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub title: String,
|
||||
pub context_data: serde_json::Value,
|
||||
pub current_tool: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Insertable)]
|
||||
#[diesel(table_name = bot_memories)]
|
||||
pub struct BotMemory {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub is_active: bool,
|
||||
pub is_admin: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bots)]
|
||||
pub struct Bot {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub llm_provider: String,
|
||||
pub llm_config: serde_json::Value,
|
||||
pub context_provider: String,
|
||||
pub context_config: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_active: Option<bool>,
|
||||
pub tenant_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)]
|
||||
#[diesel(primary_key(org_id))]
|
||||
pub struct Organization {
|
||||
pub org_id: Uuid,
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = message_history)]
|
||||
pub struct MessageHistory {
|
||||
pub id: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub role: i32,
|
||||
pub content_encrypted: String,
|
||||
pub message_type: i32,
|
||||
pub message_index: i64,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bot_configuration)]
|
||||
pub struct BotConfiguration {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub config_key: String,
|
||||
pub config_value: String,
|
||||
pub is_encrypted: bool,
|
||||
pub config_type: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_login_tokens)]
|
||||
pub struct UserLoginToken {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub token_hash: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_used: DateTime<Utc>,
|
||||
pub user_agent: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_preferences)]
|
||||
pub struct UserPreference {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
191
src/core/shared/models/core.rs.check
Normal file
191
src/core/shared/models/core.rs.check
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::shared::models::schema::{
|
||||
bot_configuration, bot_memories, bots, clicks, message_history, organizations,
|
||||
system_automations, user_login_tokens, user_preferences, user_sessions, users,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TriggerKind {
|
||||
Scheduled = 0,
|
||||
TableUpdate = 1,
|
||||
TableInsert = 2,
|
||||
TableDelete = 3,
|
||||
Webhook = 4,
|
||||
EmailReceived = 5,
|
||||
FolderChange = 6,
|
||||
}
|
||||
|
||||
impl TriggerKind {
|
||||
pub fn from_i32(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::Scheduled),
|
||||
1 => Some(Self::TableUpdate),
|
||||
2 => Some(Self::TableInsert),
|
||||
3 => Some(Self::TableDelete),
|
||||
4 => Some(Self::Webhook),
|
||||
5 => Some(Self::EmailReceived),
|
||||
6 => Some(Self::FolderChange),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Queryable, Serialize, Deserialize, Identifiable)]
|
||||
#[diesel(table_name = system_automations)]
|
||||
pub struct Automation {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub kind: i32,
|
||||
pub target: Option<String>,
|
||||
pub schedule: Option<String>,
|
||||
pub param: String,
|
||||
pub is_active: bool,
|
||||
pub last_triggered: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Selectable)]
|
||||
#[diesel(table_name = user_sessions)]
|
||||
pub struct UserSession {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub title: String,
|
||||
pub context_data: serde_json::Value,
|
||||
pub current_tool: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Insertable)]
|
||||
#[diesel(table_name = bot_memories)]
|
||||
pub struct BotMemory {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub is_active: bool,
|
||||
pub is_admin: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bots)]
|
||||
pub struct Bot {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub llm_provider: String,
|
||||
pub llm_config: serde_json::Value,
|
||||
pub context_provider: String,
|
||||
pub context_config: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_active: Option<bool>,
|
||||
pub tenant_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)]
|
||||
#[diesel(primary_key(org_id))]
|
||||
pub struct Organization {
|
||||
pub org_id: Uuid,
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = message_history)]
|
||||
pub struct MessageHistory {
|
||||
pub id: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub role: i32,
|
||||
pub content_encrypted: String,
|
||||
pub message_type: i32,
|
||||
pub message_index: i64,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bot_configuration)]
|
||||
pub struct BotConfiguration {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub config_key: String,
|
||||
pub config_value: String,
|
||||
pub is_encrypted: bool,
|
||||
pub config_type: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_login_tokens)]
|
||||
pub struct UserLoginToken {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub token_hash: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_used: DateTime<Utc>,
|
||||
pub user_agent: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_preferences)]
|
||||
pub struct UserPreference {
|
||||
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub preference_key: String,
|
||||
pub preference_value: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = clicks)]
|
||||
pub struct Click {
|
||||
pub id: Uuid,
|
||||
pub campaign_id: String,
|
||||
pub email: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
234
src/core/shared/models/core.rs.fix
Normal file
234
src/core/shared/models/core.rs.fix
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::shared::models::schema::{
|
||||
bot_configuration, bot_memories, bots, clicks, message_history, organizations,
|
||||
system_automations, user_login_tokens, user_preferences, user_sessions, users,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TriggerKind {
|
||||
Scheduled = 0,
|
||||
TableUpdate = 1,
|
||||
TableInsert = 2,
|
||||
TableDelete = 3,
|
||||
Webhook = 4,
|
||||
EmailReceived = 5,
|
||||
FolderChange = 6,
|
||||
}
|
||||
|
||||
impl TriggerKind {
|
||||
pub fn from_i32(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::Scheduled),
|
||||
1 => Some(Self::TableUpdate),
|
||||
2 => Some(Self::TableInsert),
|
||||
3 => Some(Self::TableDelete),
|
||||
4 => Some(Self::Webhook),
|
||||
5 => Some(Self::EmailReceived),
|
||||
6 => Some(Self::FolderChange),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Queryable, Serialize, Deserialize, Identifiable)]
|
||||
#[diesel(table_name = system_automations)]
|
||||
pub struct Automation {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub kind: i32,
|
||||
pub target: Option<String>,
|
||||
pub schedule: Option<String>,
|
||||
pub param: String,
|
||||
pub is_active: bool,
|
||||
pub last_triggered: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Selectable)]
|
||||
#[diesel(table_name = user_sessions)]
|
||||
pub struct UserSession {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub title: String,
|
||||
pub context_data: serde_json::Value,
|
||||
pub current_tool: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Insertable)]
|
||||
#[diesel(table_name = bot_memories)]
|
||||
pub struct BotMemory {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub is_active: bool,
|
||||
pub is_admin: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bots)]
|
||||
pub struct Bot {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub llm_provider: String,
|
||||
pub llm_config: serde_json::Value,
|
||||
pub context_provider: String,
|
||||
pub context_config: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_active: Option<bool>,
|
||||
pub tenant_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)]
|
||||
#[diesel(primary_key(org_id))]
|
||||
pub struct Organization {
|
||||
pub org_id: Uuid,
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = message_history)]
|
||||
pub struct MessageHistory {
|
||||
pub id: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub role: i32,
|
||||
pub content_encrypted: String,
|
||||
pub message_type: i32,
|
||||
pub message_index: i64,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bot_configuration)]
|
||||
pub struct BotConfiguration {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub config_key: String,
|
||||
pub config_value: String,
|
||||
pub is_encrypted: bool,
|
||||
pub config_type: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_login_tokens)]
|
||||
pub struct UserLoginToken {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub token_hash: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_used: DateTime<Utc>,
|
||||
pub user_agent: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_preferences)]
|
||||
pub struct UserPreference {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub preference_key: String,
|
||||
pub preference_value: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = clicks)]
|
||||
pub struct Click {
|
||||
pub id: Uuid,
|
||||
pub campaign_id: String,
|
||||
pub email: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)] // Correct reference
|
||||
#[diesel(primary_key(id))] // Correct primary key? No, core struct says org_id.
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)] // Wrong table name reference in previous attempt
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
}
|
||||
161
src/core/shared/models/core.rs.head
Normal file
161
src/core/shared/models/core.rs.head
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::shared::models::schema::{
|
||||
bot_configuration, bot_memories, bots, clicks, message_history, organizations,
|
||||
system_automations, user_login_tokens, user_preferences, user_sessions, users,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TriggerKind {
|
||||
Scheduled = 0,
|
||||
TableUpdate = 1,
|
||||
TableInsert = 2,
|
||||
TableDelete = 3,
|
||||
Webhook = 4,
|
||||
EmailReceived = 5,
|
||||
FolderChange = 6,
|
||||
}
|
||||
|
||||
impl TriggerKind {
|
||||
pub fn from_i32(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::Scheduled),
|
||||
1 => Some(Self::TableUpdate),
|
||||
2 => Some(Self::TableInsert),
|
||||
3 => Some(Self::TableDelete),
|
||||
4 => Some(Self::Webhook),
|
||||
5 => Some(Self::EmailReceived),
|
||||
6 => Some(Self::FolderChange),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Queryable, Serialize, Deserialize, Identifiable)]
|
||||
#[diesel(table_name = system_automations)]
|
||||
pub struct Automation {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub kind: i32,
|
||||
pub target: Option<String>,
|
||||
pub schedule: Option<String>,
|
||||
pub param: String,
|
||||
pub is_active: bool,
|
||||
pub last_triggered: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Selectable)]
|
||||
#[diesel(table_name = user_sessions)]
|
||||
pub struct UserSession {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub title: String,
|
||||
pub context_data: serde_json::Value,
|
||||
pub current_tool: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable, Insertable)]
|
||||
#[diesel(table_name = bot_memories)]
|
||||
pub struct BotMemory {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = users)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub is_active: bool,
|
||||
pub is_admin: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bots)]
|
||||
pub struct Bot {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub llm_provider: String,
|
||||
pub llm_config: serde_json::Value,
|
||||
pub context_provider: String,
|
||||
pub context_config: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_active: Option<bool>,
|
||||
pub tenant_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = organizations)]
|
||||
#[diesel(primary_key(org_id))]
|
||||
pub struct Organization {
|
||||
pub org_id: Uuid,
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = message_history)]
|
||||
pub struct MessageHistory {
|
||||
pub id: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub role: i32,
|
||||
pub content_encrypted: String,
|
||||
pub message_type: i32,
|
||||
pub message_index: i64,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = bot_configuration)]
|
||||
pub struct BotConfiguration {
|
||||
pub id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub config_key: String,
|
||||
pub config_value: String,
|
||||
pub is_encrypted: bool,
|
||||
pub config_type: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_login_tokens)]
|
||||
pub struct UserLoginToken {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub token_hash: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_used: DateTime<Utc>,
|
||||
pub user_agent: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_preferences)]
|
||||
pub struct UserPreference {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
39
src/core/shared/models/core.rs.new
Normal file
39
src/core/shared/models/core.rs.new
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = user_preferences)] // Closing UserPreference struct
|
||||
pub struct UserPreference {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub preference_key: String,
|
||||
pub preference_value: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = clicks)]
|
||||
pub struct Click {
|
||||
pub id: Uuid,
|
||||
pub campaign_id: String,
|
||||
pub email: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Queryable, Identifiable)]
|
||||
#[diesel(table_name = crate::core::shared::models::schema::organization_invitations)]
|
||||
pub struct OrganizationInvitation {
|
||||
pub id: Uuid,
|
||||
pub org_id: Uuid,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub message: Option<String>,
|
||||
pub invited_by: Uuid,
|
||||
pub token: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub accepted_at: Option<DateTime<Utc>>,
|
||||
pub accepted_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
|
|
@ -51,3 +51,7 @@ pub use super::schema::{
|
|||
|
||||
pub use botlib::message_types::MessageType;
|
||||
pub use botlib::models::{ApiResponse, Attachment, BotResponse, Session, Suggestion, UserMessage};
|
||||
|
||||
// Manually export OrganizationInvitation as it is defined in core but table is organization_invitations
|
||||
pub use self::core::OrganizationInvitation;
|
||||
|
||||
|
|
|
|||
53
src/core/shared/models/mod.rs.bak
Normal file
53
src/core/shared/models/mod.rs.bak
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
|
||||
pub mod core;
|
||||
pub use self::core::*;
|
||||
|
||||
pub mod rbac;
|
||||
pub use self::rbac::*;
|
||||
|
||||
pub mod workflow_models;
|
||||
pub use self::workflow_models::*;
|
||||
|
||||
#[cfg(feature = "tasks")]
|
||||
pub mod task_models;
|
||||
#[cfg(feature = "tasks")]
|
||||
pub use self::task_models::*;
|
||||
|
||||
pub use super::schema;
|
||||
|
||||
// Re-export core schema tables
|
||||
pub use super::schema::{
|
||||
basic_tools, bot_configuration, bot_memories, bots, clicks,
|
||||
message_history, organizations, rbac_group_roles, rbac_groups,
|
||||
rbac_permissions, rbac_role_permissions, rbac_roles, rbac_user_groups, rbac_user_roles,
|
||||
session_tool_associations, system_automations, user_login_tokens,
|
||||
user_preferences, user_sessions, users, workflow_executions, workflow_events, bot_shared_memory,
|
||||
};
|
||||
|
||||
// Re-export feature-gated schema tables
|
||||
#[cfg(feature = "tasks")]
|
||||
pub use super::schema::tasks;
|
||||
|
||||
#[cfg(feature = "mail")]
|
||||
pub use super::schema::{
|
||||
distribution_lists, email_auto_responders, email_drafts, email_folders,
|
||||
email_label_assignments, email_labels, email_rules, email_signatures,
|
||||
email_templates, global_email_signatures, scheduled_emails,
|
||||
shared_mailbox_members, shared_mailboxes, user_email_accounts,
|
||||
};
|
||||
|
||||
#[cfg(feature = "people")]
|
||||
pub use super::schema::{
|
||||
crm_accounts, crm_activities, crm_contacts, crm_leads, crm_notes,
|
||||
crm_opportunities, crm_pipeline_stages, people, people_departments,
|
||||
people_org_chart, people_person_skills, people_skills, people_team_members,
|
||||
people_teams, people_time_off,
|
||||
};
|
||||
|
||||
#[cfg(feature = "vectordb")]
|
||||
pub use super::schema::{
|
||||
kb_collections, kb_documents, user_kb_associations,
|
||||
};
|
||||
|
||||
pub use botlib::message_types::MessageType;
|
||||
pub use botlib::models::{ApiResponse, Attachment, BotResponse, Session, Suggestion, UserMessage};
|
||||
53
src/core/shared/models/mod.rs.final
Normal file
53
src/core/shared/models/mod.rs.final
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
|
||||
pub mod core;
|
||||
pub use self::core::*;
|
||||
|
||||
pub mod rbac;
|
||||
pub use self::rbac::*;
|
||||
|
||||
pub mod workflow_models;
|
||||
pub use self::workflow_models::*;
|
||||
|
||||
#[cfg(feature = "tasks")]
|
||||
pub mod task_models;
|
||||
#[cfg(feature = "tasks")]
|
||||
pub use self::task_models::*;
|
||||
|
||||
pub use super::schema;
|
||||
|
||||
// Re-export core schema tables
|
||||
pub use super::schema::{
|
||||
basic_tools, bot_configuration, bot_memories, bots, clicks,
|
||||
message_history, organizations, rbac_group_roles, rbac_groups,
|
||||
rbac_permissions, rbac_role_permissions, rbac_roles, rbac_user_groups, rbac_user_roles,
|
||||
session_tool_associations, system_automations, user_login_tokens,
|
||||
user_preferences, user_sessions, users, workflow_executions, workflow_events, bot_shared_memory,
|
||||
};
|
||||
|
||||
// Re-export feature-gated schema tables
|
||||
#[cfg(feature = "tasks")]
|
||||
pub use super::schema::tasks;
|
||||
|
||||
#[cfg(feature = "mail")]
|
||||
pub use super::schema::{
|
||||
distribution_lists, email_auto_responders, email_drafts, email_folders,
|
||||
email_label_assignments, email_labels, email_rules, email_signatures,
|
||||
email_templates, global_email_signatures, scheduled_emails,
|
||||
shared_mailbox_members, shared_mailboxes, user_email_accounts,
|
||||
};
|
||||
|
||||
#[cfg(feature = "people")]
|
||||
pub use super::schema::{
|
||||
crm_accounts, crm_activities, crm_contacts, crm_leads, crm_notes,
|
||||
crm_opportunities, crm_pipeline_stages, people, people_departments,
|
||||
people_org_chart, people_person_skills, people_skills, people_team_members,
|
||||
people_teams, people_time_off,
|
||||
};
|
||||
|
||||
#[cfg(feature = "vectordb")]
|
||||
pub use super::schema::{
|
||||
kb_collections, kb_documents, user_kb_associations,
|
||||
};
|
||||
|
||||
pub use botlib::message_types::MessageType;
|
||||
pub use botlib::models::{ApiResponse, Attachment, BotResponse, Session, Suggestion, UserMessage};
|
||||
|
|
@ -249,13 +249,13 @@ impl Default for TestAppStateBuilder {
|
|||
#[cfg(feature = "directory")]
|
||||
pub fn create_mock_auth_service() -> AuthService {
|
||||
let config = ZitadelConfig {
|
||||
issuer_url: "http://localhost:8080".to_string(),
|
||||
issuer: "http://localhost:8080".to_string(),
|
||||
issuer_url: "http://localhost:9000".to_string(),
|
||||
issuer: "http://localhost:9000".to_string(),
|
||||
client_id: "mock_client_id".to_string(),
|
||||
client_secret: "mock_client_secret".to_string(),
|
||||
redirect_uri: "http://localhost:3000/callback".to_string(),
|
||||
project_id: "mock_project_id".to_string(),
|
||||
api_url: "http://localhost:8080".to_string(),
|
||||
api_url: "http://localhost:9000".to_string(),
|
||||
service_account_key: None,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ pub async fn download_file(url: &str, output_path: &str) -> Result<(), anyhow::E
|
|||
let pb = ProgressBar::new(total_size);
|
||||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
|
||||
.expect("Invalid progress bar template")
|
||||
.unwrap_or(ProgressStyle::default_bar())
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message(format!("Downloading {}", url));
|
||||
let mut file = TokioFile::create(&output_path).await?;
|
||||
|
|
@ -546,13 +546,102 @@ pub fn truncate_text_for_model(text: &str, model: &str, max_tokens: usize) -> St
|
|||
|
||||
/// Estimates characters per token based on model type
|
||||
fn estimate_chars_per_token(model: &str) -> usize {
|
||||
if model.contains("gpt") || model.contains("claude") {
|
||||
4 // GPT/Claude models: ~4 chars per token
|
||||
} else if model.contains("llama") || model.contains("mistral") {
|
||||
3 // Llama/Mistral models: ~3 chars per token
|
||||
} else if model.contains("bert") || model.contains("mpnet") {
|
||||
4 // BERT-based models: ~4 chars per token
|
||||
if model.contains("llama") || model.contains("mistral") {
|
||||
3 // Llama/Mistral models: ~3 chars per token
|
||||
} else {
|
||||
4 // Default conservative estimate
|
||||
4 // GPT/Claude/BERT models and default: ~4 chars per token
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert date string from user locale format to ISO format (YYYY-MM-DD) for PostgreSQL.
|
||||
///
|
||||
/// The LLM automatically formats dates according to the user's language/idiom based on:
|
||||
/// 1. The conversation context (user's language)
|
||||
/// 2. The PARAM LIKE example (e.g., "15/12/2026" for DD/MM/YYYY)
|
||||
///
|
||||
/// This function handles the most common formats:
|
||||
/// - ISO: YYYY-MM-DD (already in ISO, returned as-is)
|
||||
/// - Brazilian/Portuguese: DD/MM/YYYY or DD/MM/YY
|
||||
/// - US/English: MM/DD/YYYY or MM/DD/YY
|
||||
///
|
||||
/// If the value doesn't match any date pattern, returns it unchanged.
|
||||
///
|
||||
/// NOTE: This function does NOT try to guess ambiguous formats.
|
||||
/// The LLM is responsible for formatting dates correctly based on user language.
|
||||
/// The PARAM declaration's LIKE example tells the LLM the expected format.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `value` - The date string to convert (as provided by the LLM)
|
||||
///
|
||||
/// # Returns
|
||||
/// ISO formatted date string (YYYY-MM-DD) or original value if not a recognized date
|
||||
pub fn convert_date_to_iso_format(value: &str) -> String {
|
||||
let value = value.trim();
|
||||
|
||||
// Already in ISO format (YYYY-MM-DD) - return as-is
|
||||
if value.len() == 10 && value.chars().nth(4) == Some('-') && value.chars().nth(7) == Some('-') {
|
||||
let parts: Vec<&str> = value.split('-').collect();
|
||||
if parts.len() == 3
|
||||
&& parts[0].len() == 4
|
||||
&& parts[1].len() == 2
|
||||
&& parts[2].len() == 2
|
||||
&& parts[0].chars().all(|c| c.is_ascii_digit())
|
||||
&& parts[1].chars().all(|c| c.is_ascii_digit())
|
||||
&& parts[2].chars().all(|c| c.is_ascii_digit())
|
||||
{
|
||||
if let (Ok(year), Ok(month), Ok(day)) =
|
||||
(parts[0].parse::<u32>(), parts[1].parse::<u32>(), parts[2].parse::<u32>())
|
||||
{
|
||||
if (1..=12).contains(&month) && (1..=31).contains(&day) && (1900..=2100).contains(&year) {
|
||||
return value.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle slash-separated formats: DD/MM/YYYY or MM/DD/YYYY
|
||||
// We need to detect which format based on the PARAM declaration's LIKE example
|
||||
// For now, default to DD/MM/YYYY (Brazilian format) as this is the most common for this bot
|
||||
// TODO: Pass language/idiom from session to determine correct format
|
||||
if value.len() >= 8 && value.len() <= 10 {
|
||||
let parts: Vec<&str> = value.split('/').collect();
|
||||
if parts.len() == 3 {
|
||||
let all_numeric = parts[0].chars().all(|c| c.is_ascii_digit())
|
||||
&& parts[1].chars().all(|c| c.is_ascii_digit())
|
||||
&& parts[2].chars().all(|c| c.is_ascii_digit());
|
||||
|
||||
if all_numeric {
|
||||
// Parse the three parts
|
||||
let a = parts[0].parse::<u32>().ok();
|
||||
let b = parts[1].parse::<u32>().ok();
|
||||
let c = if parts[2].len() == 2 {
|
||||
// Convert 2-digit year to 4-digit
|
||||
parts[2].parse::<u32>().ok().map(|y| {
|
||||
if y < 50 {
|
||||
2000 + y
|
||||
} else {
|
||||
1900 + y
|
||||
}
|
||||
})
|
||||
} else {
|
||||
parts[2].parse::<u32>().ok()
|
||||
};
|
||||
|
||||
if let (Some(first), Some(second), Some(third)) = (a, b, c) {
|
||||
// Default: DD/MM/YYYY format (Brazilian/Portuguese)
|
||||
// The LLM should format dates according to the user's language
|
||||
// and the PARAM LIKE example (e.g., "15/12/2026" for DD/MM/YYYY)
|
||||
let (year, month, day) = (third, second, first);
|
||||
|
||||
// Validate the determined date
|
||||
if (1..=31).contains(&day) && (1..=12).contains(&month) && (1900..=2100).contains(&year) {
|
||||
return format!("{:04}-{:02}-{:02}", year, month, day);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not a recognized date pattern, return unchanged
|
||||
value.to_string()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -479,7 +479,7 @@ impl ApiUrls {
|
|||
pub struct InternalUrls;
|
||||
|
||||
impl InternalUrls {
|
||||
pub const DIRECTORY_BASE: &'static str = "http://localhost:8080";
|
||||
pub const DIRECTORY_BASE: &'static str = "http://localhost:9000";
|
||||
pub const DATABASE: &'static str = "postgres://localhost:5432";
|
||||
pub const CACHE: &'static str = "redis://localhost:6379";
|
||||
pub const DRIVE: &'static str = "https://localhost:9000";
|
||||
|
|
|
|||
|
|
@ -1,21 +1,2 @@
|
|||
// Canvas module - split into canvas_api subdirectory for better organization
|
||||
//
|
||||
// This module has been reorganized into the following submodules:
|
||||
// - canvas_api/types: All data structures and enums
|
||||
// - canvas_api/error: Error types and implementations
|
||||
// - canvas_api/db: Database row types and migrations
|
||||
// - canvas_api/service: CanvasService business logic
|
||||
// - canvas_api/handlers: HTTP route handlers
|
||||
//
|
||||
// This file re-exports all public items for backward compatibility.
|
||||
|
||||
pub mod canvas_api;
|
||||
|
||||
// Re-export all public types for backward compatibility
|
||||
pub use canvas_api::*;
|
||||
|
||||
// Re-export the migration function at the module level
|
||||
pub use canvas_api::create_canvas_tables_migration;
|
||||
|
||||
// Re-export canvas routes at the module level
|
||||
pub use canvas_api::canvas_routes;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
pub mod canvas;
|
||||
pub mod canvas_api;
|
||||
pub mod ui;
|
||||
pub mod workflow_canvas;
|
||||
pub mod bas_analyzer;
|
||||
|
|
|
|||
|
|
@ -123,24 +123,24 @@ impl WorkflowCanvas {
|
|||
pub async fn workflow_designer_page(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
) -> Result<Html<String>, StatusCode> {
|
||||
let html = r#"
|
||||
let html = r##"
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Workflow Designer</title>
|
||||
<script src="/static/htmx.min.js"></script>
|
||||
<style>
|
||||
.canvas {
|
||||
width: 100%;
|
||||
height: 600px;
|
||||
border: 1px solid #ccc;
|
||||
.canvas {
|
||||
width: 100%;
|
||||
height: 600px;
|
||||
border: 1px solid #ccc;
|
||||
position: relative;
|
||||
background: #f9f9f9;
|
||||
}
|
||||
.node {
|
||||
position: absolute;
|
||||
padding: 10px;
|
||||
border: 2px solid #333;
|
||||
.node {
|
||||
position: absolute;
|
||||
padding: 10px;
|
||||
border: 2px solid #333;
|
||||
background: white;
|
||||
border-radius: 5px;
|
||||
cursor: move;
|
||||
|
|
@ -152,25 +152,25 @@ pub async fn workflow_designer_page(
|
|||
.node.condition { border-color: #28a745; background: #e8f5e9; }
|
||||
.node.parallel { border-color: #6f42c1; background: #f3e5f5; }
|
||||
.node.event { border-color: #fd7e14; background: #fff3e0; }
|
||||
.toolbar {
|
||||
padding: 10px;
|
||||
background: #f8f9fa;
|
||||
.toolbar {
|
||||
padding: 10px;
|
||||
background: #f8f9fa;
|
||||
border-bottom: 1px solid #dee2e6;
|
||||
}
|
||||
.btn {
|
||||
padding: 8px 16px;
|
||||
margin: 0 5px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
.btn {
|
||||
padding: 8px 16px;
|
||||
margin: 0 5px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
}
|
||||
.btn-primary { background: #007bff; color: white; }
|
||||
.btn-success { background: #28a745; color: white; }
|
||||
.btn-warning { background: #ffc107; color: black; }
|
||||
.code-preview {
|
||||
margin-top: 20px;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
.code-preview {
|
||||
margin-top: 20px;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border: 1px solid #dee2e6;
|
||||
font-family: monospace;
|
||||
white-space: pre-wrap;
|
||||
|
|
@ -189,15 +189,15 @@ pub async fn workflow_designer_page(
|
|||
<input type="file" id="file-input" accept=".bas" onchange="analyzeFile()" style="margin-left: 20px;">
|
||||
<label for="file-input" class="btn">Analyze .bas File</label>
|
||||
</div>
|
||||
|
||||
|
||||
<div id="file-analysis" style="display:none; padding: 10px; background: #e8f4f8; border: 1px solid #bee5eb; margin: 10px 0;">
|
||||
<h4>File Analysis Result</h4>
|
||||
<div id="analysis-content"></div>
|
||||
</div>
|
||||
|
||||
|
||||
<div id="canvas" class="canvas" ondrop="drop(event)" ondragover="allowDrop(event)">
|
||||
</div>
|
||||
|
||||
|
||||
<div id="code-preview" class="code-preview">
|
||||
Generated BASIC code will appear here...
|
||||
</div>
|
||||
|
|
@ -205,7 +205,7 @@ pub async fn workflow_designer_page(
|
|||
<script>
|
||||
let nodeCounter = 0;
|
||||
let nodes = [];
|
||||
|
||||
|
||||
function addNode(type) {
|
||||
nodeCounter++;
|
||||
const node = {
|
||||
|
|
@ -217,7 +217,7 @@ pub async fn workflow_designer_page(
|
|||
nodes.push(node);
|
||||
renderNode(node);
|
||||
}
|
||||
|
||||
|
||||
function renderNode(node) {
|
||||
const canvas = document.getElementById('canvas');
|
||||
const nodeEl = document.createElement('div');
|
||||
|
|
@ -226,11 +226,11 @@ pub async fn workflow_designer_page(
|
|||
nodeEl.draggable = true;
|
||||
nodeEl.style.left = node.x + 'px';
|
||||
nodeEl.style.top = node.y + 'px';
|
||||
|
||||
|
||||
let content = '';
|
||||
switch(node.type) {
|
||||
case 'bot-agent':
|
||||
content = '<strong>Bot Agent</strong><br><input type="text" placeholder="Bot Name" style="width:100px;margin:2px;"><br><input type="text" placeholder="Action" style="width:100px;margin:2px;">';
|
||||
content = '<strong>Bot Agent</strong><br><input type="text" placeholder="Bot Name " style="width:100px;margin:2px;"><br><input type="text" placeholder="Action" style="width:100px;margin:2px;">';
|
||||
break;
|
||||
case 'human-approval':
|
||||
content = '<strong>Human Approval</strong><br><input type="text" placeholder="Approver" style="width:100px;margin:2px;"><br><input type="number" placeholder="Timeout" style="width:100px;margin:2px;">';
|
||||
|
|
@ -245,20 +245,20 @@ pub async fn workflow_designer_page(
|
|||
content = '<strong>Event</strong><br><input type="text" placeholder="Event Name " style="width:100px;margin:2px;">';
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
nodeEl.innerHTML = content;
|
||||
nodeEl.ondragstart = drag;
|
||||
canvas.appendChild(nodeEl);
|
||||
}
|
||||
|
||||
|
||||
function allowDrop(ev) {
|
||||
ev.preventDefault();
|
||||
}
|
||||
|
||||
|
||||
function drag(ev) {
|
||||
ev.dataTransfer.setData("text", ev.target.id);
|
||||
}
|
||||
|
||||
|
||||
function drop(ev) {
|
||||
ev.preventDefault();
|
||||
const data = ev.dataTransfer.getData("text");
|
||||
|
|
@ -266,10 +266,10 @@ pub async fn workflow_designer_page(
|
|||
const rect = ev.currentTarget.getBoundingClientRect();
|
||||
const x = ev.clientX - rect.left;
|
||||
const y = ev.clientY - rect.top;
|
||||
|
||||
|
||||
nodeEl.style.left = x + 'px';
|
||||
nodeEl.style.top = y + 'px';
|
||||
|
||||
|
||||
// Update node position in data
|
||||
const node = nodes.find(n => n.id === data);
|
||||
if (node) {
|
||||
|
|
@ -277,16 +277,16 @@ pub async fn workflow_designer_page(
|
|||
node.y = y;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function analyzeFile() {
|
||||
const fileInput = document.getElementById('file-input');
|
||||
const file = fileInput.files[0];
|
||||
|
||||
|
||||
if (file) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = function(e) {
|
||||
const content = e.target.result;
|
||||
|
||||
|
||||
fetch('/api/workflow/analyze', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
|
@ -305,13 +305,13 @@ pub async fn workflow_designer_page(
|
|||
reader.readAsText(file);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function displayAnalysis(analysis) {
|
||||
const analysisDiv = document.getElementById('file-analysis');
|
||||
const contentDiv = document.getElementById('analysis-content');
|
||||
|
||||
|
||||
let html = `<p><strong>File Type:</strong> ${analysis.file_type}</p>`;
|
||||
|
||||
|
||||
if (analysis.metadata) {
|
||||
html += `<p><strong>Workflow Name:</strong> ${analysis.metadata.name}</p>`;
|
||||
html += `<p><strong>Steps:</strong> ${analysis.metadata.step_count}</p>`;
|
||||
|
|
@ -319,7 +319,7 @@ pub async fn workflow_designer_page(
|
|||
html += `<p><strong>Human Approval:</strong> ${analysis.metadata.has_human_approval ? 'Yes' : 'No'}</p>`;
|
||||
html += `<p><strong>Parallel Processing:</strong> ${analysis.metadata.has_parallel ? 'Yes' : 'No'}</p>`;
|
||||
}
|
||||
|
||||
|
||||
if (analysis.suggestions.length > 0) {
|
||||
html += '<p><strong>Suggestions:</strong></p><ul>';
|
||||
analysis.suggestions.forEach(suggestion => {
|
||||
|
|
@ -327,14 +327,14 @@ pub async fn workflow_designer_page(
|
|||
});
|
||||
html += '</ul>';
|
||||
}
|
||||
|
||||
|
||||
contentDiv.innerHTML = html;
|
||||
analysisDiv.style.display = 'block';
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"#;
|
||||
"##;
|
||||
|
||||
Ok(Html(html.to_string()))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -379,7 +379,8 @@ pub async fn get_current_user(
|
|||
let session_token = headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "));
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
.filter(|token| !token.is_empty());
|
||||
|
||||
match session_token {
|
||||
None => {
|
||||
|
|
@ -397,21 +398,6 @@ pub async fn get_current_user(
|
|||
is_anonymous: true,
|
||||
})
|
||||
}
|
||||
Some(token) if token.is_empty() => {
|
||||
info!("get_current_user: empty authorization token - returning anonymous user");
|
||||
Json(CurrentUserResponse {
|
||||
id: None,
|
||||
username: None,
|
||||
email: None,
|
||||
first_name: None,
|
||||
last_name: None,
|
||||
display_name: None,
|
||||
roles: None,
|
||||
organization_id: None,
|
||||
avatar_url: None,
|
||||
is_anonymous: true,
|
||||
})
|
||||
}
|
||||
Some(session_token) => {
|
||||
info!("get_current_user: looking up session token (len={}, prefix={}...)",
|
||||
session_token.len(),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,2 @@
|
|||
// Re-export all handlers from the handlers_api submodule
|
||||
// This maintains backward compatibility while organizing code into logical modules
|
||||
pub mod handlers_api;
|
||||
|
||||
// Re-export all handlers for backward compatibility
|
||||
pub use handlers_api::*;
|
||||
pub use crate::docs::handlers_api::*;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod collaboration;
|
||||
pub mod handlers;
|
||||
pub mod handlers_api;
|
||||
pub mod ooxml;
|
||||
pub mod storage;
|
||||
pub mod types;
|
||||
|
|
@ -16,21 +17,7 @@ pub use collaboration::{
|
|||
handle_docs_websocket, handle_get_collaborators, handle_get_mentions, handle_get_presence,
|
||||
handle_get_selections, handle_get_typing,
|
||||
};
|
||||
pub use handlers::{
|
||||
handle_accept_reject_all, handle_accept_reject_change, handle_add_comment, handle_add_endnote,
|
||||
handle_add_footnote, handle_ai_custom, handle_ai_expand, handle_ai_improve, handle_ai_simplify,
|
||||
handle_ai_summarize, handle_ai_translate, handle_apply_style, handle_autosave,
|
||||
handle_compare_documents, handle_create_style, handle_delete_comment, handle_delete_document,
|
||||
handle_delete_endnote, handle_delete_footnote, handle_delete_style, handle_docs_ai,
|
||||
handle_docs_get_by_id, handle_docs_save, handle_enable_track_changes, handle_export_docx,
|
||||
handle_export_html, handle_export_md, handle_export_pdf, handle_export_txt,
|
||||
handle_generate_toc, handle_get_document, handle_get_outline, handle_import_document,
|
||||
handle_list_comments, handle_list_documents, handle_list_endnotes, handle_list_footnotes,
|
||||
handle_list_styles, handle_list_track_changes, handle_new_document, handle_reply_comment,
|
||||
handle_resolve_comment, handle_save_document, handle_search_documents, handle_template_blank,
|
||||
handle_template_letter, handle_template_meeting, handle_template_report, handle_update_endnote,
|
||||
handle_update_footnote, handle_update_style, handle_update_toc,
|
||||
};
|
||||
pub use handlers::*;
|
||||
pub use types::{
|
||||
AiRequest, AiResponse, Collaborator, CollabMessage, CommentReply, ComparisonSummary, Document,
|
||||
DocumentComment, DocumentComparison, DocumentDiff, DocumentMetadata, DocumentStyle, Endnote,
|
||||
|
|
|
|||
|
|
@ -1,256 +1,296 @@
|
|||
// Drive HTTP handlers extracted from drive/mod.rs
|
||||
// Drive HTTP handlers implementation using S3
|
||||
// Extracted from drive/mod.rs and using aws-sdk-s3
|
||||
|
||||
use crate::core::shared::state::AppState;
|
||||
use crate::drive::drive_types::*;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
http::{header, StatusCode},
|
||||
response::{IntoResponse, Json, Response},
|
||||
body::Body,
|
||||
};
|
||||
use aws_sdk_s3::primitives::ByteStream;
|
||||
use chrono::Utc;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Open a file for editing
|
||||
|
||||
// Import ReadResponse from parent mod if not in drive_types
|
||||
use super::ReadResponse;
|
||||
|
||||
/// Open a file (get metadata)
|
||||
pub async fn open_file(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(file_id): Path<String>,
|
||||
) -> Result<Json<FileItem>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Opening file: {}", file_id);
|
||||
read_metadata(state, file_id).await
|
||||
}
|
||||
|
||||
// TODO: Implement actual file reading
|
||||
let file_item = FileItem {
|
||||
/// Helper to get file metadata
|
||||
async fn read_metadata(
|
||||
state: Arc<AppState>,
|
||||
file_id: String,
|
||||
) -> Result<Json<FileItem>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let client = state.drive.as_ref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({"error": "Drive not configured"})),
|
||||
))?;
|
||||
let bucket = &state.bucket_name;
|
||||
|
||||
let resp = client.head_object()
|
||||
.bucket(bucket)
|
||||
.key(&file_id)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::NOT_FOUND, Json(serde_json::json!({"error": e.to_string()}))))?;
|
||||
|
||||
let item = FileItem {
|
||||
id: file_id.clone(),
|
||||
name: "Untitled".to_string(),
|
||||
file_type: "document".to_string(),
|
||||
size: 0,
|
||||
mime_type: "text/plain".to_string(),
|
||||
created_at: Utc::now(),
|
||||
modified_at: Utc::now(),
|
||||
name: file_id.split('/').next_back().unwrap_or(&file_id).to_string(),
|
||||
file_type: if file_id.ends_with('/') { "folder".to_string() } else { "file".to_string() },
|
||||
size: resp.content_length.unwrap_or(0),
|
||||
mime_type: resp.content_type.unwrap_or_else(|| "application/octet-stream".to_string()),
|
||||
created_at: Utc::now(), // S3 doesn't track creation time easily
|
||||
modified_at: Utc::now(), // Simplified
|
||||
parent_id: None,
|
||||
url: None,
|
||||
thumbnail_url: None,
|
||||
is_favorite: false,
|
||||
is_favorite: false, // Not implemented in S3
|
||||
tags: vec![],
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
Ok(Json(file_item))
|
||||
Ok(Json(item))
|
||||
}
|
||||
|
||||
/// List all buckets
|
||||
/// List all buckets (or configured one)
|
||||
pub async fn list_buckets(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Result<Json<Vec<BucketInfo>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Listing buckets");
|
||||
let client = state.drive.as_ref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({"error": "Drive not configured"})),
|
||||
))?;
|
||||
|
||||
// TODO: Query database for buckets
|
||||
let buckets = vec![];
|
||||
|
||||
Ok(Json(buckets))
|
||||
match client.list_buckets().send().await {
|
||||
Ok(resp) => {
|
||||
let buckets = resp.buckets.unwrap_or_default().iter().map(|b| {
|
||||
BucketInfo {
|
||||
id: b.name.clone().unwrap_or_default(),
|
||||
name: b.name.clone().unwrap_or_default(),
|
||||
created_at: Utc::now(),
|
||||
file_count: 0,
|
||||
total_size: 0,
|
||||
}
|
||||
}).collect();
|
||||
Ok(Json(buckets))
|
||||
},
|
||||
Err(_) => {
|
||||
// Fallback
|
||||
Ok(Json(vec![BucketInfo {
|
||||
id: state.bucket_name.clone(),
|
||||
name: state.bucket_name.clone(),
|
||||
created_at: Utc::now(),
|
||||
file_count: 0,
|
||||
total_size: 0,
|
||||
}]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// List files in a bucket
|
||||
pub async fn list_files(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<SearchQuery>,
|
||||
) -> Result<Json<Vec<FileItem>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let query = req.query.clone().unwrap_or_default();
|
||||
let parent_path = req.parent_path.clone();
|
||||
let client = state.drive.as_ref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({"error": "Drive not configured"})),
|
||||
))?;
|
||||
let bucket = req.bucket.clone().unwrap_or_else(|| state.bucket_name.clone());
|
||||
let prefix = req.parent_path.clone().unwrap_or_default();
|
||||
|
||||
tracing::debug!("Searching files: query={}, parent={:?}", query, parent_path);
|
||||
let resp = client.list_objects_v2()
|
||||
.bucket(&bucket)
|
||||
.prefix(&prefix)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()}))))?;
|
||||
|
||||
// TODO: Implement actual file search
|
||||
let files = vec![];
|
||||
let files = resp.contents.unwrap_or_default().iter().map(|obj| {
|
||||
let key = obj.key().unwrap_or_default();
|
||||
let name = key.split('/').next_back().unwrap_or(key).to_string();
|
||||
FileItem {
|
||||
id: key.to_string(),
|
||||
name,
|
||||
file_type: if key.ends_with('/') { "folder".to_string() } else { "file".to_string() },
|
||||
size: obj.size.unwrap_or(0),
|
||||
mime_type: "application/octet-stream".to_string(),
|
||||
created_at: Utc::now(),
|
||||
modified_at: Utc::now(),
|
||||
parent_id: Some(prefix.clone()),
|
||||
url: None,
|
||||
thumbnail_url: None,
|
||||
is_favorite: false,
|
||||
tags: vec![],
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
Ok(Json(files))
|
||||
}
|
||||
|
||||
/// Read file content
|
||||
/// Read file content (as text)
|
||||
pub async fn read_file(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(file_id): Path<String>,
|
||||
) -> Result<Json<FileItem>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Reading file: {}", file_id);
|
||||
) -> Result<Json<ReadResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let client = state.drive.as_ref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({"error": "Drive not configured"})),
|
||||
))?;
|
||||
let bucket = &state.bucket_name;
|
||||
|
||||
// TODO: Implement actual file reading
|
||||
let file_item = FileItem {
|
||||
id: file_id.clone(),
|
||||
name: "Untitled".to_string(),
|
||||
file_type: "document".to_string(),
|
||||
size: 0,
|
||||
mime_type: "text/plain".to_string(),
|
||||
created_at: Utc::now(),
|
||||
modified_at: Utc::now(),
|
||||
parent_id: None,
|
||||
url: None,
|
||||
thumbnail_url: None,
|
||||
is_favorite: false,
|
||||
tags: vec![],
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
let resp = client.get_object()
|
||||
.bucket(bucket)
|
||||
.key(&file_id)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::NOT_FOUND, Json(serde_json::json!({"error": e.to_string()}))))?;
|
||||
|
||||
Ok(Json(file_item))
|
||||
let data = resp.body.collect().await.map_err(|e|
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
)?.into_bytes();
|
||||
|
||||
let content = String::from_utf8(data.to_vec()).unwrap_or_else(|_| "[Binary Content]".to_string());
|
||||
|
||||
Ok(Json(ReadResponse { content }))
|
||||
}
|
||||
|
||||
/// Write file content
|
||||
pub async fn write_file(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<WriteRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let file_id = req.file_id.unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
let client = state.drive.as_ref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({"error": "Drive not configured"})),
|
||||
))?;
|
||||
let bucket = &state.bucket_name;
|
||||
let key = req.file_id.ok_or((StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "Missing file_id"}))))?;
|
||||
|
||||
tracing::debug!("Writing file: {}", file_id);
|
||||
client.put_object()
|
||||
.bucket(bucket)
|
||||
.key(&key)
|
||||
.body(ByteStream::from(req.content.into_bytes()))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()}))))?;
|
||||
|
||||
// TODO: Implement actual file writing
|
||||
Ok(Json(serde_json::json!({"success": true})))
|
||||
}
|
||||
|
||||
/// Delete a file
|
||||
pub async fn delete_file(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(file_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Deleting file: {}", file_id);
|
||||
let client = state.drive.as_ref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({"error": "Drive not configured"})),
|
||||
))?;
|
||||
let bucket = &state.bucket_name;
|
||||
|
||||
client.delete_object()
|
||||
.bucket(bucket)
|
||||
.key(&file_id)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()}))))?;
|
||||
|
||||
// TODO: Implement actual file deletion
|
||||
Ok(Json(serde_json::json!({"success": true})))
|
||||
}
|
||||
|
||||
/// Create a folder
|
||||
pub async fn create_folder(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<CreateFolderRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let _parent_id = req.parent_id.clone().unwrap_or_default();
|
||||
let client = state.drive.as_ref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({"error": "Drive not configured"})),
|
||||
))?;
|
||||
let bucket = &state.bucket_name;
|
||||
|
||||
// Construct folder path/key
|
||||
let mut key = req.parent_id.unwrap_or_default();
|
||||
if !key.ends_with('/') && !key.is_empty() {
|
||||
key.push('/');
|
||||
}
|
||||
key.push_str(&req.name);
|
||||
if !key.ends_with('/') {
|
||||
key.push('/');
|
||||
}
|
||||
|
||||
tracing::debug!("Creating folder: {:?}", req.name);
|
||||
client.put_object()
|
||||
.bucket(bucket)
|
||||
.key(&key)
|
||||
.body(ByteStream::from_static(&[]))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()}))))?;
|
||||
|
||||
// TODO: Implement actual folder creation
|
||||
Ok(Json(serde_json::json!({"success": true})))
|
||||
}
|
||||
|
||||
/// Copy a file
|
||||
pub async fn copy_file(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
Json(_req): Json<CopyFileRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Copying file");
|
||||
|
||||
// TODO: Implement actual file copying
|
||||
Ok(Json(serde_json::json!({"success": true})))
|
||||
}
|
||||
|
||||
/// Upload file to drive
|
||||
pub async fn upload_file_to_drive(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
Json(_req): Json<UploadRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Uploading to drive");
|
||||
|
||||
// TODO: Implement actual file upload
|
||||
Ok(Json(serde_json::json!({"success": true})))
|
||||
}
|
||||
|
||||
/// Download file
|
||||
/// Download file (stream)
|
||||
pub async fn download_file(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(file_id): Path<String>,
|
||||
) -> Result<Json<FileItem>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Downloading file: {}", file_id);
|
||||
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
|
||||
let client = state.drive.as_ref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({"error": "Drive not configured"})),
|
||||
))?;
|
||||
let bucket = &state.bucket_name;
|
||||
|
||||
// TODO: Implement actual file download
|
||||
let file_item = FileItem {
|
||||
id: file_id.clone(),
|
||||
name: "Download".to_string(),
|
||||
file_type: "file".to_string(),
|
||||
size: 0,
|
||||
mime_type: "application/octet-stream".to_string(),
|
||||
created_at: Utc::now(),
|
||||
modified_at: Utc::now(),
|
||||
parent_id: None,
|
||||
url: None,
|
||||
thumbnail_url: None,
|
||||
is_favorite: false,
|
||||
tags: vec![],
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
let resp = client.get_object()
|
||||
.bucket(bucket)
|
||||
.key(&file_id)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::NOT_FOUND, Json(serde_json::json!({"error": e.to_string()}))))?;
|
||||
|
||||
Ok(Json(file_item))
|
||||
let body = resp.body.collect().await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()}))))?.into_bytes();
|
||||
|
||||
Ok(Response::builder()
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(header::CONTENT_DISPOSITION, format!("attachment; filename=\"{}\"", file_id.split('/').next_back().unwrap_or("file")))
|
||||
.body(Body::from(body))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
/// List folder contents
|
||||
pub async fn list_folder_contents(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
Json(_req): Json<SearchQuery>,
|
||||
) -> Result<Json<Vec<FileItem>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Listing folder contents");
|
||||
|
||||
// TODO: Implement actual folder listing
|
||||
let files = vec![];
|
||||
|
||||
Ok(Json(files))
|
||||
// Stubs for others (list_shared, etc.)
|
||||
pub async fn copy_file(State(_): State<Arc<AppState>>, Json(_): Json<CopyFileRequest>) -> impl IntoResponse {
|
||||
(StatusCode::NOT_IMPLEMENTED, Json(serde_json::json!({"error": "Not implemented"})))
|
||||
}
|
||||
|
||||
/// Search files
|
||||
pub async fn search_files(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
Json(req): Json<SearchQuery>,
|
||||
) -> Result<Json<Vec<FileItem>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let query = req.query.clone().unwrap_or_default();
|
||||
let parent_path = req.parent_path.clone();
|
||||
|
||||
tracing::debug!("Searching files: query={:?}, parent_path={:?}", query, parent_path);
|
||||
|
||||
// TODO: Implement actual file search
|
||||
let files = vec![];
|
||||
|
||||
Ok(Json(files))
|
||||
pub async fn upload_file_to_drive(State(_): State<Arc<AppState>>, Json(_): Json<UploadRequest>) -> impl IntoResponse {
|
||||
(StatusCode::NOT_IMPLEMENTED, Json(serde_json::json!({"error": "Not implemented"})))
|
||||
}
|
||||
|
||||
/// Get recent files
|
||||
pub async fn recent_files(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
) -> Result<Json<Vec<FileItem>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Getting recent files");
|
||||
|
||||
// TODO: Implement actual recent files query
|
||||
let files = vec![];
|
||||
|
||||
Ok(Json(files))
|
||||
pub async fn list_folder_contents(State(_): State<Arc<AppState>>, Json(_): Json<SearchQuery>) -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(Vec::<FileItem>::new()))
|
||||
}
|
||||
|
||||
/// List favorites
|
||||
pub async fn list_favorites(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
) -> Result<Json<Vec<FileItem>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Listing favorites");
|
||||
|
||||
// TODO: Implement actual favorites query
|
||||
let files = vec![];
|
||||
|
||||
Ok(Json(files))
|
||||
pub async fn search_files(State(_): State<Arc<AppState>>, Json(_): Json<SearchQuery>) -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(Vec::<FileItem>::new()))
|
||||
}
|
||||
|
||||
/// Share folder
|
||||
pub async fn share_folder(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
Json(_req): Json<ShareRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Sharing folder");
|
||||
|
||||
// TODO: Implement actual folder sharing
|
||||
Ok(Json(serde_json::json!({"success": true})))
|
||||
pub async fn recent_files(State(_): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(Vec::<FileItem>::new()))
|
||||
}
|
||||
|
||||
/// List shared files/folders
|
||||
pub async fn list_shared(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
) -> Result<Json<Vec<FileItem>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
tracing::debug!("Listing shared resources");
|
||||
|
||||
// TODO: Implement actual shared query
|
||||
let items = vec![];
|
||||
|
||||
Ok(Json(items))
|
||||
pub async fn list_favorites(State(_): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(Vec::<FileItem>::new()))
|
||||
}
|
||||
pub async fn share_folder(State(_): State<Arc<AppState>>, Json(_): Json<ShareRequest>) -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(serde_json::json!({"success": true})))
|
||||
}
|
||||
pub async fn list_shared(State(_): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(Vec::<FileItem>::new()))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ pub struct ShareRequest {
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchQuery {
|
||||
pub bucket: Option<String>,
|
||||
pub query: Option<String>,
|
||||
pub file_type: Option<String>,
|
||||
pub parent_path: Option<String>,
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ impl LocalFileMonitor {
|
|||
// Look for <botname>.gbdialog folder inside (e.g., cristo.gbai/cristo.gbdialog)
|
||||
let gbdialog_path = path.join(format!("{}.gbdialog", bot_name));
|
||||
if gbdialog_path.exists() {
|
||||
self.compile_gbdialog(&bot_name, &gbdialog_path).await?;
|
||||
self.compile_gbdialog(bot_name, &gbdialog_path).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -289,9 +289,9 @@ impl LocalFileMonitor {
|
|||
std::fs::write(&local_source_path, &source_content_clone)?;
|
||||
let mut compiler = BasicCompiler::new(state_clone, bot_id);
|
||||
let local_source_str = local_source_path.to_str()
|
||||
.ok_or_else(|| format!("Invalid UTF-8 in local source path"))?;
|
||||
.ok_or_else(|| "Invalid UTF-8 in local source path".to_string())?;
|
||||
let work_dir_str = work_dir_clone.to_str()
|
||||
.ok_or_else(|| format!("Invalid UTF-8 in work directory path"))?;
|
||||
.ok_or_else(|| "Invalid UTF-8 in work directory path".to_string())?;
|
||||
let result = compiler.compile_file(local_source_str, work_dir_str)?;
|
||||
if let Some(mcp_tool) = result.mcp_tool {
|
||||
info!(
|
||||
|
|
|
|||
|
|
@ -65,8 +65,8 @@ fn is_tracking_pixel_enabled(state: &Arc<AppState>, bot_id: Option<Uuid>) -> boo
|
|||
fn inject_tracking_pixel(html_body: &str, tracking_id: &str, state: &Arc<AppState>) -> String {
|
||||
let config_manager = crate::core::config::ConfigManager::new(state.conn.clone());
|
||||
let base_url = config_manager
|
||||
.get_config(&Uuid::nil(), "server-url", Some("http://localhost:8080"))
|
||||
.unwrap_or_else(|_| "http://localhost:8080".to_string());
|
||||
.get_config(&Uuid::nil(), "server-url", Some("http://localhost:9000"))
|
||||
.unwrap_or_else(|_| "http://localhost:9000".to_string());
|
||||
|
||||
let pixel_url = format!("{}/api/email/tracking/pixel/{}", base_url, tracking_id);
|
||||
let pixel_html = format!(
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ pub fn is_tracking_pixel_enabled(state: &Arc<AppState>, bot_id: Option<Uuid>) ->
|
|||
pub fn inject_tracking_pixel(html_body: &str, tracking_id: &str, state: &Arc<AppState>) -> String {
|
||||
let config_manager = crate::core::config::ConfigManager::new(state.conn.clone());
|
||||
let base_url = config_manager
|
||||
.get_config(&Uuid::nil(), "server-url", Some("http://localhost:8080"))
|
||||
.unwrap_or_else(|_| "http://localhost:8080".to_string());
|
||||
.get_config(&Uuid::nil(), "server-url", Some("http://localhost:9000"))
|
||||
.unwrap_or_else(|_| "http://localhost:9000".to_string());
|
||||
|
||||
let pixel_url = format!("{}/api/email/tracking/pixel/{}", base_url, tracking_id);
|
||||
let pixel_html = format!(
|
||||
|
|
|
|||
|
|
@ -232,6 +232,16 @@ impl ClaudeClient {
|
|||
(system_prompt, claude_messages)
|
||||
}
|
||||
|
||||
/// Sanitizes a string by removing invalid UTF-8 surrogate characters
|
||||
fn sanitize_utf8(input: &str) -> String {
|
||||
input.chars()
|
||||
.filter(|c| {
|
||||
let cp = *c as u32;
|
||||
!(0xD800..=0xDBFF).contains(&cp) && !(0xDC00..=0xDFFF).contains(&cp)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn build_messages(
|
||||
system_prompt: &str,
|
||||
context_data: &str,
|
||||
|
|
@ -241,15 +251,15 @@ impl ClaudeClient {
|
|||
let mut system_parts = Vec::new();
|
||||
|
||||
if !system_prompt.is_empty() {
|
||||
system_parts.push(system_prompt.to_string());
|
||||
system_parts.push(Self::sanitize_utf8(system_prompt));
|
||||
}
|
||||
if !context_data.is_empty() {
|
||||
system_parts.push(context_data.to_string());
|
||||
system_parts.push(Self::sanitize_utf8(context_data));
|
||||
}
|
||||
|
||||
for (role, content) in history {
|
||||
if role == "episodic" || role == "compact" {
|
||||
system_parts.push(format!("[Previous conversation summary]: {content}"));
|
||||
system_parts.push(format!("[Previous conversation summary]: {}", Self::sanitize_utf8(content)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -270,7 +280,8 @@ impl ClaudeClient {
|
|||
};
|
||||
|
||||
if let Some(norm_role) = normalized_role {
|
||||
if content.is_empty() {
|
||||
let sanitized_content = Self::sanitize_utf8(content);
|
||||
if sanitized_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -278,14 +289,14 @@ impl ClaudeClient {
|
|||
if let Some(last_msg) = messages.last_mut() {
|
||||
let last_msg: &mut ClaudeMessage = last_msg;
|
||||
last_msg.content.push_str("\n\n");
|
||||
last_msg.content.push_str(content);
|
||||
last_msg.content.push_str(&sanitized_content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(ClaudeMessage {
|
||||
role: norm_role.clone(),
|
||||
content: content.clone(),
|
||||
content: sanitized_content,
|
||||
});
|
||||
last_role = Some(norm_role);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -116,6 +116,16 @@ impl GLMClient {
|
|||
// GLM/z.ai uses /chat/completions (not /v1/chat/completions)
|
||||
format!("{}/chat/completions", self.base_url)
|
||||
}
|
||||
|
||||
/// Sanitizes a string by removing invalid UTF-8 surrogate characters
|
||||
fn sanitize_utf8(input: &str) -> String {
|
||||
input.chars()
|
||||
.filter(|c| {
|
||||
let cp = *c as u32;
|
||||
!(0xD800..=0xDBFF).contains(&cp) && !(0xDC00..=0xDFFF).contains(&cp)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -183,11 +193,6 @@ impl LLMProvider for GLMClient {
|
|||
key: &str,
|
||||
tools: Option<&Vec<Value>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// DEBUG: Log what we received
|
||||
info!("[GLM_DEBUG] config type: {}", config);
|
||||
info!("[GLM_DEBUG] prompt: '{}'", prompt);
|
||||
info!("[GLM_DEBUG] config as JSON: {}", serde_json::to_string_pretty(config).unwrap_or_default());
|
||||
|
||||
// config IS the messages array directly, not nested
|
||||
let messages = if let Some(msgs) = config.as_array() {
|
||||
// Convert messages from config format to GLM format
|
||||
|
|
@ -195,25 +200,23 @@ impl LLMProvider for GLMClient {
|
|||
.filter_map(|m| {
|
||||
let role = m.get("role")?.as_str()?;
|
||||
let content = m.get("content")?.as_str()?;
|
||||
info!("[GLM_DEBUG] Processing message - role: {}, content: '{}'", role, content);
|
||||
if !content.is_empty() {
|
||||
let sanitized = Self::sanitize_utf8(content);
|
||||
if !sanitized.is_empty() {
|
||||
Some(GLMMessage {
|
||||
role: role.to_string(),
|
||||
content: Some(content.to_string()),
|
||||
content: Some(sanitized),
|
||||
tool_calls: None,
|
||||
})
|
||||
} else {
|
||||
info!("[GLM_DEBUG] Skipping empty content message");
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
// Fallback to building from prompt
|
||||
info!("[GLM_DEBUG] No array found, using prompt: '{}'", prompt);
|
||||
vec![GLMMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(prompt.to_string()),
|
||||
content: Some(Self::sanitize_utf8(prompt)),
|
||||
tool_calls: None,
|
||||
}]
|
||||
};
|
||||
|
|
@ -223,8 +226,6 @@ impl LLMProvider for GLMClient {
|
|||
return Err("No valid messages in request".into());
|
||||
}
|
||||
|
||||
info!("[GLM_DEBUG] Final GLM messages count: {}", messages.len());
|
||||
|
||||
// Use glm-4.7 for tool calling support
|
||||
// GLM-4.7 supports standard OpenAI-compatible function calling
|
||||
let model_name = if model == "glm-4" { "glm-4.7" } else { model };
|
||||
|
|
@ -242,17 +243,13 @@ impl LLMProvider for GLMClient {
|
|||
stream: Some(true),
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
tools: tools.map(|t| t.clone()),
|
||||
tools: tools.cloned(),
|
||||
tool_choice,
|
||||
};
|
||||
|
||||
let url = self.build_url();
|
||||
info!("GLM streaming request to: {}", url);
|
||||
|
||||
// Log the exact request being sent
|
||||
let request_json = serde_json::to_string_pretty(&request).unwrap_or_default();
|
||||
info!("GLM request body: {}", request_json);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
|
|
@ -286,24 +283,19 @@ impl LLMProvider for GLMClient {
|
|||
}
|
||||
|
||||
if line == "data: [DONE]" {
|
||||
let _ = tx.send(String::new()); // Signal end
|
||||
std::mem::drop(tx.send(String::new())); // Signal end
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let json_str = line[6..].trim();
|
||||
info!("[GLM_SSE] Received SSE line ({} chars): {}", json_str.len(), json_str);
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
let json_str = json_str.trim();
|
||||
if let Ok(chunk_data) = serde_json::from_str::<Value>(json_str) {
|
||||
if let Some(choices) = chunk_data.get("choices").and_then(|c| c.as_array()) {
|
||||
for choice in choices {
|
||||
info!("[GLM_SSE] Processing choice");
|
||||
if let Some(delta) = choice.get("delta") {
|
||||
info!("[GLM_SSE] Delta: {}", serde_json::to_string(delta).unwrap_or_default());
|
||||
|
||||
// Handle tool_calls (GLM-4.7 standard function calling)
|
||||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||||
for tool_call in tool_calls {
|
||||
info!("[GLM_SSE] Tool call detected: {}", serde_json::to_string(tool_call).unwrap_or_default());
|
||||
// Send tool_calls as JSON for the calling code to process
|
||||
let tool_call_json = serde_json::json!({
|
||||
"type": "tool_call",
|
||||
|
|
@ -323,7 +315,6 @@ impl LLMProvider for GLMClient {
|
|||
// This makes GLM behave like OpenAI-compatible APIs
|
||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
if !content.is_empty() {
|
||||
info!("[GLM_TX] Sending to channel: '{}'", content);
|
||||
match tx.send(content.to_string()).await {
|
||||
Ok(_) => {},
|
||||
Err(e) => {
|
||||
|
|
@ -331,16 +322,14 @@ impl LLMProvider for GLMClient {
|
|||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("[GLM_SSE] No content field in delta");
|
||||
}
|
||||
} else {
|
||||
info!("[GLM_SSE] No delta in choice");
|
||||
// No delta in choice
|
||||
}
|
||||
if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str()) {
|
||||
if !reason.is_empty() {
|
||||
info!("GLM stream finished: {}", reason);
|
||||
let _ = tx.send(String::new());
|
||||
std::mem::drop(tx.send(String::new()));
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
|
@ -356,7 +345,7 @@ impl LLMProvider for GLMClient {
|
|||
}
|
||||
}
|
||||
|
||||
let _ = tx.send(String::new()); // Signal completion
|
||||
std::mem::drop(tx.send(String::new())); // Signal completion
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
188
src/llm/local.rs
188
src/llm/local.rs
|
|
@ -34,7 +34,7 @@ pub async fn ensure_llama_servers_running(
|
|||
let mut conn = conn_arc
|
||||
.get()
|
||||
.map_err(|e| format!("failed to get db connection: {e}"))?;
|
||||
Ok(crate::core::bot::get_default_bot(&mut *conn))
|
||||
Ok(crate::core::bot::get_default_bot(&mut conn))
|
||||
})
|
||||
.await??;
|
||||
let config_manager = ConfigManager::new(app_state.conn.clone());
|
||||
|
|
@ -394,72 +394,76 @@ pub fn start_llm_server(
|
|||
.unwrap_or_else(|_| "32000".to_string());
|
||||
let n_ctx_size = if n_ctx_size.is_empty() { "32000".to_string() } else { n_ctx_size };
|
||||
|
||||
let mut args = format!(
|
||||
"-m {model_path} --host 0.0.0.0 --port {port} --top_p 0.95 --temp 0.6 --repeat-penalty 1.2 --n-gpu-layers {gpu_layers} --ubatch-size 2048"
|
||||
);
|
||||
if !reasoning_format.is_empty() {
|
||||
let _ = write!(args, " --reasoning-format {reasoning_format}");
|
||||
}
|
||||
let cmd_path = if cfg!(windows) {
|
||||
format!("{}\\llama-server.exe", llama_cpp_path)
|
||||
} else {
|
||||
format!("{}/llama-server", llama_cpp_path)
|
||||
};
|
||||
|
||||
let mut command = std::process::Command::new(&cmd_path);
|
||||
command.arg("-m").arg(&model_path)
|
||||
.arg("--host").arg("0.0.0.0")
|
||||
.arg("--port").arg(port)
|
||||
.arg("--top_p").arg("0.95")
|
||||
.arg("--temp").arg("0.6")
|
||||
.arg("--repeat-penalty").arg("1.2")
|
||||
.arg("--n-gpu-layers").arg(&gpu_layers)
|
||||
.arg("--ubatch-size").arg("2048");
|
||||
|
||||
if !reasoning_format.is_empty() {
|
||||
command.arg("--reasoning-format").arg(&reasoning_format);
|
||||
}
|
||||
if n_moe != "0" {
|
||||
let _ = write!(args, " --n-cpu-moe {n_moe}");
|
||||
command.arg("--n-cpu-moe").arg(&n_moe);
|
||||
}
|
||||
if parallel != "1" {
|
||||
let _ = write!(args, " --parallel {parallel}");
|
||||
command.arg("--parallel").arg(¶llel);
|
||||
}
|
||||
if cont_batching == "true" {
|
||||
args.push_str(" --cont-batching");
|
||||
command.arg("--cont-batching");
|
||||
}
|
||||
if mlock == "true" {
|
||||
args.push_str(" --mlock");
|
||||
command.arg("--mlock");
|
||||
}
|
||||
if no_mmap == "true" {
|
||||
args.push_str(" --no-mmap");
|
||||
command.arg("--no-mmap");
|
||||
}
|
||||
if n_predict != "0" {
|
||||
let _ = write!(args, " --n-predict {n_predict}");
|
||||
command.arg("--n-predict").arg(&n_predict);
|
||||
}
|
||||
let _ = write!(args, " --ctx-size {n_ctx_size}");
|
||||
command.arg("--ctx-size").arg(&n_ctx_size);
|
||||
command.arg("--verbose");
|
||||
|
||||
if cfg!(windows) {
|
||||
let cmd_arg = format!("cd {llama_cpp_path} && .\\llama-server.exe {args}");
|
||||
info!(
|
||||
"Executing LLM server command: cd {llama_cpp_path} && .\\llama-server.exe {args} --verbose"
|
||||
);
|
||||
let cmd = SafeCommand::new("cmd")
|
||||
.and_then(|c| c.arg("/C"))
|
||||
.and_then(|c| c.trusted_shell_script_arg(&cmd_arg))
|
||||
.map_err(|e| {
|
||||
Box::new(std::io::Error::other(
|
||||
e.to_string(),
|
||||
)) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
cmd.execute().map_err(|e| {
|
||||
Box::new(std::io::Error::other(
|
||||
e.to_string(),
|
||||
)) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
} else {
|
||||
let cmd_arg = format!(
|
||||
"{llama_cpp_path}/llama-server {args} --verbose >{llama_cpp_path}/llm-stdout.log 2>&1 &"
|
||||
);
|
||||
info!(
|
||||
"Executing LLM server command: {llama_cpp_path}/llama-server {args} --verbose"
|
||||
);
|
||||
let cmd = SafeCommand::new("sh")
|
||||
.and_then(|c| c.arg("-c"))
|
||||
.and_then(|c| c.trusted_shell_script_arg(&cmd_arg))
|
||||
.map_err(|e| {
|
||||
Box::new(std::io::Error::other(
|
||||
e.to_string(),
|
||||
)) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
cmd.execute().map_err(|e| {
|
||||
Box::new(std::io::Error::other(
|
||||
e.to_string(),
|
||||
)) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
command.current_dir(&llama_cpp_path);
|
||||
}
|
||||
|
||||
let log_file_path = if cfg!(windows) {
|
||||
format!("{}\\llm-stdout.log", llama_cpp_path)
|
||||
} else {
|
||||
format!("{}/llm-stdout.log", llama_cpp_path)
|
||||
};
|
||||
|
||||
match std::fs::File::create(&log_file_path) {
|
||||
Ok(log_file) => {
|
||||
if let Ok(clone) = log_file.try_clone() {
|
||||
command.stdout(std::process::Stdio::from(clone));
|
||||
} else {
|
||||
command.stdout(std::process::Stdio::null());
|
||||
}
|
||||
command.stderr(std::process::Stdio::from(log_file));
|
||||
}
|
||||
Err(_) => {
|
||||
command.stdout(std::process::Stdio::null());
|
||||
command.stderr(std::process::Stdio::null());
|
||||
}
|
||||
}
|
||||
|
||||
info!("Executing LLM server command: {:?}", command);
|
||||
|
||||
command.spawn().map_err(|e| {
|
||||
Box::new(std::io::Error::other(e.to_string())) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn start_embedding_server(
|
||||
|
|
@ -486,45 +490,55 @@ pub async fn start_embedding_server(
|
|||
|
||||
info!("Starting embedding server on port {port} with model: {model_path}");
|
||||
|
||||
if cfg!(windows) {
|
||||
let cmd_arg = format!(
|
||||
"cd {llama_cpp_path} && .\\llama-server.exe -m {model_path} --verbose --host 0.0.0.0 --port {port} --embedding --n-gpu-layers 99 >stdout.log 2>&1"
|
||||
);
|
||||
let cmd = SafeCommand::new("cmd")
|
||||
.and_then(|c| c.arg("/c"))
|
||||
.and_then(|c| c.trusted_shell_script_arg(&cmd_arg))
|
||||
.map_err(|e| {
|
||||
Box::new(std::io::Error::other(
|
||||
e.to_string(),
|
||||
)) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
cmd.execute().map_err(|e| {
|
||||
Box::new(std::io::Error::other(
|
||||
e.to_string(),
|
||||
)) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
let cmd_path = if cfg!(windows) {
|
||||
format!("{}\\llama-server.exe", llama_cpp_path)
|
||||
} else {
|
||||
let cmd_arg = format!(
|
||||
"{llama_cpp_path}/llama-server -m {model_path} --verbose --host 0.0.0.0 --port {port} --embedding --n-gpu-layers 99 --ubatch-size 2048 >{llama_cpp_path}/llmembd-stdout.log 2>&1 &"
|
||||
);
|
||||
info!(
|
||||
"Executing embedding server command: {llama_cpp_path}/llama-server -m {model_path} --host 0.0.0.0 --port {port} --embedding"
|
||||
);
|
||||
let cmd = SafeCommand::new("sh")
|
||||
.and_then(|c| c.arg("-c"))
|
||||
.and_then(|c| c.trusted_shell_script_arg(&cmd_arg))
|
||||
.map_err(|e| {
|
||||
Box::new(std::io::Error::other(
|
||||
e.to_string(),
|
||||
)) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
cmd.execute().map_err(|e| {
|
||||
Box::new(std::io::Error::other(
|
||||
e.to_string(),
|
||||
)) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
format!("{}/llama-server", llama_cpp_path)
|
||||
};
|
||||
|
||||
let mut command = std::process::Command::new(&cmd_path);
|
||||
command.arg("-m").arg(&model_path)
|
||||
.arg("--host").arg("0.0.0.0")
|
||||
.arg("--port").arg(port)
|
||||
.arg("--embedding")
|
||||
.arg("--n-gpu-layers").arg("99")
|
||||
.arg("--verbose");
|
||||
|
||||
if !cfg!(windows) {
|
||||
command.arg("--ubatch-size").arg("2048");
|
||||
}
|
||||
|
||||
if cfg!(windows) {
|
||||
command.current_dir(&llama_cpp_path);
|
||||
}
|
||||
|
||||
let log_file_path = if cfg!(windows) {
|
||||
format!("{}\\stdout.log", llama_cpp_path)
|
||||
} else {
|
||||
format!("{}/llmembd-stdout.log", llama_cpp_path)
|
||||
};
|
||||
|
||||
match std::fs::File::create(&log_file_path) {
|
||||
Ok(log_file) => {
|
||||
if let Ok(clone) = log_file.try_clone() {
|
||||
command.stdout(std::process::Stdio::from(clone));
|
||||
} else {
|
||||
command.stdout(std::process::Stdio::null());
|
||||
}
|
||||
command.stderr(std::process::Stdio::from(log_file));
|
||||
}
|
||||
Err(_) => {
|
||||
command.stdout(std::process::Stdio::null());
|
||||
command.stderr(std::process::Stdio::null());
|
||||
}
|
||||
}
|
||||
|
||||
info!("Executing embedding server command: {:?}", command);
|
||||
|
||||
command.spawn().map_err(|e| {
|
||||
Box::new(std::io::Error::other(e.to_string())) as Box<dyn std::error::Error + Send + Sync>
|
||||
})?;
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
|
||||
Ok(())
|
||||
|
|
|
|||
|
|
@ -185,6 +185,17 @@ impl OpenAIClient {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sanitizes a string by removing invalid UTF-8 surrogate characters
|
||||
/// that cannot be encoded in valid UTF-8 (surrogates are only valid in UTF-16)
|
||||
fn sanitize_utf8(input: &str) -> String {
|
||||
input.chars()
|
||||
.filter(|c| {
|
||||
let cp = *c as u32;
|
||||
!(0xD800..=0xDBFF).contains(&cp) && !(0xDC00..=0xDFFF).contains(&cp)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn build_messages(
|
||||
system_prompt: &str,
|
||||
context_data: &str,
|
||||
|
|
@ -194,19 +205,19 @@ impl OpenAIClient {
|
|||
if !system_prompt.is_empty() {
|
||||
messages.push(serde_json::json!({
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
"content": Self::sanitize_utf8(system_prompt)
|
||||
}));
|
||||
}
|
||||
if !context_data.is_empty() {
|
||||
messages.push(serde_json::json!({
|
||||
"role": "system",
|
||||
"content": context_data
|
||||
"content": Self::sanitize_utf8(context_data)
|
||||
}));
|
||||
}
|
||||
for (role, content) in history {
|
||||
messages.push(serde_json::json!({
|
||||
"role": role,
|
||||
"content": content
|
||||
"content": Self::sanitize_utf8(content)
|
||||
}));
|
||||
}
|
||||
serde_json::Value::Array(messages)
|
||||
|
|
@ -747,10 +758,10 @@ mod tests {
|
|||
fn test_openai_client_new_custom_url() {
|
||||
let client = OpenAIClient::new(
|
||||
"test_key".to_string(),
|
||||
Some("http://localhost:8080".to_string()),
|
||||
Some("http://localhost:9000".to_string()),
|
||||
None,
|
||||
);
|
||||
assert_eq!(client.base_url, "http://localhost:8080");
|
||||
assert_eq!(client.base_url, "http://localhost:9000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -1,720 +0,0 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use log::{error, info};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
|
||||
pub mod cache;
|
||||
pub mod claude;
|
||||
pub mod episodic_memory;
|
||||
pub mod llm_models;
|
||||
pub mod local;
|
||||
pub mod smart_router;
|
||||
|
||||
pub use claude::ClaudeClient;
|
||||
pub use llm_models::get_handler;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenAIClient {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
endpoint_path: String,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
/// Estimates token count for a text string (roughly 4 characters per token for English)
|
||||
fn estimate_tokens(text: &str) -> usize {
|
||||
// Rough estimate: ~4 characters per token for English text
|
||||
// This is a heuristic and may not be accurate for all languages
|
||||
text.len().div_ceil(4)
|
||||
}
|
||||
|
||||
/// Estimates total tokens for a messages array
|
||||
fn estimate_messages_tokens(messages: &Value) -> usize {
|
||||
if let Some(msg_array) = messages.as_array() {
|
||||
msg_array
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
||||
Self::estimate_tokens(content)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates messages to fit within the max_tokens limit
|
||||
/// Keeps system messages and the most recent user/assistant messages
|
||||
fn truncate_messages(messages: &Value, max_tokens: usize) -> Value {
|
||||
let mut result = Vec::new();
|
||||
let mut token_count = 0;
|
||||
|
||||
if let Some(msg_array) = messages.as_array() {
|
||||
// First pass: keep all system messages
|
||||
for msg in msg_array {
|
||||
if let Some(role) = msg.get("role").and_then(|r| r.as_str()) {
|
||||
if role == "system" {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
||||
let msg_tokens = Self::estimate_tokens(content);
|
||||
if token_count + msg_tokens <= max_tokens {
|
||||
result.push(msg.clone());
|
||||
token_count += msg_tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: add user/assistant messages from newest to oldest
|
||||
let mut recent_messages: Vec<&Value> = msg_array
|
||||
.iter()
|
||||
.filter(|msg| msg.get("role").and_then(|r| r.as_str()) != Some("system"))
|
||||
.collect();
|
||||
|
||||
// Reverse to get newest first
|
||||
recent_messages.reverse();
|
||||
|
||||
for msg in recent_messages {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
||||
let msg_tokens = Self::estimate_tokens(content);
|
||||
if token_count + msg_tokens <= max_tokens {
|
||||
result.push(msg.clone());
|
||||
token_count += msg_tokens;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse back to chronological order for non-system messages
|
||||
// But keep system messages at the beginning
|
||||
let system_count = result.len()
|
||||
- result
|
||||
.iter()
|
||||
.filter(|m| m.get("role").and_then(|r| r.as_str()) != Some("system"))
|
||||
.count();
|
||||
let mut user_messages: Vec<Value> = result.drain(system_count..).collect();
|
||||
user_messages.reverse();
|
||||
result.extend(user_messages);
|
||||
}
|
||||
|
||||
serde_json::Value::Array(result)
|
||||
}
|
||||
|
||||
/// Ensures messages fit within model's context limit
|
||||
fn ensure_token_limit(messages: &Value, model_context_limit: usize) -> Value {
|
||||
let estimated_tokens = Self::estimate_messages_tokens(messages);
|
||||
|
||||
// Use 90% of context limit to leave room for response
|
||||
let safe_limit = (model_context_limit as f64 * 0.9) as usize;
|
||||
|
||||
if estimated_tokens > safe_limit {
|
||||
log::warn!(
|
||||
"Messages exceed token limit ({} > {}), truncating...",
|
||||
estimated_tokens,
|
||||
safe_limit
|
||||
);
|
||||
Self::truncate_messages(messages, safe_limit)
|
||||
} else {
|
||||
messages.clone()
|
||||
}
|
||||
}
|
||||
pub fn new(_api_key: String, base_url: Option<String>, endpoint_path: Option<String>) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url: base_url.unwrap_or_else(|| "https://api.openai.com".to_string()),
|
||||
endpoint_path: endpoint_path.unwrap_or_else(|| "/v1/chat/completions".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_messages(
|
||||
system_prompt: &str,
|
||||
context_data: &str,
|
||||
history: &[(String, String)],
|
||||
) -> Value {
|
||||
let mut messages = Vec::new();
|
||||
if !system_prompt.is_empty() {
|
||||
messages.push(serde_json::json!({
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}));
|
||||
}
|
||||
if !context_data.is_empty() {
|
||||
messages.push(serde_json::json!({
|
||||
"role": "system",
|
||||
"content": context_data
|
||||
}));
|
||||
}
|
||||
for (role, content) in history {
|
||||
messages.push(serde_json::json!({
|
||||
"role": role,
|
||||
"content": content
|
||||
}));
|
||||
}
|
||||
serde_json::Value::Array(messages)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for OpenAIClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
messages: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
|
||||
// Get the messages to use
|
||||
let raw_messages =
|
||||
if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
};
|
||||
|
||||
// Ensure messages fit within model's context limit
|
||||
// GLM-4.7 has 202750 tokens, other models vary
|
||||
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
|
||||
202750
|
||||
} else if model.contains("gpt-4") {
|
||||
128000
|
||||
} else if model.contains("gpt-3.5") {
|
||||
16385
|
||||
} else {
|
||||
model.starts_with("http://localhost:808") ? 768 : 4096 // Local llama.cpp or default limit
|
||||
};
|
||||
|
||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}{}", self.base_url, self.endpoint_path))
|
||||
.header("Authorization", format!("Bearer {}", key))
|
||||
.json(&serde_json::json!({
|
||||
"model": model,
|
||||
"messages": messages
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status != reqwest::StatusCode::OK {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("LLM generate error: {}", error_text);
|
||||
return Err(format!("LLM request failed with status: {}", status).into());
|
||||
}
|
||||
|
||||
let result: Value = response.json().await?;
|
||||
let raw_content = result["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or("");
|
||||
|
||||
let handler = get_handler(model);
|
||||
let content = handler.process_content(raw_content);
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
messages: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
|
||||
// Get the messages to use
|
||||
let raw_messages =
|
||||
if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
info!("Using provided messages: {:?}", messages);
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
};
|
||||
|
||||
// Ensure messages fit within model's context limit
|
||||
// GLM-4.7 has 202750 tokens, other models vary
|
||||
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
|
||||
202750
|
||||
} else if model.contains("gpt-4") {
|
||||
128000
|
||||
} else if model.contains("gpt-3.5") {
|
||||
16385
|
||||
} else {
|
||||
model.starts_with("http://localhost:808") ? 768 : 4096 // Local llama.cpp or default limit
|
||||
};
|
||||
|
||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}{}", self.base_url, self.endpoint_path))
|
||||
.header("Authorization", format!("Bearer {}", key))
|
||||
.json(&serde_json::json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status != reqwest::StatusCode::OK {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("LLM generate_stream error: {}", error_text);
|
||||
return Err(format!("LLM request failed with status: {}", status).into());
|
||||
}
|
||||
|
||||
let handler = get_handler(model);
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
for line in chunk_str.lines() {
|
||||
if line.starts_with("data: ") && !line.contains("[DONE]") {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
if let Some(content) = data["choices"][0]["delta"]["content"].as_str() {
|
||||
let processed = handler.process_content(content);
|
||||
if !processed.is_empty() {
|
||||
let _ = tx.send(processed).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_llm_services(state: &std::sync::Arc<crate::shared::state::AppState>) {
|
||||
episodic_memory::start_episodic_memory_scheduler(std::sync::Arc::clone(state));
|
||||
info!("LLM services started (episodic memory scheduler)");
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum LLMProviderType {
|
||||
OpenAI,
|
||||
Claude,
|
||||
AzureClaude,
|
||||
}
|
||||
|
||||
impl From<&str> for LLMProviderType {
|
||||
fn from(s: &str) -> Self {
|
||||
let lower = s.to_lowercase();
|
||||
if lower.contains("claude") || lower.contains("anthropic") {
|
||||
if lower.contains("azure") {
|
||||
Self::AzureClaude
|
||||
} else {
|
||||
Self::Claude
|
||||
}
|
||||
} else {
|
||||
Self::OpenAI
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_llm_provider(
|
||||
provider_type: LLMProviderType,
|
||||
base_url: String,
|
||||
deployment_name: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
) -> std::sync::Arc<dyn LLMProvider> {
|
||||
match provider_type {
|
||||
LLMProviderType::OpenAI => {
|
||||
info!("Creating OpenAI LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(OpenAIClient::new(
|
||||
"empty".to_string(),
|
||||
Some(base_url),
|
||||
endpoint_path,
|
||||
))
|
||||
}
|
||||
LLMProviderType::Claude => {
|
||||
info!("Creating Claude LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(ClaudeClient::new(base_url, deployment_name))
|
||||
}
|
||||
LLMProviderType::AzureClaude => {
|
||||
let deployment = deployment_name.unwrap_or_else(|| "claude-opus-4-5".to_string());
|
||||
info!(
|
||||
"Creating Azure Claude LLM provider with URL: {}, deployment: {}",
|
||||
base_url, deployment
|
||||
);
|
||||
std::sync::Arc::new(ClaudeClient::azure(base_url, deployment))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_llm_provider_from_url(
|
||||
url: &str,
|
||||
model: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
) -> std::sync::Arc<dyn LLMProvider> {
|
||||
let provider_type = LLMProviderType::from(url);
|
||||
create_llm_provider(provider_type, url.to_string(), model, endpoint_path)
|
||||
}
|
||||
|
||||
pub struct DynamicLLMProvider {
|
||||
inner: RwLock<Arc<dyn LLMProvider>>,
|
||||
}
|
||||
|
||||
impl DynamicLLMProvider {
|
||||
pub fn new(provider: Arc<dyn LLMProvider>) -> Self {
|
||||
Self {
|
||||
inner: RwLock::new(provider),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_provider(&self, new_provider: Arc<dyn LLMProvider>) {
|
||||
let mut guard = self.inner.write().await;
|
||||
*guard = new_provider;
|
||||
info!("LLM provider updated dynamically");
|
||||
}
|
||||
|
||||
pub async fn update_from_config(
|
||||
&self,
|
||||
url: &str,
|
||||
model: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
) {
|
||||
let new_provider = create_llm_provider_from_url(url, model, endpoint_path);
|
||||
self.update_provider(new_provider).await;
|
||||
}
|
||||
|
||||
async fn get_provider(&self) -> Arc<dyn LLMProvider> {
|
||||
self.inner.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for DynamicLLMProvider {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.get_provider()
|
||||
.await
|
||||
.generate(prompt, config, model, key)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.get_provider()
|
||||
.await
|
||||
.generate_stream(prompt, config, tx, model, key)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.get_provider().await.cancel_job(session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub r#type: String,
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: i64,
|
||||
model: String,
|
||||
choices: Vec<ChatChoice>,
|
||||
usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ChatChoice {
|
||||
index: i32,
|
||||
message: ChatMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Usage {
|
||||
#[serde(rename = "prompt_tokens")]
|
||||
prompt: i32,
|
||||
#[serde(rename = "completion_tokens")]
|
||||
completion: i32,
|
||||
#[serde(rename = "total_tokens")]
|
||||
total: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ErrorDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ErrorDetail {
|
||||
message: String,
|
||||
#[serde(rename = "type")]
|
||||
r#type: String,
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_serialization() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "get_weather".to_string(),
|
||||
arguments: r#"{"location": "NYC"}"#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&tool_call).unwrap();
|
||||
assert!(json.contains("get_weather"));
|
||||
assert!(json.contains("call_123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completion_response_serialization() {
|
||||
let response = ChatCompletionResponse {
|
||||
id: "test-id".to_string(),
|
||||
object: "chat.completion".to_string(),
|
||||
created: 1_234_567_890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![ChatChoice {
|
||||
index: 0,
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello!".to_string()),
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt: 10,
|
||||
completion: 5,
|
||||
total: 15,
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
assert!(json.contains("chat.completion"));
|
||||
assert!(json.contains("Hello!"));
|
||||
assert!(json.contains("gpt-4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response_serialization() {
|
||||
let error = ErrorResponse {
|
||||
error: ErrorDetail {
|
||||
message: "Test error".to_string(),
|
||||
r#type: "test_error".to_string(),
|
||||
code: "test_code".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&error).unwrap();
|
||||
assert!(json.contains("Test error"));
|
||||
assert!(json.contains("test_code"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_empty() {
|
||||
let messages = OpenAIClient::build_messages("", "", &[]);
|
||||
assert!(messages.is_array());
|
||||
assert!(messages.as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_with_system_prompt() {
|
||||
let messages = OpenAIClient::build_messages("You are a helpful assistant.", "", &[]);
|
||||
let arr = messages.as_array().unwrap();
|
||||
assert_eq!(arr.len(), 1);
|
||||
assert_eq!(arr[0]["role"], "system");
|
||||
assert_eq!(arr[0]["content"], "You are a helpful assistant.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_with_context() {
|
||||
let messages = OpenAIClient::build_messages("System prompt", "Context data", &[]);
|
||||
let arr = messages.as_array().unwrap();
|
||||
assert_eq!(arr.len(), 2);
|
||||
assert_eq!(arr[0]["content"], "System prompt");
|
||||
assert_eq!(arr[1]["content"], "Context data");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_with_history() {
|
||||
let history = vec![
|
||||
("user".to_string(), "Hello".to_string()),
|
||||
("assistant".to_string(), "Hi there!".to_string()),
|
||||
];
|
||||
let messages = OpenAIClient::build_messages("", "", &history);
|
||||
let arr = messages.as_array().unwrap();
|
||||
assert_eq!(arr.len(), 2);
|
||||
assert_eq!(arr[0]["role"], "user");
|
||||
assert_eq!(arr[0]["content"], "Hello");
|
||||
assert_eq!(arr[1]["role"], "assistant");
|
||||
assert_eq!(arr[1]["content"], "Hi there!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_full() {
|
||||
let history = vec![("user".to_string(), "What is the weather?".to_string())];
|
||||
let messages = OpenAIClient::build_messages(
|
||||
"You are a weather bot.",
|
||||
"Current location: NYC",
|
||||
&history,
|
||||
);
|
||||
let arr = messages.as_array().unwrap();
|
||||
assert_eq!(arr.len(), 3);
|
||||
assert_eq!(arr[0]["role"], "system");
|
||||
assert_eq!(arr[1]["role"], "system");
|
||||
assert_eq!(arr[2]["role"], "user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_client_new_default_url() {
|
||||
let client = OpenAIClient::new("test_key".to_string(), None, None);
|
||||
assert_eq!(client.base_url, "https://api.openai.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_client_new_custom_url() {
|
||||
let client = OpenAIClient::new(
|
||||
"test_key".to_string(),
|
||||
Some("http://localhost:8080".to_string()),
|
||||
None,
|
||||
);
|
||||
assert_eq!(client.base_url, "http://localhost:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_with_tool_calls() {
|
||||
let message = ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "call_1".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "search".to_string(),
|
||||
arguments: r#"{"query": "test"}"#.to_string(),
|
||||
},
|
||||
}]),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&message).unwrap();
|
||||
assert!(json.contains("tool_calls"));
|
||||
assert!(json.contains("search"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_usage_calculation() {
|
||||
let usage = Usage {
|
||||
prompt: 100,
|
||||
completion: 50,
|
||||
total: 150,
|
||||
};
|
||||
assert_eq!(usage.prompt + usage.completion, usage.total);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_choice_finish_reasons() {
|
||||
let stop_choice = ChatChoice {
|
||||
index: 0,
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Done".to_string()),
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
};
|
||||
assert_eq!(stop_choice.finish_reason, "stop");
|
||||
|
||||
let tool_choice = ChatChoice {
|
||||
index: 0,
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![]),
|
||||
},
|
||||
finish_reason: "tool_calls".to_string(),
|
||||
};
|
||||
assert_eq!(tool_choice.finish_reason, "tool_calls");
|
||||
}
|
||||
}
|
||||
|
|
@ -96,7 +96,7 @@ impl ApiRateLimiter {
|
|||
pub fn new(limits: RateLimits) -> Self {
|
||||
// Requests per minute limiter
|
||||
let rpm_quota = NonZeroU32::new(limits.requests_per_minute)
|
||||
.unwrap_or_else(|| unsafe { NonZeroU32::new_unchecked(1) });
|
||||
.unwrap_or_else(|| NonZeroU32::new(1).unwrap());
|
||||
let requests_per_minute = Arc::new(RateLimiter::direct(Quota::per_minute(rpm_quota)));
|
||||
|
||||
// Tokens per minute (using semaphore as we need to track token count)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ pub enum OptimizationGoal {
|
|||
}
|
||||
|
||||
impl OptimizationGoal {
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
pub fn from_str_name(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"speed" => Self::Speed,
|
||||
"cost" => Self::Cost,
|
||||
|
|
|
|||
|
|
@ -60,8 +60,6 @@ pub mod research;
|
|||
pub mod search;
|
||||
pub mod security;
|
||||
pub mod settings;
|
||||
#[cfg(feature = "dashboards")]
|
||||
pub mod shared;
|
||||
#[cfg(feature = "sheet")]
|
||||
pub mod sheet;
|
||||
#[cfg(feature = "slides")]
|
||||
|
|
@ -229,8 +227,9 @@ async fn main() -> std::io::Result<()> {
|
|||
if args.len() > 1 {
|
||||
let command = &args[1];
|
||||
match command.as_str() {
|
||||
"install" | "remove" | "list" | "status" | "start" | "stop" | "restart" | "--help"
|
||||
| "-h" => match crate::core::package_manager::cli::run().await {
|
||||
"install" | "remove" | "list" | "status" | "start" | "stop" | "restart"
|
||||
| "rotate-secret" | "rotate-secrets" | "vault"
|
||||
| "--version" | "-v" | "--help" | "-h" => match crate::core::package_manager::cli::run().await {
|
||||
Ok(_) => return Ok(()),
|
||||
Err(e) => {
|
||||
eprintln!("CLI error: {e}");
|
||||
|
|
|
|||
|
|
@ -216,7 +216,9 @@ pub async fn init_database(
|
|||
progress_tx.send(BootstrapProgress::ConnectingDatabase).ok();
|
||||
|
||||
// Ensure secrets manager is initialized before creating database connection
|
||||
crate::core::shared::utils::init_secrets_manager().await;
|
||||
crate::core::shared::utils::init_secrets_manager()
|
||||
.await
|
||||
.expect("Failed to initialize secrets manager");
|
||||
|
||||
let pool = match create_conn() {
|
||||
Ok(pool) => {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ pub mod recording;
|
|||
pub mod service;
|
||||
pub mod ui;
|
||||
pub mod webinar;
|
||||
pub mod webinar_api;
|
||||
pub mod webinar_types;
|
||||
pub mod whiteboard;
|
||||
pub mod whiteboard_export;
|
||||
|
|
|
|||
|
|
@ -1,35 +1,3 @@
|
|||
// Webinar API module - re-exports for backward compatibility
|
||||
// This module has been split into the webinar_api subdirectory for better organization
|
||||
use crate::meet::webinar_api::*;
|
||||
use crate::meet::webinar_types::*;
|
||||
|
||||
pub mod webinar_api {
|
||||
pub use super::webinar_api::*;
|
||||
}
|
||||
|
||||
// Re-export all public items for backward compatibility
|
||||
pub use webinar_api::{
|
||||
// Constants
|
||||
MAX_RAISED_HANDS_VISIBLE, MAX_WEBINAR_PARTICIPANTS, QA_QUESTION_MAX_LENGTH,
|
||||
|
||||
// Types
|
||||
AnswerQuestionRequest, CreatePollRequest, CreateWebinarRequest, FieldType,
|
||||
GetTranscriptionRequest, PanelistInvite, PollOption, PollStatus, PollType, PollVote,
|
||||
QAQuestion, QuestionStatus, RecordingQuality, RecordingStatus, RegisterRequest,
|
||||
RegistrationField, RegistrationStatus, RetentionPoint, RoleChangeRequest,
|
||||
StartRecordingRequest, SubmitQuestionRequest, TranscriptionFormat,
|
||||
TranscriptionSegment, TranscriptionStatus, TranscriptionWord, Webinar,
|
||||
WebinarAnalytics, WebinarEvent, WebinarEventType, WebinarParticipant,
|
||||
WebinarPoll, WebinarRecording, WebinarRegistration, WebinarSettings,
|
||||
WebinarStatus, WebinarTranscription, ParticipantRole, ParticipantStatus,
|
||||
|
||||
// Error
|
||||
WebinarError,
|
||||
|
||||
// Service
|
||||
WebinarService,
|
||||
|
||||
// Routes
|
||||
webinar_routes,
|
||||
|
||||
// Migrations
|
||||
create_webinar_tables_migration,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -74,6 +74,8 @@ static ALLOWED_COMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
|||
"systemctl",
|
||||
"sudo",
|
||||
"visudo",
|
||||
"id",
|
||||
"netsh",
|
||||
])
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -101,13 +101,13 @@ impl CorsConfig {
|
|||
Self {
|
||||
allowed_origins: vec![
|
||||
"http://localhost:3000".to_string(),
|
||||
"http://localhost:8080".to_string(),
|
||||
"http://localhost:9000".to_string(),
|
||||
"http://localhost:8300".to_string(),
|
||||
"http://127.0.0.1:3000".to_string(),
|
||||
"http://127.0.0.1:8080".to_string(),
|
||||
"http://127.0.0.1:9000".to_string(),
|
||||
"http://127.0.0.1:8300".to_string(),
|
||||
"https://localhost:3000".to_string(),
|
||||
"https://localhost:8080".to_string(),
|
||||
"https://localhost:9000".to_string(),
|
||||
"https://localhost:8300".to_string(),
|
||||
],
|
||||
allowed_methods: vec![
|
||||
|
|
@ -308,7 +308,7 @@ fn is_valid_origin_format(origin: &str) -> bool {
|
|||
return false;
|
||||
}
|
||||
|
||||
if origin.contains("..") || origin.contains("//", ) && origin.matches("//").count() > 1 {
|
||||
if origin.contains("..") || origin.matches("//").count() > 1 {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -576,7 +576,7 @@ mod tests {
|
|||
assert!(is_localhost_origin("http://localhost:3000"));
|
||||
assert!(is_localhost_origin("https://localhost:8443"));
|
||||
assert!(is_localhost_origin("http://127.0.0.1"));
|
||||
assert!(is_localhost_origin("http://127.0.0.1:8080"));
|
||||
assert!(is_localhost_origin("http://127.0.0.1:9000"));
|
||||
assert!(!is_localhost_origin("http://example.com"));
|
||||
}
|
||||
|
||||
|
|
|
|||
238
src/security/file_validation.rs
Normal file
238
src/security/file_validation.rs
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
use std::sync::LazyLock;
|
||||
|
||||
const MAX_FILE_SIZE: usize = 100 * 1024 * 1024;
|
||||
|
||||
static MAGIC_BYTES: LazyLock<Vec<(&'static [u8], &'static str)>> = LazyLock::new(|| {
|
||||
vec![
|
||||
(&[0xFF, 0xD8, 0xFF], "image/jpeg"),
|
||||
(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A], "image/png"),
|
||||
(b"GIF87a", "image/gif"),
|
||||
(b"GIF89a", "image/gif"),
|
||||
(b"BM", "image/bmp"),
|
||||
(b"II*\x00", "image/tiff"),
|
||||
(b"MM\x00*", "image/tiff"),
|
||||
(b"%PDF-", "application/pdf"),
|
||||
(b"PK\x03\x04", "application/zip"),
|
||||
(b"PK\x05\x06", "application/zip"),
|
||||
(b"PK\x07\x08", "application/zip"),
|
||||
(b"Rar!\x1A\x07", "application/vnd.rar"),
|
||||
(&[0x1F, 0x8B, 0x08], "application/gzip"),
|
||||
(b"BZh", "application/x-bzip2"),
|
||||
(&[0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00], "application/x-xz"),
|
||||
(&[0x37, 0x7A, 0xBC, 0xAF, 0x27, 0x1C], "application/7z"),
|
||||
(b"ftyp", "video/mp4"),
|
||||
(&[0x1A, 0x45, 0xDF, 0xA3], "video/webm"),
|
||||
(&[0x30, 0x26, 0xB2, 0x75, 0x8E, 0x66, 0xCF, 0x11, 0xA6, 0xD9, 0x00, 0xAA, 0x00, 0x62, 0xCE, 0x6C], "video/asf"),
|
||||
(&[0x00, 0x00, 0x00, 0x1C, 0x66, 0x74, 0x79, 0x70], "video/mp4"),
|
||||
(&[0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70], "video/mp4"),
|
||||
(b"ID3", "audio/mpeg"),
|
||||
(&[0xFF, 0xFB], "audio/mpeg"),
|
||||
(&[0xFF, 0xFA], "audio/mpeg"),
|
||||
(&[0xFF, 0xF3], "audio/mpeg"),
|
||||
(&[0xFF, 0xF2], "audio/mpeg"),
|
||||
(b"OggS", "audio/ogg"),
|
||||
(b"fLaC", "audio/flac"),
|
||||
(&[0x00, 0x00, 0x00, 0x14, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D], "audio/mp4"),
|
||||
(&[0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x6D, 0x70, 0x34, 0x32], "audio/mp4"),
|
||||
(&[0x00, 0x00, 0x00, 0x18, 0x66, 0x74, 0x79, 0x70, 0x6D, 0x70, 0x34, 0x32], "audio/mp4"),
|
||||
(&[0x00, 0x00, 0x00, 0x1C, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D], "audio/mp4"),
|
||||
(b"RIFF", "audio/wav"),
|
||||
(&[0xE0, 0x00, 0x00, 0x00], "audio/aiff"),
|
||||
]
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileValidationConfig {
|
||||
pub max_size: usize,
|
||||
pub allowed_types: Vec<String>,
|
||||
pub block_executables: bool,
|
||||
pub check_magic_bytes: bool,
|
||||
defang_pdf: bool,
|
||||
#[allow(dead_code)]
|
||||
scan_for_malware: bool,
|
||||
}
|
||||
|
||||
impl Default for FileValidationConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_size: MAX_FILE_SIZE,
|
||||
allowed_types: vec![
|
||||
"image/jpeg".into(),
|
||||
"image/png".into(),
|
||||
"image/gif".into(),
|
||||
"application/pdf".into(),
|
||||
"text/plain".into(),
|
||||
"application/zip".into(),
|
||||
],
|
||||
block_executables: true,
|
||||
check_magic_bytes: true,
|
||||
defang_pdf: true,
|
||||
scan_for_malware: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileValidationResult {
|
||||
pub is_valid: bool,
|
||||
pub detected_type: Option<String>,
|
||||
pub errors: Vec<String>,
|
||||
pub warnings: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn validate_file_upload(
|
||||
filename: &str,
|
||||
content_type: &str,
|
||||
data: &[u8],
|
||||
config: &FileValidationConfig,
|
||||
) -> FileValidationResult {
|
||||
let mut result = FileValidationResult {
|
||||
is_valid: true,
|
||||
detected_type: None,
|
||||
errors: Vec::new(),
|
||||
warnings: Vec::new(),
|
||||
};
|
||||
|
||||
if data.len() > config.max_size {
|
||||
result.is_valid = false;
|
||||
result.errors.push(format!(
|
||||
"File size {} bytes exceeds maximum allowed size of {} bytes",
|
||||
data.len(),
|
||||
config.max_size
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(extensions) = get_blocked_extensions() {
|
||||
if let Some(ext) = filename.split('.').next_back() {
|
||||
if extensions.contains(&ext.to_lowercase().as_str()) {
|
||||
result.is_valid = false;
|
||||
result.errors.push(format!(
|
||||
"File extension .{} is blocked for security reasons",
|
||||
ext
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.check_magic_bytes {
|
||||
if let Some(detected) = detect_file_type(data) {
|
||||
result.detected_type = Some(detected.clone());
|
||||
|
||||
if !config.allowed_types.is_empty() && !config.allowed_types.contains(&detected) {
|
||||
result.is_valid = false;
|
||||
result.errors.push(format!(
|
||||
"Detected file type '{}' is not in the allowed types list",
|
||||
detected
|
||||
));
|
||||
}
|
||||
|
||||
if content_type != detected && !content_type.starts_with("text/plain") && !content_type.starts_with("application/octet-stream") {
|
||||
result.warnings.push(format!(
|
||||
"Content-Type header '{}' does not match detected file type '{}'",
|
||||
content_type, detected
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.block_executables && is_potentially_executable(data) {
|
||||
result.is_valid = false;
|
||||
result.errors.push(
|
||||
"File appears to be executable or contains executable code, which is blocked".into(),
|
||||
);
|
||||
}
|
||||
|
||||
if config.defang_pdf && content_type == "application/pdf"
|
||||
&& has_potential_malicious_pdf_content(data) {
|
||||
result.warnings.push(
|
||||
"PDF file may contain potentially malicious content (JavaScript, forms, or embedded files)".into(),
|
||||
);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn detect_file_type(data: &[u8]) -> Option<String> {
|
||||
for (magic, mime_type) in MAGIC_BYTES.iter() {
|
||||
if data.starts_with(magic) {
|
||||
return Some(mime_type.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if data.starts_with(b"<") || data.starts_with(b"<!DOCTYPE") {
|
||||
if data.to_ascii_lowercase().windows(5).any(|w| w == b"<html") {
|
||||
return Some("text/html".into());
|
||||
}
|
||||
if data.windows(5).any(|w| w == b"<?xml") {
|
||||
return Some("text/xml".into());
|
||||
}
|
||||
return Some("text/plain".into());
|
||||
}
|
||||
|
||||
if data.iter().all(|&b| b.is_ascii() && !b.is_ascii_control()) {
|
||||
return Some("text/plain".into());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn get_blocked_extensions() -> Option<Vec<&'static str>> {
|
||||
Some(vec![
|
||||
"exe", "dll", "so", "dylib", "app", "deb", "rpm", "dmg", "pkg", "msi", "scr", "bat",
|
||||
"cmd", "com", "pif", "vbs", "vbe", "js", "jse", "ws", "wsf", "wsc", "wsh", "ps1",
|
||||
"ps1xml", "ps2", "ps2xml", "psc1", "psc2", "msh", "msh1", "msh2", "mshxml", "msh1xml",
|
||||
"msh2xml", "scf", "lnk", "inf", "reg", "docm", "dotm", "xlsm", "xltm", "xlam",
|
||||
"pptm", "potm", "ppam", "ppsm", "sldm", "jar", "appx", "appxbundle", "msix",
|
||||
"msixbundle", "sh", "csh", "bash", "zsh", "fish",
|
||||
])
|
||||
}
|
||||
|
||||
fn is_potentially_executable(data: &[u8]) -> bool {
|
||||
if data.len() < 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let magic = &data[0..2];
|
||||
|
||||
if matches!(magic, [0x4D, 0x5A]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if data.len() >= 4 {
|
||||
let header = &data[0..4];
|
||||
if matches!(header, [0x7F, 0x45, 0x4C, 0x46]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if data.len() >= 8 {
|
||||
let header = &data[0..8];
|
||||
if matches!(header, [0xFE, 0xED, 0xFA, 0xCF, 0x00, 0x00, 0x00, 0x01])
|
||||
|| matches!(header, [0xCF, 0xFA, 0xED, 0xFE, 0x01, 0x00, 0x00, 0x00])
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if data.len() >= 4 {
|
||||
let text_content = String::from_utf8_lossy(&data[0..data.len().min(4096)]);
|
||||
let lower = text_content.to_lowercase();
|
||||
if lower.contains("#!/bin/") || lower.contains("#!/usr/bin/") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn has_potential_malicious_pdf_content(data: &[u8]) -> bool {
|
||||
let text_content = String::from_utf8_lossy(data);
|
||||
let lower = text_content.to_lowercase();
|
||||
|
||||
lower.contains("/javascript")
|
||||
|| lower.contains("/action")
|
||||
|| lower.contains("/launch")
|
||||
|| lower.contains("/embeddedfile")
|
||||
|| lower.contains("/efilename")
|
||||
}
|
||||
|
||||
|
|
@ -35,9 +35,9 @@ impl TlsIntegration {
|
|||
services.insert(
|
||||
"api".to_string(),
|
||||
ServiceUrls {
|
||||
original: "http://localhost:8080".to_string(),
|
||||
original: "http://localhost:9000".to_string(),
|
||||
secure: "https://localhost:8443".to_string(),
|
||||
port: 8080,
|
||||
port: 9000,
|
||||
tls_port: 8443,
|
||||
},
|
||||
);
|
||||
|
|
@ -105,9 +105,9 @@ impl TlsIntegration {
|
|||
services.insert(
|
||||
"directory".to_string(),
|
||||
ServiceUrls {
|
||||
original: "http://localhost:8080".to_string(),
|
||||
original: "http://localhost:9000".to_string(),
|
||||
secure: "https://localhost:8446".to_string(),
|
||||
port: 8080,
|
||||
port: 9000,
|
||||
tls_port: 8446,
|
||||
},
|
||||
);
|
||||
|
|
|
|||
|
|
@ -512,12 +512,31 @@ impl JwtManager {
|
|||
}
|
||||
|
||||
pub async fn cleanup_blacklist(&self, _expired_before: DateTime<Utc>) -> usize {
|
||||
let mut blacklist = self.blacklist.write().await;
|
||||
let blacklist = self.blacklist.read().await;
|
||||
let initial_count = blacklist.len();
|
||||
blacklist.clear();
|
||||
let removed = initial_count;
|
||||
if removed > 0 {
|
||||
info!("Cleaned up {removed} entries from token blacklist");
|
||||
|
||||
// Store expiration times with JTIs for proper cleanup
|
||||
// For now, we need a different approach - track when tokens were revoked
|
||||
// Since we can't determine expiration from JTI alone, we'll use a time-based heuristic
|
||||
|
||||
// Proper fix: Store (JTI, expiration_time) tuples instead of just JTI strings
|
||||
// For backward compatibility, implement conservative cleanup that preserves all tokens
|
||||
// and log this limitation
|
||||
|
||||
// For production: Reimplement blacklist as HashMap<String, DateTime<Utc>>
|
||||
// to store revocation timestamp, then cleanup tokens where both revocation and
|
||||
// original expiration are before expired_before
|
||||
|
||||
// Conservative approach: don't remove anything until we have proper timestamp tracking
|
||||
// This is safe - the blacklist will grow but won't cause security issues
|
||||
let removed = 0;
|
||||
|
||||
// TODO: Reimplement blacklist storage to track revocation timestamps
|
||||
// Suggested: HashMap<String, (DateTime<Utc>, DateTime<Utc>)> storing (revoked_at, expires_at)
|
||||
// Then cleanup can check: revoked_at < expired_before AND expires_at < expired_before
|
||||
|
||||
if initial_count > 0 {
|
||||
info!("Token blacklist has {} entries (cleanup deferred pending timestamp tracking implementation)", initial_count);
|
||||
}
|
||||
removed
|
||||
}
|
||||
|
|
|
|||
33
src/security/log_sanitizer.rs
Normal file
33
src/security/log_sanitizer.rs
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
use std::sync::LazyLock;
|
||||
|
||||
static SANITIZATION_PATTERNS: LazyLock<Vec<(&'static str, &'static str)>> = LazyLock::new(|| {
|
||||
vec![
|
||||
("\n", "\\n"),
|
||||
("\r", "\\r"),
|
||||
("\t", "\\t"),
|
||||
("\\", "\\\\"),
|
||||
("\"", "\\\""),
|
||||
("'", "\\'"),
|
||||
("\x00", "\\x00"),
|
||||
("\x1B", "\\x1B"),
|
||||
]
|
||||
});
|
||||
|
||||
pub fn sanitize_for_log(input: &str) -> String {
|
||||
let mut result = input.to_string();
|
||||
|
||||
for (pattern, replacement) in SANITIZATION_PATTERNS.iter() {
|
||||
result = result.replace(pattern, replacement);
|
||||
}
|
||||
|
||||
if result.len() > 10000 {
|
||||
result.truncate(10000);
|
||||
result.push_str("... [truncated]");
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn sanitize_log_value<T: std::fmt::Display>(value: T) -> String {
|
||||
sanitize_for_log(&value.to_string())
|
||||
}
|
||||
|
|
@ -11,10 +11,12 @@ pub mod cors;
|
|||
pub mod csrf;
|
||||
pub mod dlp;
|
||||
pub mod encryption;
|
||||
pub mod file_validation;
|
||||
pub mod error_sanitizer;
|
||||
pub mod headers;
|
||||
pub mod integration;
|
||||
pub mod jwt;
|
||||
pub mod log_sanitizer;
|
||||
pub mod mfa;
|
||||
pub mod mutual_tls;
|
||||
pub mod panic_handler;
|
||||
|
|
@ -25,11 +27,15 @@ pub mod panic_handler;
|
|||
// pub mod passkey_types;
|
||||
pub mod password;
|
||||
pub mod path_guard;
|
||||
pub mod redis_csrf_store;
|
||||
pub mod redis_session_store;
|
||||
pub mod prompt_security;
|
||||
pub mod protection;
|
||||
pub mod rate_limiter;
|
||||
pub mod rbac_middleware;
|
||||
pub mod request_id;
|
||||
pub mod request_limits;
|
||||
pub mod safe_unwrap;
|
||||
pub mod secrets;
|
||||
pub mod security_monitoring;
|
||||
pub mod session;
|
||||
|
|
@ -167,9 +173,23 @@ pub use tls::{create_https_server, ServiceTlsConfig, TlsConfig, TlsManager, TlsR
|
|||
pub use validation::{
|
||||
sanitize_html, strip_html_tags, validate_alphanumeric, validate_email, validate_length,
|
||||
validate_no_html, validate_no_script_injection, validate_one_of, validate_password_strength,
|
||||
validate_phone, validate_range, validate_required, validate_slug, validate_url,
|
||||
validate_phone, validate_range, validate_required, validate_slug, validate_url, validate_url_ssrf,
|
||||
validate_username, validate_uuid, ValidationError, ValidationResult, Validator,
|
||||
};
|
||||
pub use file_validation::{
|
||||
FileValidationConfig, FileValidationResult, validate_file_upload,
|
||||
};
|
||||
pub use request_limits::{
|
||||
request_size_middleware, upload_size_middleware, DEFAULT_MAX_REQUEST_SIZE, MAX_UPLOAD_SIZE,
|
||||
};
|
||||
pub use log_sanitizer::sanitize_log_value as sanitize_log_value_compact;
|
||||
|
||||
#[cfg(feature = "cache")]
|
||||
pub use redis_session_store::RedisSessionStore;
|
||||
|
||||
#[cfg(feature = "cache")]
|
||||
pub use redis_csrf_store::RedisCsrfManager;
|
||||
pub use safe_unwrap::{safe_unwrap_or, safe_unwrap_or_default, safe_unwrap_none_or};
|
||||
|
||||
use anyhow::Result;
|
||||
use std::path::PathBuf;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
use anyhow::{Context, Result};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::security::command_guard::SafeCommand;
|
||||
|
|
@ -84,12 +83,12 @@ impl ProtectionInstaller {
|
|||
|
||||
#[cfg(windows)]
|
||||
pub fn check_admin() -> bool {
|
||||
let result = Command::new("powershell")
|
||||
.args([
|
||||
let result = SafeCommand::new("powershell")
|
||||
.and_then(|cmd| cmd.args(&[
|
||||
"-Command",
|
||||
"([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()).IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)"
|
||||
])
|
||||
.output();
|
||||
]))
|
||||
.and_then(|cmd| cmd.execute());
|
||||
|
||||
match result {
|
||||
Ok(output) => {
|
||||
|
|
@ -102,9 +101,9 @@ impl ProtectionInstaller {
|
|||
|
||||
#[cfg(not(windows))]
|
||||
pub fn check_root() -> bool {
|
||||
Command::new("id")
|
||||
.arg("-u")
|
||||
.output()
|
||||
SafeCommand::new("id")
|
||||
.and_then(|cmd| cmd.arg("-u"))
|
||||
.and_then(|cmd| cmd.execute())
|
||||
.map(|o| String::from_utf8_lossy(&o.stdout).trim() == "0")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
|
@ -268,26 +267,23 @@ impl ProtectionInstaller {
|
|||
fn configure_windows_security(&self) -> Result<()> {
|
||||
info!("Configuring Windows security settings...");
|
||||
|
||||
// Enable Windows Defender real-time protection
|
||||
let _ = Command::new("powershell")
|
||||
.args([
|
||||
let _ = SafeCommand::new("powershell")
|
||||
.and_then(|cmd| cmd.args(&[
|
||||
"-Command",
|
||||
"Set-MpPreference -DisableRealtimeMonitoring $false; Set-MpPreference -DisableIOAVProtection $false; Set-MpPreference -DisableScriptScanning $false"
|
||||
])
|
||||
.output();
|
||||
]))
|
||||
.and_then(|cmd| cmd.execute());
|
||||
|
||||
// Enable Windows Firewall
|
||||
let _ = Command::new("netsh")
|
||||
.args(["advfirewall", "set", "allprofiles", "state", "on"])
|
||||
.output();
|
||||
let _ = SafeCommand::new("netsh")
|
||||
.and_then(|cmd| cmd.args(&["advfirewall", "set", "allprofiles", "state", "on"]))
|
||||
.and_then(|cmd| cmd.execute());
|
||||
|
||||
// Enable Windows Defender scanning for mapped drives
|
||||
let _ = Command::new("powershell")
|
||||
.args([
|
||||
let _ = SafeCommand::new("powershell")
|
||||
.and_then(|cmd| cmd.args(&[
|
||||
"-Command",
|
||||
"Set-MpPreference -DisableRemovableDriveScanning $false -DisableScanningMappedNetworkDrivesForFullScan $false"
|
||||
])
|
||||
.output();
|
||||
]))
|
||||
.and_then(|cmd| cmd.execute());
|
||||
|
||||
info!("Windows security configuration completed");
|
||||
Ok(())
|
||||
|
|
@ -313,12 +309,11 @@ impl ProtectionInstaller {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[cfg(not(windows))]
|
||||
fn validate_sudoers(&self) -> Result<()> {
|
||||
let output = std::process::Command::new("visudo")
|
||||
.args(["-c", "-f", SUDOERS_FILE])
|
||||
.output()
|
||||
let output = SafeCommand::new("visudo")
|
||||
.and_then(|cmd| cmd.args(&["-c", "-f", SUDOERS_FILE]))
|
||||
.and_then(|cmd| cmd.execute())
|
||||
.context("Failed to run visudo validation")?;
|
||||
|
||||
if !output.status.success() {
|
||||
|
|
@ -330,7 +325,6 @@ impl ProtectionInstaller {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[cfg(not(windows))]
|
||||
fn install_lmd(&self) -> Result<bool> {
|
||||
let maldet_path = Path::new("/usr/local/sbin/maldet");
|
||||
|
|
@ -398,7 +392,6 @@ impl ProtectionInstaller {
|
|||
Ok(true)
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[cfg(not(windows))]
|
||||
fn update_databases(&self) -> Result<()> {
|
||||
info!("Updating security tool databases...");
|
||||
|
|
@ -442,12 +435,12 @@ impl ProtectionInstaller {
|
|||
fn update_windows_signatures(&self) -> Result<()> {
|
||||
info!("Updating Windows Defender signatures...");
|
||||
|
||||
let result = Command::new("powershell")
|
||||
.args([
|
||||
let result = SafeCommand::new("powershell")
|
||||
.and_then(|cmd| cmd.args(&[
|
||||
"-Command",
|
||||
"Update-MpSignature; Write-Host 'Windows Defender signatures updated'",
|
||||
])
|
||||
.output();
|
||||
]))
|
||||
.and_then(|cmd| cmd.execute());
|
||||
|
||||
match result {
|
||||
Ok(output) => {
|
||||
|
|
@ -571,13 +564,9 @@ impl ProtectionInstaller {
|
|||
#[cfg(windows)]
|
||||
{
|
||||
for (tool_name, tool_cmd) in WINDOWS_TOOLS {
|
||||
let check = Command::new(tool_cmd)
|
||||
.arg("--version")
|
||||
.or_else(|_| {
|
||||
Command::new("powershell")
|
||||
.args(["-Command", &format!("Get-Command {}", tool_cmd)])
|
||||
})
|
||||
.output();
|
||||
let check = SafeCommand::new(tool_cmd)
|
||||
.and_then(|cmd| cmd.arg("--version"))
|
||||
.and_then(|cmd| cmd.execute());
|
||||
|
||||
let installed = check.map(|o| o.status.success()).unwrap_or(false);
|
||||
result.tools.push(ToolVerification {
|
||||
|
|
|
|||
|
|
@ -68,11 +68,20 @@ pub struct CombinedRateLimiter {
|
|||
|
||||
impl CombinedRateLimiter {
|
||||
pub fn new(http_config: HttpRateLimitConfig, system_limits: SystemLimits) -> Self {
|
||||
const DEFAULT_RPS: NonZeroU32 = match NonZeroU32::new(100) {
|
||||
Some(v) => v,
|
||||
None => unreachable!(),
|
||||
};
|
||||
const DEFAULT_BURST: NonZeroU32 = match NonZeroU32::new(200) {
|
||||
Some(v) => v,
|
||||
None => unreachable!(),
|
||||
};
|
||||
|
||||
let quota = Quota::per_second(
|
||||
NonZeroU32::new(http_config.requests_per_second).unwrap_or(NonZeroU32::new(100).expect("100 is non-zero")),
|
||||
NonZeroU32::new(http_config.requests_per_second).unwrap_or(DEFAULT_RPS),
|
||||
)
|
||||
.allow_burst(
|
||||
NonZeroU32::new(http_config.burst_size).unwrap_or(NonZeroU32::new(200).expect("200 is non-zero")),
|
||||
NonZeroU32::new(http_config.burst_size).unwrap_or(DEFAULT_BURST),
|
||||
);
|
||||
|
||||
Self {
|
||||
|
|
|
|||
208
src/security/redis_csrf_store.rs
Normal file
208
src/security/redis_csrf_store.rs
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use std::sync::Arc;
|
||||
use super::csrf::{CsrfToken, CsrfValidationResult, CsrfConfig};
|
||||
|
||||
const CSRF_KEY_PREFIX: &str = "csrf:";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RedisCsrfStore {
|
||||
client: Arc<redis::Client>,
|
||||
config: CsrfConfig,
|
||||
}
|
||||
|
||||
impl RedisCsrfStore {
|
||||
pub async fn new(redis_url: &str, config: CsrfConfig) -> Result<Self> {
|
||||
let client = redis::Client::open(redis_url)
|
||||
.map_err(|e| anyhow!("Failed to create Redis client: {}", e))?;
|
||||
|
||||
let _ = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
Ok(Self {
|
||||
client: Arc::new(client),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn token_key(&self, token: &str) -> String {
|
||||
format!("{}{}", CSRF_KEY_PREFIX, token)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RedisCsrfManager {
|
||||
store: RedisCsrfStore,
|
||||
#[allow(dead_code)]
|
||||
secret: Vec<u8>,
|
||||
}
|
||||
|
||||
impl RedisCsrfManager {
|
||||
pub async fn new(redis_url: &str, config: CsrfConfig, secret: &[u8]) -> Result<Self> {
|
||||
if secret.len() < 32 {
|
||||
return Err(anyhow!("CSRF secret must be at least 32 bytes"));
|
||||
}
|
||||
|
||||
let store = RedisCsrfStore::new(redis_url, config).await?;
|
||||
|
||||
Ok(Self {
|
||||
store,
|
||||
secret: secret.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn generate_token(&self) -> Result<CsrfToken> {
|
||||
let token = CsrfToken::new(self.store.config.token_expiry_minutes);
|
||||
let key = self.store.token_key(&token.token);
|
||||
let value = serde_json::to_string(&token)?;
|
||||
let ttl_secs = self.store.config.token_expiry_minutes * 60;
|
||||
|
||||
let client = self.store.client.clone();
|
||||
let mut conn = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
redis::cmd("SETEX")
|
||||
.arg(&key)
|
||||
.arg(ttl_secs)
|
||||
.arg(&value)
|
||||
.query_async::<()>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to create CSRF token: {}", e))?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub async fn generate_token_with_session(&self, session_id: &str) -> Result<CsrfToken> {
|
||||
let token = CsrfToken::new(self.store.config.token_expiry_minutes)
|
||||
.with_session(session_id.to_string());
|
||||
let key = self.store.token_key(&token.token);
|
||||
let value = serde_json::to_string(&token)?;
|
||||
let ttl_secs = self.store.config.token_expiry_minutes * 60;
|
||||
|
||||
let client = self.store.client.clone();
|
||||
let mut conn = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
redis::cmd("SETEX")
|
||||
.arg(&key)
|
||||
.arg(ttl_secs)
|
||||
.arg(&value)
|
||||
.query_async::<()>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to create CSRF token: {}", e))?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub async fn validate_token(&self, token_value: &str) -> CsrfValidationResult {
|
||||
if token_value.is_empty() {
|
||||
return CsrfValidationResult::Missing;
|
||||
}
|
||||
|
||||
let client = self.store.client.clone();
|
||||
let key = self.store.token_key(token_value);
|
||||
|
||||
let mut conn = match client.get_multiplexed_async_connection().await {
|
||||
Ok(c) => c,
|
||||
Err(_) => return CsrfValidationResult::Invalid,
|
||||
};
|
||||
|
||||
let value: Option<String> = match redis::cmd("GET")
|
||||
.arg(&key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(_) => return CsrfValidationResult::Invalid,
|
||||
};
|
||||
|
||||
match value {
|
||||
Some(v) => {
|
||||
let token: CsrfToken = match serde_json::from_str(&v) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return CsrfValidationResult::Invalid,
|
||||
};
|
||||
|
||||
if token.is_expired() {
|
||||
CsrfValidationResult::Expired
|
||||
} else {
|
||||
CsrfValidationResult::Valid
|
||||
}
|
||||
}
|
||||
None => CsrfValidationResult::Invalid,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn validate_token_with_session(
|
||||
&self,
|
||||
token_value: &str,
|
||||
session_id: &str,
|
||||
) -> CsrfValidationResult {
|
||||
if token_value.is_empty() {
|
||||
return CsrfValidationResult::Missing;
|
||||
}
|
||||
|
||||
let client = self.store.client.clone();
|
||||
let key = self.store.token_key(token_value);
|
||||
|
||||
let mut conn = match client.get_multiplexed_async_connection().await {
|
||||
Ok(c) => c,
|
||||
Err(_) => return CsrfValidationResult::Invalid,
|
||||
};
|
||||
|
||||
let value: Option<String> = match redis::cmd("GET")
|
||||
.arg(&key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(_) => return CsrfValidationResult::Invalid,
|
||||
};
|
||||
|
||||
match value {
|
||||
Some(v) => {
|
||||
let token: CsrfToken = match serde_json::from_str(&v) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return CsrfValidationResult::Invalid,
|
||||
};
|
||||
|
||||
if token.is_expired() {
|
||||
return CsrfValidationResult::Expired;
|
||||
}
|
||||
|
||||
match &token.session_id {
|
||||
Some(sid) if sid == session_id => CsrfValidationResult::Valid,
|
||||
Some(_) => CsrfValidationResult::SessionMismatch,
|
||||
None => CsrfValidationResult::Valid,
|
||||
}
|
||||
}
|
||||
None => CsrfValidationResult::Invalid,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn revoke_token(&self, token_value: &str) -> Result<()> {
|
||||
let client = self.store.client.clone();
|
||||
let key = self.store.token_key(token_value);
|
||||
|
||||
let mut conn = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
redis::cmd("DEL")
|
||||
.arg(&key)
|
||||
.query_async::<()>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to revoke CSRF token: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired(&self) -> Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
185
src/security/redis_session_store.rs
Normal file
185
src/security/redis_session_store.rs
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::session::{Session, SessionStore};
|
||||
|
||||
const SESSION_KEY_PREFIX: &str = "session:";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RedisSessionStore {
|
||||
client: Arc<redis::Client>,
|
||||
}
|
||||
|
||||
impl RedisSessionStore {
|
||||
pub async fn new(redis_url: &str) -> Result<Self> {
|
||||
let client = redis::Client::open(redis_url)
|
||||
.map_err(|e| anyhow!("Failed to create Redis client: {}", e))?;
|
||||
|
||||
let _ = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
Ok(Self {
|
||||
client: Arc::new(client),
|
||||
})
|
||||
}
|
||||
|
||||
fn session_key(&self, session_id: &str) -> String {
|
||||
format!("{}{}", SESSION_KEY_PREFIX, session_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionStore for RedisSessionStore {
|
||||
fn create(&self, session: Session) -> impl std::future::Future<Output = Result<()>> + Send {
|
||||
let client = self.client.clone();
|
||||
let key = self.session_key(&session.id);
|
||||
let ttl = session.time_until_expiry();
|
||||
let ttl_secs = ttl.num_seconds().max(0) as usize;
|
||||
|
||||
async move {
|
||||
let mut conn = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
let value = serde_json::to_string(&session)?;
|
||||
|
||||
redis::cmd("SETEX")
|
||||
.arg(&key)
|
||||
.arg(ttl_secs)
|
||||
.arg(&value)
|
||||
.query_async::<()>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to create session: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, session_id: &str) -> impl std::future::Future<Output = Result<Option<Session>>> + Send {
|
||||
let client = self.client.clone();
|
||||
let key = self.session_key(session_id);
|
||||
|
||||
async move {
|
||||
let mut conn = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
let value: Option<String> = redis::cmd("GET")
|
||||
.arg(&key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get session: {}", e))?;
|
||||
|
||||
match value {
|
||||
Some(v) => {
|
||||
let session: Session = serde_json::from_str(&v)
|
||||
.map_err(|e| anyhow!("Failed to deserialize session: {}", e))?;
|
||||
Ok(Some(session))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&self, session: &Session) -> impl std::future::Future<Output = Result<()>> + Send {
|
||||
let client = self.client.clone();
|
||||
let key = self.session_key(&session.id);
|
||||
let session = session.clone();
|
||||
let ttl = session.time_until_expiry();
|
||||
let ttl_secs = ttl.num_seconds().max(0) as usize;
|
||||
|
||||
async move {
|
||||
let mut conn = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
let value = serde_json::to_string(&session)?;
|
||||
|
||||
redis::cmd("SETEX")
|
||||
.arg(&key)
|
||||
.arg(ttl_secs)
|
||||
.arg(&value)
|
||||
.query_async::<()>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to update session: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn delete(&self, session_id: &str) -> impl std::future::Future<Output = Result<()>> + Send {
|
||||
let client = self.client.clone();
|
||||
let key = self.session_key(session_id);
|
||||
|
||||
async move {
|
||||
let mut conn = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
redis::cmd("DEL")
|
||||
.arg(&key)
|
||||
.query_async::<()>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to delete session: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_user_sessions(&self, user_id: uuid::Uuid) -> impl std::future::Future<Output = Result<Vec<Session>>> + Send {
|
||||
let client = self.client.clone();
|
||||
let prefix = SESSION_KEY_PREFIX.to_string();
|
||||
|
||||
async move {
|
||||
let mut conn = client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Redis connection error: {}", e))?;
|
||||
|
||||
let pattern = format!("{}*", prefix);
|
||||
let keys: Vec<String> = redis::cmd("KEYS")
|
||||
.arg(&pattern)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to list sessions: {}", e))?;
|
||||
|
||||
let mut sessions = Vec::new();
|
||||
|
||||
for key in keys {
|
||||
let session_id = key.trim_start_matches(&prefix);
|
||||
let store = Self { client: client.clone() };
|
||||
if let Ok(Some(session)) = store.get(session_id).await {
|
||||
if session.user_id == user_id && session.is_valid() {
|
||||
sessions.push(session);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(sessions)
|
||||
}
|
||||
}
|
||||
|
||||
fn delete_user_sessions(&self, user_id: uuid::Uuid) -> impl std::future::Future<Output = Result<usize>> + Send {
|
||||
let client = self.client.clone();
|
||||
|
||||
async move {
|
||||
let sessions = Self { client: client.clone() }.get_user_sessions(user_id).await?;
|
||||
let count = sessions.len();
|
||||
|
||||
for session in sessions {
|
||||
Self { client: client.clone() }.delete(&session.id).await?;
|
||||
}
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&self) -> Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
66
src/security/request_limits.rs
Normal file
66
src/security/request_limits.rs
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
pub const DEFAULT_MAX_REQUEST_SIZE: usize = 10 * 1024 * 1024;
|
||||
|
||||
pub const MAX_UPLOAD_SIZE: usize = 100 * 1024 * 1024;
|
||||
|
||||
pub async fn request_size_middleware(
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let content_length = req
|
||||
.headers()
|
||||
.get("content-length")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<usize>().ok());
|
||||
|
||||
if let Some(len) = content_length {
|
||||
if len > DEFAULT_MAX_REQUEST_SIZE {
|
||||
return (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
axum::Json(serde_json::json!({
|
||||
"error": "request_too_large",
|
||||
"message": format!("Request body {} bytes exceeds maximum {}", len, DEFAULT_MAX_REQUEST_SIZE),
|
||||
"max_size": DEFAULT_MAX_REQUEST_SIZE
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
pub async fn upload_size_middleware(
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let content_length = req
|
||||
.headers()
|
||||
.get("content-length")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<usize>().ok());
|
||||
|
||||
if let Some(len) = content_length {
|
||||
if len > MAX_UPLOAD_SIZE {
|
||||
return (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
axum::Json(serde_json::json!({
|
||||
"error": "upload_too_large",
|
||||
"message": format!("Upload {} bytes exceeds maximum {}", len, MAX_UPLOAD_SIZE),
|
||||
"max_size": MAX_UPLOAD_SIZE
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
23
src/security/safe_unwrap.rs
Normal file
23
src/security/safe_unwrap.rs
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
pub fn safe_unwrap_or_default<T: Default>(result: Result<T, impl std::fmt::Display>, context: &str) -> T {
|
||||
result.unwrap_or_else(|e| {
|
||||
tracing::error!("{}: {}", context, e);
|
||||
T::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn safe_unwrap_or<T>(result: Result<T, impl std::fmt::Display>, context: &str, default: T) -> T {
|
||||
result.unwrap_or_else(|e| {
|
||||
tracing::error!("{}: {}", context, e);
|
||||
default
|
||||
})
|
||||
}
|
||||
|
||||
pub fn safe_unwrap_none_or<T>(result: Result<T, impl std::fmt::Display>, context: &str, value: T) -> T {
|
||||
match result {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::error!("{}: {}", context, e);
|
||||
value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -436,6 +436,42 @@ impl<S: SessionStore> SessionManager<S> {
|
|||
Ok(sessions.into_iter().filter(|s| s.is_valid()).collect())
|
||||
}
|
||||
|
||||
pub async fn regenerate_session(&self, old_session_id: &str, ip_address: Option<String>, user_agent: Option<&str>) -> Result<Option<Session>> {
|
||||
let old_session = match self.store.get(old_session_id).await? {
|
||||
Some(s) if s.is_valid() => s,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let user_id = old_session.user_id;
|
||||
|
||||
let mut new_session = Session::new(user_id, &self.config)
|
||||
.with_remember_me(old_session.remember_me)
|
||||
.with_metadata("regenerated_from".to_string(), old_session.id.clone());
|
||||
|
||||
if let Some(ip) = ip_address {
|
||||
new_session = new_session.with_ip(ip);
|
||||
}
|
||||
|
||||
if self.config.enable_device_tracking {
|
||||
if let Some(ua) = user_agent {
|
||||
new_session = new_session.with_device(DeviceInfo::from_user_agent(ua));
|
||||
}
|
||||
}
|
||||
|
||||
for (key, value) in old_session.metadata {
|
||||
if key != "regenerated_from" {
|
||||
new_session = new_session.with_metadata(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
self.store.delete(old_session_id).await?;
|
||||
self.store.create(new_session.clone()).await?;
|
||||
|
||||
info!("Regenerated session {} -> {} for user {user_id}", old_session_id, new_session.id);
|
||||
|
||||
Ok(Some(new_session))
|
||||
}
|
||||
|
||||
pub async fn invalidate_on_password_change(&self, user_id: Uuid) -> Result<usize> {
|
||||
let count = self.store.delete_user_sessions(user_id).await?;
|
||||
info!("Invalidated {count} sessions for user {user_id} due to password change");
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue