Skip to content
Merged
42 changes: 37 additions & 5 deletions core/ai_orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,18 +398,44 @@
}

// CheckAICapacity verifies if the orchestrator can process a request for a specific pipeline and modelID.
func (orch *orchestrator) CheckAICapacity(pipeline, modelID string) bool {
func (orch *orchestrator) CheckAICapacity(pipeline, modelID string) (bool, chan<- bool) {
var hasCapacity bool
if orch.node.AIWorker != nil {
// confirm local worker has capacity
return orch.node.AIWorker.HasCapacity(pipeline, modelID)
if pipeline == "live-video-to-video" {
return orch.node.AIWorker.HasCapacity(pipeline, modelID), nil
}

Check warning on line 407 in core/ai_orchestrator.go

View check run for this annotation

Codecov / codecov/patch

core/ai_orchestrator.go#L406-L407

Added lines #L406 - L407 were not covered by tests

// batch pipelines manage the capacity at the Orchestrator level to manage local ai-worker capacity
err := orch.node.ReserveAICapability(pipeline, modelID)
if err == nil {
hasCapacity = true
}
} else {
// remote workers: RemoteAIWorkerManager only selects remote workers if they have capacity for the pipeline/model
// live-video-to-video is not using remote workers currently
if orch.node.AIWorkerManager != nil {
return orch.node.AIWorkerManager.workerHasCapacity(pipeline, modelID)
} else {
return false
hasCapacity = orch.node.AIWorkerManager.workerHasCapacity(pipeline, modelID)
}
}

if !hasCapacity {
return false, nil
}

// reserve AI capacity for the pipeline and modelID
releaseCapacity := make(chan bool)

go func() {
<-releaseCapacity
orch.node.ReleaseAICapability(pipeline, modelID)
glog.Infof("Released AI capacity for pipeline=%s model_id=%s", pipeline, modelID)
close(releaseCapacity)

}()

return true, releaseCapacity

}

func (orch *orchestrator) GetLiveAICapacity() worker.Capacity {
Expand Down Expand Up @@ -545,6 +571,7 @@
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.TextToImage(ctx, req)

Check warning on line 574 in core/ai_orchestrator.go

View check run for this annotation

Codecov / codecov/patch

core/ai_orchestrator.go#L574

Added line #L574 was not covered by tests
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "image/png")
} else {
Expand Down Expand Up @@ -578,6 +605,7 @@
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.LiveVideoToVideo(ctx, req)

Check warning on line 608 in core/ai_orchestrator.go

View check run for this annotation

Codecov / codecov/patch

core/ai_orchestrator.go#L608

Added line #L608 was not covered by tests
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "application/json")
} else {
Expand Down Expand Up @@ -611,6 +639,7 @@
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.ImageToImage(ctx, req)

Check warning on line 642 in core/ai_orchestrator.go

View check run for this annotation

Codecov / codecov/patch

core/ai_orchestrator.go#L642

Added line #L642 was not covered by tests
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "image/png")
} else {
Expand Down Expand Up @@ -655,6 +684,7 @@
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.ImageToVideo(ctx, req)

Check warning on line 687 in core/ai_orchestrator.go

View check run for this annotation

Codecov / codecov/patch

core/ai_orchestrator.go#L687

Added line #L687 was not covered by tests
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "video/mp4")
} else {
Expand Down Expand Up @@ -699,6 +729,7 @@
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.Upscale(ctx, req)

Check warning on line 732 in core/ai_orchestrator.go

View check run for this annotation

Codecov / codecov/patch

core/ai_orchestrator.go#L732

Added line #L732 was not covered by tests
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "image/png")
} else {
Expand Down Expand Up @@ -880,6 +911,7 @@
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
workerResp, err := orch.node.TextToSpeech(ctx, req)

Check warning on line 914 in core/ai_orchestrator.go

View check run for this annotation

Codecov / codecov/patch

core/ai_orchestrator.go#L914

Added line #L914 was not covered by tests
if err == nil {
return orch.node.saveLocalAIWorkerResults(ctx, *workerResp, requestID, "audio/wav")
} else {
Expand Down
10 changes: 7 additions & 3 deletions core/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,9 @@ func TestCheckAICapacity(t *testing.T) {
n.Capabilities = createAIWorkerCapabilities()
n.AIWorker = &wkr
// Test when local AI worker has capacity
hasCapacity := o.CheckAICapacity("text-to-image", "livepeer/model1")
hasCapacity, releaseCapacity := o.CheckAICapacity("text-to-image", "livepeer/model1")
assert.True(t, hasCapacity)
releaseCapacity <- true

o.node.AIWorker = nil
o.node.AIWorkerManager = NewRemoteAIWorkerManager()
Expand All @@ -534,12 +535,15 @@ func TestCheckAICapacity(t *testing.T) {
}()
time.Sleep(1 * time.Millisecond) // allow the workers to activate

hasCapacity = o.CheckAICapacity("text-to-image", "livepeer/model1")
hasCapacity, releaseCapacity = o.CheckAICapacity("text-to-image", "livepeer/model1")
assert.True(t, hasCapacity)
assert.NotNil(t, releaseCapacity)
releaseCapacity <- true

// Test when remote AI worker does not have capacity
hasCapacity = o.CheckAICapacity("text-to-image", "livepeer/model2")
hasCapacity, releaseCapacity = o.CheckAICapacity("text-to-image", "livepeer/model2")
assert.False(t, hasCapacity)
assert.Nil(t, releaseCapacity)
}
func TestRemoteAIWorkerProcessPipelines(t *testing.T) {
drivers.NodeStorage = drivers.NewMemoryDriver(nil)
Expand Down
19 changes: 15 additions & 4 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@
}

// Check if there is capacity for the request
if !orch.CheckAICapacity(pipeline, modelID) {
respondWithError(w, fmt.Sprintf("Insufficient capacity for pipeline=%v modelID=%v", pipeline, modelID), http.StatusServiceUnavailable)
hasCapacity, _ := orch.CheckAICapacity(pipeline, modelID)
if !hasCapacity {
clog.Errorf(ctx, "Insufficient capacity for pipeline=%v modelID=%v", pipeline, modelID)
respondWithError(w, "insufficient capacity", http.StatusServiceUnavailable)

Check warning on line 137 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L134-L137

Added lines #L134 - L137 were not covered by tests
return
}

Expand Down Expand Up @@ -497,8 +499,11 @@
manifestID := core.ManifestID(strconv.Itoa(int(cap)) + "_" + modelID)

// Check if there is capacity for the request.
if !orch.CheckAICapacity(pipeline, modelID) {
respondWithError(w, fmt.Sprintf("Insufficient capacity for pipeline=%v modelID=%v", pipeline, modelID), http.StatusServiceUnavailable)
// Capability capacity is reserved if available and released when response is received
hasCapacity, releaseCapacity := orch.CheckAICapacity(pipeline, modelID)
if !hasCapacity {
clog.Errorf(ctx, "Insufficient capacity for pipeline=%v modelID=%v", pipeline, modelID)
respondWithError(w, "insufficient capacity", http.StatusServiceUnavailable)

Check warning on line 506 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L502-L506

Added lines #L502 - L506 were not covered by tests
return
}

Expand Down Expand Up @@ -528,6 +533,7 @@

start := time.Now()
resp, err := submitFn(ctx)

Check warning on line 536 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L536

Added line #L536 was not covered by tests
if err != nil {
if monitor.Enabled {
monitor.AIProcessingError(err.Error(), pipeline, modelID, sender.Hex())
Expand Down Expand Up @@ -614,6 +620,7 @@
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
releaseCapacity <- true

Check warning on line 623 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L623

Added line #L623 was not covered by tests
return
}

Expand All @@ -631,8 +638,12 @@
break
}
}
//release capacity after streaming is done
releaseCapacity <- true

Check warning on line 642 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L642

Added line #L642 was not covered by tests

} else {
// Non-streaming response
releaseCapacity <- true

Check warning on line 646 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L646

Added line #L646 was not covered by tests
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
Expand Down
5 changes: 3 additions & 2 deletions server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
Sign([]byte) ([]byte, error)
VerifySig(ethcommon.Address, string, []byte) bool
CheckCapacity(core.ManifestID) error
CheckAICapacity(pipeline, modelID string) bool
CheckAICapacity(pipeline, modelID string) (bool, chan<- bool)
GetLiveAICapacity() worker.Capacity
TranscodeSeg(context.Context, *core.SegTranscodingMetadata, *stream.HLSSegment) (*core.TranscodeResult, error)
ServeTranscoder(stream net.Transcoder_RegisterTranscoderServer, capacity int, capabilities *net.Capabilities)
Expand Down Expand Up @@ -390,7 +390,8 @@
if liveCap, ok := caps.Constraints.PerCapability[uint32(core.Capability_LiveVideoToVideo)]; ok {
pipeline := "live-video-to-video"
for modelID := range liveCap.GetModels() {
if orch.CheckAICapacity(pipeline, modelID) {
hasCapacity, _ := orch.CheckAICapacity(pipeline, modelID)
if hasCapacity {

Check warning on line 394 in server/rpc.go

View check run for this annotation

Codecov / codecov/patch

server/rpc.go#L393-L394

Added lines #L393 - L394 were not covered by tests
// It has capacity for at least one of the requested models
return nil
}
Expand Down
8 changes: 4 additions & 4 deletions server/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ func (r *stubOrchestrator) LiveVideoToVideo(ctx context.Context, requestID strin
return nil, nil
}

func (r *stubOrchestrator) CheckAICapacity(pipeline, modelID string) bool {
return true
func (r *stubOrchestrator) CheckAICapacity(pipeline, modelID string) (bool, chan<- bool) {
return true, nil
}
func (r *stubOrchestrator) AIResults(job int64, res *core.RemoteAIWorkerResult) {
}
Expand Down Expand Up @@ -1486,8 +1486,8 @@ func (r *mockOrchestrator) TextToSpeech(ctx context.Context, requestID string, r
func (r *mockOrchestrator) LiveVideoToVideo(ctx context.Context, requestID string, req worker.GenLiveVideoToVideoJSONRequestBody) (interface{}, error) {
return nil, nil
}
func (r *mockOrchestrator) CheckAICapacity(pipeline, modelID string) bool {
return true
func (r *mockOrchestrator) CheckAICapacity(pipeline, modelID string) (bool, chan<- bool) {
return true, nil
}
func (r *mockOrchestrator) AIResults(job int64, res *core.RemoteAIWorkerResult) {

Expand Down
Loading