diff --git a/cmd/root/run.go b/cmd/root/run.go index 52eef0751..265ec6c34 100644 --- a/cmd/root/run.go +++ b/cmd/root/run.go @@ -288,13 +288,7 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s return err } - var sessStore session.Store - switch typedRt := rt.(type) { - case *runtime.LocalRuntime: - sessStore = typedRt.SessionStore() - case *runtime.PersistentRuntime: - sessStore = typedRt.SessionStore() - } + sessStore := rt.SessionStore() return runTUI(ctx, rt, sess, f.createSessionSpawner(agentSource, sessStore), initialTeamCleanup, opts...) } @@ -336,10 +330,8 @@ func (f *runExecFlags) createRemoteRuntimeAndSession(ctx context.Context, origin return remoteRt, sess, nil } -func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadResult *teamloader.LoadResult) (runtime.Runtime, *session.Session, error) { - t := loadResult.Team - - agent, err := t.Agent(f.agentName) +func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadResult *teamloader.LoadResult) (rt runtime.Runtime, _ *session.Session, retErr error) { + agent, err := loadResult.Team.Agent(f.agentName) if err != nil { return nil, nil, err } @@ -355,53 +347,26 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes return nil, nil, fmt.Errorf("creating session store: %w", err) } - // Create model switcher config for runtime model switching support - modelSwitcherCfg := &runtime.ModelSwitcherConfig{ - Models: loadResult.Models, - Providers: loadResult.Providers, - ModelsGateway: f.runConfig.ModelsGateway, - EnvProvider: f.runConfig.EnvProvider(), - AgentDefaultModels: loadResult.AgentDefaultModels, - } - - localRt, err := runtime.New(t, - runtime.WithSessionStore(sessStore), - runtime.WithCurrentAgent(f.agentName), - runtime.WithTracer(otel.Tracer(AppName)), - runtime.WithModelSwitcherConfig(modelSwitcherCfg), - ) + rt, err = f.newLocalRuntime(loadResult, &f.runConfig, sessStore) if err != nil { + if closeErr := sessStore.Close(); closeErr != nil { + slog.Error("Failed to close session store", "error", closeErr) + } return nil, nil, fmt.Errorf("creating runtime: %w", err) } + // If anything below fails, close the runtime (which also closes sessStore). + defer func() { + if retErr != nil { + rt.Close() + } + }() var sess *session.Session if f.sessionID != "" { - // Resolve relative session references (e.g., "-1" for last session) - resolvedID, err := session.ResolveSessionID(ctx, sessStore, f.sessionID) + sess, err = f.loadExistingSession(ctx, rt, sessStore) if err != nil { - return nil, nil, fmt.Errorf("resolving session %q: %w", f.sessionID, err) + return nil, nil, err } - - // Load existing session - sess, err = sessStore.GetSession(ctx, resolvedID) - if err != nil { - return nil, nil, fmt.Errorf("loading session %q: %w", resolvedID, err) - } - sess.ToolsApproved = f.autoApprove - sess.HideToolResults = f.hideToolResults - - // Apply any stored model overrides from the session - if len(sess.AgentModelOverrides) > 0 { - if modelSwitcher, ok := localRt.(runtime.ModelSwitcher); ok { - for agentName, modelRef := range sess.AgentModelOverrides { - if err := modelSwitcher.SetAgentModel(ctx, agentName, modelRef); err != nil { - slog.Warn("Failed to apply stored model override", "agent", agentName, "model", modelRef, "error", err) - } - } - } - } - - slog.Debug("Loaded existing session", "session_id", resolvedID, "session_ref", f.sessionID, "agent", f.agentName) } else { wd, _ := os.Getwd() sess = session.New(f.buildSessionOpts(agent.MaxIterations(), agent.ThinkingConfigured(), wd)...) @@ -410,7 +375,38 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes slog.Debug("Using local runtime", "agent", f.agentName, "thinking", agent.ThinkingConfigured()) } - return localRt, sess, nil + return rt, sess, nil +} + +// loadExistingSession resolves a session reference and loads the session from the store, +// reapplying any stored model overrides. +func (f *runExecFlags) loadExistingSession(ctx context.Context, rt runtime.Runtime, sessStore session.Store) (*session.Session, error) { + // Resolve relative session references (e.g., "-1" for last session) + resolvedID, err := session.ResolveSessionID(ctx, sessStore, f.sessionID) + if err != nil { + return nil, fmt.Errorf("resolving session %q: %w", f.sessionID, err) + } + + sess, err := sessStore.GetSession(ctx, resolvedID) + if err != nil { + return nil, fmt.Errorf("loading session %q: %w", resolvedID, err) + } + sess.ToolsApproved = f.autoApprove + sess.HideToolResults = f.hideToolResults + + // Apply any stored model overrides from the session + if len(sess.AgentModelOverrides) > 0 { + if modelSwitcher, ok := rt.(runtime.ModelSwitcher); ok { + for agentName, modelRef := range sess.AgentModelOverrides { + if err := modelSwitcher.SetAgentModel(ctx, agentName, modelRef); err != nil { + slog.Warn("Failed to apply stored model override", "agent", agentName, "model", modelRef, "error", err) + } + } + } + } + + slog.Debug("Loaded existing session", "session_id", resolvedID, "session_ref", f.sessionID, "agent", f.agentName) + return sess, nil } func (f *runExecFlags) handleExecMode(ctx context.Context, out *cli.Printer, rt runtime.Runtime, sess *session.Session, args []string) error { @@ -504,6 +500,22 @@ func (f *runExecFlags) buildSessionOpts(maxIterations int, thinking bool, workin } } +// newLocalRuntime creates a local runtime with model switching support from the given load result. +func (f *runExecFlags) newLocalRuntime(loadResult *teamloader.LoadResult, runConfig *config.RuntimeConfig, sessStore session.Store) (runtime.Runtime, error) { + return runtime.New(loadResult.Team, + runtime.WithSessionStore(sessStore), + runtime.WithCurrentAgent(f.agentName), + runtime.WithTracer(otel.Tracer(AppName)), + runtime.WithModelSwitcherConfig(&runtime.ModelSwitcherConfig{ + Models: loadResult.Models, + Providers: loadResult.Providers, + ModelsGateway: runConfig.ModelsGateway, + EnvProvider: runConfig.EnvProvider(), + AgentDefaultModels: loadResult.AgentDefaultModels, + }), + ) +} + // createSessionSpawner creates a function that can spawn new sessions with different working directories. func (f *runExecFlags) createSessionSpawner(agentSource config.Source, sessStore session.Store) tui.SessionSpawner { return func(spawnCtx context.Context, workingDir string) (*app.App, *session.Session, func(), error) { @@ -517,28 +529,12 @@ func (f *runExecFlags) createSessionSpawner(agentSource config.Source, sessStore return nil, nil, nil, err } - team := loadResult.Team - agent, err := team.Agent(f.agentName) + agent, err := loadResult.Team.Agent(f.agentName) if err != nil { return nil, nil, nil, err } - // Create model switcher config - modelSwitcherCfg := &runtime.ModelSwitcherConfig{ - Models: loadResult.Models, - Providers: loadResult.Providers, - ModelsGateway: runConfigCopy.ModelsGateway, - EnvProvider: runConfigCopy.EnvProvider(), - AgentDefaultModels: loadResult.AgentDefaultModels, - } - - // Create the local runtime - localRt, err := runtime.New(team, - runtime.WithSessionStore(sessStore), - runtime.WithCurrentAgent(f.agentName), - runtime.WithTracer(otel.Tracer(AppName)), - runtime.WithModelSwitcherConfig(modelSwitcherCfg), - ) + localRt, err := f.newLocalRuntime(loadResult, runConfigCopy, sessStore) if err != nil { return nil, nil, nil, err } @@ -547,6 +543,7 @@ func (f *runExecFlags) createSessionSpawner(agentSource config.Source, sessStore newSess := session.New(f.buildSessionOpts(agent.MaxIterations(), agent.ThinkingConfigured(), workingDir)...) // Create cleanup function + team := loadResult.Team cleanup := func() { stopToolSets(team) }