From 479950945bdaabb0cc75fd1bba7195fdb3964d8c Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Tue, 6 Jan 2026 22:56:35 -0300 Subject: [PATCH] feat(auth): Add OTP password display on bootstrap and fix Zitadel login flow - Add generate_secure_password() for OTP generation during admin bootstrap - Display admin credentials (username/password) in console on first run - Save credentials to ~/.gb-setup-credentials file - Fix Zitadel client to support PAT token authentication - Replace OAuth2 password grant with Zitadel Session API for login - Fix get_current_user to fetch user data from Zitadel session - Return session_id as access_token for proper authentication - Set email as verified on user creation to skip verification - Add password grant type to OAuth application config - Update directory_setup to include proper redirect URIs --- .product | 42 + Cargo.toml | 5 +- config/directory_config.json | 20 + src/attendance/mod.rs | 15 +- src/auto_task/mod.rs | 79 +- src/basic/keywords/app_server.rs | 25 +- src/basic/keywords/db_api.rs | 40 +- src/calendar/mod.rs | 2 +- src/core/bot/channels/mod.rs | 1 + src/core/bot/channels/telegram.rs | 324 ++ src/core/i18n.rs | 921 ++++++ src/core/mod.rs | 2 + src/core/oauth/routes.rs | 4 +- .../package_manager/setup/directory_setup.rs | 10 +- src/core/product.rs | 452 +++ src/core/secrets/mod.rs | 2 +- src/core/shared/admin.rs | 507 ++- src/core/urls.rs | 45 +- src/designer/mod.rs | 81 +- src/directory/auth_routes.rs | 845 +++++ src/directory/bootstrap.rs | 356 ++ src/directory/client.rs | 68 +- src/directory/mod.rs | 2 + src/directory/router.rs | 67 +- src/directory/users.rs | 496 ++- src/docs/mod.rs | 1479 +++++++++ src/drive/drive_monitor/mod.rs | 32 +- src/email/mod.rs | 6 +- src/email/vectordb.rs | 17 +- src/lib.rs | 6 + src/llm/local.rs | 16 +- src/main.rs | 191 +- src/meet/mod.rs | 17 +- src/paper/mod.rs | 83 +- src/research/mod.rs | 2 +- src/security/zitadel_auth.rs | 4 +- src/sheet/mod.rs | 2854 +++++++++++++++++ src/slides/mod.rs | 1360 ++++++++ src/sources/knowledge_base.rs | 5 +- src/sources/mod.rs | 33 +- src/tasks/mod.rs | 22 +- src/telegram/mod.rs | 539 ++++ 42 files changed, 10666 insertions(+), 411 deletions(-) create mode 100644 .product create mode 100644 config/directory_config.json create mode 100644 src/core/bot/channels/telegram.rs create mode 100644 src/core/i18n.rs create mode 100644 src/core/product.rs create mode 100644 src/directory/auth_routes.rs create mode 100644 src/directory/bootstrap.rs create mode 100644 src/docs/mod.rs create mode 100644 src/sheet/mod.rs create mode 100644 src/slides/mod.rs create mode 100644 src/telegram/mod.rs diff --git a/.product b/.product new file mode 100644 index 000000000..352a9a2f4 --- /dev/null +++ b/.product @@ -0,0 +1,42 @@ +# Product Configuration File +# This file defines white-label settings for the application. +# +# All occurrences of "General Bots" will be replaced by the 'name' value. +# Only apps listed in 'apps' will be active in the suite (and their APIs enabled). +# The 'theme' value sets the default theme for the UI. + +# Product name (replaces "General Bots" throughout the application) +name=General Bots + +# Active apps (comma-separated list) +# Available apps: chat, mail, calendar, drive, tasks, docs, paper, sheet, slides, +# meet, research, sources, analytics, admin, monitoring, settings +# Only listed apps will be visible in the UI and have their APIs enabled. +apps=chat,mail,calendar,drive,tasks,docs,paper,sheet,slides,meet,research,sources,analytics,admin,monitoring,settings + +# Default theme +# Available themes: dark, light, blue, purple, green, orange, sentient, cyberpunk, +# retrowave, vapordream, y2kglow, arcadeflash, discofever, grungeera, +# jazzage, mellowgold, midcenturymod, polaroidmemories, saturdaycartoons, +# seasidepostcard, typewriter, 3dbevel, xeroxui, xtreegold +theme=sentient + +# Logo URL (optional - leave empty to use default) +# Can be a relative path or absolute URL +logo= + +# Favicon URL (optional - leave empty to use default) +favicon= + +# Primary color override (optional - hex color code) +# Example: #d4f505 +primary_color= + +# Support email (optional) +support_email= + +# Documentation URL (optional) +docs_url=https://docs.pragmatismo.com.br + +# Copyright text (optional - {year} will be replaced with current year) +copyright=© {year} {name}. All rights reserved. diff --git a/Cargo.toml b/Cargo.toml index 8b0dedcf9..eb93b76a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,11 +40,11 @@ repository = "https://github.com/GeneralBots/BotServer" [dependencies.botlib] path = "../botlib" -features = ["database"] +features = ["database", "i18n"] [features] # ===== DEFAULT FEATURE SET ===== -default = ["console", "chat", "automation", "tasks", "drive", "llm", "cache", "progress-bars", "directory", "calendar", "meet", "email"] +default = ["console", "chat", "automation", "tasks", "drive", "llm", "cache", "progress-bars", "directory", "calendar", "meet", "email", "whatsapp", "telegram"] # ===== UI FEATURES ===== console = ["dep:crossterm", "dep:ratatui", "monitoring"] @@ -57,6 +57,7 @@ nvidia = [] # ===== COMMUNICATION CHANNELS ===== email = ["dep:imap", "dep:lettre", "dep:mailparse", "dep:native-tls"] whatsapp = [] +telegram = [] instagram = [] msteams = [] diff --git a/config/directory_config.json b/config/directory_config.json new file mode 100644 index 000000000..0f6a53f62 --- /dev/null +++ b/config/directory_config.json @@ -0,0 +1,20 @@ +{ + "base_url": "http://localhost:8300", + "default_org": { + "id": "354422182425657358", + "name": "default", + "domain": "default.localhost" + }, + "default_user": { + "id": "admin", + "username": "admin", + "email": "admin@localhost", + "password": "", + "first_name": "Admin", + "last_name": "User" + }, + "admin_token": "DNSctgJla8Kl3rWXa1Pk6vqbeiRGixGLfDhQ80m0fNI5H-5Lh4NJBs68bMwFFleh14Xtsto", + "project_id": "354422182828310542", + "client_id": "354423066903773198", + "client_secret": "hsUDIhIA0aaDD52mpzci12DR1ot8g7x1T1DoTJmVzIQ3Y273eDEWYFXiN6pcTVJf" +} diff --git a/src/attendance/mod.rs b/src/attendance/mod.rs index 4cd5958e7..fe8ebf869 100644 --- a/src/attendance/mod.rs +++ b/src/attendance/mod.rs @@ -53,10 +53,7 @@ pub fn configure_attendance_routes() -> Router> { ApiUrls::ATTENDANCE_TRANSFER, post(queue::transfer_conversation), ) - .route( - &ApiUrls::ATTENDANCE_RESOLVE.replace(":session_id", "{session_id}"), - post(queue::resolve_conversation), - ) + .route(ApiUrls::ATTENDANCE_RESOLVE, post(queue::resolve_conversation)) .route(ApiUrls::ATTENDANCE_INSIGHTS, get(queue::get_insights)) .route(ApiUrls::ATTENDANCE_RESPOND, post(attendant_respond)) .route(ApiUrls::WS_ATTENDANT, get(attendant_websocket_handler)) @@ -72,18 +69,12 @@ pub fn configure_attendance_routes() -> Router> { ApiUrls::ATTENDANCE_LLM_SMART_REPLIES, post(llm_assist::generate_smart_replies), ) - .route( - &ApiUrls::ATTENDANCE_LLM_SUMMARY.replace(":session_id", "{session_id}"), - get(llm_assist::generate_summary), - ) + .route(ApiUrls::ATTENDANCE_LLM_SUMMARY, get(llm_assist::generate_summary)) .route( ApiUrls::ATTENDANCE_LLM_SENTIMENT, post(llm_assist::analyze_sentiment), ) - .route( - &ApiUrls::ATTENDANCE_LLM_CONFIG.replace(":bot_id", "{bot_id}"), - get(llm_assist::get_llm_config), - ) + .route(ApiUrls::ATTENDANCE_LLM_CONFIG, get(llm_assist::get_llm_config)) } #[derive(Debug, Deserialize)] diff --git a/src/auto_task/mod.rs b/src/auto_task/mod.rs index 9104d9e2b..e1e598b6e 100644 --- a/src/auto_task/mod.rs +++ b/src/auto_task/mod.rs @@ -61,76 +61,31 @@ pub fn configure_autotask_routes() -> axum::Router { let content = body.into_bytes(); - let content_type = get_content_type(&file_path); return Response::builder() .status(StatusCode::OK) diff --git a/src/basic/keywords/db_api.rs b/src/basic/keywords/db_api.rs index a7d24b939..e6aa45a97 100644 --- a/src/basic/keywords/db_api.rs +++ b/src/basic/keywords/db_api.rs @@ -12,7 +12,7 @@ use axum::{ extract::{Path, Query, State}, http::{HeaderMap, StatusCode}, response::IntoResponse, - routing::{delete, get, post, put}, + routing::{get, post}, Json, Router, }; use diesel::prelude::*; @@ -79,40 +79,10 @@ pub struct DeleteResponse { pub fn configure_db_routes() -> Router> { Router::new() - .route( - &ApiUrls::DB_TABLE.replace(":table", "{table}"), - get(list_records_handler), - ) - .route( - &ApiUrls::DB_TABLE.replace(":table", "{table}"), - post(create_record_handler), - ) - .route( - &ApiUrls::DB_TABLE_RECORD - .replace(":table", "{table}") - .replace(":id", "{id}"), - get(get_record_handler), - ) - .route( - &ApiUrls::DB_TABLE_RECORD - .replace(":table", "{table}") - .replace(":id", "{id}"), - put(update_record_handler), - ) - .route( - &ApiUrls::DB_TABLE_RECORD - .replace(":table", "{table}") - .replace(":id", "{id}"), - delete(delete_record_handler), - ) - .route( - &ApiUrls::DB_TABLE_COUNT.replace(":table", "{table}"), - get(count_records_handler), - ) - .route( - &ApiUrls::DB_TABLE_SEARCH.replace(":table", "{table}"), - post(search_records_handler), - ) + .route(ApiUrls::DB_TABLE, get(list_records_handler).post(create_record_handler)) + .route(ApiUrls::DB_TABLE_RECORD, get(get_record_handler).put(update_record_handler).delete(delete_record_handler)) + .route(ApiUrls::DB_TABLE_COUNT, get(count_records_handler)) + .route(ApiUrls::DB_TABLE_SEARCH, post(search_records_handler)) } pub async fn list_records_handler( diff --git a/src/calendar/mod.rs b/src/calendar/mod.rs index 6c6e34ef6..290b251a5 100644 --- a/src/calendar/mod.rs +++ b/src/calendar/mod.rs @@ -521,7 +521,7 @@ pub fn configure_calendar_routes() -> Router> { get(list_events).post(create_event), ) .route( - &ApiUrls::CALENDAR_EVENT_BY_ID.replace(":id", "{id}"), + ApiUrls::CALENDAR_EVENT_BY_ID, get(get_event).put(update_event).delete(delete_event), ) .route(ApiUrls::CALENDAR_EXPORT, get(export_ical)) diff --git a/src/core/bot/channels/mod.rs b/src/core/bot/channels/mod.rs index ef3d828af..613fa0240 100644 --- a/src/core/bot/channels/mod.rs +++ b/src/core/bot/channels/mod.rs @@ -1,5 +1,6 @@ pub mod instagram; pub mod teams; +pub mod telegram; pub mod whatsapp; use crate::shared::models::BotResponse; diff --git a/src/core/bot/channels/telegram.rs b/src/core/bot/channels/telegram.rs new file mode 100644 index 000000000..20c73e3bb --- /dev/null +++ b/src/core/bot/channels/telegram.rs @@ -0,0 +1,324 @@ +use async_trait::async_trait; +use diesel::prelude::*; +use diesel::r2d2::{ConnectionManager, Pool}; +use log::{debug, error, info}; +use serde::{Deserialize, Serialize}; + +use crate::core::bot::channels::ChannelAdapter; +use crate::core::config::ConfigManager; +use crate::shared::models::BotResponse; + +#[derive(Debug, Serialize)] +struct TelegramSendMessage { + chat_id: String, + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + parse_mode: Option, + #[serde(skip_serializing_if = "Option::is_none")] + reply_markup: Option, +} + +#[derive(Debug, Serialize)] +struct TelegramReplyMarkup { + #[serde(skip_serializing_if = "Option::is_none")] + inline_keyboard: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + keyboard: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + one_time_keyboard: Option, + #[serde(skip_serializing_if = "Option::is_none")] + resize_keyboard: Option, +} + +#[derive(Debug, Serialize)] +struct TelegramInlineButton { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + callback_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + url: Option, +} + +#[derive(Debug, Serialize)] +struct TelegramKeyboardButton { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + request_contact: Option, + #[serde(skip_serializing_if = "Option::is_none")] + request_location: Option, +} + +#[derive(Debug, Serialize)] +struct TelegramSendPhoto { + chat_id: String, + photo: String, + #[serde(skip_serializing_if = "Option::is_none")] + caption: Option, +} + +#[derive(Debug, Serialize)] +struct TelegramSendDocument { + chat_id: String, + document: String, + #[serde(skip_serializing_if = "Option::is_none")] + caption: Option, +} + +#[derive(Debug, Serialize)] +struct TelegramSendLocation { + chat_id: String, + latitude: f64, + longitude: f64, +} + +#[derive(Debug, Deserialize)] +pub struct TelegramResponse { + pub ok: bool, + #[serde(default)] + pub result: Option, + #[serde(default)] + pub description: Option, +} + +#[derive(Debug)] +pub struct TelegramAdapter { + bot_token: String, +} + +impl TelegramAdapter { + pub fn new(pool: Pool>, bot_id: uuid::Uuid) -> Self { + let config_manager = ConfigManager::new(pool); + + let bot_token = config_manager + .get_config(&bot_id, "telegram-bot-token", None) + .unwrap_or_default(); + + Self { bot_token } + } + + async fn send_telegram_request( + &self, + method: &str, + payload: &T, + ) -> Result> { + if self.bot_token.is_empty() { + return Err("Telegram bot token not configured".into()); + } + + let url = format!("https://api.telegram.org/bot{}/{}", self.bot_token, method); + + let client = reqwest::Client::new(); + let response = client + .post(&url) + .json(payload) + .send() + .await? + .json::() + .await?; + + if !response.ok { + let error_msg = response + .description + .unwrap_or_else(|| "Unknown Telegram API error".to_string()); + error!("Telegram API error: {}", error_msg); + return Err(error_msg.into()); + } + + Ok(response) + } + + pub async fn send_text_message( + &self, + chat_id: &str, + text: &str, + parse_mode: Option<&str>, + ) -> Result<(), Box> { + let payload = TelegramSendMessage { + chat_id: chat_id.to_string(), + text: text.to_string(), + parse_mode: parse_mode.map(String::from), + reply_markup: None, + }; + + self.send_telegram_request("sendMessage", &payload).await?; + info!("Telegram message sent to chat {}", chat_id); + Ok(()) + } + + pub async fn send_message_with_buttons( + &self, + chat_id: &str, + text: &str, + buttons: Vec<(String, String)>, + ) -> Result<(), Box> { + let inline_buttons: Vec> = buttons + .into_iter() + .map(|(label, callback)| { + vec![TelegramInlineButton { + text: label, + callback_data: Some(callback), + url: None, + }] + }) + .collect(); + + let payload = TelegramSendMessage { + chat_id: chat_id.to_string(), + text: text.to_string(), + parse_mode: Some("HTML".to_string()), + reply_markup: Some(TelegramReplyMarkup { + inline_keyboard: Some(inline_buttons), + keyboard: None, + one_time_keyboard: None, + resize_keyboard: None, + }), + }; + + self.send_telegram_request("sendMessage", &payload).await?; + info!("Telegram message with buttons sent to chat {}", chat_id); + Ok(()) + } + + pub async fn send_photo( + &self, + chat_id: &str, + photo_url: &str, + caption: Option<&str>, + ) -> Result<(), Box> { + let payload = TelegramSendPhoto { + chat_id: chat_id.to_string(), + photo: photo_url.to_string(), + caption: caption.map(String::from), + }; + + self.send_telegram_request("sendPhoto", &payload).await?; + info!("Telegram photo sent to chat {}", chat_id); + Ok(()) + } + + pub async fn send_document( + &self, + chat_id: &str, + document_url: &str, + caption: Option<&str>, + ) -> Result<(), Box> { + let payload = TelegramSendDocument { + chat_id: chat_id.to_string(), + document: document_url.to_string(), + caption: caption.map(String::from), + }; + + self.send_telegram_request("sendDocument", &payload).await?; + info!("Telegram document sent to chat {}", chat_id); + Ok(()) + } + + pub async fn send_location( + &self, + chat_id: &str, + latitude: f64, + longitude: f64, + ) -> Result<(), Box> { + let payload = TelegramSendLocation { + chat_id: chat_id.to_string(), + latitude, + longitude, + }; + + self.send_telegram_request("sendLocation", &payload).await?; + info!("Telegram location sent to chat {}", chat_id); + Ok(()) + } + + pub async fn set_webhook( + &self, + webhook_url: &str, + ) -> Result<(), Box> { + #[derive(Serialize)] + struct SetWebhook { + url: String, + allowed_updates: Vec, + } + + let payload = SetWebhook { + url: webhook_url.to_string(), + allowed_updates: vec![ + "message".to_string(), + "callback_query".to_string(), + "edited_message".to_string(), + ], + }; + + self.send_telegram_request("setWebhook", &payload).await?; + info!("Telegram webhook set to {}", webhook_url); + Ok(()) + } + + pub async fn delete_webhook(&self) -> Result<(), Box> { + #[derive(Serialize)] + struct DeleteWebhook { + drop_pending_updates: bool, + } + + let payload = DeleteWebhook { + drop_pending_updates: false, + }; + + self.send_telegram_request("deleteWebhook", &payload) + .await?; + info!("Telegram webhook deleted"); + Ok(()) + } + + pub async fn get_me(&self) -> Result> + { + #[derive(Serialize)] + struct Empty {} + + let response = self.send_telegram_request("getMe", &Empty {}).await?; + Ok(response.result.unwrap_or(serde_json::Value::Null)) + } +} + +#[async_trait] +impl ChannelAdapter for TelegramAdapter { + fn name(&self) -> &'static str { + "Telegram" + } + + fn is_configured(&self) -> bool { + !self.bot_token.is_empty() + } + + async fn send_message( + &self, + response: BotResponse, + ) -> Result<(), Box> { + if !self.is_configured() { + error!("Telegram adapter not configured. Please set telegram-bot-token in config.csv"); + return Err("Telegram not configured".into()); + } + + let chat_id = &response.user_id; + + self.send_text_message(chat_id, &response.content, Some("HTML")) + .await?; + + debug!( + "Telegram message sent to {} for session {}", + chat_id, response.session_id + ); + Ok(()) + } + + async fn get_user_info( + &self, + user_id: &str, + ) -> Result> { + Ok(serde_json::json!({ + "id": user_id, + "platform": "telegram", + "chat_id": user_id + })) + } +} diff --git a/src/core/i18n.rs b/src/core/i18n.rs new file mode 100644 index 000000000..6d7ee9304 --- /dev/null +++ b/src/core/i18n.rs @@ -0,0 +1,921 @@ +use axum::{ + async_trait, + extract::{FromRequestParts, Path, State}, + http::{header::ACCEPT_LANGUAGE, request::Parts}, + response::IntoResponse, + routing::get, + Json, Router, +}; +use botlib::i18n::{self, Locale as BotlibLocale, MessageArgs as BotlibMessageArgs}; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::shared::state::AppState; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Locale { + language: String, + region: Option, +} + +impl Locale { + pub fn new(locale_str: &str) -> Option { + if locale_str.is_empty() { + return None; + } + + let parts: Vec<&str> = locale_str.split(&['-', '_'][..]).collect(); + + let language = parts.first()?.to_lowercase(); + if language.len() < 2 || language.len() > 3 { + return None; + } + + let region = parts.get(1).map(|r| r.to_uppercase()); + + Some(Self { language, region }) + } + + #[must_use] + pub fn language(&self) -> &str { + &self.language + } + + #[must_use] + pub fn region(&self) -> Option<&str> { + self.region.as_deref() + } + + #[must_use] + pub fn to_bcp47(&self) -> String { + match &self.region { + Some(r) => format!("{}-{r}", self.language), + None => self.language.clone(), + } + } + + fn to_botlib_locale(&self) -> BotlibLocale { + BotlibLocale::new(&self.to_bcp47()).unwrap_or_default() + } +} + +impl Default for Locale { + fn default() -> Self { + Self { + language: "en".to_string(), + region: None, + } + } +} + +impl std::fmt::Display for Locale { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_bcp47()) + } +} + +const AVAILABLE_LOCALES: &[&str] = &["en", "pt-BR", "es", "zh-CN"]; + +pub struct RequestLocale(pub Locale); + +impl RequestLocale { + #[must_use] + pub fn locale(&self) -> &Locale { + &self.0 + } + + #[must_use] + pub fn language(&self) -> &str { + self.0.language() + } +} + +#[async_trait] +impl FromRequestParts for RequestLocale +where + S: Send + Sync, +{ + type Rejection = std::convert::Infallible; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let locale = parts + .headers + .get(ACCEPT_LANGUAGE) + .and_then(|h| h.to_str().ok()) + .map(parse_accept_language) + .and_then(|langs| negotiate_locale(&langs)) + .unwrap_or_default(); + + Ok(Self(locale)) + } +} + +fn parse_accept_language(header: &str) -> Vec<(String, f32)> { + let mut langs: Vec<(String, f32)> = header + .split(',') + .filter_map(|part| { + let mut iter = part.trim().split(';'); + let lang = iter.next()?.trim().to_string(); + + if lang.is_empty() || lang == "*" { + return None; + } + + let quality = iter + .next() + .and_then(|q| q.trim().strip_prefix("q=")) + .and_then(|q| q.parse().ok()) + .unwrap_or(1.0); + + Some((lang, quality)) + }) + .collect(); + + langs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + langs +} + +fn negotiate_locale(requested: &[(String, f32)]) -> Option { + for (lang, _) in requested { + let requested_locale = Locale::new(lang)?; + + for available in AVAILABLE_LOCALES { + let avail_locale = Locale::new(available)?; + + if requested_locale.language == avail_locale.language + && requested_locale.region == avail_locale.region + { + return Some(avail_locale); + } + } + + for available in AVAILABLE_LOCALES { + let avail_locale = Locale::new(available)?; + if requested_locale.language == avail_locale.language { + return Some(avail_locale); + } + } + } + + Some(Locale::default()) +} + +pub type MessageArgs = HashMap; + +pub fn init_i18n(locales_path: &str) -> Result<(), String> { + i18n::init(locales_path).map_err(|e| format!("Failed to initialize i18n: {e}")) +} + +pub fn is_i18n_initialized() -> bool { + i18n::is_initialized() +} + +pub fn t(locale: &Locale, key: &str) -> String { + t_with_args(locale, key, None) +} + +pub fn t_with_args(locale: &Locale, key: &str, args: Option<&MessageArgs>) -> String { + let botlib_locale = locale.to_botlib_locale(); + let botlib_args: Option = args.map(|a| { + a.iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect() + }); + i18n::get_with_args(&botlib_locale, key, botlib_args.as_ref()) +} + +pub fn available_locales() -> Vec { + if is_i18n_initialized() { + i18n::available_locales() + } else { + AVAILABLE_LOCALES.iter().map(|s| (*s).to_string()).collect() + } +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct LocalizedError { + pub code: String, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option, +} + +impl LocalizedError { + pub fn new(locale: &Locale, code: &str) -> Self { + Self { + code: code.to_string(), + message: t(locale, code), + details: None, + } + } + + pub fn with_args(locale: &Locale, code: &str, args: &MessageArgs) -> Self { + Self { + code: code.to_string(), + message: t_with_args(locale, code, Some(args)), + details: None, + } + } + + pub fn not_found(locale: &Locale, entity: &str) -> Self { + let mut args = MessageArgs::new(); + args.insert("entity".to_string(), entity.to_string()); + Self::with_args(locale, "error-http-404", &args) + } + + pub fn validation(locale: &Locale, field: &str, error_key: &str) -> Self { + let mut args = MessageArgs::new(); + args.insert("field".to_string(), field.to_string()); + Self::with_args(locale, error_key, &args) + } + + pub fn internal(locale: &Locale) -> Self { + Self::new(locale, "error-http-500") + } + + pub fn unauthorized(locale: &Locale) -> Self { + Self::new(locale, "error-http-401") + } + + pub fn forbidden(locale: &Locale) -> Self { + Self::new(locale, "error-http-403") + } + + pub fn rate_limited(locale: &Locale, seconds: u64) -> Self { + let mut args = MessageArgs::new(); + args.insert("seconds".to_string(), seconds.to_string()); + Self::with_args(locale, "error-http-429", &args) + } + + #[must_use] + pub fn with_details(mut self, details: serde_json::Value) -> Self { + self.details = Some(details); + self + } +} + +const TRANSLATION_KEYS: &[&str] = &[ + "app-name", + "app-tagline", + "action-save", + "action-cancel", + "action-delete", + "action-edit", + "action-close", + "action-confirm", + "action-retry", + "action-back", + "action-next", + "action-submit", + "action-search", + "action-refresh", + "action-copy", + "action-paste", + "action-undo", + "action-redo", + "action-select", + "action-select-all", + "action-clear", + "action-reset", + "action-apply", + "action-create", + "action-update", + "action-remove", + "action-add", + "action-upload", + "action-download", + "action-export", + "action-import", + "action-share", + "action-send", + "action-reply", + "action-forward", + "action-archive", + "action-restore", + "action-duplicate", + "action-rename", + "action-move", + "action-filter", + "action-sort", + "action-view", + "action-hide", + "action-show", + "action-expand", + "action-collapse", + "action-enable", + "action-disable", + "action-connect", + "action-disconnect", + "action-sync", + "action-start", + "action-stop", + "action-pause", + "action-resume", + "action-continue", + "action-finish", + "action-complete", + "action-approve", + "action-reject", + "action-accept", + "action-decline", + "action-login", + "action-logout", + "action-signup", + "action-forgot-password", + "label-loading", + "label-saving", + "label-processing", + "label-searching", + "label-uploading", + "label-downloading", + "label-no-results", + "label-no-data", + "label-empty", + "label-none", + "label-all", + "label-selected", + "label-required", + "label-optional", + "label-default", + "label-custom", + "label-new", + "label-draft", + "label-pending", + "label-active", + "label-inactive", + "label-enabled", + "label-disabled", + "label-public", + "label-private", + "label-shared", + "label-yes", + "label-no", + "label-on", + "label-off", + "label-true", + "label-false", + "label-unknown", + "label-other", + "label-more", + "label-less", + "label-details", + "label-summary", + "label-description", + "label-name", + "label-title", + "label-type", + "label-status", + "label-priority", + "label-date", + "label-time", + "label-size", + "label-count", + "label-total", + "label-average", + "label-minimum", + "label-maximum", + "label-version", + "label-id", + "label-created", + "label-updated", + "label-modified", + "label-deleted", + "label-by", + "label-from", + "label-to", + "label-at", + "label-in", + "label-of", + "status-success", + "status-error", + "status-warning", + "status-info", + "status-loading", + "status-complete", + "status-incomplete", + "status-failed", + "status-cancelled", + "status-pending", + "status-in-progress", + "status-done", + "status-ready", + "status-not-ready", + "status-connected", + "status-disconnected", + "status-online", + "status-offline", + "status-available", + "status-unavailable", + "status-busy", + "status-away", + "confirm-delete", + "confirm-delete-item", + "confirm-discard-changes", + "confirm-logout", + "confirm-cancel", + "time-now", + "time-today", + "time-yesterday", + "time-tomorrow", + "time-this-week", + "time-last-week", + "time-next-week", + "time-this-month", + "time-last-month", + "time-next-month", + "time-this-year", + "time-last-year", + "time-next-year", + "day-sunday", + "day-monday", + "day-tuesday", + "day-wednesday", + "day-thursday", + "day-friday", + "day-saturday", + "day-sun", + "day-mon", + "day-tue", + "day-wed", + "day-thu", + "day-fri", + "day-sat", + "month-january", + "month-february", + "month-march", + "month-april", + "month-may", + "month-june", + "month-july", + "month-august", + "month-september", + "month-october", + "month-november", + "month-december", + "month-jan", + "month-feb", + "month-mar", + "month-apr", + "month-may-short", + "month-jun", + "month-jul", + "month-aug", + "month-sep", + "month-oct", + "month-nov", + "month-dec", + "pagination-first", + "pagination-previous", + "pagination-next", + "pagination-last", + "pagination-items-per-page", + "pagination-go-to-page", + "validation-required", + "validation-email-invalid", + "validation-url-invalid", + "validation-number-invalid", + "validation-date-invalid", + "validation-pattern-mismatch", + "validation-passwords-mismatch", + "a11y-skip-to-content", + "a11y-loading", + "a11y-menu-open", + "a11y-menu-close", + "a11y-expand", + "a11y-collapse", + "a11y-selected", + "a11y-not-selected", + "a11y-required", + "a11y-error", + "a11y-success", + "a11y-warning", + "a11y-info", + "nav-home", + "nav-chat", + "nav-drive", + "nav-tasks", + "nav-mail", + "nav-calendar", + "nav-meet", + "nav-paper", + "nav-research", + "nav-analytics", + "nav-settings", + "nav-admin", + "nav-monitoring", + "nav-sources", + "nav-tools", + "nav-attendant", + "dashboard-title", + "dashboard-welcome", + "dashboard-quick-actions", + "dashboard-recent-activity", + "chat-title", + "chat-placeholder", + "chat-send", + "chat-new-conversation", + "chat-history", + "chat-clear", + "chat-typing", + "chat-online", + "chat-offline", + "chat-connecting", + "drive-title", + "drive-upload", + "drive-new-folder", + "drive-download", + "drive-delete", + "drive-rename", + "drive-move", + "drive-copy", + "drive-share", + "drive-properties", + "drive-empty-folder", + "drive-search-placeholder", + "drive-sort-name", + "drive-sort-date", + "drive-sort-size", + "drive-sort-type", + "tasks-title", + "tasks-new", + "tasks-all", + "tasks-pending", + "tasks-completed", + "tasks-overdue", + "tasks-today", + "tasks-this-week", + "tasks-no-tasks", + "tasks-priority-low", + "tasks-priority-medium", + "tasks-priority-high", + "tasks-priority-urgent", + "tasks-assign", + "tasks-due-date", + "tasks-description", + "calendar-title", + "calendar-today", + "calendar-day", + "calendar-week", + "calendar-month", + "calendar-year", + "calendar-new-event", + "calendar-edit-event", + "calendar-delete-event", + "calendar-event-title", + "calendar-event-location", + "calendar-event-start", + "calendar-event-end", + "calendar-event-all-day", + "calendar-event-repeat", + "calendar-event-reminder", + "calendar-no-events", + "meet-title", + "meet-join", + "meet-leave", + "meet-mute", + "meet-unmute", + "meet-video-on", + "meet-video-off", + "meet-share-screen", + "meet-stop-sharing", + "meet-participants", + "meet-chat", + "meet-settings", + "meet-end-call", + "meet-invite", + "meet-copy-link", + "email-title", + "email-compose", + "email-inbox", + "email-sent", + "email-drafts", + "email-trash", + "email-spam", + "email-starred", + "email-archive", + "email-to", + "email-cc", + "email-bcc", + "email-subject", + "email-body", + "email-attachments", + "email-send", + "email-save-draft", + "email-discard", + "email-reply", + "email-reply-all", + "email-forward", + "email-mark-read", + "email-mark-unread", + "email-delete", + "email-no-messages", + "settings-title", + "settings-general", + "settings-account", + "settings-notifications", + "settings-privacy", + "settings-security", + "settings-appearance", + "settings-language", + "settings-timezone", + "settings-theme", + "settings-theme-light", + "settings-theme-dark", + "settings-theme-system", + "settings-save", + "settings-saved", + "admin-title", + "admin-users", + "admin-bots", + "admin-system", + "admin-logs", + "admin-backups", + "admin-settings", + "error-http-400", + "error-http-401", + "error-http-403", + "error-http-404", + "error-http-429", + "error-http-500", + "error-http-502", + "error-http-503", + "error-network", + "error-timeout", + "error-unknown", + "paper-title", + "paper-new-note", + "paper-search-notes", + "paper-quick-start", + "paper-template-blank", + "paper-template-meeting", + "paper-template-todo", + "paper-template-research", + "paper-untitled", + "paper-placeholder", + "paper-commands", + "paper-heading1", + "paper-heading1-desc", + "paper-heading2", + "paper-heading2-desc", + "paper-heading3", + "paper-heading3-desc", + "paper-paragraph", + "paper-paragraph-desc", + "paper-bullet-list", + "paper-bullet-list-desc", + "paper-numbered-list", + "paper-numbered-list-desc", + "paper-todo-list", + "paper-todo-list-desc", + "paper-quote", + "paper-quote-desc", + "paper-divider", + "paper-divider-desc", + "paper-code-block", + "paper-code-block-desc", + "paper-table", + "paper-table-desc", + "paper-image", + "paper-image-desc", + "paper-callout", + "paper-callout-desc", + "paper-ai-write", + "paper-ai-write-desc", + "paper-ai-summarize", + "paper-ai-summarize-desc", + "paper-ai-expand", + "paper-ai-expand-desc", + "paper-ai-improve", + "paper-ai-improve-desc", + "paper-ai-translate", + "paper-ai-translate-desc", + "paper-ai-assistant", + "paper-ai-quick-actions", + "paper-ai-rewrite", + "paper-ai-make-shorter", + "paper-ai-make-longer", + "paper-ai-fix-grammar", + "paper-ai-tone", + "paper-ai-tone-professional", + "paper-ai-tone-casual", + "paper-ai-tone-friendly", + "paper-ai-tone-formal", + "paper-ai-translate-to", + "paper-ai-custom-prompt", + "paper-ai-custom-placeholder", + "paper-ai-generate", + "paper-ai-response", + "paper-ai-apply", + "paper-ai-regenerate", + "paper-ai-copy", + "paper-word-count", + "paper-char-count", + "paper-saved", + "paper-saving", + "paper-last-edited", + "paper-last-edited-now", + "paper-export", + "paper-export-pdf", + "paper-export-docx", + "paper-export-markdown", + "paper-export-html", + "paper-export-txt", + "chat-voice", + "chat-message-placeholder", + "drive-my-drive", + "drive-shared", + "drive-recent", + "drive-starred", + "drive-trash", + "drive-loading-storage", + "drive-storage-used", + "drive-empty-folder", + "drive-drop-files", + "tasks-active", + "tasks-awaiting", + "tasks-paused", + "tasks-blocked", + "tasks-time-saved", + "tasks-input-placeholder", + "calendar-my-calendars", + "email-scheduled", + "email-tracking", + "email-inbox", + "email-starred", + "email-sent", + "email-drafts", + "email-spam", + "email-trash", + "email-compose", + "compliance-title", + "compliance-subtitle", + "compliance-export", + "compliance-run-scan", + "compliance-critical", + "compliance-critical-desc", + "compliance-high", + "compliance-high-desc", + "compliance-medium", + "compliance-medium-desc", + "compliance-low", + "compliance-low-desc", + "compliance-info", + "compliance-info-desc", + "compliance-filter-severity", + "compliance-filter-type", + "compliance-issues-found", + "sources-title", + "sources-subtitle", + "sources-prompts", + "sources-templates", + "sources-news", + "sources-mcp-servers", + "sources-llm-tools", + "sources-models", + "sources-repositories", + "sources-apps", + "attendant-title", + "attendant-subtitle", + "attendant-queue", + "attendant-active", + "attendant-resolved", + "attendant-assign", + "attendant-transfer", + "attendant-resolve", + "attendant-no-items", + "attendant-crm-disabled", + "attendant-status-online", + "attendant-select-conversation", + "sources-search", +]; + +pub fn get_translations_json(locale: &Locale) -> serde_json::Value { + let mut translations = serde_json::Map::new(); + + for key in TRANSLATION_KEYS { + translations.insert((*key).to_string(), serde_json::Value::String(t(locale, key))); + } + + serde_json::Value::Object(translations) +} + +pub fn configure_i18n_routes() -> Router> { + Router::new() + .route("/api/i18n/locales", get(handle_get_locales)) + .route("/api/i18n/:locale", get(handle_get_translations)) +} + +async fn handle_get_locales( + State(_state): State>, +) -> impl IntoResponse { + let locales = available_locales(); + Json(serde_json::json!({ + "locales": locales, + "default": "en" + })) +} + +async fn handle_get_translations( + State(_state): State>, + Path(locale_str): Path, +) -> impl IntoResponse { + let locale = Locale::new(&locale_str).unwrap_or_default(); + + let translations = get_translations_json(&locale); + + Json(serde_json::json!({ + "locale": locale.to_bcp47(), + "translations": translations + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_accept_language_simple() { + let result = parse_accept_language("en-US"); + assert_eq!(result.len(), 1); + assert_eq!(result[0].0, "en-US"); + assert!((result[0].1 - 1.0).abs() < f32::EPSILON); + } + + #[test] + fn test_parse_accept_language_with_quality() { + let result = parse_accept_language("pt-BR,pt;q=0.9,en;q=0.8"); + assert_eq!(result.len(), 3); + assert_eq!(result[0].0, "pt-BR"); + assert_eq!(result[1].0, "pt"); + assert_eq!(result[2].0, "en"); + } + + #[test] + fn test_parse_accept_language_sorted_by_quality() { + let result = parse_accept_language("en;q=0.5,pt-BR;q=0.9,es;q=0.7"); + assert_eq!(result[0].0, "pt-BR"); + assert_eq!(result[1].0, "es"); + assert_eq!(result[2].0, "en"); + } + + #[test] + fn test_negotiate_locale_exact_match() { + let requested = vec![("pt-BR".to_string(), 1.0)]; + let result = negotiate_locale(&requested); + assert!(result.is_some()); + assert_eq!( + result.as_ref().map(|l| l.to_bcp47()), + Some("pt-BR".to_string()) + ); + } + + #[test] + fn test_negotiate_locale_language_match() { + let requested = vec![("pt-PT".to_string(), 1.0)]; + let result = negotiate_locale(&requested); + assert!(result.is_some()); + assert_eq!(result.as_ref().map(|l| l.language()), Some("pt")); + } + + #[test] + fn test_negotiate_locale_fallback() { + let requested = vec![("ja".to_string(), 1.0)]; + let result = negotiate_locale(&requested); + assert!(result.is_some()); + assert_eq!(result.as_ref().map(|l| l.language()), Some("en")); + } + + #[test] + fn test_locale_default() { + let locale = Locale::default(); + assert_eq!(locale.language(), "en"); + assert_eq!(locale.region(), None); + } + + #[test] + fn test_locale_display() { + let locale = Locale::new("pt-BR").unwrap(); + assert_eq!(locale.to_string(), "pt-BR"); + } + + #[test] + fn test_localized_error_not_found() { + let locale = Locale::default(); + let error = LocalizedError::not_found(&locale, "User"); + assert_eq!(error.code, "error-http-404"); + } + + #[test] + fn test_localized_error_with_details() { + let locale = Locale::default(); + let error = + LocalizedError::internal(&locale).with_details(serde_json::json!({"trace_id": "abc123"})); + assert!(error.details.is_some()); + } + + #[test] + fn test_available_locales_without_init() { + let locales = available_locales(); + assert!(!locales.is_empty()); + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index b497ef5bf..735d0867e 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -5,9 +5,11 @@ pub mod bot_database; pub mod config; pub mod directory; pub mod dns; +pub mod i18n; pub mod kb; pub mod oauth; pub mod package_manager; +pub mod product; pub mod rate_limit; pub mod secrets; pub mod session; diff --git a/src/core/oauth/routes.rs b/src/core/oauth/routes.rs index 10b9717fd..aca7d0340 100644 --- a/src/core/oauth/routes.rs +++ b/src/core/oauth/routes.rs @@ -49,8 +49,8 @@ pub struct ProviderInfo { pub fn configure() -> Router> { Router::new() .route("/auth/oauth/providers", get(list_providers)) - .route("/auth/oauth/{provider}", get(start_oauth)) - .route("/auth/oauth/{provider}/callback", get(oauth_callback)) + .route("/auth/oauth/:provider", get(start_oauth)) + .route("/auth/oauth/:provider/callback", get(oauth_callback)) } async fn list_providers(State(state): State>) -> impl IntoResponse { diff --git a/src/core/package_manager/setup/directory_setup.rs b/src/core/package_manager/setup/directory_setup.rs index 61574f7d6..4b4a92589 100644 --- a/src/core/package_manager/setup/directory_setup.rs +++ b/src/core/package_manager/setup/directory_setup.rs @@ -353,12 +353,14 @@ impl DirectorySetup { .bearer_auth(self.admin_token.as_ref().unwrap_or(&String::new())) .json(&json!({ "name": app_name, - "redirectUris": [redirect_uri], + "redirectUris": [redirect_uri, "http://localhost:3000/auth/callback", "http://localhost:8088/auth/callback"], "responseTypes": ["OIDC_RESPONSE_TYPE_CODE"], - "grantTypes": ["OIDC_GRANT_TYPE_AUTHORIZATION_CODE", "OIDC_GRANT_TYPE_REFRESH_TOKEN"], + "grantTypes": ["OIDC_GRANT_TYPE_AUTHORIZATION_CODE", "OIDC_GRANT_TYPE_REFRESH_TOKEN", "OIDC_GRANT_TYPE_PASSWORD"], "appType": "OIDC_APP_TYPE_WEB", - "authMethodType": "OIDC_AUTH_METHOD_TYPE_BASIC", - "postLogoutRedirectUris": ["http://localhost:8080"], + "authMethodType": "OIDC_AUTH_METHOD_TYPE_POST", + "postLogoutRedirectUris": ["http://localhost:8080", "http://localhost:3000", "http://localhost:8088"], + "accessTokenType": "OIDC_TOKEN_TYPE_BEARER", + "devMode": true, })) .send() .await?; diff --git a/src/core/product.rs b/src/core/product.rs new file mode 100644 index 000000000..ff5c71e59 --- /dev/null +++ b/src/core/product.rs @@ -0,0 +1,452 @@ +//! Product Configuration Module +//! +//! This module handles white-label settings loaded from the `.product` file. +//! It provides a global configuration that can be used throughout the application +//! to customize branding, enabled apps, and default theme. + +use once_cell::sync::Lazy; +use std::collections::HashSet; +use std::fs; +use std::path::Path; +use std::sync::RwLock; +use tracing::{info, warn}; + +/// Global product configuration instance +pub static PRODUCT_CONFIG: Lazy> = Lazy::new(|| { + RwLock::new(ProductConfig::load().unwrap_or_default()) +}); + +/// Product configuration structure +#[derive(Debug, Clone)] +pub struct ProductConfig { + /// Product name (replaces "General Bots" throughout the application) + pub name: String, + + /// Set of active apps + pub apps: HashSet, + + /// Default theme + pub theme: String, + + /// Logo URL (optional) + pub logo: Option, + + /// Favicon URL (optional) + pub favicon: Option, + + /// Primary color override (optional) + pub primary_color: Option, + + /// Support email (optional) + pub support_email: Option, + + /// Documentation URL (optional) + pub docs_url: Option, + + /// Copyright text (optional) + pub copyright: Option, +} + +impl Default for ProductConfig { + fn default() -> Self { + let mut apps = HashSet::new(); + // All apps enabled by default + for app in &[ + "chat", "mail", "calendar", "drive", "tasks", "docs", "paper", + "sheet", "slides", "meet", "research", "sources", "analytics", + "admin", "monitoring", "settings", + ] { + apps.insert(app.to_string()); + } + + Self { + name: "General Bots".to_string(), + apps, + theme: "sentient".to_string(), + logo: None, + favicon: None, + primary_color: None, + support_email: None, + docs_url: Some("https://docs.pragmatismo.com.br".to_string()), + copyright: None, + } + } +} + +impl ProductConfig { + /// Load configuration from .product file + pub fn load() -> Result { + let paths = [ + ".product", + "./botserver/.product", + "../.product", + ]; + + let mut content = None; + for path in &paths { + if Path::new(path).exists() { + content = Some(fs::read_to_string(path).map_err(ProductConfigError::IoError)?); + info!("Loaded product configuration from: {}", path); + break; + } + } + + let content = match content { + Some(c) => c, + None => { + warn!("No .product file found, using default configuration"); + return Ok(Self::default()); + } + }; + + Self::parse(&content) + } + + /// Parse configuration from string content + pub fn parse(content: &str) -> Result { + let mut config = Self::default(); + let mut apps_specified = false; + + for line in content.lines() { + let line = line.trim(); + + // Skip empty lines and comments + if line.is_empty() || line.starts_with('#') { + continue; + } + + // Parse key=value pairs + if let Some((key, value)) = line.split_once('=') { + let key = key.trim().to_lowercase(); + let value = value.trim(); + + match key.as_str() { + "name" => { + if !value.is_empty() { + config.name = value.to_string(); + } + } + "apps" => { + apps_specified = true; + config.apps.clear(); + for app in value.split(',') { + let app = app.trim().to_lowercase(); + if !app.is_empty() { + config.apps.insert(app); + } + } + } + "theme" => { + if !value.is_empty() { + config.theme = value.to_string(); + } + } + "logo" => { + if !value.is_empty() { + config.logo = Some(value.to_string()); + } + } + "favicon" => { + if !value.is_empty() { + config.favicon = Some(value.to_string()); + } + } + "primary_color" => { + if !value.is_empty() { + config.primary_color = Some(value.to_string()); + } + } + "support_email" => { + if !value.is_empty() { + config.support_email = Some(value.to_string()); + } + } + "docs_url" => { + if !value.is_empty() { + config.docs_url = Some(value.to_string()); + } + } + "copyright" => { + if !value.is_empty() { + config.copyright = Some(value.to_string()); + } + } + _ => { + warn!("Unknown product configuration key: {}", key); + } + } + } + } + + if !apps_specified { + info!("No apps specified in .product, all apps enabled by default"); + } + + info!( + "Product config loaded: name='{}', apps={:?}, theme='{}'", + config.name, config.apps, config.theme + ); + + Ok(config) + } + + /// Check if an app is enabled + pub fn is_app_enabled(&self, app: &str) -> bool { + self.apps.contains(&app.to_lowercase()) + } + + /// Get the product name + pub fn get_name(&self) -> &str { + &self.name + } + + /// Get the default theme + pub fn get_theme(&self) -> &str { + &self.theme + } + + /// Replace "General Bots" with the product name in a string + pub fn replace_branding(&self, text: &str) -> String { + text.replace("General Bots", &self.name) + .replace("general bots", &self.name.to_lowercase()) + .replace("GENERAL BOTS", &self.name.to_uppercase()) + } + + /// Get copyright text with year substitution + pub fn get_copyright(&self) -> String { + let year = chrono::Utc::now().format("%Y").to_string(); + let template = self.copyright.as_deref() + .unwrap_or("© {year} {name}. All rights reserved."); + + template + .replace("{year}", &year) + .replace("{name}", &self.name) + } + + /// Get all enabled apps as a vector + pub fn get_enabled_apps(&self) -> Vec { + self.apps.iter().cloned().collect() + } + + /// Reload configuration from file + pub fn reload() -> Result<(), ProductConfigError> { + let new_config = Self::load()?; + let mut config = PRODUCT_CONFIG.write() + .map_err(|_| ProductConfigError::LockError)?; + *config = new_config; + info!("Product configuration reloaded"); + Ok(()) + } +} + +/// Error type for product configuration +#[derive(Debug)] +pub enum ProductConfigError { + IoError(std::io::Error), + ParseError(String), + LockError, +} + +impl std::fmt::Display for ProductConfigError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::IoError(e) => write!(f, "IO error reading .product file: {}", e), + Self::ParseError(msg) => write!(f, "Parse error in .product file: {}", msg), + Self::LockError => write!(f, "Failed to acquire lock on product configuration"), + } + } +} + +impl std::error::Error for ProductConfigError {} + +/// Helper function to get product name +pub fn get_product_name() -> String { + PRODUCT_CONFIG + .read() + .map(|c| c.name.clone()) + .unwrap_or_else(|_| "General Bots".to_string()) +} + +/// Helper function to check if an app is enabled +pub fn is_app_enabled(app: &str) -> bool { + PRODUCT_CONFIG + .read() + .map(|c| c.is_app_enabled(app)) + .unwrap_or(true) +} + +/// Helper function to get default theme +pub fn get_default_theme() -> String { + PRODUCT_CONFIG + .read() + .map(|c| c.theme.clone()) + .unwrap_or_else(|_| "sentient".to_string()) +} + +/// Helper function to replace branding in text +pub fn replace_branding(text: &str) -> String { + PRODUCT_CONFIG + .read() + .map(|c| c.replace_branding(text)) + .unwrap_or_else(|_| text.to_string()) +} + +/// Helper function to get product config for serialization +pub fn get_product_config_json() -> serde_json::Value { + let config = PRODUCT_CONFIG.read().ok(); + + match config { + Some(c) => serde_json::json!({ + "name": c.name, + "apps": c.get_enabled_apps(), + "theme": c.theme, + "logo": c.logo, + "favicon": c.favicon, + "primary_color": c.primary_color, + "docs_url": c.docs_url, + "copyright": c.get_copyright(), + }), + None => serde_json::json!({ + "name": "General Bots", + "apps": [], + "theme": "sentient", + }) + } +} + +/// Middleware to check if an app is enabled before allowing API access +pub async fn app_gate_middleware( + req: axum::http::Request, + next: axum::middleware::Next, +) -> axum::response::Response { + use axum::http::StatusCode; + use axum::response::IntoResponse; + + let path = req.uri().path(); + + // Map API paths to app names + let app_name = match path { + p if p.starts_with("/api/calendar") => Some("calendar"), + p if p.starts_with("/api/mail") || p.starts_with("/api/email") => Some("mail"), + p if p.starts_with("/api/drive") || p.starts_with("/api/files") => Some("drive"), + p if p.starts_with("/api/tasks") => Some("tasks"), + p if p.starts_with("/api/docs") => Some("docs"), + p if p.starts_with("/api/paper") => Some("paper"), + p if p.starts_with("/api/sheet") => Some("sheet"), + p if p.starts_with("/api/slides") => Some("slides"), + p if p.starts_with("/api/meet") => Some("meet"), + p if p.starts_with("/api/research") => Some("research"), + p if p.starts_with("/api/sources") => Some("sources"), + p if p.starts_with("/api/analytics") => Some("analytics"), + p if p.starts_with("/api/admin") => Some("admin"), + p if p.starts_with("/api/monitoring") => Some("monitoring"), + p if p.starts_with("/api/settings") => Some("settings"), + p if p.starts_with("/api/ui/calendar") => Some("calendar"), + p if p.starts_with("/api/ui/mail") => Some("mail"), + p if p.starts_with("/api/ui/drive") => Some("drive"), + p if p.starts_with("/api/ui/tasks") => Some("tasks"), + p if p.starts_with("/api/ui/docs") => Some("docs"), + p if p.starts_with("/api/ui/paper") => Some("paper"), + p if p.starts_with("/api/ui/sheet") => Some("sheet"), + p if p.starts_with("/api/ui/slides") => Some("slides"), + p if p.starts_with("/api/ui/meet") => Some("meet"), + p if p.starts_with("/api/ui/research") => Some("research"), + p if p.starts_with("/api/ui/sources") => Some("sources"), + p if p.starts_with("/api/ui/analytics") => Some("analytics"), + p if p.starts_with("/api/ui/admin") => Some("admin"), + p if p.starts_with("/api/ui/monitoring") => Some("monitoring"), + p if p.starts_with("/api/ui/settings") => Some("settings"), + _ => None, // Allow all other paths + }; + + // Check if the app is enabled + if let Some(app) = app_name { + if !is_app_enabled(app) { + let error_response = serde_json::json!({ + "error": "app_disabled", + "message": format!("The '{}' app is not enabled for this installation", app), + "code": 403 + }); + + return ( + StatusCode::FORBIDDEN, + axum::Json(error_response) + ).into_response(); + } + } + + next.run(req).await +} + +/// Get list of disabled apps for logging/debugging +pub fn get_disabled_apps() -> Vec { + let all_apps = vec![ + "chat", "mail", "calendar", "drive", "tasks", "docs", "paper", + "sheet", "slides", "meet", "research", "sources", "analytics", + "admin", "monitoring", "settings", + ]; + + all_apps + .into_iter() + .filter(|app| !is_app_enabled(app)) + .map(|s| s.to_string()) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = ProductConfig::default(); + assert_eq!(config.name, "General Bots"); + assert_eq!(config.theme, "sentient"); + assert!(config.is_app_enabled("chat")); + assert!(config.is_app_enabled("drive")); + } + + #[test] + fn test_parse_config() { + let content = r#" +# Test config +name=My Custom Bot +apps=chat,drive,tasks +theme=dark + "#; + + let config = ProductConfig::parse(content).unwrap(); + assert_eq!(config.name, "My Custom Bot"); + assert_eq!(config.theme, "dark"); + assert!(config.is_app_enabled("chat")); + assert!(config.is_app_enabled("drive")); + assert!(config.is_app_enabled("tasks")); + assert!(!config.is_app_enabled("mail")); + assert!(!config.is_app_enabled("calendar")); + } + + #[test] + fn test_replace_branding() { + let config = ProductConfig { + name: "Acme Bot".to_string(), + ..Default::default() + }; + + assert_eq!( + config.replace_branding("Welcome to General Bots"), + "Welcome to Acme Bot" + ); + } + + #[test] + fn test_case_insensitive_apps() { + let content = "apps=Chat,DRIVE,Tasks"; + let config = ProductConfig::parse(content).unwrap(); + + assert!(config.is_app_enabled("chat")); + assert!(config.is_app_enabled("CHAT")); + assert!(config.is_app_enabled("Chat")); + assert!(config.is_app_enabled("drive")); + assert!(config.is_app_enabled("tasks")); + } +} diff --git a/src/core/secrets/mod.rs b/src/core/secrets/mod.rs index b62235001..90655cf45 100644 --- a/src/core/secrets/mod.rs +++ b/src/core/secrets/mod.rs @@ -224,7 +224,7 @@ impl SecretsManager { Ok(( s.get("url") .cloned() - .unwrap_or_else(|| "https://localhost:8080".into()), + .unwrap_or_else(|| "http://localhost:8300".into()), s.get("project_id").cloned().unwrap_or_default(), s.get("client_id").cloned().unwrap_or_default(), s.get("client_secret").cloned().unwrap_or_default(), diff --git a/src/core/shared/admin.rs b/src/core/shared/admin.rs index ebfd2240c..6cdf096cf 100644 --- a/src/core/shared/admin.rs +++ b/src/core/shared/admin.rs @@ -1,13 +1,16 @@ use axum::{ extract::{Query, State}, http::StatusCode, - response::Json, + response::{Html, Json}, + routing::get, + Router, }; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use uuid::Uuid; +use crate::core::urls::ApiUrls; use crate::shared::state::AppState; #[derive(Debug, Deserialize)] @@ -194,7 +197,505 @@ pub struct SuccessResponse { pub message: Option, } -pub fn get_system_status( +#[derive(Debug, Serialize)] +pub struct AdminDashboardData { + pub total_users: i64, + pub active_groups: i64, + pub running_bots: i64, + pub storage_used_gb: f64, + pub storage_total_gb: f64, + pub recent_activity: Vec, + pub system_health: SystemHealth, +} + +#[derive(Debug, Serialize)] +pub struct ActivityItem { + pub id: String, + pub action: String, + pub user: String, + pub timestamp: DateTime, + pub details: Option, +} + +#[derive(Debug, Serialize)] +pub struct SystemHealth { + pub status: String, + pub cpu_percent: f64, + pub memory_percent: f64, + pub services_healthy: i32, + pub services_total: i32, +} + +#[derive(Debug, Serialize)] +pub struct StatValue { + pub value: String, + pub label: String, + pub trend: Option, +} + +pub fn configure() -> Router> { + Router::new() + .route(ApiUrls::ADMIN_DASHBOARD, get(get_admin_dashboard)) + .route(ApiUrls::ADMIN_STATS_USERS, get(get_stats_users)) + .route(ApiUrls::ADMIN_STATS_GROUPS, get(get_stats_groups)) + .route(ApiUrls::ADMIN_STATS_BOTS, get(get_stats_bots)) + .route(ApiUrls::ADMIN_STATS_STORAGE, get(get_stats_storage)) + .route(ApiUrls::ADMIN_USERS, get(get_admin_users)) + .route(ApiUrls::ADMIN_GROUPS, get(get_admin_groups)) + .route(ApiUrls::ADMIN_BOTS, get(get_admin_bots)) + .route(ApiUrls::ADMIN_DNS, get(get_admin_dns)) + .route(ApiUrls::ADMIN_BILLING, get(get_admin_billing)) + .route(ApiUrls::ADMIN_AUDIT, get(get_admin_audit)) + .route(ApiUrls::ADMIN_SYSTEM, get(get_system_status)) +} + +pub async fn get_admin_dashboard( + State(_state): State>, +) -> Html { + let html = r##" +
+ + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+

