fengling-gateway/src/yarpgateway/Transforms/TenantRoutingTransform.cs
movingsam c7cf1d3738 feat: add TenantRoutingTransform for multi-tenant routing (IMPL-10)
- Implement YARP RequestTransform to route requests based on tenant
- Extract tenant ID from JWT token (supports multiple claim types)
- Query tenant-specific destination first, fallback to default
- Add IMemoryCache for performance optimization (5min expiration)
- Update Platform packages to 1.0.14 to use GwDestination.TenantCode
2026-03-08 01:01:22 +08:00

238 lines
8.6 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System.Security.Claims;
using Fengling.Platform.Domain.AggregatesModel.GatewayAggregate;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Caching.Memory;
using Yarp.ReverseProxy.Transforms;
using YarpGateway.Data;
namespace YarpGateway.Transforms;
/// <summary>
/// 租户路由转换器
///
/// 功能说明:
/// 1. 从 JWT Token 解析租户 ID
/// 2. 根据租户 ID 查找对应的目标端点
/// 3. 优先使用租户专属目标,不存在则回退到默认目标
/// 4. 使用内存缓存优化性能
/// </summary>
public class TenantRoutingTransform : RequestTransform
{
private readonly IDbContextFactory<GatewayDbContext> _dbContextFactory;
private readonly IMemoryCache _cache;
private readonly ILogger<TenantRoutingTransform> _logger;
// 缓存键前缀
private const string CacheKeyPrefix = "tenant_destination";
// 缓存过期时间5分钟
private static readonly TimeSpan CacheExpiration = TimeSpan.FromMinutes(5);
public TenantRoutingTransform(
IDbContextFactory<GatewayDbContext> dbContextFactory,
IMemoryCache cache,
ILogger<TenantRoutingTransform> logger)
{
_dbContextFactory = dbContextFactory;
_cache = cache;
_logger = logger;
}
public override async ValueTask ApplyAsync(RequestTransformContext context)
{
// 从 HttpContext 获取 ClusterId由 TenantRoutingMiddleware 设置)
if (!context.HttpContext.Items.TryGetValue("DynamicClusterId", out var clusterIdObj) || clusterIdObj is not string clusterId)
{
_logger.LogDebug("No DynamicClusterId found in HttpContext, skipping tenant routing");
return;
}
// 从 JWT Token 解析租户 ID
var tenantId = ExtractTenantFromJwt(context.HttpContext);
if (string.IsNullOrEmpty(tenantId))
{
_logger.LogDebug("No tenant ID found in JWT, using default destination for cluster {ClusterId}", clusterId);
// 没有租户ID时使用默认目标
await RouteToDefaultDestinationAsync(context, clusterId);
return;
}
_logger.LogDebug("Processing tenant routing for cluster {ClusterId}, tenant {TenantId}", clusterId, tenantId);
// 查找租户专属目标
var tenantDestination = await FindTenantDestinationAsync(clusterId, tenantId);
if (tenantDestination != null)
{
_logger.LogDebug("Using tenant-specific destination {DestinationId} ({Address}) for tenant {TenantId}",
tenantDestination.DestinationId, tenantDestination.Address, tenantId);
context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, tenantDestination.Address);
}
else
{
// 回退到默认目标
_logger.LogDebug("No tenant-specific destination found for tenant {TenantId}, falling back to default", tenantId);
await RouteToDefaultDestinationAsync(context, clusterId);
}
}
/// <summary>
/// 从 JWT Token 中提取租户 ID
/// </summary>
private static string? ExtractTenantFromJwt(HttpContext httpContext)
{
if (httpContext.User?.Identity?.IsAuthenticated != true)
{
return null;
}
// 尝试从 claims 中获取租户 ID
var tenantClaim = httpContext.User.Claims.FirstOrDefault(c =>
c.Type.Equals("tenant", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tenant_id", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tenant_code", StringComparison.OrdinalIgnoreCase) ||
c.Type.Equals("tid", StringComparison.OrdinalIgnoreCase));
return tenantClaim?.Value;
}
/// <summary>
/// 查找租户专属的目标端点
/// </summary>
private async Task<GwDestination?> FindTenantDestinationAsync(string clusterId, string tenantId)
{
var cacheKey = $"{CacheKeyPrefix}:{clusterId}:{tenantId}";
// 尝试从缓存获取
if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination))
{
_logger.LogDebug("Cache hit for tenant destination: {CacheKey}", cacheKey);
return cachedDestination;
}
// 从数据库查询
await using var dbContext = await _dbContextFactory.CreateDbContextAsync();
var cluster = await dbContext.GwClusters
.AsNoTracking()
.FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1);
if (cluster == null)
{
_logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId);
return null;
}
// 查找租户专属目标TenantCode 匹配且状态正常)
var tenantDestination = cluster.Destinations
.FirstOrDefault(d =>
d.Status == 1 &&
!string.IsNullOrEmpty(d.TenantCode) &&
d.TenantCode.Equals(tenantId, StringComparison.OrdinalIgnoreCase));
if (tenantDestination != null)
{
// 缓存结果
var cacheOptions = new MemoryCacheEntryOptions()
.SetAbsoluteExpiration(CacheExpiration)
.SetSize(1);
_cache.Set(cacheKey, tenantDestination, cacheOptions);
_logger.LogDebug("Cached tenant destination for key: {CacheKey}", cacheKey);
}
return tenantDestination;
}
/// <summary>
/// 查找默认目标端点TenantCode 为空)
/// </summary>
private async Task<GwDestination?> FindDefaultDestinationAsync(string clusterId)
{
var cacheKey = $"{CacheKeyPrefix}:{clusterId}:default";
// 尝试从缓存获取
if (_cache.TryGetValue(cacheKey, out GwDestination? cachedDestination))
{
_logger.LogDebug("Cache hit for default destination: {CacheKey}", cacheKey);
return cachedDestination;
}
// 从数据库查询
await using var dbContext = await _dbContextFactory.CreateDbContextAsync();
var cluster = await dbContext.GwClusters
.AsNoTracking()
.FirstOrDefaultAsync(c => c.ClusterId == clusterId && c.Status == 1);
if (cluster == null)
{
_logger.LogWarning("Cluster {ClusterId} not found or disabled", clusterId);
return null;
}
// 查找默认目标TenantCode 为空或 null 且状态正常)
var defaultDestination = cluster.Destinations
.FirstOrDefault(d =>
d.Status == 1 &&
string.IsNullOrEmpty(d.TenantCode));
if (defaultDestination != null)
{
// 缓存结果
var cacheOptions = new MemoryCacheEntryOptions()
.SetAbsoluteExpiration(CacheExpiration)
.SetSize(1);
_cache.Set(cacheKey, defaultDestination, cacheOptions);
_logger.LogDebug("Cached default destination for key: {CacheKey}", cacheKey);
}
return defaultDestination;
}
/// <summary>
/// 路由到默认目标
/// </summary>
private async ValueTask RouteToDefaultDestinationAsync(RequestTransformContext context, string clusterId)
{
var defaultDestination = await FindDefaultDestinationAsync(clusterId);
if (defaultDestination != null)
{
_logger.LogDebug("Using default destination {DestinationId} ({Address}) for cluster {ClusterId}",
defaultDestination.DestinationId, defaultDestination.Address, clusterId);
context.ProxyRequest.RequestUri = BuildRequestUri(context.DestinationPrefix, defaultDestination.Address);
}
else
{
_logger.LogWarning("No default destination found for cluster {ClusterId}", clusterId);
}
}
/// <summary>
/// 构建完整的请求 URI
/// </summary>
private static Uri BuildRequestUri(string? destinationPrefix, string address)
{
// 如果 address 已经是完整 URI直接使用
if (Uri.TryCreate(address, UriKind.Absolute, out var absoluteUri))
{
return absoluteUri;
}
// 否则,与 destinationPrefix 组合
if (!string.IsNullOrEmpty(destinationPrefix))
{
var baseUri = destinationPrefix.TrimEnd('/');
var path = address.StartsWith('/') ? address : "/" + address;
return new Uri(baseUri + path);
}
// 如果都无法构建,抛出异常
throw new InvalidOperationException($"Cannot build valid URI from prefix '{destinationPrefix}' and address '{address}'");
}
}