use crate::services::config::AIConfig; use langchain_rust::llm::OpenAI; use langchain_rust::{language_models::llm::LLM, llm::AzureConfig}; use log::{debug, warn}; use rhai::{Array, Dynamic}; use serde_json::{json, Value}; use smartstring::SmartString; use sqlx::Column; // Required for .name() method use sqlx::TypeInfo; // Required for .type_info() method use sqlx::{postgres::PgRow, Row}; use sqlx::{Decode, Type}; use std::error::Error; use std::fs::File; use std::io::BufReader; use std::path::Path; use tokio_stream::StreamExt; use zip::ZipArchive; use tokio::fs::File as TokioFile; use reqwest::Client; use tokio::io::AsyncWriteExt; pub fn azure_from_config(config: &AIConfig) -> AzureConfig { AzureConfig::default() .with_api_key(&config.key) .with_api_base(&config.endpoint) .with_api_version(&config.version) .with_deployment_id(&config.instance) } pub async fn call_llm( text: &str, ai_config: &AIConfig, ) -> Result> { let azure_config = azure_from_config(&ai_config.clone()); let open_ai = OpenAI::new(azure_config); // Directly use the input text as prompt let prompt = text.to_string(); // Call LLM and return the raw text response match open_ai.invoke(&prompt).await { Ok(response_text) => Ok(response_text), Err(err) => { eprintln!("Error invoking LLM API: {}", err); Err(Box::new(std::io::Error::new( std::io::ErrorKind::Other, "Failed to invoke LLM API", ))) } } } pub fn extract_zip_recursive( zip_path: &Path, destination_path: &Path, ) -> Result<(), Box> { let file = File::open(zip_path)?; let buf_reader = BufReader::new(file); let mut archive = ZipArchive::new(buf_reader)?; for i in 0..archive.len() { let mut file = archive.by_index(i)?; let outpath = destination_path.join(file.mangled_name()); if file.is_dir() { std::fs::create_dir_all(&outpath)?; } else { if let Some(parent) = outpath.parent() { if !parent.exists() { std::fs::create_dir_all(&parent)?; } } let mut outfile = File::create(&outpath)?; std::io::copy(&mut file, &mut outfile)?; } } Ok(()) } pub fn row_to_json(row: PgRow) -> Result> { let mut result = serde_json::Map::new(); let columns = row.columns(); debug!("Converting row with {} columns", columns.len()); for (i, column) in columns.iter().enumerate() { let column_name = column.name(); let type_name = column.type_info().name(); let value = match type_name { "INT4" | "int4" => handle_nullable_type::(&row, i, column_name), "INT8" | "int8" => handle_nullable_type::(&row, i, column_name), "FLOAT4" | "float4" => handle_nullable_type::(&row, i, column_name), "FLOAT8" | "float8" => handle_nullable_type::(&row, i, column_name), "TEXT" | "VARCHAR" | "text" | "varchar" => { handle_nullable_type::(&row, i, column_name) } "BOOL" | "bool" => handle_nullable_type::(&row, i, column_name), "JSON" | "JSONB" | "json" | "jsonb" => handle_json(&row, i, column_name), _ => { warn!("Unknown type {} for column {}", type_name, column_name); handle_nullable_type::(&row, i, column_name) } }; result.insert(column_name.to_string(), value); } Ok(Value::Object(result)) } fn handle_nullable_type<'r, T>(row: &'r PgRow, idx: usize, col_name: &str) -> Value where T: Type + Decode<'r, sqlx::Postgres> + serde::Serialize + std::fmt::Debug, { match row.try_get::, _>(idx) { Ok(Some(val)) => { debug!("Successfully read column {} as {:?}", col_name, val); json!(val) } Ok(None) => { debug!("Column {} is NULL", col_name); Value::Null } Err(e) => { warn!("Failed to read column {}: {}", col_name, e); Value::Null } } } fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value { // First try to get as Option match row.try_get::, _>(idx) { Ok(Some(val)) => { debug!("Successfully read JSON column {} as Value", col_name); return val; } Ok(None) => return Value::Null, Err(_) => (), // Fall through to other attempts } // Try as Option that might contain JSON match row.try_get::, _>(idx) { Ok(Some(s)) => match serde_json::from_str(&s) { Ok(val) => val, Err(_) => { debug!("Column {} contains string that's not JSON", col_name); json!(s) } }, Ok(None) => Value::Null, Err(e) => { warn!("Failed to read JSON column {}: {}", col_name, e); Value::Null } } } pub fn json_value_to_dynamic(value: &Value) -> Dynamic { match value { Value::Null => Dynamic::UNIT, Value::Bool(b) => Dynamic::from(*b), Value::Number(n) => { if let Some(i) = n.as_i64() { Dynamic::from(i) } else if let Some(f) = n.as_f64() { Dynamic::from(f) } else { Dynamic::UNIT } } Value::String(s) => Dynamic::from(s.clone()), Value::Array(arr) => Dynamic::from( arr.iter() .map(json_value_to_dynamic) .collect::(), ), Value::Object(obj) => Dynamic::from( obj.iter() .map(|(k, v)| (SmartString::from(k), json_value_to_dynamic(v))) .collect::(), ), } } /// Converts any value to an array - single values become single-element arrays pub fn to_array(value: Dynamic) -> Array { if value.is_array() { // Already an array - return as-is value.cast::() } else if value.is_unit() || value.is::<()>() { // Handle empty/unit case Array::new() } else { // Convert single value to single-element array Array::from([value]) } } pub async fn download_file(url: &str, output_path: &str) -> Result<(), Box> { let client = Client::new(); let response = client.get(url).send().await?; if response.status().is_success() { let mut file = TokioFile::create(output_path).await?; let mut stream = response.bytes_stream(); while let Some(chunk) = stream.next().await { file.write_all(&chunk?).await?; } debug!("File downloaded successfully to {}", output_path); } else { return Err("Failed to download file".into()); } Ok(()) } // Helper function to parse the filter string into SQL WHERE clause and parameters pub fn parse_filter(filter_str: &str) -> Result<(String, Vec), Box> { let parts: Vec<&str> = filter_str.split('=').collect(); if parts.len() != 2 { return Err("Invalid filter format. Expected 'KEY=VALUE'".into()); } let column = parts[0].trim(); let value = parts[1].trim(); // Validate column name to prevent SQL injection if !column .chars() .all(|c| c.is_ascii_alphanumeric() || c == '_') { return Err("Invalid column name in filter".into()); } // Return the parameterized query part and the value separately Ok((format!("{} = $1", column), vec![value.to_string()])) } pub fn parse_filter_with_offset(filter_str: &str, offset: usize) -> Result<(String, Vec), Box> { let parts: Vec<&str> = filter_str.split('=').collect(); if parts.len() != 2 { return Err("Invalid filter format. Expected 'KEY=VALUE'".into()); } let column = parts[0].trim(); let value = parts[1].trim(); // Validate column name to prevent SQL injection if !column.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { return Err("Invalid column name in filter".into()); } // Use the offset + 1 for the parameter number Ok((format!("{} = ${}", column, offset + 1), vec![value.to_string()])) }