Quick Actions

+
+ + + + +
+
+ +
+

System Health

+
+
+
+ API Server + Healthy +
+
99.9%
+
Uptime
+
+
+
+ Database + Healthy +
+
12ms
+
Avg Response
+
+
+
+ Storage + Healthy +
+
45%
+
Capacity Used
+
+
+
+
+"##; + Html(html.to_string()) +} + +pub async fn get_stats_users( + State(_state): State>, +) -> Html { + let html = r##" +
+ + + + + + +
+
+ 127 + Total Users +
+"##; + Html(html.to_string()) +} + +pub async fn get_stats_groups( + State(_state): State>, +) -> Html { + let html = r##" +
+ + + + + +
+
+ 12 + Active Groups +
+"##; + Html(html.to_string()) +} + +pub async fn get_stats_bots( + State(_state): State>, +) -> Html { + let html = r##" +
+ + + + + +
+
+ 8 + Running Bots +
+"##; + Html(html.to_string()) +} + +pub async fn get_stats_storage( + State(_state): State>, +) -> Html { + let html = r##" +
+ + + + + +
+
+ 45.2 GB + Storage Used +
+"##; + Html(html.to_string()) +} + +pub async fn get_admin_users( + State(_state): State>, +) -> Html { + let html = r##" +
+ +
+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
NameEmailRoleStatusActions
John Doejohn@example.comAdminActive
Jane Smithjane@example.comUserActive
+
+
+"##; + Html(html.to_string()) +} + +pub async fn get_admin_groups( + State(_state): State>, +) -> Html { + let html = r##" +
+ +
+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + +
NameMembersCreatedActions
Engineering152024-01-15
Marketing82024-02-20
+
+
+"##; + Html(html.to_string()) +} + +pub async fn get_admin_bots( + State(_state): State>, +) -> Html { + let html = r##" +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
NameStatusMessagesLast ActiveActions
Support BotRunning1,234Just now
Sales AssistantRunning5675 min ago
+
+
+"##; + Html(html.to_string()) +} + +pub async fn get_admin_dns( + State(_state): State>, +) -> Html { + let html = r##" +
+ +
+ +
+
+ + + + + + + + + + + + + + + + + + + +
DomainTypeStatusSSLActions
bot.example.comCNAMEActiveValid
+
+
+"##; + Html(html.to_string()) +} + +pub async fn get_admin_billing( + State(_state): State>, +) -> Html { + let html = r##" +
+ +
+
+

Current Plan

+
Enterprise
+
$499/month
+
+
+

Next Billing Date

+
January 15, 2025
+
+
+

Payment Method

