use anyhow::Result; use chrono::{Duration, Utc}; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; use crate::config::JwtConfig; // Token claims #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Claims { pub sub: String, pub exp: usize, pub iat: usize, pub user_id: String, pub email: String, pub token_version: i32, } impl Claims { pub fn new(user_id: String, email: String, token_version: i32) -> Self { let now = Utc::now(); let exp = now + Duration::minutes(15); // Access token expires in 15 minutes Self { sub: user_id.clone(), exp: exp.timestamp() as usize, iat: now.timestamp() as usize, user_id, email, token_version, } } } // Refresh token claims (longer expiry) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RefreshClaims { pub sub: String, pub exp: usize, pub iat: usize, pub user_id: String, pub token_version: i32, } impl RefreshClaims { pub fn new(user_id: String, token_version: i32) -> Self { let now = Utc::now(); let exp = now + Duration::days(30); // Refresh token expires in 30 days Self { sub: user_id.clone(), exp: exp.timestamp() as usize, iat: now.timestamp() as usize, user_id, token_version, } } } /// JWT Service for token generation and validation #[derive(Clone)] pub struct JwtService { config: JwtConfig, // In-memory storage for refresh tokens (user_id -> set of tokens) refresh_tokens: Arc>>>, encoding_key: EncodingKey, decoding_key: DecodingKey, } impl JwtService { pub fn new(config: JwtConfig) -> Self { let encoding_key = EncodingKey::from_secret(config.secret.as_ref()); let decoding_key = DecodingKey::from_secret(config.secret.as_ref()); Self { config, refresh_tokens: Arc::new(RwLock::new(HashMap::new())), encoding_key, decoding_key, } } /// Generate access and refresh tokens pub fn generate_tokens(&self, claims: Claims) -> Result<(String, String)> { // Generate access token let access_token = encode(&Header::default(), &claims, &self.encoding_key) .map_err(|e| anyhow::anyhow!("Failed to encode access token: {}", e))?; // Generate refresh token let refresh_claims = RefreshClaims::new(claims.user_id.clone(), claims.token_version); let refresh_token = encode(&Header::default(), &refresh_claims, &self.encoding_key) .map_err(|e| anyhow::anyhow!("Failed to encode refresh token: {}", e))?; Ok((access_token, refresh_token)) } /// Validate access token pub fn validate_token(&self, token: &str) -> Result { let token_data = decode::( token, &self.decoding_key, &Validation::default() ) .map_err(|e| anyhow::anyhow!("Invalid token: {}", e))?; Ok(token_data.claims) } /// Validate refresh token pub fn validate_refresh_token(&self, token: &str) -> Result { let token_data = decode::( token, &self.decoding_key, &Validation::default() ) .map_err(|e| anyhow::anyhow!("Invalid refresh token: {}", e))?; Ok(token_data.claims) } /// Store refresh token for a user pub async fn store_refresh_token(&self, user_id: &str, token: &str) -> Result<()> { let mut tokens = self.refresh_tokens.write().await; tokens.entry(user_id.to_string()) .or_insert_with(Vec::new) .push(token.to_string()); // Keep only last 5 tokens per user if let Some(user_tokens) = tokens.get_mut(user_id) { user_tokens.sort(); user_tokens.dedup(); if user_tokens.len() > 5 { *user_tokens = user_tokens.split_off(user_tokens.len() - 5); } } Ok(()) } /// Verify if a refresh token is stored pub async fn verify_refresh_token_stored(&self, user_id: &str, token: &str) -> Result { let tokens = self.refresh_tokens.read().await; if let Some(user_tokens) = tokens.get(user_id) { Ok(user_tokens.contains(&token.to_string())) } else { Ok(false) } } /// Rotate refresh token (remove old, add new) pub async fn rotate_refresh_token(&self, user_id: &str, old_token: &str, new_token: &str) -> Result<()> { // Remove old token self.revoke_refresh_token(old_token).await?; // Add new token self.store_refresh_token(user_id, new_token).await?; Ok(()) } /// Revoke a specific refresh token pub async fn revoke_refresh_token(&self, token: &str) -> Result<()> { let mut tokens = self.refresh_tokens.write().await; for user_tokens in tokens.values_mut() { user_tokens.retain(|t| t != token); } Ok(()) } /// Revoke all refresh tokens for a user pub async fn revoke_all_user_tokens(&self, user_id: &str) -> Result<()> { let mut tokens = self.refresh_tokens.write().await; tokens.remove(user_id); Ok(()) } }