diff --git a/server/ai_session.go b/server/ai_session.go index 42dddbde92..62512c144b 100644 --- a/server/ai_session.go +++ b/server/ai_session.go @@ -222,10 +222,10 @@ func NewAISessionSelector(ctx context.Context, cap core.Capability, modelID stri var warmSel, coldSel BroadcastSessionsSelector var autoClear bool if cap == core.Capability_LiveVideoToVideo { - // For Realtime Video AI, we don't use any features of MinLSSelector (preferring known sessions, etc.), - // We always select a fresh session which has the lowest initial latency - warmSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, warmCaps) - coldSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, coldCaps) + // For Realtime Video AI, we use a dedicated selection algorithm + selAlg := LiveSelectionAlgorithm{} + warmSel = NewSelector(stakeRdr, selAlg, node.OrchPerfScore, warmCaps) + coldSel = NewSelector(stakeRdr, selAlg, node.OrchPerfScore, coldCaps) // we don't use penalties for not in Realtime Video AI penalty = 0 // Automatically clear the session pool from old sessions during the discovery diff --git a/server/selection_algorithm.go b/server/selection_algorithm.go index 8ea304002b..be9ab83d11 100644 --- a/server/selection_algorithm.go +++ b/server/selection_algorithm.go @@ -58,6 +58,17 @@ func (sa ProbabilitySelectionAlgorithm) filterByPerfScore(ctx context.Context, a } func (sa ProbabilitySelectionAlgorithm) filterByMaxPrice(ctx context.Context, addrs []ethcommon.Address, maxPrice *big.Rat, prices map[ethcommon.Address]*big.Rat) []ethcommon.Address { + res := filterByMaxPrice(ctx, addrs, maxPrice, prices) + if len(res) == 0 && sa.IgnoreMaxPriceIfNeeded { + // If no orchestrators pass the filter, return all Orchestrators + // It means that no orchestrators are below the max price + clog.Warningf(ctx, "No Orchestrators passed max price filter, not using the filter, numAddrs=%d, maxPrice=%v, prices=%v, addrs=%v", len(addrs), maxPrice, prices, addrs) + return addrs + } + return res +} + +func filterByMaxPrice(ctx context.Context, addrs []ethcommon.Address, maxPrice *big.Rat, prices map[ethcommon.Address]*big.Rat) []ethcommon.Address { if maxPrice == nil || len(prices) == 0 { // Max price filter not defined, return all Orchestrators return addrs @@ -70,13 +81,6 @@ func (sa ProbabilitySelectionAlgorithm) filterByMaxPrice(ctx context.Context, ad res = append(res, addr) } } - - if len(res) == 0 && sa.IgnoreMaxPriceIfNeeded { - // If no orchestrators pass the filter, return all Orchestrators - // It means that no orchestrators are below the max price - clog.Warningf(ctx, "No Orchestrators passed max price filter, not using the filter, numAddrs=%d, maxPrice=%v, prices=%v, addrs=%v", len(addrs), maxPrice, prices, addrs) - return addrs - } return res } @@ -137,3 +141,15 @@ func selectBy(probabilities map[ethcommon.Address]float64) ethcommon.Address { // number precision return addrs[0] } + +// LiveSelectionAlgorithm is the Selection Algorithm used for Realtime Video AI +type LiveSelectionAlgorithm struct{} + +func (sa LiveSelectionAlgorithm) Select(ctx context.Context, addrs []ethcommon.Address, stakes map[ethcommon.Address]int64, maxPrice *big.Rat, prices map[ethcommon.Address]*big.Rat, perfScores map[ethcommon.Address]float64) ethcommon.Address { + filtered := filterByMaxPrice(ctx, addrs, maxPrice, prices) + if len(filtered) == 0 { + return ethcommon.Address{} + } + // Return the first address that satisfies the max price filter + return filtered[0] +} diff --git a/server/selection_algorithm_test.go b/server/selection_algorithm_test.go index 61b966b42e..ba62eb3697 100644 --- a/server/selection_algorithm_test.go +++ b/server/selection_algorithm_test.go @@ -342,3 +342,69 @@ func TestSelectByProbability(t *testing.T) { require.InDelta(t, prob, selectedRatio, 0.01) } } + +func TestLiveSelectionAlgorithm(t *testing.T) { + tests := []struct { + name string + maxPrice float64 + prices map[string]float64 + orchestrators []string + want string + }{ + { + name: "First Orchestrator with price below maxPrice", + maxPrice: 2000, + prices: map[string]float64{ + "0x0000000000000000000000000000000000000002": 2500, + "0x0000000000000000000000000000000000000003": 500, + "0x0000000000000000000000000000000000000004": 1000, + }, + orchestrators: []string{ + "0x0000000000000000000000000000000000000001", + "0x0000000000000000000000000000000000000002", + "0x0000000000000000000000000000000000000003", + "0x0000000000000000000000000000000000000004", + }, + want: "0x0000000000000000000000000000000000000003", + }, + { + name: "No Orchestrator with price below maxPrice", + maxPrice: 2000, + prices: map[string]float64{ + "0x0000000000000000000000000000000000000002": 2500, + "0x0000000000000000000000000000000000000003": 3500, + "0x0000000000000000000000000000000000000004": 4000, + }, + orchestrators: []string{ + "0x0000000000000000000000000000000000000001", + "0x0000000000000000000000000000000000000002", + "0x0000000000000000000000000000000000000003", + "0x0000000000000000000000000000000000000004", + }, + want: "0x0000000000000000000000000000000000000000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var addrs []ethcommon.Address + var maxPrice *big.Rat + prices := map[ethcommon.Address]*big.Rat{} + perfScores := map[ethcommon.Address]float64{} + for _, o := range tt.orchestrators { + addr := ethcommon.HexToAddress(o) + addrs = append(addrs, addr) + if price, ok := tt.prices[o]; ok { + prices[addr] = new(big.Rat).SetFloat64(price) + } + } + if tt.maxPrice > 0 { + maxPrice = new(big.Rat).SetFloat64(tt.maxPrice) + } + sa := &LiveSelectionAlgorithm{} + + res := sa.Select(context.Background(), addrs, nil, maxPrice, prices, perfScores) + require.Equal(t, ethcommon.HexToAddress(tt.want), res) + }) + } +}