diff --git a/byoc/job_orchestrator.go b/byoc/job_orchestrator.go index 210474992e..ddc930a39f 100644 --- a/byoc/job_orchestrator.go +++ b/byoc/job_orchestrator.go @@ -491,6 +491,10 @@ func (bso *BYOCOrchestratorServer) confirmPayment(ctx context.Context, sender et orchBal, pmtErr := bso.processPayment(ctx, sender, capability, paymentHdr) if pmtErr != nil { //log if there are payment errors but continue, balance will runout and clean up + if paymentHdr != "" { + clog.Errorf(ctx, "rejecting request: payment header present but invalid: %v", pmtErr) + return errPaymentError + } clog.Infof(ctx, "job payment error: %v", pmtErr) } @@ -502,25 +506,22 @@ func (bso *BYOCOrchestratorServer) confirmPayment(ctx context.Context, sender et return nil } -// process payment and return balance +// processPayment decodes and applies the payment header if present. +// Always returns a non-nil balance so callers can safely compare. func (bso *BYOCOrchestratorServer) processPayment(ctx context.Context, sender ethcommon.Address, capability string, paymentHdr string) (*big.Rat, error) { if paymentHdr != "" { payment, err := getPayment(paymentHdr) if err != nil { clog.Errorf(ctx, "job payment invalid: %v", err) - return nil, errPaymentError + return bso.getPaymentBalance(sender, capability), errPaymentError } if err := bso.orch.ProcessPayment(ctx, payment, core.ManifestID(capability)); err != nil { - bso.orch.FreeExternalCapabilityCapacity(capability) clog.Errorf(ctx, "Error processing payment: %v", err) - return nil, errPaymentError + return bso.getPaymentBalance(sender, capability), errPaymentError } } - orchBal := bso.getPaymentBalance(sender, capability) - - return orchBal, nil - + return bso.getPaymentBalance(sender, capability), nil } func (bso *BYOCOrchestratorServer) chargeForCompute(start time.Time, price *net.PriceInfo, sender ethcommon.Address, jobId string) { @@ -566,18 +567,26 @@ func (bso *BYOCOrchestratorServer) verifyJobCreds(ctx context.Context, jobCreds return nil, errSegSig } - if !bso.orch.VerifySig(ethcommon.HexToAddress(jobData.Sender), jobData.Request+jobData.Parameters, sigByte) { - clog.Errorf(ctx, "Sig check failed sender=%v", jobData.Sender) - return nil, errSegSig - } + sender := ethcommon.HexToAddress(jobData.Sender) - if reserveCapacity && bso.orch.ReserveExternalCapabilityCapacity(jobData.Capability) != nil { - return nil, errZeroCapacity + // Verify V1 structured binary format + v1Payload := FlattenBYOCJob(&BYOCJobSigningInput{ + ID: jobData.ID, + Capability: jobData.Capability, + Request: jobData.Request, + Parameters: jobData.Parameters, + TimeoutSeconds: jobData.Timeout, + }) + if bso.orch.VerifySig(sender, string(v1Payload), sigByte) { + if reserveCapacity && bso.orch.ReserveExternalCapabilityCapacity(jobData.Capability) != nil { + return nil, errZeroCapacity + } + jobData.CapabilityUrl = bso.orch.GetUrlForCapability(jobData.Capability) + return jobData, nil } - jobData.CapabilityUrl = bso.orch.GetUrlForCapability(jobData.Capability) - - return jobData, nil + clog.Errorf(ctx, "Sig check failed sender=%v", jobData.Sender) + return nil, errSegSig } func (bso *BYOCOrchestratorServer) verifyTokenCreds(ctx context.Context, tokenCreds string) (*JobSender, error) { diff --git a/byoc/types.go b/byoc/types.go index fa9b4a5e8b..71367f50d3 100644 --- a/byoc/types.go +++ b/byoc/types.go @@ -3,6 +3,7 @@ package byoc import ( "context" "crypto/tls" + "encoding/binary" "errors" "math/big" gonet "net" @@ -279,3 +280,66 @@ type byocLiveRequestParams struct { // when the write for the last segment started lastSegmentTime time.Time } + +// BYOCJobSigV1Prefix is the 16-byte domain separator for BYOC job signatures (V1). +// Prevents cross-protocol signature replay. +const BYOCJobSigV1Prefix = "LP_BYOC_JOB_V1\x00\x00" + +// BYOCJobSigningInput holds the fields that are bound into a BYOC job signature. +type BYOCJobSigningInput struct { + ID string + Capability string + Request string + Parameters string + TimeoutSeconds int +} + +// FlattenBYOCJob produces a deterministic binary representation of a BYOC job +// for signing, similar to SegTranscodingMetadata.Flatten() used by LV2V. +// +// Wire format: +// +// version(16) || timeout(4,BE) || len(id)(4,BE) || id || len(cap)(4,BE) || cap +// || len(req)(4,BE) || req || len(params)(4,BE) || params +func FlattenBYOCJob(job *BYOCJobSigningInput) []byte { + idBytes := []byte(job.ID) + capBytes := []byte(job.Capability) + reqBytes := []byte(job.Request) + paramsBytes := []byte(job.Parameters) + + size := 16 + 4 + + 4 + len(idBytes) + + 4 + len(capBytes) + + 4 + len(reqBytes) + + 4 + len(paramsBytes) + + buf := make([]byte, size) + offset := 0 + + copy(buf[offset:], []byte(BYOCJobSigV1Prefix)) + offset += 16 + + binary.BigEndian.PutUint32(buf[offset:], uint32(job.TimeoutSeconds)) + offset += 4 + + binary.BigEndian.PutUint32(buf[offset:], uint32(len(idBytes))) + offset += 4 + copy(buf[offset:], idBytes) + offset += len(idBytes) + + binary.BigEndian.PutUint32(buf[offset:], uint32(len(capBytes))) + offset += 4 + copy(buf[offset:], capBytes) + offset += len(capBytes) + + binary.BigEndian.PutUint32(buf[offset:], uint32(len(reqBytes))) + offset += 4 + copy(buf[offset:], reqBytes) + offset += len(reqBytes) + + binary.BigEndian.PutUint32(buf[offset:], uint32(len(paramsBytes))) + offset += 4 + copy(buf[offset:], paramsBytes) + + return buf +} diff --git a/byoc/utils.go b/byoc/utils.go index 64374cd62a..edc8aabb9f 100644 --- a/byoc/utils.go +++ b/byoc/utils.go @@ -33,7 +33,16 @@ var sendJobReqWithTimeout = sendReqWithTimeout func (g *gatewayJob) sign() error { //sign the request gateway := g.node.OrchestratorPool.Broadcaster() - sig, err := gateway.Sign([]byte(g.Job.Req.Request + g.Job.Req.Parameters)) + + sigPayload := FlattenBYOCJob(&BYOCJobSigningInput{ + ID: g.Job.Req.ID, + Capability: g.Job.Req.Capability, + Request: g.Job.Req.Request, + Parameters: g.Job.Req.Parameters, + TimeoutSeconds: g.Job.Req.Timeout, + }) + + sig, err := gateway.Sign(sigPayload) if err != nil { return errors.New(fmt.Sprintf("Unable to sign request err=%v", err)) } diff --git a/core/orchestrator.go b/core/orchestrator.go index fd33eddc5a..ee5997595a 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -382,7 +382,7 @@ func (orch *orchestrator) PriceInfoForCaps(sender ethcommon.Address, manifestID // priceInfo returns price per pixel as a fixed point number wrapped in a big.Rat func (orch *orchestrator) priceInfo(sender ethcommon.Address, manifestID ManifestID, caps *net.Capabilities) (*big.Rat, error) { // If there is already a fixed price for the given session, use this price - if manifestID != "" { + if manifestID != "" && orch.node.Balances != nil { if balances, ok := orch.node.Balances.balances[sender]; ok { fixedPrice := balances.FixedPrice(manifestID) if fixedPrice != nil { @@ -412,9 +412,18 @@ func (orch *orchestrator) priceInfo(sender ethcommon.Address, manifestID Manifes continue } for modelID := range constraints.Models { - price := orch.node.GetBasePriceForCap(sender.String(), Capability(cap), modelID) - if price == nil { - price = orch.node.GetBasePriceForCap("default", Capability(cap), modelID) + var price *big.Rat + if Capability(cap) == Capability_BYOC { + // BYOC prices are stored in jobPriceInfo, keyed by capability name + price = orch.node.GetPriceForJob(sender.String(), modelID) + if price == nil || price.Sign() == 0 { + price = orch.node.GetPriceForJob("default", modelID) + } + } else { + price = orch.node.GetBasePriceForCap(sender.String(), Capability(cap), modelID) + if price == nil { + price = orch.node.GetBasePriceForCap("default", Capability(cap), modelID) + } } if price != nil { diff --git a/server/remote_signer.go b/server/remote_signer.go index 59a925828f..00b70fc75d 100644 --- a/server/remote_signer.go +++ b/server/remote_signer.go @@ -16,6 +16,7 @@ import ( ethcommon "github.com/ethereum/go-ethereum/common" "github.com/golang/glog" "github.com/golang/protobuf/proto" + "github.com/livepeer/go-livepeer/byoc" "github.com/livepeer/go-livepeer/clog" "github.com/livepeer/go-livepeer/core" lpcrypto "github.com/livepeer/go-livepeer/crypto" @@ -29,6 +30,7 @@ const HTTPStatusPriceExceeded = 481 const HTTPStatusNoTickets = 482 const RemoteType_LiveVideoToVideo = "lv2v" const PipelineLiveVideoToVideo = "live-video-to-video" +const RemoteType_BYOC = "byoc" // SignOrchestratorInfo handles signing GetOrchestratorInfo requests for multiple orchestrators func (ls *LivepeerServer) SignOrchestratorInfo(w http.ResponseWriter, r *http.Request) { @@ -68,11 +70,76 @@ func (ls *LivepeerServer) SignOrchestratorInfo(w http.ResponseWriter, r *http.Re _ = json.NewEncoder(w).Encode(results) } +// SignBYOCJobRequest signs a BYOC job using the V1 binary format (FlattenBYOCJob). +type SignBYOCJobRequestInput struct { + ID string `json:"id"` + Capability string `json:"capability"` + Request string `json:"request"` + Parameters string `json:"parameters"` + TimeoutSeconds int `json:"timeout_seconds"` +} + +type SignBYOCJobRequestResponse struct { + Sender string `json:"sender"` + Signature string `json:"signature"` +} + +func (ls *LivepeerServer) SignBYOCJobRequest(w http.ResponseWriter, r *http.Request) { + ctx := clog.AddVal(r.Context(), "request_id", string(core.RandomManifestID())) + remoteAddr := getRemoteAddr(r) + clog.Info(ctx, "BYOC job signing request", "ip", remoteAddr) + + gw := core.NewBroadcaster(ls.LivepeerNode) + + var req SignBYOCJobRequestInput + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + clog.Errorf(ctx, "Failed to decode SignBYOCJobRequest err=%q", err) + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + if req.ID == "" || req.Capability == "" { + err := fmt.Errorf("sign-byoc-job requires non-empty id and capability") + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + if req.TimeoutSeconds <= 0 { + err := fmt.Errorf("sign-byoc-job requires positive timeout_seconds") + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + sigPayload := byoc.FlattenBYOCJob(&byoc.BYOCJobSigningInput{ + ID: req.ID, + Capability: req.Capability, + Request: req.Request, + Parameters: req.Parameters, + TimeoutSeconds: req.TimeoutSeconds, + }) + + sig, err := gw.Sign(sigPayload) + if err != nil { + clog.Errorf(ctx, "Failed to sign BYOC job request err=%q", err) + respondJsonError(ctx, w, err, http.StatusInternalServerError) + return + } + + response := SignBYOCJobRequestResponse{ + Sender: gw.Address().Hex(), + Signature: "0x" + hex.EncodeToString(sig), + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(response) +} + // StartRemoteSignerServer starts the HTTP server for remote signer mode func StartRemoteSignerServer(ls *LivepeerServer, bind string) error { // Register the remote signer endpoints ls.HTTPMux.Handle("POST /sign-orchestrator-info", http.HandlerFunc(ls.SignOrchestratorInfo)) ls.HTTPMux.Handle("POST /generate-live-payment", http.HandlerFunc(ls.GenerateLivePayment)) + ls.HTTPMux.Handle("POST /sign-byoc-job", http.HandlerFunc(ls.SignBYOCJobRequest)) if ls.LivepeerNode.RemoteDiscovery { rdp := RemoteDiscoveryConfig{ Pool: ls.LivepeerNode.OrchestratorPool, @@ -164,12 +231,12 @@ type RemotePaymentRequest struct { Orchestrator []byte `json:"orchestrator"` // Set if an ID is needed to tie into orch accounting for a session. Optional - ManifestID string + ManifestID string `json:"manifestId,omitempty"` // Number of pixels to generate a ticket for. Required if `type` is not set. InPixels int64 `json:"inPixels"` - // Job type to automatically calculate payments. Valid values: `lv2v`. Optional. + // Job type to automatically calculate payments. Valid values: `lv2v`, `byoc`. Optional. Type string `json:"type"` // Capabilities to include in the ticket. Optional; may be set for the lv2v job type. @@ -208,6 +275,36 @@ func verifyStateSignature(ls *LivepeerServer, stateBytes []byte, sig []byte) err return nil } +// resolvePriceInfo returns the effective PriceInfo for a payment request. +// For BYOC, pricing may only be advertised in CapabilitiesPrices rather than the +// top-level PriceInfo, so we search for a matching capability-specific entry. +func resolvePriceInfo(oInfo *net.OrchestratorInfo, reqType string, manifestID string) *net.PriceInfo { + top := oInfo.PriceInfo + if reqType != RemoteType_BYOC { + return top + } + + if top != nil && top.PricePerUnit > 0 && top.PixelsPerUnit > 0 && + top.Capability == uint32(core.Capability_BYOC) && top.Constraint != "" { + return top + } + + for _, cp := range oInfo.CapabilitiesPrices { + if cp == nil || cp.PricePerUnit == 0 || cp.PixelsPerUnit == 0 { + continue + } + if cp.Capability != uint32(core.Capability_BYOC) { + continue + } + if manifestID != "" && cp.Constraint != manifestID { + continue + } + return cp + } + + return top +} + // GenerateLivePayment handles remote generation of a payment for live streams. func (ls *LivepeerServer) GenerateLivePayment(w http.ResponseWriter, r *http.Request) { requestID := string(core.RandomManifestID()) @@ -242,9 +339,17 @@ func (ls *LivepeerServer) GenerateLivePayment(w http.ResponseWriter, r *http.Req respondJsonError(ctx, w, err, http.StatusBadRequest) return } - priceInfo := oInfo.PriceInfo + + priceInfo := resolvePriceInfo(&oInfo, req.Type, req.ManifestID) if priceInfo == nil || priceInfo.PricePerUnit == 0 || priceInfo.PixelsPerUnit == 0 { - err := fmt.Errorf("missing or zero priceInfo") + detail := "missing or zero priceInfo" + if req.Type == RemoteType_BYOC { + detail = fmt.Sprintf("missing or zero priceInfo for BYOC capability %q; "+ + "ensure the orchestrator advertises capability-specific pricing "+ + "(CapabilitiesPrices with Capability=BYOC and Constraint=)", + req.ManifestID) + } + err := errors.New(detail) respondJsonError(ctx, w, err, http.StatusBadRequest) return } @@ -294,6 +399,12 @@ func (ls *LivepeerServer) GenerateLivePayment(w http.ResponseWriter, r *http.Req ctx = clog.AddVal(ctx, "seqNo", fmt.Sprintf("%d", state.SequenceNumber)) manifestID := req.ManifestID + byocCapability := "" + if req.Type == RemoteType_BYOC { + if priceInfo.Capability == uint32(core.Capability_BYOC) && priceInfo.Constraint != "" { + byocCapability = priceInfo.Constraint + } + } if manifestID == "" { if hasState { // Required for lv2v so stateful requests stay tied to the same id. @@ -301,7 +412,12 @@ func (ls *LivepeerServer) GenerateLivePayment(w http.ResponseWriter, r *http.Req respondJsonError(ctx, w, err, http.StatusBadRequest) return } - manifestID = string(core.RandomManifestID()) + if req.Type == RemoteType_BYOC && byocCapability != "" { + // For BYOC, use capability name as manifest ID for shared balance tracking + manifestID = byocCapability + } else { + manifestID = string(core.RandomManifestID()) + } } ctx = clog.AddVal(ctx, "manifest_id", manifestID) @@ -388,6 +504,29 @@ func (ls *LivepeerServer) GenerateLivePayment(w http.ResponseWriter, r *http.Req } pixelsPerSec := float64(info.Height) * float64(info.Width) * float64(info.FPS) pixels = int64(pixelsPerSec * billableSecs) // pixels to charge for + } else if req.Type == RemoteType_BYOC { + // BYOC uses time-based pricing: price per unit of time (typically seconds) + // The pixelsPerUnit in the price info represents the time scaling factor + if byocCapability == "" { + err = errors.New("missing BYOC capability in OrchestratorInfo price_info.constraint") + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + now := time.Now() + lastUpdate := state.LastUpdate + if lastUpdate.IsZero() { + // Preload with 120 seconds (2 minutes) of data by default. + // The orchestrator requires minimum 60 seconds balance, so we use 2 minutes + // to have a buffer (matching the Go gateway's approach). + lastUpdate = now.Add(-120 * time.Second) + } + secSinceLastProcessed := now.Sub(lastUpdate).Seconds() + // For BYOC, "pixels" represents time units; pixelsPerUnit is typically 1 (per second) + // We calculate units as: seconds × pixelsPerUnit (which is typically 1) + pixels = int64(secSinceLastProcessed * float64(priceInfo.PixelsPerUnit)) + if pixels < priceInfo.PixelsPerUnit { + pixels = priceInfo.PixelsPerUnit // Minimum 1 unit + } } else if req.Type != "" { err = errors.New("invalid job type") respondJsonError(ctx, w, err, http.StatusBadRequest) diff --git a/server/remote_signer_test.go b/server/remote_signer_test.go index 12aed55461..8a1dbdceed 100644 --- a/server/remote_signer_test.go +++ b/server/remote_signer_test.go @@ -223,6 +223,24 @@ func TestGenerateLivePayment_RequestValidationErrors(t *testing.T) { wantStatus: http.StatusBadRequest, wantMsg: "invalid job type", }, + { + name: "byoc missing capability constraint", + req: func() RemotePaymentRequest { + oInfo := proto.Clone(baseOrchInfo).(*net.OrchestratorInfo) + oInfo.PriceInfo = &net.PriceInfo{ + PricePerUnit: 1, + PixelsPerUnit: 1, + Capability: uint32(core.Capability_BYOC), + Constraint: "", + } + return RemotePaymentRequest{ + Orchestrator: makeOrchBlob(oInfo), + Type: RemoteType_BYOC, + } + }(), + wantStatus: http.StatusBadRequest, + wantMsg: "missing BYOC capability in OrchestratorInfo price_info.constraint", + }, { name: "missing pixels without type", req: func() RemotePaymentRequest { @@ -321,6 +339,32 @@ func TestGenerateLivePayment_RequestValidationErrors(t *testing.T) { } } +func TestResolvePriceInfo_BYOCUsesCapabilitiesPrices(t *testing.T) { + require := require.New(t) + + byocPrice := &net.PriceInfo{ + PricePerUnit: 9, + PixelsPerUnit: 1, + Capability: uint32(core.Capability_BYOC), + Constraint: "acme/model", + } + + oInfo := &net.OrchestratorInfo{ + PriceInfo: &net.PriceInfo{ + PricePerUnit: 3, + PixelsPerUnit: 1, + }, + CapabilitiesPrices: []*net.PriceInfo{ + byocPrice, + }, + } + + require.Same(byocPrice, resolvePriceInfo(oInfo, RemoteType_BYOC, "")) + require.Same(byocPrice, resolvePriceInfo(oInfo, RemoteType_BYOC, "acme/model")) + require.Same(oInfo.PriceInfo, resolvePriceInfo(oInfo, RemoteType_BYOC, "other/model")) + require.Same(oInfo.PriceInfo, resolvePriceInfo(oInfo, RemoteType_LiveVideoToVideo, "")) +} + func TestGenerateLivePayment_StateValidationErrors(t *testing.T) { require := require.New(t) diff --git a/server/rpc.go b/server/rpc.go index da335b98aa..3a0d324f69 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -441,17 +441,19 @@ func orchestratorInfoWithCaps(orch Orchestrator, addr ethcommon.Address, service var priceInfo *net.PriceInfo var capsPrices []*net.PriceInfo var err error - if caps == nil { - //get capability prices - capsPrices, err = orch.GetCapabilitiesPrices(addr) + //get capability prices + capsPrices, err = orch.GetCapabilitiesPrices(addr) + if err != nil { + return nil, err + } + if caps == nil { //get base price priceInfo, err = getPriceInfo(orch, addr, manifestID) if err != nil { return nil, err } } else { - var err error priceInfo, err = orch.PriceInfoForCaps(addr, manifestID, caps) if err != nil { return nil, err diff --git a/server/rpc_test.go b/server/rpc_test.go index fd81e05acb..ab04476630 100644 --- a/server/rpc_test.go +++ b/server/rpc_test.go @@ -1082,6 +1082,22 @@ func TestGetOrchestrator_PriceInfoError(t *testing.T) { assert.EqualError(t, err, expErr.Error()) } +func TestGetOrchestrator_CapabilitiesPricesError(t *testing.T) { + orch := &mockOrchestrator{} + drivers.NodeStorage = drivers.NewMemoryDriver(nil) + expErr := errors.New("capabilities prices error") + + orch.On("VerifySig", mock.Anything, mock.Anything, mock.Anything).Return(true) + orch.On("ServiceURI").Return(url.Parse("http://someuri.com")) + orch.On("Nodes").Return(nil) + orch.On("Address").Return(ethcommon.Address{}) + orch.On("GetCapabilitiesPrices", mock.Anything).Return(nil, expErr) + + _, err := getOrchestrator(orch, &net.OrchestratorRequest{}) + + assert.EqualError(t, err, expErr.Error()) +} + func TestGetOrchestrator_GivenValidSig_ReturnsAuthToken(t *testing.T) { orch := &mockOrchestrator{} drivers.NodeStorage = drivers.NewMemoryDriver(nil) @@ -1379,12 +1395,13 @@ func TestGetOrchestrator_NoCapabilitiesPrices_NoHardware(t *testing.T) { orch.On("AuthToken", mock.Anything, mock.Anything).Return(&net.AuthToken{}) orch.On("PriceInfo", mock.Anything).Return(expectedPrice, nil) orch.On("TicketParams", mock.Anything, mock.Anything).Return(nil, nil) + orch.On("GetCapabilitiesPrices", mock.Anything).Return([]*net.PriceInfo{}, nil) orchInfo, err := getOrchestrator(orch, &net.OrchestratorRequest{Capabilities: caps.ToNetCapabilities()}) assert.Nil(t, err) assert.Nil(t, orchInfo.Hardware) - assert.Nil(t, orchInfo.CapabilitiesPrices) + assert.Empty(t, orchInfo.CapabilitiesPrices) } type mockAICapacityOrch struct { @@ -1531,10 +1548,10 @@ func (o *mockOrchestrator) PriceInfo(sender ethcommon.Address, manifestID core.M func (o *mockOrchestrator) GetCapabilitiesPrices(sender ethcommon.Address) ([]*net.PriceInfo, error) { args := o.Called(sender) if args.Get(0) != nil { - return args.Get(0).([]*net.PriceInfo), nil + return args.Get(0).([]*net.PriceInfo), args.Error(1) } - return []*net.PriceInfo{}, nil + return nil, args.Error(1) } func (o *mockOrchestrator) CheckCapacity(mid core.ManifestID) error { @@ -1743,7 +1760,7 @@ func Test_setLiveAICapacity(t *testing.T) { } } -func TestOrchestratorInfoWithCaps_NonNilEmptyCaps_DoesNotIncludeCapabilitiesPrices(t *testing.T) { +func TestOrchestratorInfoWithCaps_NonNilEmptyCaps_IncludesCapabilitiesPrices(t *testing.T) { require := require.New(t) oldNodeStorage := drivers.NodeStorage @@ -1752,17 +1769,20 @@ func TestOrchestratorInfoWithCaps_NonNilEmptyCaps_DoesNotIncludeCapabilitiesPric orch := &mockOrchestrator{} addr := ethcommon.HexToAddress("0x1") + capPrices := []*net.PriceInfo{{PricePerUnit: 5, PixelsPerUnit: 1}} + priceInfo := &net.PriceInfo{PricePerUnit: 3, PixelsPerUnit: 1} orch.On("Nodes").Return() orch.On("Address").Return(addr) + orch.On("GetCapabilitiesPrices", addr).Return(capPrices, nil) + orch.On("PriceInfoForCaps", addr, core.ManifestID(""), mock.Anything).Return(priceInfo, nil) orch.On("TicketParams", addr, mock.Anything).Return(&net.TicketParams{Recipient: pm.RandBytes(32)}, nil) orch.On("AuthToken", mock.Anything, mock.Anything).Return(&net.AuthToken{Token: []byte("tok"), SessionId: "sess", Expiration: time.Now().Add(time.Hour).Unix()}) nonNilEmptyCaps := core.NewCapabilities(nil, nil).ToNetCapabilities() info, err := orchestratorInfoWithCaps(orch, addr, "https://orch.example.com", "", nonNilEmptyCaps) require.NoError(err) - require.Nil(info.CapabilitiesPrices, "non-nil (even if empty) caps should not return capabilities prices") + require.Equal(capPrices, info.CapabilitiesPrices) - orch.AssertNotCalled(t, "GetCapabilitiesPrices", mock.Anything) orch.AssertNotCalled(t, "PriceInfo", mock.Anything) }