- Implement YARP RequestTransform to route requests based on tenant - Extract tenant ID from JWT token (supports multiple claim types) - Query tenant-specific destination first, fallback to default - Add IMemoryCache for performance optimization (5min expiration) - Update Platform packages to 1.0.14 to use GwDestination.TenantCode
238 lines
8.6 KiB
C#
238 lines
8.6 KiB
C#
using System.Security.Claims;
|
||
using Fengling.Platform.Domain.AggregatesModel.GatewayAggregate;
|
||
using Microsoft.EntityFrameworkCore;
|
||
using Microsoft.Extensions.Caching.Memory;
|
||
using Yarp.ReverseProxy.Transforms;
|
||
using YarpGateway.Data;
|
||
|
||
namespace YarpGateway.Transforms;
|
||
|
||
/// <summary>
|
||
/// 租户路由转换器
|
||
///
|
||
/// 功能说明:
|
||
/// 1. 从 JWT Token 解析租户 ID
|
||
/// 2. 根据租户 ID 查找对应的目标端点
|
||
/// 3. 优先使用租户专属目标,不存在则回退到默认目标
|
||
/// 4. 使用内存缓存优化性能
|
||
/// </summary>
|
||
public class TenantRoutingTransform : RequestTransform
|
||
{
|
||
private readonly IDbContextFactory<GatewayDbContext> _dbContextFactory;
|
||
private readonly IMemoryCache _cache;
|
||
private readonly ILogger<TenantRoutingTransform> _logger;
|
||
|
||
// 缓存键前缀
|
||
private const string CacheKeyPrefix = "tenant_destination";
|
||
// 缓存过期时间:5分钟
|
||
private static readonly TimeSpan CacheExpiration = TimeSpan.FromMinutes(5);
|
||
|
||
public TenantRoutingTransform(
|
||
IDbContextFactory<GatewayDbContext> dbContextFactory,
|
||
IMemoryCache cache,
|
||
ILogger<TenantRoutingTransform> logger)
|
||
{
|
||
_dbContextFactory = dbContextFactory;
|
||
_cache = cache;
|
||
_logger = logger;
|
||
}
|
||
|
||
public override async ValueTask ApplyAsync(RequestTransformContext context)
|
||
{
|
||
// 从 HttpContext 获取 ClusterId(由 TenantRoutingMiddleware 设置)
|
||
if (!context.HttpContext.Items.TryGetValue("DynamicClusterId", out var clusterIdObj) || clusterIdObj is not string clusterId)
|
||
{
|
||
_logger.LogDebug("No DynamicClusterId found in HttpContext, skipping tenant routing");
|
||
return;
|
||
}
|
||
|
||
// 从 JWT Token 解析租户 ID
|
||
var tenantId = ExtractTenantFromJwt(context.HttpContext);
|
||
|
||
if (string.IsNullOrEmpty(tenantId))
|
||
{
|
||
_logger.LogDebug("No tenant ID found in JWT, using default destination for cluster {ClusterId}", clusterId);
|
||
// 没有租户ID时,使用默认目标
|
||
await RouteToDefaultDestinationAsync(context, clusterId);
|
||
return;
|
||
}
|
||
|
||
_logger.LogDebug("Processing tenant routing for cluster {ClusterId}, tenant {TenantId}", clusterId, tenantId);
|
||
|
||
// 查找租户专属目标
|
||
var tenantDestination = await FindTenantDestinationAsync(clusterId, tenantId);
|
||
|
||
if (tenantDestination != null)
|
||
{
|
||
_logger.LogDebug("Using tenant-specific destination {DestinationId} ({Address}) for tenant {TenantId}",
|
||
tenantDestination.DestinationId, tenantDestination.Address, tenantId);
|
||
|
||
context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, tenantDestination.Address);
|
||
}
|
||
else
|
||
{
|
||
// 回退到默认目标
|
||
_logger.LogDebug("No tenant-specific destination found for tenant {TenantId}, falling back to default", tenantId);
|
||
await RouteToDefaultDestinationAsync(context, clusterId);
|
||
}
|
||
}
|
||
|
||
/// <summary>
|
||
/// 从 JWT Token 中提取租户 ID
|
||
/// </summary>
|
||
private static string? ExtractTenantFromJwt(HttpContext httpContext)
|
||
{
|
||
if (httpContext.User?.Identity?.IsAuthenticated != true)
|
||
{
|
||
return null;
|
||
}
|
||
|
||
// 尝试从 claims 中获取租户 ID
|
||
var tenantClaim = httpContext.User.Claims.FirstOrDefault(c =>
|
||
c.Type.Equals("tenant", StringComparison.OrdinalIgnoreCase) ||
|
||
c.Type.Equals("tenant_id", StringComparison.OrdinalIgnoreCase) ||
|
||
c.Type.Equals("tenant_code", StringComparison.OrdinalIgnoreCase) ||
|
||
c.Type.Equals("tid", StringComparison.OrdinalIgnoreCase));
|
||
|
||
return tenantClaim?.Value;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 查找租户专属的目标端点
|
||
/// </summary>
|
||
private async Task<GwDestination?> FindTenantDestinationAsync(string clusterId, string tenantId)
|
||
{
|
||
var cacheKey = $"{CacheKeyPrefix}:{clusterId}:{tenantId}";
|
||
|
||
// 尝试从缓存获取
|
||
if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination))
|
||
{
|
||
_logger.LogDebug("Cache hit for tenant destination: {CacheKey}", cacheKey);
|
||
return cachedDestination;
|
||
}
|
||
|
||
// 从数据库查询
|
||
await using var dbContext = await _dbContextFactory.CreateDbContextAsync();
|
||
|
||
var cluster = await dbContext.GwClusters
|
||
.AsNoTracking()
|
||
.FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1);
|
||
|
||
if (cluster == null)
|
||
{
|
||
_logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId);
|
||
return null;
|
||
}
|
||
|
||
// 查找租户专属目标(TenantCode 匹配且状态正常)
|
||
var tenantDestination = cluster.Destinations
|
||
.FirstOrDefault(d =>
|
||
d.Status == 1 &&
|
||
!string.IsNullOrEmpty(d.TenantCode) &&
|
||
d.TenantCode.Equals(tenantId, StringComparison.OrdinalIgnoreCase));
|
||
|
||
if (tenantDestination != null)
|
||
{
|
||
// 缓存结果
|
||
var cacheOptions = new MemoryCacheEntryOptions()
|
||
.SetAbsoluteExpiration(CacheExpiration)
|
||
.SetSize(1);
|
||
|
||
_cache.Set(cacheKey, tenantDestination, cacheOptions);
|
||
_logger.LogDebug("Cached tenant destination for key: {CacheKey}", cacheKey);
|
||
}
|
||
|
||
return tenantDestination;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 查找默认目标端点(TenantCode 为空)
|
||
/// </summary>
|
||
private async Task<GwDestination?> FindDefaultDestinationAsync(string clusterId)
|
||
{
|
||
var cacheKey = $"{CacheKeyPrefix}:{clusterId}:default";
|
||
|
||
// 尝试从缓存获取
|
||
if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination))
|
||
{
|
||
_logger.LogDebug("Cache hit for default destination: {CacheKey}", cacheKey);
|
||
return cachedDestination;
|
||
}
|
||
|
||
// 从数据库查询
|
||
await using var dbContext = await _dbContextFactory.CreateDbContextAsync();
|
||
|
||
var cluster = await dbContext.GwClusters
|
||
.AsNoTracking()
|
||
.FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1);
|
||
|
||
if (cluster == null)
|
||
{
|
||
_logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId);
|
||
return null;
|
||
}
|
||
|
||
// 查找默认目标(TenantCode 为空或 null 且状态正常)
|
||
var defaultDestination = cluster.Destinations
|
||
.FirstOrDefault(d =>
|
||
d.Status == 1 &&
|
||
string.IsNullOrEmpty(d.TenantCode));
|
||
|
||
if (defaultDestination != null)
|
||
{
|
||
// 缓存结果
|
||
var cacheOptions = new MemoryCacheEntryOptions()
|
||
.SetAbsoluteExpiration(CacheExpiration)
|
||
.SetSize(1);
|
||
|
||
_cache.Set(cacheKey, defaultDestination, cacheOptions);
|
||
_logger.LogDebug("Cached default destination for key: {CacheKey}", cacheKey);
|
||
}
|
||
|
||
return defaultDestination;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 路由到默认目标
|
||
/// </summary>
|
||
private async ValueTask RouteToDefaultDestinationAsync(RequestTransformContext context, string clusterId)
|
||
{
|
||
var defaultDestination = await FindDefaultDestinationAsync(clusterId);
|
||
|
||
if (defaultDestination != null)
|
||
{
|
||
_logger.LogDebug("Using default destination {DestinationId} ({Address}) for cluster {ClusterId}",
|
||
defaultDestination.DestinationId, defaultDestination.Address, clusterId);
|
||
|
||
context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, defaultDestination.Address);
|
||
}
|
||
else
|
||
{
|
||
_logger.LogWarning("No default destination found for cluster {ClusterId}", clusterId);
|
||
}
|
||
}
|
||
|
||
/// <summary>
|
||
/// 构建完整的请求 URI
|
||
/// </summary>
|
||
private static Uri BuildRequestUri(string? destinationPrefix, string address)
|
||
{
|
||
// 如果 address 已经是完整 URI,直接使用
|
||
if (Uri.TryCreate(address, UriKind.Absolute, out var absoluteUri))
|
||
{
|
||
return absoluteUri;
|
||
}
|
||
|
||
// 否则,与 destinationPrefix 组合
|
||
if (!string.IsNullOrEmpty(destinationPrefix))
|
||
{
|
||
var baseUri = destinationPrefix.TrimEnd('/');
|
||
var path = address.StartsWith('/') ? address : "/" + address;
|
||
return new Uri(baseUri + path);
|
||
}
|
||
|
||
// 如果都无法构建,抛出异常
|
||
throw new InvalidOperationException($"Cannot build valid URI from prefix '{destinationPrefix}' and address '{address}'");
|
||
}
|
||
}
|