feat: add TenantRoutingTransform for multi-tenant routing (IMPL-10)

- Implement YARP RequestTransform to route requests based on tenant
- Extract tenant ID from JWT token (supports multiple claim types)
- Query tenant-specific destination first, fallback to default
- Add IMemoryCache for performance optimization (5min expiration)
- Update Platform packages to 1.0.14 to use GwDestination.TenantCode
This commit is contained in:
movingsam 2026-03-08 01:01:22 +08:00
parent 52eba07097
commit c7cf1d3738
3 changed files with 258 additions and 2 deletions

View File

@ -4,7 +4,7 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<!-- Fengling ServiceDiscovery Packages (from Gitea) --> <!-- Fengling ServiceDiscovery Packages (from Gitea) -->
<PackageVersion Include="Fengling.Platform.Infrastructure" Version="1.0.12" /> <PackageVersion Include="Fengling.Platform.Infrastructure" Version="1.0.14" />
<PackageVersion Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="10.0.2" /> <PackageVersion Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="10.0.2" />
<PackageVersion Include="Microsoft.AspNetCore.Http.Abstractions" Version="10.0.0" /> <PackageVersion Include="Microsoft.AspNetCore.Http.Abstractions" Version="10.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Design" Version="10.0.2" /> <PackageVersion Include="Microsoft.EntityFrameworkCore.Design" Version="10.0.2" />

View File

