fengling-gateway/LoadBalancing/DistributedWeightedRoundRobinPolicy.cs

244 lines
6.9 KiB
C#

using System.Text.Json;
using Microsoft.Extensions.Logging;
using StackExchange.Redis;
using Yarp.ReverseProxy.LoadBalancing;
using Yarp.ReverseProxy.Model;
using YarpGateway.Config;
namespace YarpGateway.LoadBalancing;
public class DistributedWeightedRoundRobinPolicy : ILoadBalancingPolicy
{
private readonly IConnectionMultiplexer _redis;
private readonly RedisConfig _config;
private readonly ILogger<DistributedWeightedRoundRobinPolicy> _logger;
public string Name => "DistributedWeightedRoundRobin";
public DistributedWeightedRoundRobinPolicy(
IConnectionMultiplexer redis,
RedisConfig config,
ILogger<DistributedWeightedRoundRobinPolicy> logger
)
{
_redis = redis;
_config = config;
_logger = logger;
}
public DestinationState? PickDestination(
HttpContext context,
ClusterState cluster,
IReadOnlyList<DestinationState> availableDestinations
)
{
if (availableDestinations.Count == 0)
return null;
if (availableDestinations.Count == 1)
return availableDestinations[0];
var clusterId = cluster.ClusterId;
var db = _redis.GetDatabase();
var lockKey = $"lock:{_config.InstanceName}:{clusterId}";
var stateKey = $"lb:{_config.InstanceName}:{clusterId}:state";
var lockValue = Guid.NewGuid().ToString();
var lockAcquired = db.StringSet(
lockKey,
lockValue,
TimeSpan.FromMilliseconds(500),
When.NotExists
);
if (!lockAcquired)
{
_logger.LogDebug(
"Lock busy for cluster {Cluster}, using fallback selection",
clusterId
);
return FallbackSelection(availableDestinations);
}
try
{
var state = GetOrCreateLoadBalancingState(db, stateKey, availableDestinations);
var selectedDestination = SelectByWeight(state, availableDestinations);
UpdateCurrentWeights(db, stateKey, state, selectedDestination);
_logger.LogDebug(
"Selected {Destination} for cluster {Cluster}",
selectedDestination?.DestinationId,
clusterId
);
return selectedDestination;
}
catch (Exception ex)
{
_logger.LogError(
ex,
"Error in distributed load balancing for cluster {Cluster}",
clusterId
);
return availableDestinations[0];
}
finally
{
var script =
@"
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
else
return 0
end";
db.ScriptEvaluate(script, new RedisKey[] { lockKey }, new RedisValue[] { lockValue });
}
}
private LoadBalancingState GetOrCreateLoadBalancingState(
IDatabase db,
string stateKey,
IReadOnlyList<DestinationState> destinations
)
{
var existingState = db.StringGet(stateKey);
if (existingState.HasValue)
{
var options = new JsonSerializerOptions { PropertyNameCaseInsensitive = true };
;
var parsedState = (LoadBalancingState?)
System.Text.Json.JsonSerializer.Deserialize<LoadBalancingState>(
existingState.ToString(),
options
);
var version = ComputeConfigHash(destinations);
if (
parsedState != null
&& parsedState.ConfigHash == version
&& parsedState.CurrentWeights != null
)
{
return parsedState;
}
}
var newState = new LoadBalancingState
{
ConfigHash = ComputeConfigHash(destinations),
CurrentWeights = new Dictionary<string, int>(),
};
foreach (var dest in destinations)
{
var weight = GetWeight(dest);
newState.CurrentWeights[dest.DestinationId] = 0;
}
var json = System.Text.Json.JsonSerializer.Serialize(newState);
db.StringSet(stateKey, json, TimeSpan.FromHours(1));
return newState;
}
private long ComputeConfigHash(IReadOnlyList<DestinationState> destinations)
{
var hash = 0L;
foreach (var dest in destinations.OrderBy(d => d.DestinationId))
{
var weight = GetWeight(dest);
hash = HashCode.Combine(hash, dest.DestinationId.GetHashCode());
hash = HashCode.Combine(hash, weight);
}
return hash;
}
private void UpdateCurrentWeights(
IDatabase db,
string stateKey,
LoadBalancingState state,
DestinationState? selected
)
{
if (selected == null)
return;
var json = JsonSerializer.Serialize(state);
db.StringSet(stateKey, json, TimeSpan.FromHours(1));
}
private DestinationState? SelectByWeight(
LoadBalancingState state,
IReadOnlyList<DestinationState> destinations
)
{
int maxWeight = int.MinValue;
int totalWeight = 0;
DestinationState? selected = null;
foreach (var dest in destinations)
{
if (!state.CurrentWeights.ContainsKey(dest.DestinationId))
{
state.CurrentWeights[dest.DestinationId] = 0;
}
var weight = GetWeight(dest);
var currentWeight = state.CurrentWeights[dest.DestinationId];
var newWeight = currentWeight + weight;
state.CurrentWeights[dest.DestinationId] = newWeight;
totalWeight += weight;
if (newWeight > maxWeight)
{
maxWeight = newWeight;
selected = dest;
}
}
if (selected != null)
{
state.CurrentWeights[selected.DestinationId] = maxWeight - totalWeight;
}
return selected;
}
private DestinationState? FallbackSelection(IReadOnlyList<DestinationState> destinations)
{
var hash = ComputeRequestHash();
var index = Math.Abs(hash % destinations.Count);
return destinations[index];
}
private int ComputeRequestHash()
{
var now = DateTime.UtcNow;
return HashCode.Combine(now.Second.GetHashCode(), now.Millisecond.GetHashCode());
}
private int GetWeight(DestinationState destination)
{
if (
destination.Model?.Config?.Metadata?.TryGetValue("Weight", out var weightStr) == true
&& int.TryParse(weightStr, out var weight)
)
{
return weight;
}
return 1;
}
private class LoadBalancingState
{
public long ConfigHash { get; set; }
public Dictionary<string, int> CurrentWeights { get; set; } = new();
}
}