-
Notifications
You must be signed in to change notification settings - Fork 15.8k
feat: add variant support for subagents (#7138) [alt of #7140] #7156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,7 @@ import { useToast } from "../../ui/toast" | |
| import { useKV } from "../../context/kv" | ||
| import { useTextareaKeybindings } from "../textarea-keybindings" | ||
| import { DialogSkill } from "../dialog-skill" | ||
| import { getConfiguredAgentVariant } from "@tui/context/model-variant" | ||
|
|
||
| export type PromptProps = { | ||
| sessionID?: string | ||
|
|
@@ -197,7 +198,20 @@ export function Prompt(props: PromptProps) { | |
| if (msg.agent && isPrimaryAgent) { | ||
| local.agent.set(msg.agent) | ||
| if (msg.model) local.model.set(msg.model) | ||
| if (msg.variant) local.model.variant.set(msg.variant) | ||
| if (msg.variant) { | ||
| const info = local.agent.list().find((x) => x.name === msg.agent) | ||
| const provider = msg.model ? sync.data.provider.find((x) => x.id === msg.model.providerID) : undefined | ||
| const model = msg.model ? provider?.models[msg.model.modelID] : undefined | ||
| const configured = getConfiguredAgentVariant({ | ||
| agent: { model: info?.model, variant: info?.variant }, | ||
| model: msg.model | ||
| ? { providerID: msg.model.providerID, modelID: msg.model.modelID, variants: model?.variants } | ||
| : undefined, | ||
| }) | ||
| if (msg.variant === configured) local.model.variant.set(undefined) | ||
| if (msg.variant !== configured) local.model.variant.set(msg.variant) | ||
| } | ||
| if (!msg.variant) local.model.variant.set(undefined) | ||
|
Comment on lines
+210
to
+214
|
||
| } | ||
| } | ||
| }) | ||
|
|
@@ -805,8 +819,8 @@ export function Prompt(props: PromptProps) { | |
| const showVariant = createMemo(() => { | ||
| const variants = local.model.variant.list() | ||
| if (variants.length === 0) return false | ||
| const current = local.model.variant.current() | ||
| return !!current | ||
| const effective = local.model.variant.effective() | ||
| return !!effective | ||
| }) | ||
|
|
||
| const placeholderText = createMemo(() => { | ||
|
|
@@ -1073,7 +1087,7 @@ export function Prompt(props: PromptProps) { | |
| <Show when={showVariant()}> | ||
| <text fg={theme.textMuted}>·</text> | ||
| <text> | ||
| <span style={{ fg: theme.warning, bold: true }}>{local.model.variant.current()}</span> | ||
| <span style={{ fg: theme.warning, bold: true }}>{local.model.variant.effective()}</span> | ||
| </text> | ||
| </Show> | ||
| </box> | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ import { useArgs } from "./args" | |||||||||||||||||||||||||||||||||||||
| import { useSDK } from "./sdk" | ||||||||||||||||||||||||||||||||||||||
| import { RGBA } from "@opentui/core" | ||||||||||||||||||||||||||||||||||||||
| import { Filesystem } from "@/util/filesystem" | ||||||||||||||||||||||||||||||||||||||
| import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ | ||||||||||||||||||||||||||||||||||||||
| name: "Local", | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -321,17 +322,32 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ | |||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||
| variant: { | ||||||||||||||||||||||||||||||||||||||
| configured() { | ||||||||||||||||||||||||||||||||||||||
| const a = agent.current() | ||||||||||||||||||||||||||||||||||||||
| const m = currentModel() | ||||||||||||||||||||||||||||||||||||||
| if (!m) return undefined | ||||||||||||||||||||||||||||||||||||||
| const provider = sync.data.provider.find((x) => x.id === m.providerID) | ||||||||||||||||||||||||||||||||||||||
| const info = provider?.models[m.modelID] | ||||||||||||||||||||||||||||||||||||||
| return getConfiguredAgentVariant({ | ||||||||||||||||||||||||||||||||||||||
| agent: { model: a.model, variant: a.variant }, | ||||||||||||||||||||||||||||||||||||||
| model: { providerID: m.providerID, modelID: m.modelID, variants: info?.variants }, | ||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||
| selected() { | ||||||||||||||||||||||||||||||||||||||
| const m = currentModel() | ||||||||||||||||||||||||||||||||||||||
| if (!m) return undefined | ||||||||||||||||||||||||||||||||||||||
| const key = `${m.providerID}/${m.modelID}` | ||||||||||||||||||||||||||||||||||||||
| return modelStore.variant[key] | ||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||
| current() { | ||||||||||||||||||||||||||||||||||||||
| const v = this.selected() | ||||||||||||||||||||||||||||||||||||||
| if (!v) return undefined | ||||||||||||||||||||||||||||||||||||||
| if (!this.list().includes(v)) return undefined | ||||||||||||||||||||||||||||||||||||||
| return v | ||||||||||||||||||||||||||||||||||||||
| return resolveModelVariant({ | ||||||||||||||||||||||||||||||||||||||
| variants: this.list(), | ||||||||||||||||||||||||||||||||||||||
| selected: this.selected(), | ||||||||||||||||||||||||||||||||||||||
| configured: this.configured(), | ||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||
| effective() { | ||||||||||||||||||||||||||||||||||||||
| return this.current() | ||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+343
to
351
|
||||||||||||||||||||||||||||||||||||||
| return resolveModelVariant({ | |
| variants: this.list(), | |
| selected: this.selected(), | |
| configured: this.configured(), | |
| }) | |
| }, | |
| effective() { | |
| return this.current() | |
| }, | |
| return this.selected() | |
| }, | |
| effective() { | |
| return resolveModelVariant({ | |
| variants: this.list(), | |
| selected: this.selected(), | |
| configured: this.configured(), | |
| }) | |
| }, |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| type AgentModel = { | ||
| providerID: string | ||
| modelID: string | ||
| } | ||
|
|
||
| type Agent = { | ||
| model?: AgentModel | ||
| variant?: string | ||
| } | ||
|
|
||
| type Model = AgentModel & { | ||
| variants?: Record<string, unknown> | ||
| } | ||
|
|
||
| type VariantInput = { | ||
| variants: string[] | ||
| selected: string | undefined | ||
| configured: string | undefined | ||
| } | ||
|
|
||
| export function getConfiguredAgentVariant(input: { agent: Agent | undefined; model: Model | undefined }) { | ||
| if (!input.agent?.variant) return undefined | ||
| if (!input.agent.model) return undefined | ||
| if (!input.model?.variants) return undefined | ||
| if (input.agent.model.providerID !== input.model.providerID) return undefined | ||
| if (input.agent.model.modelID !== input.model.modelID) return undefined | ||
| if (!(input.agent.variant in input.model.variants)) return undefined | ||
| return input.agent.variant | ||
| } | ||
|
|
||
| export function resolveModelVariant(input: VariantInput) { | ||
| if (input.selected === "default") return undefined | ||
| if (input.selected && input.variants.includes(input.selected)) return input.selected | ||
| if (input.configured && input.variants.includes(input.configured)) return input.configured | ||
| return undefined | ||
| } | ||
|
|
||
| export function cycleModelVariant(input: VariantInput) { | ||
| if (input.variants.length === 0) return undefined | ||
| if (input.selected === "default") return input.variants[0] | ||
| if (input.selected && input.variants.includes(input.selected)) { | ||
| const index = input.variants.indexOf(input.selected) | ||
| if (index === input.variants.length - 1) return undefined | ||
| return input.variants[index + 1] | ||
| } | ||
| return input.variants[0] | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| import { describe, expect, test } from "bun:test" | ||
| import { | ||
| cycleModelVariant, | ||
| getConfiguredAgentVariant, | ||
| resolveModelVariant, | ||
| } from "../../../src/cli/cmd/tui/context/model-variant" | ||
|
|
||
| describe("tui model variant", () => { | ||
| test("resolves configured variant when active model matches", () => { | ||
| const value = getConfiguredAgentVariant({ | ||
| agent: { model: { providerID: "openai", modelID: "gpt-5.2" }, variant: "high" }, | ||
| model: { providerID: "openai", modelID: "gpt-5.2", variants: { low: {}, high: {} } }, | ||
| }) | ||
|
|
||
| expect(value).toBe("high") | ||
| }) | ||
|
|
||
| test("ignores configured variant when active model does not match", () => { | ||
| const value = getConfiguredAgentVariant({ | ||
| agent: { model: { providerID: "openai", modelID: "gpt-5.2" }, variant: "high" }, | ||
| model: { providerID: "anthropic", modelID: "claude-sonnet-4", variants: { low: {}, high: {} } }, | ||
| }) | ||
|
|
||
| expect(value).toBeUndefined() | ||
| }) | ||
|
|
||
| test("prefers selected variant over configured variant", () => { | ||
| const value = resolveModelVariant({ | ||
| variants: ["low", "high", "xhigh"], | ||
| selected: "xhigh", | ||
| configured: "high", | ||
| }) | ||
|
|
||
| expect(value).toBe("xhigh") | ||
| }) | ||
|
|
||
| test("treats default sentinel as explicit default", () => { | ||
| const value = resolveModelVariant({ | ||
| variants: ["low", "high", "xhigh"], | ||
| selected: "default", | ||
| configured: "high", | ||
| }) | ||
|
|
||
| expect(value).toBeUndefined() | ||
| }) | ||
|
|
||
| test("cycles from default sentinel to first variant", () => { | ||
| const value = cycleModelVariant({ | ||
| variants: ["low", "high", "xhigh"], | ||
| selected: "default", | ||
| configured: "high", | ||
| }) | ||
|
|
||
| expect(value).toBe("low") | ||
| }) | ||
|
|
||
| test("cycles through all variants from explicit selection", () => { | ||
| const variants = ["low", "high", "xhigh"] | ||
| const first = cycleModelVariant({ variants, selected: undefined, configured: "high" }) | ||
| const second = cycleModelVariant({ variants, selected: first, configured: "high" }) | ||
| const third = cycleModelVariant({ variants, selected: second, configured: "high" }) | ||
| const fourth = cycleModelVariant({ variants, selected: third, configured: "high" }) | ||
|
|
||
| expect(first).toBe("low") | ||
| expect(second).toBe("high") | ||
| expect(third).toBe("xhigh") | ||
| expect(fourth).toBeUndefined() | ||
| }) | ||
| }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block recomputes the configured agent variant by re-looking-up the agent info and provider model metadata. Since
LocalContextnow haslocal.model.variant.configured(), consider using that directly (after setting agent/model) to avoid duplicating the match/variant-availability logic in multiple places and risking future divergence.