@ -4,12 +4,14 @@ using Microsoft.Extensions.Options;
using Serilog; using Serilog;
using Yarp.ReverseProxy.Configuration; using Yarp.ReverseProxy.Configuration;
using Yarp.ReverseProxy.LoadBalancing; using Yarp.ReverseProxy.LoadBalancing;
using Yarp.ReverseProxy.Transforms;
using YarpGateway.Config; using YarpGateway.Config;
using YarpGateway.Data; using YarpGateway.Data;
using YarpGateway.DynamicProxy; using YarpGateway.DynamicProxy;
using YarpGateway.LoadBalancing; using YarpGateway.LoadBalancing;
using YarpGateway.Middleware; using YarpGateway.Middleware;
using YarpGateway.Services; using YarpGateway.Services;
using YarpGateway.Transforms;
using StackExchange.Redis; using StackExchange.Redis;
var builder = WebApplication.CreateBuilder(args); var builder = WebApplication.CreateBuilder(args);
@ -132,7 +134,24 @@ builder.Services.AddCors(options =>
builder.Services.AddControllers(); builder.Services.AddControllers();
builder.Services.AddHttpForwarder(); builder.Services.AddHttpForwarder();
builder.Services.AddRouting(); builder.Services.AddRouting();
builder.Services.AddReverseProxy();
// 添加内存缓存
builder.Services.AddMemoryCache();
// 注册租户路由转换器
builder.Services.AddSingleton<TenantRoutingTransform>();
// 配置 YARP 反向代理
builder.Services.AddReverseProxy()
.LoadFromConfig(builder.Configuration.GetSection("ReverseProxy"))
.AddTransforms(transformBuilder =>
{
transformBuilder.AddRequestTransform(async context =>
{
var transform = context.HttpContext.RequestServices.GetRequiredService<TenantRoutingTransform>();
await transform.ApplyAsync(context);
});
});
var app = builder.Build(); var app = builder.Build();

View File

@ -0,0 +1,237 @@
using System.Security.Claims;
using Fengling.Platform.Domain.AggregatesModel.GatewayAggregate;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Caching.Memory;
using Yarp.ReverseProxy.Transforms;
using YarpGateway.Data;
namespace YarpGateway.Transforms;
/// <summary>
/// 租户路由转换器
///
/// 功能说明:
/// 1. 从 JWT Token 解析租户 ID
/// 2. 根据租户 ID 查找对应的目标端点
/// 3. 优先使用租户专属目标,不存在则回退到默认目标
/// 4. 使用内存缓存优化性能
/// </summary>
public class TenantRoutingTransform : RequestTransform
{
private readonly IDbContextFactory<GatewayDbContext> _dbContextFactory;
private readonly IMemoryCache _cache;
private readonly ILogger<TenantRoutingTransform> _logger;
// 缓存键前缀
private const string CacheKeyPrefix = "tenant_destination";
// 缓存过期时间5分钟
private static readonly TimeSpan CacheExpiration = TimeSpan.FromMinutes(5);
public TenantRoutingTransform(
IDbContextFactory<GatewayDbContext> dbContextFactory,
IMemoryCache cache,
ILogger<TenantRoutingTransform> logger)
{
_dbContextFactory = dbContextFactory;
_cache = cache;
_logger = logger;
}
public override async ValueTask ApplyAsync(RequestTransformContext context)
{
// 从 HttpContext 获取 ClusterId由 TenantRoutingMiddleware 设置)
if (!context.HttpContext.Items.TryGetValue("DynamicClusterId", out var clusterIdObj) || clusterIdObj is not string clusterId)
{
_logger.LogDebug("No DynamicClusterId found in HttpContext, skipping tenant routing");
return;
}
// 从 JWT Token 解析租户 ID
var tenantId = ExtractTenantFromJwt(context.HttpContext);
if (string.IsNullOrEmpty(tenantId))
{
_logger.LogDebug("No tenant ID found in JWT, using default destination for cluster {ClusterId}", clusterId);
// 没有租户ID时使用默认目标
await RouteToDefaultDestinationAsync(context, clusterId);
return;
}
_logger.LogDebug("Processing tenant routing for cluster {ClusterId}, tenant {TenantId}", clusterId, tenantId);
// 查找租户专属目标
var tenantDestination = await FindTenantDestinationAsync(clusterId, tenantId);
if (tenantDestination != null)
{
_logger.LogDebug("Using tenant-specific destination {DestinationId} ({Address}) for tenant {TenantId}",
tenantDestination.DestinationId, tenantDestination.Address, tenantId);
context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, tenantDestination.Address);
}
else
{
// 回退到默认目标
_logger.LogDebug("No tenant-specific destination found for tenant {TenantId}, falling back to default", tenantId);
await RouteToDefaultDestinationAsync(context, clusterId);
}
}
/// <summary>
/// 从 JWT Token 中提取租户 ID
/// </summary>
private static string? ExtractTenantFromJwt(HttpContext httpContext)
{
if (httpContext.User?.Identity?.IsAuthenticated != true)
{
return null;
}
// 尝试从 claims 中获取租户 ID
var tenantClaim = httpContext.User.Claims.FirstOrDefault(c =>
c.Type.Equals("tenant", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tenant_id", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tenant_code", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tid", StringComparison.OrdinalIgnoreCase));
return tenantClaim?.Value;
}
/// <summary>
/// 查找租户专属的目标端点
/// </summary>
private async Task<GwDestination?> FindTenantDestinationAsync(string clusterId, string tenantId)
{
var cacheKey = $"{CacheKeyPrefix}:{clusterId}:{tenantId}";
// 尝试从缓存获取
if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination))
{
_logger.LogDebug("Cache hit for tenant destination: {CacheKey}", cacheKey);
return cachedDestination;
}
// 从数据库查询
await using var dbContext = await _dbContextFactory.CreateDbContextAsync();
var cluster = await dbContext.GwClusters
.AsNoTracking()
.FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1);
if (cluster == null)
{
_logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId);
return null;
}
// 查找租户专属目标TenantCode 匹配且状态正常)
var tenantDestination = cluster.Destinations
.FirstOrDefault(d =>
d.Status == 1 &&
!string.IsNullOrEmpty(d.TenantCode) &&
d.TenantCode.Equals(tenantId, StringComparison.OrdinalIgnoreCase));
if (tenantDestination != null)
{
// 缓存结果
var cacheOptions = new MemoryCacheEntryOptions()
.SetAbsoluteExpiration(CacheExpiration)
.SetSize(1);
_cache.Set(cacheKey, tenantDestination, cacheOptions);
_logger.LogDebug("Cached tenant destination for key: {CacheKey}", cacheKey);
}
return tenantDestination;
}
/// <summary>
/// 查找默认目标端点TenantCode 为空)
/// </summary>
private async Task<GwDestination?> FindDefaultDestinationAsync(string clusterId)
{
var cacheKey = $"{CacheKeyPrefix}:{clusterId}:default";
// 尝试从缓存获取
if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination))
{
_logger.LogDebug("Cache hit for default destination: {CacheKey}", cacheKey);
return cachedDestination;
}
// 从数据库查询
await using var dbContext = await _dbContextFactory.CreateDbContextAsync();
var cluster = await dbContext.GwClusters
.AsNoTracking()
.FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1);
if (cluster == null)
{
_logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId);
return null;
}
// 查找默认目标TenantCode 为空或 null 且状态正常)
var defaultDestination = cluster.Destinations
.FirstOrDefault(d =>
d.Status == 1 &&
string.IsNullOrEmpty(d.TenantCode));
if (defaultDestination != null)
{
// 缓存结果
var cacheOptions = new MemoryCacheEntryOptions()
.SetAbsoluteExpiration(CacheExpiration)
.SetSize(1);
_cache.Set(cacheKey, defaultDestination, cacheOptions);
_logger.LogDebug("Cached default destination for key: {CacheKey}", cacheKey);
}
return defaultDestination;
}
/// <summary>
/// 路由到默认目标
/// </summary>
private async ValueTask RouteToDefaultDestinationAsync(RequestTransformContext context, string clusterId)
{
var defaultDestination = await FindDefaultDestinationAsync(clusterId);
if (defaultDestination != null)
{
_logger.LogDebug("Using default destination {DestinationId} ({Address}) for cluster {ClusterId}",
defaultDestination.DestinationId, defaultDestination.Address, clusterId);
context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, defaultDestination.Address);
}
else
{
_logger.LogWarning("No default destination found for cluster {ClusterId}", clusterId);
}
}
/// <summary>
/// 构建完整的请求 URI
/// </summary>
private static Uri BuildRequestUri(string? destinationPrefix, string address)
{
// 如果 address 已经是完整 URI直接使用
if (Uri.TryCreate(address, UriKind.Absolute, out var absoluteUri))
{
return absoluteUri;
}
// 否则,与 destinationPrefix 组合
if (!string.IsNullOrEmpty(destinationPrefix))
{
var baseUri = destinationPrefix.TrimEnd('/');
var path = address.StartsWith('/') ? address : "/" + address;
return new Uri(baseUri + path);
}
// 如果都无法构建,抛出异常
throw new InvalidOperationException($"Cannot build valid URI from prefix '{destinationPrefix}' and address '{address}'");
}
}