using System.Security.Claims; using Fengling.Platform.Domain.AggregatesModel.GatewayAggregate; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Caching.Memory; using Yarp.ReverseProxy.Transforms; using YarpGateway.Data; namespace YarpGateway.Transforms; /// /// 租户路由转换器 /// /// 功能说明: /// 1. 从 JWT Token 解析租户 ID /// 2. 根据租户 ID 查找对应的目标端点 /// 3. 优先使用租户专属目标,不存在则回退到默认目标 /// 4. 使用内存缓存优化性能 /// public class TenantRoutingTransform : RequestTransform { private readonly IDbContextFactory _dbContextFactory; private readonly IMemoryCache _cache; private readonly ILogger _logger; // 缓存键前缀 private const string CacheKeyPrefix = "tenant_destination"; // 缓存过期时间:5分钟 private static readonly TimeSpan CacheExpiration = TimeSpan.FromMinutes(5); public TenantRoutingTransform( IDbContextFactory dbContextFactory, IMemoryCache cache, ILogger logger) { _dbContextFactory = dbContextFactory; _cache = cache; _logger = logger; } public override async ValueTask ApplyAsync(RequestTransformContext context) { // 从 HttpContext 获取 ClusterId(由 TenantRoutingMiddleware 设置) if (!context.HttpContext.Items.TryGetValue("DynamicClusterId", out var clusterIdObj) || clusterIdObj is not string clusterId) { _logger.LogDebug("No DynamicClusterId found in HttpContext, skipping tenant routing"); return; } // 从 JWT Token 解析租户 ID var tenantId = ExtractTenantFromJwt(context.HttpContext); if (string.IsNullOrEmpty(tenantId)) { _logger.LogDebug("No tenant ID found in JWT, using default destination for cluster {ClusterId}", clusterId); // 没有租户ID时,使用默认目标 await RouteToDefaultDestinationAsync(context, clusterId); return; } _logger.LogDebug("Processing tenant routing for cluster {ClusterId}, tenant {TenantId}", clusterId, tenantId); // 查找租户专属目标 var tenantDestination = await FindTenantDestinationAsync(clusterId, tenantId); if (tenantDestination != null) { _logger.LogDebug("Using tenant-specific destination {DestinationId} ({Address}) for tenant {TenantId}", tenantDestination.DestinationId, tenantDestination.Address, tenantId); context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, tenantDestination.Address); } else { // 回退到默认目标 _logger.LogDebug("No tenant-specific destination found for tenant {TenantId}, falling back to default", tenantId); await RouteToDefaultDestinationAsync(context, clusterId); } } /// /// 从 JWT Token 中提取租户 ID /// private static string? ExtractTenantFromJwt(HttpContext httpContext) { if (httpContext.User?.Identity?.IsAuthenticated != true) { return null; } // 尝试从 claims 中获取租户 ID var tenantClaim = httpContext.User.Claims.FirstOrDefault(c => c.Type.Equals("tenant", StringComparison.OrdinalIgnoreCase) || c.Type.Equals("tenant_id", StringComparison.OrdinalIgnoreCase) || c.Type.Equals("tenant_code", StringComparison.OrdinalIgnoreCase) || c.Type.Equals("tid", StringComparison.OrdinalIgnoreCase)); return tenantClaim?.Value; } /// /// 查找租户专属的目标端点 /// private async Task FindTenantDestinationAsync(string clusterId, string tenantId) { var cacheKey = $"{CacheKeyPrefix}:{clusterId}:{tenantId}"; // 尝试从缓存获取 if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination)) { _logger.LogDebug("Cache hit for tenant destination: {CacheKey}", cacheKey); return cachedDestination; } // 从数据库查询 await using var dbContext = await _dbContextFactory.CreateDbContextAsync(); var cluster = await dbContext.GwClusters .AsNoTracking() .FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1); if (cluster == null) { _logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId); return null; } // 查找租户专属目标(TenantCode 匹配且状态正常) var tenantDestination = cluster.Destinations .FirstOrDefault(d => d.Status == 1 && !string.IsNullOrEmpty(d.TenantCode) && d.TenantCode.Equals(tenantId, StringComparison.OrdinalIgnoreCase)); if (tenantDestination != null) { // 缓存结果 var cacheOptions = new MemoryCacheEntryOptions() .SetAbsoluteExpiration(CacheExpiration) .SetSize(1); _cache.Set(cacheKey, tenantDestination, cacheOptions); _logger.LogDebug("Cached tenant destination for key: {CacheKey}", cacheKey); } return tenantDestination; } /// /// 查找默认目标端点(TenantCode 为空) /// private async Task FindDefaultDestinationAsync(string clusterId) { var cacheKey = $"{CacheKeyPrefix}:{clusterId}:default"; // 尝试从缓存获取 if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination)) { _logger.LogDebug("Cache hit for default destination: {CacheKey}", cacheKey); return cachedDestination; } // 从数据库查询 await using var dbContext = await _dbContextFactory.CreateDbContextAsync(); var cluster = await dbContext.GwClusters .AsNoTracking() .FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1); if (cluster == null) { _logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId); return null; } // 查找默认目标(TenantCode 为空或 null 且状态正常) var defaultDestination = cluster.Destinations .FirstOrDefault(d => d.Status == 1 && string.IsNullOrEmpty(d.TenantCode)); if (defaultDestination != null) { // 缓存结果 var cacheOptions = new MemoryCacheEntryOptions() .SetAbsoluteExpiration(CacheExpiration) .SetSize(1); _cache.Set(cacheKey, defaultDestination, cacheOptions); _logger.LogDebug("Cached default destination for key: {CacheKey}", cacheKey); } return defaultDestination; } /// /// 路由到默认目标 /// private async ValueTask RouteToDefaultDestinationAsync(RequestTransformContext context, string clusterId) { var defaultDestination = await FindDefaultDestinationAsync(clusterId); if (defaultDestination != null) { _logger.LogDebug("Using default destination {DestinationId} ({Address}) for cluster {ClusterId}", defaultDestination.DestinationId, defaultDestination.Address, clusterId); context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, defaultDestination.Address); } else { _logger.LogWarning("No default destination found for cluster {ClusterId}", clusterId); } } /// /// 构建完整的请求 URI /// private static Uri BuildRequestUri(string? destinationPrefix, string address) { // 如果 address 已经是完整 URI,直接使用 if (Uri.TryCreate(address, UriKind.Absolute, out var absoluteUri)) { return absoluteUri; } // 否则,与 destinationPrefix 组合 if (!string.IsNullOrEmpty(destinationPrefix)) { var baseUri = destinationPrefix.TrimEnd('/'); var path = address.StartsWith('/') ? address : "/" + address; return new Uri(baseUri + path); } // 如果都无法构建,抛出异常 throw new InvalidOperationException($"Cannot build valid URI from prefix '{destinationPrefix}' and address '{address}'"); } }