Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions server/ai_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ func NewAISessionSelector(ctx context.Context, cap core.Capability, modelID stri
var warmSel, coldSel BroadcastSessionsSelector
var autoClear bool
if cap == core.Capability_LiveVideoToVideo {
// For Realtime Video AI, we don't use any features of MinLSSelector (preferring known sessions, etc.),
// We always select a fresh session which has the lowest initial latency
warmSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, warmCaps)
coldSel = NewSelector(stakeRdr, node.SelectionAlgorithm, node.OrchPerfScore, coldCaps)
// For Realtime Video AI, we use a dedicated selection algorithm
selAlg := LiveSelectionAlgorithm{}
warmSel = NewSelector(stakeRdr, selAlg, node.OrchPerfScore, warmCaps)
coldSel = NewSelector(stakeRdr, selAlg, node.OrchPerfScore, coldCaps)
// we don't use penalties for not in Realtime Video AI
penalty = 0
// Automatically clear the session pool from old sessions during the discovery
Expand Down
30 changes: 23 additions & 7 deletions server/selection_algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ func (sa ProbabilitySelectionAlgorithm) filterByPerfScore(ctx context.Context, a
}

func (sa ProbabilitySelectionAlgorithm) filterByMaxPrice(ctx context.Context, addrs []ethcommon.Address, maxPrice *big.Rat, prices map[ethcommon.Address]*big.Rat) []ethcommon.Address {
res := filterByMaxPrice(ctx, addrs, maxPrice, prices)
if len(res) == 0 && sa.IgnoreMaxPriceIfNeeded {
// If no orchestrators pass the filter, return all Orchestrators
// It means that no orchestrators are below the max price
clog.Warningf(ctx, "No Orchestrators passed max price filter, not using the filter, numAddrs=%d, maxPrice=%v, prices=%v, addrs=%v", len(addrs), maxPrice, prices, addrs)
return addrs
}
return res
}

func filterByMaxPrice(ctx context.Context, addrs []ethcommon.Address, maxPrice *big.Rat, prices map[ethcommon.Address]*big.Rat) []ethcommon.Address {
if maxPrice == nil || len(prices) == 0 {
// Max price filter not defined, return all Orchestrators
return addrs
Expand All @@ -70,13 +81,6 @@ func (sa ProbabilitySelectionAlgorithm) filterByMaxPrice(ctx context.Context, ad
res = append(res, addr)
}
}

if len(res) == 0 && sa.IgnoreMaxPriceIfNeeded {
// If no orchestrators pass the filter, return all Orchestrators
// It means that no orchestrators are below the max price
clog.Warningf(ctx, "No Orchestrators passed max price filter, not using the filter, numAddrs=%d, maxPrice=%v, prices=%v, addrs=%v", len(addrs), maxPrice, prices, addrs)
return addrs
}
return res
}

Expand Down Expand Up @@ -137,3 +141,15 @@ func selectBy(probabilities map[ethcommon.Address]float64) ethcommon.Address {
// number precision
return addrs[0]
}

// LiveSelectionAlgorithm is the Selection Algorithm used for Realtime Video AI
type LiveSelectionAlgorithm struct{}

func (sa LiveSelectionAlgorithm) Select(ctx context.Context, addrs []ethcommon.Address, stakes map[ethcommon.Address]int64, maxPrice *big.Rat, prices map[ethcommon.Address]*big.Rat, perfScores map[ethcommon.Address]float64) ethcommon.Address {
filtered := filterByMaxPrice(ctx, addrs, maxPrice, prices)
if len(filtered) == 0 {
return ethcommon.Address{}
}
// Return the first address that satisfies the max price filter
return filtered[0]
}
66 changes: 66 additions & 0 deletions server/selection_algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,69 @@ func TestSelectByProbability(t *testing.T) {
require.InDelta(t, prob, selectedRatio, 0.01)
}
}

func TestLiveSelectionAlgorithm(t *testing.T) {
tests := []struct {
name string
maxPrice float64
prices map[string]float64
orchestrators []string
want string
}{
{
name: "First Orchestrator with price below maxPrice",
maxPrice: 2000,
prices: map[string]float64{
"0x0000000000000000000000000000000000000002": 2500,
"0x0000000000000000000000000000000000000003": 500,
"0x0000000000000000000000000000000000000004": 1000,
},
orchestrators: []string{
"0x0000000000000000000000000000000000000001",
"0x0000000000000000000000000000000000000002",
"0x0000000000000000000000000000000000000003",
"0x0000000000000000000000000000000000000004",
},
want: "0x0000000000000000000000000000000000000003",
},
{
name: "No Orchestrator with price below maxPrice",
maxPrice: 2000,
prices: map[string]float64{
"0x0000000000000000000000000000000000000002": 2500,
"0x0000000000000000000000000000000000000003": 3500,
"0x0000000000000000000000000000000000000004": 4000,
},
orchestrators: []string{
"0x0000000000000000000000000000000000000001",
"0x0000000000000000000000000000000000000002",
"0x0000000000000000000000000000000000000003",
"0x0000000000000000000000000000000000000004",
},
want: "0x0000000000000000000000000000000000000000",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var addrs []ethcommon.Address
var maxPrice *big.Rat
prices := map[ethcommon.Address]*big.Rat{}
perfScores := map[ethcommon.Address]float64{}
for _, o := range tt.orchestrators {
addr := ethcommon.HexToAddress(o)
addrs = append(addrs, addr)
if price, ok := tt.prices[o]; ok {
prices[addr] = new(big.Rat).SetFloat64(price)
}
}
if tt.maxPrice > 0 {
maxPrice = new(big.Rat).SetFloat64(tt.maxPrice)
}
sa := &LiveSelectionAlgorithm{}

res := sa.Select(context.Background(), addrs, nil, maxPrice, prices, perfScores)
require.Equal(t, ethcommon.HexToAddress(tt.want), res)
})
}
}
Loading