554 lines
16 KiB
Rust
554 lines
16 KiB
Rust
use crate::shared::models::UserSession;
|
|
use crate::shared::state::AppState;
|
|
use diesel::prelude::*;
|
|
use log::{info, trace};
|
|
use rhai::{Dynamic, Engine};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelConfig {
|
|
pub name: String,
|
|
pub url: String,
|
|
pub model_path: String,
|
|
pub api_key: Option<String>,
|
|
pub max_tokens: Option<u32>,
|
|
pub temperature: Option<f32>,
|
|
}
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub enum RoutingStrategy {
|
|
Manual,
|
|
Auto,
|
|
LoadBalanced,
|
|
Fallback,
|
|
}
|
|
|
|
impl Default for RoutingStrategy {
|
|
fn default() -> Self {
|
|
RoutingStrategy::Manual
|
|
}
|
|
}
|
|
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ModelRouter {
|
|
pub models: HashMap<String, ModelConfig>,
|
|
pub default_model: String,
|
|
pub routing_strategy: RoutingStrategy,
|
|
}
|
|
|
|
impl ModelRouter {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
models: HashMap::new(),
|
|
default_model: "default".to_string(),
|
|
routing_strategy: RoutingStrategy::Manual,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
pub fn from_config(config_models: &str, bot_id: Uuid, state: &AppState) -> Self {
|
|
let mut router = Self::new();
|
|
|
|
let model_names: Vec<&str> = config_models.split(';').collect();
|
|
|
|
for name in model_names {
|
|
let name = name.trim();
|
|
if name.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
|
|
|
|
if let Ok(mut conn) = state.conn.get() {
|
|
let model_config = load_model_config(&mut conn, bot_id, name);
|
|
if let Some(config) = model_config {
|
|
router.models.insert(name.to_string(), config);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
if let Some(first_name) = config_models.split(';').next() {
|
|
router.default_model = first_name.trim().to_string();
|
|
}
|
|
|
|
router
|
|
}
|
|
|
|
|
|
pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
|
|
self.models.get(name)
|
|
}
|
|
|
|
|
|
pub fn get_default(&self) -> Option<&ModelConfig> {
|
|
self.models.get(&self.default_model)
|
|
}
|
|
|
|
|
|
pub fn route_query(&self, query: &str) -> &str {
|
|
match self.routing_strategy {
|
|
RoutingStrategy::Auto => self.auto_route(query),
|
|
RoutingStrategy::LoadBalanced => self.load_balanced_route(),
|
|
RoutingStrategy::Fallback => &self.default_model,
|
|
RoutingStrategy::Manual => &self.default_model,
|
|
}
|
|
}
|
|
|
|
|
|
fn auto_route(&self, query: &str) -> &str {
|
|
let query_lower = query.to_lowercase();
|
|
|
|
|
|
if query_lower.contains("code")
|
|
|| query_lower.contains("program")
|
|
|| query_lower.contains("function")
|
|
|| query_lower.contains("debug")
|
|
|| query_lower.contains("error")
|
|
|| query_lower.contains("syntax")
|
|
{
|
|
if self.models.contains_key("code") {
|
|
return "code";
|
|
}
|
|
}
|
|
|
|
|
|
if query_lower.contains("analyze")
|
|
|| query_lower.contains("explain")
|
|
|| query_lower.contains("compare")
|
|
|| query_lower.contains("evaluate")
|
|
|| query.len() > 500
|
|
{
|
|
if self.models.contains_key("quality") {
|
|
return "quality";
|
|
}
|
|
}
|
|
|
|
|
|
if query.len() < 100
|
|
|| query_lower.contains("what is")
|
|
|| query_lower.contains("define")
|
|
|| query_lower.contains("hello")
|
|
{
|
|
if self.models.contains_key("fast") {
|
|
return "fast";
|
|
}
|
|
}
|
|
|
|
&self.default_model
|
|
}
|
|
|
|
|
|
fn load_balanced_route(&self) -> &str {
|
|
|
|
&self.default_model
|
|
}
|
|
}
|
|
|
|
|
|
fn load_model_config(
|
|
conn: &mut diesel::PgConnection,
|
|
bot_id: Uuid,
|
|
model_name: &str,
|
|
) -> Option<ModelConfig> {
|
|
#[derive(QueryableByName)]
|
|
struct ConfigRow {
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
config_key: String,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
config_value: String,
|
|
}
|
|
|
|
|
|
let suffix = if model_name == "default" {
|
|
"".to_string()
|
|
} else {
|
|
format!("-{}", model_name)
|
|
};
|
|
|
|
let model_key = format!("llm-model{}", suffix);
|
|
let url_key = format!("llm-url{}", suffix);
|
|
let key_key = format!("llm-key{}", suffix);
|
|
|
|
let configs: Vec<ConfigRow> = diesel::sql_query(
|
|
"SELECT config_key, config_value FROM bot_configuration \
|
|
WHERE bot_id = $1 AND config_key IN ($2, $3, $4)",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(bot_id)
|
|
.bind::<diesel::sql_types::Text, _>(&model_key)
|
|
.bind::<diesel::sql_types::Text, _>(&url_key)
|
|
.bind::<diesel::sql_types::Text, _>(&key_key)
|
|
.load(conn)
|
|
.ok()?;
|
|
|
|
let mut model_path = String::new();
|
|
let mut url = String::new();
|
|
let mut api_key = None;
|
|
|
|
for config in configs {
|
|
if config.config_key == model_key {
|
|
model_path = config.config_value;
|
|
} else if config.config_key == url_key {
|
|
url = config.config_value;
|
|
} else if config.config_key == key_key && config.config_value != "none" {
|
|
api_key = Some(config.config_value);
|
|
}
|
|
}
|
|
|
|
if model_path.is_empty() && url.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
Some(ModelConfig {
|
|
name: model_name.to_string(),
|
|
url,
|
|
model_path,
|
|
api_key,
|
|
max_tokens: None,
|
|
temperature: None,
|
|
})
|
|
}
|
|
|
|
|
|
pub fn register_model_routing_keywords(
|
|
state: Arc<AppState>,
|
|
user: UserSession,
|
|
engine: &mut Engine,
|
|
) {
|
|
use_model_keyword(state.clone(), user.clone(), engine);
|
|
set_model_routing_keyword(state.clone(), user.clone(), engine);
|
|
get_current_model_keyword(state.clone(), user.clone(), engine);
|
|
list_models_keyword(state.clone(), user.clone(), engine);
|
|
}
|
|
|
|
|
|
|
|
pub fn use_model_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(
|
|
&["USE", "MODEL", "$expr$"],
|
|
false,
|
|
move |context, inputs| {
|
|
let model_name = context
|
|
.eval_expression_tree(&inputs[0])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_string();
|
|
|
|
trace!("USE MODEL '{}' for session: {}", model_name, user_clone.id);
|
|
|
|
let state_for_task = Arc::clone(&state_clone);
|
|
let session_id = user_clone.id;
|
|
let model_name_clone = model_name.clone();
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
|
let result = rt.block_on(async {
|
|
set_session_model(&state_for_task, session_id, &model_name_clone).await
|
|
});
|
|
let _ = tx.send(result);
|
|
});
|
|
|
|
match rx.recv_timeout(std::time::Duration::from_secs(10)) {
|
|
Ok(Ok(msg)) => Ok(Dynamic::from(msg)),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"USE MODEL timed out".into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
},
|
|
)
|
|
.expect("Failed to register USE MODEL syntax");
|
|
}
|
|
|
|
|
|
|
|
pub fn set_model_routing_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(
|
|
&["SET", "MODEL", "ROUTING", "$expr$"],
|
|
false,
|
|
move |context, inputs| {
|
|
let strategy_str = context
|
|
.eval_expression_tree(&inputs[0])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_lowercase();
|
|
|
|
let strategy = match strategy_str.as_str() {
|
|
"auto" => RoutingStrategy::Auto,
|
|
"load-balanced" | "loadbalanced" => RoutingStrategy::LoadBalanced,
|
|
"fallback" => RoutingStrategy::Fallback,
|
|
_ => RoutingStrategy::Manual,
|
|
};
|
|
|
|
trace!(
|
|
"SET MODEL ROUTING {:?} for session: {}",
|
|
strategy,
|
|
user_clone.id
|
|
);
|
|
|
|
let state_for_task = Arc::clone(&state_clone);
|
|
let session_id = user_clone.id;
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
|
let result = rt.block_on(async {
|
|
set_session_routing_strategy(&state_for_task, session_id, strategy).await
|
|
});
|
|
let _ = tx.send(result);
|
|
});
|
|
|
|
match rx.recv_timeout(std::time::Duration::from_secs(10)) {
|
|
Ok(Ok(msg)) => Ok(Dynamic::from(msg)),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"SET MODEL ROUTING timed out".into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
},
|
|
)
|
|
.expect("Failed to register SET MODEL ROUTING syntax");
|
|
}
|
|
|
|
|
|
|
|
pub fn get_current_model_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine.register_fn("GET CURRENT MODEL", move || -> String {
|
|
let state = Arc::clone(&state_clone);
|
|
|
|
if let Ok(mut conn) = state.conn.get() {
|
|
get_session_model_sync(&mut conn, user_clone.id)
|
|
.unwrap_or_else(|_| "default".to_string())
|
|
} else {
|
|
"default".to_string()
|
|
}
|
|
});
|
|
}
|
|
|
|
|
|
|
|
pub fn list_models_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine.register_fn("LIST MODELS", move || -> rhai::Array {
|
|
let state = Arc::clone(&state_clone);
|
|
|
|
if let Ok(mut conn) = state.conn.get() {
|
|
list_available_models_sync(&mut conn, user_clone.bot_id)
|
|
.unwrap_or_default()
|
|
.into_iter()
|
|
.map(Dynamic::from)
|
|
.collect()
|
|
} else {
|
|
rhai::Array::new()
|
|
}
|
|
});
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn set_session_model(
|
|
state: &AppState,
|
|
session_id: Uuid,
|
|
model_name: &str,
|
|
) -> Result<String, String> {
|
|
let mut conn = state
|
|
.conn
|
|
.get()
|
|
.map_err(|e| format!("Failed to acquire database connection: {}", e))?;
|
|
|
|
let now = chrono::Utc::now();
|
|
|
|
|
|
diesel::sql_query(
|
|
"INSERT INTO session_preferences (session_id, preference_key, preference_value, updated_at) \
|
|
VALUES ($1, 'current_model', $2, $3) \
|
|
ON CONFLICT (session_id, preference_key) DO UPDATE SET \
|
|
preference_value = EXCLUDED.preference_value, \
|
|
updated_at = EXCLUDED.updated_at",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
|
.bind::<diesel::sql_types::Text, _>(model_name)
|
|
.bind::<diesel::sql_types::Timestamptz, _>(now)
|
|
.execute(&mut conn)
|
|
.map_err(|e| format!("Failed to set session model: {}", e))?;
|
|
|
|
info!("Session {} now using model: {}", session_id, model_name);
|
|
|
|
Ok(format!("Now using model: {}", model_name))
|
|
}
|
|
|
|
|
|
async fn set_session_routing_strategy(
|
|
state: &AppState,
|
|
session_id: Uuid,
|
|
strategy: RoutingStrategy,
|
|
) -> Result<String, String> {
|
|
let mut conn = state
|
|
.conn
|
|
.get()
|
|
.map_err(|e| format!("Failed to acquire database connection: {}", e))?;
|
|
|
|
let now = chrono::Utc::now();
|
|
let strategy_str = match strategy {
|
|
RoutingStrategy::Manual => "manual",
|
|
RoutingStrategy::Auto => "auto",
|
|
RoutingStrategy::LoadBalanced => "load-balanced",
|
|
RoutingStrategy::Fallback => "fallback",
|
|
};
|
|
|
|
diesel::sql_query(
|
|
"INSERT INTO session_preferences (session_id, preference_key, preference_value, updated_at) \
|
|
VALUES ($1, 'model_routing', $2, $3) \
|
|
ON CONFLICT (session_id, preference_key) DO UPDATE SET \
|
|
preference_value = EXCLUDED.preference_value, \
|
|
updated_at = EXCLUDED.updated_at",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
|
.bind::<diesel::sql_types::Text, _>(strategy_str)
|
|
.bind::<diesel::sql_types::Timestamptz, _>(now)
|
|
.execute(&mut conn)
|
|
.map_err(|e| format!("Failed to set routing strategy: {}", e))?;
|
|
|
|
info!(
|
|
"Session {} routing strategy set to: {}",
|
|
session_id, strategy_str
|
|
);
|
|
|
|
Ok(format!("Model routing set to: {}", strategy_str))
|
|
}
|
|
|
|
|
|
fn get_session_model_sync(
|
|
conn: &mut diesel::PgConnection,
|
|
session_id: Uuid,
|
|
) -> Result<String, String> {
|
|
#[derive(QueryableByName)]
|
|
struct PrefValue {
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
preference_value: String,
|
|
}
|
|
|
|
let result: Option<PrefValue> = diesel::sql_query(
|
|
"SELECT preference_value FROM session_preferences \
|
|
WHERE session_id = $1 AND preference_key = 'current_model' LIMIT 1",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
|
.get_result(conn)
|
|
.optional()
|
|
.map_err(|e| format!("Failed to get session model: {}", e))?;
|
|
|
|
Ok(result
|
|
.map(|r| r.preference_value)
|
|
.unwrap_or_else(|| "default".to_string()))
|
|
}
|
|
|
|
|
|
fn list_available_models_sync(
|
|
conn: &mut diesel::PgConnection,
|
|
bot_id: Uuid,
|
|
) -> Result<Vec<String>, String> {
|
|
#[derive(QueryableByName)]
|
|
struct ConfigRow {
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
config_value: String,
|
|
}
|
|
|
|
|
|
let result: Option<ConfigRow> = diesel::sql_query(
|
|
"SELECT config_value FROM bot_configuration \
|
|
WHERE bot_id = $1 AND config_key = 'llm-models' LIMIT 1",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(bot_id)
|
|
.get_result(conn)
|
|
.optional()
|
|
.map_err(|e| format!("Failed to list models: {}", e))?;
|
|
|
|
if let Some(config) = result {
|
|
Ok(config
|
|
.config_value
|
|
.split(';')
|
|
.map(|s| s.trim().to_string())
|
|
.filter(|s| !s.is_empty())
|
|
.collect())
|
|
} else {
|
|
|
|
Ok(vec!["default".to_string()])
|
|
}
|
|
}
|
|
|
|
|
|
pub fn get_session_model(state: &AppState, session_id: Uuid) -> String {
|
|
if let Ok(mut conn) = state.conn.get() {
|
|
get_session_model_sync(&mut conn, session_id).unwrap_or_else(|_| "default".to_string())
|
|
} else {
|
|
"default".to_string()
|
|
}
|
|
}
|
|
|
|
|
|
pub fn get_session_routing_strategy(state: &AppState, session_id: Uuid) -> RoutingStrategy {
|
|
if let Ok(mut conn) = state.conn.get() {
|
|
#[derive(QueryableByName)]
|
|
struct PrefValue {
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
preference_value: String,
|
|
}
|
|
|
|
let result: Option<PrefValue> = diesel::sql_query(
|
|
"SELECT preference_value FROM session_preferences \
|
|
WHERE session_id = $1 AND preference_key = 'model_routing' LIMIT 1",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
|
.get_result(&mut conn)
|
|
.optional()
|
|
.ok()
|
|
.flatten();
|
|
|
|
if let Some(pref) = result {
|
|
match pref.preference_value.as_str() {
|
|
"auto" => RoutingStrategy::Auto,
|
|
"load-balanced" => RoutingStrategy::LoadBalanced,
|
|
"fallback" => RoutingStrategy::Fallback,
|
|
_ => RoutingStrategy::Manual,
|
|
}
|
|
} else {
|
|
RoutingStrategy::Manual
|
|
}
|
|
} else {
|
|
RoutingStrategy::Manual
|
|
}
|
|
}
|
|
|
|
|