Skip to content

Commit 5857c2d

Browse files
committed
Move agent providers into shared registry
This will allow us to support agents other than Claude Code
1 parent 80963b7 commit 5857c2d

File tree

10 files changed

+272
-64
lines changed

10 files changed

+272
-64
lines changed

lib/roast/cogs/agent.rb

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,14 @@ class MissingPromptError < AgentCogError; end
4848
#: (Input) -> Output
4949
def execute(input)
5050
puts "[USER PROMPT] #{input.valid_prompt!}" if config.show_prompt?
51-
output = provider.invoke(input)
51+
output = config.values[:provider].invoke(input)
5252
# NOTE: If progress is displayed, the agent's response will always be the last progress message,
5353
# so showing it again is duplicative.
5454
puts "[AGENT RESPONSE] #{output.response}" if config.show_response? && !config.show_progress?
5555
puts "[AGENT STATS] #{output.stats}" if config.show_stats?
5656
puts "Session ID: #{output.session}" if config.show_stats?
5757
output
5858
end
59-
60-
private
61-
62-
#: () -> Provider
63-
def provider
64-
@provider ||= case config.valid_provider!
65-
when :claude
66-
Providers::Claude.new(config)
67-
else
68-
raise UnknownProviderError, "Unknown provider: #{config.valid_provider!}"
69-
end
70-
end
7159
end
7260
end
7361
end

lib/roast/cogs/agent/config.rb

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@ module Roast
55
module Cogs
66
class Agent < Cog
77
class Config < Cog::Config
8-
VALID_PROVIDERS = [:claude].freeze #: Array[Symbol]
8+
#: (?Hash[Symbol, untyped]) -> void
9+
def initialize(initial = {})
10+
super(initial)
11+
12+
@provider_registry = {}
13+
end
14+
15+
def validate!
16+
# Provider registry is responsible for ensuring the validity of providers
17+
@values[:provider] = @provider_registry.fetch(@values[:provider]).new(self)
18+
end
919

1020
# Configure the cog to use a specified provider when invoking an agent
1121
#
@@ -36,27 +46,6 @@ def use_default_provider!
3646
@values[:provider] = nil
3747
end
3848

39-
# Get the validated provider name that the cog is configured to use when invoking an agent
40-
#
41-
# Note: this method will return the name of a valid provider or raise an `InvalidConfigError`.
42-
# It will __not__, however, validate that the agent is properly installed on your system.
43-
# If the agent is not properly installed, you will likely experience a failure when Roast attempts to
44-
# run your workflow.
45-
#
46-
# #### See Also
47-
# - `provider`
48-
# - `use_default_provider!`
49-
#
50-
#: () -> Symbol
51-
def valid_provider!
52-
provider = @values[:provider] || VALID_PROVIDERS.first
53-
unless VALID_PROVIDERS.include?(provider)
54-
raise ArgumentError, "'#{provider}' is not a valid provider. Available providers include: #{VALID_PROVIDERS.join(", ")}"
55-
end
56-
57-
provider
58-
end
59-
6049
# Configure the cog to use a specific base command when invoking the agent
6150
#
6251
# The command format is provider-specific.

lib/roast/config_manager.rb

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ class ConfigManagerNotPreparedError < ConfigManagerError; end
88
class ConfigManagerAlreadyPreparedError < ConfigManagerError; end
99
class IllegalCogNameError < ConfigManagerError; end
1010

11-
#: (Cog::Registry, Array[^() -> void]) -> void
12-
def initialize(cog_registry, config_procs)
11+
#: (Cog::Registry, Array[^() -> void], ProviderRegistry) -> void
12+
def initialize(cog_registry, config_procs, provider_registry)
1313
@cog_registry = cog_registry
14+
@provider_registry = provider_registry
1415
@config_procs = config_procs
1516
@config_context = ConfigContext.new #: ConfigContext
1617
@global_config = Cog::Config.new #: Cog::Config
@@ -54,6 +55,16 @@ def config_for(cog_class, name = nil)
5455
name_scoped_config = fetch_name_scoped_config(cog_class, name)
5556
config = config.merge(name_scoped_config)
5657
end
58+
59+
# Special case for agent cog. Insert the provider registry.
60+
# This is the only cog that needs this right now - revisit in the future to
61+
# see if we need a way to define hooks for custom cog data to be distributed
62+
# at config time.
63+
# NOTE: This must happen after all merges, since merge creates new Config instances.
64+
if config.is_a?(Cogs::Agent::Config)
65+
config.instance_variable_set(:@provider_registry, @provider_registry)
66+
end
67+
5768
config.validate!
5869
config
5970
end

