use anyhow::{anyhow, Result}; use chrono::{DateTime, Duration, Utc}; use jsonwebtoken::{ decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, }; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::sync::Arc; use tokio::sync::RwLock; use tracing::{debug, info}; use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct JwtConfig { pub issuer: String, pub audience: String, pub access_token_expiry_minutes: i64, pub refresh_token_expiry_days: i64, pub algorithm: JwtAlgorithm, pub leeway_seconds: u64, } impl Default for JwtConfig { fn default() -> Self { Self { issuer: "general-bots".into(), audience: "general-bots-api".into(), access_token_expiry_minutes: 15, refresh_token_expiry_days: 7, algorithm: JwtAlgorithm::HS256, leeway_seconds: 60, } } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum JwtAlgorithm { HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, } impl JwtAlgorithm { pub fn to_jsonwebtoken(&self) -> Algorithm { match self { Self::HS256 => Algorithm::HS256, Self::HS384 => Algorithm::HS384, Self::HS512 => Algorithm::HS512, Self::RS256 => Algorithm::RS256, Self::RS384 => Algorithm::RS384, Self::RS512 => Algorithm::RS512, Self::ES256 => Algorithm::ES256, Self::ES384 => Algorithm::ES384, } } pub fn is_symmetric(&self) -> bool { matches!(self, Self::HS256 | Self::HS384 | Self::HS512) } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum TokenType { Access, Refresh, IdToken, } impl TokenType { pub fn as_str(&self) -> &'static str { match self { Self::Access => "access", Self::Refresh => "refresh", Self::IdToken => "id_token", } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Claims { pub sub: String, pub iss: String, pub aud: String, pub exp: i64, pub iat: i64, pub nbf: i64, pub jti: String, #[serde(rename = "type")] pub token_type: String, #[serde(skip_serializing_if = "Option::is_none")] pub email: Option, #[serde(skip_serializing_if = "Option::is_none")] pub username: Option, #[serde(skip_serializing_if = "Option::is_none")] pub roles: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub permissions: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub session_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub organization_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub device_id: Option, } impl Claims { pub fn new( user_id: Uuid, issuer: &str, audience: &str, token_type: TokenType, expiry: DateTime, ) -> Self { let now = Utc::now(); Self { sub: user_id.to_string(), iss: issuer.to_string(), aud: audience.to_string(), exp: expiry.timestamp(), iat: now.timestamp(), nbf: now.timestamp(), jti: Uuid::new_v4().to_string(), token_type: token_type.as_str().to_string(), email: None, username: None, roles: None, permissions: None, session_id: None, organization_id: None, device_id: None, } } pub fn with_email(mut self, email: String) -> Self { self.email = Some(email); self } pub fn with_username(mut self, username: String) -> Self { self.username = Some(username); self } pub fn with_roles(mut self, roles: Vec) -> Self { self.roles = Some(roles); self } pub fn with_permissions(mut self, permissions: Vec) -> Self { self.permissions = Some(permissions); self } pub fn with_session_id(mut self, session_id: String) -> Self { self.session_id = Some(session_id); self } pub fn with_organization_id(mut self, org_id: String) -> Self { self.organization_id = Some(org_id); self } pub fn with_device_id(mut self, device_id: String) -> Self { self.device_id = Some(device_id); self } pub fn user_id(&self) -> Result { Uuid::parse_str(&self.sub).map_err(|e| anyhow!("Invalid user ID in claims: {e}")) } pub fn is_expired(&self) -> bool { Utc::now().timestamp() > self.exp } pub fn is_access_token(&self) -> bool { self.token_type == TokenType::Access.as_str() } pub fn is_refresh_token(&self) -> bool { self.token_type == TokenType::Refresh.as_str() } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenPair { pub access_token: String, pub refresh_token: String, pub token_type: String, pub expires_in: i64, pub refresh_expires_in: i64, #[serde(skip_serializing_if = "Option::is_none")] pub id_token: Option, #[serde(skip_serializing_if = "Option::is_none")] pub scope: Option, } #[derive(Debug, Clone)] pub enum JwtKey { Symmetric(Vec), RsaPrivate(Vec), RsaPublic(Vec), EcPrivate(Vec), EcPublic(Vec), } impl JwtKey { pub fn from_secret(secret: &str) -> Self { Self::Symmetric(secret.as_bytes().to_vec()) } pub fn from_rsa_pem(private_pem: &str, public_pem: &str) -> Result<(Self, Self)> { Ok(( Self::RsaPrivate(private_pem.as_bytes().to_vec()), Self::RsaPublic(public_pem.as_bytes().to_vec()), )) } pub fn from_ec_pem(private_pem: &str, public_pem: &str) -> Result<(Self, Self)> { Ok(( Self::EcPrivate(private_pem.as_bytes().to_vec()), Self::EcPublic(public_pem.as_bytes().to_vec()), )) } } pub struct JwtManager { config: JwtConfig, encoding_key: EncodingKey, decoding_key: DecodingKey, blacklist: Arc>>, } impl JwtManager { pub fn new(config: JwtConfig, key: JwtKey) -> Result { let (encoding_key, decoding_key) = match (&config.algorithm, key) { (JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512, JwtKey::Symmetric(secret)) => { ( EncodingKey::from_secret(&secret), DecodingKey::from_secret(&secret), ) } (JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512, JwtKey::RsaPrivate(pem)) => { let encoding = EncodingKey::from_rsa_pem(&pem) .map_err(|e| anyhow!("Invalid RSA private key: {e}"))?; let decoding = DecodingKey::from_rsa_pem(&pem) .map_err(|e| anyhow!("Invalid RSA key for decoding: {e}"))?; (encoding, decoding) } (JwtAlgorithm::ES256 | JwtAlgorithm::ES384, JwtKey::EcPrivate(pem)) => { let encoding = EncodingKey::from_ec_pem(&pem) .map_err(|e| anyhow!("Invalid EC private key: {e}"))?; let decoding = DecodingKey::from_ec_pem(&pem) .map_err(|e| anyhow!("Invalid EC key for decoding: {e}"))?; (encoding, decoding) } _ => return Err(anyhow!("Key type does not match algorithm")), }; Ok(Self { config, encoding_key, decoding_key, blacklist: Arc::new(RwLock::new(HashSet::new())), }) } pub fn with_separate_keys( config: JwtConfig, signing_key: JwtKey, verification_key: JwtKey, ) -> Result { let encoding_key = match (&config.algorithm, signing_key) { (JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512, JwtKey::Symmetric(secret)) => { EncodingKey::from_secret(&secret) } (JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512, JwtKey::RsaPrivate(pem)) => { EncodingKey::from_rsa_pem(&pem) .map_err(|e| anyhow!("Invalid RSA private key: {e}"))? } (JwtAlgorithm::ES256 | JwtAlgorithm::ES384, JwtKey::EcPrivate(pem)) => { EncodingKey::from_ec_pem(&pem) .map_err(|e| anyhow!("Invalid EC private key: {e}"))? } _ => return Err(anyhow!("Signing key type does not match algorithm")), }; let decoding_key = match (&config.algorithm, verification_key) { (JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512, JwtKey::Symmetric(secret)) => { DecodingKey::from_secret(&secret) } (JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512, JwtKey::RsaPublic(pem)) => { DecodingKey::from_rsa_pem(&pem) .map_err(|e| anyhow!("Invalid RSA public key: {e}"))? } (JwtAlgorithm::ES256 | JwtAlgorithm::ES384, JwtKey::EcPublic(pem)) => { DecodingKey::from_ec_pem(&pem) .map_err(|e| anyhow!("Invalid EC public key: {e}"))? } _ => return Err(anyhow!("Verification key type does not match algorithm")), }; Ok(Self { config, encoding_key, decoding_key, blacklist: Arc::new(RwLock::new(HashSet::new())), }) } pub fn from_secret(secret: &str) -> Result { if secret.len() < 32 { return Err(anyhow!("JWT secret must be at least 32 characters")); } Self::new(JwtConfig::default(), JwtKey::from_secret(secret)) } pub fn generate_access_token(&self, claims: Claims) -> Result { let header = Header::new(self.config.algorithm.to_jsonwebtoken()); encode(&header, &claims, &self.encoding_key) .map_err(|e| anyhow!("Failed to encode access token: {e}")) } pub fn generate_refresh_token(&self, claims: Claims) -> Result { let header = Header::new(self.config.algorithm.to_jsonwebtoken()); encode(&header, &claims, &self.encoding_key) .map_err(|e| anyhow!("Failed to encode refresh token: {e}")) } pub fn generate_token_pair(&self, user_id: Uuid) -> Result { let now = Utc::now(); let access_expiry = now + Duration::minutes(self.config.access_token_expiry_minutes); let refresh_expiry = now + Duration::days(self.config.refresh_token_expiry_days); let access_claims = Claims::new( user_id, &self.config.issuer, &self.config.audience, TokenType::Access, access_expiry, ); let refresh_claims = Claims::new( user_id, &self.config.issuer, &self.config.audience, TokenType::Refresh, refresh_expiry, ); let access_token = self.generate_access_token(access_claims)?; let refresh_token = self.generate_refresh_token(refresh_claims)?; Ok(TokenPair { access_token, refresh_token, token_type: "Bearer".into(), expires_in: self.config.access_token_expiry_minutes * 60, refresh_expires_in: self.config.refresh_token_expiry_days * 24 * 60 * 60, id_token: None, scope: None, }) } pub fn generate_token_pair_with_claims( &self, user_id: Uuid, email: Option, username: Option, roles: Option>, session_id: Option, ) -> Result { let now = Utc::now(); let access_expiry = now + Duration::minutes(self.config.access_token_expiry_minutes); let refresh_expiry = now + Duration::days(self.config.refresh_token_expiry_days); let mut access_claims = Claims::new( user_id, &self.config.issuer, &self.config.audience, TokenType::Access, access_expiry, ); if let Some(e) = email.clone() { access_claims = access_claims.with_email(e); } if let Some(u) = username.clone() { access_claims = access_claims.with_username(u); } if let Some(r) = roles.clone() { access_claims = access_claims.with_roles(r); } if let Some(s) = session_id.clone() { access_claims = access_claims.with_session_id(s); } let mut refresh_claims = Claims::new( user_id, &self.config.issuer, &self.config.audience, TokenType::Refresh, refresh_expiry, ); if let Some(s) = session_id { refresh_claims = refresh_claims.with_session_id(s); } let access_token = self.generate_access_token(access_claims)?; let refresh_token = self.generate_refresh_token(refresh_claims)?; Ok(TokenPair { access_token, refresh_token, token_type: "Bearer".into(), expires_in: self.config.access_token_expiry_minutes * 60, refresh_expires_in: self.config.refresh_token_expiry_days * 24 * 60 * 60, id_token: None, scope: None, }) } pub fn validate_token(&self, token: &str) -> Result> { let mut validation = Validation::new(self.config.algorithm.to_jsonwebtoken()); validation.set_issuer(&[&self.config.issuer]); validation.set_audience(&[&self.config.audience]); validation.leeway = self.config.leeway_seconds; decode::(token, &self.decoding_key, &validation) .map_err(|e| anyhow!("Token validation failed: {e}")) } pub async fn validate_token_with_blacklist(&self, token: &str) -> Result> { let token_data = self.validate_token(token)?; let blacklist = self.blacklist.read().await; if blacklist.contains(&token_data.claims.jti) { return Err(anyhow!("Token has been revoked")); } Ok(token_data) } pub fn validate_access_token(&self, token: &str) -> Result { let token_data = self.validate_token(token)?; if !token_data.claims.is_access_token() { return Err(anyhow!("Token is not an access token")); } Ok(token_data.claims) } pub fn validate_refresh_token(&self, token: &str) -> Result { let token_data = self.validate_token(token)?; if !token_data.claims.is_refresh_token() { return Err(anyhow!("Token is not a refresh token")); } Ok(token_data.claims) } pub async fn refresh_tokens(&self, refresh_token: &str) -> Result { let claims = self.validate_refresh_token(refresh_token)?; { let blacklist = self.blacklist.read().await; if blacklist.contains(&claims.jti) { return Err(anyhow!("Refresh token has been revoked")); } } let user_id = claims.user_id()?; self.revoke_token(&claims.jti).await?; let new_pair = self.generate_token_pair_with_claims( user_id, claims.email, claims.username, claims.roles, claims.session_id, )?; debug!("Refreshed tokens for user {user_id}"); Ok(new_pair) } pub async fn revoke_token(&self, jti: &str) -> Result<()> { let mut blacklist = self.blacklist.write().await; blacklist.insert(jti.to_string()); debug!("Revoked token {jti}"); Ok(()) } pub async fn revoke_by_token(&self, token: &str) -> Result<()> { let token_data = self.validate_token(token)?; self.revoke_token(&token_data.claims.jti).await } pub async fn is_revoked(&self, jti: &str) -> bool { let blacklist = self.blacklist.read().await; blacklist.contains(jti) } pub async fn cleanup_blacklist(&self, _expired_before: DateTime) -> usize { let mut blacklist = self.blacklist.write().await; let initial_count = blacklist.len(); blacklist.clear(); let removed = initial_count; if removed > 0 { info!("Cleaned up {removed} entries from token blacklist"); } removed } pub fn decode_without_validation(&self, token: &str) -> Result { let mut validation = Validation::new(self.config.algorithm.to_jsonwebtoken()); validation.insecure_disable_signature_validation(); validation.validate_exp = false; validation.validate_aud = false; let token_data = decode::(token, &self.decoding_key, &validation) .map_err(|e| anyhow!("Failed to decode token: {e}"))?; Ok(token_data.claims) } pub fn config(&self) -> &JwtConfig { &self.config } } pub fn extract_bearer_token(auth_header: &str) -> Option<&str> { auth_header.strip_prefix("Bearer ").or_else(|| auth_header.strip_prefix("bearer ")) } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct JwkSet { pub keys: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Jwk { pub kty: String, pub use_: Option, pub kid: Option, pub alg: Option, pub n: Option, pub e: Option, pub x: Option, pub y: Option, pub crv: Option, } impl JwkSet { pub fn new() -> Self { Self { keys: Vec::new() } } pub fn add_key(&mut self, key: Jwk) { self.keys.push(key); } } impl Default for JwkSet { fn default() -> Self { Self::new() } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenIntrospectionResponse { pub active: bool, #[serde(skip_serializing_if = "Option::is_none")] pub scope: Option, #[serde(skip_serializing_if = "Option::is_none")] pub client_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub username: Option, #[serde(skip_serializing_if = "Option::is_none")] pub token_type: Option, #[serde(skip_serializing_if = "Option::is_none")] pub exp: Option, #[serde(skip_serializing_if = "Option::is_none")] pub iat: Option, #[serde(skip_serializing_if = "Option::is_none")] pub nbf: Option, #[serde(skip_serializing_if = "Option::is_none")] pub sub: Option, #[serde(skip_serializing_if = "Option::is_none")] pub aud: Option, #[serde(skip_serializing_if = "Option::is_none")] pub iss: Option, #[serde(skip_serializing_if = "Option::is_none")] pub jti: Option, } impl TokenIntrospectionResponse { pub fn inactive() -> Self { Self { active: false, scope: None, client_id: None, username: None, token_type: None, exp: None, iat: None, nbf: None, sub: None, aud: None, iss: None, jti: None, } } pub fn from_claims(claims: &Claims, active: bool) -> Self { Self { active, scope: None, client_id: None, username: claims.username.clone(), token_type: Some(claims.token_type.clone()), exp: Some(claims.exp), iat: Some(claims.iat), nbf: Some(claims.nbf), sub: Some(claims.sub.clone()), aud: Some(claims.aud.clone()), iss: Some(claims.iss.clone()), jti: Some(claims.jti.clone()), } } } #[cfg(test)] mod tests { use super::*; fn create_test_manager() -> JwtManager { JwtManager::from_secret("this-is-a-very-long-secret-key-for-testing-purposes-only") .expect("Failed to create manager") } #[test] fn test_generate_token_pair() { let manager = create_test_manager(); let user_id = Uuid::new_v4(); let pair = manager.generate_token_pair(user_id).expect("Failed to generate"); assert!(!pair.access_token.is_empty()); assert!(!pair.refresh_token.is_empty()); assert_eq!(pair.token_type, "Bearer"); } #[test] fn test_validate_access_token() { let manager = create_test_manager(); let user_id = Uuid::new_v4(); let pair = manager.generate_token_pair(user_id).expect("Failed to generate"); let claims = manager .validate_access_token(&pair.access_token) .expect("Validation failed"); assert_eq!(claims.user_id().expect("Invalid user ID"), user_id); assert!(claims.is_access_token()); } #[test] fn test_validate_refresh_token() { let manager = create_test_manager(); let user_id = Uuid::new_v4(); let pair = manager.generate_token_pair(user_id).expect("Failed to generate"); let claims = manager .validate_refresh_token(&pair.refresh_token) .expect("Validation failed"); assert_eq!(claims.user_id().expect("Invalid user ID"), user_id); assert!(claims.is_refresh_token()); } #[test] fn test_token_with_claims() { let manager = create_test_manager(); let user_id = Uuid::new_v4(); let pair = manager .generate_token_pair_with_claims( user_id, Some("test@example.com".into()), Some("testuser".into()), Some(vec!["admin".into(), "user".into()]), Some("session-123".into()), ) .expect("Failed to generate"); let claims = manager .validate_access_token(&pair.access_token) .expect("Validation failed"); assert_eq!(claims.email, Some("test@example.com".into())); assert_eq!(claims.username, Some("testuser".into())); assert_eq!(claims.roles, Some(vec!["admin".into(), "user".into()])); assert_eq!(claims.session_id, Some("session-123".into())); } #[test] fn test_invalid_token() { let manager = create_test_manager(); let result = manager.validate_token("invalid.token.here"); assert!(result.is_err()); } #[test] fn test_wrong_token_type() { let manager = create_test_manager(); let user_id = Uuid::new_v4(); let pair = manager.generate_token_pair(user_id).expect("Failed to generate"); let result = manager.validate_refresh_token(&pair.access_token); assert!(result.is_err()); let result = manager.validate_access_token(&pair.refresh_token); assert!(result.is_err()); } #[tokio::test] async fn test_token_revocation() { let manager = create_test_manager(); let user_id = Uuid::new_v4(); let pair = manager.generate_token_pair(user_id).expect("Failed to generate"); let token_data = manager .validate_token(&pair.access_token) .expect("Validation failed"); manager .revoke_token(&token_data.claims.jti) .await .expect("Revoke failed"); let result = manager .validate_token_with_blacklist(&pair.access_token) .await; assert!(result.is_err()); } #[tokio::test] async fn test_refresh_tokens() { let manager = create_test_manager(); let user_id = Uuid::new_v4(); let pair = manager .generate_token_pair_with_claims( user_id, Some("test@example.com".into()), Some("testuser".into()), None, None, ) .expect("Failed to generate"); let new_pair = manager .refresh_tokens(&pair.refresh_token) .await .expect("Refresh failed"); assert_ne!(new_pair.access_token, pair.access_token); assert_ne!(new_pair.refresh_token, pair.refresh_token); let claims = manager .validate_access_token(&new_pair.access_token) .expect("Validation failed"); assert_eq!(claims.email, Some("test@example.com".into())); } #[test] fn test_extract_bearer_token() { assert_eq!( extract_bearer_token("Bearer abc123"), Some("abc123") ); assert_eq!( extract_bearer_token("bearer abc123"), Some("abc123") ); assert_eq!(extract_bearer_token("Basic abc123"), None); } #[test] fn test_claims_builder() { let user_id = Uuid::new_v4(); let claims = Claims::new( user_id, "issuer", "audience", TokenType::Access, Utc::now() + Duration::hours(1), ) .with_email("test@example.com".into()) .with_username("testuser".into()) .with_roles(vec!["admin".into()]) .with_organization_id("org-123".into()); assert_eq!(claims.email, Some("test@example.com".into())); assert_eq!(claims.username, Some("testuser".into())); assert_eq!(claims.roles, Some(vec!["admin".into()])); assert_eq!(claims.organization_id, Some("org-123".into())); } #[test] fn test_token_type() { assert_eq!(TokenType::Access.as_str(), "access"); assert_eq!(TokenType::Refresh.as_str(), "refresh"); assert_eq!(TokenType::IdToken.as_str(), "id_token"); } #[test] fn test_jwt_algorithm() { assert!(JwtAlgorithm::HS256.is_symmetric()); assert!(JwtAlgorithm::HS384.is_symmetric()); assert!(JwtAlgorithm::HS512.is_symmetric()); assert!(!JwtAlgorithm::RS256.is_symmetric()); assert!(!JwtAlgorithm::ES256.is_symmetric()); } #[test] fn test_token_introspection_response() { let inactive = TokenIntrospectionResponse::inactive(); assert!(!inactive.active); let user_id = Uuid::new_v4(); let claims = Claims::new( user_id, "issuer", "audience", TokenType::Access, Utc::now() + Duration::hours(1), ); let active = TokenIntrospectionResponse::from_claims(&claims, true); assert!(active.active); assert_eq!(active.sub, Some(user_id.to_string())); } #[test] fn test_jwk_set() { let mut jwk_set = JwkSet::new(); assert!(jwk_set.keys.is_empty()); jwk_set.add_key(Jwk { kty: "RSA".into(), use_: Some("sig".into()), kid: Some("key-1".into()), alg: Some("RS256".into()), n: None, e: None, x: None, y: None, crv: None, }); assert_eq!(jwk_set.keys.len(), 1); } }