fengling-gateway/tests/YarpGateway.Tests/Unit/Middleware/JwtTransformMiddlewareTests.cs
movingsam 52f4b7616e docs: add security audit and test plan
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-02-28 18:38:38 +08:00

239 lines
7.1 KiB
C#

using System.Security.Claims;
using Microsoft.AspNetCore.Authentication.JwtBearer;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Moq;
using Xunit;
using FluentAssertions;
using YarpGateway.Config;
using YarpGateway.Middleware;
namespace YarpGateway.Tests.Unit.Middleware;
public class JwtTransformMiddlewareTests
{
private readonly Mock<ILogger<JwtTransformMiddleware>> _loggerMock;
private readonly JwtConfig _jwtConfig;
public JwtTransformMiddlewareTests()
{
_jwtConfig = new JwtConfig
{
Authority = "https://auth.example.com",
Audience = "yarp-gateway"
};
_loggerMock = new Mock<ILogger<JwtTransformMiddleware>>();
}
private JwtTransformMiddleware CreateMiddleware()
{
var jwtConfigOptions = Options.Create(_jwtConfig);
return new JwtTransformMiddleware(
next: Mock.Of<RequestDelegate>(),
jwtConfig: jwtConfigOptions,
logger: _loggerMock.Object
);
}
private DefaultHttpContext CreateAuthenticatedContext(string? tenantId = "tenant-1", string? userId = "user-1")
{
var context = new DefaultHttpContext();
var claims = new List<Claim>();
if (!string.IsNullOrEmpty(tenantId))
{
claims.Add(new Claim("tenant", tenantId));
}
if (!string.IsNullOrEmpty(userId))
{
claims.Add(new Claim(ClaimTypes.NameIdentifier, userId));
claims.Add(new Claim("sub", userId));
}
claims.Add(new Claim(ClaimTypes.Name, "testuser"));
claims.Add(new Claim("name", "Test User"));
claims.Add(new Claim(ClaimTypes.Role, "admin"));
claims.Add(new Claim("role", "user"));
var identity = new ClaimsIdentity(claims, JwtBearerDefaults.AuthenticationScheme);
var principal = new ClaimsPrincipal(identity);
context.User = principal;
return context;
}
private DefaultHttpContext CreateUnauthenticatedContext()
{
var context = new DefaultHttpContext();
context.User = new ClaimsPrincipal();
return context;
}
[Fact]
public async Task InvokeAsync_WithAuthenticatedUser_ShouldExtractTenantClaim()
{
// Arrange
var context = CreateAuthenticatedContext(tenantId: "tenant-123");
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert
context.Request.Headers["X-Tenant-Id"].Should().Contain("tenant-123");
}
[Fact]
public async Task InvokeAsync_WithAuthenticatedUser_ShouldExtractUserId()
{
// Arrange
var context = CreateAuthenticatedContext(userId: "user-456");
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert
context.Request.Headers["X-User-Id"].Should().Contain("user-456");
}
[Fact]
public async Task InvokeAsync_WithAuthenticatedUser_ShouldExtractUserName()
{
// Arrange
var context = CreateAuthenticatedContext();
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert
context.Request.Headers["X-User-Name"].Should().Contain("Test User");
}
[Fact]
public async Task InvokeAsync_WithAuthenticatedUser_ShouldExtractRoles()
{
// Arrange
var context = CreateAuthenticatedContext();
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert
context.Request.Headers["X-Roles"].Should().Contain("admin,user");
}
[Fact]
public async Task InvokeAsync_WithUnauthenticatedUser_ShouldNotSetHeaders()
{
// Arrange
var context = CreateUnauthenticatedContext();
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert
context.Request.Headers.Should().NotContainKey("X-Tenant-Id");
context.Request.Headers.Should().NotContainKey("X-User-Id");
}
[Fact]
public async Task InvokeAsync_WithMissingTenantClaim_ShouldLogWarning()
{
// Arrange
var context = CreateAuthenticatedContext(tenantId: null!);
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert - should not throw, just log warning
context.Request.Headers.Should().NotContainKey("X-Tenant-Id");
}
[Fact]
public async Task InvokeAsync_WithTenantClaimUsingTenantIdType_ShouldExtractCorrectly()
{
// Arrange
var context = new DefaultHttpContext();
var claims = new List<Claim>
{
new Claim("tenant_id", "tenant-using-id-type"),
new Claim(ClaimTypes.NameIdentifier, "user-1")
};
var identity = new ClaimsIdentity(claims, JwtBearerDefaults.AuthenticationScheme);
context.User = new ClaimsPrincipal(identity);
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert
context.Request.Headers["X-Tenant-Id"].Should().Contain("tenant-using-id-type");
}
[Fact]
public async Task InvokeAsync_ShouldRemoveExistingXHeaders_PreventHeaderInjection()
{
// Arrange
var context = CreateAuthenticatedContext();
// Simulate header injection attempt
context.Request.Headers["X-Tenant-Id"] = "injected-tenant";
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert - the injected header should be removed and replaced with JWT value
context.Request.Headers["X-Tenant-Id"].Should().Contain("tenant-1");
}
[Fact]
public async Task InvokeAsync_WithMultipleTenantClaims_ShouldPrioritizeTenantType()
{
// Arrange
var context = new DefaultHttpContext();
var claims = new List<Claim>
{
new Claim("tenant", "tenant-from-claim"),
new Claim("tenant_id", "tenant-id-claim"),
new Claim(ClaimTypes.NameIdentifier, "user-1")
};
var identity = new ClaimsIdentity(claims, JwtBearerDefaults.AuthenticationScheme);
context.User = new ClaimsPrincipal(identity);
var middleware = CreateMiddleware();
// Act
await middleware.InvokeAsync(context);
// Assert - should prioritize "tenant" over "tenant_id"
context.Request.Headers["X-Tenant-Id"].Should().Contain("tenant-from-claim");
}
[Fact]
public async Task InvokeAsync_WithEmptyClaims_ShouldNotThrow()
{
// Arrange
var context = new DefaultHttpContext();
var identity = new ClaimsIdentity(Array.Empty<Claim>(), JwtBearerDefaults.AuthenticationScheme);
context.User = new ClaimsPrincipal(identity);
var middleware = CreateMiddleware();
// Act & Assert - should not throw
await middleware.InvokeAsync(context);
}
}