diff --git a/src/yarpgateway/Middleware/TenantRoutingMiddleware.cs b/src/yarpgateway/Middleware/TenantRoutingMiddleware.cs
index e0bdd77..a5f0c38 100644
--- a/src/yarpgateway/Middleware/TenantRoutingMiddleware.cs
+++ b/src/yarpgateway/Middleware/TenantRoutingMiddleware.cs
@@ -1,5 +1,4 @@
using System.Security.Claims;
-using Microsoft.Extensions.Options;
using System.Text.RegularExpressions;
using YarpGateway.Services;
@@ -8,10 +7,14 @@ namespace YarpGateway.Middleware;
///
/// 租户路由中间件
///
-/// 安全说明:
+/// 职责:
/// 1. 验证 X-Tenant-Id Header 与 JWT 中的 tenant claim 一致
/// 2. 防止租户隔离绕过攻击
-/// 3. 只有验证通过后才进行路由查找
+/// 3. 设置 DynamicClusterId 供 TenantRoutingTransform 使用
+/// 4. 设置 TenantId 供 TenantRoutingTransform 使用
+///
+/// 协作:
+/// - TenantRoutingTransform 负责从数据库查询 Destination 并设置 ProxyRequest.RequestUri
///
public class TenantRoutingMiddleware
{
@@ -31,6 +34,7 @@ public class TenantRoutingMiddleware
public async Task InvokeAsync(HttpContext context)
{
+ // 1. 从 Header 获取 TenantId
var headerTenantId = context.Request.Headers["X-Tenant-Id"].FirstOrDefault();
if (string.IsNullOrEmpty(headerTenantId))
@@ -39,7 +43,7 @@ public class TenantRoutingMiddleware
return;
}
- // 安全验证:检查 Header 中的租户 ID 是否与 JWT 一致
+ // 2. JWT Token 解析与验证:检查 Header 中的租户 ID 是否与 JWT 一致
if (context.User?.Identity?.IsAuthenticated == true)
{
var jwtTenantId = context.User.Claims
@@ -62,6 +66,7 @@ public class TenantRoutingMiddleware
}
}
+ // 3. 从路径提取服务名
var path = context.Request.Path.Value ?? string.Empty;
var serviceName = ExtractServiceName(path);
@@ -71,6 +76,7 @@ public class TenantRoutingMiddleware
return;
}
+ // 4. 从 RouteCache 获取路由信息
var route = _routeCache.GetRoute(headerTenantId, serviceName);
if (route == null)
{
@@ -79,18 +85,25 @@ public class TenantRoutingMiddleware
return;
}
+ // 5. 设置 ClusterId 和 TenantId(Transform 会使用这些信息选择 Destination)
context.Items["DynamicClusterId"] = route.ClusterId;
+ context.Items["TenantId"] = headerTenantId; // 供 Transform 使用
var routeType = route.IsGlobal ? "global" : "tenant-specific";
_logger.LogDebug("Tenant routing - Tenant: {Tenant}, Service: {Service}, Cluster: {Cluster}, Type: {Type}",
headerTenantId, serviceName, route.ClusterId, routeType);
+ // 6. 继续执行,由 TenantRoutingTransform 处理 Destination 选择
await _next(context);
}
- private string ExtractServiceName(string path)
+ ///
+ /// 从请求路径提取服务名
+ /// 例如:/api/user-service/users -> user-service
+ ///
+ private static string ExtractServiceName(string path)
{
- var match = Regex.Match(path, @"/api/(\w+)/?");
+ var match = Regex.Match(path, @"/api/([^/]+)/?");
return match.Success ? match.Groups[1].Value : string.Empty;
}
-}
\ No newline at end of file
+}
diff --git a/src/yarpgateway/Transforms/TenantRoutingTransform.cs b/src/yarpgateway/Transforms/TenantRoutingTransform.cs
index 857085d..d978453 100644
--- a/src/yarpgateway/Transforms/TenantRoutingTransform.cs
+++ b/src/yarpgateway/Transforms/TenantRoutingTransform.cs
@@ -46,8 +46,9 @@ public class TenantRoutingTransform : RequestTransform
return;
}
- // 从 JWT Token 解析租户 ID
- var tenantId = ExtractTenantFromJwt(context.HttpContext);
+ // 优先从 HttpContext.Items 获取 TenantId(由 TenantRoutingMiddleware 设置)
+ // 如果 Middleware 没有设置,则从 JWT Token 解析
+ var tenantId = ExtractTenantId(context.HttpContext);
if (string.IsNullOrEmpty(tenantId))
{
@@ -77,6 +78,23 @@ public class TenantRoutingTransform : RequestTransform
}
}
+ ///
+ /// 提取租户 ID
+ /// 优先从 HttpContext.Items 获取(由 TenantRoutingMiddleware 设置)
+ /// 如果不存在,则从 JWT Token 解析
+ ///
+ private static string? ExtractTenantId(HttpContext httpContext)
+ {
+ // 1. 优先从 HttpContext.Items 获取(Middleware 已验证过)
+ if (httpContext.Items.TryGetValue("TenantId", out var tenantIdObj) && tenantIdObj is string tenantId)
+ {
+ return tenantId;
+ }
+
+ // 2. 从 JWT Token 解析(作为回退)
+ return ExtractTenantFromJwt(httpContext);
+ }
+
///
/// 从 JWT Token 中提取租户 ID
///
diff --git a/tests/YarpGateway.Tests/Unit/Middleware/TenantRoutingMiddlewareTests.cs b/tests/YarpGateway.Tests/Unit/Middleware/TenantRoutingMiddlewareTests.cs
index 9c50956..d6576fd 100644
--- a/tests/YarpGateway.Tests/Unit/Middleware/TenantRoutingMiddlewareTests.cs
+++ b/tests/YarpGateway.Tests/Unit/Middleware/TenantRoutingMiddlewareTests.cs
@@ -9,6 +9,15 @@ using YarpGateway.Services;
namespace YarpGateway.Tests.Unit.Middleware;
+///
+/// TenantRoutingMiddleware 单元测试
+///
+/// 测试范围:
+/// 1. JWT 验证和 TenantId 验证
+/// 2. ClusterId 设置(供 Transform 使用)
+/// 3. TenantId 设置(供 Transform 使用)
+/// 4. 职责分离验证(Middleware 不负责 Destination 选择)
+///
public class TenantRoutingMiddlewareTests
{
private readonly Mock _routeCacheMock;
@@ -50,13 +59,13 @@ public class TenantRoutingMiddlewareTests
return context;
}
- private DefaultHttpContext CreateAuthenticatedContext(string tenantId, string headerTenantId)
+ private DefaultHttpContext CreateAuthenticatedContext(string jwtTenantId, string headerTenantId)
{
var context = CreateContext(headerTenantId);
var claims = new List
{
- new Claim("tenant", tenantId),
+ new Claim("tenant", jwtTenantId),
new Claim(ClaimTypes.NameIdentifier, "user-1")
};
@@ -66,6 +75,8 @@ public class TenantRoutingMiddlewareTests
return context;
}
+ #region JWT 验证测试
+
[Fact]
public async Task InvokeAsync_WithoutTenantHeader_ShouldCallNext()
{
@@ -81,51 +92,6 @@ public class TenantRoutingMiddlewareTests
nextCalled.Should().BeTrue();
}
- [Fact]
- public async Task InvokeAsync_WithValidTenantAndRoute_ShouldSetClusterId()
- {
- // Arrange
- var routeInfo = new RouteInfo
- {
- Id = "1",
- ClusterId = "cluster-user-service",
- PathPattern = "/api/user-service/**",
- Priority = 1,
- IsGlobal = false
- };
-
- _routeCacheMock
- .Setup(x => x.GetRoute("tenant-1", "user-service"))
- .Returns(routeInfo);
-
- var middleware = CreateMiddleware();
- var context = CreateContext(tenantId: "tenant-1");
-
- // Act
- await middleware.InvokeAsync(context);
-
- // Assert
- context.Items["DynamicClusterId"].Should().Be("cluster-user-service");
- }
-
- [Fact]
- public async Task InvokeAsync_WhenRouteNotFound_ShouldCallNext()
- {
- // Arrange
- _routeCacheMock
- .Setup(x => x.GetRoute(It.IsAny(), It.IsAny()))
- .Returns((RouteInfo?)null);
-
- var middleware = CreateMiddleware();
- var context = CreateContext(tenantId: "tenant-1");
-
- // Act
- await middleware.InvokeAsync(context);
-
- // Assert - should not throw, just continue
- context.Items.Should().NotContainKey("DynamicClusterId");
- }
-
[Fact]
public async Task InvokeAsync_WithTenantIdMismatch_ShouldReturn403()
{
@@ -194,38 +160,20 @@ public class TenantRoutingMiddlewareTests
context.Items["DynamicClusterId"].Should().Be("cluster-1");
}
- [Theory]
- [InlineData("/api/user-service/users", "user-service")]
- [InlineData("/api/order-service/orders", "order-service")]
- [InlineData("/api/payment/", "payment")]
- [InlineData("/api/auth", "auth")]
- [InlineData("/other/path", "")]
- public async Task InvokeAsync_ShouldExtractServiceNameFromPath(string path, string expectedServiceName)
- {
- // Arrange
- var middleware = CreateMiddleware();
- var context = CreateContext(tenantId: "tenant-1", path: path);
+ #endregion
- // Act
- await middleware.InvokeAsync(context);
-
- // Assert
- if (!string.IsNullOrEmpty(expectedServiceName))
- {
- _routeCacheMock.Verify(
- x => x.GetRoute("tenant-1", expectedServiceName),
- Times.Once);
- }
- }
+ #region ClusterId 和 TenantId 设置测试
[Fact]
- public async Task InvokeAsync_WithTenantRoute_ShouldLogAsTenantSpecific()
+ public async Task InvokeAsync_WithValidTenantAndRoute_ShouldSetClusterId()
{
// Arrange
var routeInfo = new RouteInfo
{
Id = "1",
- ClusterId = "cluster-1",
+ ClusterId = "cluster-user-service",
+ PathPattern = "/api/user-service/**",
+ Priority = 1,
IsGlobal = false
};
@@ -239,19 +187,21 @@ public class TenantRoutingMiddlewareTests
// Act
await middleware.InvokeAsync(context);
- // Assert - just verify it completes without error
- context.Items["DynamicClusterId"].Should().Be("cluster-1");
+ // Assert
+ context.Items["DynamicClusterId"].Should().Be("cluster-user-service");
}
[Fact]
- public async Task InvokeAsync_WithGlobalRoute_ShouldLogAsGlobal()
+ public async Task InvokeAsync_WithValidTenantAndRoute_ShouldSetTenantId()
{
// Arrange
var routeInfo = new RouteInfo
{
Id = "1",
- ClusterId = "global-cluster",
- IsGlobal = true
+ ClusterId = "cluster-user-service",
+ PathPattern = "/api/user-service/**",
+ Priority = 1,
+ IsGlobal = false
};
_routeCacheMock
@@ -264,24 +214,59 @@ public class TenantRoutingMiddlewareTests
// Act
await middleware.InvokeAsync(context);
- // Assert
- context.Items["DynamicClusterId"].Should().Be("global-cluster");
+ // Assert - TenantId 应该被设置供 Transform 使用
+ context.Items["TenantId"].Should().Be("tenant-1");
}
[Fact]
- public async Task InvokeAsync_WithEmptyPath_ShouldCallNext()
+ public async Task InvokeAsync_WithValidRoute_ShouldSetBothClusterIdAndTenantId()
{
// Arrange
+ var routeInfo = new RouteInfo
+ {
+ Id = "1",
+ ClusterId = "cluster-order-service",
+ IsGlobal = false
+ };
+
+ _routeCacheMock
+ .Setup(x => x.GetRoute("tenant-abc", "order-service"))
+ .Returns(routeInfo);
+
var middleware = CreateMiddleware();
- var context = CreateContext(tenantId: "tenant-1", path: "");
+ var context = CreateContext(tenantId: "tenant-abc", path: "/api/order-service/orders");
// Act
await middleware.InvokeAsync(context);
- // Assert - should not try to extract service name
- _routeCacheMock.Verify(
- x => x.GetRoute(It.IsAny(), It.IsAny()),
- Times.Never);
+ // Assert
+ context.Items.Should().ContainKey("DynamicClusterId");
+ context.Items.Should().ContainKey("TenantId");
+ context.Items["DynamicClusterId"].Should().Be("cluster-order-service");
+ context.Items["TenantId"].Should().Be("tenant-abc");
+ }
+
+ #endregion
+
+ #region 路由查找测试
+
+ [Fact]
+ public async Task InvokeAsync_WhenRouteNotFound_ShouldCallNext()
+ {
+ // Arrange
+ _routeCacheMock
+ .Setup(x => x.GetRoute(It.IsAny(), It.IsAny()))
+ .Returns((RouteInfo?)null);
+
+ var middleware = CreateMiddleware();
+ var context = CreateContext(tenantId: "tenant-1");
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert - should not throw, just continue
+ context.Items.Should().NotContainKey("DynamicClusterId");
+ context.Items.Should().NotContainKey("TenantId");
}
[Fact]
@@ -310,4 +295,253 @@ public class TenantRoutingMiddlewareTests
// Assert
context.Items["DynamicClusterId"].Should().Be("tenant-specific-cluster");
}
+
+ [Theory]
+ [InlineData("/api/user-service/users", "user-service")]
+ [InlineData("/api/order-service/orders", "order-service")]
+ [InlineData("/api/payment-service/", "payment-service")]
+ [InlineData("/api/auth-service", "auth-service")]
+ [InlineData("/api/user_service", "user_service")]
+ [InlineData("/api/UserService123", "UserService123")]
+ [InlineData("/other/path", "")]
+ public async Task InvokeAsync_ShouldExtractServiceNameFromPath(string path, string expectedServiceName)
+ {
+ // Arrange
+ var middleware = CreateMiddleware();
+ var context = CreateContext(tenantId: "tenant-1", path: path);
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert
+ if (!string.IsNullOrEmpty(expectedServiceName))
+ {
+ _routeCacheMock.Verify(
+ x => x.GetRoute("tenant-1", expectedServiceName),
+ Times.Once);
+ }
+ }
+
+ [Fact]
+ public async Task InvokeAsync_WithEmptyPath_ShouldCallNext()
+ {
+ // Arrange
+ var middleware = CreateMiddleware();
+ var context = CreateContext(tenantId: "tenant-1", path: "");
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert - should not try to extract service name
+ _routeCacheMock.Verify(
+ x => x.GetRoute(It.IsAny(), It.IsAny()),
+ Times.Never);
+ }
+
+ #endregion
+
+ #region 职责分离测试
+
+ [Fact]
+ public async Task InvokeAsync_ShouldNotSetDestinationUri()
+ {
+ // Arrange - Middleware 不应该设置 Destination 相关信息
+ var routeInfo = new RouteInfo
+ {
+ Id = "1",
+ ClusterId = "cluster-1",
+ IsGlobal = false
+ };
+
+ _routeCacheMock
+ .Setup(x => x.GetRoute("tenant-1", "user-service"))
+ .Returns(routeInfo);
+
+ var middleware = CreateMiddleware();
+ var context = CreateContext(tenantId: "tenant-1");
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert - Middleware 不应该设置任何 Destination 相关的 Items
+ // 这些应该由 TenantRoutingTransform 处理
+ context.Items.Should().NotContainKey("DestinationUri");
+ context.Items.Should().NotContainKey("TargetAddress");
+ }
+
+ [Fact]
+ public async Task InvokeAsync_ShouldOnlySetClusterIdAndTenantId()
+ {
+ // Arrange - 验证 Middleware 只负责设置 ClusterId 和 TenantId
+ var routeInfo = new RouteInfo
+ {
+ Id = "1",
+ ClusterId = "cluster-payment",
+ IsGlobal = false
+ };
+
+ _routeCacheMock
+ .Setup(x => x.GetRoute("tenant-x", "payment"))
+ .Returns(routeInfo);
+
+ var middleware = CreateMiddleware();
+ var context = CreateContext(tenantId: "tenant-x", path: "/api/payment/process");
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert - 只设置 ClusterId 和 TenantId
+ context.Items.Count.Should().Be(2);
+ context.Items.Should().ContainKey("DynamicClusterId");
+ context.Items.Should().ContainKey("TenantId");
+ }
+
+ [Fact]
+ public async Task InvokeAsync_WithGlobalRoute_ShouldSetClusterIdWithoutDestinationSelection()
+ {
+ // Arrange - 全局路由也应该只设置 ClusterId,不处理 Destination
+ var routeInfo = new RouteInfo
+ {
+ Id = "1",
+ ClusterId = "global-cluster",
+ IsGlobal = true
+ };
+
+ _routeCacheMock
+ .Setup(x => x.GetRoute("tenant-1", "user-service"))
+ .Returns(routeInfo);
+
+ var middleware = CreateMiddleware();
+ var context = CreateContext(tenantId: "tenant-1");
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert - 只设置 ClusterId,Destination 选择由 Transform 处理
+ context.Items["DynamicClusterId"].Should().Be("global-cluster");
+ context.Items["TenantId"].Should().Be("tenant-1");
+ }
+
+ #endregion
+
+ #region 与 Transform 协作测试
+
+ [Fact]
+ public async Task InvokeAsync_ShouldSetItemsForTransformConsumption()
+ {
+ // Arrange - 模拟 Transform 从 Items 读取 ClusterId 和 TenantId
+ var routeInfo = new RouteInfo
+ {
+ Id = "1",
+ ClusterId = "cluster-inventory",
+ IsGlobal = false
+ };
+
+ _routeCacheMock
+ .Setup(x => x.GetRoute("tenant-cooperation", "inventory"))
+ .Returns(routeInfo);
+
+ string? capturedClusterId = null;
+ string? capturedTenantId = null;
+
+ var nextDelegate = new RequestDelegate(ctx =>
+ {
+ // 模拟 Transform 的行为:从 Items 读取
+ capturedClusterId = ctx.Items.TryGetValue("DynamicClusterId", out var c) ? c as string : null;
+ capturedTenantId = ctx.Items.TryGetValue("TenantId", out var t) ? t as string : null;
+ return Task.CompletedTask;
+ });
+
+ var middleware = CreateMiddleware(next: nextDelegate);
+ var context = CreateContext(tenantId: "tenant-cooperation", path: "/api/inventory/items");
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert - 验证 Transform 可以正确读取 Middleware 设置的值
+ capturedClusterId.Should().Be("cluster-inventory");
+ capturedTenantId.Should().Be("tenant-cooperation");
+ }
+
+ [Fact]
+ public async Task InvokeAsync_WhenRouteNotFound_ShouldNotSetItems()
+ {
+ // Arrange
+ _routeCacheMock
+ .Setup(x => x.GetRoute("unknown-tenant", "user-service"))
+ .Returns((RouteInfo?)null);
+
+ var nextDelegate = new RequestDelegate(ctx =>
+ {
+ // 模拟 Transform 检查 Items
+ ctx.Items.ContainsKey("DynamicClusterId").Should().BeFalse();
+ ctx.Items.ContainsKey("TenantId").Should().BeFalse();
+ return Task.CompletedTask;
+ });
+
+ var middleware = CreateMiddleware(next: nextDelegate);
+ var context = CreateContext(tenantId: "unknown-tenant");
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert - Transform 应该跳过处理(因为没有 ClusterId)
+ }
+
+ #endregion
+
+ #region 边界条件测试
+
+ [Fact]
+ public async Task InvokeAsync_WithTenantRoute_ShouldWorkCorrectly()
+ {
+ // Arrange
+ var routeInfo = new RouteInfo
+ {
+ Id = "1",
+ ClusterId = "cluster-1",
+ IsGlobal = false
+ };
+
+ _routeCacheMock
+ .Setup(x => x.GetRoute("tenant-1", "user-service"))
+ .Returns(routeInfo);
+
+ var middleware = CreateMiddleware();
+ var context = CreateContext(tenantId: "tenant-1");
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert - just verify it completes without error
+ context.Items["DynamicClusterId"].Should().Be("cluster-1");
+ }
+
+ [Fact]
+ public async Task InvokeAsync_WithSpecialCharactersInTenantId_ShouldSetTenantId()
+ {
+ // Arrange - 测试特殊字符的 TenantId
+ var specialTenantId = "tenant-123_abc.XYZ";
+ var routeInfo = new RouteInfo
+ {
+ Id = "1",
+ ClusterId = "cluster-1",
+ IsGlobal = false
+ };
+
+ _routeCacheMock
+ .Setup(x => x.GetRoute(specialTenantId, "user-service"))
+ .Returns(routeInfo);
+
+ var middleware = CreateMiddleware();
+ var context = CreateContext(tenantId: specialTenantId);
+
+ // Act
+ await middleware.InvokeAsync(context);
+
+ // Assert
+ context.Items["TenantId"].Should().Be(specialTenantId);
+ }
+
+ #endregion
}