lib/roast/provider_registry.rb

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# typed: true
2+
# frozen_string_literal: true
3+
4+
module Roast
5+
# Maintains the list of registered agent providers
6+
# Built in providers are registered automatically.
7+
# Custom agents are registered per workflow with `use`.)
8+
class ProviderRegistry
9+
class ProviderRegistryError < Roast::Error; end
10+
class DuplicateProviderNameError < ProviderRegistryError; end
11+
class ProviderNotFoundError < ProviderRegistryError; end
12+
13+
#: Symbol
14+
attr_accessor :default
15+
16+
delegate :key?, to: :@providers
17+
18+
def initialize
19+
@providers = {} #: Hash[Symbol, singleton(Cogs::Agent::Provider)]
20+
@default = ENV["ROAST_DEFAULT_AGENT"]&.to_sym || :claude
21+
end
22+
23+
#: (singleton(Cogs::Agent::Provider), ?Symbol?) -> void
24+
def register(provider_class, name = nil)
25+
name = build_provider_name(provider_class) if name.blank?
26+
raise DuplicateProviderNameError if @providers.key?(name)
27+
28+
@providers[name] = provider_class
29+
end
30+
31+
def fetch(name)
32+
name = default if name.nil?
33+
raise ProviderNotFoundError unless @providers.key?(name)
34+
35+
@providers.fetch(name)
36+
end
37+
38+
private
39+
40+
#: (singleton(Cogs::Agent::Provider)) -> Symbol
41+
def build_provider_name(provider_class)
42+
provider_class.name.not_nil!.demodulize.underscore.to_sym
43+
end
44+
end
45+
end

lib/roast/workflow.rb

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def initialize(workflow_path, workflow_context)
2828
@workflow_context = workflow_context #: WorkflowContext
2929
@workflow_definition = File.read(workflow_path) #: String
3030
@cog_registry = Cog::Registry.new #: Cog::Registry
31+
@provider_registry = ProviderRegistry.new #: ProviderRegistry
3132
@config_procs = [] #: Array[^() -> void]
3233
@execution_procs = { nil: [] } #: Hash[Symbol?, Array[^() -> void]]
3334
@config_manager = nil #: ConfigManager?
@@ -40,7 +41,8 @@ def prepare!
4041

4142
@preparing = true
4243
extract_dsl_procs!
43-
@config_manager = ConfigManager.new(@cog_registry, @config_procs)
44+
add_providers!
45+
@config_manager = ConfigManager.new(@cog_registry, @config_procs, @provider_registry)
4446
@config_manager.not_nil!.prepare!
4547
# TODO: probably we should just not pass the params as the top-level scope value anymore
4648
@execution_manager = ExecutionManager.new(@cog_registry, @config_manager.not_nil!, @execution_procs, @workflow_context, scope_value: @workflow_context.params)
@@ -118,5 +120,9 @@ def use(cogs = [], from: nil)
118120
def extract_dsl_procs!
119121
instance_eval(@workflow_definition, @workflow_path.realpath.to_s, 1)
120122
end
123+
124+
def add_providers!
125+
@provider_registry.register(Roast::Cogs::Agent::Providers::Claude, :claude)
126+
end
121127
end
122128
end

test/roast/cogs/agent/config_test.rb

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,63 @@ class Agent < Cog
88
class ConfigTest < ActiveSupport::TestCase
99
def setup
1010
@config = Config.new
11-
@default_provider = Config::VALID_PROVIDERS.first
11+
@default_provider = :claude
1212
end
1313

1414
# Provider configuration tests
1515
test "provider sets provider value" do
1616
@config.provider(@default_provider)
1717

18-
assert_equal @default_provider, @config.valid_provider!
18+
assert_equal @default_provider, @config.values[:provider]
1919
end
2020

