use crate::core::shared::models::UserSession; use crate::core::shared::state::AppState; use diesel::prelude::*; use log::{info, trace}; use rhai::{Dynamic, Engine}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelConfig { pub name: String, pub url: String, pub model_path: String, pub api_key: Option, pub max_tokens: Option, pub temperature: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Default)] pub enum RoutingStrategy { #[default] Manual, Auto, LoadBalanced, Fallback, } #[derive(Debug, Clone)] pub struct ModelRouter { pub models: HashMap, pub default_model: String, pub routing_strategy: RoutingStrategy, } impl Default for ModelRouter { fn default() -> Self { Self { models: HashMap::new(), default_model: "default".to_string(), routing_strategy: RoutingStrategy::Manual, } } } impl ModelRouter { pub fn new() -> Self { Self::default() } pub fn from_config(config_models: &str, bot_id: Uuid, state: &AppState) -> Self { let mut router = Self::new(); let model_names: Vec<&str> = config_models.split(';').collect(); for name in model_names { let name = name.trim(); if name.is_empty() { continue; } if let Ok(mut conn) = state.conn.get() { let model_config = load_model_config(&mut conn, bot_id, name); if let Some(config) = model_config { router.models.insert(name.to_string(), config); } } } if let Some(first_name) = config_models.split(';').next() { router.default_model = first_name.trim().to_string(); } router } pub fn get_model(&self, name: &str) -> Option<&ModelConfig> { self.models.get(name) } pub fn get_default(&self) -> Option<&ModelConfig> { self.models.get(&self.default_model) } pub fn route_query(&self, query: &str) -> &str { match self.routing_strategy { RoutingStrategy::Auto => self.auto_route(query), RoutingStrategy::LoadBalanced => self.load_balanced_route(), RoutingStrategy::Fallback | RoutingStrategy::Manual => &self.default_model, } } fn auto_route(&self, query: &str) -> &str { let query_lower = query.to_lowercase(); if (query_lower.contains("code") || query_lower.contains("program") || query_lower.contains("function") || query_lower.contains("debug") || query_lower.contains("error") || query_lower.contains("syntax")) && self.models.contains_key("code") { return "code"; } if (query_lower.contains("analyze") || query_lower.contains("explain") || query_lower.contains("compare") || query_lower.contains("evaluate") || query.len() > 500) && self.models.contains_key("quality") { return "quality"; } if (query.len() < 100 || query_lower.contains("what is") || query_lower.contains("define") || query_lower.contains("hello")) && self.models.contains_key("fast") { return "fast"; } &self.default_model } fn load_balanced_route(&self) -> &str { &self.default_model } } fn load_model_config( conn: &mut diesel::PgConnection, bot_id: Uuid, model_name: &str, ) -> Option { #[derive(QueryableByName)] struct ConfigRow { #[diesel(sql_type = diesel::sql_types::Text)] config_key: String, #[diesel(sql_type = diesel::sql_types::Text)] config_value: String, } let suffix = if model_name == "default" { "".to_string() } else { format!("-{}", model_name) }; let model_key = format!("llm-model{}", suffix); let url_key = format!("llm-url{}", suffix); let key_key = format!("llm-key{}", suffix); let configs: Vec = diesel::sql_query( "SELECT config_key, config_value FROM bot_configuration \ WHERE bot_id = $1 AND config_key IN ($2, $3, $4)", ) .bind::(bot_id) .bind::(&model_key) .bind::(&url_key) .bind::(&key_key) .load(conn) .ok()?; let mut model_path = String::new(); let mut url = String::new(); let mut api_key = None; for config in configs { if config.config_key == model_key { model_path = config.config_value; } else if config.config_key == url_key { url = config.config_value; } else if config.config_key == key_key && config.config_value != "none" { api_key = Some(config.config_value); } } if model_path.is_empty() && url.is_empty() { return None; } Some(ModelConfig { name: model_name.to_string(), url, model_path, api_key, max_tokens: None, temperature: None, }) } pub fn register_model_routing_keywords( state: Arc, user: UserSession, engine: &mut Engine, ) { use_model_keyword(Arc::clone(&state), user.clone(), engine); set_model_routing_keyword(Arc::clone(&state), user.clone(), engine); get_current_model_keyword(Arc::clone(&state), user.clone(), engine); list_models_keyword(state, user, engine); } pub fn use_model_keyword(state: Arc, user: UserSession, engine: &mut Engine) { let state_clone = Arc::clone(&state); let user_clone = user; engine .register_custom_syntax(["USE", "MODEL", "$expr$"], false, move |context, inputs| { let model_name = context .eval_expression_tree(&inputs[0])? .to_string() .trim_matches('"') .to_string(); trace!("USE MODEL '{}' for session: {}", model_name, user_clone.id); let state_for_task = Arc::clone(&state_clone); let session_id = user_clone.id; let model_name_clone = model_name; let (tx, rx) = std::sync::mpsc::channel(); std::thread::spawn(move || { let _rt = match tokio::runtime::Runtime::new() { Ok(rt) => rt, Err(e) => { let _ = tx.send(Err(format!("Failed to create runtime: {}", e))); return; } }; let result = set_session_model(&state_for_task, session_id, &model_name_clone); let _ = tx.send(result); }); match rx.recv_timeout(std::time::Duration::from_secs(10)) { Ok(Ok(msg)) => Ok(Dynamic::from(msg)), Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime( e.into(), rhai::Position::NONE, ))), Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime( "USE MODEL timed out".into(), rhai::Position::NONE, ))), } }) .expect("Failed to register USE MODEL syntax"); } pub fn set_model_routing_keyword(state: Arc, user: UserSession, engine: &mut Engine) { let state_clone = Arc::clone(&state); let user_clone = user; engine .register_custom_syntax( ["SET", "MODEL", "ROUTING", "$expr$"], false, move |context, inputs| { let strategy_str = context .eval_expression_tree(&inputs[0])? .to_string() .trim_matches('"') .to_lowercase(); let strategy = match strategy_str.as_str() { "auto" => RoutingStrategy::Auto, "load-balanced" | "loadbalanced" => RoutingStrategy::LoadBalanced, "fallback" => RoutingStrategy::Fallback, _ => RoutingStrategy::Manual, }; trace!( "SET MODEL ROUTING {:?} for session: {}", strategy, user_clone.id ); let state_for_task = Arc::clone(&state_clone); let session_id = user_clone.id; let strategy_clone = strategy; let (tx, rx) = std::sync::mpsc::channel(); std::thread::spawn(move || { let result = set_session_routing_strategy(&state_for_task, session_id, strategy_clone); let _ = tx.send(result); }); match rx.recv_timeout(std::time::Duration::from_secs(10)) { Ok(Ok(msg)) => Ok(Dynamic::from(msg)), Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime( e.into(), rhai::Position::NONE, ))), Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime( "SET MODEL ROUTING timed out".into(), rhai::Position::NONE, ))), } }, ) .expect("Failed to register SET MODEL ROUTING syntax"); } pub fn get_current_model_keyword(state: Arc, user: UserSession, engine: &mut Engine) { let state_clone: Arc = Arc::clone(&state); let user_clone = user; engine.register_fn("GET CURRENT MODEL", move || -> String { let state = Arc::::clone(&state_clone); if let Ok(mut conn) = state.conn.get() { get_session_model_sync(&mut conn, user_clone.id) .unwrap_or_else(|_| "default".to_string()) } else { "default".to_string() } }); } pub fn list_models_keyword(state: Arc, user: UserSession, engine: &mut Engine) { let state_clone: Arc = Arc::clone(&state); let user_clone = user; engine.register_fn("LIST MODELS", move || -> rhai::Array { let state = Arc::::clone(&state_clone); if let Ok(mut conn) = state.conn.get() { list_available_models_sync(&mut conn, user_clone.bot_id) .unwrap_or_default() .into_iter() .map(Dynamic::from) .collect() } else { rhai::Array::new() } }); } fn set_session_model( state: &AppState, session_id: Uuid, model_name: &str, ) -> Result { let mut conn = state .conn .get() .map_err(|e| format!("Failed to acquire database connection: {}", e))?; let now = chrono::Utc::now(); diesel::sql_query( "INSERT INTO session_preferences (session_id, preference_key, preference_value, updated_at) \ VALUES ($1, 'current_model', $2, $3) \ ON CONFLICT (session_id, preference_key) DO UPDATE SET \ preference_value = EXCLUDED.preference_value, \ updated_at = EXCLUDED.updated_at", ) .bind::(session_id) .bind::(model_name) .bind::(now) .execute(&mut conn) .map_err(|e| format!("Failed to set session model: {}", e))?; info!("Session {} now using model: {}", session_id, model_name); Ok(format!("Now using model: {}", model_name)) } fn set_session_routing_strategy( state: &AppState, session_id: Uuid, strategy: RoutingStrategy, ) -> Result { let mut conn = state .conn .get() .map_err(|e| format!("Failed to acquire database connection: {}", e))?; let now = chrono::Utc::now(); let strategy_str = match strategy { RoutingStrategy::Manual => "manual", RoutingStrategy::Auto => "auto", RoutingStrategy::LoadBalanced => "load-balanced", RoutingStrategy::Fallback => "fallback", }; diesel::sql_query( "INSERT INTO session_preferences (session_id, preference_key, preference_value, updated_at) \ VALUES ($1, 'model_routing', $2, $3) \ ON CONFLICT (session_id, preference_key) DO UPDATE SET \ preference_value = EXCLUDED.preference_value, \ updated_at = EXCLUDED.updated_at", ) .bind::(session_id) .bind::(strategy_str) .bind::(now) .execute(&mut conn) .map_err(|e| format!("Failed to set routing strategy: {}", e))?; info!( "Session {} routing strategy set to: {}", session_id, strategy_str ); Ok(format!("Model routing set to: {}", strategy_str)) } fn get_session_model_sync( conn: &mut diesel::PgConnection, session_id: Uuid, ) -> Result { #[derive(QueryableByName)] struct PrefValue { #[diesel(sql_type = diesel::sql_types::Text)] preference_value: String, } // 1. Check session preference first (set by USE MODEL) let result: Option = diesel::sql_query( "SELECT preference_value FROM session_preferences \ WHERE session_id = $1 AND preference_key = 'current_model' LIMIT 1", ) .bind::(session_id) .get_result(conn) .optional() .map_err(|e| format!("Failed to get session model: {}", e))?; if let Some(pref) = result { return Ok(pref.preference_value); } // 2. No session preference - get bot's configured model // Need to get bot_id from session first #[derive(QueryableByName)] struct SessionBot { #[diesel(sql_type = diesel::sql_types::Uuid)] bot_id: Uuid, } let bot_result: Option = diesel::sql_query( "SELECT bot_id FROM sessions WHERE id = $1 LIMIT 1", ) .bind::(session_id) .get_result(conn) .optional() .map_err(|e| format!("Failed to get session bot: {}", e))?; if let Some(session_bot) = bot_result { let bot_id = session_bot.bot_id; // Get bot's llm-model config #[derive(QueryableByName)] struct ConfigValue { #[diesel(sql_type = diesel::sql_types::Text)] config_value: String, } let bot_model: Option = diesel::sql_query( "SELECT config_value FROM bot_configuration \ WHERE bot_id = $1 AND config_key = 'llm-model' LIMIT 1", ) .bind::(bot_id) .get_result(conn) .optional() .map_err(|e| format!("Failed to get bot model: {}", e))?; if let Some(model) = bot_model { if !model.config_value.is_empty() && model.config_value != "true" { return Ok(model.config_value); } } // 3. Bot has no model configured - fall back to default bot's model let (default_bot_id, _) = crate::core::bot::get_default_bot(conn); let default_model: Option = diesel::sql_query( "SELECT config_value FROM bot_configuration \ WHERE bot_id = $1 AND config_key = 'llm-model' LIMIT 1", ) .bind::(default_bot_id) .get_result(conn) .optional() .map_err(|e| format!("Failed to get default bot model: {}", e))?; if let Some(model) = default_model { if !model.config_value.is_empty() && model.config_value != "true" { return Ok(model.config_value); } } } // 4. Ultimate fallback Ok("llama-3.3:8b".to_string()) } fn list_available_models_sync( conn: &mut diesel::PgConnection, bot_id: Uuid, ) -> Result, String> { #[derive(QueryableByName)] struct ConfigRow { #[diesel(sql_type = diesel::sql_types::Text)] config_value: String, } let result: Option = diesel::sql_query( "SELECT config_value FROM bot_configuration \ WHERE bot_id = $1 AND config_key = 'llm-models' LIMIT 1", ) .bind::(bot_id) .get_result(conn) .optional() .map_err(|e| format!("Failed to list models: {}", e))?; if let Some(config) = result { Ok(config .config_value .split(';') .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .collect()) } else { Ok(vec!["default".to_string()]) } } pub fn get_session_model(state: &AppState, session_id: Uuid) -> String { if let Ok(mut conn) = state.conn.get() { get_session_model_sync(&mut conn, session_id).unwrap_or_else(|_| "default".to_string()) } else { "default".to_string() } } pub fn get_session_routing_strategy(state: &AppState, session_id: Uuid) -> RoutingStrategy { if let Ok(mut conn) = state.conn.get() { #[derive(QueryableByName)] struct PrefValue { #[diesel(sql_type = diesel::sql_types::Text)] preference_value: String, } let result: Option = diesel::sql_query( "SELECT preference_value FROM session_preferences \ WHERE session_id = $1 AND preference_key = 'model_routing' LIMIT 1", ) .bind::(session_id) .get_result(&mut conn) .optional() .ok() .flatten(); if let Some(pref) = result { match pref.preference_value.as_str() { "auto" => RoutingStrategy::Auto, "load-balanced" => RoutingStrategy::LoadBalanced, "fallback" => RoutingStrategy::Fallback, _ => RoutingStrategy::Manual, } } else { RoutingStrategy::Manual } } else { RoutingStrategy::Manual } }