diff --git a/src/Middleware/JwtTransformMiddleware.cs b/src/Middleware/JwtTransformMiddleware.cs
index 19998c2..0d2fdb5 100644
--- a/src/Middleware/JwtTransformMiddleware.cs
+++ b/src/Middleware/JwtTransformMiddleware.cs
@@ -6,6 +6,14 @@ using YarpGateway.Config;
namespace YarpGateway.Middleware;
+///
+/// JWT 转换中间件
+///
+/// 安全说明:
+/// 1. 从已验证的 JWT Claims 中提取用户信息(不是直接解析 token)
+/// 2. 清除请求中所有 X-* Header 以防止 Header 注入攻击
+/// 3. 验证租户 ID 与 JWT 中的 tenant claim 一致
+///
public class JwtTransformMiddleware
{
private readonly RequestDelegate _next;
@@ -25,32 +33,46 @@ public class JwtTransformMiddleware
public async Task InvokeAsync(HttpContext context)
{
- var authHeader = context.Request.Headers["Authorization"].FirstOrDefault();
- if (string.IsNullOrEmpty(authHeader) || !authHeader.StartsWith("Bearer "))
+ // 安全措施:清除所有 X-* Header 防止 Header 注入攻击
+ var xHeaders = context.Request.Headers
+ .Where(h => h.Key.StartsWith("X-", StringComparison.OrdinalIgnoreCase))
+ .Select(h => h.Key)
+ .ToList();
+
+ foreach (var header in xHeaders)
+ {
+ context.Request.Headers.Remove(header);
+ }
+
+ // 检查用户是否已通过 JWT 认证
+ if (context.User?.Identity?.IsAuthenticated != true)
{
await _next(context);
return;
}
- var token = authHeader.Substring("Bearer ".Length).Trim();
-
try
{
- var jwtHandler = new JwtSecurityTokenHandler();
- var jwtToken = jwtHandler.ReadJwtToken(token);
-
- var tenantId = jwtToken.Claims.FirstOrDefault(c => c.Type == "tenant")?.Value;
- var userId = jwtToken
- .Claims.FirstOrDefault(c => c.Type == ClaimTypes.NameIdentifier)
- ?.Value;
- var userName = jwtToken.Claims.FirstOrDefault(c => c.Type == ClaimTypes.Name)?.Value;
- var roles = jwtToken
- .Claims.Where(c => c.Type == ClaimTypes.Role)
+ // 从已验证的 ClaimsPrincipal 中提取信息(安全)
+ var claims = context.User.Claims;
+
+ var tenantId = claims.FirstOrDefault(c => c.Type == "tenant")?.Value
+ ?? claims.FirstOrDefault(c => c.Type == "tenant_id")?.Value;
+
+ var userId = claims.FirstOrDefault(c => c.Type == ClaimTypes.NameIdentifier)?.Value
+ ?? claims.FirstOrDefault(c => c.Type == "sub")?.Value;
+
+ var userName = claims.FirstOrDefault(c => c.Type == ClaimTypes.Name)?.Value
+ ?? claims.FirstOrDefault(c => c.Type == "name")?.Value;
+
+ var roles = claims
+ .Where(c => c.Type == ClaimTypes.Role || c.Type == "role")
.Select(c => c.Value)
.ToList();
if (!string.IsNullOrEmpty(tenantId))
{
+ // 安全地设置 Header(从已验证的 JWT 中提取)
context.Request.Headers["X-Tenant-Id"] = tenantId;
if (!string.IsNullOrEmpty(userId))
@@ -63,7 +85,7 @@ public class JwtTransformMiddleware
context.Request.Headers["X-Roles"] = string.Join(",", roles);
_logger.LogInformation(
- "JWT transformed - Tenant: {Tenant}, User: {User}",
+ "JWT claims transformed - Tenant: {Tenant}, User: {User}",
tenantId,
userId
);
@@ -75,9 +97,9 @@ public class JwtTransformMiddleware
}
catch (Exception ex)
{
- _logger.LogError(ex, "Failed to parse JWT token");
+ _logger.LogError(ex, "Failed to extract claims from authenticated user");
}
await _next(context);
}
-}
+}
\ No newline at end of file
diff --git a/src/Middleware/TenantRoutingMiddleware.cs b/src/Middleware/TenantRoutingMiddleware.cs
index ab0301d..e0bdd77 100644
--- a/src/Middleware/TenantRoutingMiddleware.cs
+++ b/src/Middleware/TenantRoutingMiddleware.cs
@@ -1,9 +1,18 @@
+using System.Security.Claims;
using Microsoft.Extensions.Options;
using System.Text.RegularExpressions;
using YarpGateway.Services;
namespace YarpGateway.Middleware;
+///
+/// 租户路由中间件
+///
+/// 安全说明:
+/// 1. 验证 X-Tenant-Id Header 与 JWT 中的 tenant claim 一致
+/// 2. 防止租户隔离绕过攻击
+/// 3. 只有验证通过后才进行路由查找
+///
public class TenantRoutingMiddleware
{
private readonly RequestDelegate _next;
@@ -22,13 +31,37 @@ public class TenantRoutingMiddleware
public async Task InvokeAsync(HttpContext context)
{
- var tenantId = context.Request.Headers["X-Tenant-Id"].FirstOrDefault();
- if (string.IsNullOrEmpty(tenantId))
+ var headerTenantId = context.Request.Headers["X-Tenant-Id"].FirstOrDefault();
+
+ if (string.IsNullOrEmpty(headerTenantId))
{
await _next(context);
return;
}
+ // 安全验证:检查 Header 中的租户 ID 是否与 JWT 一致
+ if (context.User?.Identity?.IsAuthenticated == true)
+ {
+ var jwtTenantId = context.User.Claims
+ .FirstOrDefault(c => c.Type == "tenant" || c.Type == "tenant_id")?.Value;
+
+ if (!string.IsNullOrEmpty(jwtTenantId) && jwtTenantId != headerTenantId)
+ {
+ // 记录安全事件
+ _logger.LogWarning(
+ "Tenant ID mismatch detected! JWT tenant: {JwtTenant}, Header tenant: {HeaderTenant}, User: {User}",
+ jwtTenantId,
+ headerTenantId,
+ context.User.FindFirst(ClaimTypes.NameIdentifier)?.Value ?? "unknown"
+ );
+
+ // 拒绝请求
+ context.Response.StatusCode = StatusCodes.Status403Forbidden;
+ await context.Response.WriteAsync("Tenant ID verification failed");
+ return;
+ }
+ }
+
var path = context.Request.Path.Value ?? string.Empty;
var serviceName = ExtractServiceName(path);
@@ -38,10 +71,10 @@ public class TenantRoutingMiddleware
return;
}
- var route = _routeCache.GetRoute(tenantId, serviceName);
+ var route = _routeCache.GetRoute(headerTenantId, serviceName);
if (route == null)
{
- _logger.LogWarning("Route not found - Tenant: {Tenant}, Service: {Service}", tenantId, serviceName);
+ _logger.LogDebug("Route not found - Tenant: {Tenant}, Service: {Service}", headerTenantId, serviceName);
await _next(context);
return;
}
@@ -49,8 +82,8 @@ public class TenantRoutingMiddleware
context.Items["DynamicClusterId"] = route.ClusterId;
var routeType = route.IsGlobal ? "global" : "tenant-specific";
- _logger.LogInformation("Tenant routing - Tenant: {Tenant}, Service: {Service}, Cluster: {Cluster}, Type: {Type}",
- tenantId, serviceName, route.ClusterId, routeType);
+ _logger.LogDebug("Tenant routing - Tenant: {Tenant}, Service: {Service}, Cluster: {Cluster}, Type: {Type}",
+ headerTenantId, serviceName, route.ClusterId, routeType);
await _next(context);
}
@@ -60,4 +93,4 @@ public class TenantRoutingMiddleware
var match = Regex.Match(path, @"/api/(\w+)/?");
return match.Success ? match.Groups[1].Value : string.Empty;
}
-}
+}
\ No newline at end of file