Drop image (with ravif/paste), sqlx, zitadel, and related dependencies that were causing compilation issues. Replace image processing with direct png crate usage. Update rcgen to 0.14 with new API changes. Refactor CA certificate generation to use Issuer pattern.
949 lines
33 KiB
Rust
949 lines
33 KiB
Rust
use crate::shared::models::UserSession;
|
|
use crate::shared::state::AppState;
|
|
use diesel::prelude::*;
|
|
use log::{info, trace, warn};
|
|
use rhai::{Dynamic, Engine};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
|
|
/// A2A Protocol Message Types
|
|
/// Based on https://a2a-protocol.org/latest/
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub enum A2AMessageType {
|
|
/// Agent requesting action from another agent
|
|
Request,
|
|
/// Agent responding to a request
|
|
Response,
|
|
/// Message broadcast to all agents in session
|
|
Broadcast,
|
|
/// Hand off conversation to another agent
|
|
Delegate,
|
|
/// Request collaboration on a task
|
|
Collaborate,
|
|
/// Acknowledge receipt of message
|
|
Ack,
|
|
/// Error response
|
|
Error,
|
|
}
|
|
|
|
impl Default for A2AMessageType {
|
|
fn default() -> Self {
|
|
A2AMessageType::Request
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for A2AMessageType {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
A2AMessageType::Request => write!(f, "request"),
|
|
A2AMessageType::Response => write!(f, "response"),
|
|
A2AMessageType::Broadcast => write!(f, "broadcast"),
|
|
A2AMessageType::Delegate => write!(f, "delegate"),
|
|
A2AMessageType::Collaborate => write!(f, "collaborate"),
|
|
A2AMessageType::Ack => write!(f, "ack"),
|
|
A2AMessageType::Error => write!(f, "error"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<&str> for A2AMessageType {
|
|
fn from(s: &str) -> Self {
|
|
match s.to_lowercase().as_str() {
|
|
"request" => A2AMessageType::Request,
|
|
"response" => A2AMessageType::Response,
|
|
"broadcast" => A2AMessageType::Broadcast,
|
|
"delegate" => A2AMessageType::Delegate,
|
|
"collaborate" => A2AMessageType::Collaborate,
|
|
"ack" => A2AMessageType::Ack,
|
|
"error" => A2AMessageType::Error,
|
|
_ => A2AMessageType::Request,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A2A Protocol Message
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct A2AMessage {
|
|
/// Unique message identifier
|
|
pub id: Uuid,
|
|
/// Source agent identifier
|
|
pub from_agent: String,
|
|
/// Target agent identifier (None for broadcasts)
|
|
pub to_agent: Option<String>,
|
|
/// Message type
|
|
pub message_type: A2AMessageType,
|
|
/// Message payload (JSON)
|
|
pub payload: serde_json::Value,
|
|
/// Correlation ID for request-response matching
|
|
pub correlation_id: Uuid,
|
|
/// Session ID
|
|
pub session_id: Uuid,
|
|
/// Timestamp
|
|
pub timestamp: chrono::DateTime<chrono::Utc>,
|
|
/// Optional metadata
|
|
pub metadata: HashMap<String, String>,
|
|
/// TTL in seconds (0 = no expiry)
|
|
pub ttl_seconds: u32,
|
|
/// Hop count for preventing infinite loops
|
|
pub hop_count: u32,
|
|
}
|
|
|
|
impl A2AMessage {
|
|
pub fn new(
|
|
from_agent: &str,
|
|
to_agent: Option<&str>,
|
|
message_type: A2AMessageType,
|
|
payload: serde_json::Value,
|
|
session_id: Uuid,
|
|
) -> Self {
|
|
Self {
|
|
id: Uuid::new_v4(),
|
|
from_agent: from_agent.to_string(),
|
|
to_agent: to_agent.map(|s| s.to_string()),
|
|
message_type,
|
|
payload,
|
|
correlation_id: Uuid::new_v4(),
|
|
session_id,
|
|
timestamp: chrono::Utc::now(),
|
|
metadata: HashMap::new(),
|
|
ttl_seconds: 30,
|
|
hop_count: 0,
|
|
}
|
|
}
|
|
|
|
/// Create a response to this message
|
|
pub fn create_response(&self, from_agent: &str, payload: serde_json::Value) -> Self {
|
|
Self {
|
|
id: Uuid::new_v4(),
|
|
from_agent: from_agent.to_string(),
|
|
to_agent: Some(self.from_agent.clone()),
|
|
message_type: A2AMessageType::Response,
|
|
payload,
|
|
correlation_id: self.correlation_id,
|
|
session_id: self.session_id,
|
|
timestamp: chrono::Utc::now(),
|
|
metadata: HashMap::new(),
|
|
ttl_seconds: 30,
|
|
hop_count: self.hop_count + 1,
|
|
}
|
|
}
|
|
|
|
/// Check if message has expired
|
|
pub fn is_expired(&self) -> bool {
|
|
if self.ttl_seconds == 0 {
|
|
return false;
|
|
}
|
|
let now = chrono::Utc::now();
|
|
let expiry = self.timestamp + chrono::Duration::seconds(self.ttl_seconds as i64);
|
|
now > expiry
|
|
}
|
|
|
|
/// Check if max hops exceeded
|
|
pub fn max_hops_exceeded(&self, max_hops: u32) -> bool {
|
|
self.hop_count >= max_hops
|
|
}
|
|
}
|
|
|
|
/// A2A Protocol Configuration
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct A2AConfig {
|
|
/// Whether A2A protocol is enabled
|
|
pub enabled: bool,
|
|
/// Default timeout in seconds
|
|
pub timeout_seconds: u32,
|
|
/// Maximum hops to prevent infinite loops
|
|
pub max_hops: u32,
|
|
/// Protocol version
|
|
pub protocol_version: String,
|
|
/// Enable message persistence
|
|
pub persist_messages: bool,
|
|
/// Retry attempts on failure
|
|
pub retry_count: u32,
|
|
/// Maximum pending messages in queue
|
|
pub queue_size: u32,
|
|
}
|
|
|
|
impl Default for A2AConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
enabled: true,
|
|
timeout_seconds: 30,
|
|
max_hops: 5,
|
|
protocol_version: "1.0".to_string(),
|
|
persist_messages: true,
|
|
retry_count: 3,
|
|
queue_size: 100,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Load A2A configuration from bot config
|
|
pub fn load_a2a_config(state: &AppState, bot_id: Uuid) -> A2AConfig {
|
|
let mut config = A2AConfig::default();
|
|
|
|
if let Ok(mut conn) = state.conn.get() {
|
|
#[derive(QueryableByName)]
|
|
struct ConfigRow {
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
config_key: String,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
config_value: String,
|
|
}
|
|
|
|
let configs: Vec<ConfigRow> = diesel::sql_query(
|
|
"SELECT config_key, config_value FROM bot_configuration \
|
|
WHERE bot_id = $1 AND config_key LIKE 'a2a-%'",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(bot_id)
|
|
.load(&mut conn)
|
|
.unwrap_or_default();
|
|
|
|
for row in configs {
|
|
match row.config_key.as_str() {
|
|
"a2a-enabled" => {
|
|
config.enabled = row.config_value.to_lowercase() == "true";
|
|
}
|
|
"a2a-timeout" => {
|
|
config.timeout_seconds = row.config_value.parse().unwrap_or(30);
|
|
}
|
|
"a2a-max-hops" => {
|
|
config.max_hops = row.config_value.parse().unwrap_or(5);
|
|
}
|
|
"a2a-protocol-version" => {
|
|
config.protocol_version = row.config_value;
|
|
}
|
|
"a2a-persist-messages" => {
|
|
config.persist_messages = row.config_value.to_lowercase() == "true";
|
|
}
|
|
"a2a-retry-count" => {
|
|
config.retry_count = row.config_value.parse().unwrap_or(3);
|
|
}
|
|
"a2a-queue-size" => {
|
|
config.queue_size = row.config_value.parse().unwrap_or(100);
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
}
|
|
|
|
config
|
|
}
|
|
|
|
/// Register all A2A protocol keywords
|
|
pub fn register_a2a_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
send_to_bot_keyword(state.clone(), user.clone(), engine);
|
|
broadcast_message_keyword(state.clone(), user.clone(), engine);
|
|
collaborate_with_keyword(state.clone(), user.clone(), engine);
|
|
wait_for_bot_keyword(state.clone(), user.clone(), engine);
|
|
delegate_conversation_keyword(state.clone(), user.clone(), engine);
|
|
get_a2a_messages_keyword(state.clone(), user.clone(), engine);
|
|
}
|
|
|
|
/// SEND TO BOT "bot_name" MESSAGE "message_content"
|
|
/// Send a message to a specific bot
|
|
pub fn send_to_bot_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(
|
|
&["SEND", "TO", "BOT", "$expr$", "MESSAGE", "$expr$"],
|
|
false,
|
|
move |context, inputs| {
|
|
let target_bot = context
|
|
.eval_expression_tree(&inputs[0])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_string();
|
|
let message_content = context
|
|
.eval_expression_tree(&inputs[1])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_string();
|
|
|
|
trace!(
|
|
"SEND TO BOT '{}' MESSAGE for session: {}",
|
|
target_bot,
|
|
user_clone.id
|
|
);
|
|
|
|
let state_for_task = Arc::clone(&state_clone);
|
|
let session_id = user_clone.id;
|
|
let bot_id = user_clone.bot_id;
|
|
let from_bot = format!("bot_{}", bot_id);
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
|
let result = rt.block_on(async {
|
|
send_a2a_message(
|
|
&state_for_task,
|
|
session_id,
|
|
&from_bot,
|
|
Some(&target_bot),
|
|
A2AMessageType::Request,
|
|
serde_json::json!({ "content": message_content }),
|
|
)
|
|
.await
|
|
});
|
|
let _ = tx.send(result);
|
|
});
|
|
|
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
|
Ok(Ok(msg_id)) => Ok(Dynamic::from(msg_id.to_string())),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"SEND TO BOT timed out".into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
},
|
|
)
|
|
.expect("Failed to register SEND TO BOT syntax");
|
|
}
|
|
|
|
/// BROADCAST MESSAGE "message_content"
|
|
/// Broadcast a message to all bots in the session
|
|
pub fn broadcast_message_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(
|
|
&["BROADCAST", "MESSAGE", "$expr$"],
|
|
false,
|
|
move |context, inputs| {
|
|
let message_content = context
|
|
.eval_expression_tree(&inputs[0])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_string();
|
|
|
|
trace!("BROADCAST MESSAGE for session: {}", user_clone.id);
|
|
|
|
let state_for_task = Arc::clone(&state_clone);
|
|
let session_id = user_clone.id;
|
|
let bot_id = user_clone.bot_id;
|
|
let from_bot = format!("bot_{}", bot_id);
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
|
let result = rt.block_on(async {
|
|
send_a2a_message(
|
|
&state_for_task,
|
|
session_id,
|
|
&from_bot,
|
|
None, // No target = broadcast
|
|
A2AMessageType::Broadcast,
|
|
serde_json::json!({ "content": message_content }),
|
|
)
|
|
.await
|
|
});
|
|
let _ = tx.send(result);
|
|
});
|
|
|
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
|
Ok(Ok(msg_id)) => Ok(Dynamic::from(msg_id.to_string())),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"BROADCAST MESSAGE timed out".into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
},
|
|
)
|
|
.expect("Failed to register BROADCAST MESSAGE syntax");
|
|
}
|
|
|
|
/// COLLABORATE WITH "bot1", "bot2" ON "task"
|
|
/// Request collaboration from multiple bots on a task
|
|
pub fn collaborate_with_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(
|
|
&["COLLABORATE", "WITH", "$expr$", "ON", "$expr$"],
|
|
false,
|
|
move |context, inputs| {
|
|
let bots_str = context
|
|
.eval_expression_tree(&inputs[0])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_string();
|
|
let task = context
|
|
.eval_expression_tree(&inputs[1])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_string();
|
|
|
|
let bots: Vec<String> = bots_str
|
|
.split(',')
|
|
.map(|s| s.trim().trim_matches('"').to_string())
|
|
.filter(|s| !s.is_empty())
|
|
.collect();
|
|
|
|
trace!(
|
|
"COLLABORATE WITH {:?} ON '{}' for session: {}",
|
|
bots,
|
|
task,
|
|
user_clone.id
|
|
);
|
|
|
|
let state_for_task = Arc::clone(&state_clone);
|
|
let session_id = user_clone.id;
|
|
let bot_id = user_clone.bot_id;
|
|
let from_bot = format!("bot_{}", bot_id);
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
|
let result = rt.block_on(async {
|
|
let mut message_ids = Vec::new();
|
|
for target_bot in &bots {
|
|
match send_a2a_message(
|
|
&state_for_task,
|
|
session_id,
|
|
&from_bot,
|
|
Some(target_bot),
|
|
A2AMessageType::Collaborate,
|
|
serde_json::json!({
|
|
"task": task,
|
|
"collaborators": bots.clone()
|
|
}),
|
|
)
|
|
.await
|
|
{
|
|
Ok(id) => message_ids.push(id.to_string()),
|
|
Err(e) => {
|
|
warn!(
|
|
"Failed to send collaboration request to {}: {}",
|
|
target_bot, e
|
|
);
|
|
}
|
|
}
|
|
}
|
|
Ok::<Vec<String>, String>(message_ids)
|
|
});
|
|
let _ = tx.send(result);
|
|
});
|
|
|
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
|
Ok(Ok(ids)) => {
|
|
let array: rhai::Array = ids.into_iter().map(Dynamic::from).collect();
|
|
Ok(Dynamic::from(array))
|
|
}
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"COLLABORATE WITH timed out".into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
},
|
|
)
|
|
.expect("Failed to register COLLABORATE WITH syntax");
|
|
}
|
|
|
|
/// response = WAIT FOR BOT "bot_name" TIMEOUT seconds
|
|
/// Wait for a response from a specific bot
|
|
pub fn wait_for_bot_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(
|
|
&["WAIT", "FOR", "BOT", "$expr$", "TIMEOUT", "$expr$"],
|
|
false,
|
|
move |context, inputs| {
|
|
let target_bot = context
|
|
.eval_expression_tree(&inputs[0])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_string();
|
|
let timeout_secs: i64 = context
|
|
.eval_expression_tree(&inputs[1])?
|
|
.as_int()
|
|
.unwrap_or(30);
|
|
|
|
trace!(
|
|
"WAIT FOR BOT '{}' TIMEOUT {} for session: {}",
|
|
target_bot,
|
|
timeout_secs,
|
|
user_clone.id
|
|
);
|
|
|
|
let state_for_task = Arc::clone(&state_clone);
|
|
let session_id = user_clone.id;
|
|
let bot_id = user_clone.bot_id;
|
|
let current_bot = format!("bot_{}", bot_id);
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
|
let result = rt.block_on(async {
|
|
wait_for_bot_response(
|
|
&state_for_task,
|
|
session_id,
|
|
&target_bot,
|
|
¤t_bot,
|
|
timeout_secs as u64,
|
|
)
|
|
.await
|
|
});
|
|
let _ = tx.send(result);
|
|
});
|
|
|
|
let timeout_duration = std::time::Duration::from_secs(timeout_secs as u64 + 5);
|
|
match rx.recv_timeout(timeout_duration) {
|
|
Ok(Ok(response)) => Ok(Dynamic::from(response)),
|
|
Ok(Err(e)) => Ok(Dynamic::from(format!("Error: {}", e))),
|
|
Err(_) => Ok(Dynamic::from("Timeout waiting for response".to_string())),
|
|
}
|
|
},
|
|
)
|
|
.expect("Failed to register WAIT FOR BOT syntax");
|
|
}
|
|
|
|
/// DELEGATE CONVERSATION TO "bot_name"
|
|
/// Hand off the entire conversation to another bot
|
|
pub fn delegate_conversation_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(
|
|
&["DELEGATE", "CONVERSATION", "TO", "$expr$"],
|
|
false,
|
|
move |context, inputs| {
|
|
let target_bot = context
|
|
.eval_expression_tree(&inputs[0])?
|
|
.to_string()
|
|
.trim_matches('"')
|
|
.to_string();
|
|
|
|
trace!(
|
|
"DELEGATE CONVERSATION TO '{}' for session: {}",
|
|
target_bot,
|
|
user_clone.id
|
|
);
|
|
|
|
let state_for_task = Arc::clone(&state_clone);
|
|
let session_id = user_clone.id;
|
|
let bot_id = user_clone.bot_id;
|
|
let from_bot = format!("bot_{}", bot_id);
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
|
|
let result = rt.block_on(async {
|
|
// Send delegate message
|
|
send_a2a_message(
|
|
&state_for_task,
|
|
session_id,
|
|
&from_bot,
|
|
Some(&target_bot),
|
|
A2AMessageType::Delegate,
|
|
serde_json::json!({
|
|
"action": "take_over",
|
|
"reason": "delegation_request"
|
|
}),
|
|
)
|
|
.await?;
|
|
|
|
// Update session to set new active bot
|
|
set_session_active_bot(&state_for_task, session_id, &target_bot).await
|
|
});
|
|
let _ = tx.send(result);
|
|
});
|
|
|
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
|
Ok(Ok(msg)) => Ok(Dynamic::from(msg)),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"DELEGATE CONVERSATION timed out".into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
},
|
|
)
|
|
.expect("Failed to register DELEGATE CONVERSATION syntax");
|
|
}
|
|
|
|
/// GET A2A MESSAGES()
|
|
/// Get all A2A messages for this bot in current session
|
|
pub fn get_a2a_messages_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine.register_fn("GET A2A MESSAGES", move || -> rhai::Array {
|
|
let state = Arc::clone(&state_clone);
|
|
let session_id = user_clone.id;
|
|
let bot_id = user_clone.bot_id;
|
|
let current_bot = format!("bot_{}", bot_id);
|
|
|
|
if let Ok(mut conn) = state.conn.get() {
|
|
get_pending_messages_sync(&mut conn, session_id, ¤t_bot)
|
|
.unwrap_or_default()
|
|
.into_iter()
|
|
.map(|msg| Dynamic::from(serde_json::to_string(&msg).unwrap_or_default()))
|
|
.collect()
|
|
} else {
|
|
rhai::Array::new()
|
|
}
|
|
});
|
|
}
|
|
|
|
// ============================================================================
|
|
// Database Operations
|
|
// ============================================================================
|
|
|
|
/// Send an A2A message
|
|
async fn send_a2a_message(
|
|
state: &AppState,
|
|
session_id: Uuid,
|
|
from_agent: &str,
|
|
to_agent: Option<&str>,
|
|
message_type: A2AMessageType,
|
|
payload: serde_json::Value,
|
|
) -> Result<Uuid, String> {
|
|
let mut conn = state
|
|
.conn
|
|
.get()
|
|
.map_err(|e| format!("Failed to acquire database connection: {}", e))?;
|
|
|
|
let message = A2AMessage::new(from_agent, to_agent, message_type, payload, session_id);
|
|
let message_id = message.id;
|
|
|
|
let payload_str = serde_json::to_string(&message.payload)
|
|
.map_err(|e| format!("Failed to serialize payload: {}", e))?;
|
|
|
|
let metadata_str = serde_json::to_string(&message.metadata)
|
|
.map_err(|e| format!("Failed to serialize metadata: {}", e))?;
|
|
|
|
diesel::sql_query(
|
|
"INSERT INTO a2a_messages \
|
|
(id, session_id, from_agent, to_agent, message_type, payload, correlation_id, \
|
|
timestamp, metadata, ttl_seconds, hop_count, processed) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, false)",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(message.id)
|
|
.bind::<diesel::sql_types::Uuid, _>(message.session_id)
|
|
.bind::<diesel::sql_types::Text, _>(&message.from_agent)
|
|
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(message.to_agent.as_deref())
|
|
.bind::<diesel::sql_types::Text, _>(message.message_type.to_string())
|
|
.bind::<diesel::sql_types::Text, _>(&payload_str)
|
|
.bind::<diesel::sql_types::Uuid, _>(message.correlation_id)
|
|
.bind::<diesel::sql_types::Timestamptz, _>(message.timestamp)
|
|
.bind::<diesel::sql_types::Text, _>(&metadata_str)
|
|
.bind::<diesel::sql_types::Integer, _>(message.ttl_seconds as i32)
|
|
.bind::<diesel::sql_types::Integer, _>(message.hop_count as i32)
|
|
.execute(&mut conn)
|
|
.map_err(|e| format!("Failed to insert A2A message: {}", e))?;
|
|
|
|
info!(
|
|
"A2A message sent: {} -> {:?} (type: {})",
|
|
from_agent, to_agent, message.message_type
|
|
);
|
|
|
|
Ok(message_id)
|
|
}
|
|
|
|
/// Wait for a response from a specific bot
|
|
async fn wait_for_bot_response(
|
|
state: &AppState,
|
|
session_id: Uuid,
|
|
from_bot: &str,
|
|
to_bot: &str,
|
|
timeout_secs: u64,
|
|
) -> Result<String, String> {
|
|
let start = std::time::Instant::now();
|
|
let timeout = std::time::Duration::from_secs(timeout_secs);
|
|
|
|
loop {
|
|
if start.elapsed() > timeout {
|
|
return Err("Timeout waiting for bot response".to_string());
|
|
}
|
|
|
|
// Check for response
|
|
if let Ok(mut conn) = state.conn.get() {
|
|
#[derive(QueryableByName)]
|
|
struct MessageRow {
|
|
#[diesel(sql_type = diesel::sql_types::Uuid)]
|
|
id: Uuid,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
payload: String,
|
|
}
|
|
|
|
let result: Option<MessageRow> = diesel::sql_query(
|
|
"SELECT id, payload FROM a2a_messages \
|
|
WHERE session_id = $1 AND from_agent = $2 AND to_agent = $3 \
|
|
AND message_type = 'response' AND processed = false \
|
|
ORDER BY timestamp DESC LIMIT 1",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
|
.bind::<diesel::sql_types::Text, _>(from_bot)
|
|
.bind::<diesel::sql_types::Text, _>(to_bot)
|
|
.get_result(&mut conn)
|
|
.optional()
|
|
.map_err(|e| format!("Failed to query messages: {}", e))?;
|
|
|
|
if let Some(msg) = result {
|
|
// Mark as processed
|
|
let _ = diesel::sql_query("UPDATE a2a_messages SET processed = true WHERE id = $1")
|
|
.bind::<diesel::sql_types::Uuid, _>(msg.id)
|
|
.execute(&mut conn);
|
|
|
|
// Extract content from payload
|
|
if let Ok(payload) = serde_json::from_str::<serde_json::Value>(&msg.payload) {
|
|
if let Some(content) = payload.get("content").and_then(|c| c.as_str()) {
|
|
return Ok(content.to_string());
|
|
}
|
|
}
|
|
return Ok(msg.payload);
|
|
}
|
|
}
|
|
|
|
// Sleep before next check
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
|
}
|
|
}
|
|
|
|
/// Set the active bot for a session
|
|
async fn set_session_active_bot(
|
|
state: &AppState,
|
|
session_id: Uuid,
|
|
bot_name: &str,
|
|
) -> Result<String, String> {
|
|
let mut conn = state
|
|
.conn
|
|
.get()
|
|
.map_err(|e| format!("Failed to acquire database connection: {}", e))?;
|
|
|
|
let now = chrono::Utc::now();
|
|
|
|
diesel::sql_query(
|
|
"INSERT INTO session_preferences (session_id, preference_key, preference_value, updated_at) \
|
|
VALUES ($1, 'active_bot', $2, $3) \
|
|
ON CONFLICT (session_id, preference_key) DO UPDATE SET \
|
|
preference_value = EXCLUDED.preference_value, \
|
|
updated_at = EXCLUDED.updated_at",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
|
.bind::<diesel::sql_types::Text, _>(bot_name)
|
|
.bind::<diesel::sql_types::Timestamptz, _>(now)
|
|
.execute(&mut conn)
|
|
.map_err(|e| format!("Failed to set active bot: {}", e))?;
|
|
|
|
info!("Session {} delegated to bot: {}", session_id, bot_name);
|
|
|
|
Ok(format!("Conversation delegated to {}", bot_name))
|
|
}
|
|
|
|
/// Get pending messages for a bot (sync version)
|
|
fn get_pending_messages_sync(
|
|
conn: &mut diesel::PgConnection,
|
|
session_id: Uuid,
|
|
to_agent: &str,
|
|
) -> Result<Vec<A2AMessage>, String> {
|
|
#[derive(QueryableByName)]
|
|
struct MessageRow {
|
|
#[diesel(sql_type = diesel::sql_types::Uuid)]
|
|
id: Uuid,
|
|
#[diesel(sql_type = diesel::sql_types::Uuid)]
|
|
session_id: Uuid,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
from_agent: String,
|
|
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
|
|
to_agent: Option<String>,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
message_type: String,
|
|
#[diesel(sql_type = diesel::sql_types::Text)]
|
|
payload: String,
|
|
#[diesel(sql_type = diesel::sql_types::Uuid)]
|
|
correlation_id: Uuid,
|
|
#[diesel(sql_type = diesel::sql_types::Timestamptz)]
|
|
timestamp: chrono::DateTime<chrono::Utc>,
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
|
ttl_seconds: i32,
|
|
#[diesel(sql_type = diesel::sql_types::Integer)]
|
|
hop_count: i32,
|
|
}
|
|
|
|
let rows: Vec<MessageRow> = diesel::sql_query(
|
|
"SELECT id, session_id, from_agent, to_agent, message_type, payload, \
|
|
correlation_id, timestamp, ttl_seconds, hop_count \
|
|
FROM a2a_messages \
|
|
WHERE session_id = $1 AND (to_agent = $2 OR to_agent IS NULL) AND processed = false \
|
|
ORDER BY timestamp ASC",
|
|
)
|
|
.bind::<diesel::sql_types::Uuid, _>(session_id)
|
|
.bind::<diesel::sql_types::Text, _>(to_agent)
|
|
.load(conn)
|
|
.map_err(|e| format!("Failed to get pending messages: {}", e))?;
|
|
|
|
let messages: Vec<A2AMessage> = rows
|
|
.into_iter()
|
|
.map(|row| A2AMessage {
|
|
id: row.id,
|
|
session_id: row.session_id,
|
|
from_agent: row.from_agent,
|
|
to_agent: row.to_agent,
|
|
message_type: A2AMessageType::from(row.message_type.as_str()),
|
|
payload: serde_json::from_str(&row.payload).unwrap_or(serde_json::json!({})),
|
|
correlation_id: row.correlation_id,
|
|
timestamp: row.timestamp,
|
|
metadata: HashMap::new(),
|
|
ttl_seconds: row.ttl_seconds as u32,
|
|
hop_count: row.hop_count as u32,
|
|
})
|
|
.filter(|msg| !msg.is_expired())
|
|
.collect();
|
|
|
|
Ok(messages)
|
|
}
|
|
|
|
/// Respond to an A2A message (helper for bots to use)
|
|
pub async fn respond_to_a2a_message(
|
|
state: &AppState,
|
|
original_message: &A2AMessage,
|
|
from_agent: &str,
|
|
response_content: &str,
|
|
) -> Result<Uuid, String> {
|
|
send_a2a_message(
|
|
state,
|
|
original_message.session_id,
|
|
from_agent,
|
|
Some(&original_message.from_agent),
|
|
A2AMessageType::Response,
|
|
serde_json::json!({ "content": response_content }),
|
|
)
|
|
.await
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tests
|
|
// ============================================================================
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_a2a_message_creation() {
|
|
let msg = A2AMessage::new(
|
|
"bot_a",
|
|
Some("bot_b"),
|
|
A2AMessageType::Request,
|
|
serde_json::json!({"test": "data"}),
|
|
Uuid::new_v4(),
|
|
);
|
|
|
|
assert_eq!(msg.from_agent, "bot_a");
|
|
assert_eq!(msg.to_agent, Some("bot_b".to_string()));
|
|
assert_eq!(msg.message_type, A2AMessageType::Request);
|
|
assert_eq!(msg.hop_count, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_a2a_message_response() {
|
|
let original = A2AMessage::new(
|
|
"bot_a",
|
|
Some("bot_b"),
|
|
A2AMessageType::Request,
|
|
serde_json::json!({"question": "test"}),
|
|
Uuid::new_v4(),
|
|
);
|
|
|
|
let response = original.create_response("bot_b", serde_json::json!({"answer": "result"}));
|
|
|
|
assert_eq!(response.from_agent, "bot_b");
|
|
assert_eq!(response.to_agent, Some("bot_a".to_string()));
|
|
assert_eq!(response.message_type, A2AMessageType::Response);
|
|
assert_eq!(response.correlation_id, original.correlation_id);
|
|
assert_eq!(response.hop_count, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_message_type_display() {
|
|
assert_eq!(A2AMessageType::Request.to_string(), "request");
|
|
assert_eq!(A2AMessageType::Response.to_string(), "response");
|
|
assert_eq!(A2AMessageType::Broadcast.to_string(), "broadcast");
|
|
assert_eq!(A2AMessageType::Delegate.to_string(), "delegate");
|
|
}
|
|
|
|
#[test]
|
|
fn test_message_type_from_str() {
|
|
assert_eq!(A2AMessageType::from("request"), A2AMessageType::Request);
|
|
assert_eq!(A2AMessageType::from("RESPONSE"), A2AMessageType::Response);
|
|
assert_eq!(A2AMessageType::from("unknown"), A2AMessageType::Request);
|
|
}
|
|
|
|
#[test]
|
|
fn test_a2a_config_default() {
|
|
let config = A2AConfig::default();
|
|
assert!(config.enabled);
|
|
assert_eq!(config.timeout_seconds, 30);
|
|
assert_eq!(config.max_hops, 5);
|
|
assert_eq!(config.protocol_version, "1.0");
|
|
}
|
|
|
|
#[test]
|
|
fn test_message_not_expired() {
|
|
let msg = A2AMessage::new(
|
|
"bot_a",
|
|
Some("bot_b"),
|
|
A2AMessageType::Request,
|
|
serde_json::json!({}),
|
|
Uuid::new_v4(),
|
|
);
|
|
|
|
assert!(!msg.is_expired());
|
|
}
|
|
|
|
#[test]
|
|
fn test_max_hops_not_exceeded() {
|
|
let msg = A2AMessage::new(
|
|
"bot_a",
|
|
Some("bot_b"),
|
|
A2AMessageType::Request,
|
|
serde_json::json!({}),
|
|
Uuid::new_v4(),
|
|
);
|
|
|
|
assert!(!msg.max_hops_exceeded(5));
|
|
}
|
|
|
|
#[test]
|
|
fn test_max_hops_exceeded() {
|
|
let mut msg = A2AMessage::new(
|
|
"bot_a",
|
|
Some("bot_b"),
|
|
A2AMessageType::Request,
|
|
serde_json::json!({}),
|
|
Uuid::new_v4(),
|
|
);
|
|
msg.hop_count = 5;
|
|
|
|
assert!(msg.max_hops_exceeded(5));
|
|
}
|
|
}
|