207 lines
5.3 KiB
Rust
207 lines
5.3 KiB
Rust
use async_trait::async_trait;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::{mpsc, Mutex};
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolResult {
|
|
pub success: bool,
|
|
pub output: String,
|
|
pub requires_input: bool,
|
|
pub session_id: String,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Tool {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: HashMap<String, String>,
|
|
pub script: String,
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait ToolExecutor: Send + Sync {
|
|
async fn execute(
|
|
&self,
|
|
tool_name: &str,
|
|
session_id: &str,
|
|
user_id: &str,
|
|
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>>;
|
|
async fn provide_input(
|
|
&self,
|
|
session_id: &str,
|
|
input: &str,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
|
async fn get_output(
|
|
&self,
|
|
session_id: &str,
|
|
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>>;
|
|
async fn is_waiting_for_input(
|
|
&self,
|
|
session_id: &str,
|
|
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>;
|
|
}
|
|
|
|
pub struct MockToolExecutor;
|
|
|
|
impl MockToolExecutor {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ToolExecutor for MockToolExecutor {
|
|
async fn execute(
|
|
&self,
|
|
tool_name: &str,
|
|
session_id: &str,
|
|
user_id: &str,
|
|
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
|
|
Ok(ToolResult {
|
|
success: true,
|
|
output: format!("Mock tool {} executed for user {}", tool_name, user_id),
|
|
requires_input: false,
|
|
session_id: session_id.to_string(),
|
|
})
|
|
}
|
|
|
|
async fn provide_input(
|
|
&self,
|
|
_session_id: &str,
|
|
_input: &str,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn get_output(
|
|
&self,
|
|
_session_id: &str,
|
|
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
|
Ok(vec!["Mock output".to_string()])
|
|
}
|
|
|
|
async fn is_waiting_for_input(
|
|
&self,
|
|
_session_id: &str,
|
|
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
|
Ok(false)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ToolManager {
|
|
tools: HashMap<String, Tool>,
|
|
waiting_responses: Arc<Mutex<HashMap<String, mpsc::Sender<String>>>>,
|
|
}
|
|
|
|
impl ToolManager {
|
|
pub fn new() -> Self {
|
|
let mut tools = HashMap::new();
|
|
|
|
let calculator_tool = Tool {
|
|
name: "calculator".to_string(),
|
|
description: "Perform calculations".to_string(),
|
|
parameters: HashMap::from([
|
|
(
|
|
"operation".to_string(),
|
|
"add|subtract|multiply|divide".to_string(),
|
|
),
|
|
("a".to_string(), "number".to_string()),
|
|
("b".to_string(), "number".to_string()),
|
|
]),
|
|
script: r#"
|
|
print("Calculator started");
|
|
"#
|
|
.to_string(),
|
|
};
|
|
|
|
tools.insert(calculator_tool.name.clone(), calculator_tool);
|
|
Self {
|
|
tools,
|
|
waiting_responses: Arc::new(Mutex::new(HashMap::new())),
|
|
}
|
|
}
|
|
|
|
pub fn get_tool(&self, name: &str) -> Option<&Tool> {
|
|
self.tools.get(name)
|
|
}
|
|
|
|
pub fn list_tools(&self) -> Vec<String> {
|
|
self.tools.keys().cloned().collect()
|
|
}
|
|
|
|
pub async fn execute_tool(
|
|
&self,
|
|
tool_name: &str,
|
|
session_id: &str,
|
|
user_id: &str,
|
|
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
|
|
let _tool = self.get_tool(tool_name).ok_or("Tool not found")?;
|
|
|
|
Ok(ToolResult {
|
|
success: true,
|
|
output: format!("Tool {} started for user {}", tool_name, user_id),
|
|
requires_input: true,
|
|
session_id: session_id.to_string(),
|
|
})
|
|
}
|
|
|
|
pub async fn is_tool_waiting(
|
|
&self,
|
|
session_id: &str,
|
|
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
|
let waiting = self.waiting_responses.lock().await;
|
|
Ok(waiting.contains_key(session_id))
|
|
}
|
|
|
|
pub async fn provide_input(
|
|
&self,
|
|
session_id: &str,
|
|
input: &str,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
self.provide_user_response(session_id, "default_bot", input.to_string())
|
|
}
|
|
|
|
pub async fn get_tool_output(
|
|
&self,
|
|
_session_id: &str,
|
|
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
|
Ok(vec![])
|
|
}
|
|
|
|
pub fn provide_user_response(
|
|
&self,
|
|
user_id: &str,
|
|
bot_id: &str,
|
|
response: String,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let key = format!("{}:{}", user_id, bot_id);
|
|
let waiting = self.waiting_responses.clone();
|
|
|
|
tokio::spawn(async move {
|
|
let mut waiting_lock = waiting.lock().await;
|
|
if let Some(tx) = waiting_lock.get_mut(&key) {
|
|
let _ = tx.send(response).await;
|
|
waiting_lock.remove(&key);
|
|
}
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl Default for ToolManager {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
pub struct ToolApi;
|
|
|
|
impl ToolApi {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|