using System.Text.Json; using Microsoft.Extensions.Logging; using StackExchange.Redis; using Yarp.ReverseProxy.LoadBalancing; using Yarp.ReverseProxy.Model; using YarpGateway.Config; namespace YarpGateway.LoadBalancing; public class DistributedWeightedRoundRobinPolicy : ILoadBalancingPolicy { private readonly IConnectionMultiplexer _redis; private readonly RedisConfig _config; private readonly ILogger _logger; public string Name => "DistributedWeightedRoundRobin"; public DistributedWeightedRoundRobinPolicy( IConnectionMultiplexer redis, RedisConfig config, ILogger logger ) { _redis = redis; _config = config; _logger = logger; } public DestinationState? PickDestination( HttpContext context, ClusterState cluster, IReadOnlyList availableDestinations ) { if (availableDestinations.Count == 0) return null; if (availableDestinations.Count == 1) return availableDestinations[0]; var clusterId = cluster.ClusterId; var db = _redis.GetDatabase(); var lockKey = $"lock:{_config.InstanceName}:{clusterId}"; var stateKey = $"lb:{_config.InstanceName}:{clusterId}:state"; var lockValue = Guid.NewGuid().ToString(); var lockAcquired = db.StringSet( lockKey, lockValue, TimeSpan.FromMilliseconds(500), When.NotExists ); if (!lockAcquired) { _logger.LogDebug( "Lock busy for cluster {Cluster}, using fallback selection", clusterId ); return FallbackSelection(availableDestinations); } try { var state = GetOrCreateLoadBalancingState(db, stateKey, availableDestinations); var selectedDestination = SelectByWeight(state, availableDestinations); UpdateCurrentWeights(db, stateKey, state, selectedDestination); _logger.LogDebug( "Selected {Destination} for cluster {Cluster}", selectedDestination?.DestinationId, clusterId ); return selectedDestination; } catch (Exception ex) { _logger.LogError( ex, "Error in distributed load balancing for cluster {Cluster}", clusterId ); return availableDestinations[0]; } finally { var script = @" if redis.call('GET', KEYS[1]) == ARGV[1] then return redis.call('DEL', KEYS[1]) else return 0 end"; db.ScriptEvaluate(script, new RedisKey[] { lockKey }, new RedisValue[] { lockValue }); } } private LoadBalancingState GetOrCreateLoadBalancingState( IDatabase db, string stateKey, IReadOnlyList destinations ) { var existingState = db.StringGet(stateKey); if (existingState.HasValue) { var options = new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; ; var parsedState = (LoadBalancingState?) System.Text.Json.JsonSerializer.Deserialize( existingState.ToString(), options ); var version = ComputeConfigHash(destinations); if ( parsedState != null && parsedState.ConfigHash == version && parsedState.CurrentWeights != null ) { return parsedState; } } var newState = new LoadBalancingState { ConfigHash = ComputeConfigHash(destinations), CurrentWeights = new Dictionary(), }; foreach (var dest in destinations) { var weight = GetWeight(dest); newState.CurrentWeights[dest.DestinationId] = 0; } var json = System.Text.Json.JsonSerializer.Serialize(newState); db.StringSet(stateKey, json, TimeSpan.FromHours(1)); return newState; } private long ComputeConfigHash(IReadOnlyList destinations) { var hash = 0L; foreach (var dest in destinations.OrderBy(d => d.DestinationId)) { var weight = GetWeight(dest); hash = HashCode.Combine(hash, dest.DestinationId.GetHashCode()); hash = HashCode.Combine(hash, weight); } return hash; } private void UpdateCurrentWeights( IDatabase db, string stateKey, LoadBalancingState state, DestinationState? selected ) { if (selected == null) return; var json = JsonSerializer.Serialize(state); db.StringSet(stateKey, json, TimeSpan.FromHours(1)); } private DestinationState? SelectByWeight( LoadBalancingState state, IReadOnlyList destinations ) { int maxWeight = int.MinValue; int totalWeight = 0; DestinationState? selected = null; foreach (var dest in destinations) { if (!state.CurrentWeights.ContainsKey(dest.DestinationId)) { state.CurrentWeights[dest.DestinationId] = 0; } var weight = GetWeight(dest); var currentWeight = state.CurrentWeights[dest.DestinationId]; var newWeight = currentWeight + weight; state.CurrentWeights[dest.DestinationId] = newWeight; totalWeight += weight; if (newWeight > maxWeight) { maxWeight = newWeight; selected = dest; } } if (selected != null) { state.CurrentWeights[selected.DestinationId] = maxWeight - totalWeight; } return selected; } private DestinationState? FallbackSelection(IReadOnlyList destinations) { var hash = ComputeRequestHash(); var index = Math.Abs(hash % destinations.Count); return destinations[index]; } private int ComputeRequestHash() { var now = DateTime.UtcNow; return HashCode.Combine(now.Second.GetHashCode(), now.Millisecond.GetHashCode()); } private int GetWeight(DestinationState destination) { if ( destination.Model?.Config?.Metadata?.TryGetValue("Weight", out var weightStr) == true && int.TryParse(weightStr, out var weight) ) { return weight; } return 1; } private class LoadBalancingState { public long ConfigHash { get; set; } public Dictionary CurrentWeights { get; set; } = new(); } }