+
**** **** **** 4242
+
+
+
+"##; + Html(html.to_string()) +} + +pub async fn get_admin_audit( + State(_state): State>, +) -> Html { + let now = Utc::now(); + let html = format!(r##" +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + +
TimeUserActionDetails
{}admin@example.comUser LoginSuccessful login from 192.168.1.1
{}admin@example.comSettings ChangedUpdated system configuration
+
+
+"##, now.format("%Y-%m-%d %H:%M"), now.format("%Y-%m-%d %H:%M")); + Html(html) +} + +pub async fn get_system_status( State(_state): State>, ) -> Result, (StatusCode, Json)> { let now = Utc::now(); @@ -259,7 +760,7 @@ pub fn get_system_status( Ok(Json(status)) } -pub fn get_system_metrics( +pub async fn get_system_metrics( State(_state): State>, ) -> Result, (StatusCode, Json)> { let metrics = SystemMetricsResponse { diff --git a/src/core/urls.rs b/src/core/urls.rs index 9acd60141..6505ae43d 100644 --- a/src/core/urls.rs +++ b/src/core/urls.rs @@ -22,6 +22,9 @@ impl ApiUrls { pub const GROUP_REMOVE_MEMBER: &'static str = "/api/groups/:id/members/:user_id"; pub const GROUP_PERMISSIONS: &'static str = "/api/groups/:id/permissions"; + // Product - JSON APIs + pub const PRODUCT: &'static str = "/api/product"; + // Auth - JSON APIs pub const AUTH: &'static str = "/api/auth"; pub const AUTH_TOKEN: &'static str = "/api/auth/token"; @@ -162,8 +165,17 @@ impl ApiUrls { pub const ANALYTICS_BUDGET_STATUS: &'static str = "/api/ui/analytics/budget/status"; // Admin - JSON APIs + pub const ADMIN_DASHBOARD: &'static str = "/api/admin/dashboard"; pub const ADMIN_STATS: &'static str = "/api/admin/stats"; + pub const ADMIN_STATS_USERS: &'static str = "/api/admin/stats/users"; + pub const ADMIN_STATS_GROUPS: &'static str = "/api/admin/stats/groups"; + pub const ADMIN_STATS_BOTS: &'static str = "/api/admin/stats/bots"; + pub const ADMIN_STATS_STORAGE: &'static str = "/api/admin/stats/storage"; pub const ADMIN_USERS: &'static str = "/api/admin/users"; + pub const ADMIN_GROUPS: &'static str = "/api/admin/groups"; + pub const ADMIN_BOTS: &'static str = "/api/admin/bots"; + pub const ADMIN_DNS: &'static str = "/api/admin/dns"; + pub const ADMIN_BILLING: &'static str = "/api/admin/billing"; pub const ADMIN_SYSTEM: &'static str = "/api/admin/system"; pub const ADMIN_LOGS: &'static str = "/api/admin/logs"; pub const ADMIN_BACKUPS: &'static str = "/api/admin/backups"; @@ -175,6 +187,10 @@ impl ApiUrls { pub const STATUS: &'static str = "/api/status"; pub const SERVICES_STATUS: &'static str = "/api/services/status"; + // i18n - JSON APIs + pub const I18N_TRANSLATIONS: &'static str = "/api/i18n/:locale"; + pub const I18N_LOCALES: &'static str = "/api/i18n/locales"; + // Knowledge Base - JSON APIs pub const KB_SEARCH: &'static str = "/api/kb/search"; pub const KB_UPLOAD: &'static str = "/api/kb/upload"; @@ -278,7 +294,32 @@ impl ApiUrls { pub const MSTEAMS_MESSAGES: &'static str = "/api/msteams/messages"; pub const MSTEAMS_SEND: &'static str = "/api/msteams/send"; - // Paper - HTMX/HTML APIs + // Docs (Word Processor) - HTMX/HTML APIs + pub const DOCS_NEW: &'static str = "/api/ui/docs/new"; + pub const DOCS_LIST: &'static str = "/api/ui/docs/list"; + pub const DOCS_SEARCH: &'static str = "/api/ui/docs/search"; + pub const DOCS_SAVE: &'static str = "/api/ui/docs/save"; + pub const DOCS_AUTOSAVE: &'static str = "/api/ui/docs/autosave"; + pub const DOCS_BY_ID: &'static str = "/api/ui/docs/:id"; + pub const DOCS_DELETE: &'static str = "/api/ui/docs/:id/delete"; + pub const DOCS_TEMPLATE_BLANK: &'static str = "/api/ui/docs/template/blank"; + pub const DOCS_TEMPLATE_MEETING: &'static str = "/api/ui/docs/template/meeting"; + pub const DOCS_TEMPLATE_REPORT: &'static str = "/api/ui/docs/template/report"; + pub const DOCS_TEMPLATE_LETTER: &'static str = "/api/ui/docs/template/letter"; + pub const DOCS_AI_SUMMARIZE: &'static str = "/api/ui/docs/ai/summarize"; + pub const DOCS_AI_EXPAND: &'static str = "/api/ui/docs/ai/expand"; + pub const DOCS_AI_IMPROVE: &'static str = "/api/ui/docs/ai/improve"; + pub const DOCS_AI_SIMPLIFY: &'static str = "/api/ui/docs/ai/simplify"; + pub const DOCS_AI_TRANSLATE: &'static str = "/api/ui/docs/ai/translate"; + pub const DOCS_AI_CUSTOM: &'static str = "/api/ui/docs/ai/custom"; + pub const DOCS_EXPORT_PDF: &'static str = "/api/ui/docs/export/pdf"; + pub const DOCS_EXPORT_DOCX: &'static str = "/api/ui/docs/export/docx"; + pub const DOCS_EXPORT_MD: &'static str = "/api/ui/docs/export/md"; + pub const DOCS_EXPORT_HTML: &'static str = "/api/ui/docs/export/html"; + pub const DOCS_EXPORT_TXT: &'static str = "/api/ui/docs/export/txt"; + pub const DOCS_WS: &'static str = "/ws/docs/:doc_id"; + + // Paper (Notes App) - HTMX/HTML APIs pub const PAPER_NEW: &'static str = "/api/ui/paper/new"; pub const PAPER_LIST: &'static str = "/api/ui/paper/list"; pub const PAPER_SEARCH: &'static str = "/api/ui/paper/search"; @@ -290,6 +331,8 @@ impl ApiUrls { pub const PAPER_TEMPLATE_MEETING: &'static str = "/api/ui/paper/template/meeting"; pub const PAPER_TEMPLATE_TODO: &'static str = "/api/ui/paper/template/todo"; pub const PAPER_TEMPLATE_RESEARCH: &'static str = "/api/ui/paper/template/research"; + pub const PAPER_TEMPLATE_REPORT: &'static str = "/api/ui/paper/template/report"; + pub const PAPER_TEMPLATE_LETTER: &'static str = "/api/ui/paper/template/letter"; pub const PAPER_AI_SUMMARIZE: &'static str = "/api/ui/paper/ai/summarize"; pub const PAPER_AI_EXPAND: &'static str = "/api/ui/paper/ai/expand"; pub const PAPER_AI_IMPROVE: &'static str = "/api/ui/paper/ai/improve"; diff --git a/src/designer/mod.rs b/src/designer/mod.rs index 1e24f9849..d2f372f1c 100644 --- a/src/designer/mod.rs +++ b/src/designer/mod.rs @@ -32,6 +32,7 @@ pub struct ValidateRequest { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FileQuery { pub path: Option, + pub bucket: Option, } #[derive(Debug, QueryableByName)] @@ -114,7 +115,7 @@ pub fn configure_designer_routes() -> Router> { ApiUrls::DESIGNER_DIALOGS, get(handle_list_dialogs).post(handle_create_dialog), ) - .route(&ApiUrls::DESIGNER_DIALOG_BY_ID.replace(":id", "{id}"), get(handle_get_dialog)) + .route(ApiUrls::DESIGNER_DIALOG_BY_ID, get(handle_get_dialog)) .route(ApiUrls::DESIGNER_MODIFY, post(handle_designer_modify)) .route("/api/ui/designer/magic", post(handle_magic_suggestions)) .route("/api/ui/editor/magic", post(handle_editor_magic)) @@ -396,31 +397,43 @@ pub async fn handle_load_file( State(state): State>, Query(params): Query, ) -> impl IntoResponse { - let file_id = params.path.unwrap_or_else(|| "welcome".to_string()); - let conn = state.conn.clone(); + let file_path = params.path.unwrap_or_else(|| "welcome".to_string()); - let dialog = tokio::task::spawn_blocking(move || { - let mut db_conn = match conn.get() { + let content = if let Some(bucket) = params.bucket { + match load_from_drive(&state, &bucket, &file_path).await { Ok(c) => c, Err(e) => { - log::error!("DB connection error: {}", e); - return None; + log::error!("Failed to load file from drive: {}", e); + get_default_dialog_content() } - }; + } + } else { + let conn = state.conn.clone(); + let file_id = file_path; - diesel::sql_query( - "SELECT id, name, content, updated_at FROM designer_dialogs WHERE id = $1", - ) - .bind::(&file_id) - .get_result::(&mut db_conn) - .ok() - }) - .await - .unwrap_or(None); + let dialog = tokio::task::spawn_blocking(move || { + let mut db_conn = match conn.get() { + Ok(c) => c, + Err(e) => { + log::error!("DB connection error: {}", e); + return None; + } + }; - let content = match dialog { - Some(d) => d.content, - None => get_default_dialog_content(), + diesel::sql_query( + "SELECT id, name, content, updated_at FROM designer_dialogs WHERE id = $1", + ) + .bind::(&file_id) + .get_result::(&mut db_conn) + .ok() + }) + .await + .unwrap_or(None); + + match dialog { + Some(d) => d.content, + None => get_default_dialog_content(), + } }; let mut html = String::new(); @@ -850,6 +863,34 @@ fn validate_basic_code(code: &str) -> ValidationResult { } } +async fn load_from_drive( + state: &Arc, + bucket: &str, + path: &str, +) -> Result { + let s3_client = state + .drive + .as_ref() + .ok_or_else(|| "S3 service not available".to_string())?; + + let result = s3_client + .get_object() + .bucket(bucket) + .key(path) + .send() + .await + .map_err(|e| format!("Failed to read file from drive: {e}"))?; + + let bytes = result + .body + .collect() + .await + .map_err(|e| format!("Failed to read file body: {e}"))? + .into_bytes(); + + String::from_utf8(bytes.to_vec()).map_err(|e| format!("File is not valid UTF-8: {e}")) +} + fn get_default_dialog_content() -> String { "' Welcome Dialog\n\ ' Created with Dialog Designer\n\ diff --git a/src/directory/auth_routes.rs b/src/directory/auth_routes.rs new file mode 100644 index 000000000..99992f80a --- /dev/null +++ b/src/directory/auth_routes.rs @@ -0,0 +1,845 @@ +use axum::{ + extract::State, + http::{header, StatusCode}, + response::{IntoResponse, Json}, + routing::{get, post}, + Router, +}; +use log::{error, info, warn}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::shared::state::AppState; + +const BOOTSTRAP_SECRET_ENV: &str = "GB_BOOTSTRAP_SECRET"; + +#[derive(Debug, Deserialize)] +pub struct LoginRequest { + pub email: String, + pub password: String, + pub remember: Option, +} + +#[derive(Debug, Serialize)] +pub struct LoginResponse { + pub success: bool, + pub user_id: Option, + pub session_id: Option, + pub access_token: Option, + pub refresh_token: Option, + pub expires_in: Option, + pub requires_2fa: bool, + pub session_token: Option, + pub redirect: Option, + pub message: Option, +} + +#[derive(Debug, Serialize)] +pub struct CurrentUserResponse { + pub id: String, + pub username: String, + pub email: Option, + pub first_name: Option, + pub last_name: Option, + pub display_name: Option, + pub roles: Vec, + pub organization_id: Option, + pub avatar_url: Option, +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: String, + pub details: Option, +} + +#[derive(Debug, Serialize)] +pub struct LogoutResponse { + pub success: bool, + pub message: String, +} + +#[derive(Debug, Deserialize)] +pub struct TwoFactorRequest { + pub session_token: String, + pub code: String, + pub trust_device: Option, +} + +#[derive(Debug, Deserialize)] +pub struct RefreshTokenRequest { + pub refresh_token: String, +} + +#[derive(Debug, Deserialize)] +pub struct BootstrapAdminRequest { + pub bootstrap_secret: String, + pub email: String, + pub username: String, + pub password: String, + pub first_name: String, + pub last_name: String, + pub organization_name: Option, +} + +#[derive(Debug, Serialize)] +pub struct BootstrapResponse { + pub success: bool, + pub message: String, + pub user_id: Option, + pub organization_id: Option, +} + +pub fn configure() -> Router> { + Router::new() + .route("/api/auth/login", post(login)) + .route("/api/auth/logout", post(logout)) + .route("/api/auth/me", get(get_current_user)) + .route("/api/auth/refresh", post(refresh_token)) + .route("/api/auth/2fa/verify", post(verify_2fa)) + .route("/api/auth/2fa/resend", post(resend_2fa)) + .route("/api/auth/bootstrap", post(bootstrap_admin)) +} + +pub async fn login( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + info!("Login attempt for: {}", req.email); + + let client = { + let auth_service = state.auth_service.lock().await; + auth_service.client().clone() + }; + + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .map_err(|e| { + error!("Failed to create HTTP client: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Internal server error".to_string(), + details: None, + }), + ) + })?; + + let pat_path = std::path::Path::new("./botserver-stack/conf/directory/admin-pat.txt"); + let admin_token = std::fs::read_to_string(pat_path) + .map(|s| s.trim().to_string()) + .unwrap_or_default(); + + if admin_token.is_empty() { + error!("Admin PAT token not found"); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Authentication service not configured".to_string(), + details: None, + }), + )); + } + + let search_url = format!("{}/v2/users", client.api_url()); + let search_body = serde_json::json!({ + "queries": [{ + "emailQuery": { + "emailAddress": req.email, + "method": "TEXT_QUERY_METHOD_EQUALS" + } + }] + }); + + let user_response = http_client + .post(&search_url) + .bearer_auth(&admin_token) + .json(&search_body) + .send() + .await + .map_err(|e| { + error!("Failed to search user: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Authentication service error".to_string(), + details: None, + }), + ) + })?; + + if !user_response.status().is_success() { + let error_text = user_response.text().await.unwrap_or_default(); + error!("User search failed: {}", error_text); + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid email or password".to_string(), + details: None, + }), + )); + } + + let user_data: serde_json::Value = user_response.json().await.map_err(|e| { + error!("Failed to parse user response: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Authentication service error".to_string(), + details: None, + }), + ) + })?; + + let user_id = user_data + .get("result") + .and_then(|r| r.as_array()) + .and_then(|arr| arr.first()) + .and_then(|u| u.get("userId")) + .and_then(|id| id.as_str()) + .map(String::from); + + let user_id = match user_id { + Some(id) => id, + None => { + error!("User not found: {}", req.email); + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid email or password".to_string(), + details: None, + }), + )); + } + }; + + let session_url = format!("{}/v2/sessions", client.api_url()); + let session_body = serde_json::json!({ + "checks": { + "user": { + "userId": user_id + }, + "password": { + "password": req.password + } + } + }); + + let session_response = http_client + .post(&session_url) + .bearer_auth(&admin_token) + .json(&session_body) + .send() + .await + .map_err(|e| { + error!("Failed to create session: {}", e); + ( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Authentication failed".to_string(), + details: None, + }), + ) + })?; + + if !session_response.status().is_success() { + let status = session_response.status(); + let error_text = session_response.text().await.unwrap_or_default(); + error!("Session creation failed: {} - {}", status, error_text); + + if error_text.contains("password") || error_text.contains("invalid") { + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid email or password".to_string(), + details: None, + }), + )); + } + + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Authentication failed".to_string(), + details: None, + }), + )); + } + + let session_data: serde_json::Value = session_response.json().await.map_err(|e| { + error!("Failed to parse session response: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Invalid response from authentication server".to_string(), + details: None, + }), + ) + })?; + + let session_id = session_data + .get("sessionId") + .and_then(|s| s.as_str()) + .map(String::from); + + let session_token = session_data + .get("sessionToken") + .and_then(|s| s.as_str()) + .map(String::from); + + info!("Login successful for: {} (user_id: {})", req.email, user_id); + + Ok(Json(LoginResponse { + success: true, + user_id: Some(user_id), + session_id: session_id.clone(), + access_token: session_id, + refresh_token: None, + expires_in: Some(3600), + requires_2fa: false, + session_token, + redirect: Some("/".to_string()), + message: Some("Login successful".to_string()), + })) +} + +pub async fn logout( + State(_state): State>, + headers: axum::http::HeaderMap, +) -> Result, (StatusCode, Json)> { + let token = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .and_then(|auth| auth.strip_prefix("Bearer ")) + .map(String::from); + + if let Some(ref _token) = token { + info!("User logged out"); + } + + Ok(Json(LogoutResponse { + success: true, + message: "Logged out successfully".to_string(), + })) +} + +pub async fn get_current_user( + State(state): State>, + headers: axum::http::HeaderMap, +) -> Result, (StatusCode, Json)> { + let session_token = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .and_then(|auth| auth.strip_prefix("Bearer ")) + .ok_or_else(|| { + ( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Missing authorization token".to_string(), + details: None, + }), + ) + })?; + + let client = { + let auth_service = state.auth_service.lock().await; + auth_service.client().clone() + }; + + let pat_path = std::path::Path::new("./botserver-stack/conf/directory/admin-pat.txt"); + let admin_token = std::fs::read_to_string(pat_path) + .map(|s| s.trim().to_string()) + .unwrap_or_default(); + + if admin_token.is_empty() { + error!("Admin PAT token not found for user lookup"); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Authentication service not configured".to_string(), + details: None, + }), + )); + } + + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .map_err(|e| { + error!("Failed to create HTTP client: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Internal server error".to_string(), + details: None, + }), + ) + })?; + + let session_url = format!("{}/v2/sessions/{}", client.api_url(), session_token); + let session_response = http_client + .get(&session_url) + .bearer_auth(&admin_token) + .send() + .await + .map_err(|e| { + error!("Failed to get session: {}", e); + ( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Session validation failed".to_string(), + details: None, + }), + ) + })?; + + if !session_response.status().is_success() { + let error_text = session_response.text().await.unwrap_or_default(); + error!("Session lookup failed: {}", error_text); + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid or expired session".to_string(), + details: None, + }), + )); + } + + let session_data: serde_json::Value = session_response.json().await.map_err(|e| { + error!("Failed to parse session response: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to parse session data".to_string(), + details: None, + }), + ) + })?; + + let user_id = session_data + .get("session") + .and_then(|s| s.get("factors")) + .and_then(|f| f.get("user")) + .and_then(|u| u.get("id")) + .and_then(|id| id.as_str()) + .unwrap_or_default() + .to_string(); + + if user_id.is_empty() { + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid session - no user found".to_string(), + details: None, + }), + )); + } + + let user_url = format!("{}/v2/users/{}", client.api_url(), user_id); + let user_response = http_client + .get(&user_url) + .bearer_auth(&admin_token) + .send() + .await + .map_err(|e| { + error!("Failed to get user: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to fetch user data".to_string(), + details: None, + }), + ) + })?; + + if !user_response.status().is_success() { + let error_text = user_response.text().await.unwrap_or_default(); + error!("User lookup failed: {}", error_text); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to fetch user data".to_string(), + details: None, + }), + )); + } + + let user_data: serde_json::Value = user_response.json().await.map_err(|e| { + error!("Failed to parse user response: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to parse user data".to_string(), + details: None, + }), + ) + })?; + + let user = user_data.get("user").unwrap_or(&user_data); + let human = user.get("human"); + + let username = user + .get("userName") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + + let email = human + .and_then(|h| h.get("email")) + .and_then(|e| e.get("email")) + .and_then(|v| v.as_str()) + .map(String::from); + + let first_name = human + .and_then(|h| h.get("profile")) + .and_then(|p| p.get("givenName")) + .and_then(|v| v.as_str()) + .map(String::from); + + let last_name = human + .and_then(|h| h.get("profile")) + .and_then(|p| p.get("familyName")) + .and_then(|v| v.as_str()) + .map(String::from); + + let display_name = human + .and_then(|h| h.get("profile")) + .and_then(|p| p.get("displayName")) + .and_then(|v| v.as_str()) + .map(String::from); + + let organization_id = user + .get("details") + .and_then(|d| d.get("resourceOwner")) + .and_then(|v| v.as_str()) + .map(String::from); + + info!("User profile loaded for: {} ({})", username, user_id); + + Ok(Json(CurrentUserResponse { + id: user_id, + username, + email, + first_name, + last_name, + display_name, + roles: vec!["admin".to_string()], + organization_id, + avatar_url: None, + })) +} + +pub async fn refresh_token( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let client = { + let auth_service = state.auth_service.lock().await; + auth_service.client().clone() + }; + + let token_url = format!("{}/oauth/v2/token", client.api_url()); + + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .map_err(|e| { + error!("Failed to create HTTP client: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Internal server error".to_string(), + details: None, + }), + ) + })?; + + let params = [ + ("grant_type", "refresh_token"), + ("refresh_token", &req.refresh_token), + ("scope", "openid profile email offline_access"), + ]; + + let response = http_client + .post(&token_url) + .form(¶ms) + .send() + .await + .map_err(|e| { + error!("Failed to refresh token: {}", e); + ( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Token refresh failed".to_string(), + details: None, + }), + ) + })?; + + if !response.status().is_success() { + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid or expired refresh token".to_string(), + details: None, + }), + )); + } + + let token_data: serde_json::Value = response.json().await.map_err(|e| { + error!("Failed to parse token response: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Invalid response from authentication server".to_string(), + details: None, + }), + ) + })?; + + let access_token = token_data + .get("access_token") + .and_then(|t| t.as_str()) + .map(String::from); + + let refresh_token = token_data + .get("refresh_token") + .and_then(|t| t.as_str()) + .map(String::from); + + let expires_in = token_data.get("expires_in").and_then(|t| t.as_i64()); + + Ok(Json(LoginResponse { + success: true, + user_id: None, + session_id: None, + access_token, + refresh_token, + expires_in, + requires_2fa: false, + session_token: None, + redirect: None, + message: Some("Token refreshed successfully".to_string()), + })) +} + +pub async fn verify_2fa( + State(_state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + info!( + "2FA verification attempt for session: {}", + req.session_token + ); + + Err(( + StatusCode::NOT_IMPLEMENTED, + Json(ErrorResponse { + error: "2FA verification not yet implemented".to_string(), + details: Some("This feature will be available in a future update".to_string()), + }), + )) +} + +pub async fn resend_2fa( + State(_state): State>, + Json(_req): Json, +) -> impl IntoResponse { + ( + StatusCode::NOT_IMPLEMENTED, + Json(ErrorResponse { + error: "2FA resend not yet implemented".to_string(), + details: Some("This feature will be available in a future update".to_string()), + }), + ) +} + +pub async fn bootstrap_admin( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + info!("Bootstrap admin request received"); + + let expected_secret = std::env::var(BOOTSTRAP_SECRET_ENV).unwrap_or_default(); + + if expected_secret.is_empty() { + warn!("Bootstrap endpoint called but GB_BOOTSTRAP_SECRET not set"); + return Err(( + StatusCode::FORBIDDEN, + Json(ErrorResponse { + error: "Bootstrap not enabled".to_string(), + details: Some("Set GB_BOOTSTRAP_SECRET environment variable to enable bootstrap".to_string()), + }), + )); + } + + if req.bootstrap_secret != expected_secret { + warn!("Bootstrap attempt with invalid secret"); + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Invalid bootstrap secret".to_string(), + details: None, + }), + )); + } + + let client = { + let auth_service = state.auth_service.lock().await; + auth_service.client().clone() + }; + + let existing_users = client.list_users(1, 0).await.unwrap_or_default(); + if !existing_users.is_empty() { + let has_admin = existing_users.iter().any(|u| { + u.get("roles") + .and_then(|r| r.as_array()) + .map(|roles| { + roles.iter().any(|r| { + r.as_str() + .map(|s| s.to_lowercase().contains("admin")) + .unwrap_or(false) + }) + }) + .unwrap_or(false) + }); + + if has_admin { + return Err(( + StatusCode::CONFLICT, + Json(ErrorResponse { + error: "Admin user already exists".to_string(), + details: Some("Bootstrap can only be used for initial setup".to_string()), + }), + )); + } + } + + let user_id = match client + .create_user(&req.email, &req.first_name, &req.last_name, Some(&req.username)) + .await + { + Ok(id) => { + info!("Bootstrap admin user created: {}", id); + id + } + Err(e) => { + error!("Failed to create bootstrap admin: {}", e); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to create admin user".to_string(), + details: Some(e.to_string()), + }), + )); + } + }; + + if let Err(e) = set_user_password(&client, &user_id, &req.password).await { + error!("Failed to set admin password: {}", e); + } + + let org_name = req.organization_name.unwrap_or_else(|| "Default Organization".to_string()); + let org_id = match create_organization(&client, &org_name).await { + Ok(id) => { + info!("Bootstrap organization created: {}", id); + Some(id) + } + Err(e) => { + warn!("Failed to create organization (may already exist): {}", e); + None + } + }; + + if let Some(ref oid) = org_id { + let admin_roles = vec![ + "admin".to_string(), + "org_owner".to_string(), + "user_manager".to_string(), + ]; + if let Err(e) = client.add_org_member(oid, &user_id, admin_roles).await { + error!("Failed to add admin to organization: {}", e); + } else { + info!("Admin user added to organization with admin roles"); + } + } + + info!( + "Bootstrap complete: admin user {} created successfully", + req.username + ); + + Ok(Json(BootstrapResponse { + success: true, + message: format!( + "Admin user '{}' created successfully. You can now login with your credentials.", + req.username + ), + user_id: Some(user_id), + organization_id: org_id, + })) +} + +async fn set_user_password( + client: &crate::directory::client::ZitadelClient, + user_id: &str, + password: &str, +) -> Result<(), String> { + let url = format!("{}/v2/users/{}/password", client.api_url(), user_id); + + let body = serde_json::json!({ + "newPassword": { + "password": password, + "changeRequired": false + } + }); + + let response = client + .http_post(url) + .await + .json(&body) + .send() + .await + .map_err(|e| e.to_string())?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to set password: {}", error_text)) + } +} + +async fn create_organization( + client: &crate::directory::client::ZitadelClient, + name: &str, +) -> Result { + let url = format!("{}/v2/organizations", client.api_url()); + + let body = serde_json::json!({ + "name": name + }); + + let response = client + .http_post(url) + .await + .json(&body) + .send() + .await + .map_err(|e| e.to_string())?; + + if response.status().is_success() { + let data: serde_json::Value = response.json().await.map_err(|e| e.to_string())?; + let org_id = data + .get("organizationId") + .or_else(|| data.get("id")) + .and_then(|v| v.as_str()) + .ok_or_else(|| "No organization ID in response".to_string())? + .to_string(); + Ok(org_id) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to create organization: {}", error_text)) + } +} diff --git a/src/directory/bootstrap.rs b/src/directory/bootstrap.rs new file mode 100644 index 000000000..7cecd770a --- /dev/null +++ b/src/directory/bootstrap.rs @@ -0,0 +1,356 @@ +use anyhow::Result; +use log::{error, info, warn}; +use rand::Rng; +use std::fs; +use std::os::unix::fs::PermissionsExt; + +use super::client::ZitadelClient; + +const ADMIN_USERNAME: &str = "admin"; +const DEFAULT_ORG_NAME: &str = "General Bots"; + +pub struct BootstrapResult { + pub user_id: String, + pub organization_id: Option, + pub username: String, + pub email: String, + pub initial_password: String, + pub setup_url: String, +} + +pub async fn check_and_bootstrap_admin(client: &ZitadelClient) -> Result> { + info!("Checking if bootstrap is needed..."); + + match client.list_users(10, 0).await { + Ok(users) => { + if !users.is_empty() { + let has_admin = users.iter().any(|u| { + let username = u + .get("userName") + .or_else(|| u.get("username")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let has_admin_role = u + .get("roles") + .and_then(|r| r.as_array()) + .map(|roles| { + roles.iter().any(|r| { + r.as_str() + .map(|s| s.to_lowercase().contains("admin")) + .unwrap_or(false) + }) + }) + .unwrap_or(false); + + username == ADMIN_USERNAME || has_admin_role + }); + + if has_admin { + info!("Admin user already exists, skipping bootstrap"); + return Ok(None); + } + } + } + Err(e) => { + warn!( + "Could not check existing users (may be first run): {}", + e + ); + } + } + + info!("No admin user found, bootstrapping initial admin account..."); + + let result = create_bootstrap_admin(client).await?; + + print_bootstrap_credentials(&result); + + Ok(Some(result)) +} + +fn generate_secure_password() -> String { + let mut rng = rand::rng(); + + let lowercase: Vec = (b'a'..=b'z').map(|c| c as char).collect(); + let uppercase: Vec = (b'A'..=b'Z').map(|c| c as char).collect(); + let digits: Vec = (b'0'..=b'9').map(|c| c as char).collect(); + let special: Vec = "!@#$%&*".chars().collect(); + + let mut password = Vec::with_capacity(16); + + password.push(lowercase[rng.random_range(0..lowercase.len())]); + password.push(uppercase[rng.random_range(0..uppercase.len())]); + password.push(digits[rng.random_range(0..digits.len())]); + password.push(special[rng.random_range(0..special.len())]); + + let all_chars: Vec = lowercase + .iter() + .chain(uppercase.iter()) + .chain(digits.iter()) + .chain(special.iter()) + .copied() + .collect(); + + for _ in 0..12 { + password.push(all_chars[rng.random_range(0..all_chars.len())]); + } + + for i in (1..password.len()).rev() { + let j = rng.random_range(0..=i); + password.swap(i, j); + } + + password.into_iter().collect() +} + +async fn create_bootstrap_admin(client: &ZitadelClient) -> Result { + let email = format!("{}@localhost", ADMIN_USERNAME); + + let user_id = client + .create_user(&email, "System", "Administrator", Some(ADMIN_USERNAME)) + .await + .map_err(|e| anyhow::anyhow!("Failed to create admin user: {}", e))?; + + info!("Created admin user with ID: {}", user_id); + + let initial_password = generate_secure_password(); + + if let Err(e) = client.set_user_password(&user_id, &initial_password, true).await { + warn!("Failed to set initial password via API: {}. User may need to use password reset flow.", e); + } else { + info!("Initial password set for admin user"); + } + + let org_id = match create_default_organization(client).await { + Ok(id) => { + info!("Created default organization with ID: {}", id); + + let admin_roles = vec![ + "admin".to_string(), + "org_owner".to_string(), + "user_manager".to_string(), + ]; + if let Err(e) = client.add_org_member(&id, &user_id, admin_roles).await { + warn!("Failed to add admin to organization: {}", e); + } + + Some(id) + } + Err(e) => { + warn!("Failed to create default organization: {}", e); + None + } + }; + + let base_url = client.api_url(); + let setup_url = format!("{}/ui/login", base_url); + + let result = BootstrapResult { + user_id: user_id.clone(), + organization_id: org_id, + username: ADMIN_USERNAME.to_string(), + email: email.clone(), + initial_password: initial_password.clone(), + setup_url: setup_url.clone(), + }; + + save_setup_credentials(&result); + + Ok(result) +} + +async fn create_default_organization(client: &ZitadelClient) -> Result { + let url = format!("{}/v2/organizations", client.api_url()); + + let body = serde_json::json!({ + "name": DEFAULT_ORG_NAME + }); + + let response = client + .http_post(url) + .await + .json(&body) + .send() + .await + .map_err(|e| anyhow::anyhow!("Failed to create organization: {}", e))?; + + if response.status().is_success() { + let data: serde_json::Value = response + .json() + .await + .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?; + + let org_id = data + .get("organizationId") + .or_else(|| data.get("id")) + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("No organization ID in response"))? + .to_string(); + + Ok(org_id) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(anyhow::anyhow!( + "Failed to create organization: {}", + error_text + )) + } +} + +fn save_setup_credentials(result: &BootstrapResult) { + let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); + let creds_path = format!("{}/.gb-setup-credentials", home); + + let content = format!( + r#"# General Bots Initial Setup Credentials +# Created: {} +# DELETE THIS FILE AFTER FIRST LOGIN + +╔════════════════════════════════════════════════════════════╗ +║ ADMIN LOGIN CREDENTIALS (OTP) ║ +╠════════════════════════════════════════════════════════════╣ +║ ║ +║ Username: {:<46}║ +║ Password: {:<46}║ +║ Email: {:<46}║ +║ ║ +║ Login URL: {:<45}║ +║ ║ +╚════════════════════════════════════════════════════════════╝ + +IMPORTANT: +- This is a one-time password (OTP) +- You will be required to change it on first login +- Delete this file after you have logged in successfully + +Alternative access via Zitadel console: +1. Go to: {}/ui/console +2. Login with admin PAT from: ./botserver-stack/conf/directory/admin-pat.txt +3. Find user '{}' and manage settings +"#, + chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC"), + result.username, + result.initial_password, + result.email, + result.setup_url, + result.setup_url.split("/ui/").next().unwrap_or("http://localhost:8300"), + result.username + ); + + match fs::write(&creds_path, &content) { + Ok(_) => { + #[cfg(unix)] + { + if let Err(e) = fs::set_permissions(&creds_path, fs::Permissions::from_mode(0o600)) { + warn!("Failed to set file permissions: {}", e); + } + } + info!("Setup credentials saved to: {}", creds_path); + } + Err(e) => { + error!("Failed to save setup credentials: {}", e); + } + } +} + +fn print_bootstrap_credentials(result: &BootstrapResult) { + let separator = "═".repeat(60); + + println!(); + println!("╔{}╗", separator); + println!("║{:^60}║", ""); + println!("║{:^60}║", "🤖 GENERAL BOTS - INITIAL SETUP"); + println!("║{:^60}║", ""); + println!("╠{}╣", separator); + println!("║{:^60}║", ""); + println!("║ {:56}║", "Administrator account created!"); + println!("║{:^60}║", ""); + println!("╠{}╣", separator); + println!("║{:^60}║", ""); + println!("║{:^60}║", "🔐 ONE-TIME PASSWORD (OTP) FOR LOGIN:"); + println!("║{:^60}║", ""); + println!("║ {:<58}║", format!("Username: {}", result.username)); + println!("║ {:<58}║", format!("Password: {}", result.initial_password)); + println!("║ {:<58}║", format!("Email: {}", result.email)); + println!("║{:^60}║", ""); + + if let Some(ref org_id) = result.organization_id { + println!( + "║ {:<58}║", + format!("Organization: {} ({})", DEFAULT_ORG_NAME, &org_id[..8.min(org_id.len())]) + ); + println!("║{:^60}║", ""); + } + + println!("╠{}╣", separator); + println!("║{:^60}║", ""); + println!("║ {:56}║", "🌐 LOGIN URL:"); + println!("║{:^60}║", ""); + + let url_display = if result.setup_url.len() > 54 { + format!("{}...", &result.setup_url[..51]) + } else { + result.setup_url.clone() + }; + println!("║ {:56}║", url_display); + println!("║{:^60}║", ""); + println!("╠{}╣", separator); + println!("║{:^60}║", ""); + println!("║ ⚠️ {:<53}║", "IMPORTANT - SAVE THESE CREDENTIALS!"); + println!("║{:^60}║", ""); + println!("║ {:<56}║", "• This password will NOT be shown again"); + println!("║ {:<56}║", "• You must change it on first login"); + println!("║ {:<56}║", "• Credentials also saved to: ~/.gb-setup-credentials"); + println!("║{:^60}║", ""); + println!("╚{}╝", separator); + println!(); + + info!( + "Bootstrap complete: admin user '{}' created with OTP password", + result.username + ); +} + +pub fn print_existing_admin_notice() { + println!(); + println!("ℹ️ Admin user already exists. Skipping bootstrap."); + println!(" If you forgot your password, use Zitadel console to reset it."); + println!(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_secure_password() { + let password = generate_secure_password(); + + assert!(password.len() >= 14); + + let has_lower = password.chars().any(|c| c.is_ascii_lowercase()); + let has_upper = password.chars().any(|c| c.is_ascii_uppercase()); + let has_digit = password.chars().any(|c| c.is_ascii_digit()); + let has_special = password.chars().any(|c| "!@#$%&*".contains(c)); + + assert!(has_lower, "Password should contain lowercase"); + assert!(has_upper, "Password should contain uppercase"); + assert!(has_digit, "Password should contain digits"); + assert!(has_special, "Password should contain special chars"); + } + + #[test] + fn test_password_uniqueness() { + let passwords: Vec = (0..10).map(|_| generate_secure_password()).collect(); + + for i in 0..passwords.len() { + for j in (i + 1)..passwords.len() { + assert_ne!( + passwords[i], passwords[j], + "Generated passwords should be unique" + ); + } + } + } +} diff --git a/src/directory/client.rs b/src/directory/client.rs index 1a3e580e3..d24905a02 100644 --- a/src/directory/client.rs +++ b/src/directory/client.rs @@ -20,6 +20,7 @@ pub struct ZitadelClient { config: ZitadelConfig, http_client: reqwest::Client, access_token: Arc>>, + pat_token: Option, } impl ZitadelClient { @@ -33,13 +34,40 @@ impl ZitadelClient { config, http_client, access_token: Arc::new(RwLock::new(None)), + pat_token: None, }) } + pub fn with_pat_token(config: ZitadelConfig, pat_token: String) -> Result { + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?; + + Ok(Self { + config, + http_client, + access_token: Arc::new(RwLock::new(None)), + pat_token: Some(pat_token), + }) + } + + pub fn set_pat_token(&mut self, token: String) { + self.pat_token = Some(token); + } + pub fn api_url(&self) -> &str { &self.config.api_url } + pub fn client_id(&self) -> &str { + &self.config.client_id + } + + pub fn client_secret(&self) -> &str { + &self.config.client_secret + } + pub async fn http_get(&self, url: String) -> reqwest::RequestBuilder { let token = self.get_access_token().await.unwrap_or_default(); self.http_client.get(url).bearer_auth(token) @@ -60,7 +88,16 @@ impl ZitadelClient { self.http_client.patch(url).bearer_auth(token) } + pub async fn http_delete(&self, url: String) -> reqwest::RequestBuilder { + let token = self.get_access_token().await.unwrap_or_default(); + self.http_client.delete(url).bearer_auth(token) + } + pub async fn get_access_token(&self) -> Result { + if let Some(ref pat) = self.pat_token { + return Ok(pat.clone()); + } + { let token = self.access_token.read().await; if let Some(t) = token.as_ref() { @@ -69,6 +106,7 @@ impl ZitadelClient { } let token_url = format!("{}/oauth/v2/token", self.config.api_url); + log::info!("Requesting access token from: {}", token_url); let params = [ ("grant_type", "client_credentials"), @@ -123,7 +161,7 @@ impl ZitadelClient { }, "email": { "email": email, - "isVerified": false + "isVerified": true } }); @@ -467,4 +505,32 @@ impl ZitadelClient { Ok(true) } + + pub async fn set_user_password(&self, user_id: &str, password: &str, change_required: bool) -> Result<()> { + let token = self.get_access_token().await?; + let url = format!("{}/v2/users/{}/password", self.config.api_url, user_id); + + let body = serde_json::json!({ + "newPassword": { + "password": password, + "changeRequired": change_required + } + }); + + let response = self + .http_client + .post(&url) + .bearer_auth(&token) + .json(&body) + .send() + .await + .map_err(|e| anyhow!("Failed to set password: {}", e))?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(anyhow!("Failed to set password: {}", error_text)); + } + + Ok(()) + } } diff --git a/src/directory/mod.rs b/src/directory/mod.rs index d0e05c179..19a0a4f0a 100644 --- a/src/directory/mod.rs +++ b/src/directory/mod.rs @@ -9,6 +9,8 @@ use std::collections::HashMap; use std::sync::Arc; use uuid::Uuid; +pub mod auth_routes; +pub mod bootstrap; pub mod client; pub mod groups; pub mod router; diff --git a/src/directory/router.rs b/src/directory/router.rs index 20c57bf10..a43ca32e9 100644 --- a/src/directory/router.rs +++ b/src/directory/router.rs @@ -1,4 +1,3 @@ - use axum::{ routing::{delete, get, post, put}, Router, @@ -10,32 +9,43 @@ use crate::shared::state::AppState; use super::groups; use super::users; - - pub fn configure() -> Router> { Router::new() - - - .route("/users/create", post(users::create_user)) - .route("/users/{user_id}/update", put(users::update_user)) - .route("/users/{user_id}/delete", delete(users::delete_user)) + .route("/users/:user_id/update", put(users::update_user)) + .route("/users/:user_id/delete", delete(users::delete_user)) .route("/users/list", get(users::list_users)) .route("/users/search", get(users::list_users)) - .route("/users/{user_id}/profile", get(users::get_user_profile)) - .route("/users/{user_id}/profile/update", put(users::update_user)) - .route("/users/{user_id}/settings", get(users::get_user_profile)) - .route("/users/{user_id}/permissions", get(users::get_user_profile)) - .route("/users/{user_id}/roles", get(users::get_user_profile)) - .route("/users/{user_id}/status", get(users::get_user_profile)) - .route("/users/{user_id}/presence", get(users::get_user_profile)) - .route("/users/{user_id}/activity", get(users::get_user_profile)) + .route("/users/:user_id/profile", get(users::get_user_profile)) + .route("/users/:user_id/profile/update", put(users::update_user)) + .route("/users/:user_id/settings", get(users::get_user_profile)) + .route("/users/:user_id/permissions", get(users::get_user_profile)) + .route("/users/:user_id/roles", get(users::get_user_profile)) + .route("/users/:user_id/status", get(users::get_user_profile)) + .route("/users/:user_id/presence", get(users::get_user_profile)) + .route("/users/:user_id/activity", get(users::get_user_profile)) .route( - "/users/{user_id}/security/2fa/enable", + "/users/:user_id/organization", + post(users::assign_organization), + ) + .route( + "/users/:user_id/organization/:org_id", + delete(users::remove_from_organization), + ) + .route( + "/users/:user_id/organization/:org_id/roles", + put(users::update_user_roles), + ) + .route( + "/users/:user_id/memberships", + get(users::get_user_memberships), + ) + .route( + "/users/:user_id/security/2fa/enable", post(users::get_user_profile), ) .route( - "/users/{user_id}/security/2fa/disable", + "/users/:user_id/security/2fa/disable", post(users::get_user_profile), ) .route( @@ -47,36 +57,33 @@ pub fn configure() -> Router> { get(users::get_user_profile), ) .route( - "/users/{user_id}/notifications/preferences/update", + "/users/:user_id/notifications/preferences/update", get(users::get_user_profile), ) - - - .route("/groups/create", post(groups::create_group)) - .route("/groups/{group_id}/update", put(groups::update_group)) - .route("/groups/{group_id}/delete", delete(groups::delete_group)) + .route("/groups/:group_id/update", put(groups::update_group)) + .route("/groups/:group_id/delete", delete(groups::delete_group)) .route("/groups/list", get(groups::list_groups)) .route("/groups/search", get(groups::list_groups)) - .route("/groups/{group_id}/members", get(groups::get_group_members)) + .route("/groups/:group_id/members", get(groups::get_group_members)) .route( - "/groups/{group_id}/members/add", + "/groups/:group_id/members/add", post(groups::add_group_member), ) .route( - "/groups/{group_id}/members/roles", + "/groups/:group_id/members/roles", post(groups::remove_group_member), ) .route( - "/groups/{group_id}/permissions", + "/groups/:group_id/permissions", get(groups::get_group_members), ) .route( - "/groups/{group_id}/settings", + "/groups/:group_id/settings", get(groups::get_group_members), ) .route( - "/groups/{group_id}/analytics", + "/groups/:group_id/analytics", get(groups::get_group_members), ) .route( diff --git a/src/directory/users.rs b/src/directory/users.rs index 3d2c88ad6..d4cca8cd2 100644 --- a/src/directory/users.rs +++ b/src/directory/users.rs @@ -1,4 +1,3 @@ - use axum::{ extract::{Path, Query, State}, http::StatusCode, @@ -9,20 +8,19 @@ use log::{error, info}; use serde::{Deserialize, Serialize}; use std::sync::Arc; - use crate::shared::state::AppState; - - #[derive(Debug, Deserialize)] pub struct CreateUserRequest { pub username: String, pub email: String, - pub password: String, + pub password: Option, pub first_name: String, pub last_name: String, pub display_name: Option, pub role: Option, + pub organization_id: Option, + pub roles: Option>, } #[derive(Debug, Deserialize)] @@ -33,6 +31,8 @@ pub struct UpdateUserRequest { pub display_name: Option, pub email: Option, pub phone: Option, + pub organization_id: Option, + pub roles: Option>, } #[derive(Debug, Deserialize)] @@ -40,6 +40,7 @@ pub struct UserQuery { pub page: Option, pub per_page: Option, pub search: Option, + pub organization_id: Option, } #[derive(Debug, Serialize)] @@ -51,6 +52,8 @@ pub struct UserResponse { pub last_name: String, pub display_name: Option, pub state: String, + pub organization_id: Option, + pub roles: Vec, pub created_at: Option>, pub updated_at: Option>, } @@ -76,8 +79,16 @@ pub struct ErrorResponse { pub details: Option, } +#[derive(Debug, Deserialize)] +pub struct AssignOrganizationRequest { + pub organization_id: String, + pub roles: Option>, +} - +#[derive(Debug, Deserialize)] +pub struct UpdateRolesRequest { + pub roles: Vec, +} pub async fn create_user( State(state): State>, @@ -85,43 +96,51 @@ pub async fn create_user( ) -> Result, (StatusCode, Json)> { info!("Creating user: {} ({})", req.username, req.email); - let client = { let auth_service = state.auth_service.lock().await; auth_service.client().clone() }; - - match client - .create_user( - &req.email, - &req.first_name, - &req.last_name, - Some(&req.username), - ) + let user_id = match client + .create_user(&req.email, &req.first_name, &req.last_name, Some(&req.username)) .await { - Ok(user_id) => { - info!("User created successfully: {}", user_id); - Ok(Json(SuccessResponse { - success: true, - message: Some(format!("User {} created successfully", req.username)), - user_id: Some(user_id), - })) - } + Ok(id) => id, Err(e) => { - error!("Failed to create user: {}", e); - Err(( + error!("Failed to create user in Zitadel: {}", e); + return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { error: "Failed to create user".to_string(), details: Some(e.to_string()), }), - )) + )); + } + }; + + if let Some(ref org_id) = req.organization_id { + let roles = req.roles.clone().unwrap_or_else(|| vec!["user".to_string()]); + + if let Err(e) = client.add_org_member(org_id, &user_id, roles.clone()).await { + error!( + "Failed to add user {} to organization {}: {}", + user_id, org_id, e + ); + } else { + info!( + "User {} added to organization {} with roles {:?}", + user_id, org_id, roles + ); } } -} + info!("User created successfully: {}", user_id); + Ok(Json(SuccessResponse { + success: true, + message: Some(format!("User {} created successfully", req.username)), + user_id: Some(user_id), + })) +} pub async fn update_user( State(state): State>, @@ -135,7 +154,6 @@ pub async fn update_user( auth_service.client().clone() }; - let mut update_data = serde_json::Map::new(); if let Some(username) = &req.username { update_data.insert("userName".to_string(), serde_json::json!(username)); @@ -156,45 +174,49 @@ pub async fn update_user( update_data.insert("phone".to_string(), serde_json::json!(phone)); } - - match client - .http_patch(format!("{}/users/{}", client.api_url(), user_id)) - .await - .json(&serde_json::Value::Object(update_data)) - .send() - .await - { - Ok(response) if response.status().is_success() => { - info!("User {} updated successfully", user_id); - Ok(Json(SuccessResponse { - success: true, - message: Some(format!("User {} updated successfully", user_id)), - user_id: Some(user_id), - })) - } - Ok(_) => { - error!("Failed to update user: unexpected response"); - Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Failed to update user".to_string(), - details: Some("Unexpected response from server".to_string()), - }), - )) - } - Err(e) => { - error!("Failed to update user: {}", e); - Err(( - StatusCode::NOT_FOUND, - Json(ErrorResponse { - error: "User not found".to_string(), - details: Some(e.to_string()), - }), - )) + if !update_data.is_empty() { + match client + .http_patch(format!("{}/users/{}", client.api_url(), user_id)) + .await + .json(&serde_json::Value::Object(update_data)) + .send() + .await + { + Ok(response) if response.status().is_success() => { + info!("User {} profile updated successfully", user_id); + } + Ok(response) => { + let status = response.status(); + error!("Failed to update user profile: {}", status); + } + Err(e) => { + error!("Failed to update user profile: {}", e); + } } } -} + if let Some(ref org_id) = req.organization_id { + let roles = req.roles.clone().unwrap_or_else(|| vec!["user".to_string()]); + + if let Err(e) = client.add_org_member(org_id, &user_id, roles.clone()).await { + error!( + "Failed to update user {} organization membership: {}", + user_id, e + ); + } else { + info!( + "User {} organization membership updated to {} with roles {:?}", + user_id, org_id, roles + ); + } + } + + Ok(Json(SuccessResponse { + success: true, + message: Some(format!("User {} updated successfully", user_id)), + user_id: Some(user_id), + })) +} pub async fn delete_user( State(state): State>, @@ -207,23 +229,37 @@ pub async fn delete_user( auth_service.client().clone() }; - - match client.get_user(&user_id).await { - Ok(_) => { - - info!("User {} deleted/deactivated", user_id); + match client + .http_delete(format!("{}/v2/users/{}", client.api_url(), user_id)) + .await + .send() + .await + { + Ok(response) if response.status().is_success() => { + info!("User {} deleted successfully", user_id); Ok(Json(SuccessResponse { success: true, message: Some(format!("User {} deleted successfully", user_id)), user_id: Some(user_id), })) } + Ok(response) => { + let status = response.status(); + error!("Failed to delete user: {}", status); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to delete user".to_string(), + details: Some(format!("Server returned {}", status)), + }), + )) + } Err(e) => { error!("Failed to delete user: {}", e); Err(( - StatusCode::NOT_FOUND, + StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { - error: "User not found".to_string(), + error: "Failed to delete user".to_string(), details: Some(e.to_string()), }), )) @@ -231,7 +267,6 @@ pub async fn delete_user( } } - pub async fn list_users( State(state): State>, Query(params): Query, @@ -246,9 +281,12 @@ pub async fn list_users( auth_service.client().clone() }; - let users_result = if let Some(search_term) = params.search { + let users_result = if let Some(ref org_id) = params.organization_id { + info!("Filtering users by organization: {}", org_id); + client.get_org_members(org_id).await + } else if let Some(ref search_term) = params.search { info!("Searching users with term: {}", search_term); - client.search_users(&search_term).await + client.search_users(search_term).await } else { let offset = (page - 1) * per_page; client.list_users(per_page, offset).await @@ -259,22 +297,63 @@ pub async fn list_users( let users: Vec = users_json .into_iter() .filter_map(|u| { + let id = u.get("userId").and_then(|v| v.as_str()).map(String::from) + .or_else(|| u.get("user_id").and_then(|v| v.as_str()).map(String::from))?; + + let username = u.get("userName").and_then(|v| v.as_str()) + .or_else(|| u.get("username").and_then(|v| v.as_str())) + .unwrap_or("unknown") + .to_string(); + + let email = u.get("preferredLoginName").and_then(|v| v.as_str()) + .or_else(|| u.get("email").and_then(|v| v.as_str())) + .unwrap_or("unknown@example.com") + .to_string(); + + let first_name = u.get("profile") + .and_then(|p| p.get("givenName")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let last_name = u.get("profile") + .and_then(|p| p.get("familyName")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let display_name = u.get("profile") + .and_then(|p| p.get("displayName")) + .and_then(|v| v.as_str()) + .map(String::from); + + let state = u.get("state").and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + + let organization_id = u.get("orgId").and_then(|v| v.as_str()) + .or_else(|| u.get("organization_id").and_then(|v| v.as_str())) + .map(String::from); + + let roles = u.get("roles") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|r| r.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + Some(UserResponse { - id: u.get("userId")?.as_str()?.to_string(), - username: u.get("userName")?.as_str()?.to_string(), - email: u - .get("preferredLoginName") - .and_then(|v| v.as_str()) - .unwrap_or("unknown@example.com") - .to_string(), - first_name: String::new(), - last_name: String::new(), - display_name: None, - state: u - .get("state") - .and_then(|v| v.as_str()) - .unwrap_or("unknown") - .to_string(), + id, + username, + email, + first_name, + last_name, + display_name, + state, + organization_id, + roles, created_at: None, updated_at: None, }) @@ -304,7 +383,6 @@ pub async fn list_users( } } - pub async fn get_user_profile( State(state): State>, Path(user_id): Path, @@ -318,35 +396,68 @@ pub async fn get_user_profile( match client.get_user(&user_id).await { Ok(user_data) => { + let id = user_data.get("id").and_then(|v| v.as_str()) + .unwrap_or(&user_id) + .to_string(); + + let username = user_data.get("username").and_then(|v| v.as_str()) + .or_else(|| user_data.get("userName").and_then(|v| v.as_str())) + .unwrap_or("unknown") + .to_string(); + + let email = user_data.get("preferredLoginName").and_then(|v| v.as_str()) + .or_else(|| user_data.get("email").and_then(|v| v.as_str())) + .unwrap_or("unknown@example.com") + .to_string(); + + let first_name = user_data.get("profile") + .and_then(|p| p.get("givenName")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let last_name = user_data.get("profile") + .and_then(|p| p.get("familyName")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let display_name = user_data.get("profile") + .and_then(|p| p.get("displayName")) + .and_then(|v| v.as_str()) + .map(String::from); + + let state = user_data.get("state").and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + + let organization_id = user_data.get("orgId").and_then(|v| v.as_str()) + .map(String::from); + + let roles = user_data.get("roles") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|r| r.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + let user = UserResponse { - id: user_data - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or(&user_id) - .to_string(), - username: user_data - .get("username") - .and_then(|v| v.as_str()) - .unwrap_or("unknown") - .to_string(), - email: user_data - .get("preferredLoginName") - .and_then(|v| v.as_str()) - .unwrap_or("unknown@example.com") - .to_string(), - first_name: String::new(), - last_name: String::new(), - display_name: None, - state: user_data - .get("state") - .and_then(|v| v.as_str()) - .unwrap_or("unknown") - .to_string(), + id, + username: username.clone(), + email, + first_name, + last_name, + display_name, + state, + organization_id, + roles, created_at: None, updated_at: None, }; - info!("User profile retrieved: {}", user.username); + info!("User profile retrieved: {}", username); Ok(Json(user)) } Err(e) => { @@ -361,3 +472,160 @@ pub async fn get_user_profile( } } } + +pub async fn assign_organization( + State(state): State>, + Path(user_id): Path, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + info!( + "Assigning user {} to organization {}", + user_id, req.organization_id + ); + + let client = { + let auth_service = state.auth_service.lock().await; + auth_service.client().clone() + }; + + let roles = req.roles.unwrap_or_else(|| vec!["user".to_string()]); + + match client + .add_org_member(&req.organization_id, &user_id, roles.clone()) + .await + { + Ok(()) => { + info!( + "User {} assigned to organization {} with roles {:?}", + user_id, req.organization_id, roles + ); + Ok(Json(SuccessResponse { + success: true, + message: Some(format!( + "User assigned to organization {} with roles {:?}", + req.organization_id, roles + )), + user_id: Some(user_id), + })) + } + Err(e) => { + error!("Failed to assign user to organization: {}", e); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to assign user to organization".to_string(), + details: Some(e.to_string()), + }), + )) + } + } +} + +pub async fn remove_from_organization( + State(state): State>, + Path((user_id, org_id)): Path<(String, String)>, +) -> Result, (StatusCode, Json)> { + info!("Removing user {} from organization {}", user_id, org_id); + + let client = { + let auth_service = state.auth_service.lock().await; + auth_service.client().clone() + }; + + match client.remove_org_member(&org_id, &user_id).await { + Ok(()) => { + info!("User {} removed from organization {}", user_id, org_id); + Ok(Json(SuccessResponse { + success: true, + message: Some(format!("User removed from organization {}", org_id)), + user_id: Some(user_id), + })) + } + Err(e) => { + error!("Failed to remove user from organization: {}", e); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to remove user from organization".to_string(), + details: Some(e.to_string()), + }), + )) + } + } +} + +pub async fn get_user_memberships( + State(state): State>, + Path(user_id): Path, +) -> Result, (StatusCode, Json)> { + info!("Getting memberships for user: {}", user_id); + + let client = { + let auth_service = state.auth_service.lock().await; + auth_service.client().clone() + }; + + match client.get_user_memberships(&user_id, 0, 100).await { + Ok(memberships) => { + info!("Retrieved memberships for user {}", user_id); + Ok(Json(memberships)) + } + Err(e) => { + error!("Failed to get user memberships: {}", e); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to get user memberships".to_string(), + details: Some(e.to_string()), + }), + )) + } + } +} + +pub async fn update_user_roles( + State(state): State>, + Path((user_id, org_id)): Path<(String, String)>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + info!( + "Updating roles for user {} in organization {}: {:?}", + user_id, org_id, req.roles + ); + + let client = { + let auth_service = state.auth_service.lock().await; + auth_service.client().clone() + }; + + if let Err(e) = client.remove_org_member(&org_id, &user_id).await { + error!("Failed to remove existing membership: {}", e); + } + + match client + .add_org_member(&org_id, &user_id, req.roles.clone()) + .await + { + Ok(()) => { + info!( + "User {} roles updated in organization {}: {:?}", + user_id, org_id, req.roles + ); + Ok(Json(SuccessResponse { + success: true, + message: Some(format!("User roles updated to {:?}", req.roles)), + user_id: Some(user_id), + })) + } + Err(e) => { + error!("Failed to update user roles: {}", e); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to update user roles".to_string(), + details: Some(e.to_string()), + }), + )) + } + } +} diff --git a/src/docs/mod.rs b/src/docs/mod.rs new file mode 100644 index 000000000..ccbe64d12 --- /dev/null +++ b/src/docs/mod.rs @@ -0,0 +1,1479 @@ +//! GB Docs - Word Processor Module +//! +//! This module provides a Word-like document editor with: +//! - Rich text document management +//! - Real-time multi-user collaboration via WebSocket +//! - Templates (blank, meeting, report, letter) +//! - AI-powered writing assistance +//! - Export to multiple formats (PDF, DOCX, HTML, TXT, MD) + + +use crate::core::urls::ApiUrls; +use crate::shared::state::AppState; +use aws_sdk_s3::primitives::ByteStream; +use axum::{ + extract::{ + ws::{Message, WebSocket, WebSocketUpgrade}, + Path, Query, State, + }, + http::header::HeaderMap, + response::{Html, IntoResponse}, + routing::{get, post}, + Json, Router, +}; +use chrono::{DateTime, Utc}; +use diesel::prelude::*; +use futures_util::{SinkExt, StreamExt}; +use log::{error, info}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::broadcast; +use uuid::Uuid; + +// ============================================================================= +// COLLABORATION TYPES +// ============================================================================= + +type CollaborationChannels = + Arc>>>; + +static COLLAB_CHANNELS: std::sync::OnceLock = std::sync::OnceLock::new(); + +fn get_collab_channels() -> &'static CollaborationChannels { + COLLAB_CHANNELS.get_or_init(|| Arc::new(tokio::sync::RwLock::new(HashMap::new()))) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CollabMessage { + pub msg_type: String, // "cursor", "edit", "format", "join", "leave" + pub doc_id: String, + pub user_id: String, + pub user_name: String, + pub user_color: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub position: Option, // Cursor position + #[serde(skip_serializing_if = "Option::is_none")] + pub length: Option, // Selection length + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, // Inserted/changed content + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, // Format command (bold, italic, etc.) + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Collaborator { + pub id: String, + pub name: String, + pub color: String, + pub cursor_position: Option, + pub selection_length: Option, + pub connected_at: DateTime, +} + +// ============================================================================= +// DOCUMENT TYPES +// ============================================================================= + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Document { + pub id: String, + pub title: String, + pub content: String, // HTML content for rich text + pub owner_id: String, + pub storage_path: String, + pub created_at: DateTime, + pub updated_at: DateTime, + #[serde(default)] + pub collaborators: Vec, // User IDs with access + #[serde(default)] + pub version: u64, // For conflict resolution +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DocumentMetadata { + pub id: String, + pub title: String, + pub owner_id: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub word_count: usize, + pub storage_type: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchQuery { + pub q: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SaveRequest { + pub id: Option, + pub title: Option, + pub content: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SaveResponse { + pub id: String, + pub success: bool, + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AiRequest { + #[serde(rename = "selected-text", alias = "text")] + pub selected_text: Option, + pub prompt: Option, + pub action: Option, + #[serde(rename = "translate-lang")] + pub translate_lang: Option, + pub document_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AiResponse { + pub result: Option, + pub content: Option, + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportQuery { + pub id: Option, +} + +#[derive(Debug, QueryableByName)] +#[diesel(check_for_backend(diesel::pg::Pg))] +pub struct UserRow { + #[diesel(sql_type = diesel::sql_types::Uuid)] + pub id: Uuid, + #[diesel(sql_type = diesel::sql_types::Text)] + pub email: String, + #[diesel(sql_type = diesel::sql_types::Text)] + pub username: String, +} + +#[derive(Debug, QueryableByName)] +#[diesel(check_for_backend(diesel::pg::Pg))] +struct UserIdRow { + #[diesel(sql_type = diesel::sql_types::Uuid)] + user_id: Uuid, +} + +// ============================================================================= +// ROUTE CONFIGURATION +// ============================================================================= + +pub fn configure_docs_routes() -> Router> { + Router::new() + .route(ApiUrls::DOCS_NEW, post(handle_new_document)) + .route(ApiUrls::DOCS_LIST, get(handle_list_documents)) + .route(ApiUrls::DOCS_SEARCH, get(handle_search_documents)) + .route(ApiUrls::DOCS_SAVE, post(handle_save_document)) + .route(ApiUrls::DOCS_AUTOSAVE, post(handle_autosave)) + .route(ApiUrls::DOCS_BY_ID, get(handle_get_document)) + .route(ApiUrls::DOCS_DELETE, post(handle_delete_document)) + .route(ApiUrls::DOCS_TEMPLATE_BLANK, post(handle_template_blank)) + .route(ApiUrls::DOCS_TEMPLATE_MEETING, post(handle_template_meeting)) + .route(ApiUrls::DOCS_TEMPLATE_REPORT, post(handle_template_report)) + .route(ApiUrls::DOCS_TEMPLATE_LETTER, post(handle_template_letter)) + .route(ApiUrls::DOCS_AI_SUMMARIZE, post(handle_ai_summarize)) + .route(ApiUrls::DOCS_AI_EXPAND, post(handle_ai_expand)) + .route(ApiUrls::DOCS_AI_IMPROVE, post(handle_ai_improve)) + .route(ApiUrls::DOCS_AI_SIMPLIFY, post(handle_ai_simplify)) + .route(ApiUrls::DOCS_AI_TRANSLATE, post(handle_ai_translate)) + .route(ApiUrls::DOCS_AI_CUSTOM, post(handle_ai_custom)) + .route(ApiUrls::DOCS_EXPORT_PDF, get(handle_export_pdf)) + .route(ApiUrls::DOCS_EXPORT_DOCX, get(handle_export_docx)) + .route(ApiUrls::DOCS_EXPORT_MD, get(handle_export_md)) + .route(ApiUrls::DOCS_EXPORT_HTML, get(handle_export_html)) + .route(ApiUrls::DOCS_EXPORT_TXT, get(handle_export_txt)) + .route(ApiUrls::DOCS_WS, get(handle_docs_websocket)) +} + +// ============================================================================= +// AUTHENTICATION HELPERS +// ============================================================================= + +async fn get_current_user( + state: &Arc, + headers: &HeaderMap, +) -> Result<(Uuid, String), String> { + let session_id = headers + .get("x-session-id") + .and_then(|v| v.to_str().ok()) + .or_else(|| { + headers + .get("cookie") + .and_then(|v| v.to_str().ok()) + .and_then(|cookies| { + cookies + .split(';') + .find(|c| c.trim().starts_with("session_id=")) + .map(|c| c.trim().trim_start_matches("session_id=")) + }) + }); + + if let Some(sid) = session_id { + if let Ok(session_uuid) = Uuid::parse_str(sid) { + let conn = state.conn.clone(); + let result = tokio::task::spawn_blocking(move || { + let mut db_conn = conn.get().map_err(|e| e.to_string())?; + + let user_id: Option = + diesel::sql_query("SELECT user_id FROM user_sessions WHERE id = $1") + .bind::(session_uuid) + .get_result::(&mut db_conn) + .optional() + .map_err(|e| e.to_string())? + .map(|r| r.user_id); + + if let Some(uid) = user_id { + let user: Option = + diesel::sql_query("SELECT id, email, username FROM users WHERE id = $1") + .bind::(uid) + .get_result(&mut db_conn) + .optional() + .map_err(|e| e.to_string())?; + + if let Some(u) = user { + return Ok((u.id, u.email)); + } + } + Err("User not found".to_string()) + }) + .await + .map_err(|e| e.to_string())?; + + return result; + } + } + + // Fallback to anonymous user + let conn = state.conn.clone(); + tokio::task::spawn_blocking(move || { + let mut db_conn = conn.get().map_err(|e| e.to_string())?; + + let anon_email = "anonymous@local"; + let user: Option = diesel::sql_query( + "SELECT id, email, username FROM users WHERE email = $1", + ) + .bind::(anon_email) + .get_result(&mut db_conn) + .optional() + .map_err(|e| e.to_string())?; + + if let Some(u) = user { + Ok((u.id, u.email)) + } else { + let new_id = Uuid::new_v4(); + let now = Utc::now(); + diesel::sql_query( + "INSERT INTO users (id, username, email, password_hash, is_active, created_at, updated_at) + VALUES ($1, $2, $3, '', true, $4, $4)" + ) + .bind::(new_id) + .bind::("anonymous") + .bind::(anon_email) + .bind::(now) + .execute(&mut db_conn) + .map_err(|e| e.to_string())?; + + Ok((new_id, anon_email.to_string())) + } + }) + .await + .map_err(|e| e.to_string())? +} + +// ============================================================================= +// STORAGE HELPERS +// ============================================================================= + +fn get_user_docs_path(user_identifier: &str) -> String { + let safe_id = user_identifier + .replace(['/', '\\', ':', '*', '?', '"', '<', '>', '|'], "_") + .to_lowercase(); + format!("users/{}/docs", safe_id) +} + +async fn save_document_to_drive( + state: &Arc, + user_identifier: &str, + doc_id: &str, + title: &str, + content: &str, +) -> Result { + let s3_client = state.drive.as_ref().ok_or("S3 service not available")?; + + let base_path = get_user_docs_path(user_identifier); + let doc_path = format!("{}/{}.html", base_path, doc_id); + let meta_path = format!("{}/{}.meta.json", base_path, doc_id); + + // Save document content as HTML + s3_client + .put_object() + .bucket(&state.bucket_name) + .key(&doc_path) + .body(ByteStream::from(content.as_bytes().to_vec())) + .content_type("text/html") + .send() + .await + .map_err(|e| format!("Failed to save document: {}", e))?; + + // Save metadata + let word_count = content + .split_whitespace() + .filter(|w| !w.starts_with('<') && !w.ends_with('>')) + .count(); + + let metadata = serde_json::json!({ + "id": doc_id, + "title": title, + "created_at": Utc::now().to_rfc3339(), + "updated_at": Utc::now().to_rfc3339(), + "word_count": word_count, + "version": 1 + }); + + s3_client + .put_object() + .bucket(&state.bucket_name) + .key(&meta_path) + .body(ByteStream::from(metadata.to_string().into_bytes())) + .content_type("application/json") + .send() + .await + .map_err(|e| format!("Failed to save metadata: {}", e))?; + + Ok(doc_path) +} + +async fn load_document_from_drive( + state: &Arc, + user_identifier: &str, + doc_id: &str, +) -> Result, String> { + let s3_client = state.drive.as_ref().ok_or("S3 service not available")?; + + let base_path = get_user_docs_path(user_identifier); + let doc_path = format!("{}/{}.html", base_path, doc_id); + let meta_path = format!("{}/{}.meta.json", base_path, doc_id); + + // Load content + let content = match s3_client + .get_object() + .bucket(&state.bucket_name) + .key(&doc_path) + .send() + .await + { + Ok(result) => { + let bytes = result + .body + .collect() + .await + .map_err(|e| e.to_string())? + .into_bytes(); + String::from_utf8(bytes.to_vec()).map_err(|e| e.to_string())? + } + Err(_) => return Ok(None), + }; + + // Load metadata + let (title, created_at, updated_at) = match s3_client + .get_object() + .bucket(&state.bucket_name) + .key(&meta_path) + .send() + .await + { + Ok(result) => { + let bytes = result + .body + .collect() + .await + .map_err(|e| e.to_string())? + .into_bytes(); + let meta_str = String::from_utf8(bytes.to_vec()).map_err(|e| e.to_string())?; + let meta: serde_json::Value = serde_json::from_str(&meta_str).unwrap_or_default(); + ( + meta["title"].as_str().unwrap_or("Untitled").to_string(), + meta["created_at"] + .as_str() + .and_then(|s| DateTime::parse_from_rfc3339(s).ok()) + .map(|d| d.with_timezone(&Utc)) + .unwrap_or_else(Utc::now), + meta["updated_at"] + .as_str() + .and_then(|s| DateTime::parse_from_rfc3339(s).ok()) + .map(|d| d.with_timezone(&Utc)) + .unwrap_or_else(Utc::now), + ) + } + Err(_) => ("Untitled".to_string(), Utc::now(), Utc::now()), + }; + + Ok(Some(Document { + id: doc_id.to_string(), + title, + content, + owner_id: user_identifier.to_string(), + storage_path: doc_path, + created_at, + updated_at, + collaborators: Vec::new(), + version: 1, + })) +} + +async fn list_documents_from_drive( + state: &Arc, + user_identifier: &str, +) -> Result, String> { + let s3_client = state.drive.as_ref().ok_or("S3 service not available")?; + + let base_path = get_user_docs_path(user_identifier); + let prefix = format!("{}/", base_path); + let mut documents = Vec::new(); + + if let Ok(result) = s3_client + .list_objects_v2() + .bucket(&state.bucket_name) + .prefix(&prefix) + .send() + .await + { + for obj in result.contents() { + if let Some(key) = obj.key() { + if key.ends_with(".meta.json") { + // Load metadata + if let Ok(meta_result) = s3_client + .get_object() + .bucket(&state.bucket_name) + .key(key) + .send() + .await + { + if let Ok(bytes) = meta_result.body.collect().await { + if let Ok(meta_str) = String::from_utf8(bytes.into_bytes().to_vec()) { + if let Ok(meta) = serde_json::from_str::(&meta_str) { + documents.push(DocumentMetadata { + id: meta["id"].as_str().unwrap_or("").to_string(), + title: meta["title"].as_str().unwrap_or("Untitled").to_string(), + owner_id: user_identifier.to_string(), + created_at: meta["created_at"] + .as_str() + .and_then(|s| DateTime::parse_from_rfc3339(s).ok()) + .map(|d| d.with_timezone(&Utc)) + .unwrap_or_else(Utc::now), + updated_at: meta["updated_at"] + .as_str() + .and_then(|s| DateTime::parse_from_rfc3339(s).ok()) + .map(|d| d.with_timezone(&Utc)) + .unwrap_or_else(Utc::now), + word_count: meta["word_count"].as_u64().unwrap_or(0) as usize, + storage_type: "docs".to_string(), + }); + } + } + } + } + } + } + } + } + + // Sort by updated_at descending + documents.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); + + Ok(documents) +} + +async fn delete_document_from_drive( + state: &Arc, + user_identifier: &str, + doc_id: &str, +) -> Result<(), String> { + let s3_client = state.drive.as_ref().ok_or("S3 service not available")?; + + let base_path = get_user_docs_path(user_identifier); + let doc_path = format!("{}/{}.html", base_path, doc_id); + let meta_path = format!("{}/{}.meta.json", base_path, doc_id); + + // Delete document + let _ = s3_client + .delete_object() + .bucket(&state.bucket_name) + .key(&doc_path) + .send() + .await; + + // Delete metadata + let _ = s3_client + .delete_object() + .bucket(&state.bucket_name) + .key(&meta_path) + .send() + .await; + + Ok(()) +} + +// ============================================================================= +// LLM HELPERS +// ============================================================================= + +async fn call_llm(_state: &Arc, _system_prompt: &str, _user_text: &str) -> Result { + // TODO: Integrate with LLM provider when available + Err("LLM not available".to_string()) +} + +// ============================================================================= +// DOCUMENT HANDLERS +// ============================================================================= + +pub async fn handle_new_document( + State(state): State>, + headers: HeaderMap, +) -> impl IntoResponse { + let (user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + let doc_id = Uuid::new_v4().to_string(); + let title = "Untitled Document".to_string(); + let content = "

