From c7cf1d3738357facec903170abb43f0511d86398 Mon Sep 17 00:00:00 2001 From: movingsam Date: Sun, 8 Mar 2026 01:01:22 +0800 Subject: [PATCH] feat: add TenantRoutingTransform for multi-tenant routing (IMPL-10) - Implement YARP RequestTransform to route requests based on tenant - Extract tenant ID from JWT token (supports multiple claim types) - Query tenant-specific destination first, fallback to default - Add IMemoryCache for performance optimization (5min expiration) - Update Platform packages to 1.0.14 to use GwDestination.TenantCode --- src/yarpgateway/Directory.Packages.props | 2 +- src/yarpgateway/Program.cs | 21 +- .../Transforms/TenantRoutingTransform.cs | 237 ++++++++++++++++++ 3 files changed, 258 insertions(+), 2 deletions(-) create mode 100644 src/yarpgateway/Transforms/TenantRoutingTransform.cs diff --git a/src/yarpgateway/Directory.Packages.props b/src/yarpgateway/Directory.Packages.props index 50be47e..9bb13b0 100644 --- a/src/yarpgateway/Directory.Packages.props +++ b/src/yarpgateway/Directory.Packages.props @@ -4,7 +4,7 @@ - + diff --git a/src/yarpgateway/Program.cs b/src/yarpgateway/Program.cs index ffe2842..566e36e 100644 --- a/src/yarpgateway/Program.cs +++ b/src/yarpgateway/Program.cs @@ -4,12 +4,14 @@ using Microsoft.Extensions.Options; using Serilog; using Yarp.ReverseProxy.Configuration; using Yarp.ReverseProxy.LoadBalancing; +using Yarp.ReverseProxy.Transforms; using YarpGateway.Config; using YarpGateway.Data; using YarpGateway.DynamicProxy; using YarpGateway.LoadBalancing; using YarpGateway.Middleware; using YarpGateway.Services; +using YarpGateway.Transforms; using StackExchange.Redis; var builder = WebApplication.CreateBuilder(args); @@ -132,7 +134,24 @@ builder.Services.AddCors(options => builder.Services.AddControllers(); builder.Services.AddHttpForwarder(); builder.Services.AddRouting(); -builder.Services.AddReverseProxy(); + +// 添加内存缓存 +builder.Services.AddMemoryCache(); + +// 注册租户路由转换器 +builder.Services.AddSingleton(); + +// 配置 YARP 反向代理 +builder.Services.AddReverseProxy() + .LoadFromConfig(builder.Configuration.GetSection("ReverseProxy")) + .AddTransforms(transformBuilder => + { + transformBuilder.AddRequestTransform(async context => + { + var transform = context.HttpContext.RequestServices.GetRequiredService(); + await transform.ApplyAsync(context); + }); + }); var app = builder.Build(); diff --git a/src/yarpgateway/Transforms/TenantRoutingTransform.cs b/src/yarpgateway/Transforms/TenantRoutingTransform.cs new file mode 100644 index 0000000..857085d --- /dev/null +++ b/src/yarpgateway/Transforms/TenantRoutingTransform.cs @@ -0,0 +1,237 @@ +using System.Security.Claims; +using Fengling.Platform.Domain.AggregatesModel.GatewayAggregate; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Caching.Memory; +using Yarp.ReverseProxy.Transforms; +using YarpGateway.Data; + +namespace YarpGateway.Transforms; + +/// +/// 租户路由转换器 +/// +/// 功能说明: +/// 1. 从 JWT Token 解析租户 ID +/// 2. 根据租户 ID 查找对应的目标端点 +/// 3. 优先使用租户专属目标,不存在则回退到默认目标 +/// 4. 使用内存缓存优化性能 +/// +public class TenantRoutingTransform : RequestTransform +{ + private readonly IDbContextFactory _dbContextFactory; + private readonly IMemoryCache _cache; + private readonly ILogger _logger; + + // 缓存键前缀 + private const string CacheKeyPrefix = "tenant_destination"; + // 缓存过期时间:5分钟 + private static readonly TimeSpan CacheExpiration = TimeSpan.FromMinutes(5); + + public TenantRoutingTransform( + IDbContextFactory dbContextFactory, + IMemoryCache cache, + ILogger logger) + { + _dbContextFactory = dbContextFactory; + _cache = cache; + _logger = logger; + } + + public override async ValueTask ApplyAsync(RequestTransformContext context) + { + // 从 HttpContext 获取 ClusterId(由 TenantRoutingMiddleware 设置) + if (!context.HttpContext.Items.TryGetValue("DynamicClusterId", out var clusterIdObj) || clusterIdObj is not string clusterId) + { + _logger.LogDebug("No DynamicClusterId found in HttpContext, skipping tenant routing"); + return; + } + + // 从 JWT Token 解析租户 ID + var tenantId = ExtractTenantFromJwt(context.HttpContext); + + if (string.IsNullOrEmpty(tenantId)) + { + _logger.LogDebug("No tenant ID found in JWT, using default destination for cluster {ClusterId}", clusterId); + // 没有租户ID时,使用默认目标 + await RouteToDefaultDestinationAsync(context, clusterId); + return; + } + + _logger.LogDebug("Processing tenant routing for cluster {ClusterId}, tenant {TenantId}", clusterId, tenantId); + + // 查找租户专属目标 + var tenantDestination = await FindTenantDestinationAsync(clusterId, tenantId); + + if (tenantDestination != null) + { + _logger.LogDebug("Using tenant-specific destination {DestinationId} ({Address}) for tenant {TenantId}", + tenantDestination.DestinationId, tenantDestination.Address, tenantId); + + context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, tenantDestination.Address); + } + else + { + // 回退到默认目标 + _logger.LogDebug("No tenant-specific destination found for tenant {TenantId}, falling back to default", tenantId); + await RouteToDefaultDestinationAsync(context, clusterId); + } + } + + /// + /// 从 JWT Token 中提取租户 ID + /// + private static string? ExtractTenantFromJwt(HttpContext httpContext) + { + if (httpContext.User?.Identity?.IsAuthenticated != true) + { + return null; + } + + // 尝试从 claims 中获取租户 ID + var tenantClaim = httpContext.User.Claims.FirstOrDefault(c => + c.Type.Equals("tenant", StringComparison.OrdinalIgnoreCase) || + c.Type.Equals("tenant_id", StringComparison.OrdinalIgnoreCase) || + c.Type.Equals("tenant_code", StringComparison.OrdinalIgnoreCase) || + c.Type.Equals("tid", StringComparison.OrdinalIgnoreCase)); + + return tenantClaim?.Value; + } + + /// + /// 查找租户专属的目标端点 + /// + private async Task FindTenantDestinationAsync(string clusterId, string tenantId) + { + var cacheKey = $"{CacheKeyPrefix}:{clusterId}:{tenantId}"; + + // 尝试从缓存获取 + if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination)) + { + _logger.LogDebug("Cache hit for tenant destination: {CacheKey}", cacheKey); + return cachedDestination; + } + + // 从数据库查询 + await using var dbContext = await _dbContextFactory.CreateDbContextAsync(); + + var cluster = await dbContext.GwClusters + .AsNoTracking() + .FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1); + + if (cluster == null) + { + _logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId); + return null; + } + + // 查找租户专属目标(TenantCode 匹配且状态正常) + var tenantDestination = cluster.Destinations + .FirstOrDefault(d => + d.Status == 1 && + !string.IsNullOrEmpty(d.TenantCode) && + d.TenantCode.Equals(tenantId, StringComparison.OrdinalIgnoreCase)); + + if (tenantDestination != null) + { + // 缓存结果 + var cacheOptions = new MemoryCacheEntryOptions() + .SetAbsoluteExpiration(CacheExpiration) + .SetSize(1); + + _cache.Set(cacheKey, tenantDestination, cacheOptions); + _logger.LogDebug("Cached tenant destination for key: {CacheKey}", cacheKey); + } + + return tenantDestination; + } + + /// + /// 查找默认目标端点(TenantCode 为空) + /// + private async Task FindDefaultDestinationAsync(string clusterId) + { + var cacheKey = $"{CacheKeyPrefix}:{clusterId}:default"; + + // 尝试从缓存获取 + if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination)) + { + _logger.LogDebug("Cache hit for default destination: {CacheKey}", cacheKey); + return cachedDestination; + } + + // 从数据库查询 + await using var dbContext = await _dbContextFactory.CreateDbContextAsync(); + + var cluster = await dbContext.GwClusters + .AsNoTracking() + .FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1); + + if (cluster == null) + { + _logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId); + return null; + } + + // 查找默认目标(TenantCode 为空或 null 且状态正常) + var defaultDestination = cluster.Destinations + .FirstOrDefault(d => + d.Status == 1 && + string.IsNullOrEmpty(d.TenantCode)); + + if (defaultDestination != null) + { + // 缓存结果 + var cacheOptions = new MemoryCacheEntryOptions() + .SetAbsoluteExpiration(CacheExpiration) + .SetSize(1); + + _cache.Set(cacheKey, defaultDestination, cacheOptions); + _logger.LogDebug("Cached default destination for key: {CacheKey}", cacheKey); + } + + return defaultDestination; + } + + /// + /// 路由到默认目标 + /// + private async ValueTask RouteToDefaultDestinationAsync(RequestTransformContext context, string clusterId) + { + var defaultDestination = await FindDefaultDestinationAsync(clusterId); + + if (defaultDestination != null) + { + _logger.LogDebug("Using default destination {DestinationId} ({Address}) for cluster {ClusterId}", + defaultDestination.DestinationId, defaultDestination.Address, clusterId); + + context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, defaultDestination.Address); + } + else + { + _logger.LogWarning("No default destination found for cluster {ClusterId}", clusterId); + } + } + + /// + /// 构建完整的请求 URI + /// + private static Uri BuildRequestUri(string? destinationPrefix, string address) + { + // 如果 address 已经是完整 URI,直接使用 + if (Uri.TryCreate(address, UriKind.Absolute, out var absoluteUri)) + { + return absoluteUri; + } + + // 否则,与 destinationPrefix 组合 + if (!string.IsNullOrEmpty(destinationPrefix)) + { + var baseUri = destinationPrefix.TrimEnd('/'); + var path = address.StartsWith('/') ? address : "/" + address; + return new Uri(baseUri + path); + } + + // 如果都无法构建,抛出异常 + throw new InvalidOperationException($"Cannot build valid URI from prefix '{destinationPrefix}' and address '{address}'"); + } +}