using System.Security.Claims;
using Fengling.Platform.Domain.AggregatesModel.GatewayAggregate;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Caching.Memory;
using Yarp.ReverseProxy.Transforms;
using YarpGateway.Data;
namespace YarpGateway.Transforms;
///
/// 租户路由转换器
///
/// 功能说明:
/// 1. 从 JWT Token 解析租户 ID
/// 2. 根据租户 ID 查找对应的目标端点
/// 3. 优先使用租户专属目标,不存在则回退到默认目标
/// 4. 使用内存缓存优化性能
///
public class TenantRoutingTransform : RequestTransform
{
private readonly IDbContextFactory _dbContextFactory;
private readonly IMemoryCache _cache;
private readonly ILogger _logger;
// 缓存键前缀
private const string CacheKeyPrefix = "tenant_destination";
// 缓存过期时间:5分钟
private static readonly TimeSpan CacheExpiration = TimeSpan.FromMinutes(5);
public TenantRoutingTransform(
IDbContextFactory dbContextFactory,
IMemoryCache cache,
ILogger logger)
{
_dbContextFactory = dbContextFactory;
_cache = cache;
_logger = logger;
}
public override async ValueTask ApplyAsync(RequestTransformContext context)
{
// 从 HttpContext 获取 ClusterId(由 TenantRoutingMiddleware 设置)
if (!context.HttpContext.Items.TryGetValue("DynamicClusterId", out var clusterIdObj) || clusterIdObj is not string clusterId)
{
_logger.LogDebug("No DynamicClusterId found in HttpContext, skipping tenant routing");
return;
}
// 从 JWT Token 解析租户 ID
var tenantId = ExtractTenantFromJwt(context.HttpContext);
if (string.IsNullOrEmpty(tenantId))
{
_logger.LogDebug("No tenant ID found in JWT, using default destination for cluster {ClusterId}", clusterId);
// 没有租户ID时,使用默认目标
await RouteToDefaultDestinationAsync(context, clusterId);
return;
}
_logger.LogDebug("Processing tenant routing for cluster {ClusterId}, tenant {TenantId}", clusterId, tenantId);
// 查找租户专属目标
var tenantDestination = await FindTenantDestinationAsync(clusterId, tenantId);
if (tenantDestination != null)
{
_logger.LogDebug("Using tenant-specific destination {DestinationId} ({Address}) for tenant {TenantId}",
tenantDestination.DestinationId, tenantDestination.Address, tenantId);
context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, tenantDestination.Address);
}
else
{
// 回退到默认目标
_logger.LogDebug("No tenant-specific destination found for tenant {TenantId}, falling back to default", tenantId);
await RouteToDefaultDestinationAsync(context, clusterId);
}
}
///
/// 从 JWT Token 中提取租户 ID
///
private static string? ExtractTenantFromJwt(HttpContext httpContext)
{
if (httpContext.User?.Identity?.IsAuthenticated != true)
{
return null;
}
// 尝试从 claims 中获取租户 ID
var tenantClaim = httpContext.User.Claims.FirstOrDefault(c =>
c.Type.Equals("tenant", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tenant_id", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tenant_code", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tid", StringComparison.OrdinalIgnoreCase));
return tenantClaim?.Value;
}
///
/// 查找租户专属的目标端点
///
private async Task FindTenantDestinationAsync(string clusterId, string tenantId)
{
var cacheKey = $"{CacheKeyPrefix}:{clusterId}:{tenantId}";
// 尝试从缓存获取
if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination))
{
_logger.LogDebug("Cache hit for tenant destination: {CacheKey}", cacheKey);
return cachedDestination;
}
// 从数据库查询
await using var dbContext = await _dbContextFactory.CreateDbContextAsync();
var cluster = await dbContext.GwClusters
.AsNoTracking()
.FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1);
if (cluster == null)
{
_logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId);
return null;
}
// 查找租户专属目标(TenantCode 匹配且状态正常)
var tenantDestination = cluster.Destinations
.FirstOrDefault(d =>
d.Status == 1 &&
!string.IsNullOrEmpty(d.TenantCode) &&
d.TenantCode.Equals(tenantId, StringComparison.OrdinalIgnoreCase));
if (tenantDestination != null)
{
// 缓存结果
var cacheOptions = new MemoryCacheEntryOptions()
.SetAbsoluteExpiration(CacheExpiration)
.SetSize(1);
_cache.Set(cacheKey, tenantDestination, cacheOptions);
_logger.LogDebug("Cached tenant destination for key: {CacheKey}", cacheKey);
}
return tenantDestination;
}
///
/// 查找默认目标端点(TenantCode 为空)
///
private async Task FindDefaultDestinationAsync(string clusterId)
{
var cacheKey = $"{CacheKeyPrefix}:{clusterId}:default";
// 尝试从缓存获取
if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination))
{
_logger.LogDebug("Cache hit for default destination: {CacheKey}", cacheKey);
return cachedDestination;
}
// 从数据库查询
await using var dbContext = await _dbContextFactory.CreateDbContextAsync();
var cluster = await dbContext.GwClusters
.AsNoTracking()
.FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1);
if (cluster == null)
{
_logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId);
return null;
}
// 查找默认目标(TenantCode 为空或 null 且状态正常)
var defaultDestination = cluster.Destinations
.FirstOrDefault(d =>
d.Status == 1 &&
string.IsNullOrEmpty(d.TenantCode));
if (defaultDestination != null)
{
// 缓存结果
var cacheOptions = new MemoryCacheEntryOptions()
.SetAbsoluteExpiration(CacheExpiration)
.SetSize(1);
_cache.Set(cacheKey, defaultDestination, cacheOptions);
_logger.LogDebug("Cached default destination for key: {CacheKey}", cacheKey);
}
return defaultDestination;
}
///
/// 路由到默认目标
///
private async ValueTask RouteToDefaultDestinationAsync(RequestTransformContext context, string clusterId)
{
var defaultDestination = await FindDefaultDestinationAsync(clusterId);
if (defaultDestination != null)
{
_logger.LogDebug("Using default destination {DestinationId} ({Address}) for cluster {ClusterId}",
defaultDestination.DestinationId, defaultDestination.Address, clusterId);
context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, defaultDestination.Address);
}
else
{
_logger.LogWarning("No default destination found for cluster {ClusterId}", clusterId);
}
}
///
/// 构建完整的请求 URI
///
private static Uri BuildRequestUri(string? destinationPrefix, string address)
{
// 如果 address 已经是完整 URI,直接使用
if (Uri.TryCreate(address, UriKind.Absolute, out var absoluteUri))
{
return absoluteUri;
}
// 否则,与 destinationPrefix 组合
if (!string.IsNullOrEmpty(destinationPrefix))
{
var baseUri = destinationPrefix.TrimEnd('/');
var path = address.StartsWith('/') ? address : "/" + address;
return new Uri(baseUri + path);
}
// 如果都无法构建,抛出异常
throw new InvalidOperationException($"Cannot build valid URI from prefix '{destinationPrefix}' and address '{address}'");
}
}