2121
test "use_default_provider! clears provider value" do
2222
@config.provider(:fake_provider)
2323
@config.use_default_provider!
2424

25-
assert_equal @default_provider, @config.valid_provider!
25+
assert_nil @config.values[:provider]
2626
end
2727

28-
test "valid_provider! returns default when not set" do
29-
assert_equal @default_provider, @config.valid_provider!
28+
test "initialize sets provider_registry to empty hash" do
29+
config = Config.new
30+
31+
assert_equal({}, config.instance_variable_get(:@provider_registry))
3032
end
3133

32-
test "valid_provider! raises on invalid provider" do
33-
@config.provider(:invalid_provider)
34+
test "validate! resolves provider from registry and instantiates with config" do
35+
mock_provider_instance = mock("provider_instance")
36+
mock_provider_class = mock("provider_class")
37+
mock_provider_class.expects(:new).with(@config).returns(mock_provider_instance)
3438

35-
error = assert_raises(ArgumentError) do
36-
@config.valid_provider!
37-
end
39+
registry = { claude: mock_provider_class }
40+
@config.provider(:claude)
41+
@config.instance_variable_set(:@provider_registry, registry)
42+
@config.validate!
43+
44+
assert_equal mock_provider_instance, @config.values[:provider]
45+
end
46+
47+
test "validate! uses nil provider key when provider not explicitly set" do
48+
mock_provider_instance = mock("provider_instance")
49+
mock_provider_class = mock("provider_class")
50+
mock_provider_class.expects(:new).with(@config).returns(mock_provider_instance)
3851

39-
assert_match(/invalid_provider.*not a valid provider/, error.message)
52+
registry = mock("registry")
53+
registry.expects(:fetch).with(nil).returns(mock_provider_class)
54+
@config.instance_variable_set(:@provider_registry, registry)
55+
@config.validate!
56+
57+
assert_equal mock_provider_instance, @config.values[:provider]
58+
end
59+
60+
test "validate! raises when provider not found in registry" do
61+
registry = ProviderRegistry.new
62+
@config.provider(:nonexistent)
63+
@config.instance_variable_set(:@provider_registry, registry)
64+
65+
assert_raises(ProviderRegistry::ProviderNotFoundError) do
66+
@config.validate!
67+
end
4068
end
4169

4270
# Command configuration tests

test/roast/config_manager_test.rb

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def setup
2222
end
2323

2424
test "prepare! transitions to prepared state" do
25-
manager = ConfigManager.new(@registry, [])
25+
manager = ConfigManager.new(@registry, [], ProviderRegistry.new)
2626

2727
refute manager.prepared?
2828
manager.prepare!
2929
assert manager.prepared?
3030
end
3131

3232
test "prepare! raises when called twice" do
33-
manager = ConfigManager.new(@registry, [])
33+
manager = ConfigManager.new(@registry, [], ProviderRegistry.new)
3434
manager.prepare!
3535

3636
assert_raises(ConfigManager::ConfigManagerAlreadyPreparedError) do
@@ -44,22 +44,22 @@ def setup
4444
test_cog { timeout 60 }
4545
timeout_set = true
4646
end
47-
manager = ConfigManager.new(@registry, [config_proc])
47+
manager = ConfigManager.new(@registry, [config_proc], ProviderRegistry.new)
4848
manager.prepare!
4949

5050
assert timeout_set
5151
end
5252

5353
test "config_for raises when not prepared" do
54-
manager = ConfigManager.new(@registry, [])
54+
manager = ConfigManager.new(@registry, [], ProviderRegistry.new)
5555

5656
assert_raises(ConfigManager::ConfigManagerNotPreparedError) do
5757
manager.config_for(TestCog)
5858
end
5959
end
6060

6161
test "config_for returns default config when no config procs are provided" do
62-
manager = ConfigManager.new(@registry, [])
62+
manager = ConfigManager.new(@registry, [], ProviderRegistry.new)
6363
manager.prepare!
6464

6565
config = manager.config_for(TestCog)
@@ -71,7 +71,7 @@ def setup
7171
config_proc = proc do
7272
test_cog { timeout 60 }
7373
end
74-
manager = ConfigManager.new(@registry, [config_proc])
74+
manager = ConfigManager.new(@registry, [config_proc], ProviderRegistry.new)
7575
manager.prepare!
7676

