938 lines
33 KiB
Rust
938 lines
33 KiB
Rust
|
|
use crate::shared::models::UserSession;
|
||
|
|
use crate::shared::state::AppState;
|
||
|
|
use diesel::prelude::*;
|
||
|
|
use log::{error, info, trace, warn};
|
||
|
|
use rhai::{Dynamic, Engine};
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use std::collections::HashMap;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use uuid::Uuid;
|
||
|
|
|
||
|
|
/// A2A Protocol Message Types
|
||
|
|
/// Based on https://a2a-protocol.org/latest/
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
|
|
pub enum A2AMessageType {
|
||
|
|
/// Agent requesting action from another agent
|
||
|
|
Request,
|
||
|
|
/// Agent responding to a request
|
||
|
|
Response,
|
||
|
|
/// Message broadcast to all agents in session
|
||
|
|
Broadcast,
|
||
|
|
/// Hand off conversation to another agent
|
||
|
|
Delegate,
|
||
|
|
/// Request collaboration on a task
|
||
|
|
Collaborate,
|
||
|
|
/// Acknowledge receipt of message
|
||
|
|
Ack,
|
||
|
|
/// Error response
|
||
|
|
Error,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for A2AMessageType {
|
||
|
|
fn default() -> Self {
|
||
|
|
A2AMessageType::Request
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl std::fmt::Display for A2AMessageType {
|
||
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
|
match self {
|
||
|
|
A2AMessageType::Request => write!(f, "request"),
|
||
|
|
A2AMessageType::Response => write!(f, "response"),
|
||
|
|
A2AMessageType::Broadcast => write!(f, "broadcast"),
|
||
|
|
A2AMessageType::Delegate => write!(f, "delegate"),
|
||
|
|
A2AMessageType::Collaborate => write!(f, "collaborate"),
|
||
|
|
A2AMessageType::Ack => write!(f, "ack"),
|
||
|
|
A2AMessageType::Error => write!(f, "error"),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl From<&str> for A2AMessageType {
|
||
|
|
fn from(s: &str) -> Self {
|
||
|
|
match s.to_lowercase().as_str() {
|
||
|
|
"request" => A2AMessageType::Request,
|
||
|
|
"response" => A2AMessageType::Response,
|
||
|
|
"broadcast" => A2AMessageType::Broadcast,
|
||
|
|
"delegate" => A2AMessageType::Delegate,
|
||
|
|
"collaborate" => A2AMessageType::Collaborate,
|
||
|
|
"ack" => A2AMessageType::Ack,
|
||
|
|
"error" => A2AMessageType::Error,
|
||
|
|
_ => A2AMessageType::Request,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// A2A Protocol Message
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct A2AMessage {
|
||
|
|
/// Unique message identifier
|
||
|
|
pub id: Uuid,
|
||
|
|
/// Source agent identifier
|
||
|
|
pub from_agent: String,
|
||
|
|
/// Target agent identifier (None for broadcasts)
|
||
|
|
pub to_agent: Option<String>,
|
||
|
|
/// Message type
|
||
|
|
pub message_type: A2AMessageType,
|
||
|
|
/// Message payload (JSON)
|
||
|
|
pub payload: serde_json::Value,
|
||
|
|
/// Correlation ID for request-response matching
|
||
|
|
pub correlation_id: Uuid,
|
||
|
|
/// Session ID
|
||
|
|
pub session_id: Uuid,
|
||
|
|
/// Timestamp
|
||
|
|
pub timestamp: chrono::DateTime<chrono::Utc>,
|
||
|
|
/// Optional metadata
|
||
|
|
pub metadata: HashMap<String, String>,
|
||
|
|
/// TTL in seconds (0 = no expiry)
|
||
|
|
pub ttl_seconds: u32,
|
||
|
|
/// Hop count for preventing infinite loops
|
||
|
|
pub hop_count: u32,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl A2AMessage {
|
||
|
|
pub fn new(
|
||
|
|
from_agent: &str,
|
||
|
|
to_agent: Option<&str>,
|
||
|
|
message_type: A2AMessageType,
|
||
|
|
payload: serde_json::Value,
|
||
|
|
session_id: Uuid,
|
||
|
|
) -> Self {
|
||
|
|
Self {
|
||
|
|
id: Uuid::new_v4(),
|
||
|
|
from_agent: from_agent.to_string(),
|
||
|
|
to_agent: to_agent.map(|s| s.to_string()),
|
||
|
|
message_type,
|
||
|
|
payload,
|
||
|
|
correlation_id: Uuid::new_v4(),
|
||
|
|
session_id,
|
||
|
|
timestamp: chrono::Utc::now(),
|
||
|
|
metadata: HashMap::new(),
|
||
|
|
ttl_seconds: 30,
|
||
|
|
hop_count: 0,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Create a response to this message
|
||
|
|
pub fn create_response(&self, from_agent: &str, payload: serde_json::Value) -> Self {
|
||
|
|
Self {
|
||
|
|
id: Uuid::new_v4(),
|
||
|
|
from_agent: from_agent.to_string(),
|
||
|
|
to_agent: Some(self.from_agent.clone()),
|
||
|
|
message_type: A2AMessageType::Response,
|
||
|
|
payload,
|
||
|
|
correlation_id: self.correlation_id,
|
||
|
|
session_id: self.session_id,
|
||
|
|
timestamp: chrono::Utc::now(),
|
||
|
|
metadata: HashMap::new(),
|
||
|
|
ttl_seconds: 30,
|
||
|
|
hop_count: self.hop_count + 1,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check if message has expired
|
||
|
|
pub fn is_expired(&self) -> bool {
|
||
|
|
if self.ttl_seconds == 0 {
|
||
|
|
return false;
|
||
|
|
}
|
||
|
|
let now = chrono::Utc::now();
|
||
|
|
let expiry = self.timestamp + chrono::Duration::seconds(self.ttl_seconds as i64);
|
||
|
|
now > expiry
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check if max hops exceeded
|
||
|
|
pub fn max_hops_exceeded(&self, max_hops: u32) -> bool {
|
||
|
|
self.hop_count >= max_hops
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// A2A Protocol Configuration
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct A2AConfig {
|
||
|
|
/// Whether A2A protocol is enabled
|
||
|
|
pub enabled: bool,
|
||
|
|
/// Default timeout in seconds
|
||
|
|
pub timeout_seconds: u32,
|
||
|
|
/// Maximum hops to prevent infinite loops
|
||
|
|
pub max_hops: u32,
|
||
|
|
/// Protocol version
|
||
|
|
pub protocol_version: String,
|
||
|
|
/// Enable message persistence
|
||
|
|
pub persist_messages: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for A2AConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
enabled: true,
|
||
|
|
timeout_seconds: 30,
|
||
|
|
max_hops: 5,
|
||
|
|
protocol_version: "1.0".to_string(),
|
||
|
|
persist_messages: true,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Load A2A configuration from bot config
|
||
|
|
pub fn load_a2a_config(state: &AppState, bot_id: Uuid) -> A2AConfig {
|
||
|
|
let mut config = A2AConfig::default();
|
||
|
|
|
||
|
|
if let Ok(mut conn) = state.conn.get() {
|
||
|
|
#[derive(QueryableByName)]
|
||
|
|
struct ConfigRow {
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||
|
|
config_key: String,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||
|
|
config_value: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
let configs: Vec<ConfigRow> = diesel::sql_query(
|
||
|
|
"SELECT config_key, config_value FROM bot_configuration \
|
||
|
|
WHERE bot_id = $1 AND config_key LIKE 'a2a-%'",
|
||
|
|
)
|
||
|
|
.bind::<diesel::sql_types::Uuid, _>(bot_id)
|
||
|
|
.load(&mut conn)
|
||
|
|
.unwrap_or_default();
|
||
|
|
|
||
|
|
for row in configs {
|
||
|
|
match row.config_key.as_str() {
|
||
|
|
"a2a-enabled" => {
|
||
|
|
config.enabled = row.config_value.to_lowercase() == "true";
|
||
|
|
}
|
||
|
|
"a2a-timeout" => {
|
||
|
|
config.timeout_seconds = row.config_value.parse().unwrap_or(30);
|
||
|
|
}
|
||
|
|
"a2a-max-hops" => {
|
||
|
|
config.max_hops = row.config_value.parse().unwrap_or(5);
|
||
|
|
}
|
||
|
|
"a2a-protocol-version" => {
|
||
|
|
config.protocol_version = row.config_value;
|
||
|
|
}
|
||
|
|
"a2a-persist-messages" => {
|
||
|
|
config.persist_messages = row.config_value.to_lowercase() == "true";
|
||
|
|
}
|
||
|
|
_ => {}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
config
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Register all A2A protocol keywords
|
||
|
|
pub fn register_a2a_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||
|
|
send_to_bot_keyword(state.clone(), user.clone(), engine);
|
||
|
|
broadcast_message_keyword(state.clone(), user.clone(), engine);
|
||
|
|
collaborate_with_keyword(state.clone(), user.clone(), engine);
|
||
|
|
wait_for_bot_keyword(state.clone(), user.clone(), engine);
|
||
|
|
delegate_conversation_keyword(state.clone(), user.clone(), engine);
|
||
|
|
get_a2a_messages_keyword(state.clone(), user.clone(), engine);
|
||
|
|
}
|
||
|
|
|
||
|
|
/// SEND TO BOT "bot_name" MESSAGE "message_content"
|
||
|
|
/// Send a message to a specific bot
|
||
|
|
pub fn send_to_bot_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||
|
|
let state_clone = Arc::clone(&state);
|
||
|
|
let user_clone = user.clone();
|
||
|
|
|
||
|
|
engine
|
||
|
|
.register_custom_syntax(
|
||
|
|
&["SEND", "TO", "BOT", "$expr$", "MESSAGE", "$expr$"],
|
||
|
|
false,
|
||
|
|
move |context, inputs| {
|
||
|
|
let target_bot = context
|
||
|
|
.eval_expression_tree(&inputs[0])?
|
||
|
|
.to_string()
|
||
|
|
.trim_matches('"')
|
||
|
|
.to_string();
|
||
|
|
let message_content = context
|
||
|
|
.eval_expression_tree(&inputs[1])?
|
||
|
|
.to_string()
|
||
|
|
.trim_matches('"')
|
||
|
|
.to_string();
|
||
|
|
|
||
|
|
trace!(
|
||
|
|
"SEND TO BOT '{}' MESSAGE for session: {}",
|
||
|
|
target_bot,
|
||
|
|
user_clone.id
|
||
|
|
);
|
||
|
|
|
||
|
|
let state_for_task = Arc::clone(&state_clone);
|
||
|
|
let session_id = user_clone.id;
|
||
|
|
let bot_id = user_clone.bot_id;
|
||
|
|
let from_bot = format!("bot_{}", bot_id);
|
||
|
|
|
||
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
||
|
|
|
||
|
|
std::thread::spawn(move || {
|
||
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
||
|
|
let result = rt.block_on(async {
|
||
|
|
send_a2a_message(
|
||
|
|
&state_for_task,
|
||
|
|
session_id,
|
||
|
|
&from_bot,
|
||
|
|
Some(&target_bot),
|
||
|
|
A2AMessageType::Request,
|
||
|
|
serde_json::json!({ "content": message_content }),
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
});
|
||
|
|
let _ = tx.send(result);
|
||
|
|
});
|
||
|
|
|
||
|
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
||
|
|
Ok(Ok(msg_id)) => Ok(Dynamic::from(msg_id.to_string())),
|
||
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||
|
|
e.into(),
|
||
|
|
rhai::Position::NONE,
|
||
|
|
))),
|
||
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||
|
|
"SEND TO BOT timed out".into(),
|
||
|
|
rhai::Position::NONE,
|
||
|
|
))),
|
||
|
|
}
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.expect("Failed to register SEND TO BOT syntax");
|
||
|
|
}
|
||
|
|
|
||
|
|
/// BROADCAST MESSAGE "message_content"
|
||
|
|
/// Broadcast a message to all bots in the session
|
||
|
|
pub fn broadcast_message_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||
|
|
let state_clone = Arc::clone(&state);
|
||
|
|
let user_clone = user.clone();
|
||
|
|
|
||
|
|
engine
|
||
|
|
.register_custom_syntax(
|
||
|
|
&["BROADCAST", "MESSAGE", "$expr$"],
|
||
|
|
false,
|
||
|
|
move |context, inputs| {
|
||
|
|
let message_content = context
|
||
|
|
.eval_expression_tree(&inputs[0])?
|
||
|
|
.to_string()
|
||
|
|
.trim_matches('"')
|
||
|
|
.to_string();
|
||
|
|
|
||
|
|
trace!("BROADCAST MESSAGE for session: {}", user_clone.id);
|
||
|
|
|
||
|
|
let state_for_task = Arc::clone(&state_clone);
|
||
|
|
let session_id = user_clone.id;
|
||
|
|
let bot_id = user_clone.bot_id;
|
||
|
|
let from_bot = format!("bot_{}", bot_id);
|
||
|
|
|
||
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
||
|
|
|
||
|
|
std::thread::spawn(move || {
|
||
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
||
|
|
let result = rt.block_on(async {
|
||
|
|
send_a2a_message(
|
||
|
|
&state_for_task,
|
||
|
|
session_id,
|
||
|
|
&from_bot,
|
||
|
|
None, // No target = broadcast
|
||
|
|
A2AMessageType::Broadcast,
|
||
|
|
serde_json::json!({ "content": message_content }),
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
});
|
||
|
|
let _ = tx.send(result);
|
||
|
|
});
|
||
|
|
|
||
|
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
||
|
|
Ok(Ok(msg_id)) => Ok(Dynamic::from(msg_id.to_string())),
|
||
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||
|
|
e.into(),
|
||
|
|
rhai::Position::NONE,
|
||
|
|
))),
|
||
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||
|
|
"BROADCAST MESSAGE timed out".into(),
|
||
|
|
rhai::Position::NONE,
|
||
|
|
))),
|
||
|
|
}
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.expect("Failed to register BROADCAST MESSAGE syntax");
|
||
|
|
}
|
||
|
|
|
||
|
|
/// COLLABORATE WITH "bot1", "bot2" ON "task"
|
||
|
|
/// Request collaboration from multiple bots on a task
|
||
|
|
pub fn collaborate_with_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||
|
|
let state_clone = Arc::clone(&state);
|
||
|
|
let user_clone = user.clone();
|
||
|
|
|
||
|
|
engine
|
||
|
|
.register_custom_syntax(
|
||
|
|
&["COLLABORATE", "WITH", "$expr$", "ON", "$expr$"],
|
||
|
|
false,
|
||
|
|
move |context, inputs| {
|
||
|
|
let bots_str = context
|
||
|
|
.eval_expression_tree(&inputs[0])?
|
||
|
|
.to_string()
|
||
|
|
.trim_matches('"')
|
||
|
|
.to_string();
|
||
|
|
let task = context
|
||
|
|
.eval_expression_tree(&inputs[1])?
|
||
|
|
.to_string()
|
||
|
|
.trim_matches('"')
|
||
|
|
.to_string();
|
||
|
|
|
||
|
|
let bots: Vec<String> = bots_str
|
||
|
|
.split(',')
|
||
|
|
.map(|s| s.trim().trim_matches('"').to_string())
|
||
|
|
.filter(|s| !s.is_empty())
|
||
|
|
.collect();
|
||
|
|
|
||
|
|
trace!(
|
||
|
|
"COLLABORATE WITH {:?} ON '{}' for session: {}",
|
||
|
|
bots,
|
||
|
|
task,
|
||
|
|
user_clone.id
|
||
|
|
);
|
||
|
|
|
||
|
|
let state_for_task = Arc::clone(&state_clone);
|
||
|
|
let session_id = user_clone.id;
|
||
|
|
let bot_id = user_clone.bot_id;
|
||
|
|
let from_bot = format!("bot_{}", bot_id);
|
||
|
|
|
||
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
||
|
|
|
||
|
|
std::thread::spawn(move || {
|
||
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
||
|
|
let result = rt.block_on(async {
|
||
|
|
let mut message_ids = Vec::new();
|
||
|
|
for target_bot in &bots {
|
||
|
|
match send_a2a_message(
|
||
|
|
&state_for_task,
|
||
|
|
session_id,
|
||
|
|
&from_bot,
|
||
|
|
Some(target_bot),
|
||
|
|
A2AMessageType::Collaborate,
|
||
|
|
serde_json::json!({
|
||
|
|
"task": task,
|
||
|
|
"collaborators": bots.clone()
|
||
|
|
}),
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
{
|
||
|
|
Ok(id) => message_ids.push(id.to_string()),
|
||
|
|
Err(e) => {
|
||
|
|
warn!(
|
||
|
|
"Failed to send collaboration request to {}: {}",
|
||
|
|
target_bot, e
|
||
|
|
);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Ok::<Vec<String>, String>(message_ids)
|
||
|
|
});
|
||
|
|
let _ = tx.send(result);
|
||
|
|
});
|
||
|
|
|
||
|
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
||
|
|
Ok(Ok(ids)) => {
|
||
|
|
let array: rhai::Array = ids.into_iter().map(Dynamic::from).collect();
|
||
|
|
Ok(Dynamic::from(array))
|
||
|
|
}
|
||
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||
|
|
e.into(),
|
||
|
|
rhai::Position::NONE,
|
||
|
|
))),
|
||
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||
|
|
"COLLABORATE WITH timed out".into(),
|
||
|
|
rhai::Position::NONE,
|
||
|
|
))),
|
||
|
|
}
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.expect("Failed to register COLLABORATE WITH syntax");
|
||
|
|
}
|
||
|
|
|
||
|
|
/// response = WAIT FOR BOT "bot_name" TIMEOUT seconds
|
||
|
|
/// Wait for a response from a specific bot
|
||
|
|
pub fn wait_for_bot_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||
|
|
let state_clone = Arc::clone(&state);
|
||
|
|
let user_clone = user.clone();
|
||
|
|
|
||
|
|
engine
|
||
|
|
.register_custom_syntax(
|
||
|
|
&["WAIT", "FOR", "BOT", "$expr$", "TIMEOUT", "$expr$"],
|
||
|
|
false,
|
||
|
|
move |context, inputs| {
|
||
|
|
let target_bot = context
|
||
|
|
.eval_expression_tree(&inputs[0])?
|
||
|
|
.to_string()
|
||
|
|
.trim_matches('"')
|
||
|
|
.to_string();
|
||
|
|
let timeout_secs: i64 = context
|
||
|
|
.eval_expression_tree(&inputs[1])?
|
||
|
|
.as_int()
|
||
|
|
.unwrap_or(30);
|
||
|
|
|
||
|
|
trace!(
|
||
|
|
"WAIT FOR BOT '{}' TIMEOUT {} for session: {}",
|
||
|
|
target_bot,
|
||
|
|
timeout_secs,
|
||
|
|
user_clone.id
|
||
|
|
);
|
||
|
|
|
||
|
|
let state_for_task = Arc::clone(&state_clone);
|
||
|
|
let session_id = user_clone.id;
|
||
|
|
let bot_id = user_clone.bot_id;
|
||
|
|
let current_bot = format!("bot_{}", bot_id);
|
||
|
|
|
||
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
||
|
|
|
||
|
|
std::thread::spawn(move || {
|
||
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
||
|
|
let result = rt.block_on(async {
|
||
|
|
wait_for_bot_response(
|
||
|
|
&state_for_task,
|
||
|
|
session_id,
|
||
|
|
&target_bot,
|
||
|
|
¤t_bot,
|
||
|
|
timeout_secs as u64,
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
});
|
||
|
|
let _ = tx.send(result);
|
||
|
|
});
|
||
|
|
|
||
|
|
let timeout_duration = std::time::Duration::from_secs(timeout_secs as u64 + 5);
|
||
|
|
match rx.recv_timeout(timeout_duration) {
|
||
|
|
Ok(Ok(response)) => Ok(Dynamic::from(response)),
|
||
|
|
Ok(Err(e)) => Ok(Dynamic::from(format!("Error: {}", e))),
|
||
|
|
Err(_) => Ok(Dynamic::from("Timeout waiting for response".to_string())),
|
||
|
|
}
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.expect("Failed to register WAIT FOR BOT syntax");
|
||
|
|
}
|
||
|
|
|
||
|
|
/// DELEGATE CONVERSATION TO "bot_name"
|
||
|
|
/// Hand off the entire conversation to another bot
|
||
|
|
pub fn delegate_conversation_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||
|
|
let state_clone = Arc::clone(&state);
|
||
|
|
let user_clone = user.clone();
|
||
|
|
|
||
|
|
engine
|
||
|
|
.register_custom_syntax(
|
||
|
|
&["DELEGATE", "CONVERSATION", "TO", "$expr$"],
|
||
|
|
false,
|
||
|
|
move |context, inputs| {
|
||
|
|
let target_bot = context
|
||
|
|
.eval_expression_tree(&inputs[0])?
|
||
|
|
.to_string()
|
||
|
|
.trim_matches('"')
|
||
|
|
.to_string();
|
||
|
|
|
||
|
|
trace!(
|
||
|
|
"DELEGATE CONVERSATION TO '{}' for session: {}",
|
||
|
|
target_bot,
|
||
|
|
user_clone.id
|
||
|
|
);
|
||
|
|
|
||
|
|
let state_for_task = Arc::clone(&state_clone);
|
||
|
|
let session_id = user_clone.id;
|
||
|
|
let bot_id = user_clone.bot_id;
|
||
|
|
let from_bot = format!("bot_{}", bot_id);
|
||
|
|
|
||
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
||
|
|
|
||
|
|
std::thread::spawn(move || {
|
||
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
||
|
|
let result = rt.block_on(async {
|
||
|
|
// Send delegate message
|
||
|
|
send_a2a_message(
|
||
|
|
&state_for_task,
|
||
|
|
session_id,
|
||
|
|
&from_bot,
|
||
|
|
Some(&target_bot),
|
||
|
|
A2AMessageType::Delegate,
|
||
|
|
serde_json::json!({
|
||
|
|
"action": "take_over",
|
||
|
|
"reason": "delegation_request"
|
||
|
|
}),
|
||
|
|
)
|
||
|
|
.await?;
|
||
|
|
|
||
|
|
// Update session to set new active bot
|
||
|
|
set_session_active_bot(&state_for_task, session_id, &target_bot).await
|
||
|
|
});
|
||
|
|
let _ = tx.send(result);
|
||
|
|
});
|
||
|
|
|
||
|
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
||
|
|
Ok(Ok(msg)) => Ok(Dynamic::from(msg)),
|
||
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||
|
|
e.into(),
|
||
|
|
rhai::Position::NONE,
|
||
|
|
))),
|
||
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
||
|
|
"DELEGATE CONVERSATION timed out".into(),
|
||
|
|
rhai::Position::NONE,
|
||
|
|
))),
|
||
|
|
}
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.expect("Failed to register DELEGATE CONVERSATION syntax");
|
||
|
|
}
|
||
|
|
|
||
|
|
/// GET A2A MESSAGES()
|
||
|
|
/// Get all A2A messages for this bot in current session
|
||
|
|
pub fn get_a2a_messages_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||
|
|
let state_clone = Arc::clone(&state);
|
||
|
|
let user_clone = user.clone();
|
||
|
|
|
||
|
|
engine.register_fn("GET A2A MESSAGES", move || -> rhai::Array {
|
||
|
|
let state = Arc::clone(&state_clone);
|
||
|
|
let session_id = user_clone.id;
|
||
|
|
let bot_id = user_clone.bot_id;
|
||
|
|
let current_bot = format!("bot_{}", bot_id);
|
||
|
|
|
||
|
|
if let Ok(mut conn) = state.conn.get() {
|
||
|
|
get_pending_messages_sync(&mut conn, session_id, ¤t_bot)
|
||
|
|
.unwrap_or_default()
|
||
|
|
.into_iter()
|
||
|
|
.map(|msg| Dynamic::from(serde_json::to_string(&msg).unwrap_or_default()))
|
||
|
|
.collect()
|
||
|
|
} else {
|
||
|
|
rhai::Array::new()
|
||
|
|
}
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Database Operations
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
/// Send an A2A message
|
||
|
|
async fn send_a2a_message(
|
||
|
|
state: &AppState,
|
||
|
|
session_id: Uuid,
|
||
|
|
from_agent: &str,
|
||
|
|
to_agent: Option<&str>,
|
||
|
|
message_type: A2AMessageType,
|
||
|
|
payload: serde_json::Value,
|
||
|
|
) -> Result<Uuid, String> {
|
||
|
|
let mut conn = state
|
||
|
|
.conn
|
||
|
|
.get()
|
||
|
|
.map_err(|e| format!("Failed to acquire database connection: {}", e))?;
|
||
|
|
|
||
|
|
let message = A2AMessage::new(from_agent, to_agent, message_type, payload, session_id);
|
||
|
|
let message_id = message.id;
|
||
|
|
|
||
|
|
let payload_str = serde_json::to_string(&message.payload)
|
||
|
|
.map_err(|e| format!("Failed to serialize payload: {}", e))?;
|
||
|
|
|
||
|
|
let metadata_str = serde_json::to_string(&message.metadata)
|
||
|
|
.map_err(|e| format!("Failed to serialize metadata: {}", e))?;
|
||
|
|
|
||
|
|
diesel::sql_query(
|
||
|
|
"INSERT INTO a2a_messages \
|
||
|
|
(id, session_id, from_agent, to_agent, message_type, payload, correlation_id, \
|
||
|
|
timestamp, metadata, ttl_seconds, hop_count, processed) \
|
||
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, false)",
|
||
|
|
)
|
||
|
|
.bind::<diesel::sql_types::Uuid, _>(message.id)
|
||
|
|
.bind::<diesel::sql_types::Uuid, _>(message.session_id)
|
||
|
|
.bind::<diesel::sql_types::Text, _>(&message.from_agent)
|
||
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(message.to_agent.as_deref())
|
||
|
|
.bind::<diesel::sql_types::Text, _>(message.message_type.to_string())
|
||
|
|
.bind::<diesel::sql_types::Text, _>(&payload_str)
|
||
|
|
.bind::<diesel::sql_types::Uuid, _>(message.correlation_id)
|
||
|
|
.bind::<diesel::sql_types::Timestamptz, _>(message.timestamp)
|
||
|
|
.bind::<diesel::sql_types::Text, _>(&metadata_str)
|
||
|
|
.bind::<diesel::sql_types::Integer, _>(message.ttl_seconds as i32)
|
||
|
|
.bind::<diesel::sql_types::Integer, _>(message.hop_count as i32)
|
||
|
|
.execute(&mut conn)
|
||
|
|
.map_err(|e| format!("Failed to insert A2A message: {}", e))?;
|
||
|
|
|
||
|
|
info!(
|
||
|
|
"A2A message sent: {} -> {:?} (type: {})",
|
||
|
|
from_agent, to_agent, message.message_type
|
||
|
|
);
|
||
|
|
|
||
|
|
Ok(message_id)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Wait for a response from a specific bot
|
||
|
|
async fn wait_for_bot_response(
|
||
|
|
state: &AppState,
|
||
|
|
session_id: Uuid,
|
||
|
|
from_bot: &str,
|
||
|
|
to_bot: &str,
|
||
|
|
timeout_secs: u64,
|
||
|
|
) -> Result<String, String> {
|
||
|
|
let start = std::time::Instant::now();
|
||
|
|
let timeout = std::time::Duration::from_secs(timeout_secs);
|
||
|
|
|
||
|
|
loop {
|
||
|
|
if start.elapsed() > timeout {
|
||
|
|
return Err("Timeout waiting for bot response".to_string());
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check for response
|
||
|
|
if let Ok(mut conn) = state.conn.get() {
|
||
|
|
#[derive(QueryableByName)]
|
||
|
|
struct MessageRow {
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Uuid)]
|
||
|
|
id: Uuid,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||
|
|
payload: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
let result: Option<MessageRow> = diesel::sql_query(
|
||
|
|
"SELECT id, payload FROM a2a_messages \
|
||
|
|
WHERE session_id = $1 AND from_agent = $2 AND to_agent = $3 \
|
||
|
|
AND message_type = 'response' AND processed = false \
|
||
|
|
ORDER BY timestamp DESC LIMIT 1",
|
||
|
|
)
|
||
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
||
|
|
.bind::<diesel::sql_types::Text, _>(from_bot)
|
||
|
|
.bind::<diesel::sql_types::Text, _>(to_bot)
|
||
|
|
.get_result(&mut conn)
|
||
|
|
.optional()
|
||
|
|
.map_err(|e| format!("Failed to query messages: {}", e))?;
|
||
|
|
|
||
|
|
if let Some(msg) = result {
|
||
|
|
// Mark as processed
|
||
|
|
let _ = diesel::sql_query("UPDATE a2a_messages SET processed = true WHERE id = $1")
|
||
|
|
.bind::<diesel::sql_types::Uuid, _>(msg.id)
|
||
|
|
.execute(&mut conn);
|
||
|
|
|
||
|
|
// Extract content from payload
|
||
|
|
if let Ok(payload) = serde_json::from_str::<serde_json::Value>(&msg.payload) {
|
||
|
|
if let Some(content) = payload.get("content").and_then(|c| c.as_str()) {
|
||
|
|
return Ok(content.to_string());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return Ok(msg.payload);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Sleep before next check
|
||
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Set the active bot for a session
|
||
|
|
async fn set_session_active_bot(
|
||
|
|
state: &AppState,
|
||
|
|
session_id: Uuid,
|
||
|
|
bot_name: &str,
|
||
|
|
) -> Result<String, String> {
|
||
|
|
let mut conn = state
|
||
|
|
.conn
|
||
|
|
.get()
|
||
|
|
.map_err(|e| format!("Failed to acquire database connection: {}", e))?;
|
||
|
|
|
||
|
|
let now = chrono::Utc::now();
|
||
|
|
|
||
|
|
diesel::sql_query(
|
||
|
|
"INSERT INTO session_preferences (session_id, preference_key, preference_value, updated_at) \
|
||
|
|
VALUES ($1, 'active_bot', $2, $3) \
|
||
|
|
ON CONFLICT (session_id, preference_key) DO UPDATE SET \
|
||
|
|
preference_value = EXCLUDED.preference_value, \
|
||
|
|
updated_at = EXCLUDED.updated_at",
|
||
|
|
)
|
||
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
||
|
|
.bind::<diesel::sql_types::Text, _>(bot_name)
|
||
|
|
.bind::<diesel::sql_types::Timestamptz, _>(now)
|
||
|
|
.execute(&mut conn)
|
||
|
|
.map_err(|e| format!("Failed to set active bot: {}", e))?;
|
||
|
|
|
||
|
|
info!("Session {} delegated to bot: {}", session_id, bot_name);
|
||
|
|
|
||
|
|
Ok(format!("Conversation delegated to {}", bot_name))
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get pending messages for a bot (sync version)
|
||
|
|
fn get_pending_messages_sync(
|
||
|
|
conn: &mut diesel::PgConnection,
|
||
|
|
session_id: Uuid,
|
||
|
|
to_agent: &str,
|
||
|
|
) -> Result<Vec<A2AMessage>, String> {
|
||
|
|
#[derive(QueryableByName)]
|
||
|
|
struct MessageRow {
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Uuid)]
|
||
|
|
id: Uuid,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Uuid)]
|
||
|
|
session_id: Uuid,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||
|
|
from_agent: String,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
||
|
|
to_agent: Option<String>,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||
|
|
message_type: String,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
||
|
|
payload: String,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Uuid)]
|
||
|
|
correlation_id: Uuid,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Timestamptz)]
|
||
|
|
timestamp: chrono::DateTime<chrono::Utc>,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
||
|
|
ttl_seconds: i32,
|
||
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
||
|
|
hop_count: i32,
|
||
|
|
}
|
||
|
|
|
||
|
|
let rows: Vec<MessageRow> = diesel::sql_query(
|
||
|
|
"SELECT id, session_id, from_agent, to_agent, message_type, payload, \
|
||
|
|
correlation_id, timestamp, ttl_seconds, hop_count \
|
||
|
|
FROM a2a_messages \
|
||
|
|
WHERE session_id = $1 AND (to_agent = $2 OR to_agent IS NULL) AND processed = false \
|
||
|
|
ORDER BY timestamp ASC",
|
||
|
|
)
|
||
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
||
|
|
.bind::<diesel::sql_types::Text, _>(to_agent)
|
||
|
|
.load(conn)
|
||
|
|
.map_err(|e| format!("Failed to get pending messages: {}", e))?;
|
||
|
|
|
||
|
|
let messages: Vec<A2AMessage> = rows
|
||
|
|
.into_iter()
|
||
|
|
.map(|row| A2AMessage {
|
||
|
|
id: row.id,
|
||
|
|
session_id: row.session_id,
|
||
|
|
from_agent: row.from_agent,
|
||
|
|
to_agent: row.to_agent,
|
||
|
|
message_type: A2AMessageType::from(row.message_type.as_str()),
|
||
|
|
payload: serde_json::from_str(&row.payload).unwrap_or(serde_json::json!({})),
|
||
|
|
correlation_id: row.correlation_id,
|
||
|
|
timestamp: row.timestamp,
|
||
|
|
metadata: HashMap::new(),
|
||
|
|
ttl_seconds: row.ttl_seconds as u32,
|
||
|
|
hop_count: row.hop_count as u32,
|
||
|
|
})
|
||
|
|
.filter(|msg| !msg.is_expired())
|
||
|
|
.collect();
|
||
|
|
|
||
|
|
Ok(messages)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Respond to an A2A message (helper for bots to use)
|
||
|
|
pub async fn respond_to_a2a_message(
|
||
|
|
state: &AppState,
|
||
|
|
original_message: &A2AMessage,
|
||
|
|
from_agent: &str,
|
||
|
|
response_content: &str,
|
||
|
|
) -> Result<Uuid, String> {
|
||
|
|
send_a2a_message(
|
||
|
|
state,
|
||
|
|
original_message.session_id,
|
||
|
|
from_agent,
|
||
|
|
Some(&original_message.from_agent),
|
||
|
|
A2AMessageType::Response,
|
||
|
|
serde_json::json!({ "content": response_content }),
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
}
|
||
|
|
|
||
|
|
// ============================================================================
|
||
|
|
// Tests
|
||
|
|
// ============================================================================
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_a2a_message_creation() {
|
||
|
|
let msg = A2AMessage::new(
|
||
|
|
"bot_a",
|
||
|
|
Some("bot_b"),
|
||
|
|
A2AMessageType::Request,
|
||
|
|
serde_json::json!({"test": "data"}),
|
||
|
|
Uuid::new_v4(),
|
||
|
|
);
|
||
|
|
|
||
|
|
assert_eq!(msg.from_agent, "bot_a");
|
||
|
|
assert_eq!(msg.to_agent, Some("bot_b".to_string()));
|
||
|
|
assert_eq!(msg.message_type, A2AMessageType::Request);
|
||
|
|
assert_eq!(msg.hop_count, 0);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_a2a_message_response() {
|
||
|
|
let original = A2AMessage::new(
|
||
|
|
"bot_a",
|
||
|
|
Some("bot_b"),
|
||
|
|
A2AMessageType::Request,
|
||
|
|
serde_json::json!({"question": "test"}),
|
||
|
|
Uuid::new_v4(),
|
||
|
|
);
|
||
|
|
|
||
|
|
let response = original.create_response("bot_b", serde_json::json!({"answer": "result"}));
|
||
|
|
|
||
|
|
assert_eq!(response.from_agent, "bot_b");
|
||
|
|
assert_eq!(response.to_agent, Some("bot_a".to_string()));
|
||
|
|
assert_eq!(response.message_type, A2AMessageType::Response);
|
||
|
|
assert_eq!(response.correlation_id, original.correlation_id);
|
||
|
|
assert_eq!(response.hop_count, 1);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_message_type_display() {
|
||
|
|
assert_eq!(A2AMessageType::Request.to_string(), "request");
|
||
|
|
assert_eq!(A2AMessageType::Response.to_string(), "response");
|
||
|
|
assert_eq!(A2AMessageType::Broadcast.to_string(), "broadcast");
|
||
|
|
assert_eq!(A2AMessageType::Delegate.to_string(), "delegate");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_message_type_from_str() {
|
||
|
|
assert_eq!(A2AMessageType::from("request"), A2AMessageType::Request);
|
||
|
|
assert_eq!(A2AMessageType::from("RESPONSE"), A2AMessageType::Response);
|
||
|
|
assert_eq!(A2AMessageType::from("unknown"), A2AMessageType::Request);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_a2a_config_default() {
|
||
|
|
let config = A2AConfig::default();
|
||
|
|
assert!(config.enabled);
|
||
|
|
assert_eq!(config.timeout_seconds, 30);
|
||
|
|
assert_eq!(config.max_hops, 5);
|
||
|
|
assert_eq!(config.protocol_version, "1.0");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_message_not_expired() {
|
||
|
|
let msg = A2AMessage::new(
|
||
|
|
"bot_a",
|
||
|
|
Some("bot_b"),
|
||
|
|
A2AMessageType::Request,
|
||
|
|
serde_json::json!({}),
|
||
|
|
Uuid::new_v4(),
|
||
|
|
);
|
||
|
|
|
||
|
|
assert!(!msg.is_expired());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_max_hops_not_exceeded() {
|
||
|
|
let msg = A2AMessage::new(
|
||
|
|
"bot_a",
|
||
|
|
Some("bot_b"),
|
||
|
|
A2AMessageType::Request,
|
||
|
|
serde_json::json!({}),
|
||
|
|
Uuid::new_v4(),
|
||
|
|
);
|
||
|
|
|
||
|
|
assert!(!msg.max_hops_exceeded(5));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_max_hops_exceeded() {
|
||
|
|
let mut msg = A2AMessage::new(
|
||
|
|
"bot_a",
|
||
|
|
Some("bot_b"),
|
||
|
|
A2AMessageType::Request,
|
||
|
|
serde_json::json!({}),
|
||
|
|
Uuid::new_v4(),
|
||
|
|
);
|
||
|
|
msg.hop_count = 5;
|
||
|
|
|
||
|
|
assert!(msg.max_hops_exceeded(5));
|
||
|
|
}
|
||
|
|
}
|