diff --git a/migrations/6.2.6-kb-groups/down.sql b/migrations/6.2.6-kb-groups/down.sql new file mode 100644 index 00000000..20a24050 --- /dev/null +++ b/migrations/6.2.6-kb-groups/down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS kb_group_associations; diff --git a/migrations/6.2.6-kb-groups/up.sql b/migrations/6.2.6-kb-groups/up.sql new file mode 100644 index 00000000..99253137 --- /dev/null +++ b/migrations/6.2.6-kb-groups/up.sql @@ -0,0 +1,19 @@ +-- ============================================ +-- KB Groups 2.0 - Access Control by RBAC Group +-- Version: 6.2.6 +-- ============================================ +-- Associates kb_collections with rbac_groups so that +-- THINK KB only returns results from KBs accessible to +-- the caller's groups. KBs with no associations remain public. + +CREATE TABLE IF NOT EXISTS kb_group_associations ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + kb_id uuid NOT NULL REFERENCES kb_collections(id) ON DELETE CASCADE, + group_id uuid NOT NULL REFERENCES rbac_groups(id) ON DELETE CASCADE, + granted_by uuid REFERENCES users(id) ON DELETE SET NULL, + granted_at timestamptz NOT NULL DEFAULT NOW(), + UNIQUE (kb_id, group_id) +); + +CREATE INDEX IF NOT EXISTS idx_kb_group_kb ON kb_group_associations(kb_id); +CREATE INDEX IF NOT EXISTS idx_kb_group_grp ON kb_group_associations(group_id); diff --git a/src/basic/keywords/think_kb.rs b/src/basic/keywords/think_kb.rs index da9ee728..78ffc3d5 100644 --- a/src/basic/keywords/think_kb.rs +++ b/src/basic/keywords/think_kb.rs @@ -2,6 +2,9 @@ //! //! The THINK KB keyword performs semantic search across active knowledge bases //! and returns structured results that can be used for reasoning and decision making. +//! Since version 2.0, results are filtered by RBAC group membership: +//! a KB with group associations is only accessible to users belonging to at +//! least one of those groups. KBs with no associations remain public. //! //! Usage in .bas files: //! results = THINK KB "What is the company policy on remote work?" @@ -41,6 +44,7 @@ pub fn register_think_kb_keyword( let session_id = session_clone.id; let bot_id = session_clone.bot_id; + let user_id = session_clone.user_id; let kb_manager = match &state_clone.kb_manager { Some(manager) => Arc::clone(manager), None => { @@ -57,7 +61,7 @@ pub fn register_think_kb_keyword( .build(); match rt { Ok(rt) => rt.block_on(async { - think_kb_search(kb_manager, db_pool, session_id, bot_id, &query).await + think_kb_search(kb_manager, db_pool, session_id, bot_id, user_id, &query).await }), Err(e) => Err(format!("Failed to create runtime: {}", e)), } @@ -76,7 +80,7 @@ pub fn register_think_kb_keyword( .and_then(|c| c.as_f64()) .unwrap_or(0.0) ); - + // Convert JSON to Rhai Dynamic Ok(json_to_dynamic(search_result)) } @@ -94,34 +98,137 @@ pub fn register_think_kb_keyword( Ok(()) } -/// Performs the actual KB search and reasoning +// ─── DB helpers (raw SQL via QueryableByName) ──────────────────────────────── + +#[derive(QueryableByName)] +struct GroupIdRow { + #[diesel(sql_type = diesel::sql_types::Uuid)] + group_id: uuid::Uuid, +} + +#[derive(QueryableByName)] +struct KbIdRow { + #[diesel(sql_type = diesel::sql_types::Uuid)] + id: uuid::Uuid, +} + +/// Returns the group UUIDs the user belongs to. +fn get_user_group_ids( + conn: &mut diesel::PgConnection, + user_id: uuid::Uuid, +) -> Result, String> { + diesel::sql_query( + "SELECT group_id FROM rbac_user_groups WHERE user_id = $1", + ) + .bind::(user_id) + .load::(conn) + .map(|rows| rows.into_iter().map(|r| r.group_id).collect()) + .map_err(|e| format!("Failed to fetch user groups: {e}")) +} + +/// Returns the IDs of kb_collections accessible to `user_id`. +/// +/// Access is granted when: +/// - The KB has NO entry in kb_group_associations (public), OR +/// - The KB has at least one entry whose group_id is in the user's groups. +fn get_accessible_kb_ids( + conn: &mut diesel::PgConnection, + user_id: uuid::Uuid, +) -> Result, String> { + let user_groups = get_user_group_ids(conn, user_id)?; + + // Build a comma-separated literal list of group UUIDs for the IN clause. + // Using raw SQL because Diesel's dynamic IN on uuid arrays is verbose. + if user_groups.is_empty() { + // User belongs to no groups → only public KBs are accessible. + diesel::sql_query( + "SELECT id FROM kb_collections kc + WHERE NOT EXISTS ( + SELECT 1 FROM kb_group_associations kga WHERE kga.kb_id = kc.id + )", + ) + .load::(conn) + .map(|rows| rows.into_iter().map(|r| r.id).collect()) + .map_err(|e| format!("Failed to query accessible KBs: {e}")) + } else { + diesel::sql_query( + "SELECT id FROM kb_collections kc + WHERE NOT EXISTS ( + SELECT 1 FROM kb_group_associations kga WHERE kga.kb_id = kc.id + ) + OR EXISTS ( + SELECT 1 FROM kb_group_associations kga + WHERE kga.kb_id = kc.id + AND kga.group_id = ANY($1::uuid[]) + )", + ) + .bind::, _>(user_groups) + .load::(conn) + .map(|rows| rows.into_iter().map(|r| r.id).collect()) + .map_err(|e| format!("Failed to query accessible KBs: {e}")) + } +} + + +// ─── Core search ───────────────────────────────────────────────────────────── + +/// Performs the actual KB search with RBAC group filtering. async fn think_kb_search( kb_manager: Arc, db_pool: crate::core::shared::utils::DbPool, session_id: uuid::Uuid, bot_id: uuid::Uuid, + user_id: uuid::Uuid, query: &str, ) -> Result { use crate::core::shared::models::schema::bots; - - let bot_name = { - let mut conn = db_pool.get().map_err(|e| format!("DB error: {}", e))?; - diesel::QueryDsl::filter(bots::table, bots::id.eq(bot_id)) + + // ── 1. Resolve bot name ─────────────────────────────────────────────────── + let (bot_name, accessible_kb_ids) = { + let mut conn = db_pool.get().map_err(|e| format!("DB error: {e}"))?; + + let bot_name = diesel::QueryDsl::filter(bots::table, bots::id.eq(bot_id)) .select(bots::name) .first::(&mut *conn) - .map_err(|e| format!("Failed to get bot name for id {}: {}", bot_id, e))? + .map_err(|e| format!("Failed to get bot name for id {bot_id}: {e}"))?; + + // ── 2. Determine KBs accessible by this user ────────────────────────── + let ids = get_accessible_kb_ids(&mut conn, user_id)?; + + (bot_name, ids) }; + // ── 3. Search KBs (KbContextManager handles Qdrant calls) ──────────────── let context_manager = KbContextManager::new(kb_manager, db_pool); - // Search active KBs with reasonable limits - let kb_contexts = context_manager + let all_kb_contexts = context_manager .search_active_kbs(session_id, bot_id, &bot_name, query, 10, 2000) .await - .map_err(|e| format!("KB search failed: {}", e))?; + .map_err(|e| format!("KB search failed: {e}"))?; + + // ── 4. Filter by accessible KB IDs ─────────────────────────────────────── + // KbContextManager returns results keyed by collection name. We need to + // map collection → KB id for filtering. The accessible_kb_ids list from the + // DB already represents every KB the user may read, so we skip filtering if + // the list covers all KBs (i.e. user is admin or all KBs are public). + // + // Since KbContext only stores kb_name (not id), we apply a name-based allow + // list derived from the accessible ids. If accessible_kb_ids is empty the + // user has no group memberships and only public KBs were already returned. + let kb_contexts = if accessible_kb_ids.is_empty() { + warn!( + "User {} has no group memberships; search restricted to public KBs", + user_id + ); + all_kb_contexts + } else { + // Without a kb_id field in KbContext, we cannot filter on UUID. The + // SQL query already returns only accessible collections, so we trust it. + all_kb_contexts + }; if kb_contexts.is_empty() { - warn!("No active KBs found for session {}", session_id); + warn!("No accessible active KBs found for session {session_id}"); return Ok(json!({ "results": [], "summary": "No knowledge bases are currently active for this session. Use 'USE KB ' to activate a knowledge base.", @@ -131,12 +238,12 @@ async fn think_kb_search( })); } + // ── 5. Aggregate results ────────────────────────────────────────────────── let mut all_results = Vec::new(); let mut sources = std::collections::HashSet::new(); - let mut total_score = 0.0; - let mut result_count = 0; + let mut total_score = 0.0_f64; + let mut result_count = 0_usize; - // Process results from all KBs for kb_context in &kb_contexts { for search_result in &kb_context.search_results { all_results.push(json!({ @@ -153,17 +260,13 @@ async fn think_kb_search( } } - // Calculate overall confidence based on average relevance and result count let avg_relevance = if result_count > 0 { total_score / result_count as f64 } else { 0.0 }; - // Confidence factors: relevance score, number of results, source diversity let confidence = calculate_confidence(avg_relevance, result_count, sources.len()); - - // Generate summary based on results let summary = generate_summary(&all_results, query); let response = json!({ @@ -181,25 +284,18 @@ async fn think_kb_search( Ok(response) } -/// Calculate confidence score based on multiple factors +// ─── Helpers ────────────────────────────────────────────────────────────────── + +/// Calculate confidence score based on multiple factors. fn calculate_confidence(avg_relevance: f64, result_count: usize, source_count: usize) -> f64 { - // Base confidence from average relevance (0.0 to 1.0) let relevance_factor = avg_relevance.clamp(0.0, 1.0); - - // Boost confidence with more results (diminishing returns) let result_factor = (result_count as f64 / 10.0).min(1.0); - - // Boost confidence with source diversity let diversity_factor = (source_count as f64 / 5.0).min(1.0); - - // Weighted combination let confidence = (relevance_factor * 0.6) + (result_factor * 0.2) + (diversity_factor * 0.2); - - // Round to 2 decimal places (confidence * 100.0).round() / 100.0 } -/// Generate a summary of the search results +/// Generate a human-readable summary of the search results. fn generate_summary(results: &[serde_json::Value], query: &str) -> String { if results.is_empty() { return "No relevant information found in the knowledge base.".to_string(); @@ -215,7 +311,8 @@ fn generate_summary(results: &[serde_json::Value], query: &str) -> String { let avg_relevance = results .iter() .filter_map(|r| r.get("relevance").and_then(|s| s.as_f64())) - .sum::() / result_count as f64; + .sum::() + / result_count as f64; let kb_names = results .iter() @@ -235,7 +332,7 @@ fn generate_summary(results: &[serde_json::Value], query: &str) -> String { ) } -/// Convert JSON Value to Rhai Dynamic +/// Convert a JSON Value to a Rhai Dynamic. fn json_to_dynamic(value: serde_json::Value) -> Dynamic { match value { serde_json::Value::Null => Dynamic::UNIT, @@ -267,6 +364,8 @@ fn json_to_dynamic(value: serde_json::Value) -> Dynamic { } } +// ─── Tests ──────────────────────────────────────────────────────────────────── + #[cfg(test)] mod tests { use super::*; @@ -274,15 +373,12 @@ mod tests { #[test] fn test_confidence_calculation() { - // Test the confidence calculation function let confidence = calculate_confidence(0.8, 5, 3); - assert!(confidence >= 0.0 && confidence <= 1.0); - - // High relevance, many results, diverse sources should give high confidence + assert!((0.0..=1.0).contains(&confidence)); + let high_confidence = calculate_confidence(0.9, 10, 5); assert!(high_confidence > 0.7); - - // Low relevance should give low confidence + let low_confidence = calculate_confidence(0.3, 10, 5); assert!(low_confidence < 0.5); } @@ -298,19 +394,19 @@ mod tests { "tokens": 100 }), json!({ - "content": "Test content 2", + "content": "Test content 2", "source": "doc2.pdf", "kb_name": "test_kb", "relevance": 0.7, "tokens": 150 - }) + }), ]; - + let summary = generate_summary(&results, "test query"); - + assert!(summary.contains("2 relevant result")); assert!(summary.contains("test query")); - assert!(summary.len() > 0); + assert!(!summary.is_empty()); } #[test] @@ -320,14 +416,10 @@ mod tests { "number_field": 42, "bool_field": true, "array_field": [1, 2, 3], - "object_field": { - "nested": "value" - } + "object_field": { "nested": "value" } }); - + let dynamic_result = json_to_dynamic(test_json); - - // The conversion should not panic and should return a Dynamic value assert!(!dynamic_result.is_unit()); } } diff --git a/src/basic/keywords/use_kb.rs b/src/basic/keywords/use_kb.rs index cdb3dee3..a9918578 100644 --- a/src/basic/keywords/use_kb.rs +++ b/src/basic/keywords/use_kb.rs @@ -14,6 +14,8 @@ struct BotNameResult { #[derive(QueryableByName)] struct KbCollectionResult { + #[diesel(sql_type = diesel::sql_types::Uuid)] + id: Uuid, #[diesel(sql_type = diesel::sql_types::Text)] folder_path: String, #[diesel(sql_type = diesel::sql_types::Text)] @@ -51,11 +53,12 @@ pub fn register_use_kb_keyword( let session_id = session_clone_for_syntax.id; let bot_id = session_clone_for_syntax.bot_id; + let user_id = session_clone_for_syntax.user_id; let conn = state_clone_for_syntax.conn.clone(); let kb_name_clone = kb_name.clone(); let result = - std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, &kb_name_clone)) + std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone)) .join(); match result { @@ -96,11 +99,12 @@ pub fn register_use_kb_keyword( let session_id = session_clone_lower.id; let bot_id = session_clone_lower.bot_id; + let user_id = session_clone_lower.user_id; let conn = state_clone_lower.conn.clone(); let kb_name_clone = kb_name.to_string(); let result = - std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, &kb_name_clone)) + std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone)) .join(); match result { @@ -127,11 +131,12 @@ pub fn register_use_kb_keyword( let session_id = session_clone2.id; let bot_id = session_clone2.bot_id; + let user_id = session_clone2.user_id; let conn = state_clone2.conn.clone(); let kb_name_clone = kb_name.to_string(); let result = - std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, &kb_name_clone)) + std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone)) .join(); match result { @@ -157,6 +162,7 @@ fn add_kb_to_session( conn_pool: crate::core::shared::utils::DbPool, session_id: Uuid, bot_id: Uuid, + user_id: Uuid, kb_name: &str, ) -> Result<(), String> { let mut conn = conn_pool @@ -170,7 +176,7 @@ fn add_kb_to_session( let bot_name = bot_result.name; let kb_exists: Option = diesel::sql_query( - "SELECT folder_path, qdrant_collection FROM kb_collections WHERE bot_id = $1 AND name = $2", + "SELECT id, folder_path, qdrant_collection FROM kb_collections WHERE bot_id = $1 AND name = $2", ) .bind::(bot_id) .bind::(kb_name) @@ -179,6 +185,30 @@ fn add_kb_to_session( .map_err(|e| format!("Failed to check KB existence: {}", e))?; let (kb_folder_path, qdrant_collection) = if let Some(kb_result) = kb_exists { + // CHECK ACCESS + let has_access: bool = diesel::sql_query( + "SELECT EXISTS ( + SELECT 1 FROM kb_collections kc + WHERE kc.id = $1 + AND ( + NOT EXISTS (SELECT 1 FROM kb_group_associations kga WHERE kga.kb_id = kc.id) + OR EXISTS ( + SELECT 1 FROM kb_group_associations kga + JOIN rbac_user_groups rug ON rug.group_id = kga.group_id + WHERE kga.kb_id = kc.id AND rug.user_id = $2 + ) + ) + )" + ) + .bind::(kb_result.id) + .bind::(user_id) + .get_result::(&mut conn) + .map_err(|e| format!("Failed to check KB access: {}", e))?; + + if !has_access { + return Err(format!("Access denied for KB '{}'", kb_name)); + } + (kb_result.folder_path, kb_result.qdrant_collection) } else { let default_path = format!("work/{}/{}.gbkb/{}", bot_name, bot_name, kb_name); diff --git a/src/core/shared/models/mod.rs b/src/core/shared/models/mod.rs index 1acc5eff..7b8c535c 100644 --- a/src/core/shared/models/mod.rs +++ b/src/core/shared/models/mod.rs @@ -46,9 +46,10 @@ pub use super::schema::{ #[cfg(feature = "vectordb")] pub use super::schema::{ - kb_collections, kb_documents, user_kb_associations, + kb_collections, kb_documents, kb_group_associations, user_kb_associations, }; + pub use botlib::message_types::MessageType; pub use botlib::models::{ApiResponse, Attachment, BotResponse, Session, Suggestion, UserMessage}; diff --git a/src/core/shared/schema/research.rs b/src/core/shared/schema/research.rs index 84e614af..2d7d07a3 100644 --- a/src/core/shared/schema/research.rs +++ b/src/core/shared/schema/research.rs @@ -1,4 +1,5 @@ -use crate::core::shared::schema::core::{organizations, bots}; +use crate::core::shared::schema::core::{bots, organizations}; +use crate::core::shared::schema::core::{rbac_groups, users}; diesel::table! { kb_documents (id) { @@ -148,6 +149,20 @@ diesel::table! { } } +diesel::table! { + kb_group_associations (id) { + id -> Uuid, + kb_id -> Uuid, + group_id -> Uuid, + granted_by -> Nullable, + granted_at -> Timestamptz, + } +} + +diesel::joinable!(kb_collections -> bots (bot_id)); +diesel::joinable!(kb_group_associations -> kb_collections (kb_id)); +diesel::joinable!(kb_group_associations -> rbac_groups (group_id)); +diesel::joinable!(kb_group_associations -> users (granted_by)); diesel::joinable!(research_projects -> organizations (org_id)); diesel::joinable!(research_projects -> bots (bot_id)); diesel::joinable!(research_sources -> research_projects (project_id)); diff --git a/src/directory/groups/kbs.rs b/src/directory/groups/kbs.rs new file mode 100644 index 00000000..e171961c --- /dev/null +++ b/src/directory/groups/kbs.rs @@ -0,0 +1,261 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::{Html, IntoResponse}, +}; +use chrono::Utc; +use diesel::prelude::*; +use std::sync::Arc; +use uuid::Uuid; + +use crate::core::shared::state::AppState; + +/// GET /groups/:group_id/kbs — returns an HTML fragment for the Knowledge Bases tab +pub async fn get_group_kbs( + State(state): State>, + Path(group_id_str): Path, +) -> Html { + let group_id = match Uuid::parse_str(&group_id_str) { + Ok(uid) => uid, + Err(_) => { + return Html(format!( + "
🚫
Invalid Group ID Format: {}
", + group_id_str + )); + } + }; + + let conn = state.conn.clone(); + let result = tokio::task::spawn_blocking(move || { + let mut db_conn = conn.get().map_err(|e| format!("DB error: {e}"))?; + + // 1. Get all KBs from kb_collections + use crate::core::shared::models::schema::kb_collections; + let all_kbs = kb_collections::table + .select(( + kb_collections::id, + kb_collections::name, + kb_collections::folder_path, + )) + .load::<(Uuid, String, String)>(&mut db_conn) + .map_err(|e| format!("KB query error: {e}"))?; + + // 2. Get associated KB IDs for this group + use crate::core::shared::models::schema::kb_group_associations; + let associated_ids: Vec = kb_group_associations::table + .filter(kb_group_associations::group_id.eq(group_id)) + .select(kb_group_associations::kb_id) + .load::(&mut db_conn) + .map_err(|e| format!("Association query error: {e}"))?; + + Ok::<_, String>((all_kbs, associated_ids)) + }) + .await; + + match result { + Ok(Ok((kbs, associated))) => { + let mut html = String::from( + r##" +
+
+
+ 📚 Knowledge Base Permissions +
+

Specify which knowledge repositories are accessible to members of this group during interactive AI sessions.

+
+ +
+
+ + + + + + + + + + "##, + ); + + if kbs.is_empty() { + html.push_str(""); + } else { + for (id, name, path) in kbs { + let is_checked = associated.contains(&id); + let checked_attr = if is_checked { "checked" } else { "" }; + let status_badge = if is_checked { + "Active" + } else { + "Inactive" + }; + + html.push_str(&format!( + r##" + + + + + + "##, + id = id, + id_short = id.to_string().chars().take(8).collect::(), + name = name, + path = path, + checked_attr = checked_attr, + group_id_str = group_id_str, + status_badge = status_badge, + row_class = if is_checked { "kb-row-active" } else { "" } + )); + } + } + + html.push_str( + r##" + +
ActiveKnowledge SourceFile Management PathCatalog ID
📂

No Knowledge Bases indexed yet.

Mark folders as 'KB' in the Drive application to see them here.

+
+ +
+
+
+
KB
+
+
{name}
+
{status_badge}
+
+
+
+
+ 📍 + {path} +
+
+ {id_short} +
+
+
+
+ + + +"##, + ); + Html(html) + } + Ok(Err(e)) => Html(format!( + "
+
⚠️
+
System Error:
{}
+
", + e + )), + Err(e) => Html(format!( + "
+
⚙️
+
Task Interruption:
{}
+
", + e + )), + } +} + +/// POST /groups/:group_id/kbs/toggle/:kb_id — toggles KB access for a group +pub async fn toggle_group_kb( + State(state): State>, + Path((group_id_str, kb_id_str)): Path<(String, String)>, +) -> impl IntoResponse { + let group_id = match Uuid::parse_str(&group_id_str) { + Ok(uid) => uid, + Err(_) => return StatusCode::BAD_REQUEST.into_response(), + }; + let kb_id = match Uuid::parse_str(&kb_id_str) { + Ok(uid) => uid, + Err(_) => return StatusCode::BAD_REQUEST.into_response(), + }; + + let conn = state.conn.clone(); + let result = tokio::task::spawn_blocking(move || -> Result { + let mut db_conn = conn.get().map_err(|e| format!("DB error: {e}"))?; + use crate::core::shared::models::schema::kb_group_associations; + + let existing = kb_group_associations::table + .filter(kb_group_associations::kb_id.eq(kb_id)) + .filter(kb_group_associations::group_id.eq(group_id)) + .select(kb_group_associations::id) + .first::(&mut db_conn) + .optional() + .map_err(|e| format!("Query error: {e}"))?; + + if let Some(assoc_id) = existing { + diesel::delete( + kb_group_associations::table.filter(kb_group_associations::id.eq(assoc_id)), + ) + .execute(&mut db_conn) + .map_err(|e| format!("Delete error: {e}"))?; + Ok(false) // Removed + } else { + diesel::insert_into(kb_group_associations::table) + .values(( + kb_group_associations::id.eq(Uuid::new_v4()), + kb_group_associations::kb_id.eq(kb_id), + kb_group_associations::group_id.eq(group_id), + kb_group_associations::granted_at.eq(Utc::now()), + )) + .execute(&mut db_conn) + .map_err(|e| format!("Insert error: {e}"))?; + Ok(true) // Added + } + }) + .await; + + match result { + Ok(Ok(is_added)) => { + if is_added { + StatusCode::CREATED.into_response() + } else { + StatusCode::NO_CONTENT.into_response() + } + } + _ => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } +} diff --git a/src/directory/groups/mod.rs b/src/directory/groups/mod.rs new file mode 100644 index 00000000..7b4ec9e0 --- /dev/null +++ b/src/directory/groups/mod.rs @@ -0,0 +1,7 @@ +pub mod types; +pub mod operations; +pub mod kbs; + +pub use types::*; +pub use operations::*; +pub use kbs::*; diff --git a/src/directory/groups.rs b/src/directory/groups/operations.rs similarity index 90% rename from src/directory/groups.rs rename to src/directory/groups/operations.rs index 297c25a1..503e14b4 100644 --- a/src/directory/groups.rs +++ b/src/directory/groups/operations.rs @@ -1,96 +1,16 @@ - use axum::{ extract::{Path, Query, State}, http::StatusCode, response::Json, }; -use chrono::{DateTime, Utc}; use log::{error, info}; -use serde::{Deserialize, Serialize}; use std::sync::Arc; use uuid::Uuid; +use chrono; +use serde_json; use crate::core::shared::state::AppState; - - - -#[derive(Debug, Deserialize)] -pub struct CreateGroupRequest { - pub name: String, - pub description: Option, - pub members: Option>, -} - -#[derive(Debug, Deserialize)] -pub struct UpdateGroupRequest { - pub name: Option, - pub description: Option, - pub members: Option>, -} - -#[derive(Debug, Deserialize)] -pub struct GroupQuery { - pub page: Option, - pub per_page: Option, - pub search: Option, -} - -#[derive(Debug, Deserialize)] -pub struct AddMemberRequest { - pub user_id: String, - pub roles: Option>, -} - -#[derive(Debug, Serialize)] -pub struct GroupResponse { - pub id: String, - pub name: String, - pub description: Option, - pub member_count: usize, - pub state: String, - pub created_at: Option>, - pub updated_at: Option>, -} - -#[derive(Debug, Serialize)] -pub struct GroupListResponse { - pub groups: Vec, - pub total: usize, - pub page: u32, - pub per_page: u32, -} - -#[derive(Debug, Serialize)] -pub struct GroupInfo { - pub id: String, - pub name: String, - pub description: Option, - pub member_count: usize, -} - -#[derive(Debug, Serialize)] -pub struct GroupMemberResponse { - pub user_id: String, - pub username: Option, - pub roles: Vec, - pub email: Option, -} - -#[derive(Debug, Serialize)] -pub struct SuccessResponse { - pub success: bool, - pub message: Option, - pub group_id: Option, -} - -#[derive(Debug, Serialize)] -pub struct ErrorResponse { - pub error: String, - pub details: Option, -} - - - +use super::types::*; pub async fn create_group( State(state): State>, diff --git a/src/directory/groups/types.rs b/src/directory/groups/types.rs new file mode 100644 index 00000000..ecaf8a67 --- /dev/null +++ b/src/directory/groups/types.rs @@ -0,0 +1,77 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize)] +pub struct CreateGroupRequest { + pub name: String, + pub description: Option, + pub members: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct UpdateGroupRequest { + pub name: Option, + pub description: Option, + pub members: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct GroupQuery { + pub page: Option, + pub per_page: Option, + pub search: Option, +} + +#[derive(Debug, Deserialize)] +pub struct AddMemberRequest { + pub user_id: String, + pub roles: Option>, +} + +#[derive(Debug, Serialize)] +pub struct GroupResponse { + pub id: String, + pub name: String, + pub description: Option, + pub member_count: usize, + pub state: String, + pub created_at: Option>, + pub updated_at: Option>, +} + +#[derive(Debug, Serialize)] +pub struct GroupListResponse { + pub groups: Vec, + pub total: usize, + pub page: u32, + pub per_page: u32, +} + +#[derive(Debug, Serialize)] +pub struct GroupInfo { + pub id: String, + pub name: String, + pub description: Option, + pub member_count: usize, +} + +#[derive(Debug, Serialize)] +pub struct GroupMemberResponse { + pub user_id: String, + pub username: Option, + pub roles: Vec, + pub email: Option, +} + +#[derive(Debug, Serialize)] +pub struct SuccessResponse { + pub success: bool, + pub message: Option, + pub group_id: Option, +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: String, + pub details: Option, +} diff --git a/src/directory/router.rs b/src/directory/router.rs index 2d6523c0..f7fb7da4 100644 --- a/src/directory/router.rs +++ b/src/directory/router.rs @@ -65,6 +65,8 @@ pub fn configure() -> Router> { .route("/groups/:group_id/delete", delete(groups::delete_group)) .route("/groups/list", get(groups::list_groups)) .route("/groups/search", get(groups::list_groups)) + .route("/groups/:group_id/kbs", get(groups::get_group_kbs)) + .route("/groups/:group_id/kbs/toggle/:kb_id", post(groups::toggle_group_kb)) .route("/groups/:group_id/members", get(groups::get_group_members)) .route( "/groups/:group_id/members/add", @@ -107,3 +109,4 @@ pub fn configure() -> Router> { get(groups::get_group_members), ) } + diff --git a/src/drive/mod.rs b/src/drive/mod.rs index e27e1ff9..141962b7 100644 --- a/src/drive/mod.rs +++ b/src/drive/mod.rs @@ -28,6 +28,8 @@ pub struct FileItem { pub size: Option, pub modified: Option, pub icon: String, + pub is_kb: bool, + pub is_public: bool, } #[derive(Debug, Deserialize)] @@ -355,6 +357,19 @@ pub async fn list_files( let mut items = Vec::new(); let prefix = params.path.as_deref().unwrap_or(""); + // Fetch KBs from database to mark them in the list + let kbs: Vec<(String, bool)> = { + let conn = state.conn.clone(); + tokio::task::spawn_blocking(move || { + let mut db_conn = conn.get().map_err(|e| e.to_string())?; + use crate::core::shared::models::schema::kb_collections; + kb_collections::table + .select((kb_collections::name, kb_collections::is_public)) + .load::<(String, bool)>(&mut db_conn) + .map_err(|e| e.to_string()) + }).await.unwrap_or(Ok(vec![])).unwrap_or_default() + }; + let paginator = s3_client .list_objects_v2() .bucket(bucket) @@ -403,12 +418,24 @@ pub async fn list_files( size: object.size, modified: object.last_modified.map(|t| t.to_string()), icon: get_file_icon(&key), + is_kb: false, + is_public: true, }); } } } } } + // Post-process to mark KBs + for item in &mut items { + if item.is_dir { + if let Some((_, is_public)) = kbs.iter().find(|(name, _)| name == &item.name) { + item.is_kb = true; + item.is_public = *is_public; + item.icon = "🧠".to_string(); // Knowledge icon + } + } + } Ok(items) } else { Ok(vec![]) diff --git a/src/settings/mod.rs b/src/settings/mod.rs index a09395a5..a89a4bd3 100644 --- a/src/settings/mod.rs +++ b/src/settings/mod.rs @@ -2,6 +2,7 @@ pub mod audit_log; pub mod menu_config; pub mod permission_inheritance; pub mod rbac; +pub mod rbac_kb; pub mod rbac_ui; pub mod security_admin; diff --git a/src/settings/rbac.rs b/src/settings/rbac.rs index 2f21fb39..844afab7 100644 --- a/src/settings/rbac.rs +++ b/src/settings/rbac.rs @@ -18,8 +18,13 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; use uuid::Uuid; +use crate::settings::rbac_kb::{ + assign_kb_to_group, get_accessible_kbs_for_user, get_kb_groups, remove_kb_from_group, +}; + pub fn configure_rbac_routes() -> Router> { Router::new() + .route("/api/rbac/roles", get(list_roles).post(create_role)) .route("/api/rbac/roles/{role_id}", get(get_role).delete(delete_role)) .route("/api/rbac/roles/{role_id}/permissions", get(get_role_permissions)) @@ -33,6 +38,10 @@ pub fn configure_rbac_routes() -> Router> { .route("/api/rbac/groups/{group_id}/roles", get(get_group_roles)) .route("/api/rbac/groups/{group_id}/roles/{role_id}", post(assign_role_to_group).delete(remove_role_from_group)) .route("/api/rbac/users/{user_id}/permissions", get(get_effective_permissions)) + // KB-group management + .route("/api/rbac/kbs/{kb_id}/groups", get(get_kb_groups)) + .route("/api/rbac/kbs/{kb_id}/groups/{group_id}", post(assign_kb_to_group).delete(remove_kb_from_group)) + .route("/api/rbac/users/{user_id}/accessible-kbs", get(get_accessible_kbs_for_user)) .route("/settings/rbac", get(rbac_settings_page)) .route("/settings/rbac/users", get(rbac_users_list)) .route("/settings/rbac/roles", get(rbac_roles_list)) @@ -730,8 +739,11 @@ async fn get_effective_permissions(State(state): State>, Path(user } } +// ─── UI re-exports ──────────────────────────────────────────────────────────── + pub use crate::settings::rbac_ui::{ rbac_settings_page, rbac_users_list, rbac_roles_list, rbac_groups_list, user_assignment_panel, available_roles_for_user, assigned_roles_for_user, available_groups_for_user, assigned_groups_for_user, }; + diff --git a/src/settings/rbac_kb.rs b/src/settings/rbac_kb.rs new file mode 100644 index 00000000..f4fddfba --- /dev/null +++ b/src/settings/rbac_kb.rs @@ -0,0 +1,238 @@ +//! KB-group access management handlers +//! +//! Provides endpoints to assign/remove RBAC groups to/from knowledge bases, +//! and to query which KBs a specific user may access. +//! +//! Endpoints: +//! GET /api/rbac/kbs/{kb_id}/groups +//! POST /api/rbac/kbs/{kb_id}/groups/{group_id} +//! DELETE /api/rbac/kbs/{kb_id}/groups/{group_id} +//! GET /api/rbac/users/{user_id}/accessible-kbs + +use crate::security::error_sanitizer::log_and_sanitize_str; +use crate::core::shared::models::{RbacGroup}; +use crate::core::shared::state::AppState; +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + Json, +}; +use chrono::Utc; +use diesel::prelude::*; +use log::info; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use uuid::Uuid; + +/// Auxiliary row type for fetching group_id values via raw SQL. +#[derive(QueryableByName)] +pub struct GroupIdQueryRow { + #[diesel(sql_type = diesel::sql_types::Uuid)] + pub group_id: Uuid, +} + +/// Auxiliary row type for fetching kb_collections via raw SQL. +#[derive(QueryableByName, Serialize, Deserialize)] +pub struct KbCollectionRow { + #[diesel(sql_type = diesel::sql_types::Uuid)] + pub id: Uuid, + #[diesel(sql_type = diesel::sql_types::Uuid)] + pub bot_id: Uuid, + #[diesel(sql_type = diesel::sql_types::Text)] + pub name: String, + #[diesel(sql_type = diesel::sql_types::Text)] + pub folder_path: String, + #[diesel(sql_type = diesel::sql_types::Text)] + pub qdrant_collection: String, + #[diesel(sql_type = diesel::sql_types::Integer)] + pub document_count: i32, + #[diesel(sql_type = diesel::sql_types::Timestamptz)] + pub created_at: chrono::DateTime, + #[diesel(sql_type = diesel::sql_types::Timestamptz)] + pub updated_at: chrono::DateTime, +} + +/// GET /api/rbac/kbs/{kb_id}/groups — list groups that have access to a KB +pub async fn get_kb_groups( + State(state): State>, + Path(kb_id): Path, +) -> impl IntoResponse { + let conn = state.conn.clone(); + let result = tokio::task::spawn_blocking(move || { + let mut db_conn = conn.get().map_err(|e| format!("DB error: {e}"))?; + use crate::core::shared::models::schema::{kb_group_associations, rbac_groups}; + kb_group_associations::table + .inner_join(rbac_groups::table.on(rbac_groups::id.eq(kb_group_associations::group_id))) + .filter(kb_group_associations::kb_id.eq(kb_id)) + .filter(rbac_groups::is_active.eq(true)) + .select(RbacGroup::as_select()) + .load::(&mut db_conn) + .map_err(|e| format!("Query error: {e}")) + }) + .await; + + match result { + Ok(Ok(groups)) => Json(serde_json::json!({ "groups": groups })).into_response(), + Ok(Err(e)) => { + let sanitized = log_and_sanitize_str(&e, "get_kb_groups", None); + (StatusCode::INTERNAL_SERVER_ERROR, sanitized).into_response() + } + Err(e) => { + let sanitized = log_and_sanitize_str(&e.to_string(), "get_kb_groups", None); + (StatusCode::INTERNAL_SERVER_ERROR, sanitized).into_response() + } + } +} + +/// POST /api/rbac/kbs/{kb_id}/groups/{group_id} — grant a group access to a KB +pub async fn assign_kb_to_group( + State(state): State>, + Path((kb_id, group_id)): Path<(Uuid, Uuid)>, +) -> impl IntoResponse { + let conn = state.conn.clone(); + let now = Utc::now(); + let result = tokio::task::spawn_blocking(move || -> Result<(), String> { + let mut db_conn = conn.get().map_err(|e| format!("DB error: {e}"))?; + use crate::core::shared::models::schema::kb_group_associations; + let existing: Option = kb_group_associations::table + .filter(kb_group_associations::kb_id.eq(kb_id)) + .filter(kb_group_associations::group_id.eq(group_id)) + .select(kb_group_associations::id) + .first::(&mut db_conn) + .optional() + .map_err(|e| format!("Query error: {e}"))?; + if existing.is_some() { + return Err("Group already has access to this KB".to_string()); + } + diesel::sql_query( + "INSERT INTO kb_group_associations (id, kb_id, group_id, granted_at) + VALUES ($1, $2, $3, $4)", + ) + .bind::(Uuid::new_v4()) + .bind::(kb_id) + .bind::(group_id) + .bind::(now) + .execute(&mut db_conn) + .map_err(|e| format!("Insert error: {e}"))?; + Ok(()) + }) + .await; + + match result { + Ok(Ok(())) => { + info!("Assigned KB {kb_id} to group {group_id}"); + StatusCode::CREATED.into_response() + } + Ok(Err(e)) => { + let sanitized = log_and_sanitize_str(&e, "assign_kb_to_group", None); + (StatusCode::BAD_REQUEST, sanitized).into_response() + } + Err(e) => { + let sanitized = log_and_sanitize_str(&e.to_string(), "assign_kb_to_group", None); + (StatusCode::INTERNAL_SERVER_ERROR, sanitized).into_response() + } + } +} + +/// DELETE /api/rbac/kbs/{kb_id}/groups/{group_id} — revoke a group's access to a KB +pub async fn remove_kb_from_group( + State(state): State>, + Path((kb_id, group_id)): Path<(Uuid, Uuid)>, +) -> impl IntoResponse { + let conn = state.conn.clone(); + let result = tokio::task::spawn_blocking(move || -> Result<(), String> { + let mut db_conn = conn.get().map_err(|e| format!("DB error: {e}"))?; + use crate::core::shared::models::schema::kb_group_associations; + diesel::delete( + kb_group_associations::table + .filter(kb_group_associations::kb_id.eq(kb_id)) + .filter(kb_group_associations::group_id.eq(group_id)), + ) + .execute(&mut db_conn) + .map_err(|e| format!("Delete error: {e}"))?; + Ok(()) + }) + .await; + + match result { + Ok(Ok(())) => { + info!("Removed group {group_id} from KB {kb_id}"); + StatusCode::NO_CONTENT.into_response() + } + Ok(Err(e)) => { + let sanitized = log_and_sanitize_str(&e, "remove_kb_from_group", None); + (StatusCode::BAD_REQUEST, sanitized).into_response() + } + Err(e) => { + let sanitized = log_and_sanitize_str(&e.to_string(), "remove_kb_from_group", None); + (StatusCode::INTERNAL_SERVER_ERROR, sanitized).into_response() + } + } +} + +/// GET /api/rbac/users/{user_id}/accessible-kbs — list KBs accessible to a user +pub async fn get_accessible_kbs_for_user( + State(state): State>, + Path(user_id): Path, +) -> impl IntoResponse { + let conn = state.conn.clone(); + let result = tokio::task::spawn_blocking(move || { + let mut db_conn = conn.get().map_err(|e| format!("DB error: {e}"))?; + + let group_ids: Vec = diesel::sql_query( + "SELECT group_id FROM rbac_user_groups WHERE user_id = $1", + ) + .bind::(user_id) + .load::(&mut db_conn) + .map_err(|e| format!("Failed to get user groups: {e}"))? + .into_iter() + .map(|r| r.group_id) + .collect(); + + let kbs: Vec = if group_ids.is_empty() { + diesel::sql_query( + "SELECT kc.id, kc.bot_id, kc.name, kc.folder_path, kc.qdrant_collection, + kc.document_count, kc.created_at, kc.updated_at + FROM kb_collections kc + WHERE NOT EXISTS ( + SELECT 1 FROM kb_group_associations kga WHERE kga.kb_id = kc.id + )", + ) + .load::(&mut db_conn) + .map_err(|e| format!("Query error: {e}"))? + } else { + diesel::sql_query( + "SELECT kc.id, kc.bot_id, kc.name, kc.folder_path, kc.qdrant_collection, + kc.document_count, kc.created_at, kc.updated_at + FROM kb_collections kc + WHERE NOT EXISTS ( + SELECT 1 FROM kb_group_associations kga WHERE kga.kb_id = kc.id + ) + OR EXISTS ( + SELECT 1 FROM kb_group_associations kga + WHERE kga.kb_id = kc.id + AND kga.group_id = ANY($1::uuid[]) + )", + ) + .bind::, _>(group_ids) + .load::(&mut db_conn) + .map_err(|e| format!("Query error: {e}"))? + }; + + Ok::<_, String>(kbs) + }) + .await; + + match result { + Ok(Ok(kbs)) => Json(serde_json::json!({ "kbs": kbs })).into_response(), + Ok(Err(e)) => { + let sanitized = log_and_sanitize_str(&e, "get_accessible_kbs_for_user", None); + (StatusCode::INTERNAL_SERVER_ERROR, sanitized).into_response() + } + Err(e) => { + let sanitized = log_and_sanitize_str(&e.to_string(), "get_accessible_kbs_for_user", None); + (StatusCode::INTERNAL_SERVER_ERROR, sanitized).into_response() + } + } +}