Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,27 @@ var containerHostPorts = map[string]string{
"live-video-to-video": "8900",
}

// Mapping for per pipeline container images.
// Default pipeline container image mapping to use if no overrides are provided.
var defaultBaseImage = "livepeer/ai-runner:latest"
var pipelineToImage = map[string]string{
"segment-anything-2": "livepeer/ai-runner:segment-anything-2",
"text-to-speech": "livepeer/ai-runner:text-to-speech",
"audio-to-text": "livepeer/ai-runner:audio-to-text",
"llm": "livepeer/ai-runner:llm",
}

var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"segment_anything_2": "livepeer/ai-runner:live-app-segment_anything_2",
"noop": "livepeer/ai-runner:live-app-noop",
}

type ImageOverrides struct {
Default string `json:"default"`
Batch map[string]string `json:"batch"`
Live map[string]string `json:"live"`
}

// DockerClient is an interface for the Docker client, allowing for mocking in tests.
// NOTE: ensure any docker.Client methods used in this package are added.
type DockerClient interface {
Expand All @@ -91,9 +97,9 @@ var _ DockerClient = (*docker.Client)(nil)
var dockerWaitUntilRunningFunc = dockerWaitUntilRunning

type DockerManager struct {
defaultImage string
gpus []string
modelDir string
gpus []string
modelDir string
overrides ImageOverrides

dockerClient DockerClient
// gpu ID => container name
Expand All @@ -103,7 +109,7 @@ type DockerManager struct {
mu *sync.Mutex
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
func NewDockerManager(overrides ImageOverrides, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, client); err != nil {
cancel()
Expand All @@ -112,9 +118,9 @@ func NewDockerManager(defaultImage string, gpus []string, modelDir string, clien
cancel()

manager := &DockerManager{
defaultImage: defaultImage,
gpus: gpus,
modelDir: modelDir,
overrides: overrides,
dockerClient: client,
gpuContainers: make(map[string]string),
containers: make(map[string]*RunnerContainer),
Expand Down Expand Up @@ -215,17 +221,24 @@ func (m *DockerManager) returnContainer(rc *RunnerContainer) {
func (m *DockerManager) getContainerImageName(pipeline, modelID string) (string, error) {
if pipeline == "live-video-to-video" {
// We currently use the model ID as the live pipeline name for legacy reasons.
if image, ok := livePipelineToImage[modelID]; ok {
if image, ok := m.overrides.Live[modelID]; ok {
return image, nil
} else if image, ok := livePipelineToImage[modelID]; ok {
return image, nil
}
return "", fmt.Errorf("no container image found for live pipeline %s", modelID)
}

if image, ok := pipelineToImage[pipeline]; ok {
if image, ok := m.overrides.Batch[pipeline]; ok {
return image, nil
} else if image, ok := pipelineToImage[pipeline]; ok {
return image, nil
}

return m.defaultImage, nil
if m.overrides.Default != "" {
return m.overrides.Default, nil
}
return defaultBaseImage, nil
}

// HasCapacity checks if an unused managed container exists or if a GPU is available for a new container.
Expand Down
95 changes: 89 additions & 6 deletions worker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ func NewMockServer() *MockServer {
// createDockerManager creates a DockerManager with a mock DockerClient.
func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager {
return &DockerManager{
defaultImage: "default-image",
gpus: []string{"gpu0"},
modelDir: "/models",
overrides: ImageOverrides{Default: "default-image"},
dockerClient: mockDockerClient,
gpuContainers: make(map[string]string),
containers: make(map[string]*RunnerContainer),
Expand All @@ -110,10 +110,10 @@ func TestNewDockerManager(t *testing.T) {
mockDockerClient := new(MockDockerClient)

createAndVerifyManager := func() *DockerManager {
manager, err := NewDockerManager("default-image", []string{"gpu0"}, "/models", mockDockerClient)
manager, err := NewDockerManager(ImageOverrides{Default: "default-image"}, []string{"gpu0"}, "/models", mockDockerClient)
require.NoError(t, err)
require.NotNil(t, manager)
require.Equal(t, "default-image", manager.defaultImage)
require.Equal(t, "default-image", manager.overrides.Default)
require.Equal(t, []string{"gpu0"}, manager.gpus)
require.Equal(t, "/models", manager.modelDir)
require.Equal(t, mockDockerClient, manager.dockerClient)
Expand Down Expand Up @@ -301,47 +301,130 @@ func TestDockerManager_returnContainer(t *testing.T) {

func TestDockerManager_getContainerImageName(t *testing.T) {
mockDockerClient := new(MockDockerClient)
manager := createDockerManager(mockDockerClient)
dockerManager := createDockerManager(mockDockerClient)

tests := []struct {
name string
setup func(*DockerManager, *MockDockerClient)
pipeline string
modelID string
expectedImage string
expectError bool
}{
{
name: "live-video-to-video with valid modelID",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "live-video-to-video",
modelID: "streamdiffusion",
expectedImage: "livepeer/ai-runner:live-app-streamdiffusion",
expectError: false,
},
{
name: "live-video-to-video with invalid modelID",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "live-video-to-video",
modelID: "invalid-model",
expectError: true,
},
{
name: "valid pipeline",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "text-to-speech",
modelID: "",
expectedImage: "livepeer/ai-runner:text-to-speech",
expectError: false,
},
{
name: "invalid pipeline",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "invalid-pipeline",
modelID: "",
expectedImage: "default-image",
expectError: false,
},
{
name: "override default image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "custom-image",
}
},
pipeline: "",
modelID: "",
expectedImage: "custom-image",
expectError: false,
},
{
name: "override batch image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Batch: map[string]string{
"text-to-speech": "custom-image",
},
}
},
pipeline: "text-to-speech",
modelID: "",
expectedImage: "custom-image",
expectError: false,
},
{
name: "override live image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Live: map[string]string{
"streamdiffusion": "custom-image",
},
}
},
pipeline: "live-video-to-video",
modelID: "streamdiffusion",
expectedImage: "custom-image",
expectError: false,
},
{
name: "non-overridden batch image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "default-image",
Batch: map[string]string{
"text-to-speech": "custom-batch-image",
},
Live: map[string]string{
"streamdiffusion": "custom-live-image",
},
}
},
pipeline: "audio-to-text",
modelID: "",
expectedImage: "livepeer/ai-runner:audio-to-text",
expectError: false,
},
{
name: "non-overridden live image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "default-image",
Batch: map[string]string{
"text-to-speech": "custom-batch-image",
},
Live: map[string]string{
"streamdiffusion": "custom-live-image",
},
}
},
pipeline: "live-video-to-video",
modelID: "comfyui",
expectedImage: "livepeer/ai-runner:live-app-comfyui",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
image, err := manager.getContainerImageName(tt.pipeline, tt.modelID)
tt.setup(dockerManager, mockDockerClient)

image, err := dockerManager.getContainerImageName(tt.pipeline, tt.modelID)
if tt.expectError {
require.Error(t, err)
require.Equal(t, fmt.Sprintf("no container image found for live pipeline %s", tt.modelID), err.Error())
Expand Down Expand Up @@ -500,7 +583,7 @@ func TestDockerManager_createContainer(t *testing.T) {
dockerManager.gpus = []string{gpu}
dockerManager.gpuContainers = make(map[string]string)
dockerManager.containers = make(map[string]*RunnerContainer)
dockerManager.defaultImage = containerImage
dockerManager.overrides.Default = containerImage

mockDockerClient.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: containerID}, nil)
mockDockerClient.On("ContainerStart", mock.Anything, containerID, mock.Anything).Return(nil)
Expand Down
4 changes: 2 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ type Worker struct {
mu *sync.Mutex
}

func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) {
func NewWorker(imageOverrides ImageOverrides, gpus []string, modelDir string) (*Worker, error) {
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}

manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient)
manager, err := NewDockerManager(imageOverrides, gpus, modelDir, dockerClient)
if err != nil {
return nil, err
}
Expand Down