Refactor LLM flow, add prompts, fix UI streaming
- Extract LLM generation into `execute_llm_generation` and simplify keyword handling. - Prepend system prompt and session context to LLM prompts in `BotOrchestrator`. - Parse incoming WebSocket messages as JSON and use the `content` field. - Add async `get_session_context` and stop injecting Redis context into conversation history. - Change default LLM URL to `http://48.217.66.81:8080` throughout the project. - Use the existing DB pool instead of creating a separate custom connection. - Update `start.bas` to call LLM and set a new context string. - Refactor web client message handling: separate event processing, improve streaming logic, reset streaming state on thinking end, and remove unused test functions.
This commit is contained in:
parent
e144e013d7
commit
f626793d77
9 changed files with 171 additions and 141 deletions
|
|
@ -25,7 +25,7 @@ dirs=(
|
||||||
#"automation"
|
#"automation"
|
||||||
#"basic"
|
#"basic"
|
||||||
"bot"
|
"bot"
|
||||||
"channels"
|
#"channels"
|
||||||
#"config"
|
#"config"
|
||||||
#"context"
|
#"context"
|
||||||
#"email"
|
#"email"
|
||||||
|
|
@ -62,7 +62,6 @@ files=(
|
||||||
"$PROJECT_ROOT/src/basic/keywords/hear_talk.rs"
|
"$PROJECT_ROOT/src/basic/keywords/hear_talk.rs"
|
||||||
"$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/start.bas"
|
"$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/start.bas"
|
||||||
"$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/auth.bas"
|
"$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/auth.bas"
|
||||||
"$PROJECT_ROOT/web/index.html"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for file in "${files[@]}"; do
|
for file in "${files[@]}"; do
|
||||||
|
|
|
||||||
|
|
@ -8,25 +8,31 @@ pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||||
|
|
||||||
engine
|
engine
|
||||||
.register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| {
|
.register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| {
|
||||||
let text = context.eval_expression_tree(&inputs[0])?;
|
let text = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||||
let text_str = text.to_string();
|
|
||||||
|
|
||||||
info!("LLM processing text: {}", text_str);
|
info!("LLM processing text: {}", text);
|
||||||
|
|
||||||
let state_inner = state_clone.clone();
|
let state_inner = state_clone.clone();
|
||||||
let fut = async move {
|
let fut = execute_llm_generation(state_inner, text);
|
||||||
let prompt = text_str;
|
|
||||||
state_inner
|
|
||||||
.llm_provider
|
|
||||||
.generate(&prompt, &serde_json::Value::Null)
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("LLM call failed: {}", e))
|
|
||||||
};
|
|
||||||
|
|
||||||
let result =
|
let result =
|
||||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))?;
|
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||||
|
.map_err(|e| format!("LLM generation failed: {}", e))?;
|
||||||
|
|
||||||
Ok(Dynamic::from(result))
|
Ok(Dynamic::from(result))
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn execute_llm_generation(
|
||||||
|
state: AppState,
|
||||||
|
prompt: String,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
info!("Starting LLM generation for prompt: '{}'", prompt);
|
||||||
|
|
||||||
|
state
|
||||||
|
.llm_provider
|
||||||
|
.generate(&prompt, &serde_json::Value::Null)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("LLM call failed: {}", e).into())
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ use actix_web::{web, HttpRequest, HttpResponse, Result};
|
||||||
use actix_ws::Message as WsMessage;
|
use actix_ws::Message as WsMessage;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use log::{debug, error, info, warn};
|
use log::{debug, error, info, warn};
|
||||||
|
|
||||||
use serde_json;
|
use serde_json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
@ -274,12 +273,27 @@ impl BotOrchestrator {
|
||||||
message: &UserMessage,
|
message: &UserMessage,
|
||||||
session: &UserSession,
|
session: &UserSession,
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let system_prompt = std::env::var("SYSTEM_PROMPT").unwrap_or_default();
|
||||||
|
let context_data = {
|
||||||
|
let session_manager = self.state.session_manager.lock().await;
|
||||||
|
session_manager
|
||||||
|
.get_session_context(&session.id, &session.user_id)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut prompt = String::new();
|
||||||
|
if !system_prompt.is_empty() {
|
||||||
|
prompt.push_str(&format!("System: {}\n", system_prompt));
|
||||||
|
}
|
||||||
|
if !context_data.is_empty() {
|
||||||
|
prompt.push_str(&format!("Context: {}\n", context_data));
|
||||||
|
}
|
||||||
|
|
||||||
let history = {
|
let history = {
|
||||||
let mut session_manager = self.state.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.get_conversation_history(session.id, session.user_id)?
|
session_manager.get_conversation_history(session.id, session.user_id)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut prompt = String::new();
|
|
||||||
for (role, content) in history {
|
for (role, content) in history {
|
||||||
prompt.push_str(&format!("{}: {}\n", role, content));
|
prompt.push_str(&format!("{}: {}\n", role, content));
|
||||||
}
|
}
|
||||||
|
|
@ -351,14 +365,31 @@ impl BotOrchestrator {
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let system_prompt = std::env::var("SYSTEM_PROMPT").unwrap_or_default();
|
||||||
|
let context_data = {
|
||||||
|
let session_manager = self.state.session_manager.lock().await;
|
||||||
|
session_manager
|
||||||
|
.get_session_context(&session.id, &session.user_id)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
let prompt = {
|
let prompt = {
|
||||||
let mut sm = self.state.session_manager.lock().await;
|
let mut sm = self.state.session_manager.lock().await;
|
||||||
let history = sm.get_conversation_history(session.id, user_id)?;
|
let history = sm.get_conversation_history(session.id, user_id)?;
|
||||||
let mut p = String::new();
|
let mut p = String::new();
|
||||||
|
|
||||||
|
if !system_prompt.is_empty() {
|
||||||
|
p.push_str(&format!("System: {}\n", system_prompt));
|
||||||
|
}
|
||||||
|
if !context_data.is_empty() {
|
||||||
|
p.push_str(&format!("Context: {}\n", context_data));
|
||||||
|
}
|
||||||
|
|
||||||
for (role, content) in &history {
|
for (role, content) in &history {
|
||||||
p.push_str(&format!("{}: {}\n", role, content));
|
p.push_str(&format!("{}: {}\n", role, content));
|
||||||
}
|
}
|
||||||
p.push_str(&format!("User: {}\nAssistant:", message.content));
|
p.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"Stream prompt constructed with {} history entries",
|
"Stream prompt constructed with {} history entries",
|
||||||
history.len()
|
history.len()
|
||||||
|
|
@ -406,11 +437,17 @@ impl BotOrchestrator {
|
||||||
let mut analysis_buffer = String::new();
|
let mut analysis_buffer = String::new();
|
||||||
let mut in_analysis = false;
|
let mut in_analysis = false;
|
||||||
let mut chunk_count = 0;
|
let mut chunk_count = 0;
|
||||||
|
let mut first_word_received = false;
|
||||||
|
|
||||||
while let Some(chunk) = stream_rx.recv().await {
|
while let Some(chunk) = stream_rx.recv().await {
|
||||||
chunk_count += 1;
|
chunk_count += 1;
|
||||||
analysis_buffer.push_str(&chunk);
|
|
||||||
|
|
||||||
|
if !first_word_received && !chunk.trim().is_empty() {
|
||||||
|
first_word_received = true;
|
||||||
|
debug!("First word received in stream: '{}'", chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
analysis_buffer.push_str(&chunk);
|
||||||
if analysis_buffer.contains("<|channel|>") && !in_analysis {
|
if analysis_buffer.contains("<|channel|>") && !in_analysis {
|
||||||
in_analysis = true;
|
in_analysis = true;
|
||||||
}
|
}
|
||||||
|
|
@ -567,7 +604,6 @@ impl BotOrchestrator {
|
||||||
"Sending warning to session {} on channel {}: {}",
|
"Sending warning to session {} on channel {}: {}",
|
||||||
session_id, channel, message
|
session_id, channel, message
|
||||||
);
|
);
|
||||||
|
|
||||||
if channel == "web" {
|
if channel == "web" {
|
||||||
self.send_event(
|
self.send_event(
|
||||||
"system",
|
"system",
|
||||||
|
|
@ -615,7 +651,6 @@ impl BotOrchestrator {
|
||||||
"Triggering auto welcome for user: {}, session: {}, token: {:?}",
|
"Triggering auto welcome for user: {}, session: {}, token: {:?}",
|
||||||
user_id, session_id, token
|
user_id, session_id, token
|
||||||
);
|
);
|
||||||
|
|
||||||
let session_uuid = Uuid::parse_str(session_id).map_err(|e| {
|
let session_uuid = Uuid::parse_str(session_id).map_err(|e| {
|
||||||
error!("Invalid session ID: {}", e);
|
error!("Invalid session ID: {}", e);
|
||||||
e
|
e
|
||||||
|
|
@ -676,8 +711,8 @@ async fn websocket_handler(
|
||||||
|
|
||||||
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
||||||
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
||||||
|
|
||||||
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||||
|
|
||||||
orchestrator
|
orchestrator
|
||||||
.register_response_channel(session_id.clone(), tx.clone())
|
.register_response_channel(session_id.clone(), tx.clone())
|
||||||
.await;
|
.await;
|
||||||
|
|
@ -685,11 +720,13 @@ async fn websocket_handler(
|
||||||
data.web_adapter
|
data.web_adapter
|
||||||
.add_connection(session_id.clone(), tx.clone())
|
.add_connection(session_id.clone(), tx.clone())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
data.voice_adapter
|
data.voice_adapter
|
||||||
.add_connection(session_id.clone(), tx.clone())
|
.add_connection(session_id.clone(), tx.clone())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let bot_id = std::env::var("BOT_GUID").unwrap_or_else(|_| "default_bot".to_string());
|
let bot_id = std::env::var("BOT_GUID").unwrap_or_else(|_| "default_bot".to_string());
|
||||||
|
|
||||||
orchestrator
|
orchestrator
|
||||||
.send_event(
|
.send_event(
|
||||||
&user_id,
|
&user_id,
|
||||||
|
|
@ -749,12 +786,28 @@ async fn websocket_handler(
|
||||||
message_count += 1;
|
message_count += 1;
|
||||||
let bot_id =
|
let bot_id =
|
||||||
std::env::var("BOT_GUID").unwrap_or_else(|_| "default_bot".to_string());
|
std::env::var("BOT_GUID").unwrap_or_else(|_| "default_bot".to_string());
|
||||||
|
|
||||||
|
// Parse the text as JSON to extract the content field
|
||||||
|
let json_value: serde_json::Value = match serde_json::from_str(&text) {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Error parsing JSON message {}: {}", message_count, e);
|
||||||
|
continue; // Skip processing this message
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract content from JSON, fallback to original text if content field doesn't exist
|
||||||
|
let content = json_value["content"]
|
||||||
|
.as_str()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let user_message = UserMessage {
|
let user_message = UserMessage {
|
||||||
bot_id: bot_id,
|
bot_id: bot_id,
|
||||||
user_id: user_id_clone.clone(),
|
user_id: user_id_clone.clone(),
|
||||||
session_id: session_id_clone2.clone(),
|
session_id: session_id_clone2.clone(),
|
||||||
channel: "web".to_string(),
|
channel: "web".to_string(),
|
||||||
content: text.to_string(),
|
content: content,
|
||||||
message_type: 1,
|
message_type: 1,
|
||||||
media_url: None,
|
media_url: None,
|
||||||
timestamp: Utc::now(),
|
timestamp: Utc::now(),
|
||||||
|
|
@ -767,6 +820,7 @@ async fn websocket_handler(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
WsMessage::Close(_) => {
|
WsMessage::Close(_) => {
|
||||||
let bot_id =
|
let bot_id =
|
||||||
std::env::var("BOT_GUID").unwrap_or_else(|_| "default_bot".to_string());
|
std::env::var("BOT_GUID").unwrap_or_else(|_| "default_bot".to_string());
|
||||||
|
|
@ -809,7 +863,6 @@ async fn auth_handler(
|
||||||
web::Query(params): web::Query<HashMap<String, String>>,
|
web::Query(params): web::Query<HashMap<String, String>>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let token = params.get("token").cloned().unwrap_or_default();
|
let token = params.get("token").cloned().unwrap_or_default();
|
||||||
info!("Auth handler called with token: {}", token);
|
|
||||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap();
|
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap();
|
||||||
let bot_id = if let Ok(bot_guid) = std::env::var("BOT_GUID") {
|
let bot_id = if let Ok(bot_guid) = std::env::var("BOT_GUID") {
|
||||||
match Uuid::parse_str(&bot_guid) {
|
match Uuid::parse_str(&bot_guid) {
|
||||||
|
|
@ -827,7 +880,6 @@ async fn auth_handler(
|
||||||
|
|
||||||
let session = {
|
let session = {
|
||||||
let mut sm = data.session_manager.lock().await;
|
let mut sm = data.session_manager.lock().await;
|
||||||
|
|
||||||
match sm.get_or_create_user_session(user_id, bot_id, "Auth Session") {
|
match sm.get_or_create_user_session(user_id, bot_id, "Auth Session") {
|
||||||
Ok(Some(s)) => s,
|
Ok(Some(s)) => s,
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
|
|
@ -842,6 +894,7 @@ async fn auth_handler(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let session_id_clone = session.id.clone();
|
let session_id_clone = session.id.clone();
|
||||||
let auth_script_path = "./templates/annoucements.gbai/annoucements.gbdialog/auth.bas";
|
let auth_script_path = "./templates/annoucements.gbai/annoucements.gbdialog/auth.bas";
|
||||||
let auth_script = match std::fs::read_to_string(auth_script_path) {
|
let auth_script = match std::fs::read_to_string(auth_script_path) {
|
||||||
|
|
@ -870,7 +923,6 @@ async fn auth_handler(
|
||||||
|
|
||||||
let session = {
|
let session = {
|
||||||
let mut sm = data.session_manager.lock().await;
|
let mut sm = data.session_manager.lock().await;
|
||||||
|
|
||||||
match sm.get_session_by_id(session_id_clone) {
|
match sm.get_session_by_id(session_id_clone) {
|
||||||
Ok(Some(s)) => s,
|
Ok(Some(s)) => s,
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
|
|
@ -902,7 +954,6 @@ async fn whatsapp_webhook_verify(
|
||||||
let mode = params.get("hub.mode").unwrap_or(&empty);
|
let mode = params.get("hub.mode").unwrap_or(&empty);
|
||||||
let token = params.get("hub.verify_token").unwrap_or(&empty);
|
let token = params.get("hub.verify_token").unwrap_or(&empty);
|
||||||
let challenge = params.get("hub.challenge").unwrap_or(&empty);
|
let challenge = params.get("hub.challenge").unwrap_or(&empty);
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Verification params - mode: {}, token: {}, challenge: {}",
|
"Verification params - mode: {}, token: {}, challenge: {}",
|
||||||
mode, token, challenge
|
mode, token, challenge
|
||||||
|
|
@ -930,7 +981,6 @@ async fn voice_start(
|
||||||
.get("user_id")
|
.get("user_id")
|
||||||
.and_then(|u| u.as_str())
|
.and_then(|u| u.as_str())
|
||||||
.unwrap_or("user");
|
.unwrap_or("user");
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Voice session start request - session: {}, user: {}",
|
"Voice session start request - session: {}, user: {}",
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
|
|
@ -968,7 +1018,6 @@ async fn voice_stop(
|
||||||
.get("session_id")
|
.get("session_id")
|
||||||
.and_then(|s| s.as_str())
|
.and_then(|s| s.as_str())
|
||||||
.unwrap_or("");
|
.unwrap_or("");
|
||||||
|
|
||||||
match data.voice_adapter.stop_voice_session(session_id).await {
|
match data.voice_adapter.stop_voice_session(session_id).await {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
info!(
|
info!(
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get configuration from environment variables
|
// Get configuration from environment variables
|
||||||
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://48.217.66.81:8080".to_string());
|
||||||
let embedding_url =
|
let embedding_url =
|
||||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
let llama_cpp_path = env::var("LLM_CPP_PATH").unwrap_or_else(|_| "~/llama.cpp".to_string());
|
let llama_cpp_path = env::var("LLM_CPP_PATH").unwrap_or_else(|_| "~/llama.cpp".to_string());
|
||||||
|
|
@ -259,7 +259,7 @@ pub async fn chat_completions_local(
|
||||||
dotenv().ok().unwrap();
|
dotenv().ok().unwrap();
|
||||||
|
|
||||||
// Get llama.cpp server URL
|
// Get llama.cpp server URL
|
||||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://48.217.66.81:8080".to_string());
|
||||||
|
|
||||||
// Convert OpenAI format to llama.cpp format
|
// Convert OpenAI format to llama.cpp format
|
||||||
let prompt = messages_to_prompt(&req_body.messages);
|
let prompt = messages_to_prompt(&req_body.messages);
|
||||||
|
|
|
||||||
29
src/main.rs
29
src/main.rs
|
|
@ -72,19 +72,20 @@ async fn main() -> std::io::Result<()> {
|
||||||
cfg.database_custom.database
|
cfg.database_custom.database
|
||||||
);
|
);
|
||||||
|
|
||||||
let db_custom_pool = match diesel::Connection::establish(&custom_db_url) {
|
let db_custom_pool = db_pool.clone();
|
||||||
Ok(conn) => {
|
// match diesel::Connection::establish(&custom_db_url) {
|
||||||
info!("Connected to custom database successfully");
|
// Ok(conn) => {
|
||||||
Arc::new(Mutex::new(conn))
|
// info!("Connected to custom database successfully");
|
||||||
}
|
// Arc::new(Mutex::new(conn))
|
||||||
Err(e2) => {
|
// }
|
||||||
log::error!("Failed to connect to custom database: {}", e2);
|
// Err(e2) => {
|
||||||
return Err(std::io::Error::new(
|
// log::error!("Failed to connect to custom database: {}", e2);
|
||||||
std::io::ErrorKind::ConnectionRefused,
|
// return Err(std::io::Error::new(
|
||||||
format!("Custom Database connection failed: {}", e2),
|
// std::io::ErrorKind::ConnectionRefused,
|
||||||
));
|
// format!("Custom Database connection failed: {}", e2),
|
||||||
}
|
// ));
|
||||||
};
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
ensure_llama_servers_running()
|
ensure_llama_servers_running()
|
||||||
.await
|
.await
|
||||||
|
|
@ -103,7 +104,7 @@ async fn main() -> std::io::Result<()> {
|
||||||
|
|
||||||
let tool_manager = Arc::new(tools::ToolManager::new());
|
let tool_manager = Arc::new(tools::ToolManager::new());
|
||||||
let llama_url =
|
let llama_url =
|
||||||
std::env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
std::env::var("LLM_URL").unwrap_or_else(|_| "http://48.217.66.81:8080".to_string());
|
||||||
let llm_provider = Arc::new(crate::llm::OpenAIClient::new(
|
let llm_provider = Arc::new(crate::llm::OpenAIClient::new(
|
||||||
"empty".to_string(),
|
"empty".to_string(),
|
||||||
Some(llama_url.clone()),
|
Some(llama_url.clone()),
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,15 @@
|
||||||
|
use crate::shared::models::UserSession;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use diesel::prelude::*;
|
use diesel::prelude::*;
|
||||||
use diesel::PgConnection;
|
use diesel::PgConnection;
|
||||||
use log::{debug, error, info, warn};
|
use log::{debug, error, info, warn};
|
||||||
use redis::Client;
|
use redis::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::shared::models::UserSession;
|
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
pub struct SessionData {
|
pub struct SessionData {
|
||||||
pub id: Uuid,
|
pub id: Uuid,
|
||||||
|
|
@ -28,7 +26,6 @@ pub struct SessionManager {
|
||||||
|
|
||||||
impl SessionManager {
|
impl SessionManager {
|
||||||
pub fn new(conn: PgConnection, redis_client: Option<Arc<Client>>) -> Self {
|
pub fn new(conn: PgConnection, redis_client: Option<Arc<Client>>) -> Self {
|
||||||
info!("Initializing SessionManager");
|
|
||||||
SessionManager {
|
SessionManager {
|
||||||
conn,
|
conn,
|
||||||
sessions: HashMap::new(),
|
sessions: HashMap::new(),
|
||||||
|
|
@ -46,7 +43,6 @@ impl SessionManager {
|
||||||
"SessionManager.provide_input called for session {}",
|
"SessionManager.provide_input called for session {}",
|
||||||
session_id
|
session_id
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(sess) = self.sessions.get_mut(&session_id) {
|
if let Some(sess) = self.sessions.get_mut(&session_id) {
|
||||||
sess.data = input;
|
sess.data = input;
|
||||||
self.waiting_for_input.remove(&session_id);
|
self.waiting_for_input.remove(&session_id);
|
||||||
|
|
@ -69,7 +65,6 @@ impl SessionManager {
|
||||||
|
|
||||||
pub fn mark_waiting(&mut self, session_id: Uuid) {
|
pub fn mark_waiting(&mut self, session_id: Uuid) {
|
||||||
self.waiting_for_input.insert(session_id);
|
self.waiting_for_input.insert(session_id);
|
||||||
info!("Session {} marked as waiting for input", session_id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_session_by_id(
|
pub fn get_session_by_id(
|
||||||
|
|
@ -77,12 +72,10 @@ impl SessionManager {
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
) -> Result<Option<UserSession>, Box<dyn Error + Send + Sync>> {
|
) -> Result<Option<UserSession>, Box<dyn Error + Send + Sync>> {
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
use crate::shared::models::user_sessions::dsl::*;
|
||||||
|
|
||||||
let result = user_sessions
|
let result = user_sessions
|
||||||
.filter(id.eq(session_id))
|
.filter(id.eq(session_id))
|
||||||
.first::<UserSession>(&mut self.conn)
|
.first::<UserSession>(&mut self.conn)
|
||||||
.optional()?;
|
.optional()?;
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -92,14 +85,12 @@ impl SessionManager {
|
||||||
bid: Uuid,
|
bid: Uuid,
|
||||||
) -> Result<Option<UserSession>, Box<dyn Error + Send + Sync>> {
|
) -> Result<Option<UserSession>, Box<dyn Error + Send + Sync>> {
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
use crate::shared::models::user_sessions::dsl::*;
|
||||||
|
|
||||||
let result = user_sessions
|
let result = user_sessions
|
||||||
.filter(user_id.eq(uid))
|
.filter(user_id.eq(uid))
|
||||||
.filter(bot_id.eq(bid))
|
.filter(bot_id.eq(bid))
|
||||||
.order(created_at.desc())
|
.order(created_at.desc())
|
||||||
.first::<UserSession>(&mut self.conn)
|
.first::<UserSession>(&mut self.conn)
|
||||||
.optional()?;
|
.optional()?;
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,11 +101,8 @@ impl SessionManager {
|
||||||
session_title: &str,
|
session_title: &str,
|
||||||
) -> Result<Option<UserSession>, Box<dyn Error + Send + Sync>> {
|
) -> Result<Option<UserSession>, Box<dyn Error + Send + Sync>> {
|
||||||
if let Some(existing) = self.get_user_session(uid, bid)? {
|
if let Some(existing) = self.get_user_session(uid, bid)? {
|
||||||
debug!("Found existing session: {}", existing.id);
|
|
||||||
return Ok(Some(existing));
|
return Ok(Some(existing));
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Creating new session for user {} with bot {}", uid, bid);
|
|
||||||
self.create_session(uid, bid, session_title).map(Some)
|
self.create_session(uid, bid, session_title).map(Some)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -128,7 +116,6 @@ impl SessionManager {
|
||||||
use crate::shared::models::users::dsl as users_dsl;
|
use crate::shared::models::users::dsl as users_dsl;
|
||||||
|
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
|
||||||
let user_exists: Option<Uuid> = users_dsl::users
|
let user_exists: Option<Uuid> = users_dsl::users
|
||||||
.filter(users_dsl::id.eq(uid))
|
.filter(users_dsl::id.eq(uid))
|
||||||
.select(users_dsl::id)
|
.select(users_dsl::id)
|
||||||
|
|
@ -151,7 +138,6 @@ impl SessionManager {
|
||||||
users_dsl::updated_at.eq(now),
|
users_dsl::updated_at.eq(now),
|
||||||
))
|
))
|
||||||
.execute(&mut self.conn)?;
|
.execute(&mut self.conn)?;
|
||||||
info!("Created placeholder user: {}", uid);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let inserted: UserSession = diesel::insert_into(user_sessions)
|
let inserted: UserSession = diesel::insert_into(user_sessions)
|
||||||
|
|
@ -173,7 +159,6 @@ impl SessionManager {
|
||||||
e
|
e
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
info!("New session created: {}", inserted.id);
|
|
||||||
Ok(inserted)
|
Ok(inserted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -213,19 +198,18 @@ impl SessionManager {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_conversation_history(
|
pub async fn get_session_context(
|
||||||
&mut self,
|
&self,
|
||||||
sess_id: Uuid,
|
session_id: &Uuid,
|
||||||
_uid: Uuid,
|
user_id: &Uuid,
|
||||||
) -> Result<Vec<(String, String)>, Box<dyn Error + Send + Sync>> {
|
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||||
use crate::shared::models::message_history::dsl::*;
|
// Bring the Redis command trait into scope so we can call `get`.
|
||||||
use redis::Commands; // Import trait that provides the `get` method
|
use redis::Commands;
|
||||||
|
|
||||||
let redis_key = format!("context:{}:{}", _uid, sess_id);
|
let redis_key = format!("context:{}:{}", user_id, session_id);
|
||||||
|
if let Some(redis_client) = &self.redis {
|
||||||
// Fetch context from Redis and append to history on first retrieval
|
// Attempt to obtain a Redis connection; log and ignore errors, returning `None`.
|
||||||
let redis_context = if let Some(redis_client) = &self.redis {
|
let conn_option = redis_client
|
||||||
let conn = redis_client
|
|
||||||
.get_connection()
|
.get_connection()
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!("Failed to get Redis connection: {}", e);
|
warn!("Failed to get Redis connection: {}", e);
|
||||||
|
|
@ -233,31 +217,35 @@ impl SessionManager {
|
||||||
})
|
})
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
if let Some(mut connection) = conn {
|
if let Some(mut connection) = conn_option {
|
||||||
match connection.get::<_, Option<String>>(&redis_key) {
|
match connection.get::<_, Option<String>>(&redis_key) {
|
||||||
Ok(Some(context)) => {
|
Ok(Some(context)) => {
|
||||||
info!(
|
debug!(
|
||||||
"Retrieved context from Redis for key {}: {} chars",
|
"Retrieved context from Redis for key {}: {} chars",
|
||||||
redis_key,
|
redis_key,
|
||||||
context.len()
|
context.len()
|
||||||
);
|
);
|
||||||
Some(context)
|
return Ok(context);
|
||||||
}
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
debug!("No context found in Redis for key {}", redis_key);
|
debug!("No context found in Redis for key {}", redis_key);
|
||||||
None
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to retrieve context from Redis: {}", e);
|
warn!("Failed to retrieve context from Redis: {}", e);
|
||||||
None
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
None
|
// If Redis is unavailable or the key is missing, return an empty context.
|
||||||
};
|
Ok(String::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_conversation_history(
|
||||||
|
&mut self,
|
||||||
|
sess_id: Uuid,
|
||||||
|
_uid: Uuid,
|
||||||
|
) -> Result<Vec<(String, String)>, Box<dyn Error + Send + Sync>> {
|
||||||
|
use crate::shared::models::message_history::dsl::*;
|
||||||
|
|
||||||
let messages = message_history
|
let messages = message_history
|
||||||
.filter(session_id.eq(sess_id))
|
.filter(session_id.eq(sess_id))
|
||||||
|
|
@ -265,13 +253,7 @@ impl SessionManager {
|
||||||
.select((role, content_encrypted))
|
.select((role, content_encrypted))
|
||||||
.load::<(i32, String)>(&mut self.conn)?;
|
.load::<(i32, String)>(&mut self.conn)?;
|
||||||
|
|
||||||
// Build conversation history, inserting Redis context as a system (role 2) message if it exists
|
|
||||||
let mut history: Vec<(String, String)> = Vec::new();
|
let mut history: Vec<(String, String)> = Vec::new();
|
||||||
|
|
||||||
if let Some(ctx) = redis_context {
|
|
||||||
history.push(("system".to_string(), ctx));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (other_role, content) in messages {
|
for (other_role, content) in messages {
|
||||||
let role_str = match other_role {
|
let role_str = match other_role {
|
||||||
0 => "user".to_string(),
|
0 => "user".to_string(),
|
||||||
|
|
@ -289,12 +271,10 @@ impl SessionManager {
|
||||||
uid: Uuid,
|
uid: Uuid,
|
||||||
) -> Result<Vec<UserSession>, Box<dyn Error + Send + Sync>> {
|
) -> Result<Vec<UserSession>, Box<dyn Error + Send + Sync>> {
|
||||||
use crate::shared::models::user_sessions;
|
use crate::shared::models::user_sessions;
|
||||||
|
|
||||||
let sessions = user_sessions::table
|
let sessions = user_sessions::table
|
||||||
.filter(user_sessions::user_id.eq(uid))
|
.filter(user_sessions::user_id.eq(uid))
|
||||||
.order(user_sessions::created_at.desc())
|
.order(user_sessions::created_at.desc())
|
||||||
.load::<UserSession>(&mut self.conn)?;
|
.load::<UserSession>(&mut self.conn)?;
|
||||||
|
|
||||||
Ok(sessions)
|
Ok(sessions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -326,12 +306,11 @@ impl SessionManager {
|
||||||
if updated_count == 0 {
|
if updated_count == 0 {
|
||||||
warn!("No session found for user {} and bot {}", uid, bid);
|
warn!("No session found for user {} and bot {}", uid, bid);
|
||||||
} else {
|
} else {
|
||||||
debug!(
|
info!(
|
||||||
"Answer mode updated to {} for user {} and bot {}",
|
"Answer mode updated to {} for user {} and bot {}",
|
||||||
mode, uid, bid
|
mode, uid, bid
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -349,9 +328,8 @@ impl SessionManager {
|
||||||
if updated_count == 0 {
|
if updated_count == 0 {
|
||||||
warn!("No session found with ID: {}", session_id);
|
warn!("No session found with ID: {}", session_id);
|
||||||
} else {
|
} else {
|
||||||
info!("Updated session {} to user ID: {}", session_id, new_user_id);
|
debug!("Updated user ID for session {}", session_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ impl Default for AppState {
|
||||||
tool_manager: Arc::new(ToolManager::new()),
|
tool_manager: Arc::new(ToolManager::new()),
|
||||||
llm_provider: Arc::new(crate::llm::OpenAIClient::new(
|
llm_provider: Arc::new(crate::llm::OpenAIClient::new(
|
||||||
"empty".to_string(),
|
"empty".to_string(),
|
||||||
Some("http://localhost:8081".to_string()),
|
Some("http://48.217.66.81:8080".to_string()),
|
||||||
)),
|
)),
|
||||||
auth_service: Arc::new(tokio::sync::Mutex::new(AuthService::new(
|
auth_service: Arc::new(tokio::sync::Mutex::new(AuthService::new(
|
||||||
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
TALK "Olá, estou preparando um resumo para você."
|
TALK "Olá, estou preparando um resumo para você."
|
||||||
|
let x = LLM "Quando é 5+5?"
|
||||||
SET_CONTEXT "azul bolinha"
|
TALK x
|
||||||
|
SET_CONTEXT "Este é o documento que você deve usar para responder dúvidas: O céu é azul."
|
||||||
|
|
||||||
REM text = GET "default.pdf"
|
REM text = GET "default.pdf"
|
||||||
REM resume = LLM "Say Hello and present a a resume from " + text
|
REM resume = LLM "Say Hello and present a a resume from " + text
|
||||||
|
|
|
||||||
|
|
@ -278,38 +278,26 @@
|
||||||
|
|
||||||
ws.onmessage = function (event) {
|
ws.onmessage = function (event) {
|
||||||
const response = JSON.parse(event.data);
|
const response = JSON.parse(event.data);
|
||||||
|
|
||||||
|
// Handle events first
|
||||||
if (response.message_type === 2) {
|
if (response.message_type === 2) {
|
||||||
const eventData = JSON.parse(response.content);
|
const eventData = JSON.parse(response.content);
|
||||||
handleEvent(eventData.event, eventData.data);
|
handleEvent(eventData.event, eventData.data);
|
||||||
|
|
||||||
|
// Don't return - allow regular message processing to continue
|
||||||
|
// in case this event packet also contains response content
|
||||||
|
if (
|
||||||
|
response.content &&
|
||||||
|
response.content.trim() !== "" &&
|
||||||
|
!response.is_complete
|
||||||
|
) {
|
||||||
|
processMessageContent(response);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle normal complete messages (like from TALK)
|
// Process regular messages
|
||||||
if (response.is_complete) {
|
processMessageContent(response);
|
||||||
isStreaming = false;
|
|
||||||
streamingMessageId = null;
|
|
||||||
// Only add message if there's actual content
|
|
||||||
if (
|
|
||||||
response.content &&
|
|
||||||
response.content.trim() !== ""
|
|
||||||
) {
|
|
||||||
addMessage("assistant", response.content);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Handle streaming messages
|
|
||||||
if (!isStreaming) {
|
|
||||||
isStreaming = true;
|
|
||||||
streamingMessageId = "streaming-" + Date.now();
|
|
||||||
addMessage(
|
|
||||||
"assistant",
|
|
||||||
response.content,
|
|
||||||
true,
|
|
||||||
streamingMessageId,
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
updateLastMessage(response.content);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
ws.onopen = function () {
|
ws.onopen = function () {
|
||||||
|
|
@ -330,6 +318,29 @@
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function processMessageContent(response) {
|
||||||
|
if (response.is_complete) {
|
||||||
|
isStreaming = false;
|
||||||
|
streamingMessageId = null;
|
||||||
|
if (response.content && response.content.trim() !== "") {
|
||||||
|
addMessage("assistant", response.content);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!isStreaming) {
|
||||||
|
isStreaming = true;
|
||||||
|
streamingMessageId = "streaming-" + Date.now();
|
||||||
|
addMessage(
|
||||||
|
"assistant",
|
||||||
|
response.content,
|
||||||
|
true,
|
||||||
|
streamingMessageId,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
updateLastMessage(response.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function handleEvent(eventType, eventData) {
|
function handleEvent(eventType, eventData) {
|
||||||
console.log("Event received:", eventType, eventData);
|
console.log("Event received:", eventType, eventData);
|
||||||
switch (eventType) {
|
switch (eventType) {
|
||||||
|
|
@ -340,6 +351,9 @@
|
||||||
case "thinking_end":
|
case "thinking_end":
|
||||||
hideThinkingIndicator();
|
hideThinkingIndicator();
|
||||||
isThinking = false;
|
isThinking = false;
|
||||||
|
// Reset streaming state to ensure clean start for next message
|
||||||
|
isStreaming = false;
|
||||||
|
streamingMessageId = null;
|
||||||
break;
|
break;
|
||||||
case "warn":
|
case "warn":
|
||||||
showWarning(eventData.message);
|
showWarning(eventData.message);
|
||||||
|
|
@ -647,9 +661,7 @@
|
||||||
window.testWarning = function () {
|
window.testWarning = function () {
|
||||||
fetch("/api/warn", {
|
fetch("/api/warn", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: { "Content-Type": "application/json" },
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
session_id: currentSessionId || "default",
|
session_id: currentSessionId || "default",
|
||||||
channel: "web",
|
channel: "web",
|
||||||
|
|
@ -658,22 +670,6 @@
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
window.testMarkdown = function () {
|
|
||||||
const markdownContent = `Tabela de Frutas
|
|
||||||
| Nome da fruta | Cor predominante | Estação em que costuma estar em flor/colheita |
|
|
||||||
|---------------|------------------|-----------------------------------------------|
|
|
||||||
| Maçã | Verde, vermelho | Outono (início do verão em países de clima temperado) |
|
|
||||||
| Banana | Amarelo | Todo ano (principalmente nas regiões tropicais) |
|
|
||||||
| Laranja | Laranja | Inverno (pico de colheita em países de clima temperado) |
|
|
||||||
**Nota**: As informações sobre a estação de colheita são gerais e podem variar de acordo com a região e variedade da fruta.`;
|
|
||||||
addMessage("assistant", markdownContent);
|
|
||||||
};
|
|
||||||
|
|
||||||
window.testProblematicTable = function () {
|
|
||||||
const problematicContent = `Tabela de Frutas**| Nome da fruta | Cor predominante | Estação em que costuma estar em flor/colheita |||||| Maçã | Verde, vermelho | Outono (início do verão em países de clima temperado) || Banana | Amarelo | Todo ano (principalmente nas regiões tropicais) || Laranja | Laranja | Inverno (pico de colheita em países de clima temperado) | Nota: As informações sobre a estação de colheita são gerais e podem variar de acordo com a região e variedade da fruta`;
|
|
||||||
addMessage("assistant", problematicContent);
|
|
||||||
};
|
|
||||||
|
|
||||||
initializeAuth();
|
initializeAuth();
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue