use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use reqwest;
use serde::{Deserialize, Serialize};
use std::error::Error;
const HANKO_API_URL: &str = "https://XXXXX.hanko.io";
const JWT_AUDIENCE_DOMAIN: &str = "example.com";
#[derive(Debug, Deserialize)]
struct Jwk {
kty: String,
kid: String,
n: String,
e: String,
alg: String,
}
#[derive(Debug, Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct EmailInfo {
pub address: String,
pub is_primary: bool,
pub is_verified: bool,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct UserDecodedJwtInfo {
pub aud: Vec<String>,
pub email: EmailInfo,
pub exp: i64,
pub iat: i64,
pub sub: String,
}
async fn fetch_jwks() -> Result<Jwks, Box<dyn Error>> {
let response = reqwest::get(format!("{}/.well-known/jwks.json", HANKO_API_URL)).await?;
let jwks = response.json::<Jwks>().await?;
Ok(jwks)
}
fn get_jwt_kid(token: &str) -> Result<String, Box<dyn Error>> {
let header = decode_header(token)?;
header.kid.ok_or_else(|| "Missing kid in header".into())
}
fn decode_token(jwk: &Jwk, token: &str) -> Result<UserDecodedJwtInfo, Box<dyn Error>> {
let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_audience(&[JWT_AUDIENCE_DOMAIN.to_string()]);
let token_data = decode::<UserDecodedJwtInfo>(token, &decoding_key, &validation)?;
Ok(token_data.claims)
}
async fn validate_hanko_token(token: &str) -> Result<UserDecodedJwtInfo, Box<dyn Error>> {
let jwt_kid = get_jwt_kid(token)?;
let jwks = fetch_jwks().await?;
let matching_jwk = jwks.keys.iter().find(|jwk| jwk.kid == jwt_kid);
match matching_jwk {
Some(jwk) => {
let token_data = decode_token(jwk, token)?;
Ok(token_data)
}
None => Err(format!("No matching JWK found matching token kid: {}", jwt_kid).into()),
}
}
#[tokio::main]
async fn main() {
let jwt_hanko_token = "some-hanko-jwt-token".to_string();
match validate_hanko_token(&jwt_hanko_token).await {
Ok(payload) => println!("Token is valid. Payload: {:?}", payload),
Err(err) => eprintln!("Token validation failed: {}", err),
}
}