Fix Bedrock config for OpenAI GPT-OSS models
All checks were successful
BotServer CI / build (push) Successful in 12m35s
All checks were successful
BotServer CI / build (push) Successful in 12m35s
This commit is contained in:
parent
c523cee177
commit
82bfd0a443
10 changed files with 857 additions and 44 deletions
|
|
@ -23,7 +23,7 @@ impl LlmAssistConfig {
|
|||
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
for line in content.lines() {
|
||||
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
|
||||
let parts: Vec<&str> = line.splitn(2, ',').map(|s| s.trim()).collect();
|
||||
|
||||
if parts.len() < 2 {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -160,30 +160,25 @@ impl BootstrapManager {
|
|||
|
||||
if pm.is_installed("directory") {
|
||||
// Wait for Zitadel to be ready - it might have been started during installation
|
||||
// Use incremental backoff: check frequently at first, then slow down
|
||||
// Use very aggressive backoff for fastest startup detection
|
||||
let mut directory_already_running = zitadel_health_check();
|
||||
if !directory_already_running {
|
||||
info!("Zitadel not responding to health check, waiting with incremental backoff...");
|
||||
// Check intervals: 1s x5, 2s x5, 5s x5, 10s x3 = ~60s total
|
||||
let intervals: [(u64, u32); 4] = [(1, 5), (2, 5), (5, 5), (10, 3)];
|
||||
let mut total_waited: u64 = 0;
|
||||
for (interval_secs, count) in intervals {
|
||||
for i in 0..count {
|
||||
if zitadel_health_check() {
|
||||
info!("Zitadel/Directory service is now responding (waited {}s)", total_waited);
|
||||
directory_already_running = true;
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_secs(interval_secs)).await;
|
||||
total_waited += interval_secs;
|
||||
// Show incremental progress every ~10s
|
||||
if total_waited % 10 == 0 || i == 0 {
|
||||
info!("Zitadel health check: {}s elapsed, retrying...", total_waited);
|
||||
}
|
||||
}
|
||||
if directory_already_running {
|
||||
info!("Zitadel not responding to health check, waiting...");
|
||||
// Check every 500ms for fast detection (was: 1s, 2s, 5s, 10s)
|
||||
let mut checks = 0;
|
||||
let max_checks = 120; // 60 seconds max
|
||||
while checks < max_checks {
|
||||
if zitadel_health_check() {
|
||||
info!("Zitadel/Directory service is now responding (checked {} times)", checks);
|
||||
directory_already_running = true;
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
checks += 1;
|
||||
// Log progress every 10 checks (5 seconds)
|
||||
if checks % 10 == 0 {
|
||||
info!("Zitadel health check: {}s elapsed, retrying...", checks / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -207,16 +207,27 @@ pub enum BotExistsResult {
|
|||
/// Check if Zitadel directory is healthy
|
||||
pub fn zitadel_health_check() -> bool {
|
||||
// Check if Zitadel is responding on port 8300
|
||||
if let Ok(output) = Command::new("curl")
|
||||
.args(["-f", "-s", "--connect-timeout", "2", "http://localhost:8300/debug/healthz"])
|
||||
.output()
|
||||
{
|
||||
if output.status.success() {
|
||||
return true;
|
||||
// Use very short timeout for fast startup detection
|
||||
let output = Command::new("curl")
|
||||
.args(["-f", "-s", "--connect-timeout", "1", "-m", "2", "http://localhost:8300/debug/healthz"])
|
||||
.output();
|
||||
|
||||
match output {
|
||||
Ok(result) => {
|
||||
if result.status.success() {
|
||||
let response = String::from_utf8_lossy(&result.stdout);
|
||||
debug!("Zitadel health check response: {}", response);
|
||||
return response.trim() == "ok";
|
||||
}
|
||||
let stderr = String::from_utf8_lossy(&result.stderr);
|
||||
debug!("Zitadel health check failed: {}", stderr);
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Zitadel health check error: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: just check if port 8300 is listening
|
||||
// Fast fallback: just check if port 8300 is listening
|
||||
match Command::new("nc")
|
||||
.args(["-z", "-w", "1", "127.0.0.1", "8300"])
|
||||
.output()
|
||||
|
|
|
|||
|
|
@ -503,7 +503,7 @@ impl WhatsAppAdapter {
|
|||
}
|
||||
|
||||
if !list_block.is_empty() {
|
||||
if !current_part.is_empty() && current_part.len() + list_block.len() + 1 <= max_length {
|
||||
if !current_part.is_empty() && current_part.len() + list_block.len() < max_length {
|
||||
current_part.push('\n');
|
||||
current_part.push_str(&list_block);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -416,7 +416,7 @@ impl BotOrchestrator {
|
|||
let session_id = Uuid::parse_str(&message.session_id)?;
|
||||
let message_content = message.content.clone();
|
||||
|
||||
let (session, context_data, history, model, key, system_prompt) = {
|
||||
let (session, context_data, history, model, key, system_prompt, bot_llm_url) = {
|
||||
let state_clone = self.state.clone();
|
||||
tokio::task::spawn_blocking(
|
||||
move || -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
|
||||
|
|
@ -453,14 +453,19 @@ impl BotOrchestrator {
|
|||
.get_config(&session.bot_id, "llm-key", Some(""))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Load bot-specific llm-url (may differ from global default)
|
||||
let bot_llm_url = config_manager
|
||||
.get_bot_config_value(&session.bot_id, "llm-url")
|
||||
.ok();
|
||||
|
||||
// Load system-prompt from config.csv, fallback to default
|
||||
let system_prompt = config_manager
|
||||
.get_config(&session.bot_id, "system-prompt", Some("You are a helpful assistant with access to tools that can help you complete tasks. When a user's request matches one of your available tools, use the appropriate tool instead of providing a generic response."))
|
||||
.unwrap_or_else(|_| "You are a helpful assistant.".to_string());
|
||||
.unwrap_or_else(|_| "You are a helpful General Bots assistant.".to_string());
|
||||
|
||||
info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(200)]);
|
||||
info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(500)]);
|
||||
|
||||
Ok((session, context_data, history, model, key, system_prompt))
|
||||
Ok((session, context_data, history, model, key, system_prompt, bot_llm_url))
|
||||
},
|
||||
)
|
||||
.await??
|
||||
|
|
@ -615,7 +620,14 @@ impl BotOrchestrator {
|
|||
}
|
||||
|
||||
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
|
||||
let llm = self.state.llm_provider.clone();
|
||||
|
||||
// Use bot-specific LLM provider if the bot has its own llm-url configured
|
||||
let llm: std::sync::Arc<dyn crate::llm::LLMProvider> = if let Some(ref url) = bot_llm_url {
|
||||
info!("Bot has custom llm-url: {}, creating per-bot LLM provider", url);
|
||||
crate::llm::create_llm_provider_from_url(url, Some(model.clone()), None)
|
||||
} else {
|
||||
self.state.llm_provider.clone()
|
||||
};
|
||||
|
||||
let model_clone = model.clone();
|
||||
let key_clone = key.clone();
|
||||
|
|
|
|||
|
|
@ -482,7 +482,7 @@ impl ConfigManager {
|
|||
|
||||
let mut updated = 0;
|
||||
for line in content.lines().skip(1) {
|
||||
let parts: Vec<&str> = line.split(',').collect();
|
||||
let parts: Vec<&str> = line.splitn(2, ',').collect();
|
||||
if parts.len() >= 2 {
|
||||
let key = parts[0].trim();
|
||||
let value = parts[1].trim();
|
||||
|
|
|
|||
239
src/llm/bedrock.rs
Normal file
239
src/llm/bedrock.rs
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use log::{error, info};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::llm::LLMProvider;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BedrockClient {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl BedrockClient {
|
||||
pub fn new(base_url: String) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for BedrockClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
messages: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
|
||||
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
};
|
||||
|
||||
|
||||
let mut messages_limited = Vec::new();
|
||||
if let Some(msg_array) = raw_messages.as_array() {
|
||||
for msg in msg_array {
|
||||
messages_limited.push(msg.clone());
|
||||
}
|
||||
}
|
||||
let formatted_messages = serde_json::Value::Array(messages_limited);
|
||||
|
||||
let auth_header = if key.starts_with("Bearer ") {
|
||||
key.to_string()
|
||||
} else {
|
||||
format!("Bearer {}", key)
|
||||
};
|
||||
|
||||
let request_body = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": formatted_messages,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
info!("Sending request to Bedrock endpoint: {}", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
.post(&self.base_url)
|
||||
.header("Authorization", &auth_header)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("Bedrock generate error: {}", error_text);
|
||||
return Err(format!("Bedrock API error ({}): {}", status, error_text).into());
|
||||
}
|
||||
|
||||
let json: Value = response.json().await?;
|
||||
|
||||
if let Some(choices) = json.get("choices") {
|
||||
if let Some(first_choice) = choices.get(0) {
|
||||
if let Some(message) = first_choice.get("message") {
|
||||
if let Some(content) = message.get("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
return Ok(content_str.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err("Failed to parse response from Bedrock".into())
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
messages: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str,
|
||||
key: &str,
|
||||
tools: Option<&Vec<Value>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
|
||||
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
};
|
||||
|
||||
let mut messages_limited = Vec::new();
|
||||
if let Some(msg_array) = raw_messages.as_array() {
|
||||
for msg in msg_array {
|
||||
messages_limited.push(msg.clone());
|
||||
}
|
||||
}
|
||||
let formatted_messages = serde_json::Value::Array(messages_limited);
|
||||
|
||||
let auth_header = if key.starts_with("Bearer ") {
|
||||
key.to_string()
|
||||
} else {
|
||||
format!("Bearer {}", key)
|
||||
};
|
||||
|
||||
let mut request_body = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": formatted_messages,
|
||||
"stream": true
|
||||
});
|
||||
|
||||
if let Some(tools_value) = tools {
|
||||
if !tools_value.is_empty() {
|
||||
request_body["tools"] = serde_json::json!(tools_value);
|
||||
info!("Added {} tools to Bedrock request", tools_value.len());
|
||||
}
|
||||
}
|
||||
|
||||
let stream_url = if self.base_url.ends_with("/invoke") {
|
||||
self.base_url.replace("/invoke", "/invoke-with-response-stream")
|
||||
} else {
|
||||
self.base_url.clone()
|
||||
};
|
||||
|
||||
info!("Sending streaming request to Bedrock endpoint: {}", stream_url);
|
||||
|
||||
let response = self.client
|
||||
.post(&stream_url)
|
||||
.header("Authorization", &auth_header)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("Bedrock generate_stream error: {}", error_text);
|
||||
return Err(format!("Bedrock API error ({}): {}", status, error_text).into());
|
||||
}
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut tool_call_buffer = String::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||||
for line in text.split('\n') {
|
||||
let line = line.trim();
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(json) = serde_json::from_str::<Value>(data) {
|
||||
if let Some(choices) = json.get("choices") {
|
||||
if let Some(first_choice) = choices.get(0) {
|
||||
if let Some(delta) = first_choice.get("delta") {
|
||||
// Handle standard content streaming
|
||||
if let Some(content) = delta.get("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
if !content_str.is_empty() && tx.send(content_str.to_string()).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls streaming
|
||||
if let Some(tool_calls) = delta.get("tool_calls") {
|
||||
if let Some(calls_array) = tool_calls.as_array() {
|
||||
if let Some(first_call) = calls_array.first() {
|
||||
if let Some(function) = first_call.get("function") {
|
||||
// Stream function JSON representation just like OpenAI does
|
||||
if let Some(name) = function.get("name") {
|
||||
if let Some(name_str) = name.as_str() {
|
||||
tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name_str);
|
||||
}
|
||||
}
|
||||
if let Some(args) = function.get("arguments") {
|
||||
if let Some(args_str) = args.as_str() {
|
||||
tool_call_buffer.push_str(args_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Bedrock stream reading error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize tool call JSON parsing
|
||||
if !tool_call_buffer.is_empty() {
|
||||
tool_call_buffer.push_str("\"}");
|
||||
if tx.send(format!("`tool_call`: {}", tool_call_buffer)).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_job(&self, _session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -14,11 +14,15 @@ pub mod llm_models;
|
|||
pub mod local;
|
||||
pub mod rate_limiter;
|
||||
pub mod smart_router;
|
||||
pub mod vertex;
|
||||
pub mod bedrock;
|
||||
|
||||
pub use claude::ClaudeClient;
|
||||
pub use glm::GLMClient;
|
||||
pub use llm_models::get_handler;
|
||||
pub use rate_limiter::{ApiRateLimiter, RateLimits};
|
||||
pub use vertex::VertexTokenManager;
|
||||
pub use bedrock::BedrockClient;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
|
|
@ -248,14 +252,16 @@ impl LLMProvider for OpenAIClient {
|
|||
// GLM-4.7 has 202750 tokens, other models vary
|
||||
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
|
||||
202750
|
||||
} else if model.contains("gpt-4") {
|
||||
128000
|
||||
} else if model.contains("gemini") {
|
||||
1000000 // Google Gemini models have 1M+ token context
|
||||
} else if model.contains("gpt-oss") || model.contains("gpt-4") {
|
||||
128000 // Cerebras gpt-oss models and GPT-4 variants
|
||||
} else if model.contains("gpt-3.5") {
|
||||
16385
|
||||
} else if model.starts_with("http://localhost:808") || model == "local" {
|
||||
768 // Local llama.cpp server context limit
|
||||
} else {
|
||||
4096 // Default conservative limit
|
||||
32768 // Default conservative limit for modern models
|
||||
};
|
||||
|
||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
||||
|
|
@ -329,14 +335,16 @@ impl LLMProvider for OpenAIClient {
|
|||
// GLM-4.7 has 202750 tokens, other models vary
|
||||
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
|
||||
202750
|
||||
} else if model.contains("gpt-4") {
|
||||
128000
|
||||
} else if model.contains("gemini") {
|
||||
1000000 // Google Gemini models have 1M+ token context
|
||||
} else if model.contains("gpt-oss") || model.contains("gpt-4") {
|
||||
128000 // Cerebras gpt-oss models and GPT-4 variants
|
||||
} else if model.contains("gpt-3.5") {
|
||||
16385
|
||||
} else if model.starts_with("http://localhost:808") || model == "local" {
|
||||
768 // Local llama.cpp server context limit
|
||||
} else {
|
||||
4096 // Default conservative limit
|
||||
32768 // Default conservative limit for modern models
|
||||
};
|
||||
|
||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
||||
|
|
@ -448,6 +456,8 @@ pub enum LLMProviderType {
|
|||
Claude,
|
||||
AzureClaude,
|
||||
GLM,
|
||||
Bedrock,
|
||||
Vertex,
|
||||
}
|
||||
|
||||
impl From<&str> for LLMProviderType {
|
||||
|
|
@ -461,6 +471,10 @@ impl From<&str> for LLMProviderType {
|
|||
}
|
||||
} else if lower.contains("z.ai") || lower.contains("glm") {
|
||||
Self::GLM
|
||||
} else if lower.contains("bedrock") {
|
||||
Self::Bedrock
|
||||
} else if lower.contains("googleapis.com") || lower.contains("vertex") || lower.contains("generativelanguage") {
|
||||
Self::Vertex
|
||||
} else {
|
||||
Self::OpenAI
|
||||
}
|
||||
|
|
@ -498,6 +512,15 @@ pub fn create_llm_provider(
|
|||
info!("Creating GLM/z.ai LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(GLMClient::new(base_url))
|
||||
}
|
||||
LLMProviderType::Bedrock => {
|
||||
info!("Creating Bedrock LLM provider with exact URL: {}", base_url);
|
||||
std::sync::Arc::new(BedrockClient::new(base_url))
|
||||
}
|
||||
LLMProviderType::Vertex => {
|
||||
info!("Creating Vertex/Gemini LLM provider with URL: {}", base_url);
|
||||
// Re-export the struct if we haven't already
|
||||
std::sync::Arc::new(crate::llm::vertex::VertexClient::new(base_url, endpoint_path))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
515
src/llm/vertex.rs
Normal file
515
src/llm/vertex.rs
Normal file
|
|
@ -0,0 +1,515 @@
|
|||
//! Google Vertex AI and Gemini Native API Integration
|
||||
//! Support for both OpenAI-compatible endpoints and native Google AI Studio / Vertex AI endpoints.
|
||||
|
||||
use log::{error, info, trace};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
|
||||
const GOOGLE_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ServiceAccountKey {
|
||||
client_email: String,
|
||||
private_key: String,
|
||||
token_uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CachedToken {
|
||||
access_token: String,
|
||||
expires_at: std::time::Instant,
|
||||
}
|
||||
|
||||
/// Manages OAuth2 access tokens for Google Vertex AI
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VertexTokenManager {
|
||||
credentials_path: String,
|
||||
cached_token: Arc<RwLock<Option<CachedToken>>>,
|
||||
}
|
||||
|
||||
impl VertexTokenManager {
|
||||
pub fn new(credentials_path: &str) -> Self {
|
||||
Self {
|
||||
credentials_path: credentials_path.to_string(),
|
||||
cached_token: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a valid access token, refreshing if necessary
|
||||
pub async fn get_access_token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Check cached token
|
||||
{
|
||||
let cache = self.cached_token.read().await;
|
||||
if let Some(ref token) = *cache {
|
||||
// Return cached token if it expires in more than 60 seconds
|
||||
if token.expires_at > std::time::Instant::now() + std::time::Duration::from_secs(60) {
|
||||
trace!("Using cached Vertex AI access token");
|
||||
return Ok(token.access_token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate new token
|
||||
info!("Generating new Vertex AI access token from service account");
|
||||
let token = self.generate_token().await?;
|
||||
|
||||
// Cache it
|
||||
{
|
||||
let mut cache = self.cached_token.write().await;
|
||||
*cache = Some(token.clone());
|
||||
}
|
||||
|
||||
Ok(token.access_token)
|
||||
}
|
||||
|
||||
async fn generate_token(&self) -> Result<CachedToken, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let content = if self.credentials_path.trim().starts_with('{') {
|
||||
self.credentials_path.clone()
|
||||
} else {
|
||||
// Expand ~ to home directory
|
||||
let expanded_path = if self.credentials_path.starts_with('~') {
|
||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
|
||||
self.credentials_path.replacen('~', &home, 1)
|
||||
} else {
|
||||
self.credentials_path.clone()
|
||||
};
|
||||
|
||||
tokio::fs::read_to_string(&expanded_path)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read service account key from '{}': {}", expanded_path, e))?
|
||||
};
|
||||
|
||||
let sa_key: ServiceAccountKey = serde_json::from_str(&content)
|
||||
.map_err(|e| format!("Failed to parse service account JSON: {}", e))?;
|
||||
|
||||
let token_uri = sa_key.token_uri.as_deref().unwrap_or(GOOGLE_TOKEN_URL);
|
||||
|
||||
// Create JWT assertion
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map_err(|e| format!("System time error: {}", e))?
|
||||
.as_secs();
|
||||
|
||||
let jwt_claims = JwtClaims {
|
||||
iss: sa_key.client_email.clone(),
|
||||
scope: GOOGLE_SCOPE.to_string(),
|
||||
aud: token_uri.to_string(),
|
||||
iat: now,
|
||||
exp: now + 3600,
|
||||
};
|
||||
|
||||
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
|
||||
let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(sa_key.private_key.as_bytes())
|
||||
.map_err(|e| format!("Failed to parse RSA private key: {}", e))?;
|
||||
|
||||
let jwt = jsonwebtoken::encode(&header, &jwt_claims, &encoding_key)
|
||||
.map_err(|e| format!("Failed to encode JWT: {}", e))?;
|
||||
|
||||
// Exchange JWT for access token
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post(token_uri)
|
||||
.form(&[
|
||||
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
|
||||
("assertion", &jwt),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to request token: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
error!("Google OAuth2 token request failed: {} - {}", status, body);
|
||||
return Err(format!("Token request failed with status {}: {}", status, body).into());
|
||||
}
|
||||
|
||||
let token_response: TokenResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse token response: {}", e))?;
|
||||
|
||||
info!("Successfully obtained Vertex AI access token (expires in {}s)", token_response.expires_in);
|
||||
|
||||
Ok(CachedToken {
|
||||
access_token: token_response.access_token,
|
||||
expires_at: std::time::Instant::now() + std::time::Duration::from_secs(token_response.expires_in),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct JwtClaims {
|
||||
iss: String,
|
||||
scope: String,
|
||||
aud: String,
|
||||
iat: u64,
|
||||
exp: u64,
|
||||
}
|
||||
|
||||
/// Builds the Vertex AI OpenAI-compatible endpoint URL
|
||||
pub fn build_vertex_url(project: &str, location: &str) -> String {
|
||||
format!(
|
||||
"https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/endpoints/openapi"
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if a URL is a Vertex AI endpoint
|
||||
pub fn is_vertex_ai_url(url: &str) -> bool {
|
||||
url.contains("aiplatform.googleapis.com") || url.contains("generativelanguage.googleapis.com")
|
||||
}
|
||||
|
||||
use crate::llm::LLMProvider;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct VertexClient {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
endpoint_path: String,
|
||||
token_manager: Arc<tokio::sync::RwLock<Option<Arc<VertexTokenManager>>>>,
|
||||
}
|
||||
|
||||
impl VertexClient {
|
||||
pub fn new(base_url: String, endpoint_path: Option<String>) -> Self {
|
||||
let endpoint = endpoint_path.unwrap_or_else(|| {
|
||||
if base_url.contains("generativelanguage.googleapis.com") {
|
||||
// Default to OpenAI compatibility if possible, but generate_stream will auto-detect
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
} else if base_url.contains("aiplatform.googleapis.com") && !base_url.contains("openapi") {
|
||||
// If it's a naked aiplatform URL, it might be missing the project/location path
|
||||
// We'll leave it empty and let the caller provide a full URL if needed
|
||||
"".to_string()
|
||||
} else {
|
||||
"".to_string()
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url,
|
||||
endpoint_path: endpoint,
|
||||
token_manager: Arc::new(tokio::sync::RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_auth_header(&self, key: &str) -> (&'static str, String) {
|
||||
// If the key is a path or JSON (starts with {), it's a Service Account
|
||||
if key.starts_with('{') || key.starts_with('~') || key.starts_with('/') || key.ends_with(".json") {
|
||||
let mut manager_opt = self.token_manager.write().await;
|
||||
if manager_opt.is_none() {
|
||||
*manager_opt = Some(Arc::new(VertexTokenManager::new(key)));
|
||||
}
|
||||
if let Some(manager) = manager_opt.as_ref() {
|
||||
match manager.get_access_token().await {
|
||||
Ok(token) => return ("Authorization", format!("Bearer {}", token)),
|
||||
Err(e) => error!("Failed to get Vertex OAuth token: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to Google API Key (Google AI Studio / Gemini Developer APIs)
|
||||
("x-goog-api-key", key.to_string())
|
||||
}
|
||||
|
||||
fn convert_messages_to_gemini(&self, messages: &Value) -> Value {
|
||||
let mut contents = Vec::new();
|
||||
if let Some(msg_array) = messages.as_array() {
|
||||
for msg in msg_array {
|
||||
let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("user");
|
||||
let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or("");
|
||||
|
||||
let gemini_role = match role {
|
||||
"assistant" => "model",
|
||||
"system" => "user", // Gemini doesn't have system role in 'contents' by default, often wrapped in systemInstruction
|
||||
_ => "user",
|
||||
};
|
||||
|
||||
contents.push(serde_json::json!({
|
||||
"role": gemini_role,
|
||||
"parts": [{"text": content}]
|
||||
}));
|
||||
}
|
||||
}
|
||||
serde_json::json!(contents)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for VertexClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
messages: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
};
|
||||
|
||||
let (header_name, auth_value) = self.get_auth_header(key).await;
|
||||
|
||||
// Auto-detect if we should use native Gemini API
|
||||
let is_native = self.endpoint_path.contains(":generateContent") ||
|
||||
self.base_url.contains("generativelanguage") && !self.endpoint_path.contains("openai");
|
||||
|
||||
let full_url = if is_native && !self.endpoint_path.contains(":generateContent") {
|
||||
// Build native URL for generativelanguage
|
||||
format!("{}/v1beta/models/{}:generateContent", self.base_url.trim_end_matches('/'), model)
|
||||
} else {
|
||||
format!("{}{}", self.base_url, self.endpoint_path)
|
||||
}.trim_end_matches('/').to_string();
|
||||
|
||||
let request_body = if is_native {
|
||||
serde_json::json!({
|
||||
"contents": self.convert_messages_to_gemini(raw_messages),
|
||||
"generationConfig": {
|
||||
"temperature": 0.7,
|
||||
}
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({
|
||||
"model": model,
|
||||
"messages": raw_messages,
|
||||
"stream": false
|
||||
})
|
||||
};
|
||||
|
||||
info!("Sending request to Vertex/Gemini endpoint: {}", full_url);
|
||||
|
||||
let response = self.client
|
||||
.post(full_url)
|
||||
.header(header_name, &auth_value)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("Vertex/Gemini generate error: {}", error_text);
|
||||
return Err(format!("Google API error ({}): {}", status, error_text).into());
|
||||
}
|
||||
|
||||
let json: Value = response.json().await?;
|
||||
|
||||
// Try parsing OpenAI format
|
||||
if let Some(choices) = json.get("choices") {
|
||||
if let Some(first_choice) = choices.get(0) {
|
||||
if let Some(message) = first_choice.get("message") {
|
||||
if let Some(content) = message.get("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
return Ok(content_str.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try parsing Native Gemini format
|
||||
if let Some(candidates) = json.get("candidates") {
|
||||
if let Some(first_candidate) = candidates.get(0) {
|
||||
if let Some(content) = first_candidate.get("content") {
|
||||
if let Some(parts) = content.get("parts") {
|
||||
if let Some(first_part) = parts.get(0) {
|
||||
if let Some(text) = first_part.get("text") {
|
||||
if let Some(text_str) = text.as_str() {
|
||||
return Ok(text_str.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err("Failed to parse response from Vertex/Gemini (unknown format)".into())
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
messages: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str,
|
||||
key: &str,
|
||||
tools: Option<&Vec<Value>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
};
|
||||
|
||||
let (header_name, auth_value) = self.get_auth_header(key).await;
|
||||
|
||||
// Auto-detect if we should use native Gemini API
|
||||
let is_native = self.endpoint_path.contains("GenerateContent") ||
|
||||
(self.base_url.contains("googleapis.com") &&
|
||||
!self.endpoint_path.contains("openai") &&
|
||||
!self.endpoint_path.contains("completions"));
|
||||
|
||||
let full_url = if is_native && !self.endpoint_path.contains("GenerateContent") {
|
||||
// Build native URL for Gemini if it looks like we need it
|
||||
if self.base_url.contains("aiplatform") {
|
||||
// Global Vertex endpoint format
|
||||
format!("{}/v1/publishers/google/models/{}:streamGenerateContent", self.base_url.trim_end_matches('/'), model)
|
||||
} else {
|
||||
// Google AI Studio format
|
||||
format!("{}/v1beta/models/{}:streamGenerateContent", self.base_url.trim_end_matches('/'), model)
|
||||
}
|
||||
} else {
|
||||
format!("{}{}", self.base_url, self.endpoint_path)
|
||||
}.trim_end_matches('/').to_string();
|
||||
|
||||
let mut request_body = if is_native {
|
||||
let mut body = serde_json::json!({
|
||||
"contents": self.convert_messages_to_gemini(raw_messages),
|
||||
"generationConfig": {
|
||||
"temperature": 0.7,
|
||||
}
|
||||
});
|
||||
|
||||
// Handle thinking models if requested
|
||||
if model.contains("preview") || model.contains("3.1-flash") {
|
||||
body["generationConfig"]["thinkingConfig"] = serde_json::json!({
|
||||
"thinkingLevel": "LOW"
|
||||
});
|
||||
}
|
||||
body
|
||||
} else {
|
||||
serde_json::json!({
|
||||
"model": model,
|
||||
"messages": raw_messages,
|
||||
"stream": true
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(tools_value) = tools {
|
||||
if !tools_value.is_empty() {
|
||||
request_body["tools"] = serde_json::json!(tools_value);
|
||||
info!("Added {} tools to Vertex request", tools_value.len());
|
||||
}
|
||||
}
|
||||
|
||||
info!("Sending streaming request to Vertex/Gemini endpoint: {}", full_url);
|
||||
|
||||
let response = self.client
|
||||
.post(full_url)
|
||||
.header(header_name, &auth_value)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("Vertex/Gemini generate_stream error: {}", error_text);
|
||||
return Err(format!("Google API error ({}): {}", status, error_text).into());
|
||||
}
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut tool_call_buffer = String::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||||
for line in text.split('\n') {
|
||||
let line = line.trim();
|
||||
if line.is_empty() { continue; }
|
||||
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
// --- OpenAI SSE Format ---
|
||||
if data == "[DONE]" { continue; }
|
||||
if let Ok(json) = serde_json::from_str::<Value>(data) {
|
||||
if let Some(choices) = json.get("choices") {
|
||||
if let Some(first_choice) = choices.get(0) {
|
||||
if let Some(delta) = first_choice.get("delta") {
|
||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
if !content.is_empty() {
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls...
|
||||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||||
if let Some(first_call) = tool_calls.first() {
|
||||
if let Some(function) = first_call.get("function") {
|
||||
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
|
||||
tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name);
|
||||
}
|
||||
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
|
||||
tool_call_buffer.push_str(args);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// --- Native Gemini JSON format ---
|
||||
// It usually arrives as raw JSON objects, sometimes with leading commas or brackets in a stream array
|
||||
let trimmed = line.trim_start_matches(|c| c == ',' || c == '[' || c == ']').trim_end_matches(']');
|
||||
if trimmed.is_empty() { continue; }
|
||||
|
||||
if let Ok(json) = serde_json::from_str::<Value>(trimmed) {
|
||||
if let Some(candidates) = json.get("candidates").and_then(|c| c.as_array()) {
|
||||
if let Some(candidate) = candidates.first() {
|
||||
if let Some(content) = candidate.get("content") {
|
||||
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
|
||||
for part in parts {
|
||||
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
|
||||
if !text.is_empty() {
|
||||
let _ = tx.send(text.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Vertex stream reading error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !tool_call_buffer.is_empty() {
|
||||
tool_call_buffer.push_str("\"}");
|
||||
let _ = tx.send(format!("`tool_call`: {}", tool_call_buffer)).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_job(&self, _session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -880,12 +880,30 @@ async fn start_drive_monitors(
|
|||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let local_dev_bots = tokio::task::spawn_blocking(move || {
|
||||
let mut bots = std::collections::HashSet::new();
|
||||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/opt/gbo/data".to_string());
|
||||
if let Ok(entries) = std::fs::read_dir(data_dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()).map(|e| e.eq_ignore_ascii_case("gbai")).unwrap_or(false) {
|
||||
if let Some(bot_name) = path.file_stem().and_then(|s| s.to_str()) {
|
||||
bots.insert(bot_name.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bots
|
||||
})
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
info!("Found {} active bots to monitor", bots_to_monitor.len());
|
||||
|
||||
for (bot_id, bot_name) in bots_to_monitor {
|
||||
// Skip default bot - it's managed locally via ConfigWatcher
|
||||
if bot_name == "default" {
|
||||
info!("Skipping DriveMonitor for 'default' bot - managed via ConfigWatcher");
|
||||
// Skip default bot and local dev bots - they are managed locally
|
||||
if bot_name == "default" || local_dev_bots.contains(&bot_name) {
|
||||
info!("Skipping DriveMonitor for '{}' bot - managed locally via ConfigWatcher/LocalFileMonitor", bot_name);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue