use crate::config::DriveConfig; use crate::core::secrets::SecretsManager; use anyhow::{Context, Result}; use aws_config::BehaviorVersion; use aws_sdk_s3::{config::Builder as S3ConfigBuilder, Client as S3Client}; use diesel::Connection; use diesel::{ r2d2::{ConnectionManager, Pool}, PgConnection, }; use futures_util::StreamExt; #[cfg(feature = "progress-bars")] use indicatif::{ProgressBar, ProgressStyle}; use once_cell::sync::Lazy; use reqwest::Client; use rhai::{Array, Dynamic}; use serde_json::Value; use smartstring::SmartString; use std::error::Error; use std::sync::Arc; use tokio::fs::File as TokioFile; use tokio::io::AsyncWriteExt; use tokio::sync::RwLock; /// Global SecretsManager instance - initialized once, used everywhere static SECRETS_MANAGER: Lazy>>> = Lazy::new(|| Arc::new(RwLock::new(None))); /// Initialize the global secrets manager (call once at startup) pub async fn init_secrets_manager() -> Result<()> { let manager = SecretsManager::from_env()?; let mut guard = SECRETS_MANAGER.write().await; *guard = Some(manager); Ok(()) } /// Get database URL from Vault - NO FALLBACK pub async fn get_database_url() -> Result { let guard = SECRETS_MANAGER.read().await; if let Some(ref manager) = *guard { if manager.is_enabled() { return manager.get_database_url().await; } } // NO FALLBACK - Vault is mandatory Err(anyhow::anyhow!("Vault not configured. Set VAULT_ADDR and VAULT_TOKEN in .env")) } /// Get database URL synchronously (blocking) for diesel connections - NO FALLBACK pub fn get_database_url_sync() -> Result { // Check if we're in an async runtime context if let Ok(handle) = tokio::runtime::Handle::try_current() { // We're inside a tokio runtime - use block_in_place to avoid nesting let result = tokio::task::block_in_place(|| { handle.block_on(async { get_database_url().await }) }); if let Ok(url) = result { return Ok(url); } } else { // Not in a runtime - create a new one let rt = tokio::runtime::Runtime::new().map_err(|e| anyhow::anyhow!("Failed to create runtime: {}", e))?; if let Ok(url) = rt.block_on(async { get_database_url().await }) { return Ok(url); } } // NO FALLBACK - Vault is mandatory Err(anyhow::anyhow!("Vault not configured. Set VAULT_ADDR and VAULT_TOKEN in .env")) } /// Get the global SecretsManager instance pub async fn get_secrets_manager() -> Option { let guard = SECRETS_MANAGER.read().await; guard.clone() } pub async fn create_s3_operator( config: &DriveConfig, ) -> Result> { let endpoint = if !config.server.ends_with('/') { format!("{}/", config.server) } else { config.server.clone() }; let base_config = aws_config::defaults(BehaviorVersion::latest()) .endpoint_url(endpoint) .region("auto") .credentials_provider(aws_sdk_s3::config::Credentials::new( config.access_key.clone(), config.secret_key.clone(), None, None, "static", )) .load() .await; let s3_config = S3ConfigBuilder::from(&base_config) .force_path_style(true) .build(); Ok(S3Client::from_conf(s3_config)) } pub fn json_value_to_dynamic(value: &Value) -> Dynamic { match value { Value::Null => Dynamic::UNIT, Value::Bool(b) => Dynamic::from(*b), Value::Number(n) => { if let Some(i) = n.as_i64() { Dynamic::from(i) } else if let Some(f) = n.as_f64() { Dynamic::from(f) } else { Dynamic::UNIT } } Value::String(s) => Dynamic::from(s.clone()), Value::Array(arr) => Dynamic::from( arr.iter() .map(json_value_to_dynamic) .collect::(), ), Value::Object(obj) => Dynamic::from( obj.iter() .map(|(k, v)| (SmartString::from(k), json_value_to_dynamic(v))) .collect::(), ), } } pub fn to_array(value: Dynamic) -> Array { if value.is_array() { value.cast::() } else if value.is_unit() || value.is::<()>() { Array::new() } else { Array::from([value]) } } /// Download a file from a URL with progress bar (when progress-bars feature is enabled) #[cfg(feature = "progress-bars")] pub async fn download_file(url: &str, output_path: &str) -> Result<(), anyhow::Error> { use std::time::Duration; let url = url.to_string(); let output_path = output_path.to_string(); let download_handle = tokio::spawn(async move { let client = Client::builder() .user_agent("Mozilla/5.0 (compatible; BotServer/1.0)") .connect_timeout(Duration::from_secs(30)) .read_timeout(Duration::from_secs(300)) .pool_idle_timeout(Duration::from_secs(90)) .tcp_keepalive(Duration::from_secs(60)) .build()?; let response = client.get(&url).send().await?; if response.status().is_success() { let total_size = response.content_length().unwrap_or(0); let pb = ProgressBar::new(total_size); pb.set_style(ProgressStyle::default_bar() .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})") .unwrap() .progress_chars("#>-")); pb.set_message(format!("Downloading {}", url)); let mut file = TokioFile::create(&output_path).await?; let mut downloaded: u64 = 0; let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { let chunk = chunk_result?; file.write_all(&chunk).await?; downloaded += chunk.len() as u64; pb.set_position(downloaded); } pb.finish_with_message(format!("Downloaded {}", output_path)); Ok(()) } else { Err(anyhow::anyhow!("HTTP {}: {}", response.status(), url)) } }); download_handle.await? } /// Download a file from a URL (without progress bar when progress-bars feature is disabled) #[cfg(not(feature = "progress-bars"))] pub async fn download_file(url: &str, output_path: &str) -> Result<(), anyhow::Error> { use std::time::Duration; let url = url.to_string(); let output_path = output_path.to_string(); let download_handle = tokio::spawn(async move { let client = Client::builder() .user_agent("Mozilla/5.0 (compatible; BotServer/1.0)") .connect_timeout(Duration::from_secs(30)) .read_timeout(Duration::from_secs(300)) .pool_idle_timeout(Duration::from_secs(90)) .tcp_keepalive(Duration::from_secs(60)) .build()?; let response = client.get(&url).send().await?; if response.status().is_success() { let mut file = TokioFile::create(&output_path).await?; let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { let chunk = chunk_result?; file.write_all(&chunk).await?; } Ok(()) } else { Err(anyhow::anyhow!("HTTP {}: {}", response.status(), url)) } }); download_handle.await? } pub fn parse_filter(filter_str: &str) -> Result<(String, Vec), Box> { let parts: Vec<&str> = filter_str.split('=').collect(); if parts.len() != 2 { return Err("Invalid filter format. Expected 'KEY=VALUE'".into()); } let column = parts[0].trim(); let value = parts[1].trim(); if !column .chars() .all(|c| c.is_ascii_alphanumeric() || c == '_') { return Err("Invalid column name in filter".into()); } Ok((format!("{} = $1", column), vec![value.to_string()])) } pub fn estimate_token_count(text: &str) -> usize { let char_count = text.chars().count(); (char_count / 4).max(1) } pub fn establish_pg_connection() -> Result { let database_url = get_database_url_sync() .expect("Vault not configured. Set VAULT_ADDR and VAULT_TOKEN in .env"); PgConnection::establish(&database_url) .with_context(|| format!("Failed to connect to database at {}", database_url)) } pub type DbPool = Pool>; pub fn create_conn() -> Result { let database_url = get_database_url_sync() .expect("Vault not configured. Set VAULT_ADDR and VAULT_TOKEN in .env"); let manager = ConnectionManager::::new(database_url); Pool::builder().build(manager) } /// Create database connection pool using SecretsManager (async version) pub async fn create_conn_async() -> Result { let database_url = get_database_url().await .expect("Vault not configured. Set VAULT_ADDR and VAULT_TOKEN in .env"); let manager = ConnectionManager::::new(database_url); Pool::builder().build(manager) } pub fn parse_database_url(url: &str) -> (String, String, String, u32, String) { if let Some(stripped) = url.strip_prefix("postgres://") { let parts: Vec<&str> = stripped.split('@').collect(); if parts.len() == 2 { let user_pass: Vec<&str> = parts[0].split(':').collect(); let host_db: Vec<&str> = parts[1].split('/').collect(); if user_pass.len() >= 2 && host_db.len() >= 2 { let username = user_pass[0].to_string(); let password = user_pass[1].to_string(); let host_port: Vec<&str> = host_db[0].split(':').collect(); let server = host_port[0].to_string(); let port = host_port .get(1) .and_then(|p| p.parse().ok()) .unwrap_or(5432); let database = host_db[1].to_string(); return (username, password, server, port, database); } } } ( "".to_string(), "".to_string(), "".to_string(), 5432, "".to_string(), ) } /// Run database migrations pub fn run_migrations(pool: &DbPool) -> Result<(), Box> { use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations"); let mut conn = pool.get()?; conn.run_pending_migrations(MIGRATIONS).map_err( |e| -> Box { Box::new(std::io::Error::new( std::io::ErrorKind::Other, format!("Migration error: {}", e), )) }, )?; Ok(()) }