7777
config = manager.config_for(TestCog)
@@ -83,7 +83,7 @@ def setup
8383
config_proc = proc do
8484
test_cog(:my_step) { timeout 90 }
8585
end
86-
manager = ConfigManager.new(@registry, [config_proc])
86+
manager = ConfigManager.new(@registry, [config_proc], ProviderRegistry.new)
8787
manager.prepare!
8888

8989
scoped_config = manager.config_for(TestCog, :my_step)
@@ -97,7 +97,7 @@ def setup
9797
config_proc = proc do
9898
test_cog(/^api_/) { timeout 120 }
9999
end
100-
manager = ConfigManager.new(@registry, [config_proc])
100+
manager = ConfigManager.new(@registry, [config_proc], ProviderRegistry.new)
101101
manager.prepare!
102102

103103
matching_config = manager.config_for(TestCog, :api_call)
@@ -112,7 +112,7 @@ def setup
112112
test_cog { async! }
113113
test_cog(:my_step) { timeout 90 }
114114
end
115-
manager = ConfigManager.new(@registry, [config_proc])
115+
manager = ConfigManager.new(@registry, [config_proc], ProviderRegistry.new)
116116
manager.prepare!
117117

118118
config = manager.config_for(TestCog, :my_step)
@@ -125,14 +125,37 @@ def setup
125125
config_proc = proc do
126126
global { abort_on_failure! }
127127
end
128-
manager = ConfigManager.new(@registry, [config_proc])
128+
manager = ConfigManager.new(@registry, [config_proc], ProviderRegistry.new)
129129
manager.prepare!
130130

131131
config = manager.config_for(TestCog)
132132

133133
assert config.abort_on_failure?
134134
end
135135

136+
test "config_for injects provider registry into Agent::Config" do
137+
registry = Cog::Registry.new
138+
provider_registry = ProviderRegistry.new
139+
provider_registry.register(Cogs::Agent::Providers::Claude, :claude)
140+
141+
manager = ConfigManager.new(registry, [], provider_registry)
142+
manager.prepare!
143+
144+
config = manager.config_for(Cogs::Agent)
145+
146+
assert_equal provider_registry, config.instance_variable_get(:@provider_registry)
147+
end
148+
149+
test "config_for does not inject provider registry into non-Agent configs" do
150+
manager = ConfigManager.new(@registry, [], ProviderRegistry.new)
151+
manager.prepare!
152+
153+
config = manager.config_for(TestCog)
154+
155+
refute config.respond_to?(:provider_registry)
156+
refute config.instance_variable_defined?(:@provider_registry)
157+
end
158+
136159
test "prepare! raises IllegalCogNameError when cog name conflicts with existing method" do
137160
# Register a cog whose derived name ("freeze") conflicts with Object#freeze
138161
conflicting_cog = Class.new(Cog) do
@@ -144,7 +167,7 @@ def name
144167
end
145168
@registry.use(conflicting_cog)
146169

147-
manager = ConfigManager.new(@registry, [])
170+
manager = ConfigManager.new(@registry, [], ProviderRegistry.new)
148171

149172
assert_raises(ConfigManager::IllegalCogNameError) do
150173
manager.prepare!

test/roast/execution_manager_test.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def setup
88
@registry = Cog::Registry.new
99
@registry.use(TestCogSupport::TestCog)
1010

11-
@config_manager = ConfigManager.new(@registry, [])
11+
@config_manager = ConfigManager.new(@registry, [], ProviderRegistry.new)
1212
@config_manager.prepare!
1313

1414
@workflow_context = WorkflowContext.new(
@@ -68,7 +68,7 @@ def name
6868
conflicting_registry.use(conflicting_cog)
6969

7070
# Use a clean registry for config_manager so it doesn't hit the same conflict
71-
config_manager = ConfigManager.new(Cog::Registry.new, [])
71+
config_manager = ConfigManager.new(Cog::Registry.new, [], ProviderRegistry.new)
7272
config_manager.prepare!
7373

7474
manager = ExecutionManager.new(

0 commit comments

Comments
 (0)