security: enhance JWT and tenant routing middleware

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
movingsam 2026-02-28 18:37:48 +08:00
parent 2a4a06ddb8
commit eec65c1e05
2 changed files with 79 additions and 24 deletions

View File

@ -6,6 +6,14 @@ using YarpGateway.Config;
namespace YarpGateway.Middleware; namespace YarpGateway.Middleware;
/// <summary>
/// JWT 转换中间件
///
/// 安全说明:
/// 1. 从已验证的 JWT Claims 中提取用户信息(不是直接解析 token
/// 2. 清除请求中所有 X-* Header 以防止 Header 注入攻击
/// 3. 验证租户 ID 与 JWT 中的 tenant claim 一致
/// </summary>
public class JwtTransformMiddleware public class JwtTransformMiddleware
{ {
private readonly RequestDelegate _next; private readonly RequestDelegate _next;
@ -25,32 +33,46 @@ public class JwtTransformMiddleware
public async Task InvokeAsync(HttpContext context) public async Task InvokeAsync(HttpContext context)
{ {
var authHeader = context.Request.Headers["Authorization"].FirstOrDefault(); // 安全措施:清除所有 X-* Header 防止 Header 注入攻击
if (string.IsNullOrEmpty(authHeader) || !authHeader.StartsWith("Bearer ")) var xHeaders = context.Request.Headers
.Where(h => h.Key.StartsWith("X-", StringComparison.OrdinalIgnoreCase))
.Select(h => h.Key)
.ToList();
foreach (var header in xHeaders)
{
context.Request.Headers.Remove(header);
}
// 检查用户是否已通过 JWT 认证
if (context.User?.Identity?.IsAuthenticated != true)
{ {
await _next(context); await _next(context);
return; return;
} }
var token = authHeader.Substring("Bearer ".Length).Trim();
try try
{ {
var jwtHandler = new JwtSecurityTokenHandler(); // 从已验证的 ClaimsPrincipal 中提取信息(安全)
var jwtToken = jwtHandler.ReadJwtToken(token); var claims = context.User.Claims;
var tenantId = jwtToken.Claims.FirstOrDefault(c => c.Type == "tenant")?.Value; var tenantId = claims.FirstOrDefault(c => c.Type == "tenant")?.Value
var userId = jwtToken ?? claims.FirstOrDefault(c => c.Type == "tenant_id")?.Value;
.Claims.FirstOrDefault(c => c.Type == ClaimTypes.NameIdentifier)
?.Value; var userId = claims.FirstOrDefault(c => c.Type == ClaimTypes.NameIdentifier)?.Value
var userName = jwtToken.Claims.FirstOrDefault(c => c.Type == ClaimTypes.Name)?.Value; ?? claims.FirstOrDefault(c => c.Type == "sub")?.Value;
var roles = jwtToken
.Claims.Where(c => c.Type == ClaimTypes.Role) var userName = claims.FirstOrDefault(c => c.Type == ClaimTypes.Name)?.Value
?? claims.FirstOrDefault(c => c.Type == "name")?.Value;
var roles = claims
.Where(c => c.Type == ClaimTypes.Role || c.Type == "role")
.Select(c => c.Value) .Select(c => c.Value)
.ToList(); .ToList();
if (!string.IsNullOrEmpty(tenantId)) if (!string.IsNullOrEmpty(tenantId))
{ {
// 安全地设置 Header从已验证的 JWT 中提取)
context.Request.Headers["X-Tenant-Id"] = tenantId; context.Request.Headers["X-Tenant-Id"] = tenantId;
if (!string.IsNullOrEmpty(userId)) if (!string.IsNullOrEmpty(userId))
@ -63,7 +85,7 @@ public class JwtTransformMiddleware
context.Request.Headers["X-Roles"] = string.Join(",", roles); context.Request.Headers["X-Roles"] = string.Join(",", roles);
_logger.LogInformation( _logger.LogInformation(
"JWT transformed - Tenant: {Tenant}, User: {User}", "JWT claims transformed - Tenant: {Tenant}, User: {User}",
tenantId, tenantId,
userId userId
); );
@ -75,9 +97,9 @@ public class JwtTransformMiddleware
} }
catch (Exception ex) catch (Exception ex)
{ {
_logger.LogError(ex, "Failed to parse JWT token"); _logger.LogError(ex, "Failed to extract claims from authenticated user");
} }
await _next(context); await _next(context);
} }
} }

View File

@ -1,9 +1,18 @@
using System.Security.Claims;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using YarpGateway.Services; using YarpGateway.Services;
namespace YarpGateway.Middleware; namespace YarpGateway.Middleware;
/// <summary>
/// 租户路由中间件
///
/// 安全说明:
/// 1. 验证 X-Tenant-Id Header 与 JWT 中的 tenant claim 一致
/// 2. 防止租户隔离绕过攻击
/// 3. 只有验证通过后才进行路由查找
/// </summary>
public class TenantRoutingMiddleware public class TenantRoutingMiddleware
{ {
private readonly RequestDelegate _next; private readonly RequestDelegate _next;
@ -22,13 +31,37 @@ public class TenantRoutingMiddleware
public async Task InvokeAsync(HttpContext context) public async Task InvokeAsync(HttpContext context)
{ {
var tenantId = context.Request.Headers["X-Tenant-Id"].FirstOrDefault(); var headerTenantId = context.Request.Headers["X-Tenant-Id"].FirstOrDefault();
if (string.IsNullOrEmpty(tenantId))
if (string.IsNullOrEmpty(headerTenantId))
{ {
await _next(context); await _next(context);
return; return;
} }
// 安全验证:检查 Header 中的租户 ID 是否与 JWT 一致
if (context.User?.Identity?.IsAuthenticated == true)
{
var jwtTenantId = context.User.Claims
.FirstOrDefault(c => c.Type == "tenant" || c.Type == "tenant_id")?.Value;
if (!string.IsNullOrEmpty(jwtTenantId) && jwtTenantId != headerTenantId)
{
// 记录安全事件
_logger.LogWarning(
"Tenant ID mismatch detected! JWT tenant: {JwtTenant}, Header tenant: {HeaderTenant}, User: {User}",
jwtTenantId,
headerTenantId,
context.User.FindFirst(ClaimTypes.NameIdentifier)?.Value ?? "unknown"
);
// 拒绝请求
context.Response.StatusCode = StatusCodes.Status403Forbidden;
await context.Response.WriteAsync("Tenant ID verification failed");
return;
}
}
var path = context.Request.Path.Value ?? string.Empty; var path = context.Request.Path.Value ?? string.Empty;
var serviceName = ExtractServiceName(path); var serviceName = ExtractServiceName(path);
@ -38,10 +71,10 @@ public class TenantRoutingMiddleware
return; return;
} }
var route = _routeCache.GetRoute(tenantId, serviceName); var route = _routeCache.GetRoute(headerTenantId, serviceName);
if (route == null) if (route == null)
{ {
_logger.LogWarning("Route not found - Tenant: {Tenant}, Service: {Service}", tenantId, serviceName); _logger.LogDebug("Route not found - Tenant: {Tenant}, Service: {Service}", headerTenantId, serviceName);
await _next(context); await _next(context);
return; return;
} }
@ -49,8 +82,8 @@ public class TenantRoutingMiddleware
context.Items["DynamicClusterId"] = route.ClusterId; context.Items["DynamicClusterId"] = route.ClusterId;
var routeType = route.IsGlobal ? "global" : "tenant-specific"; var routeType = route.IsGlobal ? "global" : "tenant-specific";
_logger.LogInformation("Tenant routing - Tenant: {Tenant}, Service: {Service}, Cluster: {Cluster}, Type: {Type}", _logger.LogDebug("Tenant routing - Tenant: {Tenant}, Service: {Service}, Cluster: {Cluster}, Type: {Type}",
tenantId, serviceName, route.ClusterId, routeType); headerTenantId, serviceName, route.ClusterId, routeType);
await _next(context); await _next(context);
} }
@ -60,4 +93,4 @@ public class TenantRoutingMiddleware
var match = Regex.Match(path, @"/api/(\w+)/?"); var match = Regex.Match(path, @"/api/(\w+)/?");
return match.Success ? match.Groups[1].Value : string.Empty; return match.Success ? match.Groups[1].Value : string.Empty;
} }
} }