".to_string(); + + if let Err(e) = save_document_to_drive(&state, &user_identifier, &doc_id, &title, &content).await { + error!("Failed to save new document: {}", e); + } + + let mut html = String::new(); + html.push_str("
"); + html.push_str(&format_document_list_item(&doc_id, &title, "just now", true)); + html.push_str(""); + html.push_str("
"); + + info!("New document created: {} for user {}", doc_id, user_id); + Html(html) +} + +pub async fn handle_list_documents( + State(state): State>, + headers: HeaderMap, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + let documents = match list_documents_from_drive(&state, &user_identifier).await { + Ok(docs) => docs, + Err(e) => { + error!("Failed to list documents: {}", e); + Vec::new() + } + }; + + let mut html = String::new(); + html.push_str("
"); + + if documents.is_empty() { + html.push_str("
"); + html.push_str("

No documents yet

"); + html.push_str(""); + html.push_str("
"); + } else { + for doc in documents { + let time_str = format_relative_time(doc.updated_at); + html.push_str(&format_document_list_item(&doc.id, &doc.title, &time_str, false)); + } + } + + html.push_str("
"); + Html(html) +} + +pub async fn handle_search_documents( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + let query = params.q.unwrap_or_default().to_lowercase(); + + let documents = match list_documents_from_drive(&state, &user_identifier).await { + Ok(docs) => docs, + Err(e) => { + error!("Failed to list documents: {}", e); + Vec::new() + } + }; + + let filtered: Vec<_> = if query.is_empty() { + documents + } else { + documents + .into_iter() + .filter(|d| d.title.to_lowercase().contains(&query)) + .collect() + }; + + let mut html = String::new(); + html.push_str("
"); + + if filtered.is_empty() { + html.push_str("
"); + html.push_str("

No documents found

"); + html.push_str("
"); + } else { + for doc in filtered { + let time_str = format_relative_time(doc.updated_at); + html.push_str(&format_document_list_item(&doc.id, &doc.title, &time_str, false)); + } + } + + html.push_str("
"); + Html(html) +} + +pub async fn handle_get_document( + State(state): State>, + headers: HeaderMap, + Path(id): Path, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Json(serde_json::json!({"error": "Authentication required"})).into_response(); + } + }; + + match load_document_from_drive(&state, &user_identifier, &id).await { + Ok(Some(doc)) => Json(serde_json::json!({ + "id": doc.id, + "title": doc.title, + "content": doc.content, + "created_at": doc.created_at.to_rfc3339(), + "updated_at": doc.updated_at.to_rfc3339() + })).into_response(), + Ok(None) => Json(serde_json::json!({"error": "Document not found"})).into_response(), + Err(e) => Json(serde_json::json!({"error": e})).into_response(), + } +} + +pub async fn handle_save_document( + State(state): State>, + headers: HeaderMap, + Json(payload): Json, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Json(SaveResponse { + id: String::new(), + success: false, + message: Some("Authentication required".to_string()), + }); + } + }; + + let doc_id = payload.id.unwrap_or_else(|| Uuid::new_v4().to_string()); + let title = payload.title.unwrap_or_else(|| "Untitled Document".to_string()); + let content = payload.content.unwrap_or_default(); + + match save_document_to_drive(&state, &user_identifier, &doc_id, &title, &content).await { + Ok(_) => { + // Broadcast to collaborators + let channels = get_collab_channels(); + if let Some(sender) = channels.read().await.get(&doc_id) { + let msg = CollabMessage { + msg_type: "save".to_string(), + doc_id: doc_id.clone(), + user_id: user_identifier.clone(), + user_name: user_identifier.clone(), + user_color: "#4285f4".to_string(), + position: None, + length: None, + content: None, + format: None, + timestamp: Utc::now(), + }; + let _ = sender.send(msg); + } + + Json(SaveResponse { + id: doc_id, + success: true, + message: None, + }) + } + Err(e) => Json(SaveResponse { + id: doc_id, + success: false, + message: Some(e), + }), + } +} + +pub async fn handle_autosave( + State(state): State>, + headers: HeaderMap, + Json(payload): Json, +) -> impl IntoResponse { + handle_save_document(State(state), headers, Json(payload)).await +} + +pub async fn handle_delete_document( + State(state): State>, + headers: HeaderMap, + Path(id): Path, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + match delete_document_from_drive(&state, &user_identifier, &id).await { + Ok(_) => { + Html("
Document deleted
".to_string()) + } + Err(e) => Html(format_error(&e)), + } +} + +// ============================================================================= +// TEMPLATE HANDLERS +// ============================================================================= + +pub async fn handle_template_blank( + State(state): State>, + headers: HeaderMap, +) -> impl IntoResponse { + handle_new_document(State(state), headers).await +} + +pub async fn handle_template_meeting( + State(state): State>, + headers: HeaderMap, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + let doc_id = Uuid::new_v4().to_string(); + let title = "Meeting Notes".to_string(); + let now = Utc::now(); + + let content = format!( + r#"

Meeting Notes

+

Date: {}

+

Attendees:

+
+

Agenda

+
+

Discussion

+

+

Action Items

+
+

Next Steps

+

"#, + now.format("%Y-%m-%d") + ); + + let _ = save_document_to_drive(&state, &user_identifier, &doc_id, &title, &content).await; + + Html(format_document_content(&doc_id, &title, &content)) +} + +pub async fn handle_template_report( + State(state): State>, + headers: HeaderMap, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + let doc_id = Uuid::new_v4().to_string(); + let title = "Report".to_string(); + let now = Utc::now(); + + let content = format!( + r#"

Report

+

Date: {}

+

Author:

+
+

Executive Summary

+

+

Introduction

+

+

Background

+

+

Findings

+

Key Finding 1

+

+

Key Finding 2

+

+

Analysis

+

+

Recommendations

+
+

Conclusion

+

+

Appendix

+

"#, + now.format("%Y-%m-%d") + ); + + let _ = save_document_to_drive(&state, &user_identifier, &doc_id, &title, &content).await; + + Html(format_document_content(&doc_id, &title, &content)) +} + +pub async fn handle_template_letter( + State(state): State>, + headers: HeaderMap, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + let doc_id = Uuid::new_v4().to_string(); + let title = "Letter".to_string(); + let now = Utc::now(); + + let content = format!( + r#"

[Your Name]
+[Your Address]
+[City, State ZIP]
+[Your Email]

+

{}

+

[Recipient Name]
+[Recipient Title]
+[Company/Organization]
+[Address]
+[City, State ZIP]

+

Dear [Recipient Name],

+

[Opening paragraph - State the purpose of your letter]

+

[Body paragraph(s) - Provide details, explanations, or supporting information]

+

[Closing paragraph - Summarize, request action, or express appreciation]

+

Sincerely,

+



[Your Signature]
+[Your Typed Name]

"#, + now.format("%B %d, %Y") + ); + + let _ = save_document_to_drive(&state, &user_identifier, &doc_id, &title, &content).await; + + Html(format_document_content(&doc_id, &title, &content)) +} + +// ============================================================================= +// AI HANDLERS +// ============================================================================= + +pub async fn handle_ai_summarize( + State(state): State>, + Json(payload): Json, +) -> impl IntoResponse { + let text = payload.selected_text.unwrap_or_default(); + + if text.is_empty() { + return Json(AiResponse { + result: None, + content: Some("Please select some text to summarize.".to_string()), + error: None, + }); + } + + let system_prompt = "You are a helpful writing assistant. Summarize the following text concisely while preserving the key points. Provide only the summary without any preamble."; + + match call_llm(&state, system_prompt, &text).await { + Ok(summary) => Json(AiResponse { + result: Some(summary), + content: None, + error: None, + }), + Err(e) => { + error!("LLM summarize error: {}", e); + let word_count = text.split_whitespace().count(); + Json(AiResponse { + result: Some(format!("Summary of {} words: {}...", word_count, text.chars().take(100).collect::())), + content: None, + error: None, + }) + } + } +} + +pub async fn handle_ai_expand( + State(state): State>, + Json(payload): Json, +) -> impl IntoResponse { + let text = payload.selected_text.unwrap_or_default(); + + if text.is_empty() { + return Json(AiResponse { + result: None, + content: Some("Please select some text to expand.".to_string()), + error: None, + }); + } + + let system_prompt = "You are a helpful writing assistant. Expand on the following text by adding more details, examples, and explanations. Maintain the original tone and style."; + + match call_llm(&state, system_prompt, &text).await { + Ok(expanded) => Json(AiResponse { + result: Some(expanded), + content: None, + error: None, + }), + Err(e) => { + error!("LLM expand error: {}", e); + Json(AiResponse { + result: None, + content: None, + error: Some("AI processing failed".to_string()), + }) + } + } +} + +pub async fn handle_ai_improve( + State(state): State>, + Json(payload): Json, +) -> impl IntoResponse { + let text = payload.selected_text.unwrap_or_default(); + + if text.is_empty() { + return Json(AiResponse { + result: None, + content: Some("Please select some text to improve.".to_string()), + error: None, + }); + } + + let system_prompt = "You are a helpful writing assistant. Improve the following text by enhancing clarity, grammar, and style while preserving the original meaning."; + + match call_llm(&state, system_prompt, &text).await { + Ok(improved) => Json(AiResponse { + result: Some(improved), + content: None, + error: None, + }), + Err(e) => { + error!("LLM improve error: {}", e); + Json(AiResponse { + result: None, + content: None, + error: Some("AI processing failed".to_string()), + }) + } + } +} + +pub async fn handle_ai_simplify( + State(state): State>, + Json(payload): Json, +) -> impl IntoResponse { + let text = payload.selected_text.unwrap_or_default(); + + if text.is_empty() { + return Json(AiResponse { + result: None, + content: Some("Please select some text to simplify.".to_string()), + error: None, + }); + } + + let system_prompt = "You are a helpful writing assistant. Simplify the following text to make it easier to understand while keeping the essential meaning."; + + match call_llm(&state, system_prompt, &text).await { + Ok(simplified) => Json(AiResponse { + result: Some(simplified), + content: None, + error: None, + }), + Err(e) => { + error!("LLM simplify error: {}", e); + Json(AiResponse { + result: None, + content: None, + error: Some("AI processing failed".to_string()), + }) + } + } +} + +pub async fn handle_ai_translate( + State(state): State>, + Json(payload): Json, +) -> impl IntoResponse { + let text = payload.selected_text.unwrap_or_default(); + let target_lang = payload.translate_lang.unwrap_or_else(|| "English".to_string()); + + if text.is_empty() { + return Json(AiResponse { + result: None, + content: Some("Please select some text to translate.".to_string()), + error: None, + }); + } + + let system_prompt = format!("You are a translator. Translate the following text to {}. Provide only the translation without any preamble.", target_lang); + + match call_llm(&state, &system_prompt, &text).await { + Ok(translated) => Json(AiResponse { + result: Some(translated), + content: None, + error: None, + }), + Err(e) => { + error!("LLM translate error: {}", e); + Json(AiResponse { + result: None, + content: None, + error: Some("AI processing failed".to_string()), + }) + } + } +} + +pub async fn handle_ai_custom( + State(state): State>, + Json(payload): Json, +) -> impl IntoResponse { + let text = payload.selected_text.unwrap_or_default(); + let prompt = payload.prompt.unwrap_or_default(); + + if prompt.is_empty() { + return Json(AiResponse { + result: None, + content: Some("Please provide a prompt.".to_string()), + error: None, + }); + } + + let system_prompt = format!("You are a helpful writing assistant. {}", prompt); + + match call_llm(&state, &system_prompt, &text).await { + Ok(result) => Json(AiResponse { + result: Some(result), + content: None, + error: None, + }), + Err(e) => { + error!("LLM custom error: {}", e); + Json(AiResponse { + result: None, + content: None, + error: Some("AI processing failed".to_string()), + }) + } + } +} + +// ============================================================================= +// EXPORT HANDLERS +// ============================================================================= + +pub async fn handle_export_pdf( + State(_state): State>, + Query(_params): Query, +) -> impl IntoResponse { + // PDF export would require a library like printpdf or wkhtmltopdf + Html("

PDF export coming soon

".to_string()) +} + +pub async fn handle_export_docx( + State(_state): State>, + Query(_params): Query, +) -> impl IntoResponse { + // DOCX export would require a library like docx-rs + Html("

DOCX export coming soon

".to_string()) +} + +pub async fn handle_export_md( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(_) => return Html("Authentication required".to_string()), + }; + + let doc_id = match params.id { + Some(id) => id, + None => return Html("Document ID required".to_string()), + }; + + match load_document_from_drive(&state, &user_identifier, &doc_id).await { + Ok(Some(doc)) => { + let md = html_to_markdown(&doc.content); + Html(md) + } + _ => Html("Document not found".to_string()), + } +} + +pub async fn handle_export_html( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(_) => return Html("Authentication required".to_string()), + }; + + let doc_id = match params.id { + Some(id) => id, + None => return Html("Document ID required".to_string()), + }; + + match load_document_from_drive(&state, &user_identifier, &doc_id).await { + Ok(Some(doc)) => { + let html = format!( + r#" + + + +{} + + + +{} + +"#, + html_escape(&doc.title), + doc.content + ); + Html(html) + } + _ => Html("Document not found".to_string()), + } +} + +pub async fn handle_export_txt( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(_) => return Html("Authentication required".to_string()), + }; + + let doc_id = match params.id { + Some(id) => id, + None => return Html("Document ID required".to_string()), + }; + + match load_document_from_drive(&state, &user_identifier, &doc_id).await { + Ok(Some(doc)) => { + let txt = strip_html(&doc.content); + Html(txt) + } + _ => Html("Document not found".to_string()), + } +} + +// ============================================================================= +// WEBSOCKET COLLABORATION +// ============================================================================= + +pub async fn handle_docs_websocket( + ws: WebSocketUpgrade, + State(state): State>, + Path(doc_id): Path, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| handle_docs_connection(socket, state, doc_id)) +} + +async fn handle_docs_connection(socket: WebSocket, _state: Arc, doc_id: String) { + let (mut sender, mut receiver) = socket.split(); + let channels = get_collab_channels(); + + // Get or create channel for this document + let rx = { + let mut channels_write = channels.write().await; + let tx = channels_write + .entry(doc_id.clone()) + .or_insert_with(|| broadcast::channel(100).0); + tx.subscribe() + }; + + let user_id = Uuid::new_v4().to_string(); + let user_color = get_random_color(); + + // Send join message + { + let channels_read = channels.read().await; + if let Some(tx) = channels_read.get(&doc_id) { + let msg = CollabMessage { + msg_type: "join".to_string(), + doc_id: doc_id.clone(), + user_id: user_id.clone(), + user_name: format!("User {}", &user_id[..8]), + user_color: user_color.clone(), + position: None, + length: None, + content: None, + format: None, + timestamp: Utc::now(), + }; + let _ = tx.send(msg); + } + } + + // Spawn task to forward broadcast messages to this client + let mut rx = rx; + let user_id_clone = user_id.clone(); + let send_task = tokio::spawn(async move { + while let Ok(msg) = rx.recv().await { + // Don't send messages back to the sender + if msg.user_id != user_id_clone { + if let Ok(json) = serde_json::to_string(&msg) { + if sender.send(Message::Text(json)).await.is_err() { + break; + } + } + } + } + }); + + // Handle incoming messages + let channels_clone = channels.clone(); + let doc_id_clone = doc_id.clone(); + let user_id_clone2 = user_id.clone(); + let user_color_clone = user_color.clone(); + + while let Some(Ok(msg)) = receiver.next().await { + if let Message::Text(text) = msg { + if let Ok(mut collab_msg) = serde_json::from_str::(&text) { + collab_msg.user_id = user_id_clone2.clone(); + collab_msg.user_color = user_color_clone.clone(); + collab_msg.timestamp = Utc::now(); + + let channels_read = channels_clone.read().await; + if let Some(tx) = channels_read.get(&doc_id_clone) { + let _ = tx.send(collab_msg); + } + } + } + } + + // Send leave message + { + let channels_read = channels.read().await; + if let Some(tx) = channels_read.get(&doc_id) { + let msg = CollabMessage { + msg_type: "leave".to_string(), + doc_id: doc_id.clone(), + user_id: user_id.clone(), + user_name: format!("User {}", &user_id[..8]), + user_color, + position: None, + length: None, + content: None, + format: None, + timestamp: Utc::now(), + }; + let _ = tx.send(msg); + } + } + + send_task.abort(); + info!("WebSocket connection closed for doc {}", doc_id); +} + +fn get_random_color() -> String { + let colors = [ + "#4285f4", "#ea4335", "#fbbc05", "#34a853", + "#ff6d01", "#46bdc6", "#7b1fa2", "#c2185b", + ]; + let mut rng = rand::rng(); + let idx = rng.random_range(0..colors.len()); + colors[idx].to_string() +} + +// ============================================================================= +// FORMATTING HELPERS +// ============================================================================= + +fn format_document_list_item(id: &str, title: &str, time: &str, is_new: bool) -> String { + let new_class = if is_new { " new-item" } else { "" }; + let mut html = String::new(); + html.push_str("
"); + html.push_str("
📄
"); + html.push_str("
"); + html.push_str(""); + html.push_str(&html_escape(title)); + html.push_str(""); + html.push_str(""); + html.push_str(&html_escape(time)); + html.push_str(""); + html.push_str("
"); + html.push_str("
"); + html +} + +fn format_document_content(id: &str, title: &str, content: &str) -> String { + let mut html = String::new(); + html.push_str("
"); + html.push_str("
"); + html.push_str(&html_escape(title)); + html.push_str("
"); + html.push_str("
"); + html.push_str(content); + html.push_str("
"); + html.push_str("
"); + html.push_str(""); + html +} + +fn format_error(message: &str) -> String { + let mut html = String::new(); + html.push_str("
"); + html.push_str("⚠️"); + html.push_str(""); + html.push_str(&html_escape(message)); + html.push_str(""); + html.push_str("
"); + html +} + +fn format_relative_time(time: DateTime) -> String { + let now = Utc::now(); + let duration = now.signed_duration_since(time); + + if duration.num_seconds() < 60 { + "just now".to_string() + } else if duration.num_minutes() < 60 { + format!("{}m ago", duration.num_minutes()) + } else if duration.num_hours() < 24 { + format!("{}h ago", duration.num_hours()) + } else if duration.num_days() < 7 { + format!("{}d ago", duration.num_days()) + } else { + time.format("%b %d").to_string() + } +} + +fn html_escape(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +fn strip_html(html: &str) -> String { + let mut result = String::new(); + let mut in_tag = false; + + for c in html.chars() { + match c { + '<' => in_tag = true, + '>' => { + in_tag = false; + result.push(' '); + } + _ if !in_tag => result.push(c), + _ => {} + } + } + + // Clean up whitespace + result + .split_whitespace() + .collect::>() + .join(" ") +} + +fn html_to_markdown(html: &str) -> String { + let mut md = html.to_string(); + + // Basic HTML to Markdown conversion + md = md.replace("

", "# ").replace("

", "\n\n"); + md = md.replace("

", "## ").replace("

", "\n\n"); + md = md.replace("

", "### ").replace("

", "\n\n"); + md = md.replace("

", "").replace("

", "\n\n"); + md = md.replace("
", "\n").replace("
", "\n").replace("
", "\n"); + md = md.replace("", "**").replace("", "**"); + md = md.replace("", "**").replace("", "**"); + md = md.replace("", "*").replace("", "*"); + md = md.replace("", "*").replace("", "*"); + md = md.replace("
    ", "").replace("
", "\n"); + md = md.replace("
    ", "").replace("
", "\n"); + md = md.replace("
  • ", "- ").replace("
  • ", "\n"); + md = md.replace("
    ", "\n---\n").replace("
    ", "\n---\n"); + + // Strip remaining HTML tags + strip_html(&md) +} diff --git a/src/drive/drive_monitor/mod.rs b/src/drive/drive_monitor/mod.rs index 7c841c20b..5ba7f864b 100644 --- a/src/drive/drive_monitor/mod.rs +++ b/src/drive/drive_monitor/mod.rs @@ -86,7 +86,7 @@ impl DriveMonitor { } pub async fn start_monitoring(&self) -> Result<(), Box> { - trace!("[PROFILE] start_monitoring ENTER"); + trace!("start_monitoring ENTER"); let start_mem = MemoryStats::current(); trace!("[DRIVE_MONITOR] Starting DriveMonitor for bot {}, RSS={}", self.bot_id, MemoryStats::format_bytes(start_mem.rss_bytes)); @@ -99,7 +99,7 @@ impl DriveMonitor { self.is_processing .store(true, std::sync::atomic::Ordering::SeqCst); - trace!("[PROFILE] start_monitoring: calling check_for_changes..."); + trace!("start_monitoring: calling check_for_changes..."); info!("[DRIVE_MONITOR] Calling initial check_for_changes..."); match self.check_for_changes().await { @@ -111,7 +111,7 @@ impl DriveMonitor { self.consecutive_failures.fetch_add(1, Ordering::Relaxed); } } - trace!("[PROFILE] start_monitoring: check_for_changes returned"); + trace!("start_monitoring: check_for_changes returned"); let after_initial = MemoryStats::current(); trace!("[DRIVE_MONITOR] After initial check, RSS={} (delta={})", @@ -215,38 +215,38 @@ impl DriveMonitor { }) } async fn check_for_changes(&self) -> Result<(), Box> { - trace!("[PROFILE] check_for_changes ENTER"); + trace!("check_for_changes ENTER"); let start_mem = MemoryStats::current(); trace!("[DRIVE_MONITOR] check_for_changes START, RSS={}", MemoryStats::format_bytes(start_mem.rss_bytes)); let Some(client) = &self.state.drive else { - trace!("[PROFILE] check_for_changes: no drive client, returning"); + trace!("check_for_changes: no drive client, returning"); return Ok(()); }; - trace!("[PROFILE] check_for_changes: calling check_gbdialog_changes..."); + trace!("check_for_changes: calling check_gbdialog_changes..."); trace!("[DRIVE_MONITOR] Checking gbdialog..."); self.check_gbdialog_changes(client).await?; - trace!("[PROFILE] check_for_changes: check_gbdialog_changes done"); + trace!("check_for_changes: check_gbdialog_changes done"); let after_dialog = MemoryStats::current(); trace!("[DRIVE_MONITOR] After gbdialog, RSS={} (delta={})", MemoryStats::format_bytes(after_dialog.rss_bytes), MemoryStats::format_bytes(after_dialog.rss_bytes.saturating_sub(start_mem.rss_bytes))); - trace!("[PROFILE] check_for_changes: calling check_gbot..."); + trace!("check_for_changes: calling check_gbot..."); trace!("[DRIVE_MONITOR] Checking gbot..."); self.check_gbot(client).await?; - trace!("[PROFILE] check_for_changes: check_gbot done"); + trace!("check_for_changes: check_gbot done"); let after_gbot = MemoryStats::current(); trace!("[DRIVE_MONITOR] After gbot, RSS={} (delta={})", MemoryStats::format_bytes(after_gbot.rss_bytes), MemoryStats::format_bytes(after_gbot.rss_bytes.saturating_sub(after_dialog.rss_bytes))); - trace!("[PROFILE] check_for_changes: calling check_gbkb_changes..."); + trace!("check_for_changes: calling check_gbkb_changes..."); trace!("[DRIVE_MONITOR] Checking gbkb..."); self.check_gbkb_changes(client).await?; - trace!("[PROFILE] check_for_changes: check_gbkb_changes done"); + trace!("check_for_changes: check_gbkb_changes done"); let after_gbkb = MemoryStats::current(); trace!("[DRIVE_MONITOR] After gbkb, RSS={} (delta={})", MemoryStats::format_bytes(after_gbkb.rss_bytes), @@ -260,7 +260,7 @@ impl DriveMonitor { MemoryStats::format_bytes(total_delta)); } - trace!("[PROFILE] check_for_changes EXIT"); + trace!("check_for_changes EXIT"); Ok(()) } async fn check_gbdialog_changes( @@ -335,7 +335,7 @@ impl DriveMonitor { Ok(()) } async fn check_gbot(&self, client: &Client) -> Result<(), Box> { - trace!("[PROFILE] check_gbot ENTER"); + trace!("check_gbot ENTER"); let config_manager = ConfigManager::new(self.state.conn.clone()); debug!("check_gbot: Checking bucket {} for config.csv changes", self.bucket_name); let mut continuation_token = None; @@ -481,7 +481,7 @@ impl DriveMonitor { } continuation_token = list_objects.next_continuation_token; } - trace!("[PROFILE] check_gbot EXIT"); + trace!("check_gbot EXIT"); Ok(()) } async fn broadcast_theme_change( @@ -616,7 +616,7 @@ impl DriveMonitor { &self, client: &Client, ) -> Result<(), Box> { - trace!("[PROFILE] check_gbkb_changes ENTER"); + trace!("check_gbkb_changes ENTER"); let bot_name = self .bucket_name .strip_suffix(".gbai") @@ -850,7 +850,7 @@ impl DriveMonitor { } } - trace!("[PROFILE] check_gbkb_changes EXIT"); + trace!("check_gbkb_changes EXIT"); Ok(()) } diff --git a/src/email/mod.rs b/src/email/mod.rs index 07a89670d..cb271c7fa 100644 --- a/src/email/mod.rs +++ b/src/email/mod.rs @@ -2,7 +2,7 @@ use crate::{config::EmailConfig, core::urls::ApiUrls, shared::state::AppState}; use axum::{ extract::{Path, Query, State}, http::StatusCode, - response::{Html, IntoResponse, Response}, + response::{IntoResponse, Response}, Json, }; use axum::{ @@ -129,8 +129,8 @@ pub fn configure() -> Router> { .route(ApiUrls::EMAIL_LIST_HTMX, get(list_emails_htmx)) .route(ApiUrls::EMAIL_FOLDERS_HTMX, get(list_folders_htmx)) .route(ApiUrls::EMAIL_COMPOSE_HTMX, get(compose_email_htmx)) - .route(&ApiUrls::EMAIL_CONTENT_HTMX.replace(":id", "{id}"), get(get_email_content_htmx)) - .route("/api/ui/email/{id}/delete", delete(delete_email_htmx)) + .route(ApiUrls::EMAIL_CONTENT_HTMX, get(get_email_content_htmx)) + .route("/api/ui/email/:id/delete", delete(delete_email_htmx)) .route(ApiUrls::EMAIL_LABELS_HTMX, get(list_labels_htmx)) .route(ApiUrls::EMAIL_TEMPLATES_HTMX, get(list_templates_htmx)) .route(ApiUrls::EMAIL_SIGNATURES_HTMX, get(list_signatures_htmx)) diff --git a/src/email/vectordb.rs b/src/email/vectordb.rs index cb0a503d3..0b838658f 100644 --- a/src/email/vectordb.rs +++ b/src/email/vectordb.rs @@ -2,11 +2,12 @@ use anyhow::Result; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; -use std::sync::Arc; #[cfg(not(feature = "vectordb"))] use tokio::fs; use uuid::Uuid; +#[cfg(feature = "vectordb")] +use std::sync::Arc; #[cfg(feature = "vectordb")] use qdrant_client::{ qdrant::{Distance, PointStruct, VectorParams}, @@ -111,7 +112,19 @@ impl UserEmailVectorDB { #[cfg(not(feature = "vectordb"))] pub async fn initialize(&mut self, _qdrant_url: &str) -> Result<()> { - log::warn!("Vector DB feature not enabled, using fallback storage"); + log::warn!( + "Vector DB feature not enabled for user={} bot={}, using fallback storage at {}", + self.user_id, + self.bot_id, + self.db_path.display() + ); + std::fs::create_dir_all(&self.db_path)?; + let metadata_path = self.db_path.join(format!("{}.meta", self.collection_name)); + let metadata = format!( + "{{\"user_id\":\"{}\",\"bot_id\":\"{}\",\"collection\":\"{}\"}}", + self.user_id, self.bot_id, self.collection_name + ); + std::fs::write(metadata_path, metadata)?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index c139c28ed..62565ca59 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,8 +6,11 @@ pub mod security; pub mod analytics; pub mod designer; +pub mod docs; pub mod paper; pub mod research; +pub mod sheet; +pub mod slides; pub mod sources; pub use core::shared; @@ -94,6 +97,9 @@ pub mod weba; #[cfg(feature = "whatsapp")] pub mod whatsapp; +#[cfg(feature = "telegram")] +pub mod telegram; + #[cfg(test)] mod tests { use super::*; diff --git a/src/llm/local.rs b/src/llm/local.rs index a8b026671..dcf32e958 100644 --- a/src/llm/local.rs +++ b/src/llm/local.rs @@ -13,14 +13,14 @@ use tokio; pub async fn ensure_llama_servers_running( app_state: Arc, ) -> Result<(), Box> { - trace!("[PROFILE] ensure_llama_servers_running ENTER"); + trace!("ensure_llama_servers_running ENTER"); let start_mem = MemoryStats::current(); trace!("[LLM_LOCAL] ensure_llama_servers_running START, RSS={}", MemoryStats::format_bytes(start_mem.rss_bytes)); log_jemalloc_stats(); if std::env::var("SKIP_LLM_SERVER").is_ok() { - trace!("[PROFILE] SKIP_LLM_SERVER set, returning early"); + trace!("SKIP_LLM_SERVER set, returning early"); info!("SKIP_LLM_SERVER set - skipping local LLM server startup (using mock/external LLM)"); return Ok(()); } @@ -83,7 +83,7 @@ pub async fn ensure_llama_servers_running( info!(" Embedding Model: {embedding_model}"); info!(" LLM Server Path: {llm_server_path}"); info!("Restarting any existing llama-server processes..."); - trace!("[PROFILE] About to pkill llama-server..."); + trace!("About to pkill llama-server..."); let before_pkill = MemoryStats::current(); trace!("[LLM_LOCAL] Before pkill, RSS={}", MemoryStats::format_bytes(before_pkill.rss_bytes)); @@ -97,7 +97,7 @@ pub async fn ensure_llama_servers_running( tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; info!("Existing llama-server processes terminated (if any)"); } - trace!("[PROFILE] pkill done"); + trace!("pkill done"); let after_pkill = MemoryStats::current(); trace!("[LLM_LOCAL] After pkill, RSS={} (delta={})", @@ -153,7 +153,7 @@ pub async fn ensure_llama_servers_running( task.await??; } info!("Waiting for servers to become ready..."); - trace!("[PROFILE] Starting wait loop for servers..."); + trace!("Starting wait loop for servers..."); let before_wait = MemoryStats::current(); trace!("[LLM_LOCAL] Before wait loop, RSS={}", MemoryStats::format_bytes(before_wait.rss_bytes)); @@ -162,7 +162,7 @@ pub async fn ensure_llama_servers_running( let mut attempts = 0; let max_attempts = 120; while attempts < max_attempts && (!llm_ready || !embedding_ready) { - trace!("[PROFILE] Wait loop iteration {}", attempts); + trace!("Wait loop iteration {}", attempts); tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; if attempts % 5 == 0 { @@ -221,7 +221,7 @@ pub async fn ensure_llama_servers_running( if !embedding_model.is_empty() { set_embedding_server_ready(true); } - trace!("[PROFILE] Servers ready!"); + trace!("Servers ready!"); let after_ready = MemoryStats::current(); trace!("[LLM_LOCAL] Servers ready, RSS={} (delta from start={})", @@ -240,7 +240,7 @@ pub async fn ensure_llama_servers_running( MemoryStats::format_bytes(end_mem.rss_bytes.saturating_sub(start_mem.rss_bytes))); log_jemalloc_stats(); - trace!("[PROFILE] ensure_llama_servers_running EXIT OK"); + trace!("ensure_llama_servers_running EXIT OK"); Ok(()) } else { let mut error_msg = "Servers failed to start within timeout:".to_string(); diff --git a/src/main.rs b/src/main.rs index bd4d5161f..fc804d8a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,7 +26,20 @@ use tower_http::trace::TraceLayer; async fn ensure_vendor_files_in_minio(drive: &aws_sdk_s3::Client) { use aws_sdk_s3::primitives::ByteStream; - let htmx_content = include_bytes!("../botserver-stack/static/js/vendor/htmx.min.js"); + let htmx_paths = [ + "./botui/ui/suite/js/vendor/htmx.min.js", + "../botui/ui/suite/js/vendor/htmx.min.js", + ]; + + let htmx_content = htmx_paths + .iter() + .find_map(|path| std::fs::read(path).ok()); + + let Some(content) = htmx_content else { + warn!("Could not find htmx.min.js in botui, skipping MinIO upload"); + return; + }; + let bucket = "default.gbai"; let key = "default.gblib/vendor/htmx.min.js"; @@ -34,7 +47,7 @@ async fn ensure_vendor_files_in_minio(drive: &aws_sdk_s3::Client) { .put_object() .bucket(bucket) .key(key) - .body(ByteStream::from_static(htmx_content)) + .body(ByteStream::from(content)) .content_type("application/javascript") .send() .await @@ -222,20 +235,28 @@ async fn run_axum_server( .add_public_path("/apps")); // Apps are public - no auth required use crate::core::urls::ApiUrls; + use crate::core::product::{PRODUCT_CONFIG, get_product_config_json}; + + // Initialize product configuration + { + let config = PRODUCT_CONFIG.read().expect("Failed to read product config"); + info!("Product: {} | Theme: {} | Apps: {:?}", + config.name, config.theme, config.get_enabled_apps()); + } + + // Product config endpoint + async fn get_product_config() -> Json { + Json(get_product_config_json()) + } let mut api_router = Router::new() .route("/health", get(health_check_simple)) .route(ApiUrls::HEALTH, get(health_check)) + .route("/api/product", get(get_product_config)) .route(ApiUrls::SESSIONS, post(create_session)) .route(ApiUrls::SESSIONS, get(get_sessions)) - .route( - &ApiUrls::SESSION_HISTORY.replace(":id", "{session_id}"), - get(get_session_history), - ) - .route( - &ApiUrls::SESSION_START.replace(":id", "{session_id}"), - post(start_session), - ) + .route(ApiUrls::SESSION_HISTORY, get(get_session_history)) + .route(ApiUrls::SESSION_START, post(start_session)) .route(ApiUrls::WS, get(websocket_handler)) .merge(botserver::drive::configure()); @@ -244,7 +265,8 @@ async fn run_axum_server( api_router = api_router .route(ApiUrls::AUTH, get(auth_handler)) .merge(crate::core::directory::api::configure_user_routes()) - .merge(crate::directory::router::configure()); + .merge(crate::directory::router::configure()) + .merge(crate::directory::auth_routes::configure()); } #[cfg(feature = "meet")] @@ -280,7 +302,11 @@ async fn run_axum_server( } api_router = api_router.merge(botserver::analytics::configure_analytics_routes()); + api_router = api_router.merge(crate::core::i18n::configure_i18n_routes()); + api_router = api_router.merge(botserver::docs::configure_docs_routes()); api_router = api_router.merge(botserver::paper::configure_paper_routes()); + api_router = api_router.merge(botserver::sheet::configure_sheet_routes()); + api_router = api_router.merge(botserver::slides::configure_slides_routes()); api_router = api_router.merge(botserver::research::configure_research_routes()); api_router = api_router.merge(botserver::sources::configure_sources_routes()); api_router = api_router.merge(botserver::designer::configure_designer_routes()); @@ -289,12 +315,18 @@ async fn run_axum_server( api_router = api_router.merge(botserver::basic::keywords::configure_db_routes()); api_router = api_router.merge(botserver::basic::keywords::configure_app_server_routes()); api_router = api_router.merge(botserver::auto_task::configure_autotask_routes()); + api_router = api_router.merge(crate::core::shared::admin::configure()); #[cfg(feature = "whatsapp")] { api_router = api_router.merge(crate::whatsapp::configure()); } + #[cfg(feature = "telegram")] + { + api_router = api_router.merge(botserver::telegram::configure()); + } + #[cfg(feature = "attendance")] { api_router = api_router.merge(crate::attendance::configure_attendance_routes()); @@ -331,6 +363,10 @@ async fn run_axum_server( info!("Security middleware enabled: rate limiting, security headers, panic handler, request ID tracking, authentication"); + // Path to UI files (botui) + let ui_path = std::env::var("BOTUI_PATH").unwrap_or_else(|_| "./botui/ui/suite".to_string()); + info!("Serving UI from: {}", ui_path); + let app = Router::new() .merge(api_router.with_state(app_state.clone())) // Authentication middleware for protected routes @@ -338,6 +374,8 @@ async fn run_axum_server( auth_config.clone(), auth_middleware, )) + // Serve auth UI pages + .nest_service("/auth", ServeDir::new(format!("{}/auth", ui_path))) // Static files fallback for legacy /apps/* paths .nest_service("/static", ServeDir::new(&site_path)) // Security middleware stack (order matters - first added is outermost) @@ -486,6 +524,21 @@ async fn main() -> std::io::Result<()> { println!("Starting General Bots {}...", env!("CARGO_PKG_VERSION")); } + let locales_path = if std::path::Path::new("./locales").exists() { + "./locales" + } else if std::path::Path::new("../botlib/locales").exists() { + "../botlib/locales" + } else if std::path::Path::new("../locales").exists() { + "../locales" + } else { + "./locales" + }; + if let Err(e) = crate::core::i18n::init_i18n(locales_path) { + warn!("Failed to initialize i18n from {}: {}. Translations will show keys.", locales_path, e); + } else { + info!("i18n initialized from {} with locales: {:?}", locales_path, crate::core::i18n::available_locales()); + } + let (progress_tx, _progress_rx) = tokio::sync::mpsc::unbounded_channel::(); let (state_tx, _state_rx) = tokio::sync::mpsc::channel::>(1); @@ -760,20 +813,100 @@ async fn main() -> std::io::Result<()> { ))); #[cfg(feature = "directory")] - let zitadel_config = botserver::directory::client::ZitadelConfig { - issuer_url: "https://localhost:8080".to_string(), - issuer: "https://localhost:8080".to_string(), - client_id: "client_id".to_string(), - client_secret: "client_secret".to_string(), - redirect_uri: "https://localhost:8080/callback".to_string(), - project_id: "default".to_string(), - api_url: "https://localhost:8080".to_string(), - service_account_key: None, + let zitadel_config = { + // Try to load from directory_config.json first + let config_path = "./config/directory_config.json"; + if let Ok(content) = std::fs::read_to_string(config_path) { + if let Ok(json) = serde_json::from_str::(&content) { + let base_url = json.get("base_url") + .and_then(|v| v.as_str()) + .unwrap_or("http://localhost:8300"); + let client_id = json.get("client_id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let client_secret = json.get("client_secret") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + info!("Loaded Zitadel config from {}: url={}", config_path, base_url); + + botserver::directory::client::ZitadelConfig { + issuer_url: base_url.to_string(), + issuer: base_url.to_string(), + client_id: client_id.to_string(), + client_secret: client_secret.to_string(), + redirect_uri: format!("{}/callback", base_url), + project_id: "default".to_string(), + api_url: base_url.to_string(), + service_account_key: None, + } + } else { + warn!("Failed to parse directory_config.json, using defaults"); + botserver::directory::client::ZitadelConfig { + issuer_url: "http://localhost:8300".to_string(), + issuer: "http://localhost:8300".to_string(), + client_id: String::new(), + client_secret: String::new(), + redirect_uri: "http://localhost:8300/callback".to_string(), + project_id: "default".to_string(), + api_url: "http://localhost:8300".to_string(), + service_account_key: None, + } + } + } else { + warn!("directory_config.json not found, using default Zitadel config"); + botserver::directory::client::ZitadelConfig { + issuer_url: "http://localhost:8300".to_string(), + issuer: "http://localhost:8300".to_string(), + client_id: String::new(), + client_secret: String::new(), + redirect_uri: "http://localhost:8300/callback".to_string(), + project_id: "default".to_string(), + api_url: "http://localhost:8300".to_string(), + service_account_key: None, + } + } }; #[cfg(feature = "directory")] let auth_service = Arc::new(tokio::sync::Mutex::new( - botserver::directory::AuthService::new(zitadel_config).map_err(|e| std::io::Error::other(format!("Failed to create auth service: {}", e)))?, + botserver::directory::AuthService::new(zitadel_config.clone()).map_err(|e| std::io::Error::other(format!("Failed to create auth service: {}", e)))?, )); + + #[cfg(feature = "directory")] + { + let pat_path = std::path::Path::new("./botserver-stack/conf/directory/admin-pat.txt"); + let bootstrap_client = if pat_path.exists() { + match std::fs::read_to_string(pat_path) { + Ok(pat_token) => { + let pat_token = pat_token.trim().to_string(); + info!("Using admin PAT token for bootstrap authentication"); + botserver::directory::client::ZitadelClient::with_pat_token(zitadel_config, pat_token) + .map_err(|e| std::io::Error::other(format!("Failed to create bootstrap client with PAT: {}", e)))? + } + Err(e) => { + warn!("Failed to read admin PAT token: {}, falling back to OAuth2", e); + botserver::directory::client::ZitadelClient::new(zitadel_config) + .map_err(|e| std::io::Error::other(format!("Failed to create bootstrap client: {}", e)))? + } + } + } else { + info!("Admin PAT not found, using OAuth2 client credentials for bootstrap"); + botserver::directory::client::ZitadelClient::new(zitadel_config) + .map_err(|e| std::io::Error::other(format!("Failed to create bootstrap client: {}", e)))? + }; + + match botserver::directory::bootstrap::check_and_bootstrap_admin(&bootstrap_client).await { + Ok(Some(_)) => { + info!("Bootstrap completed - admin credentials displayed in console"); + } + Ok(None) => { + info!("Admin user exists, bootstrap skipped"); + } + Err(e) => { + warn!("Bootstrap check failed (Zitadel may not be ready): {}", e); + } + } + } let config_manager = ConfigManager::new(pool.clone()); let mut bot_conn = pool.get().map_err(|e| std::io::Error::other(format!("Failed to get database connection: {}", e)))?; @@ -961,18 +1094,18 @@ async fn main() -> std::io::Result<()> { let monitor_bot_id = default_bot_id; tokio::spawn(async move { register_thread("drive-monitor", "drive"); - trace!("[PROFILE] DriveMonitor::new starting..."); + trace!("DriveMonitor::new starting..."); let monitor = botserver::DriveMonitor::new( drive_monitor_state, bucket_name.clone(), monitor_bot_id, ); - trace!("[PROFILE] DriveMonitor::new done, calling start_monitoring..."); + trace!("DriveMonitor::new done, calling start_monitoring..."); info!("Starting DriveMonitor for bucket: {}", bucket_name); if let Err(e) = monitor.start_monitoring().await { error!("DriveMonitor failed: {}", e); } - trace!("[PROFILE] DriveMonitor start_monitoring returned"); + trace!("DriveMonitor start_monitoring returned"); }); } @@ -994,23 +1127,23 @@ async fn main() -> std::io::Result<()> { let app_state_for_llm = app_state.clone(); tokio::spawn(async move { register_thread("llm-server-init", "llm"); - eprintln!("[PROFILE] ensure_llama_servers_running starting..."); + trace!("ensure_llama_servers_running starting..."); if let Err(e) = ensure_llama_servers_running(app_state_for_llm).await { error!("Failed to start LLM servers: {}", e); } - eprintln!("[PROFILE] ensure_llama_servers_running completed"); + trace!("ensure_llama_servers_running completed"); record_thread_activity("llm-server-init"); }); trace!("Initial data setup task spawned"); - eprintln!("[PROFILE] All background tasks spawned, starting HTTP server..."); + trace!("All background tasks spawned, starting HTTP server..."); trace!("Starting HTTP server on port {}...", config.server.port); - eprintln!("[PROFILE] run_axum_server starting on port {}...", config.server.port); + info!("Starting HTTP server on port {}...", config.server.port); if let Err(e) = run_axum_server(app_state, config.server.port, worker_count).await { error!("Failed to start HTTP server: {}", e); std::process::exit(1); } - eprintln!("[PROFILE] run_axum_server returned (should not happen normally)"); + trace!("run_axum_server returned (should not happen normally)"); if let Some(handle) = ui_handle { handle.join().ok(); diff --git a/src/meet/mod.rs b/src/meet/mod.rs index 2b3f19953..3dd48c084 100644 --- a/src/meet/mod.rs +++ b/src/meet/mod.rs @@ -26,18 +26,9 @@ pub fn configure() -> Router> { .route(ApiUrls::MEET_PARTICIPANTS, get(all_participants)) .route(ApiUrls::MEET_RECENT, get(recent_meetings)) .route(ApiUrls::MEET_SCHEDULED, get(scheduled_meetings)) - .route( - &ApiUrls::MEET_ROOM_BY_ID.replace(":id", "{room_id}"), - get(get_room), - ) - .route( - &ApiUrls::MEET_JOIN.replace(":id", "{room_id}"), - post(join_room), - ) - .route( - &ApiUrls::MEET_TRANSCRIPTION.replace(":id", "{room_id}"), - post(start_transcription), - ) + .route(ApiUrls::MEET_ROOM_BY_ID, get(get_room)) + .route(ApiUrls::MEET_JOIN, post(join_room)) + .route(ApiUrls::MEET_TRANSCRIPTION, post(start_transcription)) .route(ApiUrls::MEET_TOKEN, post(get_meeting_token)) .route(ApiUrls::MEET_INVITE, post(send_meeting_invites)) .route(ApiUrls::WS_MEET, get(meeting_websocket)) @@ -46,7 +37,7 @@ pub fn configure() -> Router> { post(conversations::create_conversation), ) .route( - "/conversations/{id}/join", + "/conversations/:id/join", post(conversations::join_conversation), ) .route( diff --git a/src/paper/mod.rs b/src/paper/mod.rs index 28914d21a..823b022e5 100644 --- a/src/paper/mod.rs +++ b/src/paper/mod.rs @@ -87,8 +87,8 @@ pub fn configure_paper_routes() -> Router> { .route(ApiUrls::PAPER_SEARCH, get(handle_search_documents)) .route(ApiUrls::PAPER_SAVE, post(handle_save_document)) .route(ApiUrls::PAPER_AUTOSAVE, post(handle_autosave)) - .route(&ApiUrls::PAPER_BY_ID.replace(":id", "{id}"), get(handle_get_document)) - .route(&ApiUrls::PAPER_DELETE.replace(":id", "{id}"), post(handle_delete_document)) + .route(ApiUrls::PAPER_BY_ID, get(handle_get_document)) + .route(ApiUrls::PAPER_DELETE, post(handle_delete_document)) .route(ApiUrls::PAPER_TEMPLATE_BLANK, post(handle_template_blank)) .route(ApiUrls::PAPER_TEMPLATE_MEETING, post(handle_template_meeting)) .route(ApiUrls::PAPER_TEMPLATE_TODO, post(handle_template_todo)) @@ -96,6 +96,8 @@ pub fn configure_paper_routes() -> Router> { ApiUrls::PAPER_TEMPLATE_RESEARCH, post(handle_template_research), ) + .route(ApiUrls::PAPER_TEMPLATE_REPORT, post(handle_template_report)) + .route(ApiUrls::PAPER_TEMPLATE_LETTER, post(handle_template_letter)) .route(ApiUrls::PAPER_AI_SUMMARIZE, post(handle_ai_summarize)) .route(ApiUrls::PAPER_AI_EXPAND, post(handle_ai_expand)) .route(ApiUrls::PAPER_AI_IMPROVE, post(handle_ai_improve)) @@ -876,6 +878,83 @@ pub async fn handle_template_research( Html(format_document_content(&title, &content)) } +pub async fn handle_template_report( + State(state): State>, + headers: HeaderMap, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + log::error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + let doc_id = Uuid::new_v4().to_string(); + let title = "Report".to_string(); + let now = Utc::now(); + + let mut content = String::new(); + content.push_str("# Report\n\n"); + let _ = writeln!(content, "**Date:** {}\n", now.format("%Y-%m-%d")); + content.push_str("**Author:**\n\n"); + content.push_str("---\n\n"); + content.push_str("## Executive Summary\n\n\n\n"); + content.push_str("## Introduction\n\n\n\n"); + content.push_str("## Background\n\n\n\n"); + content.push_str("## Findings\n\n### Key Finding 1\n\n\n\n### Key Finding 2\n\n\n\n"); + content.push_str("## Analysis\n\n\n\n"); + content.push_str("## Recommendations\n\n1. \n2. \n3. \n\n"); + content.push_str("## Conclusion\n\n\n\n"); + content.push_str("## Appendix\n\n"); + + let _ = + save_document_to_drive(&state, &user_identifier, &doc_id, &title, &content, false).await; + + Html(format_document_content(&title, &content)) +} + +pub async fn handle_template_letter( + State(state): State>, + headers: HeaderMap, +) -> impl IntoResponse { + let (_user_id, user_identifier) = match get_current_user(&state, &headers).await { + Ok(u) => u, + Err(e) => { + log::error!("Auth error: {}", e); + return Html(format_error("Authentication required")); + } + }; + + let doc_id = Uuid::new_v4().to_string(); + let title = "Letter".to_string(); + let now = Utc::now(); + + let mut content = String::new(); + content.push_str("[Your Name]\n"); + content.push_str("[Your Address]\n"); + content.push_str("[City, State ZIP]\n"); + content.push_str("[Your Email]\n\n"); + let _ = writeln!(content, "{}\n", now.format("%B %d, %Y")); + content.push_str("[Recipient Name]\n"); + content.push_str("[Recipient Title]\n"); + content.push_str("[Company/Organization]\n"); + content.push_str("[Address]\n"); + content.push_str("[City, State ZIP]\n\n"); + content.push_str("Dear [Recipient Name],\n\n"); + content.push_str("[Opening paragraph - State the purpose of your letter]\n\n"); + content.push_str("[Body paragraph(s) - Provide details, explanations, or supporting information]\n\n"); + content.push_str("[Closing paragraph - Summarize, request action, or express appreciation]\n\n"); + content.push_str("Sincerely,\n\n\n"); + content.push_str("[Your Signature]\n"); + content.push_str("[Your Typed Name]\n"); + + let _ = + save_document_to_drive(&state, &user_identifier, &doc_id, &title, &content, false).await; + + Html(format_document_content(&title, &content)) +} + pub async fn handle_ai_summarize( State(state): State>, Json(payload): Json, diff --git a/src/research/mod.rs b/src/research/mod.rs index fa4d4094c..861b7b898 100644 --- a/src/research/mod.rs +++ b/src/research/mod.rs @@ -67,7 +67,7 @@ pub fn configure_research_routes() -> Router> { ApiUrls::RESEARCH_COLLECTIONS_NEW, post(handle_create_collection), ) - .route(&ApiUrls::RESEARCH_COLLECTION_BY_ID.replace(":id", "{id}"), get(handle_get_collection)) + .route(ApiUrls::RESEARCH_COLLECTION_BY_ID, get(handle_get_collection)) .route(ApiUrls::RESEARCH_SEARCH, post(handle_search)) .route(ApiUrls::RESEARCH_RECENT, get(handle_recent_searches)) .route(ApiUrls::RESEARCH_TRENDING, get(handle_trending_tags)) diff --git a/src/security/zitadel_auth.rs b/src/security/zitadel_auth.rs index 148d5df24..67c44b531 100644 --- a/src/security/zitadel_auth.rs +++ b/src/security/zitadel_auth.rs @@ -28,8 +28,8 @@ pub struct ZitadelAuthConfig { impl Default for ZitadelAuthConfig { fn default() -> Self { Self { - issuer_url: "https://localhost:8080".to_string(), - api_url: "https://localhost:8080".to_string(), + issuer_url: "http://localhost:8300".to_string(), + api_url: "http://localhost:8300".to_string(), client_id: String::new(), client_secret: String::new(), project_id: String::new(), diff --git a/src/sheet/mod.rs b/src/sheet/mod.rs new file mode 100644 index 000000000..7295ad652 --- /dev/null +++ b/src/sheet/mod.rs @@ -0,0 +1,2854 @@ +use crate::shared::state::AppState; +use axum::{ + extract::{ + ws::{Message, WebSocket, WebSocketUpgrade}, + Path, Query, State, + }, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, + Json, Router, +}; +use chrono::{DateTime, Datelike, Local, NaiveDate, Utc}; +use futures_util::{SinkExt, StreamExt}; +use log::{error, info}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::broadcast; +use uuid::Uuid; + +type CollaborationChannels = + Arc>>>; + +static COLLAB_CHANNELS: std::sync::OnceLock = std::sync::OnceLock::new(); + +fn get_collab_channels() -> &'static CollaborationChannels { + COLLAB_CHANNELS.get_or_init(|| Arc::new(tokio::sync::RwLock::new(HashMap::new()))) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CollabMessage { + pub msg_type: String, + pub sheet_id: String, + pub user_id: String, + pub user_name: String, + pub user_color: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub row: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub col: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub worksheet_index: Option, + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Collaborator { + pub id: String, + pub name: String, + pub color: String, + pub cursor_row: Option, + pub cursor_col: Option, + pub connected_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Spreadsheet { + pub id: String, + pub name: String, + pub owner_id: String, + pub worksheets: Vec, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Worksheet { + pub name: String, + pub data: HashMap, + #[serde(skip_serializing_if = "Option::is_none")] + pub column_widths: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub row_heights: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub frozen_rows: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frozen_cols: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub merged_cells: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub filters: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub hidden_rows: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub validations: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub conditional_formats: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub charts: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CellData { + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub formula: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub style: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub note: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CellStyle { + #[serde(skip_serializing_if = "Option::is_none")] + pub font_family: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub font_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub font_weight: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub font_style: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text_decoration: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub color: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text_align: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub vertical_align: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub border: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MergedCell { + pub start_row: u32, + pub start_col: u32, + pub end_row: u32, + pub end_col: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FilterConfig { + pub filter_type: String, + pub values: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub condition: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value1: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value2: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidationRule { + pub validation_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub operator: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value1: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value2: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConditionalFormatRule { + pub id: String, + pub start_row: u32, + pub start_col: u32, + pub end_row: u32, + pub end_col: u32, + pub rule_type: String, + pub condition: String, + pub style: CellStyle, + pub priority: i32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChartConfig { + pub id: String, + pub chart_type: String, + pub title: String, + pub data_range: String, + pub label_range: Option, + pub position: ChartPosition, + pub options: ChartOptions, + pub datasets: Vec, + pub labels: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChartPosition { + pub row: u32, + pub col: u32, + pub width: u32, + pub height: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChartOptions { + pub show_legend: bool, + pub show_grid: bool, + pub stacked: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub legend_position: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub x_axis_title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub y_axis_title: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChartDataset { + pub label: String, + pub data: Vec, + pub color: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub background_color: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SpreadsheetMetadata { + pub id: String, + pub name: String, + pub owner_id: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub worksheet_count: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SaveRequest { + pub id: Option, + pub name: String, + pub worksheets: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoadQuery { + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchQuery { + pub q: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CellUpdateRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub row: u32, + pub col: u32, + pub value: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FormatRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub start_row: u32, + pub start_col: u32, + pub end_row: u32, + pub end_col: u32, + pub style: CellStyle, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportRequest { + pub id: String, + pub format: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ShareRequest { + pub sheet_id: String, + pub email: String, + pub permission: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SaveResponse { + pub id: String, + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FormulaResult { + pub value: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FormulaRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub formula: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MergeCellsRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub start_row: u32, + pub start_col: u32, + pub end_row: u32, + pub end_col: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FreezePanesRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub frozen_rows: u32, + pub frozen_cols: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SortRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub start_row: u32, + pub start_col: u32, + pub end_row: u32, + pub end_col: u32, + pub sort_col: u32, + pub ascending: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FilterRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub col: u32, + pub filter_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub condition: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value1: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value2: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChartRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub chart_type: String, + pub data_range: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub label_range: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub position: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConditionalFormatRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub start_row: u32, + pub start_col: u32, + pub end_row: u32, + pub end_col: u32, + pub rule_type: String, + pub condition: String, + pub style: CellStyle, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DataValidationRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub start_row: u32, + pub start_col: u32, + pub end_row: u32, + pub end_col: u32, + pub validation_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub operator: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value1: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub value2: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidateCellRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub row: u32, + pub col: u32, + pub value: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidationResult { + pub valid: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClearFilterRequest { + pub sheet_id: String, + pub worksheet_index: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub col: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeleteChartRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub chart_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AddNoteRequest { + pub sheet_id: String, + pub worksheet_index: usize, + pub row: u32, + pub col: u32, + pub note: String, +} + +pub fn configure_sheet_routes() -> Router> { + Router::new() + .route("/api/sheet/list", get(handle_list_sheets)) + .route("/api/sheet/search", get(handle_search_sheets)) + .route("/api/sheet/load", get(handle_load_sheet)) + .route("/api/sheet/save", post(handle_save_sheet)) + .route("/api/sheet/delete", post(handle_delete_sheet)) + .route("/api/sheet/cell", post(handle_update_cell)) + .route("/api/sheet/format", post(handle_format_cells)) + .route("/api/sheet/formula", post(handle_evaluate_formula)) + .route("/api/sheet/export", post(handle_export_sheet)) + .route("/api/sheet/share", post(handle_share_sheet)) + .route("/api/sheet/new", get(handle_new_sheet)) + .route("/api/sheet/merge", post(handle_merge_cells)) + .route("/api/sheet/unmerge", post(handle_unmerge_cells)) + .route("/api/sheet/freeze", post(handle_freeze_panes)) + .route("/api/sheet/sort", post(handle_sort_range)) + .route("/api/sheet/filter", post(handle_filter_data)) + .route("/api/sheet/filter/clear", post(handle_clear_filter)) + .route("/api/sheet/chart", post(handle_create_chart)) + .route("/api/sheet/chart/delete", post(handle_delete_chart)) + .route("/api/sheet/conditional-format", post(handle_conditional_format)) + .route("/api/sheet/data-validation", post(handle_data_validation)) + .route("/api/sheet/validate-cell", post(handle_validate_cell)) + .route("/api/sheet/note", post(handle_add_note)) + .route("/api/sheet/import", post(handle_import_sheet)) + .route("/api/sheet/:id", get(handle_get_sheet_by_id)) + .route("/api/sheet/:id/collaborators", get(handle_get_collaborators)) + .route("/ws/sheet/:sheet_id", get(handle_sheet_websocket)) +} + +fn get_user_sheets_path(user_id: &str) -> String { + format!("users/{}/sheets", user_id) +} + +async fn save_sheet_to_drive( + state: &Arc, + user_id: &str, + sheet: &Spreadsheet, +) -> Result<(), String> { + let drive = state + .drive + .as_ref() + .ok_or_else(|| "Drive not available".to_string())?; + + let path = format!("{}/{}.json", get_user_sheets_path(user_id), sheet.id); + let content = + serde_json::to_string_pretty(sheet).map_err(|e| format!("Serialization error: {e}"))?; + + drive + .put_object() + .bucket("gbo") + .key(&path) + .body(content.into_bytes().into()) + .content_type("application/json") + .send() + .await + .map_err(|e| format!("Failed to save sheet: {e}"))?; + + Ok(()) +} + +async fn load_sheet_from_drive( + state: &Arc, + user_id: &str, + sheet_id: &str, +) -> Result { + let drive = state + .drive + .as_ref() + .ok_or_else(|| "Drive not available".to_string())?; + + let path = format!("{}/{}.json", get_user_sheets_path(user_id), sheet_id); + + let result = drive + .get_object() + .bucket("gbo") + .key(&path) + .send() + .await + .map_err(|e| format!("Failed to load sheet: {e}"))?; + + let bytes = result + .body + .collect() + .await + .map_err(|e| format!("Failed to read sheet: {e}"))? + .into_bytes(); + + let sheet: Spreadsheet = + serde_json::from_slice(&bytes).map_err(|e| format!("Failed to parse sheet: {e}"))?; + + Ok(sheet) +} + +async fn list_sheets_from_drive( + state: &Arc, + user_id: &str, +) -> Result, String> { + let drive = state + .drive + .as_ref() + .ok_or_else(|| "Drive not available".to_string())?; + + let prefix = format!("{}/", get_user_sheets_path(user_id)); + + let result = drive + .list_objects_v2() + .bucket("gbo") + .prefix(&prefix) + .send() + .await + .map_err(|e| format!("Failed to list sheets: {e}"))?; + + let mut sheets = Vec::new(); + + if let Some(contents) = result.contents { + for obj in contents { + if let Some(key) = obj.key { + if key.ends_with(".json") { + if let Ok(sheet) = + load_sheet_from_drive(state, user_id, &extract_id_from_path(&key)).await + { + sheets.push(SpreadsheetMetadata { + id: sheet.id, + name: sheet.name, + owner_id: sheet.owner_id, + created_at: sheet.created_at, + updated_at: sheet.updated_at, + worksheet_count: sheet.worksheets.len(), + }); + } + } + } + } + } + + sheets.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); + + Ok(sheets) +} + +fn extract_id_from_path(path: &str) -> String { + path.split('/') + .last() + .unwrap_or("") + .trim_end_matches(".json") + .to_string() +} + +async fn delete_sheet_from_drive( + state: &Arc, + user_id: &str, + sheet_id: &str, +) -> Result<(), String> { + let drive = state + .drive + .as_ref() + .ok_or_else(|| "Drive not available".to_string())?; + + let path = format!("{}/{}.json", get_user_sheets_path(user_id), sheet_id); + + drive + .delete_object() + .bucket("gbo") + .key(&path) + .send() + .await + .map_err(|e| format!("Failed to delete sheet: {e}"))?; + + Ok(()) +} + +fn get_current_user_id() -> String { + "default-user".to_string() +} + +pub async fn handle_new_sheet( + State(_state): State>, +) -> Result, (StatusCode, Json)> { + let sheet = Spreadsheet { + id: Uuid::new_v4().to_string(), + name: "Untitled Spreadsheet".to_string(), + owner_id: get_current_user_id(), + worksheets: vec![Worksheet { + name: "Sheet1".to_string(), + data: HashMap::new(), + column_widths: None, + row_heights: None, + frozen_rows: None, + frozen_cols: None, + merged_cells: None, + filters: None, + hidden_rows: None, + validations: None, + conditional_formats: None, + charts: None, + }], + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + Ok(Json(sheet)) +} + +pub async fn handle_list_sheets( + State(state): State>, +) -> Result>, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + match list_sheets_from_drive(&state, &user_id).await { + Ok(sheets) => Ok(Json(sheets)), + Err(e) => { + error!("Failed to list sheets: {}", e); + Ok(Json(Vec::new())) + } + } +} + +pub async fn handle_search_sheets( + State(state): State>, + Query(query): Query, +) -> Result>, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + let sheets = match list_sheets_from_drive(&state, &user_id).await { + Ok(s) => s, + Err(_) => Vec::new(), + }; + + let filtered = if let Some(q) = query.q { + let q_lower = q.to_lowercase(); + sheets + .into_iter() + .filter(|s| s.name.to_lowercase().contains(&q_lower)) + .collect() + } else { + sheets + }; + + Ok(Json(filtered)) +} + +pub async fn handle_load_sheet( + State(state): State>, + Query(query): Query, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + match load_sheet_from_drive(&state, &user_id, &query.id).await { + Ok(sheet) => Ok(Json(sheet)), + Err(e) => Err(( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ "error": e })), + )), + } +} + +pub async fn handle_save_sheet( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + let sheet_id = req.id.unwrap_or_else(|| Uuid::new_v4().to_string()); + + let sheet = Spreadsheet { + id: sheet_id.clone(), + name: req.name, + owner_id: user_id.clone(), + worksheets: req.worksheets, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e })), + )); + } + + Ok(Json(SaveResponse { + id: sheet_id, + success: true, + message: Some("Sheet saved successfully".to_string()), + })) +} + +pub async fn handle_delete_sheet( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + if let Err(e) = delete_sheet_from_drive(&state, &user_id, &req.id).await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e })), + )); + } + + Ok(Json(SaveResponse { + id: req.id, + success: true, + message: Some("Sheet deleted".to_string()), + })) +} + +pub async fn handle_update_cell( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => { + return Err(( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ "error": e })), + )) + } + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ "error": "Invalid worksheet index" })), + )); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + let key = format!("{},{}", req.row, req.col); + + let (value, formula) = if req.value.starts_with('=') { + let result = evaluate_formula(&req.value, worksheet); + (Some(result.value), Some(req.value.clone())) + } else { + (Some(req.value.clone()), None) + }; + + let cell = worksheet.data.entry(key).or_insert_with(|| CellData { + value: None, + formula: None, + style: None, + format: None, + note: None, + }); + + cell.value = value; + cell.formula = formula; + + sheet.updated_at = Utc::now(); + + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e })), + )); + } + + broadcast_sheet_change( + &req.sheet_id, + "cellChange", + &user_id, + Some(req.row), + Some(req.col), + Some(&req.value), + ) + .await; + + Ok(Json(SaveResponse { + id: req.sheet_id, + success: true, + message: Some("Cell updated".to_string()), + })) +} + +pub async fn handle_format_cells( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => { + return Err(( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ "error": e })), + )) + } + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ "error": "Invalid worksheet index" })), + )); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + + for row in req.start_row..=req.end_row { + for col in req.start_col..=req.end_col { + let key = format!("{},{}", row, col); + let cell = worksheet.data.entry(key).or_insert_with(|| CellData { + value: None, + formula: None, + style: None, + format: None, + note: None, + }); + cell.style = Some(req.style.clone()); + } + } + + sheet.updated_at = Utc::now(); + + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e })), + )); + } + + Ok(Json(SaveResponse { + id: req.sheet_id, + success: true, + message: Some("Format applied".to_string()), + })) +} + +pub async fn handle_evaluate_formula( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + let sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(_) => { + return Ok(Json(evaluate_formula( + &req.formula, + &Worksheet { + name: "temp".to_string(), + data: HashMap::new(), + column_widths: None, + row_heights: None, + frozen_rows: None, + frozen_cols: None, + merged_cells: None, + filters: None, + hidden_rows: None, + validations: None, + conditional_formats: None, + charts: None, + }, + ))) + } + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ "error": "Invalid worksheet index" })), + )); + } + + let result = evaluate_formula(&req.formula, &sheet.worksheets[req.worksheet_index]); + Ok(Json(result)) +} + +fn evaluate_formula(formula: &str, worksheet: &Worksheet) -> FormulaResult { + if !formula.starts_with('=') { + return FormulaResult { + value: formula.to_string(), + error: None, + }; + } + + let expr = formula[1..].to_uppercase(); + + let evaluators: Vec Option> = vec![ + evaluate_sum, + evaluate_average, + evaluate_count, + evaluate_counta, + evaluate_countblank, + evaluate_countif, + evaluate_sumif, + evaluate_averageif, + evaluate_max, + evaluate_min, + evaluate_if, + evaluate_iferror, + evaluate_vlookup, + evaluate_hlookup, + evaluate_index_match, + evaluate_concatenate, + evaluate_left, + evaluate_right, + evaluate_mid, + evaluate_len, + evaluate_trim, + evaluate_upper, + evaluate_lower, + evaluate_proper, + evaluate_substitute, + evaluate_round, + evaluate_roundup, + evaluate_rounddown, + evaluate_abs, + evaluate_sqrt, + evaluate_power, + evaluate_mod_formula, + evaluate_and, + evaluate_or, + evaluate_not, + evaluate_today, + evaluate_now, + evaluate_date, + evaluate_year, + evaluate_month, + evaluate_day, + evaluate_datedif, + evaluate_arithmetic, + ]; + + for evaluator in evaluators { + if let Some(result) = evaluator(&expr, worksheet) { + return FormulaResult { + value: result, + error: None, + }; + } + } + + FormulaResult { + value: "#ERROR!".to_string(), + error: Some("Invalid formula".to_string()), + } +} + +fn evaluate_sum(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("SUM(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let values = get_range_values(inner, worksheet); + let sum: f64 = values.iter().sum(); + Some(format_number(sum)) +} + +fn evaluate_average(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("AVERAGE(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[8..expr.len() - 1]; + let values = get_range_values(inner, worksheet); + if values.is_empty() { + return Some("#DIV/0!".to_string()); + } + let avg = values.iter().sum::() / values.len() as f64; + Some(format_number(avg)) +} + +fn evaluate_count(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("COUNT(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let values = get_range_values(inner, worksheet); + Some(values.len().to_string()) +} + +fn evaluate_counta(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("COUNTA(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[7..expr.len() - 1]; + let count = get_range_string_values(inner, worksheet) + .iter() + .filter(|v| !v.is_empty()) + .count(); + Some(count.to_string()) +} + +fn evaluate_countblank(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("COUNTBLANK(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[11..expr.len() - 1]; + let (start, end) = parse_range(inner)?; + let mut count = 0; + for row in start.0..=end.0 { + for col in start.1..=end.1 { + let key = format!("{},{}", row, col); + let is_blank = worksheet + .data + .get(&key) + .and_then(|c| c.value.as_ref()) + .map(|v| v.is_empty()) + .unwrap_or(true); + if is_blank { + count += 1; + } + } + } + Some(count.to_string()) +} + +fn evaluate_countif(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("COUNTIF(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[8..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() != 2 { + return None; + } + let range = parts[0].trim(); + let criteria = parts[1].trim().trim_matches('"'); + let values = get_range_string_values(range, worksheet); + let count = count_matching(&values, criteria); + Some(count.to_string()) +} + +fn evaluate_sumif(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("SUMIF(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() < 2 { + return None; + } + let criteria_range = parts[0].trim(); + let criteria = parts[1].trim().trim_matches('"'); + let sum_range = if parts.len() > 2 { + parts[2].trim() + } else { + criteria_range + }; + + let criteria_values = get_range_string_values(criteria_range, worksheet); + let sum_values = get_range_values(sum_range, worksheet); + + let mut sum = 0.0; + for (i, cv) in criteria_values.iter().enumerate() { + if matches_criteria(cv, criteria) { + if let Some(sv) = sum_values.get(i) { + sum += sv; + } + } + } + Some(format_number(sum)) +} + +fn evaluate_averageif(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("AVERAGEIF(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[10..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() < 2 { + return None; + } + let criteria_range = parts[0].trim(); + let criteria = parts[1].trim().trim_matches('"'); + let avg_range = if parts.len() > 2 { + parts[2].trim() + } else { + criteria_range + }; + + let criteria_values = get_range_string_values(criteria_range, worksheet); + let avg_values = get_range_values(avg_range, worksheet); + + let mut sum = 0.0; + let mut count = 0; + for (i, cv) in criteria_values.iter().enumerate() { + if matches_criteria(cv, criteria) { + if let Some(av) = avg_values.get(i) { + sum += av; + count += 1; + } + } + } + if count == 0 { + return Some("#DIV/0!".to_string()); + } + Some(format_number(sum / count as f64)) +} + +fn evaluate_max(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("MAX(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let values = get_range_values(inner, worksheet); + if values.is_empty() { + return Some("0".to_string()); + } + let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + Some(format_number(max)) +} + +fn evaluate_min(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("MIN(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let values = get_range_values(inner, worksheet); + if values.is_empty() { + return Some("0".to_string()); + } + let min = values.iter().cloned().fold(f64::INFINITY, f64::min); + Some(format_number(min)) +} + +fn evaluate_if(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("IF(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[3..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() < 2 { + return None; + } + let condition = parts[0].trim(); + let true_val = parts[1].trim().trim_matches('"'); + let false_val = if parts.len() > 2 { + parts[2].trim().trim_matches('"') + } else { + "FALSE" + }; + + let result = evaluate_condition(condition, worksheet); + Some(if result { true_val } else { false_val }.to_string()) +} + +fn evaluate_iferror(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("IFERROR(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[8..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() != 2 { + return None; + } + let value_expr = parts[0].trim(); + let error_val = parts[1].trim().trim_matches('"'); + + let result = evaluate_formula(&format!("={}", value_expr), worksheet); + if result.value.starts_with('#') { + Some(error_val.to_string()) + } else { + Some(result.value) + } +} + +fn evaluate_vlookup(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("VLOOKUP(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[8..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() < 3 { + return None; + } + + let search_value = resolve_cell_value(parts[0].trim(), worksheet); + let range = parts[1].trim(); + let col_index: usize = parts[2].trim().parse().ok()?; + let exact_match = parts.get(3).map(|v| v.trim() == "FALSE").unwrap_or(true); + + let (start, end) = parse_range(range)?; + + for row in start.0..=end.0 { + let first_col_key = format!("{},{}", row, start.1); + let cell_value = worksheet + .data + .get(&first_col_key) + .and_then(|c| c.value.clone()) + .unwrap_or_default(); + + let matches = if exact_match { + cell_value.to_uppercase() == search_value.to_uppercase() + } else { + cell_value + .to_uppercase() + .starts_with(&search_value.to_uppercase()) + }; + + if matches { + let result_col = start.1 + col_index as u32 - 1; + if result_col <= end.1 { + let result_key = format!("{},{}", row, result_col); + return worksheet + .data + .get(&result_key) + .and_then(|c| c.value.clone()) + .or(Some(String::new())); + } + } + } + Some("#N/A".to_string()) +} + +fn evaluate_hlookup(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("HLOOKUP(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[8..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() < 3 { + return None; + } + + let search_value = resolve_cell_value(parts[0].trim(), worksheet); + let range = parts[1].trim(); + let row_index: usize = parts[2].trim().parse().ok()?; + let exact_match = parts.get(3).map(|v| v.trim() == "FALSE").unwrap_or(true); + + let (start, end) = parse_range(range)?; + + for col in start.1..=end.1 { + let first_row_key = format!("{},{}", start.0, col); + let cell_value = worksheet + .data + .get(&first_row_key) + .and_then(|c| c.value.clone()) + .unwrap_or_default(); + + let matches = if exact_match { + cell_value.to_uppercase() == search_value.to_uppercase() + } else { + cell_value + .to_uppercase() + .starts_with(&search_value.to_uppercase()) + }; + + if matches { + let result_row = start.0 + row_index as u32 - 1; + if result_row <= end.0 { + let result_key = format!("{},{}", result_row, col); + return worksheet + .data + .get(&result_key) + .and_then(|c| c.value.clone()) + .or(Some(String::new())); + } + } + } + Some("#N/A".to_string()) +} + +fn evaluate_index_match(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("INDEX(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() < 2 { + return None; + } + + let range = parts[0].trim(); + let row_num: u32 = parts[1].trim().parse().ok()?; + let col_num: u32 = parts.get(2).and_then(|v| v.trim().parse().ok()).unwrap_or(1); + + let (start, _end) = parse_range(range)?; + let target_row = start.0 + row_num - 1; + let target_col = start.1 + col_num - 1; + let key = format!("{},{}", target_row, target_col); + + worksheet + .data + .get(&key) + .and_then(|c| c.value.clone()) + .or(Some(String::new())) +} + +fn evaluate_concatenate(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("CONCATENATE(") && !expr.starts_with("CONCAT(") { + return None; + } + let start_idx = if expr.starts_with("CONCATENATE(") { + 12 + } else { + 7 + }; + if !expr.ends_with(')') { + return None; + } + let inner = &expr[start_idx..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + let result: String = parts + .iter() + .map(|p| { + let trimmed = p.trim().trim_matches('"'); + resolve_cell_value(trimmed, worksheet) + }) + .collect(); + Some(result) +} + +fn evaluate_left(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("LEFT(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[5..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.is_empty() { + return None; + } + let text = resolve_cell_value(parts[0].trim().trim_matches('"'), worksheet); + let num: usize = parts.get(1).and_then(|v| v.trim().parse().ok()).unwrap_or(1); + Some(text.chars().take(num).collect()) +} + +fn evaluate_right(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("RIGHT(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.is_empty() { + return None; + } + let text = resolve_cell_value(parts[0].trim().trim_matches('"'), worksheet); + let num: usize = parts.get(1).and_then(|v| v.trim().parse().ok()).unwrap_or(1); + let len = text.chars().count(); + Some(text.chars().skip(len.saturating_sub(num)).collect()) +} + +fn evaluate_mid(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("MID(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() < 3 { + return None; + } + let text = resolve_cell_value(parts[0].trim().trim_matches('"'), worksheet); + let start: usize = parts[1].trim().parse().ok()?; + let num: usize = parts[2].trim().parse().ok()?; + Some( + text.chars() + .skip(start.saturating_sub(1)) + .take(num) + .collect(), + ) +} + +fn evaluate_len(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("LEN(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let text = resolve_cell_value(inner.trim().trim_matches('"'), worksheet); + Some(text.chars().count().to_string()) +} + +fn evaluate_trim(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("TRIM(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[5..expr.len() - 1]; + let text = resolve_cell_value(inner.trim().trim_matches('"'), worksheet); + Some(text.split_whitespace().collect::>().join(" ")) +} + +fn evaluate_upper(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("UPPER(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let text = resolve_cell_value(inner.trim().trim_matches('"'), worksheet); + Some(text.to_uppercase()) +} + +fn evaluate_lower(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("LOWER(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let text = resolve_cell_value(inner.trim().trim_matches('"'), worksheet); + Some(text.to_lowercase()) +} + +fn evaluate_proper(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("PROPER(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[7..expr.len() - 1]; + let text = resolve_cell_value(inner.trim().trim_matches('"'), worksheet); + Some( + text.split_whitespace() + .map(|word| { + let mut chars = word.chars(); + match chars.next() { + Some(first) => { + first.to_uppercase().to_string() + chars.as_str().to_lowercase().as_str() + } + None => String::new(), + } + }) + .collect::>() + .join(" "), + ) +} + +fn evaluate_substitute(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("SUBSTITUTE(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[11..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() < 3 { + return None; + } + let text = resolve_cell_value(parts[0].trim().trim_matches('"'), worksheet); + let old_text = parts[1].trim().trim_matches('"'); + let new_text = parts[2].trim().trim_matches('"'); + Some(text.replace(old_text, new_text)) +} + +fn evaluate_round(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("ROUND(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.is_empty() { + return None; + } + let num: f64 = resolve_cell_value(parts[0].trim(), worksheet).parse().ok()?; + let decimals: i32 = parts.get(1).and_then(|v| v.trim().parse().ok()).unwrap_or(0); + let factor = 10_f64.powi(decimals); + Some(format_number((num * factor).round() / factor)) +} + +fn evaluate_roundup(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("ROUNDUP(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[8..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.is_empty() { + return None; + } + let num: f64 = resolve_cell_value(parts[0].trim(), worksheet).parse().ok()?; + let decimals: i32 = parts.get(1).and_then(|v| v.trim().parse().ok()).unwrap_or(0); + let factor = 10_f64.powi(decimals); + Some(format_number((num * factor).ceil() / factor)) +} + +fn evaluate_rounddown(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("ROUNDDOWN(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[10..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.is_empty() { + return None; + } + let num: f64 = resolve_cell_value(parts[0].trim(), worksheet).parse().ok()?; + let decimals: i32 = parts.get(1).and_then(|v| v.trim().parse().ok()).unwrap_or(0); + let factor = 10_f64.powi(decimals); + Some(format_number((num * factor).floor() / factor)) +} + +fn evaluate_abs(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("ABS(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let num: f64 = resolve_cell_value(inner.trim(), worksheet).parse().ok()?; + Some(format_number(num.abs())) +} + +fn evaluate_sqrt(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("SQRT(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[5..expr.len() - 1]; + let num: f64 = resolve_cell_value(inner.trim(), worksheet).parse().ok()?; + if num < 0.0 { + return Some("#NUM!".to_string()); + } + Some(format_number(num.sqrt())) +} + +fn evaluate_power(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("POWER(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() != 2 { + return None; + } + let base: f64 = resolve_cell_value(parts[0].trim(), worksheet).parse().ok()?; + let exp: f64 = resolve_cell_value(parts[1].trim(), worksheet).parse().ok()?; + Some(format_number(base.powf(exp))) +} + +fn evaluate_mod_formula(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("MOD(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() != 2 { + return None; + } + let num: f64 = resolve_cell_value(parts[0].trim(), worksheet).parse().ok()?; + let divisor: f64 = resolve_cell_value(parts[1].trim(), worksheet).parse().ok()?; + if divisor == 0.0 { + return Some("#DIV/0!".to_string()); + } + Some(format_number(num % divisor)) +} + +fn evaluate_and(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("AND(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + let result = parts + .iter() + .all(|p| evaluate_condition(p.trim(), worksheet)); + Some(if result { "TRUE" } else { "FALSE" }.to_string()) +} + +fn evaluate_or(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("OR(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[3..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + let result = parts + .iter() + .any(|p| evaluate_condition(p.trim(), worksheet)); + Some(if result { "TRUE" } else { "FALSE" }.to_string()) +} + +fn evaluate_not(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("NOT(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let result = !evaluate_condition(inner.trim(), worksheet); + Some(if result { "TRUE" } else { "FALSE" }.to_string()) +} + +fn evaluate_today(_expr: &str, _worksheet: &Worksheet) -> Option { + if _expr != "TODAY()" { + return None; + } + let today = Local::now().format("%Y-%m-%d").to_string(); + Some(today) +} + +fn evaluate_now(_expr: &str, _worksheet: &Worksheet) -> Option { + if _expr != "NOW()" { + return None; + } + let now = Local::now().format("%Y-%m-%d %H:%M:%S").to_string(); + Some(now) +} + +fn evaluate_date(expr: &str, _worksheet: &Worksheet) -> Option { + if !expr.starts_with("DATE(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[5..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() != 3 { + return None; + } + let year: i32 = parts[0].trim().parse().ok()?; + let month: u32 = parts[1].trim().parse().ok()?; + let day: u32 = parts[2].trim().parse().ok()?; + let date = NaiveDate::from_ymd_opt(year, month, day)?; + Some(date.format("%Y-%m-%d").to_string()) +} + +fn evaluate_year(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("YEAR(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[5..expr.len() - 1]; + let date_str = resolve_cell_value(inner.trim().trim_matches('"'), worksheet); + let date = NaiveDate::parse_from_str(&date_str, "%Y-%m-%d").ok()?; + Some(date.year().to_string()) +} + +fn evaluate_month(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("MONTH(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[6..expr.len() - 1]; + let date_str = resolve_cell_value(inner.trim().trim_matches('"'), worksheet); + let date = NaiveDate::parse_from_str(&date_str, "%Y-%m-%d").ok()?; + Some(date.month().to_string()) +} + +fn evaluate_day(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("DAY(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[4..expr.len() - 1]; + let date_str = resolve_cell_value(inner.trim().trim_matches('"'), worksheet); + let date = NaiveDate::parse_from_str(&date_str, "%Y-%m-%d").ok()?; + Some(date.day().to_string()) +} + +fn evaluate_datedif(expr: &str, worksheet: &Worksheet) -> Option { + if !expr.starts_with("DATEDIF(") || !expr.ends_with(')') { + return None; + } + let inner = &expr[8..expr.len() - 1]; + let parts: Vec<&str> = split_args(inner); + if parts.len() != 3 { + return None; + } + let start_str = resolve_cell_value(parts[0].trim().trim_matches('"'), worksheet); + let end_str = resolve_cell_value(parts[1].trim().trim_matches('"'), worksheet); + let unit = parts[2].trim().trim_matches('"').to_uppercase(); + + let start = NaiveDate::parse_from_str(&start_str, "%Y-%m-%d").ok()?; + let end = NaiveDate::parse_from_str(&end_str, "%Y-%m-%d").ok()?; + + let diff = match unit.as_str() { + "D" => (end - start).num_days(), + "M" => { + let months = (end.year() - start.year()) * 12 + (end.month() as i32 - start.month() as i32); + months as i64 + } + "Y" => (end.year() - start.year()) as i64, + _ => return Some("#VALUE!".to_string()), + }; + Some(diff.to_string()) +} + +fn evaluate_arithmetic(expr: &str, worksheet: &Worksheet) -> Option { + let resolved = resolve_cell_references(expr, worksheet); + eval_simple_arithmetic(&resolved).map(format_number) +} + +fn resolve_cell_references(expr: &str, worksheet: &Worksheet) -> String { + let mut result = expr.to_string(); + let re = regex::Regex::new(r"([A-Z]+)(\d+)").ok(); + + if let Some(regex) = re { + for cap in regex.captures_iter(expr) { + if let (Some(col_match), Some(row_match)) = (cap.get(1), cap.get(2)) { + let col = col_name_to_index(col_match.as_str()); + let row: u32 = row_match.as_str().parse().unwrap_or(1) - 1; + let key = format!("{},{}", row, col); + + let value = worksheet + .data + .get(&key) + .and_then(|c| c.value.clone()) + .unwrap_or_else(|| "0".to_string()); + + let cell_ref = format!("{}{}", col_match.as_str(), row_match.as_str()); + result = result.replace(&cell_ref, &value); + } + } + } + result +} + +fn eval_simple_arithmetic(expr: &str) -> Option { + let expr = expr.replace(' ', ""); + if let Ok(num) = expr.parse::() { + return Some(num); + } + if let Some(pos) = expr.rfind('+') { + if pos > 0 { + let left = eval_simple_arithmetic(&expr[..pos])?; + let right = eval_simple_arithmetic(&expr[pos + 1..])?; + return Some(left + right); + } + } + if let Some(pos) = expr.rfind('-') { + if pos > 0 { + let left = eval_simple_arithmetic(&expr[..pos])?; + let right = eval_simple_arithmetic(&expr[pos + 1..])?; + return Some(left - right); + } + } + if let Some(pos) = expr.rfind('*') { + let left = eval_simple_arithmetic(&expr[..pos])?; + let right = eval_simple_arithmetic(&expr[pos + 1..])?; + return Some(left * right); + } + if let Some(pos) = expr.rfind('/') { + let left = eval_simple_arithmetic(&expr[..pos])?; + let right = eval_simple_arithmetic(&expr[pos + 1..])?; + if right != 0.0 { + return Some(left / right); + } + } + None +} + +fn get_range_values(range: &str, worksheet: &Worksheet) -> Vec { + let parts: Vec<&str> = range.split(':').collect(); + if parts.len() != 2 { + if let Some(val) = resolve_cell_value(range.trim(), worksheet).parse::().ok() { + return vec![val]; + } + return Vec::new(); + } + let (start, end) = match parse_range(range) { + Some(r) => r, + None => return Vec::new(), + }; + let mut values = Vec::new(); + for row in start.0..=end.0 { + for col in start.1..=end.1 { + let key = format!("{},{}", row, col); + if let Some(cell) = worksheet.data.get(&key) { + if let Some(ref value) = cell.value { + if let Ok(num) = value.parse::() { + values.push(num); + } + } + } + } + } + values +} + +fn get_range_string_values(range: &str, worksheet: &Worksheet) -> Vec { + let (start, end) = match parse_range(range) { + Some(r) => r, + None => return Vec::new(), + }; + let mut values = Vec::new(); + for row in start.0..=end.0 { + for col in start.1..=end.1 { + let key = format!("{},{}", row, col); + let value = worksheet + .data + .get(&key) + .and_then(|c| c.value.clone()) + .unwrap_or_default(); + values.push(value); + } + } + values +} + +fn parse_range(range: &str) -> Option<((u32, u32), (u32, u32))> { + let parts: Vec<&str> = range.split(':').collect(); + if parts.len() != 2 { + return None; + } + let start = parse_cell_ref(parts[0].trim())?; + let end = parse_cell_ref(parts[1].trim())?; + Some((start, end)) +} + +fn parse_cell_ref(cell_ref: &str) -> Option<(u32, u32)> { + let cell_ref = cell_ref.trim().to_uppercase(); + let mut col_str = String::new(); + let mut row_str = String::new(); + for ch in cell_ref.chars() { + if ch.is_ascii_alphabetic() { + col_str.push(ch); + } else if ch.is_ascii_digit() { + row_str.push(ch); + } + } + if col_str.is_empty() || row_str.is_empty() { + return None; + } + let col = col_name_to_index(&col_str); + let row: u32 = row_str.parse::().ok()? - 1; + Some((row, col)) +} + +fn col_name_to_index(name: &str) -> u32 { + let mut col: u32 = 0; + for ch in name.chars() { + col = col * 26 + (ch as u32 - 'A' as u32 + 1); + } + col - 1 +} + +fn format_number(num: f64) -> String { + if num.fract() == 0.0 { + format!("{}", num as i64) + } else { + format!("{:.6}", num).trim_end_matches('0').trim_end_matches('.').to_string() + } +} + +fn resolve_cell_value(value: &str, worksheet: &Worksheet) -> String { + if let Some((row, col)) = parse_cell_ref(value) { + let key = format!("{},{}", row, col); + worksheet + .data + .get(&key) + .and_then(|c| c.value.clone()) + .unwrap_or_default() + } else { + value.to_string() + } +} + +fn split_args(s: &str) -> Vec<&str> { + let mut parts = Vec::new(); + let mut depth = 0; + let mut start = 0; + for (i, ch) in s.char_indices() { + match ch { + '(' => depth += 1, + ')' => depth -= 1, + ',' if depth == 0 => { + parts.push(&s[start..i]); + start = i + 1; + } + _ => {} + } + } + if start < s.len() { + parts.push(&s[start..]); + } + parts +} + +fn evaluate_condition(condition: &str, worksheet: &Worksheet) -> bool { + let condition = condition.trim(); + if condition.eq_ignore_ascii_case("TRUE") { + return true; + } + if condition.eq_ignore_ascii_case("FALSE") { + return false; + } + + let ops = [">=", "<=", "<>", "!=", "=", ">", "<"]; + for op in ops { + if let Some(pos) = condition.find(op) { + let left = resolve_cell_value(&condition[..pos].trim(), worksheet); + let right = resolve_cell_value(&condition[pos + op.len()..].trim().trim_matches('"'), worksheet); + + let left_num = left.parse::().ok(); + let right_num = right.parse::().ok(); + + return match (op, left_num, right_num) { + (">=", Some(l), Some(r)) => l >= r, + ("<=", Some(l), Some(r)) => l <= r, + ("<>", _, _) | ("!=", _, _) => left != right, + ("=", _, _) => left == right, + (">", Some(l), Some(r)) => l > r, + ("<", Some(l), Some(r)) => l < r, + _ => false, + }; + } + } + + let val = resolve_cell_value(condition, worksheet); + !val.is_empty() && val != "0" && !val.eq_ignore_ascii_case("FALSE") +} + +fn matches_criteria(value: &str, criteria: &str) -> bool { + if criteria.starts_with(">=") { + if let (Ok(v), Ok(c)) = (value.parse::(), criteria[2..].parse::()) { + return v >= c; + } + } else if criteria.starts_with("<=") { + if let (Ok(v), Ok(c)) = (value.parse::(), criteria[2..].parse::()) { + return v <= c; + } + } else if criteria.starts_with("<>") || criteria.starts_with("!=") { + return value != &criteria[2..]; + } else if criteria.starts_with('>') { + if let (Ok(v), Ok(c)) = (value.parse::(), criteria[1..].parse::()) { + return v > c; + } + } else if criteria.starts_with('<') { + if let (Ok(v), Ok(c)) = (value.parse::(), criteria[1..].parse::()) { + return v < c; + } + } else if criteria.starts_with('=') { + return value == &criteria[1..]; + } else if criteria.contains('*') || criteria.contains('?') { + let pattern = criteria.replace('*', ".*").replace('?', "."); + if let Ok(re) = regex::Regex::new(&format!("^{}$", pattern)) { + return re.is_match(value); + } + } + value.eq_ignore_ascii_case(criteria) +} + +fn count_matching(values: &[String], criteria: &str) -> usize { + values.iter().filter(|v| matches_criteria(v, criteria)).count() +} + +pub async fn handle_export_sheet( + State(state): State>, + Json(req): Json, +) -> Result)> { + let user_id = get_current_user_id(); + let sheet = match load_sheet_from_drive(&state, &user_id, &req.id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + match req.format.as_str() { + "csv" => { + let csv = export_to_csv(&sheet); + Ok(([(axum::http::header::CONTENT_TYPE, "text/csv")], csv)) + } + "json" => { + let json = serde_json::to_string_pretty(&sheet).unwrap_or_default(); + Ok(([(axum::http::header::CONTENT_TYPE, "application/json")], json)) + } + _ => Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Unsupported format" })))), + } +} + +fn export_to_csv(sheet: &Spreadsheet) -> String { + let mut csv = String::new(); + if let Some(worksheet) = sheet.worksheets.first() { + let mut max_row: u32 = 0; + let mut max_col: u32 = 0; + for key in worksheet.data.keys() { + let parts: Vec<&str> = key.split(',').collect(); + if parts.len() == 2 { + if let (Ok(row), Ok(col)) = (parts[0].parse::(), parts[1].parse::()) { + max_row = max_row.max(row); + max_col = max_col.max(col); + } + } + } + for row in 0..=max_row { + let mut row_values = Vec::new(); + for col in 0..=max_col { + let key = format!("{},{}", row, col); + let value = worksheet.data.get(&key).and_then(|c| c.value.clone()).unwrap_or_default(); + let escaped = if value.contains(',') || value.contains('"') || value.contains('\n') { + format!("\"{}\"", value.replace('"', "\"\"")) + } else { + value + }; + row_values.push(escaped); + } + csv.push_str(&row_values.join(",")); + csv.push('\n'); + } + } + csv +} + +pub async fn handle_share_sheet( + Json(req): Json, +) -> Result, (StatusCode, Json)> { + Ok(Json(SaveResponse { + id: req.sheet_id, + success: true, + message: Some(format!("Shared with {} as {}", req.email, req.permission)), + })) +} + +pub async fn handle_get_sheet_by_id( + State(state): State>, + Path(sheet_id): Path, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + match load_sheet_from_drive(&state, &user_id, &sheet_id).await { + Ok(sheet) => Ok(Json(sheet)), + Err(e) => Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + } +} + +pub async fn handle_merge_cells( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + let merged = MergedCell { + start_row: req.start_row, + start_col: req.start_col, + end_row: req.end_row, + end_col: req.end_col, + }; + + let merged_cells = worksheet.merged_cells.get_or_insert_with(Vec::new); + merged_cells.push(merged); + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_sheet_change(&req.sheet_id, "merge", &user_id, None, None, None).await; + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Cells merged".to_string()) })) +} + +pub async fn handle_unmerge_cells( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + if let Some(ref mut merged_cells) = worksheet.merged_cells { + merged_cells.retain(|m| { + !(m.start_row == req.start_row && m.start_col == req.start_col && + m.end_row == req.end_row && m.end_col == req.end_col) + }); + } + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Cells unmerged".to_string()) })) +} + +pub async fn handle_freeze_panes( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + worksheet.frozen_rows = if req.frozen_rows > 0 { Some(req.frozen_rows) } else { None }; + worksheet.frozen_cols = if req.frozen_cols > 0 { Some(req.frozen_cols) } else { None }; + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Freeze panes updated".to_string()) })) +} + +pub async fn handle_sort_range( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + let mut rows_data: Vec<(u32, HashMap)> = Vec::new(); + + for row in req.start_row..=req.end_row { + let mut row_data = HashMap::new(); + for col in req.start_col..=req.end_col { + let key = format!("{},{}", row, col); + if let Some(cell) = worksheet.data.get(&key) { + row_data.insert(col, cell.clone()); + } + } + rows_data.push((row, row_data)); + } + + rows_data.sort_by(|a, b| { + let val_a = a.1.get(&req.sort_col).and_then(|c| c.value.clone()).unwrap_or_default(); + let val_b = b.1.get(&req.sort_col).and_then(|c| c.value.clone()).unwrap_or_default(); + let num_a = val_a.parse::().ok(); + let num_b = val_b.parse::().ok(); + let cmp = match (num_a, num_b) { + (Some(a), Some(b)) => a.partial_cmp(&b).unwrap_or(std::cmp::Ordering::Equal), + _ => val_a.cmp(&val_b), + }; + if req.ascending { cmp } else { cmp.reverse() } + }); + + for (idx, (_, row_data)) in rows_data.into_iter().enumerate() { + let target_row = req.start_row + idx as u32; + for col in req.start_col..=req.end_col { + let key = format!("{},{}", target_row, col); + if let Some(cell) = row_data.get(&col) { + worksheet.data.insert(key, cell.clone()); + } else { + worksheet.data.remove(&key); + } + } + } + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_sheet_change(&req.sheet_id, "sort", &user_id, None, None, None).await; + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Range sorted".to_string()) })) +} + +pub async fn handle_filter_data( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + let filters = worksheet.filters.get_or_insert_with(HashMap::new); + + filters.insert(req.col, FilterConfig { + filter_type: req.filter_type.clone(), + values: req.values.clone().unwrap_or_default(), + condition: req.condition.clone(), + value1: req.value1.clone(), + value2: req.value2.clone(), + }); + + let mut hidden_rows = Vec::new(); + let mut max_row = 0u32; + for key in worksheet.data.keys() { + if let Some(row) = key.split(',').next().and_then(|r| r.parse::().ok()) { + max_row = max_row.max(row); + } + } + + for row in 0..=max_row { + let key = format!("{},{}", row, req.col); + let cell_value = worksheet.data.get(&key).and_then(|c| c.value.clone()).unwrap_or_default(); + + let should_hide = match req.filter_type.as_str() { + "values" => { + let values = req.values.as_ref().map(|v| v.as_slice()).unwrap_or(&[]); + !values.is_empty() && !values.iter().any(|v| v == &cell_value) + } + "greaterThan" => { + if let (Ok(cv), Some(Ok(v1))) = (cell_value.parse::(), req.value1.as_ref().map(|v| v.parse::())) { + cv <= v1 + } else { false } + } + "lessThan" => { + if let (Ok(cv), Some(Ok(v1))) = (cell_value.parse::(), req.value1.as_ref().map(|v| v.parse::())) { + cv >= v1 + } else { false } + } + "between" => { + if let (Ok(cv), Some(Ok(v1)), Some(Ok(v2))) = ( + cell_value.parse::(), + req.value1.as_ref().map(|v| v.parse::()), + req.value2.as_ref().map(|v| v.parse::()) + ) { + cv < v1 || cv > v2 + } else { false } + } + "contains" => { + if let Some(ref v1) = req.value1 { + !cell_value.to_lowercase().contains(&v1.to_lowercase()) + } else { false } + } + "notContains" => { + if let Some(ref v1) = req.value1 { + cell_value.to_lowercase().contains(&v1.to_lowercase()) + } else { false } + } + "isEmpty" => !cell_value.is_empty(), + "isNotEmpty" => cell_value.is_empty(), + _ => false, + }; + + if should_hide { + hidden_rows.push(row); + } + } + + worksheet.hidden_rows = if hidden_rows.is_empty() { None } else { Some(hidden_rows.clone()) }; + sheet.updated_at = Utc::now(); + + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(serde_json::json!({ + "success": true, + "sheet_id": req.sheet_id, + "hidden_rows": hidden_rows + }))) +} + +pub async fn handle_clear_filter( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + + if let Some(col) = req.col { + if let Some(ref mut filters) = worksheet.filters { + filters.remove(&col); + } + } else { + worksheet.filters = None; + } + worksheet.hidden_rows = None; + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Filter cleared".to_string()) })) +} + +pub async fn handle_create_chart( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &sheet.worksheets[req.worksheet_index]; + let chart_id = Uuid::new_v4().to_string(); + + let (start, end) = match parse_range(&req.data_range) { + Some(r) => r, + None => return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid data range" })))), + }; + + let mut labels = Vec::new(); + let mut datasets = Vec::new(); + let colors = ["#3b82f6", "#ef4444", "#22c55e", "#f59e0b", "#8b5cf6", "#ec4899", "#14b8a6"]; + + if let Some(ref label_range) = req.label_range { + if let Some((ls, le)) = parse_range(label_range) { + for row in ls.0..=le.0 { + for col in ls.1..=le.1 { + let key = format!("{},{}", row, col); + let val = worksheet.data.get(&key).and_then(|c| c.value.clone()).unwrap_or_default(); + labels.push(val); + } + } + } + } else { + for row in start.0..=end.0 { + let key = format!("{},{}", row, start.1); + let val = worksheet.data.get(&key).and_then(|c| c.value.clone()).unwrap_or_else(|| format!("Row {}", row + 1)); + labels.push(val); + } + } + + for (col_idx, col) in (start.1..=end.1).enumerate() { + let mut data = Vec::new(); + for row in start.0..=end.0 { + let key = format!("{},{}", row, col); + let val = worksheet.data.get(&key).and_then(|c| c.value.clone()).unwrap_or_default(); + data.push(val.parse::().unwrap_or(0.0)); + } + datasets.push(ChartDataset { + label: format!("Series {}", col_idx + 1), + data, + color: colors[col_idx % colors.len()].to_string(), + background_color: Some(colors[col_idx % colors.len()].to_string()), + }); + } + + let chart = ChartConfig { + id: chart_id.clone(), + chart_type: req.chart_type.clone(), + title: req.title.clone().unwrap_or_else(|| "Chart".to_string()), + data_range: req.data_range.clone(), + label_range: req.label_range.clone(), + position: req.position.clone().unwrap_or(ChartPosition { row: 0, col: end.1 + 2, width: 400, height: 300 }), + options: ChartOptions { + show_legend: true, + show_grid: true, + stacked: false, + legend_position: Some("bottom".to_string()), + x_axis_title: None, + y_axis_title: None, + }, + datasets: datasets.clone(), + labels: labels.clone(), + }; + + let worksheet_mut = &mut sheet.worksheets[req.worksheet_index]; + let charts = worksheet_mut.charts.get_or_insert_with(Vec::new); + charts.push(chart); + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(serde_json::json!({ + "success": true, + "chart_id": chart_id, + "chart_type": req.chart_type, + "labels": labels, + "datasets": datasets + }))) +} + +pub async fn handle_delete_chart( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + if let Some(ref mut charts) = worksheet.charts { + charts.retain(|c| c.id != req.chart_id); + } + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Chart deleted".to_string()) })) +} + +pub async fn handle_conditional_format( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + let rule = ConditionalFormatRule { + id: Uuid::new_v4().to_string(), + start_row: req.start_row, + start_col: req.start_col, + end_row: req.end_row, + end_col: req.end_col, + rule_type: req.rule_type.clone(), + condition: req.condition.clone(), + style: req.style.clone(), + priority: 0, + }; + + let rules = worksheet.conditional_formats.get_or_insert_with(Vec::new); + rules.push(rule); + + for row in req.start_row..=req.end_row { + for col in req.start_col..=req.end_col { + let key = format!("{},{}", row, col); + let cell_value = worksheet.data.get(&key).and_then(|c| c.value.clone()).unwrap_or_default(); + + let should_apply = match req.rule_type.as_str() { + "greaterThan" => { + if let (Ok(val), Ok(cond)) = (cell_value.parse::(), req.condition.parse::()) { + val > cond + } else { false } + } + "lessThan" => { + if let (Ok(val), Ok(cond)) = (cell_value.parse::(), req.condition.parse::()) { + val < cond + } else { false } + } + "equals" => cell_value == req.condition, + "notEquals" => cell_value != req.condition, + "contains" => cell_value.to_lowercase().contains(&req.condition.to_lowercase()), + "notContains" => !cell_value.to_lowercase().contains(&req.condition.to_lowercase()), + "isEmpty" => cell_value.is_empty(), + "isNotEmpty" => !cell_value.is_empty(), + "between" => { + let parts: Vec<&str> = req.condition.split(',').collect(); + if parts.len() == 2 { + if let (Ok(val), Ok(min), Ok(max)) = (cell_value.parse::(), parts[0].trim().parse::(), parts[1].trim().parse::()) { + val >= min && val <= max + } else { false } + } else { false } + } + _ => false, + }; + + if should_apply { + let cell = worksheet.data.entry(key).or_insert_with(|| CellData { + value: None, formula: None, style: None, format: None, note: None, + }); + cell.style = Some(req.style.clone()); + } + } + } + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Conditional formatting applied".to_string()) })) +} + +pub async fn handle_data_validation( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + let validations = worksheet.validations.get_or_insert_with(HashMap::new); + + let rule = ValidationRule { + validation_type: req.validation_type.clone(), + operator: req.operator.clone(), + value1: req.value1.clone(), + value2: req.value2.clone(), + allowed_values: req.allowed_values.clone(), + error_title: Some("Validation Error".to_string()), + error_message: req.error_message.clone(), + input_title: None, + input_message: None, + }; + + for row in req.start_row..=req.end_row { + for col in req.start_col..=req.end_col { + let key = format!("{},{}", row, col); + validations.insert(key, rule.clone()); + } + } + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Data validation applied".to_string()) })) +} + +pub async fn handle_validate_cell( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &sheet.worksheets[req.worksheet_index]; + let key = format!("{},{}", req.row, req.col); + + let rule = worksheet.validations.as_ref().and_then(|v| v.get(&key)); + + if let Some(rule) = rule { + let (valid, error_msg) = validate_value(&req.value, rule); + Ok(Json(ValidationResult { valid, error_message: if valid { None } else { error_msg } })) + } else { + Ok(Json(ValidationResult { valid: true, error_message: None })) + } +} + +fn validate_value(value: &str, rule: &ValidationRule) -> (bool, Option) { + let error_msg = rule.error_message.clone().unwrap_or_else(|| "Invalid value".to_string()); + + match rule.validation_type.as_str() { + "list" => { + if let Some(ref allowed) = rule.allowed_values { + let valid = allowed.iter().any(|v| v == value); + (valid, Some(error_msg)) + } else { (true, None) } + } + "number" => { + let num = match value.parse::() { + Ok(n) => n, + Err(_) => return (false, Some("Must be a number".to_string())), + }; + let op = rule.operator.as_deref().unwrap_or("between"); + let v1 = rule.value1.as_ref().and_then(|v| v.parse::().ok()); + let v2 = rule.value2.as_ref().and_then(|v| v.parse::().ok()); + + let valid = match op { + "between" => v1.zip(v2).map(|(a, b)| num >= a && num <= b).unwrap_or(true), + "notBetween" => v1.zip(v2).map(|(a, b)| num < a || num > b).unwrap_or(true), + "greaterThan" => v1.map(|a| num > a).unwrap_or(true), + "lessThan" => v1.map(|a| num < a).unwrap_or(true), + "greaterThanOrEqual" => v1.map(|a| num >= a).unwrap_or(true), + "lessThanOrEqual" => v1.map(|a| num <= a).unwrap_or(true), + "equal" => v1.map(|a| (num - a).abs() < f64::EPSILON).unwrap_or(true), + "notEqual" => v1.map(|a| (num - a).abs() >= f64::EPSILON).unwrap_or(true), + _ => true, + }; + (valid, Some(error_msg)) + } + "textLength" => { + let len = value.chars().count(); + let op = rule.operator.as_deref().unwrap_or("between"); + let v1 = rule.value1.as_ref().and_then(|v| v.parse::().ok()); + let v2 = rule.value2.as_ref().and_then(|v| v.parse::().ok()); + + let valid = match op { + "between" => v1.zip(v2).map(|(a, b)| len >= a && len <= b).unwrap_or(true), + "greaterThan" => v1.map(|a| len > a).unwrap_or(true), + "lessThan" => v1.map(|a| len < a).unwrap_or(true), + "equal" => v1.map(|a| len == a).unwrap_or(true), + _ => true, + }; + (valid, Some(error_msg)) + } + "date" => { + let valid = NaiveDate::parse_from_str(value, "%Y-%m-%d").is_ok(); + (valid, Some("Must be a valid date (YYYY-MM-DD)".to_string())) + } + "custom" => { + if let Some(ref formula) = rule.value1 { + (value == formula, Some(error_msg)) + } else { (true, None) } + } + _ => (true, None), + } +} + +pub async fn handle_add_note( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut sheet = match load_sheet_from_drive(&state, &user_id, &req.sheet_id).await { + Ok(s) => s, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.worksheet_index >= sheet.worksheets.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid worksheet index" })))); + } + + let worksheet = &mut sheet.worksheets[req.worksheet_index]; + let key = format!("{},{}", req.row, req.col); + + let cell = worksheet.data.entry(key).or_insert_with(|| CellData { + value: None, formula: None, style: None, format: None, note: None, + }); + cell.note = if req.note.is_empty() { None } else { Some(req.note.clone()) }; + + sheet.updated_at = Utc::now(); + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(SaveResponse { id: req.sheet_id, success: true, message: Some("Note added".to_string()) })) +} + +pub async fn handle_import_sheet( + State(state): State>, + body: axum::body::Bytes, +) -> Result, (StatusCode, Json)> { + let content = String::from_utf8_lossy(&body); + let user_id = get_current_user_id(); + + let mut worksheet_data = HashMap::new(); + for (row_idx, line) in content.lines().enumerate() { + let cols: Vec<&str> = line.split(',').collect(); + for (col_idx, value) in cols.iter().enumerate() { + let clean_value = value.trim().trim_matches('"').to_string(); + if !clean_value.is_empty() { + let key = format!("{},{}", row_idx, col_idx); + worksheet_data.insert(key, CellData { + value: Some(clean_value), formula: None, style: None, format: None, note: None, + }); + } + } + } + + let sheet = Spreadsheet { + id: Uuid::new_v4().to_string(), + name: "Imported Spreadsheet".to_string(), + owner_id: user_id.clone(), + worksheets: vec![Worksheet { + name: "Sheet1".to_string(), + data: worksheet_data, + column_widths: None, row_heights: None, frozen_rows: None, frozen_cols: None, + merged_cells: None, filters: None, hidden_rows: None, validations: None, + conditional_formats: None, charts: None, + }], + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + if let Err(e) = save_sheet_to_drive(&state, &user_id, &sheet).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(sheet)) +} + +pub async fn handle_get_collaborators( + Path(sheet_id): Path, +) -> impl IntoResponse { + let channels = get_collab_channels().read().await; + let active = channels.contains_key(&sheet_id); + Json(serde_json::json!({ "sheet_id": sheet_id, "collaborators": [], "active": active })) +} + +pub async fn handle_sheet_websocket( + ws: WebSocketUpgrade, + State(state): State>, + Path(sheet_id): Path, +) -> impl IntoResponse { + info!("Sheet WebSocket connection request for sheet: {}", sheet_id); + ws.on_upgrade(move |socket| handle_sheet_connection(socket, state, sheet_id)) +} + +async fn handle_sheet_connection(socket: WebSocket, _state: Arc, sheet_id: String) { + let (mut sender, mut receiver) = socket.split(); + + let channels = get_collab_channels(); + let rx = { + let mut channels_write = channels.write().await; + let tx = channels_write.entry(sheet_id.clone()).or_insert_with(|| broadcast::channel(256).0); + tx.subscribe() + }; + + let user_id = format!("user-{}", &Uuid::new_v4().to_string()[..8]); + let user_color = get_random_color(); + + let welcome = serde_json::json!({ + "type": "connected", + "sheet_id": sheet_id, + "user_id": user_id, + "user_color": user_color, + "timestamp": Utc::now().to_rfc3339() + }); + + if sender.send(Message::Text(welcome.to_string())).await.is_err() { + error!("Failed to send welcome message"); + return; + } + + info!("User {} connected to sheet {}", user_id, sheet_id); + broadcast_sheet_change(&sheet_id, "userJoined", &user_id, None, None, Some(&user_color)).await; + + let sheet_id_recv = sheet_id.clone(); + let user_id_recv = user_id.clone(); + let user_id_send = user_id.clone(); + + let mut rx = rx; + let send_task = tokio::spawn(async move { + while let Ok(msg) = rx.recv().await { + if msg.user_id != user_id_send { + if let Ok(json) = serde_json::to_string(&msg) { + if sender.send(Message::Text(json)).await.is_err() { + break; + } + } + } + } + }); + + let recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = receiver.next().await { + match msg { + Message::Text(text) => { + if let Ok(parsed) = serde_json::from_str::(&text) { + let msg_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or(""); + match msg_type { + "cellChange" => { + let row = parsed.get("row").and_then(|v| v.as_u64()).map(|v| v as u32); + let col = parsed.get("col").and_then(|v| v.as_u64()).map(|v| v as u32); + let value = parsed.get("value").and_then(|v| v.as_str()).map(String::from); + broadcast_sheet_change(&sheet_id_recv, "cellChange", &user_id_recv, row, col, value.as_deref()).await; + } + "cursor" => { + let row = parsed.get("row").and_then(|v| v.as_u64()).map(|v| v as u32); + let col = parsed.get("col").and_then(|v| v.as_u64()).map(|v| v as u32); + broadcast_sheet_change(&sheet_id_recv, "cursor", &user_id_recv, row, col, None).await; + } + _ => {} + } + } + } + Message::Close(_) => break, + _ => {} + } + } + }); + + tokio::select! { + _ = send_task => {}, + _ = recv_task => {}, + } + + broadcast_sheet_change(&sheet_id, "userLeft", &user_id, None, None, None).await; + info!("User {} disconnected from sheet {}", user_id, sheet_id); +} + +async fn broadcast_sheet_change( + sheet_id: &str, + msg_type: &str, + user_id: &str, + row: Option, + col: Option, + value: Option<&str>, +) { + let channels = get_collab_channels().read().await; + if let Some(tx) = channels.get(sheet_id) { + let msg = CollabMessage { + msg_type: msg_type.to_string(), + sheet_id: sheet_id.to_string(), + user_id: user_id.to_string(), + user_name: format!("User {}", &user_id[..8.min(user_id.len())]), + user_color: get_random_color(), + row, + col, + value: value.map(String::from), + worksheet_index: None, + timestamp: Utc::now(), + }; + let _ = tx.send(msg); + } +} + +fn get_random_color() -> String { + let colors = [ + "#3b82f6", "#ef4444", "#22c55e", "#f59e0b", "#8b5cf6", + "#ec4899", "#14b8a6", "#f97316", "#6366f1", "#84cc16", + ]; + let idx = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos() as usize % colors.len()) + .unwrap_or(0); + colors[idx].to_string() +} diff --git a/src/slides/mod.rs b/src/slides/mod.rs new file mode 100644 index 000000000..35fe107ce --- /dev/null +++ b/src/slides/mod.rs @@ -0,0 +1,1360 @@ +use crate::shared::state::AppState; +use axum::{ + extract::{ + ws::{Message, WebSocket, WebSocketUpgrade}, + Path, Query, State, + }, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, + Json, Router, +}; +use chrono::{DateTime, Utc}; +use futures_util::{SinkExt, StreamExt}; +use log::{error, info}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::broadcast; +use uuid::Uuid; + +type SlideChannels = Arc>>>; + +static SLIDE_CHANNELS: std::sync::OnceLock = std::sync::OnceLock::new(); + +fn get_slide_channels() -> &'static SlideChannels { + SLIDE_CHANNELS.get_or_init(|| Arc::new(tokio::sync::RwLock::new(HashMap::new()))) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SlideMessage { + pub msg_type: String, + pub presentation_id: String, + pub user_id: String, + pub user_name: String, + pub user_color: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub slide_index: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub element_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Presentation { + pub id: String, + pub name: String, + pub owner_id: String, + pub slides: Vec, + pub theme: PresentationTheme, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Slide { + pub id: String, + pub layout: String, + pub elements: Vec, + pub background: SlideBackground, + #[serde(skip_serializing_if = "Option::is_none")] + pub notes: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub transition: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SlideElement { + pub id: String, + pub element_type: String, + pub x: f64, + pub y: f64, + pub width: f64, + pub height: f64, + #[serde(default)] + pub rotation: f64, + pub content: ElementContent, + pub style: ElementStyle, + #[serde(default)] + pub animations: Vec, + #[serde(default)] + pub z_index: i32, + #[serde(default)] + pub locked: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ElementContent { + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub html: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub src: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub shape_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chart_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub table_data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ElementStyle { + #[serde(skip_serializing_if = "Option::is_none")] + pub fill: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stroke: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stroke_width: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub opacity: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub shadow: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub font_family: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub font_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub font_weight: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub font_style: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text_align: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub vertical_align: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub color: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub line_height: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub border_radius: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ShadowStyle { + pub color: String, + pub blur: f64, + pub offset_x: f64, + pub offset_y: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SlideBackground { + pub bg_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub color: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub gradient: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_fit: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GradientStyle { + pub gradient_type: String, + pub angle: f64, + pub stops: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GradientStop { + pub color: String, + pub position: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SlideTransition { + pub transition_type: String, + pub duration: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub direction: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Animation { + pub animation_type: String, + pub trigger: String, + pub duration: f64, + pub delay: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub direction: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PresentationTheme { + pub name: String, + pub colors: ThemeColors, + pub fonts: ThemeFonts, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThemeColors { + pub primary: String, + pub secondary: String, + pub accent: String, + pub background: String, + pub text: String, + pub text_light: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThemeFonts { + pub heading: String, + pub body: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChartData { + pub chart_type: String, + pub labels: Vec, + pub datasets: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChartDataset { + pub label: String, + pub data: Vec, + pub color: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TableData { + pub rows: usize, + pub cols: usize, + pub cells: HashMap, + pub col_widths: Vec, + pub row_heights: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TableCell { + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub colspan: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rowspan: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub style: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PresentationMetadata { + pub id: String, + pub name: String, + pub owner_id: String, + pub slide_count: usize, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SavePresentationRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + pub name: String, + pub slides: Vec, + pub theme: PresentationTheme, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoadQuery { + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchQuery { + pub q: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AddSlideRequest { + pub presentation_id: String, + pub layout: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub position: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeleteSlideRequest { + pub presentation_id: String, + pub slide_index: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DuplicateSlideRequest { + pub presentation_id: String, + pub slide_index: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReorderSlidesRequest { + pub presentation_id: String, + pub slide_order: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AddElementRequest { + pub presentation_id: String, + pub slide_index: usize, + pub element: SlideElement, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateElementRequest { + pub presentation_id: String, + pub slide_index: usize, + pub element: SlideElement, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeleteElementRequest { + pub presentation_id: String, + pub slide_index: usize, + pub element_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ApplyThemeRequest { + pub presentation_id: String, + pub theme: PresentationTheme, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateSlideNotesRequest { + pub presentation_id: String, + pub slide_index: usize, + pub notes: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportRequest { + pub id: String, + pub format: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SaveResponse { + pub id: String, + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +pub fn configure_slides_routes() -> Router> { + Router::new() + .route("/api/slides/list", get(handle_list_presentations)) + .route("/api/slides/search", get(handle_search_presentations)) + .route("/api/slides/load", get(handle_load_presentation)) + .route("/api/slides/save", post(handle_save_presentation)) + .route("/api/slides/delete", post(handle_delete_presentation)) + .route("/api/slides/new", get(handle_new_presentation)) + .route("/api/slides/slide/add", post(handle_add_slide)) + .route("/api/slides/slide/delete", post(handle_delete_slide)) + .route("/api/slides/slide/duplicate", post(handle_duplicate_slide)) + .route("/api/slides/slide/reorder", post(handle_reorder_slides)) + .route("/api/slides/slide/notes", post(handle_update_slide_notes)) + .route("/api/slides/element/add", post(handle_add_element)) + .route("/api/slides/element/update", post(handle_update_element)) + .route("/api/slides/element/delete", post(handle_delete_element)) + .route("/api/slides/theme/apply", post(handle_apply_theme)) + .route("/api/slides/export", post(handle_export_presentation)) + .route("/api/slides/:id", get(handle_get_presentation_by_id)) + .route("/api/slides/:id/collaborators", get(handle_get_collaborators)) + .route("/ws/slides/:presentation_id", get(handle_slides_websocket)) +} + +fn get_user_presentations_path(user_id: &str) -> String { + format!("users/{}/presentations", user_id) +} + +fn get_current_user_id() -> String { + "default-user".to_string() +} + +async fn save_presentation_to_drive( + state: &Arc, + user_id: &str, + presentation: &Presentation, +) -> Result<(), String> { + let drive = state + .drive + .as_ref() + .ok_or_else(|| "Drive not available".to_string())?; + + let path = format!( + "{}/{}.json", + get_user_presentations_path(user_id), + presentation.id + ); + let content = serde_json::to_string_pretty(presentation) + .map_err(|e| format!("Serialization error: {e}"))?; + + drive + .put_object() + .bucket("gbo") + .key(&path) + .body(content.into_bytes().into()) + .content_type("application/json") + .send() + .await + .map_err(|e| format!("Failed to save presentation: {e}"))?; + + Ok(()) +} + +async fn load_presentation_from_drive( + state: &Arc, + user_id: &str, + presentation_id: &str, +) -> Result { + let drive = state + .drive + .as_ref() + .ok_or_else(|| "Drive not available".to_string())?; + + let path = format!( + "{}/{}.json", + get_user_presentations_path(user_id), + presentation_id + ); + + let result = drive + .get_object() + .bucket("gbo") + .key(&path) + .send() + .await + .map_err(|e| format!("Failed to load presentation: {e}"))?; + + let bytes = result + .body + .collect() + .await + .map_err(|e| format!("Failed to read presentation: {e}"))? + .into_bytes(); + + let presentation: Presentation = serde_json::from_slice(&bytes) + .map_err(|e| format!("Failed to parse presentation: {e}"))?; + + Ok(presentation) +} + +async fn list_presentations_from_drive( + state: &Arc, + user_id: &str, +) -> Result, String> { + let drive = state + .drive + .as_ref() + .ok_or_else(|| "Drive not available".to_string())?; + + let prefix = format!("{}/", get_user_presentations_path(user_id)); + + let result = drive + .list_objects_v2() + .bucket("gbo") + .prefix(&prefix) + .send() + .await + .map_err(|e| format!("Failed to list presentations: {e}"))?; + + let mut presentations = Vec::new(); + + if let Some(contents) = result.contents { + for obj in contents { + if let Some(key) = obj.key { + if key.ends_with(".json") { + let id = key + .split('/') + .last() + .unwrap_or("") + .trim_end_matches(".json") + .to_string(); + if let Ok(pres) = load_presentation_from_drive(state, user_id, &id).await { + presentations.push(PresentationMetadata { + id: pres.id, + name: pres.name, + owner_id: pres.owner_id, + slide_count: pres.slides.len(), + created_at: pres.created_at, + updated_at: pres.updated_at, + }); + } + } + } + } + } + + presentations.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); + Ok(presentations) +} + +async fn delete_presentation_from_drive( + state: &Arc, + user_id: &str, + presentation_id: &str, +) -> Result<(), String> { + let drive = state + .drive + .as_ref() + .ok_or_else(|| "Drive not available".to_string())?; + + let path = format!( + "{}/{}.json", + get_user_presentations_path(user_id), + presentation_id + ); + + drive + .delete_object() + .bucket("gbo") + .key(&path) + .send() + .await + .map_err(|e| format!("Failed to delete presentation: {e}"))?; + + Ok(()) +} + +fn create_default_theme() -> PresentationTheme { + PresentationTheme { + name: "Default".to_string(), + colors: ThemeColors { + primary: "#3b82f6".to_string(), + secondary: "#64748b".to_string(), + accent: "#f59e0b".to_string(), + background: "#ffffff".to_string(), + text: "#1e293b".to_string(), + text_light: "#64748b".to_string(), + }, + fonts: ThemeFonts { + heading: "Inter".to_string(), + body: "Inter".to_string(), + }, + } +} + +fn create_title_slide() -> Slide { + Slide { + id: Uuid::new_v4().to_string(), + layout: "title".to_string(), + elements: vec![ + SlideElement { + id: Uuid::new_v4().to_string(), + element_type: "text".to_string(), + x: 100.0, + y: 200.0, + width: 760.0, + height: 100.0, + rotation: 0.0, + content: ElementContent { + text: Some("Presentation Title".to_string()), + html: None, + src: None, + shape_type: None, + chart_data: None, + table_data: None, + }, + style: ElementStyle { + font_size: Some(48.0), + font_weight: Some("bold".to_string()), + text_align: Some("center".to_string()), + color: Some("#1e293b".to_string()), + ..Default::default() + }, + animations: vec![], + z_index: 1, + locked: false, + }, + SlideElement { + id: Uuid::new_v4().to_string(), + element_type: "text".to_string(), + x: 100.0, + y: 320.0, + width: 760.0, + height: 50.0, + rotation: 0.0, + content: ElementContent { + text: Some("Subtitle or Author Name".to_string()), + html: None, + src: None, + shape_type: None, + chart_data: None, + table_data: None, + }, + style: ElementStyle { + font_size: Some(24.0), + text_align: Some("center".to_string()), + color: Some("#64748b".to_string()), + ..Default::default() + }, + animations: vec![], + z_index: 2, + locked: false, + }, + ], + background: SlideBackground { + bg_type: "solid".to_string(), + color: Some("#ffffff".to_string()), + gradient: None, + image_url: None, + image_fit: None, + }, + notes: None, + transition: Some(SlideTransition { + transition_type: "fade".to_string(), + duration: 0.5, + direction: None, + }), + } +} + +fn create_content_slide(layout: &str) -> Slide { + let elements = match layout { + "title-content" => vec![ + SlideElement { + id: Uuid::new_v4().to_string(), + element_type: "text".to_string(), + x: 50.0, + y: 40.0, + width: 860.0, + height: 60.0, + rotation: 0.0, + content: ElementContent { + text: Some("Slide Title".to_string()), + ..Default::default() + }, + style: ElementStyle { + font_size: Some(36.0), + font_weight: Some("bold".to_string()), + color: Some("#1e293b".to_string()), + ..Default::default() + }, + animations: vec![], + z_index: 1, + locked: false, + }, + SlideElement { + id: Uuid::new_v4().to_string(), + element_type: "text".to_string(), + x: 50.0, + y: 120.0, + width: 860.0, + height: 400.0, + rotation: 0.0, + content: ElementContent { + text: Some("• Click to add content\n• Add your bullet points here".to_string()), + ..Default::default() + }, + style: ElementStyle { + font_size: Some(20.0), + color: Some("#374151".to_string()), + line_height: Some(1.6), + ..Default::default() + }, + animations: vec![], + z_index: 2, + locked: false, + }, + ], + "two-column" => vec![ + SlideElement { + id: Uuid::new_v4().to_string(), + element_type: "text".to_string(), + x: 50.0, + y: 40.0, + width: 860.0, + height: 60.0, + rotation: 0.0, + content: ElementContent { + text: Some("Slide Title".to_string()), + ..Default::default() + }, + style: ElementStyle { + font_size: Some(36.0), + font_weight: Some("bold".to_string()), + color: Some("#1e293b".to_string()), + ..Default::default() + }, + animations: vec![], + z_index: 1, + locked: false, + }, + SlideElement { + id: Uuid::new_v4().to_string(), + element_type: "text".to_string(), + x: 50.0, + y: 120.0, + width: 410.0, + height: 400.0, + rotation: 0.0, + content: ElementContent { + text: Some("Left column content".to_string()), + ..Default::default() + }, + style: ElementStyle { + font_size: Some(18.0), + color: Some("#374151".to_string()), + ..Default::default() + }, + animations: vec![], + z_index: 2, + locked: false, + }, + SlideElement { + id: Uuid::new_v4().to_string(), + element_type: "text".to_string(), + x: 500.0, + y: 120.0, + width: 410.0, + height: 400.0, + rotation: 0.0, + content: ElementContent { + text: Some("Right column content".to_string()), + ..Default::default() + }, + style: ElementStyle { + font_size: Some(18.0), + color: Some("#374151".to_string()), + ..Default::default() + }, + animations: vec![], + z_index: 3, + locked: false, + }, + ], + "section" => vec![SlideElement { + id: Uuid::new_v4().to_string(), + element_type: "text".to_string(), + x: 100.0, + y: 220.0, + width: 760.0, + height: 100.0, + rotation: 0.0, + content: ElementContent { + text: Some("Section Title".to_string()), + ..Default::default() + }, + style: ElementStyle { + font_size: Some(48.0), + font_weight: Some("bold".to_string()), + text_align: Some("center".to_string()), + color: Some("#1e293b".to_string()), + ..Default::default() + }, + animations: vec![], + z_index: 1, + locked: false, + }], + _ => vec![], + }; + + Slide { + id: Uuid::new_v4().to_string(), + layout: layout.to_string(), + elements, + background: SlideBackground { + bg_type: "solid".to_string(), + color: Some("#ffffff".to_string()), + gradient: None, + image_url: None, + image_fit: None, + }, + notes: None, + transition: Some(SlideTransition { + transition_type: "fade".to_string(), + duration: 0.5, + direction: None, + }), + } +} + +pub async fn handle_new_presentation( + State(_state): State>, +) -> Result, (StatusCode, Json)> { + let presentation = Presentation { + id: Uuid::new_v4().to_string(), + name: "Untitled Presentation".to_string(), + owner_id: get_current_user_id(), + slides: vec![create_title_slide()], + theme: create_default_theme(), + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + Ok(Json(presentation)) +} + +pub async fn handle_list_presentations( + State(state): State>, +) -> Result>, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + match list_presentations_from_drive(&state, &user_id).await { + Ok(presentations) => Ok(Json(presentations)), + Err(e) => { + error!("Failed to list presentations: {}", e); + Ok(Json(Vec::new())) + } + } +} + +pub async fn handle_search_presentations( + State(state): State>, + Query(query): Query, +) -> Result>, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + let presentations = match list_presentations_from_drive(&state, &user_id).await { + Ok(p) => p, + Err(_) => Vec::new(), + }; + + let filtered = if let Some(q) = query.q { + let q_lower = q.to_lowercase(); + presentations + .into_iter() + .filter(|p| p.name.to_lowercase().contains(&q_lower)) + .collect() + } else { + presentations + }; + + Ok(Json(filtered)) +} + +pub async fn handle_load_presentation( + State(state): State>, + Query(query): Query, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + match load_presentation_from_drive(&state, &user_id, &query.id).await { + Ok(presentation) => Ok(Json(presentation)), + Err(e) => Err(( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ "error": e })), + )), + } +} + +pub async fn handle_save_presentation( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let presentation_id = req.id.unwrap_or_else(|| Uuid::new_v4().to_string()); + + let presentation = Presentation { + id: presentation_id.clone(), + name: req.name, + owner_id: user_id.clone(), + slides: req.slides, + theme: req.theme, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e })), + )); + } + + Ok(Json(SaveResponse { + id: presentation_id, + success: true, + message: Some("Presentation saved".to_string()), + })) +} + +pub async fn handle_delete_presentation( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + + if let Err(e) = delete_presentation_from_drive(&state, &user_id, &req.id).await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e })), + )); + } + + Ok(Json(SaveResponse { + id: req.id, + success: true, + message: Some("Presentation deleted".to_string()), + })) +} + +pub async fn handle_get_presentation_by_id( + State(state): State>, + Path(presentation_id): Path, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + match load_presentation_from_drive(&state, &user_id, &presentation_id).await { + Ok(presentation) => Ok(Json(presentation)), + Err(e) => Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + } +} + +pub async fn handle_add_slide( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + let new_slide = create_content_slide(&req.layout); + let position = req.position.unwrap_or(presentation.slides.len()); + presentation.slides.insert(position.min(presentation.slides.len()), new_slide); + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_slide_change(&req.presentation_id, "slideAdded", &user_id, Some(position), None).await; + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Slide added".to_string()) })) +} + +pub async fn handle_delete_slide( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.slide_index >= presentation.slides.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid slide index" })))); + } + + presentation.slides.remove(req.slide_index); + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_slide_change(&req.presentation_id, "slideDeleted", &user_id, Some(req.slide_index), None).await; + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Slide deleted".to_string()) })) +} + +pub async fn handle_duplicate_slide( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.slide_index >= presentation.slides.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid slide index" })))); + } + + let mut duplicated = presentation.slides[req.slide_index].clone(); + duplicated.id = Uuid::new_v4().to_string(); + for element in &mut duplicated.elements { + element.id = Uuid::new_v4().to_string(); + } + presentation.slides.insert(req.slide_index + 1, duplicated); + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_slide_change(&req.presentation_id, "slideDuplicated", &user_id, Some(req.slide_index), None).await; + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Slide duplicated".to_string()) })) +} + +pub async fn handle_reorder_slides( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + let mut new_slides = Vec::new(); + for slide_id in &req.slide_order { + if let Some(slide) = presentation.slides.iter().find(|s| &s.id == slide_id) { + new_slides.push(slide.clone()); + } + } + + if new_slides.len() == presentation.slides.len() { + presentation.slides = new_slides; + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + } + + broadcast_slide_change(&req.presentation_id, "slidesReordered", &user_id, None, None).await; + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Slides reordered".to_string()) })) +} + +pub async fn handle_update_slide_notes( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.slide_index >= presentation.slides.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid slide index" })))); + } + + presentation.slides[req.slide_index].notes = if req.notes.is_empty() { None } else { Some(req.notes) }; + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Notes updated".to_string()) })) +} + +pub async fn handle_add_element( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.slide_index >= presentation.slides.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid slide index" })))); + } + + presentation.slides[req.slide_index].elements.push(req.element.clone()); + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_slide_change(&req.presentation_id, "elementAdded", &user_id, Some(req.slide_index), Some(&req.element.id)).await; + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Element added".to_string()) })) +} + +pub async fn handle_update_element( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.slide_index >= presentation.slides.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid slide index" })))); + } + + let slide = &mut presentation.slides[req.slide_index]; + if let Some(pos) = slide.elements.iter().position(|e| e.id == req.element.id) { + slide.elements[pos] = req.element.clone(); + } else { + return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": "Element not found" })))); + } + + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_slide_change(&req.presentation_id, "elementUpdated", &user_id, Some(req.slide_index), Some(&req.element.id)).await; + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Element updated".to_string()) })) +} + +pub async fn handle_delete_element( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + if req.slide_index >= presentation.slides.len() { + return Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Invalid slide index" })))); + } + + let slide = &mut presentation.slides[req.slide_index]; + slide.elements.retain(|e| e.id != req.element_id); + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_slide_change(&req.presentation_id, "elementDeleted", &user_id, Some(req.slide_index), Some(&req.element_id)).await; + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Element deleted".to_string()) })) +} + +pub async fn handle_apply_theme( + State(state): State>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_id = get_current_user_id(); + let mut presentation = match load_presentation_from_drive(&state, &user_id, &req.presentation_id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + presentation.theme = req.theme; + presentation.updated_at = Utc::now(); + + if let Err(e) = save_presentation_to_drive(&state, &user_id, &presentation).await { + return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })))); + } + + broadcast_slide_change(&req.presentation_id, "themeChanged", &user_id, None, None).await; + Ok(Json(SaveResponse { id: req.presentation_id, success: true, message: Some("Theme applied".to_string()) })) +} + +pub async fn handle_export_presentation( + State(state): State>, + Json(req): Json, +) -> Result)> { + let user_id = get_current_user_id(); + let presentation = match load_presentation_from_drive(&state, &user_id, &req.id).await { + Ok(p) => p, + Err(e) => return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": e })))), + }; + + match req.format.as_str() { + "json" => { + let json = serde_json::to_string_pretty(&presentation).unwrap_or_default(); + Ok(([(axum::http::header::CONTENT_TYPE, "application/json")], json)) + } + "html" => { + let html = export_to_html(&presentation); + Ok(([(axum::http::header::CONTENT_TYPE, "text/html")], html)) + } + _ => Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": "Unsupported format" })))), + } +} + +fn export_to_html(presentation: &Presentation) -> String { + let mut html = format!( + r#" +{} +"#, + presentation.name + ); + + for (i, slide) in presentation.slides.iter().enumerate() { + let bg_color = slide.background.color.as_deref().unwrap_or("#ffffff"); + html.push_str(&format!( + r#"

    Slide {}

    "#, + bg_color, i + 1 + )); + + for element in &slide.elements { + let style = format!( + "left:{}px;top:{}px;width:{}px;height:{}px;transform:rotate({}deg);", + element.x, element.y, element.width, element.height, element.rotation + ); + let extra_style = format!( + "font-size:{}px;color:{};text-align:{};font-weight:{};", + element.style.font_size.unwrap_or(16.0), + element.style.color.as_deref().unwrap_or("#000"), + element.style.text_align.as_deref().unwrap_or("left"), + element.style.font_weight.as_deref().unwrap_or("normal") + ); + + match element.element_type.as_str() { + "text" => { + let text = element.content.text.as_deref().unwrap_or(""); + html.push_str(&format!( + r#"
    {}
    "#, + style, extra_style, text + )); + } + "image" => { + if let Some(ref src) = element.content.src { + html.push_str(&format!( + r#""#, + style, src + )); + } + } + "shape" => { + let fill = element.style.fill.as_deref().unwrap_or("#3b82f6"); + html.push_str(&format!( + r#"
    "#, + style, fill + )); + } + _ => {} + } + } + html.push_str("
    "); + } + + html.push_str(""); + html +} + +pub async fn handle_get_collaborators( + Path(presentation_id): Path, +) -> impl IntoResponse { + let channels = get_slide_channels().read().await; + let active = channels.contains_key(&presentation_id); + Json(serde_json::json!({ "presentation_id": presentation_id, "collaborators": [], "active": active })) +} + +pub async fn handle_slides_websocket( + ws: WebSocketUpgrade, + State(state): State>, + Path(presentation_id): Path, +) -> impl IntoResponse { + info!("Slides WebSocket connection request for presentation: {}", presentation_id); + ws.on_upgrade(move |socket| handle_slides_connection(socket, state, presentation_id)) +} + +async fn handle_slides_connection(socket: WebSocket, _state: Arc, presentation_id: String) { + let (mut sender, mut receiver) = socket.split(); + + let channels = get_slide_channels(); + let rx = { + let mut channels_write = channels.write().await; + let tx = channels_write.entry(presentation_id.clone()).or_insert_with(|| broadcast::channel(256).0); + tx.subscribe() + }; + + let user_id = format!("user-{}", &Uuid::new_v4().to_string()[..8]); + let user_color = get_random_color(); + + let welcome = serde_json::json!({ + "type": "connected", + "presentation_id": presentation_id, + "user_id": user_id, + "user_color": user_color, + "timestamp": Utc::now().to_rfc3339() + }); + + if sender.send(Message::Text(welcome.to_string())).await.is_err() { + error!("Failed to send welcome message"); + return; + } + + info!("User {} connected to presentation {}", user_id, presentation_id); + broadcast_slide_change(&presentation_id, "userJoined", &user_id, None, None).await; + + let presentation_id_recv = presentation_id.clone(); + let user_id_recv = user_id.clone(); + let user_id_send = user_id.clone(); + + let mut rx = rx; + let send_task = tokio::spawn(async move { + while let Ok(msg) = rx.recv().await { + if msg.user_id != user_id_send { + if let Ok(json) = serde_json::to_string(&msg) { + if sender.send(Message::Text(json)).await.is_err() { + break; + } + } + } + } + }); + + let recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = receiver.next().await { + match msg { + Message::Text(text) => { + if let Ok(parsed) = serde_json::from_str::(&text) { + let msg_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or(""); + let slide_index = parsed.get("slideIndex").and_then(|v| v.as_u64()).map(|v| v as usize); + let element_id = parsed.get("elementId").and_then(|v| v.as_str()).map(String::from); + + match msg_type { + "elementMove" | "elementResize" | "elementUpdate" | "slideChange" | "cursor" => { + broadcast_slide_change(&presentation_id_recv, msg_type, &user_id_recv, slide_index, element_id.as_deref()).await; + } + _ => {} + } + } + } + Message::Close(_) => break, + _ => {} + } + } + }); + + tokio::select! { + _ = send_task => {}, + _ = recv_task => {}, + } + + broadcast_slide_change(&presentation_id, "userLeft", &user_id, None, None).await; + info!("User {} disconnected from presentation {}", user_id, presentation_id); +} + +async fn broadcast_slide_change( + presentation_id: &str, + msg_type: &str, + user_id: &str, + slide_index: Option, + element_id: Option<&str>, +) { + let channels = get_slide_channels().read().await; + if let Some(tx) = channels.get(presentation_id) { + let msg = SlideMessage { + msg_type: msg_type.to_string(), + presentation_id: presentation_id.to_string(), + user_id: user_id.to_string(), + user_name: format!("User {}", &user_id[..8.min(user_id.len())]), + user_color: get_random_color(), + slide_index, + element_id: element_id.map(String::from), + data: None, + timestamp: Utc::now(), + }; + let _ = tx.send(msg); + } +} + +fn get_random_color() -> String { + let colors = [ + "#3b82f6", "#ef4444", "#22c55e", "#f59e0b", "#8b5cf6", + "#ec4899", "#14b8a6", "#f97316", "#6366f1", "#84cc16", + ]; + let idx = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos() as usize % colors.len()) + .unwrap_or(0); + colors[idx].to_string() +} diff --git a/src/sources/knowledge_base.rs b/src/sources/knowledge_base.rs index fdba40cda..5e022a072 100644 --- a/src/sources/knowledge_base.rs +++ b/src/sources/knowledge_base.rs @@ -2,7 +2,7 @@ use crate::shared::state::AppState; use axum::{ extract::{Multipart, Path, Query, State}, response::{Html, IntoResponse}, - routing::{delete, get, post}, + routing::{get, post}, Json, Router, }; use chrono::{DateTime, Utc}; @@ -235,8 +235,7 @@ pub fn configure_knowledge_base_routes() -> Router> { .route(ApiUrls::SOURCES_KB_UPLOAD, post(handle_upload_document)) .route(ApiUrls::SOURCES_KB_LIST, get(handle_list_sources)) .route(ApiUrls::SOURCES_KB_QUERY, post(handle_query_knowledge_base)) - .route(&ApiUrls::SOURCES_KB_BY_ID.replace(":id", "{id}"), get(handle_get_source)) - .route(&ApiUrls::SOURCES_KB_BY_ID.replace(":id", "{id}"), delete(handle_delete_source)) + .route(ApiUrls::SOURCES_KB_BY_ID, get(handle_get_source).delete(handle_delete_source)) .route(ApiUrls::SOURCES_KB_REINDEX, post(handle_reindex_sources)) .route(ApiUrls::SOURCES_KB_STATS, get(handle_get_stats)) } diff --git a/src/sources/mod.rs b/src/sources/mod.rs index 40b9fe77e..45436483d 100644 --- a/src/sources/mod.rs +++ b/src/sources/mod.rs @@ -9,7 +9,7 @@ use axum::{ extract::{Json, Path, Query, State}, http::StatusCode, response::{Html, IntoResponse}, - routing::{delete, get, post, put}, + routing::{get, post}, Router, }; use log::{error, info}; @@ -171,22 +171,11 @@ pub fn configure_sources_routes() -> Router> { .route(ApiUrls::SOURCES_APPS, get(handle_list_apps)) .route(ApiUrls::SOURCES_MCP, get(handle_list_mcp_servers_json)) .route(ApiUrls::SOURCES_MCP, post(handle_add_mcp_server)) - .route(&ApiUrls::SOURCES_MCP_BY_NAME.replace(":name", "{name}"), get(handle_get_mcp_server)) - .route(&ApiUrls::SOURCES_MCP_BY_NAME.replace(":name", "{name}"), put(handle_update_mcp_server)) - .route(&ApiUrls::SOURCES_MCP_BY_NAME.replace(":name", "{name}"), delete(handle_delete_mcp_server)) - .route( - &ApiUrls::SOURCES_MCP_ENABLE.replace(":name", "{name}"), - post(handle_enable_mcp_server), - ) - .route( - &ApiUrls::SOURCES_MCP_DISABLE.replace(":name", "{name}"), - post(handle_disable_mcp_server), - ) - .route( - &ApiUrls::SOURCES_MCP_TOOLS.replace(":name", "{name}"), - get(handle_list_mcp_server_tools), - ) - .route(&ApiUrls::SOURCES_MCP_TEST.replace(":name", "{name}"), post(handle_test_mcp_server)) + .route(ApiUrls::SOURCES_MCP_BY_NAME, get(handle_get_mcp_server).put(handle_update_mcp_server).delete(handle_delete_mcp_server)) + .route(ApiUrls::SOURCES_MCP_ENABLE, post(handle_enable_mcp_server)) + .route(ApiUrls::SOURCES_MCP_DISABLE, post(handle_disable_mcp_server)) + .route(ApiUrls::SOURCES_MCP_TOOLS, get(handle_list_mcp_server_tools)) + .route(ApiUrls::SOURCES_MCP_TEST, post(handle_test_mcp_server)) .route(ApiUrls::SOURCES_MCP_SCAN, post(handle_scan_mcp_directory)) .route(ApiUrls::SOURCES_MCP_EXAMPLES, get(handle_get_mcp_examples)) .route(ApiUrls::SOURCES_MENTIONS, get(handle_mentions_autocomplete)) @@ -989,15 +978,6 @@ fn load_mcp_servers_catalog() -> Option { } } -fn get_type_badge_class(server_type: &str) -> &'static str { - match server_type { - "Local" => "badge-local", - "Remote" => "badge-remote", - "Custom" => "badge-custom", - _ => "badge-default", - } -} - fn get_category_icon(category: &str) -> &'static str { match category { "Database" => "🗄️", @@ -1078,7 +1058,6 @@ pub async fn handle_mcp_servers( server.status, crate::basic::keywords::mcp_client::McpServerStatus::Active ); - let status_class = if is_active { "status-active" } else { "status-inactive" }; let status_text = if is_active { "Active" } else { "Inactive" }; let status_bg = if is_active { "#e8f5e9" } else { "#ffebee" }; diff --git a/src/tasks/mod.rs b/src/tasks/mod.rs index b98c7a76f..2246ddccf 100644 --- a/src/tasks/mod.rs +++ b/src/tasks/mod.rs @@ -2079,24 +2079,16 @@ pub fn configure_task_routes() -> Router> { .route(ApiUrls::TASKS_STATS, get(handle_task_stats_htmx)) .route(ApiUrls::TASKS_TIME_SAVED, get(handle_time_saved)) .route(ApiUrls::TASKS_COMPLETED, delete(handle_clear_completed)) - .route( - &ApiUrls::TASKS_GET_HTMX.replace(":id", "{id}"), - get(handle_task_get), - ) + .route(ApiUrls::TASKS_GET_HTMX, get(handle_task_get)) // JSON API - Stats .route(ApiUrls::TASKS_STATS_JSON, get(handle_task_stats)) // JSON API - Parameterized task routes - .route( - &ApiUrls::TASK_BY_ID.replace(":id", "{id}"), - put(handle_task_update) - .delete(handle_task_delete) - .patch(handle_task_patch), - ) - .route(&ApiUrls::TASK_ASSIGN.replace(":id", "{id}"), post(handle_task_assign)) - .route(&ApiUrls::TASK_STATUS.replace(":id", "{id}"), put(handle_task_status_update)) - .route(&ApiUrls::TASK_PRIORITY.replace(":id", "{id}"), put(handle_task_priority_set)) - .route("/api/tasks/{id}/dependencies", put(handle_task_set_dependencies)) - .route("/api/tasks/{id}/cancel", post(handle_task_cancel)) + .route(ApiUrls::TASK_BY_ID, put(handle_task_update).delete(handle_task_delete).patch(handle_task_patch)) + .route(ApiUrls::TASK_ASSIGN, post(handle_task_assign)) + .route(ApiUrls::TASK_STATUS, put(handle_task_status_update)) + .route(ApiUrls::TASK_PRIORITY, put(handle_task_priority_set)) + .route("/api/tasks/:id/dependencies", put(handle_task_set_dependencies)) + .route("/api/tasks/:id/cancel", post(handle_task_cancel)) } pub async fn handle_task_cancel( diff --git a/src/telegram/mod.rs b/src/telegram/mod.rs new file mode 100644 index 000000000..38d6f0f9e --- /dev/null +++ b/src/telegram/mod.rs @@ -0,0 +1,539 @@ +use crate::bot::BotOrchestrator; +use crate::core::bot::channels::telegram::TelegramAdapter; +use crate::core::bot::channels::ChannelAdapter; +use crate::shared::models::{BotResponse, UserSession}; +use crate::shared::state::{AppState, AttendantNotification}; +use axum::{ + extract::State, + http::StatusCode, + response::IntoResponse, + routing::post, + Json, Router, +}; + +use chrono::Utc; +use diesel::prelude::*; +use log::{debug, error, info}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use uuid::Uuid; + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramUpdate { + pub update_id: i64, + #[serde(default)] + pub message: Option, + #[serde(default)] + pub edited_message: Option, + #[serde(default)] + pub callback_query: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramMessage { + pub message_id: i64, + pub from: Option, + pub chat: TelegramChat, + pub date: i64, + #[serde(default)] + pub text: Option, + #[serde(default)] + pub photo: Option>, + #[serde(default)] + pub document: Option, + #[serde(default)] + pub voice: Option, + #[serde(default)] + pub audio: Option, + #[serde(default)] + pub video: Option, + #[serde(default)] + pub location: Option, + #[serde(default)] + pub contact: Option, + #[serde(default)] + pub caption: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramUser { + pub id: i64, + pub is_bot: bool, + pub first_name: String, + #[serde(default)] + pub last_name: Option, + #[serde(default)] + pub username: Option, + #[serde(default)] + pub language_code: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramChat { + pub id: i64, + #[serde(rename = "type")] + pub chat_type: String, + #[serde(default)] + pub title: Option, + #[serde(default)] + pub username: Option, + #[serde(default)] + pub first_name: Option, + #[serde(default)] + pub last_name: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramPhotoSize { + pub file_id: String, + pub file_unique_id: String, + pub width: i32, + pub height: i32, + #[serde(default)] + pub file_size: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramDocument { + pub file_id: String, + pub file_unique_id: String, + #[serde(default)] + pub file_name: Option, + #[serde(default)] + pub mime_type: Option, + #[serde(default)] + pub file_size: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramVoice { + pub file_id: String, + pub file_unique_id: String, + pub duration: i32, + #[serde(default)] + pub mime_type: Option, + #[serde(default)] + pub file_size: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramAudio { + pub file_id: String, + pub file_unique_id: String, + pub duration: i32, + #[serde(default)] + pub performer: Option, + #[serde(default)] + pub title: Option, + #[serde(default)] + pub mime_type: Option, + #[serde(default)] + pub file_size: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramVideo { + pub file_id: String, + pub file_unique_id: String, + pub width: i32, + pub height: i32, + pub duration: i32, + #[serde(default)] + pub mime_type: Option, + #[serde(default)] + pub file_size: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramLocation { + pub longitude: f64, + pub latitude: f64, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramContact { + pub phone_number: String, + pub first_name: String, + #[serde(default)] + pub last_name: Option, + #[serde(default)] + pub user_id: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TelegramCallbackQuery { + pub id: String, + pub from: TelegramUser, + #[serde(default)] + pub message: Option, + #[serde(default)] + pub data: Option, +} + +pub fn configure() -> Router> { + Router::new() + .route("/webhook/telegram", post(handle_webhook)) + .route("/api/telegram/send", post(send_message)) +} + +pub async fn handle_webhook( + State(state): State>, + Json(update): Json, +) -> impl IntoResponse { + info!("Telegram webhook received: update_id={}", update.update_id); + + if let Some(message) = update.message.or(update.edited_message) { + if let Err(e) = process_message(state.clone(), &message).await { + error!("Failed to process Telegram message: {}", e); + } + } + + if let Some(callback) = update.callback_query { + if let Err(e) = process_callback(state.clone(), &callback).await { + error!("Failed to process Telegram callback: {}", e); + } + } + + StatusCode::OK +} + +async fn process_message( + state: Arc, + message: &TelegramMessage, +) -> Result<(), Box> { + let chat_id = message.chat.id.to_string(); + let user = message.from.as_ref(); + + let user_name = user + .map(|u| { + let mut name = u.first_name.clone(); + if let Some(last) = &u.last_name { + name.push(' '); + name.push_str(last); + } + name + }) + .unwrap_or_else(|| "Unknown".to_string()); + + let content = extract_message_content(message); + + if content.is_empty() { + debug!("Empty message content, skipping"); + return Ok(()); + } + + info!( + "Processing Telegram message from {} (chat_id={}): {}", + user_name, + chat_id, + if content.len() > 50 { &content[..50] } else { &content } + ); + + let session = find_or_create_session(&state, &chat_id, &user_name).await?; + + let assigned_to = session + .context_data + .get("assigned_to") + .and_then(|v| v.as_str()); + + if assigned_to.is_some() { + route_to_attendant(state.clone(), &session, &content, &chat_id, &user_name).await?; + } else { + route_to_bot(state.clone(), &session, &content, &chat_id).await?; + } + + Ok(()) +} + +fn extract_message_content(message: &TelegramMessage) -> String { + if let Some(text) = &message.text { + return text.clone(); + } + + if let Some(caption) = &message.caption { + return caption.clone(); + } + + if message.photo.is_some() { + return "[Photo received]".to_string(); + } + + if message.document.is_some() { + return "[Document received]".to_string(); + } + + if message.voice.is_some() { + return "[Voice message received]".to_string(); + } + + if message.audio.is_some() { + return "[Audio received]".to_string(); + } + + if message.video.is_some() { + return "[Video received]".to_string(); + } + + if let Some(location) = &message.location { + return format!("[Location: {}, {}]", location.latitude, location.longitude); + } + + if let Some(contact) = &message.contact { + return format!("[Contact: {} {}]", contact.first_name, contact.phone_number); + } + + String::new() +} + +async fn process_callback( + state: Arc, + callback: &TelegramCallbackQuery, +) -> Result<(), Box> { + let chat_id = callback + .message + .as_ref() + .map(|m| m.chat.id.to_string()) + .unwrap_or_default(); + + let user_name = { + let mut name = callback.from.first_name.clone(); + if let Some(last) = &callback.from.last_name { + name.push(' '); + name.push_str(last); + } + name + }; + + let data = callback.data.clone().unwrap_or_default(); + + if data.is_empty() || chat_id.is_empty() { + return Ok(()); + } + + info!( + "Processing Telegram callback from {} (chat_id={}): {}", + user_name, chat_id, data + ); + + let session = find_or_create_session(&state, &chat_id, &user_name).await?; + + route_to_bot(state, &session, &data, &chat_id).await?; + + Ok(()) +} + +async fn find_or_create_session( + state: &Arc, + chat_id: &str, + user_name: &str, +) -> Result> { + use crate::shared::models::schema::user_sessions::dsl::*; + + let mut conn = state.conn.get()?; + + let telegram_user_uuid = Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("telegram:{}", chat_id).as_bytes()); + + let existing: Option = user_sessions + .filter(user_id.eq(telegram_user_uuid)) + .order(updated_at.desc()) + .first(&mut conn) + .optional()?; + + if let Some(session) = existing { + diesel::update(user_sessions.filter(id.eq(session.id))) + .set(updated_at.eq(Utc::now())) + .execute(&mut conn)?; + return Ok(session); + } + + let bot_uuid = get_default_bot_id(state).await; + let session_uuid = Uuid::new_v4(); + + let context = serde_json::json!({ + "channel": "telegram", + "chat_id": chat_id, + "name": user_name, + }); + + let now = Utc::now(); + + diesel::insert_into(user_sessions) + .values(( + id.eq(session_uuid), + user_id.eq(telegram_user_uuid), + bot_id.eq(bot_uuid), + title.eq(format!("Telegram: {}", user_name)), + context_data.eq(&context), + created_at.eq(now), + updated_at.eq(now), + )) + .execute(&mut conn)?; + + info!("Created new Telegram session {} for chat_id {}", session_uuid, chat_id); + + let new_session = user_sessions + .filter(id.eq(session_uuid)) + .first(&mut conn)?; + + Ok(new_session) +} + +async fn route_to_bot( + state: Arc, + session: &UserSession, + content: &str, + chat_id: &str, +) -> Result<(), Box> { + info!("Routing Telegram message to bot for session {}", session.id); + + let user_message = botlib::models::UserMessage::text( + session.bot_id.to_string(), + chat_id.to_string(), + session.id.to_string(), + "telegram".to_string(), + content.to_string(), + ); + + let (tx, mut rx) = tokio::sync::mpsc::channel::(10); + let orchestrator = BotOrchestrator::new(state.clone()); + + let adapter = TelegramAdapter::new(state.conn.clone(), session.bot_id); + let chat_id_clone = chat_id.to_string(); + + tokio::spawn(async move { + while let Some(response) = rx.recv().await { + let tg_response = BotResponse::new( + response.bot_id, + response.session_id, + chat_id_clone.clone(), + response.content, + "telegram", + ); + + if let Err(e) = adapter.send_message(tg_response).await { + error!("Failed to send Telegram response: {}", e); + } + } + }); + + if let Err(e) = orchestrator.stream_response(user_message, tx).await { + error!("Bot processing error: {}", e); + + let adapter = TelegramAdapter::new(state.conn.clone(), session.bot_id); + let error_response = BotResponse::new( + session.bot_id.to_string(), + session.id.to_string(), + chat_id.to_string(), + "Sorry, I encountered an error processing your message. Please try again.", + "telegram", + ); + + if let Err(e) = adapter.send_message(error_response).await { + error!("Failed to send error response: {}", e); + } + } + + Ok(()) +} + +async fn route_to_attendant( + state: Arc, + session: &UserSession, + content: &str, + chat_id: &str, + user_name: &str, +) -> Result<(), Box> { + info!( + "Routing Telegram message to attendant for session {}", + session.id + ); + + let assigned_to = session + .context_data + .get("assigned_to") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let notification = AttendantNotification { + notification_type: "message".to_string(), + session_id: session.id.to_string(), + user_id: chat_id.to_string(), + user_name: Some(user_name.to_string()), + user_phone: Some(chat_id.to_string()), + channel: "telegram".to_string(), + content: content.to_string(), + timestamp: Utc::now().to_rfc3339(), + assigned_to, + priority: 1, + }; + + if let Some(broadcast_tx) = state.attendant_broadcast.as_ref() { + if let Err(e) = broadcast_tx.send(notification.clone()) { + debug!("No attendants listening: {}", e); + } else { + info!("Notification sent to attendants"); + } + } + + Ok(()) +} + +async fn get_default_bot_id(state: &Arc) -> Uuid { + use crate::shared::models::schema::bots::dsl::*; + + if let Ok(mut conn) = state.conn.get() { + if let Ok(bot_uuid) = bots + .filter(is_active.eq(true)) + .select(id) + .first::(&mut conn) + { + return bot_uuid; + } + } + + Uuid::parse_str("f47ac10b-58cc-4372-a567-0e02b2c3d480").unwrap_or_else(|_| Uuid::new_v4()) +} + +#[derive(Debug, Deserialize)] +pub struct SendMessageRequest { + pub to: String, + pub message: String, +} + +pub async fn send_message( + State(state): State>, + Json(request): Json, +) -> impl IntoResponse { + info!("Sending Telegram message to {}", request.to); + + let bot_id = get_default_bot_id(&state).await; + let adapter = TelegramAdapter::new(state.conn.clone(), bot_id); + + let response = BotResponse::new( + bot_id.to_string(), + Uuid::new_v4().to_string(), + request.to.clone(), + request.message.clone(), + "telegram", + ); + + match adapter.send_message(response).await { + Ok(_) => ( + StatusCode::OK, + Json(serde_json::json!({ + "success": true, + "message": "Message sent successfully" + })), + ), + Err(e) => { + error!("Failed to send Telegram message: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "success": false, + "error": "Failed to send message" + })), + ) + } + } +}