diff --git a/cecli/__init__.py b/cecli/__init__.py index 85f102b2b96..49025eeddbe 100644 --- a/cecli/__init__.py +++ b/cecli/__init__.py @@ -1,6 +1,6 @@ from packaging import version -__version__ = "0.99.9.dev" +__version__ = "0.99.10.dev" safe_version = __version__ try: diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index d261e772222..07aa4304642 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -31,6 +31,8 @@ from .base_coder import Coder +from cecli.helpers.coroutines import interruptible # isort:skip + class AgentCoder(Coder): """Mode where the LLM autonomously manages which files are in context.""" @@ -42,6 +44,9 @@ class AgentCoder(Coder): stop_on_empty = False def __init__(self, *args, **kwargs): + if kwargs.get("uuid", None): + self.uuid = kwargs.get("uuid") + self.recently_removed = {} self.tool_usage_history = [] self.loaded_custom_tools = [] @@ -55,7 +60,7 @@ def __init__(self, *args, **kwargs): "commandinteractive", "explorecode", "ls", - "getlines", + "readrange", "grep", "thinking", "updatetodolist", @@ -301,8 +306,23 @@ async def _execute_local_tool_calls(self, tool_calls_list): else: all_results_content.append(f"Error: Unknown tool name '{tool_name}'") if tasks: - task_results = await asyncio.gather(*tasks) - all_results_content.extend(str(res) for res in task_results) + + async def gather_and_await(): + return await asyncio.gather(*tasks, return_exceptions=True) + + task_results, interrupted = await interruptible( + gather_and_await(), self.interrupt_event + ) + + if interrupted: + self.io.tool_warning("Tool execution interrupted.") + all_results_content.append("Tool execution interrupted by user.") + elif task_results: + for res in task_results: + if isinstance(res, Exception): + all_results_content.append(f"Error in tool execution: {res}") + else: + all_results_content.append(str(res)) if not await HookIntegration.call_post_tool_hooks( self, tool_name, args_string, "\n\n".join(all_results_content) @@ -393,7 +413,11 @@ async def _exec_async(): """) return f"Error executing tool call {tool_name}: {e}" - return await _exec_async() + result, interrupted = await interruptible(_exec_async(), self.interrupt_event) + + if interrupted: + return "Tool execution interrupted by user." + return result def _calculate_context_block_tokens(self, force=False): """ @@ -995,7 +1019,7 @@ def _generate_tool_context(self, repetitive_tools): context_parts.append("\n\n") context_parts.append("## File Editing Tools Disabled") context_parts.append( - "File editing tools are currently disabled.Use `GetLines` to determine the" + "File editing tools are currently disabled.Use `ReadRange` to determine the" " current hashline prefixes needed to perform an edit and activate them when you" " are ready to edit a file." ) diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 319e7b640bf..c0354c9cb6f 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -139,6 +139,12 @@ class Coder: partial_response_tool_calls = [] commit_before_message = [] message_cost = 0.0 + total_tokens_sent = 0 + total_tokens_received = 0 + total_cached_tokens = 0 + message_tokens_sent = 0 + message_tokens_received = 0 + message_cached_tokens = 0 add_cache_headers = False cache_warming_thread = None num_cache_warming_pings = 0 @@ -227,6 +233,7 @@ async def create( ignore_mentions=from_coder.ignore_mentions, total_tokens_sent=from_coder.total_tokens_sent, total_tokens_received=from_coder.total_tokens_received, + total_cached_tokens=from_coder.total_cached_tokens, file_watcher=from_coder.file_watcher, mcp_manager=from_coder.mcp_manager, uuid=from_coder.uuid, @@ -316,6 +323,7 @@ def __init__( ignore_mentions=None, total_tokens_sent=0, total_tokens_received=0, + total_cached_tokens=0, file_watcher=None, auto_copy_context=False, auto_accept_architect=True, @@ -331,6 +339,7 @@ def __init__( ): # initialize from args.map_cache_dir self.interrupt_event = asyncio.Event() + self.coroutines = coroutines self.uuid = generate_unique_id() if uuid: self.uuid = uuid @@ -388,8 +397,10 @@ def __init__( self.total_cost = total_cost self.total_tokens_sent = total_tokens_sent self.total_tokens_received = total_tokens_received + self.total_cached_tokens = total_cached_tokens self.message_tokens_sent = 0 self.message_tokens_received = 0 + self.message_cached_tokens = 0 self.token_profiler = TokenProfiler( enable_printing=nested.getter(self.args, "show_speed", False) @@ -1370,11 +1381,6 @@ async def _run_parallel(self, with_message=None, preproc=True): except (SwitchCoderSignal, SystemExit): # Re-raise SwitchCoder to be handled by outer try block raise - except KeyboardInterrupt: - # Handle keyboard interrupt gracefully - self.io.set_placeholder("") - self.io.stop_spinner() - self.keyboard_interrupt() finally: # Signal tasks to stop self.input_running = False @@ -1454,10 +1460,6 @@ async def input_task(self, preproc): await asyncio.sleep(0.1) # Small yield to prevent tight loop - except KeyboardInterrupt: - self.io.set_placeholder("") - self.keyboard_interrupt() - await self.io.stop_task_streams() except (SwitchCoderSignal, SystemExit): raise except Exception as e: @@ -1738,8 +1740,7 @@ def keyboard_interrupt(self): # Ensure cursor is visible on exit Console().show_cursor(True) - self.io.tool_warning("\n\n^C KeyboardInterrupt") - + self.io.tool_warning("^C KeyboardInterrupt") self.interrupt_event.set() self.last_keyboard_interrupt = time.time() @@ -2262,9 +2263,16 @@ async def send_message(self, inp): self.io.tool_error(err_msg) self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...") - await asyncio.sleep(retry_delay) + + _res, interrupted_sleep = await coroutines.interruptible( + asyncio.sleep(retry_delay), self.interrupt_event + ) + if interrupted_sleep: + interrupted = True + break + continue - except KeyboardInterrupt: + except (KeyboardInterrupt, asyncio.CancelledError): interrupted = True break except FinishReasonLength: @@ -2629,11 +2637,19 @@ async def _execute_mcp_tools(self, server, tool_calls): all_results_content.append("Tool Request Aborted.") continue - call_result = await experimental_mcp_client.call_openai_tool( - session=session, - openai_tool=new_tool_call, + async def do_tool_call(): + return await experimental_mcp_client.call_openai_tool( + session=session, + openai_tool=new_tool_call, + ) + + call_result, interrupted = await coroutines.interruptible( + do_tool_call(), self.interrupt_event ) + if interrupted: + raise KeyboardInterrupt("Tool call interrupted") + content_parts = [] if call_result.content: for item in call_result.content: @@ -2678,6 +2694,9 @@ async def _execute_mcp_tools(self, server, tool_calls): } ) + except KeyboardInterrupt: + self.io.tool_warning(f"Tool call {tool_call.function.name} interrupted.") + raise except Exception as e: tool_error = f"Error executing tool call {tool_call.function.name}: \n{e}" self.io.tool_warning( @@ -2694,6 +2713,9 @@ async def _execute_mcp_tools(self, server, tool_calls): tool_responses.append( {"role": "tool", "tool_call_id": tool_call.id, "content": connection_error} ) + except asyncio.CancelledError: + # Re-raise CancelledError to ensure the task cancellation propagates + raise except Exception as e: connection_error = f"Could not connect to server {server.name}\n{e}" self.io.tool_warning(connection_error) @@ -2728,7 +2750,15 @@ async def process_tool_calls(self, tool_call_response): return False # 5. Execute tools - tool_responses_by_server = await self._execute_tool_groups(tool_groups) + self.interrupt_event.clear() + + tool_responses_by_server, interrupted = await coroutines.interruptible( + self._execute_tool_groups(tool_groups), self.interrupt_event + ) + + if interrupted: + self.io.tool_warning("Tool execution interrupted.") + return False # 6. Add responses to conversation (re-prefixing if necessary) tool_responses = [] @@ -3040,33 +3070,22 @@ async def send(self, messages, model=None, functions=None, tools=None): self.token_profiler.start() try: - completion_task = asyncio.create_task( - model.send_completion( - messages, - functions, - self.stream, - self.temperature, - # This could include any tools, but for now it is just MCP tools - tools=tools, - override_kwargs=self.model_kwargs.copy(), - ) + completion_coro = model.send_completion( + messages, + functions, + self.stream, + self.temperature, + # This could include any tools, but for now it is just MCP tools + tools=tools, + override_kwargs=self.model_kwargs.copy(), + interrupt_event=self.interrupt_event, ) - interrupt_task = asyncio.create_task(self.interrupt_event.wait()) - done, pending = await asyncio.wait( - {completion_task, interrupt_task}, - return_when=asyncio.FIRST_COMPLETED, + (hash_object, completion), interrupted = await coroutines.interruptible( + completion_coro, self.interrupt_event ) - - if interrupt_task in done: - completion_task.cancel() - try: - await completion_task - except asyncio.CancelledError: - pass + if interrupted: raise KeyboardInterrupt - - hash_object, completion = completion_task.result() self.chat_completion_call_hashes.append(hash_object.hexdigest()) if not isinstance(completion, ModelResponse): @@ -3089,7 +3108,7 @@ async def send(self, messages, model=None, functions=None, tools=None): self.token_profiler.on_error() self.calculate_and_show_tokens_and_cost(messages, completion) raise - except KeyboardInterrupt as kbi: + except (KeyboardInterrupt, asyncio.CancelledError) as kbi: self.keyboard_interrupt() raise kbi finally: @@ -3498,10 +3517,13 @@ def calculate_and_show_tokens_and_cost(self, messages, completion=None): if completion and hasattr(completion, "usage") and completion.usage is not None: prompt_tokens = completion.usage.prompt_tokens completion_tokens = completion.usage.completion_tokens - cache_hit_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr( - completion.usage, "cache_read_input_tokens", 0 + cache_hit_tokens = ( + getattr(completion.usage, "prompt_cache_hit_tokens", 0) + or getattr(completion.usage, "cache_read_input_tokens", 0) + or 0 ) - cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0) + cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0) or 0 + self.message_cached_tokens += cache_hit_tokens if hasattr(completion.usage, "cache_read_input_tokens") or hasattr( completion.usage, "cache_creation_input_tokens" @@ -3534,8 +3556,22 @@ def calculate_and_show_tokens_and_cost(self, messages, completion=None): tokens_report, self.message_tokens_sent, self.message_tokens_received ) + total_combined_tokens = ( + self.total_tokens_sent + + self.total_tokens_received + + self.message_tokens_sent + + self.message_tokens_received + ) + total_combined_cached = self.total_cached_tokens + self.message_cached_tokens + + total_stats = f"{format_tokens(total_combined_tokens)}" + if total_combined_cached: + total_stats += f"/{format_tokens(total_combined_cached)}" + + total_stats += " ↑↓" + if not self.get_active_model().info.get("input_cost_per_token"): - self.usage_report = tokens_report + self.usage_report = tokens_report + "\n" + total_stats return try: @@ -3552,11 +3588,8 @@ def calculate_and_show_tokens_and_cost(self, messages, completion=None): self.total_cost += cost self.message_cost += cost - total_combined_tokens = ( - self.total_tokens_sent + self.total_tokens_received + prompt_tokens + completion_tokens - ) cost_report = ( - f"${self.format_cost(self.message_cost)} • {format_tokens(total_combined_tokens)} ↑↓" + f"${self.format_cost(self.message_cost)} • {total_stats}" f" ${self.format_cost(self.total_cost)}" ) @@ -3614,6 +3647,7 @@ def show_usage_report(self): self.total_tokens_sent += self.message_tokens_sent self.total_tokens_received += self.message_tokens_received + self.total_cached_tokens += self.message_cached_tokens if self.tui and self.tui(): self.tui().update_cost(self.usage_report.replace("\n", " ")) @@ -3624,6 +3658,7 @@ def show_usage_report(self): self.message_cost = 0.0 self.message_tokens_sent = 0 self.message_tokens_received = 0 + self.message_cached_tokens = 0 def get_multi_response_content_in_progress(self, final=False): cur = self.multi_response_content or "" diff --git a/cecli/commands/__init__.py b/cecli/commands/__init__.py index d608ffbd37f..81e4d4c9d4a 100644 --- a/cecli/commands/__init__.py +++ b/cecli/commands/__init__.py @@ -25,14 +25,17 @@ from .drop import DropCommand from .editor import EditCommand, EditorCommand from .editor_model import EditorModelCommand +from .exclude_skill import ExcludeSkillCommand from .exit import ExitCommand from .git import GitCommand from .hashline import HashlineCommand from .help import HelpCommand from .history_search import HistorySearchCommand from .hooks import HooksCommand +from .include_skill import IncludeSkillCommand from .lint import LintCommand from .list_sessions import ListSessionsCommand +from .list_skills import ListSkillsCommand from .load import LoadCommand from .load_hook import LoadHookCommand from .load_mcp import LoadMcpCommand @@ -102,14 +105,17 @@ CommandRegistry.register(EditCommand) CommandRegistry.register(EditorCommand) CommandRegistry.register(EditorModelCommand) +CommandRegistry.register(ExcludeSkillCommand) CommandRegistry.register(ExitCommand) CommandRegistry.register(GitCommand) CommandRegistry.register(HashlineCommand) CommandRegistry.register(HelpCommand) CommandRegistry.register(HistorySearchCommand) CommandRegistry.register(HooksCommand) +CommandRegistry.register(IncludeSkillCommand) CommandRegistry.register(LintCommand) CommandRegistry.register(ListSessionsCommand) +CommandRegistry.register(ListSkillsCommand) CommandRegistry.register(LoadCommand) CommandRegistry.register(LoadHookCommand) CommandRegistry.register(LoadMcpCommand) @@ -172,6 +178,7 @@ "EditCommand", "EditorCommand", "EditorModelCommand", + "ExcludeSkillCommand", "ExitCommand", "expand_subdir", "format_command_result", @@ -182,8 +189,10 @@ "HelpCommand", "HistorySearchCommand", "HookCommand", + "IncludeSkillCommand", "LintCommand", "ListSessionsCommand", + "ListSkillsCommand", "LoadCommand", "LoadHookCommand", "LoadMcpCommand", diff --git a/cecli/commands/exclude_skill.py b/cecli/commands/exclude_skill.py new file mode 100644 index 00000000000..8f8197eba5b --- /dev/null +++ b/cecli/commands/exclude_skill.py @@ -0,0 +1,72 @@ +from typing import List + +from cecli.commands.utils.base_command import BaseCommand +from cecli.commands.utils.helpers import format_command_result + + +class ExcludeSkillCommand(BaseCommand): + NORM_NAME = "exclude-skill" + DESCRIPTION = "Exclude a skill by name (agent mode only)" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the exclude-skill command with given parameters.""" + if not args.strip(): + io.tool_output("Usage: /exclude-skill ") + return format_command_result(io, "exclude-skill", "Usage: /exclude-skill ") + + skill_names = args.strip().split() + + # Check if we're in agent mode + if not hasattr(coder, "edit_format") or coder.edit_format != "agent": + io.tool_output("Skill exclusion is only available in agent mode.") + return format_command_result( + io, "exclude-skill", "Skill exclusion is only available in agent mode" + ) + + # Check if skills_manager is available + if not hasattr(coder, "skills_manager") or coder.skills_manager is None: + io.tool_output("Skills manager is not initialized. Skills may not be configured.") + # Check if skills directories are configured + if hasattr(coder, "skills_directory_paths") and not coder.skills_directory_paths: + io.tool_output( + "No skills directories configured. Use --skills-paths to configure skill" + " directories." + ) + return format_command_result(io, "exclude-skill", "Skills manager is not initialized") + + results = [] + for skill_name in skill_names: + # Use the instance method on skills_manager + result = coder.skills_manager.exclude_skill(skill_name) + results.append(result) + + return format_command_result(io, "exclude-skill", "\n".join(results)) + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for exclude-skill command.""" + if not hasattr(coder, "skills_manager") or coder.skills_manager is None: + return [] + + try: + skills = coder.skills_manager.find_skills() + return [skill.name for skill in skills] + except Exception: + return [] + + @classmethod + def get_help(cls) -> str: + """Get help text for the exclude-skill command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /exclude-skill ... # Exclude one or more skills by name\n" + help_text += "\nExamples:\n" + help_text += " /exclude-skill pdf # Exclude (blacklist) the PDF skill\n" + help_text += " /exclude-skill web pdf # Exclude both web and PDF skills\n" + help_text += ( + "\nThis command excludes one or more skills by name, adding them to the blacklist. " + "Skills are only available in agent mode.\n" + ) + help_text += "Excluded skills will be hidden from discovery and unavailable for loading.\n" + return help_text diff --git a/cecli/commands/include_skill.py b/cecli/commands/include_skill.py new file mode 100644 index 00000000000..79f31fde15d --- /dev/null +++ b/cecli/commands/include_skill.py @@ -0,0 +1,72 @@ +from typing import List + +from cecli.commands.utils.base_command import BaseCommand +from cecli.commands.utils.helpers import format_command_result + + +class IncludeSkillCommand(BaseCommand): + NORM_NAME = "include-skill" + DESCRIPTION = "Include a skill by name (agent mode only)" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the include-skill command with given parameters.""" + if not args.strip(): + io.tool_output("Usage: /include-skill ") + return format_command_result(io, "include-skill", "Usage: /include-skill ") + + skill_names = args.strip().split() + + # Check if we're in agent mode + if not hasattr(coder, "edit_format") or coder.edit_format != "agent": + io.tool_output("Skill inclusion is only available in agent mode.") + return format_command_result( + io, "include-skill", "Skill inclusion is only available in agent mode" + ) + + # Check if skills_manager is available + if not hasattr(coder, "skills_manager") or coder.skills_manager is None: + io.tool_output("Skills manager is not initialized. Skills may not be configured.") + # Check if skills directories are configured + if hasattr(coder, "skills_directory_paths") and not coder.skills_directory_paths: + io.tool_output( + "No skills directories configured. Use --skills-paths to configure skill" + " directories." + ) + return format_command_result(io, "include-skill", "Skills manager is not initialized") + + results = [] + for skill_name in skill_names: + # Use the instance method on skills_manager + result = coder.skills_manager.include_skill(skill_name) + results.append(result) + + return format_command_result(io, "include-skill", "\n".join(results)) + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for include-skill command.""" + if not hasattr(coder, "skills_manager") or coder.skills_manager is None: + return [] + + try: + skills = coder.skills_manager.find_skills() + return [skill.name for skill in skills] + except Exception: + return [] + + @classmethod + def get_help(cls) -> str: + """Get help text for the include-skill command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /include-skill ... # Include one or more skills by name\n" + help_text += "\nExamples:\n" + help_text += " /include-skill pdf # Include (whitelist) the PDF skill\n" + help_text += " /include-skill web pdf # Include both web and PDF skills\n" + help_text += ( + "\nThis command includes one or more skills by name, adding them to the whitelist. " + "Skills are only available in agent mode.\n" + ) + help_text += "When a skill is included, only whitelisted skills will be discoverable.\n" + return help_text diff --git a/cecli/commands/list_skills.py b/cecli/commands/list_skills.py new file mode 100644 index 00000000000..752dc885c06 --- /dev/null +++ b/cecli/commands/list_skills.py @@ -0,0 +1,51 @@ +from typing import List + +from cecli.commands.utils.base_command import BaseCommand +from cecli.commands.utils.helpers import format_command_result + + +class ListSkillsCommand(BaseCommand): + NORM_NAME = "list-skills" + DESCRIPTION = "List all available skills with their states and file paths" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the list-skills command with given parameters.""" + # Check if skills_manager is available + if not hasattr(coder, "skills_manager") or coder.skills_manager is None: + io.tool_output("Skills manager is not initialized. Skills may not be configured.") + if hasattr(coder, "skills_directory_paths") and not coder.skills_directory_paths: + io.tool_output( + "No skills directories configured. Use --skills-paths to configure skill" + " directories." + ) + return format_command_result(io, "list-skills", "Skills manager is not initialized") + + try: + formatted = coder.skills_manager.get_skills_list_formatted() + return format_command_result(io, "list-skills", formatted) + except Exception as e: + error_msg = f"Error listing skills: {e}" + return format_command_result(io, "list-skills", error_msg) + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for list-skills command.""" + return [] + + @classmethod + def get_help(cls) -> str: + """Get help text for the list-skills command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /list-skills # List all available skills with their states and paths\n" + help_text += "\nExamples:\n" + help_text += ( + " /list-skills # Shows a table of all skills, their include/exclude/visible status,\n" + ) + help_text += " # whether they are loaded, and their directory paths\n" + help_text += "\n" + help_text += "This command lists all skills found in the configured skill directories,\n" + help_text += "displaying their current status (included/excluded/visible),\n" + help_text += "whether they are loaded into context, and their file system paths.\n" + return help_text diff --git a/cecli/commands/load_mcp.py b/cecli/commands/load_mcp.py index eb1e6d2e402..302d568640f 100644 --- a/cecli/commands/load_mcp.py +++ b/cecli/commands/load_mcp.py @@ -20,48 +20,68 @@ async def execute(cls, io, coder, args, **kwargs): ) server_names = args.strip().split() + results = [] + servers_to_load = [] + # Handle '*' wildcard to load all servers enabled by default if server_names == ["*"]: for server in coder.mcp_manager.servers: if server in coder.mcp_manager.connected_servers: results.append(f"Server already loaded: {server.name}") continue + auto_connect = server.config.get("enabled", True) if not auto_connect: results.append(f"Skipping server (not enabled by default): {server.name}") continue - did_connect = await coder.mcp_manager.connect_server(server.name) - if did_connect: - results.append(f"Loaded server: {server.name}") - else: - results.append(f"Unable to load server: {server.name}") + + servers_to_load.append(server) else: for server_name in server_names: server = coder.mcp_manager.get_server(server_name) if server is None: + io.tool_error(f"MCP server {server_name} does not exist.") results.append(f"MCP server {server_name} does not exist.") - continue - - did_connect = await coder.mcp_manager.connect_server(server.name) - if did_connect: - results.append(f"Loaded server: {server_name}") else: - results.append(f"Unable to load server: {server_name}") + servers_to_load.append(server) - try: - return format_command_result(io, cls.NORM_NAME, "\n".join(results)) - finally: - from . import SwitchCoderSignal - - raise SwitchCoderSignal( - edit_format=coder.edit_format, - summarize_from_coder=False, - from_coder=coder, - show_announcements=True, + # Early exit if nothing valid to process + if not servers_to_load and results: + return format_command_result(io, cls.NORM_NAME, "", "\n".join(results)) + + # Process connections with interrupt support + for server in servers_to_load: + server_name = server.name + coder.interrupt_event.clear() + + did_connect, interrupted = await coder.coroutines.interruptible( + coder.mcp_manager.connect_server(server_name), + coder.interrupt_event, ) + if interrupted: + io.tool_warning(f"MCP connection interrupted: {server_name}") + results.append(f"Interrupted: {server_name}") + continue + + if did_connect: + results.append(f"Loaded server: {server_name}") + else: + results.append(f"Unable to load server: {server_name}") + + io.tool_output("\n".join(results)) + + from . import SwitchCoderSignal + + raise SwitchCoderSignal( + edit_format=coder.edit_format, + summarize_from_coder=False, + from_coder=coder, + show_announcements=True, + ) + @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for load-mcp command.""" diff --git a/cecli/commands/load_skill.py b/cecli/commands/load_skill.py index ccb90ef8353..3750a3f2645 100644 --- a/cecli/commands/load_skill.py +++ b/cecli/commands/load_skill.py @@ -39,7 +39,6 @@ async def execute(cls, io, coder, args, **kwargs): for skill_name in skill_names: # Use the instance method on skills_manager result = coder.skills_manager.load_skill(skill_name) - io.tool_output(result) results.append(result) return format_command_result(io, "load-skill", "\n".join(results)) diff --git a/cecli/commands/remove_mcp.py b/cecli/commands/remove_mcp.py index 2239d7ba883..ad212da4051 100644 --- a/cecli/commands/remove_mcp.py +++ b/cecli/commands/remove_mcp.py @@ -20,38 +20,59 @@ async def execute(cls, io, coder, args, **kwargs): ) server_names = args.strip().split() + results = [] + servers_to_disconnect = [] # Handle '*' wildcard to disconnect all servers if server_names == ["*"]: connected = [s for s in coder.mcp_manager.servers if s.is_connected] + if not connected: results.append("No MCP servers connected, nothing to remove.") else: - for server in connected: - await coder.mcp_manager.disconnect_server(server.name) - results.append(f"Removed server: {server.name}") + servers_to_disconnect.extend(connected) else: for server_name in server_names: - was_disconnected = await coder.mcp_manager.disconnect_server(server_name) - if was_disconnected: - results.append(f"Removed server: {server_name}") - else: - results.append(f"Unable to remove server: {server_name}") + servers_to_disconnect.append(server_name) - try: - return format_command_result(io, cls.NORM_NAME, "\n".join(results)) - finally: - from . import SwitchCoderSignal - - raise SwitchCoderSignal( - edit_format=coder.edit_format, - summarize_from_coder=False, - from_coder=coder, - show_announcements=True, - mcp_manager=coder.mcp_manager, + # Early exit if nothing to process + if not servers_to_disconnect and results: + return format_command_result(io, cls.NORM_NAME, "", "\n".join(results)) + + # Process disconnections with interrupt support + for item in servers_to_disconnect: + server_name = item.name if hasattr(item, "name") else item + + coder.interrupt_event.clear() + + was_disconnected, interrupted = await coder.coroutines.interruptible( + coder.mcp_manager.disconnect_server(server_name), + coder.interrupt_event, ) + if interrupted: + io.tool_warning(f"MCP disconnection interrupted: {server_name}") + results.append(f"Interrupted: {server_name}") + continue + + if was_disconnected: + results.append(f"Removed server: {server_name}") + else: + results.append(f"Unable to remove server: {server_name}") + + io.tool_output("\n".join(results)) + + from . import SwitchCoderSignal + + raise SwitchCoderSignal( + edit_format=coder.edit_format, + summarize_from_coder=False, + from_coder=coder, + show_announcements=True, + mcp_manager=coder.mcp_manager, + ) + @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for remove-mcp command.""" diff --git a/cecli/commands/remove_skill.py b/cecli/commands/remove_skill.py index 4d665453a73..97b3628ccd9 100644 --- a/cecli/commands/remove_skill.py +++ b/cecli/commands/remove_skill.py @@ -39,7 +39,6 @@ async def execute(cls, io, coder, args, **kwargs): for skill_name in skill_names: # Use the instance method on skills_manager result = coder.skills_manager.remove_skill(skill_name) - io.tool_output(result) results.append(result) return format_command_result(io, "remove-skill", "\n".join(results)) diff --git a/cecli/commands/utils/helpers.py b/cecli/commands/utils/helpers.py index 475de317874..2217278d2cd 100644 --- a/cecli/commands/utils/helpers.py +++ b/cecli/commands/utils/helpers.py @@ -253,10 +253,10 @@ def format_command_result( Formatted result string """ if error: - io.tool_error(f"\nError in {command_name}: {str(error)}") + io.tool_error(f"Error in {command_name}: {str(error)}") return f"Error: {str(error)}" else: - io.tool_output(f"\n✅ {success_message}") + io.tool_output(f"✅ {success_message}") return f"Successfully executed {command_name}." diff --git a/cecli/helpers/conversation/files.py b/cecli/helpers/conversation/files.py index 13b69641c5a..11ae831608e 100644 --- a/cecli/helpers/conversation/files.py +++ b/cecli/helpers/conversation/files.py @@ -436,16 +436,18 @@ def get_file_context(self, file_path: str) -> str: # Generate hashline representations for each range context_parts = [] + content_lines = content.splitlines() + for i, (start_line, end_line) in enumerate(ranges): # Note: hashline uses 1-based line numbers, so no conversion needed start_line_adj = max(1, start_line) - end_line_adj = min(len(content.splitlines()), end_line) + end_line_adj = min(len(content_lines), end_line) if start_line_adj > end_line_adj: continue # Extract lines for this range (0-based indexing for list) - lines = content.splitlines()[start_line_adj - 1 : end_line_adj] + lines = content_lines[start_line_adj - 1 : end_line_adj] # Generate hashline representation using the hashline() function # Join lines back with newlines for hashline() @@ -469,14 +471,16 @@ def remove_file_context(self, file_path: str) -> None: # Remove from numbered contexts self._numbered_contexts.pop(abs_fname, None) - # Remove using hash key (file_context, abs_fname) + # Remove using hash key pattern matching for file_context messages coder = self.get_coder() if coder: - ConversationService.get_manager(coder).remove_message_by_hash_key( - ("file_context_user", abs_fname) - ) - ConversationService.get_manager(coder).remove_message_by_hash_key( - ("file_context_assistant", abs_fname) + ConversationService.get_manager(coder).remove_messages_by_hash_key_pattern( + lambda hash_key: ( + isinstance(hash_key, tuple) + and len(hash_key) in (2, 3) + and hash_key[0] in ("file_context_user", "file_context_assistant") + and hash_key[1] == abs_fname + ) ) def remove_file_messages(self, file_path: str) -> None: @@ -488,14 +492,16 @@ def remove_file_messages(self, file_path: str) -> None: """ abs_fname = os.path.abspath(file_path) - # Remove using hash key (file_context, abs_fname) + # Remove using hash key pattern matching for file_context messages coder = self.get_coder() if coder: - ConversationService.get_manager(coder).remove_message_by_hash_key( - ("file_context_user", abs_fname) - ) - ConversationService.get_manager(coder).remove_message_by_hash_key( - ("file_context_assistant", abs_fname) + ConversationService.get_manager(coder).remove_messages_by_hash_key_pattern( + lambda hash_key: ( + isinstance(hash_key, tuple) + and len(hash_key) in (2, 3) + and hash_key[0] in ("file_context_user", "file_context_assistant") + and hash_key[1] == abs_fname + ) ) def clear_all_numbered_contexts(self) -> None: diff --git a/cecli/helpers/conversation/integration.py b/cecli/helpers/conversation/integration.py index 00ee834e004..e2a0a083ba8 100644 --- a/cecli/helpers/conversation/integration.py +++ b/cecli/helpers/conversation/integration.py @@ -18,6 +18,7 @@ class ConversationChunks: def __init__(self, coder): self.coder = weakref.ref(coder) self.uuid = coder.uuid + self._last_clear_count = 0 @classmethod def get_instance(cls, coder) -> "ConversationChunks": @@ -212,7 +213,11 @@ def cleanup_files(self) -> None: if diff_count > 0 and other_count > 0 and diff_count / other_count > 20: should_clear = True - if should_clear: + self._last_clear_count += 1 + + if should_clear and self._last_clear_count >= 10: + self._last_clear_count = 0 + # Clear all diff messages ConversationService.get_manager(coder).clear_tag(MessageTag.DIFFS) ConversationService.get_manager(coder).clear_tag(MessageTag.FILE_CONTEXTS) @@ -263,6 +268,43 @@ def cleanup_files(self) -> None: image_assistant_hash_key ) + # Clean up stale file_context messages + # If a file has 3 or more file_context_user messages, remove all but the most recent + # (and their corresponding assistant messages) to prevent excessive stale context + file_context_messages = ConversationService.get_manager(coder).get_tag_messages( + MessageTag.FILE_CONTEXTS + ) + + # Group user file_context messages by file path + user_msgs_by_file: Dict[str, List[int]] = {} + user_msg_indices: List[int] = [] + for msg_idx, msg in enumerate(file_context_messages): + if msg.hash_key and len(msg.hash_key) == 3 and msg.hash_key[0] == "file_context_user": + file_path = msg.hash_key[1] + if file_path not in user_msgs_by_file: + user_msgs_by_file[file_path] = [] + user_msgs_by_file[file_path].append(msg_idx) + user_msg_indices.append(msg_idx) + + # For files with 3+ user messages, keep only the last one + hash_keys_to_remove: set = set() + for file_path, indices in user_msgs_by_file.items(): + if len(indices) >= 3: + # Keep the last one (most recent in sorted order) + older_indices = indices[:-1] + for old_idx in older_indices: + old_msg = file_context_messages[old_idx] + content_hash = old_msg.hash_key[2] + # Mark the user message for removal + hash_keys_to_remove.add(("file_context_user", file_path, content_hash)) + # Mark the corresponding assistant message for removal + hash_keys_to_remove.add(("file_context_assistant", file_path, content_hash)) + + if hash_keys_to_remove: + ConversationService.get_manager(coder).remove_messages_by_hash_key_pattern( + lambda hash_key: hash_key in hash_keys_to_remove + ) + ConversationService.get_manager(coder).clear_tag(MessageTag.RULES) def add_file_list_reminder(self) -> None: @@ -587,7 +629,7 @@ def add_readonly_files_messages(self) -> List[Dict[str, Any]]: # Add assistant message with file path as hash_key assistant_msg = { "role": "assistant", - "content": "I understand, thank you for sharing the file contents.", + "content": f"Thank you for sharing the file contents for {rel_fname}.", } ConversationService.get_manager(coder).add_message( message_dict=assistant_msg, @@ -687,7 +729,7 @@ def add_chat_files_messages(self) -> Dict[str, Any]: # Create assistant message assistant_msg = { "role": "assistant", - "content": "I understand, thank you for sharing the file contents.", + "content": f"Thank you for sharing the file contents for {rel_fname}.", } # Determine tag based on editability @@ -777,22 +819,21 @@ def add_file_context_messages(self, promote_messages=True) -> None: assistant_msg = { "role": "assistant", - "content": "I understand, thank you for sharing the prefixed file contents.", + "content": f"Thank you for sharing the prefixed file contents for {rel_fname}.", } # Add to conversation manager + content_hash = xxhash.xxh3_128_hexdigest(context_content.encode("utf-8")) ConversationService.get_manager(coder).queue_message( message_dict=user_msg, tag=MessageTag.FILE_CONTEXTS, - hash_key=("file_context_user", file_path), - force=True, + hash_key=("file_context_user", file_path, content_hash), ) ConversationService.get_manager(coder).queue_message( message_dict=assistant_msg, tag=MessageTag.FILE_CONTEXTS, - hash_key=("file_context_assistant", file_path), - force=True, + hash_key=("file_context_assistant", file_path, content_hash), ) def reset(self) -> None: diff --git a/cecli/helpers/coroutines.py b/cecli/helpers/coroutines.py index 77cee82b162..07f1a669d5a 100644 --- a/cecli/helpers/coroutines.py +++ b/cecli/helpers/coroutines.py @@ -1,8 +1,45 @@ -import asyncio # noqa: F401 +import asyncio -def is_active(coroutine): - if not coroutine or coroutine.done() or coroutine.cancelled(): +def is_active(task): + if not task or task.done() or task.cancelled(): return False return True + + +async def interruptible(coroutine, interrupt_event): + """ + Runs a coroutine and allows it to be interrupted by an asyncio.Event. + + Args: + coroutine: The coroutine to run. + interrupt_event: The asyncio.Event that signals an interruption. + + Returns: + A tuple of (result, interrupted). + - If not interrupted: (coroutine_result, False) + - If interrupted: (None, True) + """ + main_task = asyncio.create_task(coroutine) + interrupt_task = asyncio.create_task(interrupt_event.wait()) + + done, pending = await asyncio.wait( + {main_task, interrupt_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # Expected + + if interrupt_task in done: + return None, True + + try: + return main_task.result(), False + except asyncio.CancelledError: + return None, True diff --git a/cecli/helpers/hashpos/hashpos.py b/cecli/helpers/hashpos/hashpos.py index f4d16ba1879..e694fa9e7d0 100644 --- a/cecli/helpers/hashpos/hashpos.py +++ b/cecli/helpers/hashpos/hashpos.py @@ -26,6 +26,54 @@ def _get_anchor_bits(self, line_idx: int) -> int: a2 = (line_idx * 59 + 31) % 63 return (a1 << 6) | a2 + def _spread_bits(self, x: int) -> int: + """ + Spreads 12 bits of x into 24 bits by inserting a 0 between each bit. + Input: 000000000000abcdefghijkl (12 bits) + Output: 0a0b0c0d0e0f0g0h0i0j0k0l (24 bits) + """ + x &= 0xFFF # Ensure we only have 12 bits + # Shift bits by 8, mask keeps the blocks separated + # x starts: 000000000000 abcdefgh ijkl + x = (x | (x << 8)) & 0x00FF00FF # 0000abcd efgh0000 00000000 ijkl... + # Shift by 4, then 2, then 1 to create 1-bit gaps + x = (x | (x << 4)) & 0x0F0F0F0F + x = (x | (x << 2)) & 0x33333333 + x = (x | (x << 1)) & 0x55555555 # Result: 0a0b0c0d0e0f0g0h0i0j0k0l + return x + + def _compact_bits(self, x: int) -> int: + """ + The inverse of spread: pulls every other bit back together. + Input: 0a0b0c0d0e0f0g0h0i0j0k0l (24 bits) + Output: 000000000000abcdefghijkl (12 bits) + """ + x &= 0x55555555 # Mask to ensure we only look at the "active" bits + x = (x | (x >> 1)) & 0x33333333 + x = (x | (x >> 2)) & 0x0F0F0F0F + x = (x | (x >> 4)) & 0x00FF00FF + x = (x | (x >> 8)) & 0x0000FFFF # Result: abcdefghijkl + return x + + def _interleave(self, content: int, anchor: int) -> int: + """ + Weaves content and anchor bits together. + Content bits occupy the 'odd' positions, Anchor bits occupy the 'even'. + """ + # Spread content bits and shift by 1 to put them in positions 1, 3, 5... + # Spread anchor bits and leave them in positions 0, 2, 4... + return (self._spread_bits(content) << 1) | self._spread_bits(anchor) + + def _deinterleave(self, mixed: int) -> tuple[int, int]: + """ + Extracts content and anchor bits from a 24-bit interleaved integer. + """ + # To get content: shift right by 1, then compact + content = self._compact_bits(mixed >> 1) + # To get anchor: just compact (the mask inside _compact_bits handles the rest) + anchor = self._compact_bits(mixed) + return content, anchor + def generate_private_id(self, text: str) -> str: bits = self._get_content_bits(text) return f"{bits:03x}" @@ -33,9 +81,7 @@ def generate_private_id(self, text: str) -> str: def generate_public_id(self, text: str, line_idx: int) -> str: content_bits = self._get_content_bits(text) anchor_bits = self._get_anchor_bits(line_idx) - # Apply modular offset to content bits using anchor bits - offset_content = (content_bits + anchor_bits) & 0xFFF - packed = (offset_content << 12) | anchor_bits + packed = self._interleave(content_bits, anchor_bits) res = "" for _ in range(4): @@ -48,11 +94,7 @@ def unpack_public_id(self, public_id: str) -> tuple[int, int]: for i, char in enumerate(public_id): packed |= self.B64.index(char) << (6 * i) - offset_content = (packed >> 12) & 0xFFF - anchor_bits = packed & 0xFFF - # Reverse the modular offset to recover original content bits - content_bits = (offset_content - anchor_bits) & 0xFFF - return content_bits, anchor_bits + return self._deinterleave(packed) def format_content(self, use_private_ids: bool = False, start_line: int = 1) -> str: formatted_lines = [] diff --git a/cecli/helpers/skills.py b/cecli/helpers/skills.py index a239aebf95f..a209122d39b 100644 --- a/cecli/helpers/skills.py +++ b/cecli/helpers/skills.py @@ -12,6 +12,10 @@ import yaml +# Global state store for sticky include/exclude lists, keyed by coder.uuid +# This ensures skill state survives SkillsManager re-creation within the same coder session +_skill_state_store: Dict[str, Dict[str, Any]] = {} + @dataclass class SkillMetadata: @@ -73,6 +77,49 @@ def __init__( # Track which skills have been loaded via load_skill() self._loaded_skills: set[str] = set() + # Restore state from global store (sticky across SkillsManager recreation) + if not self._restore_state(): + # First time initialization - save initial state from config + self._save_state() + + def _save_state(self): + """Save current mutable state to the global skill state store. + + This allows state to persist across SkillsManager re-creation + within the same coder session. + """ + if not self.coder or not getattr(self.coder, "uuid", None): + return + + _skill_state_store[self.coder.uuid] = { + "include_list": self.include_list.copy() if self.include_list is not None else None, + "exclude_list": self.exclude_list.copy(), + "loaded_skills": self._loaded_skills.copy(), + } + + def _restore_state(self) -> bool: + """Restore mutable state from the global skill state store if available. + + Returns: + True if state was restored, False otherwise. + """ + + if not self.coder or not getattr(self.coder, "uuid", None): + return False + + state = _skill_state_store.get(self.coder.uuid) + + if state is None: + return False + + self.include_list = ( + state["include_list"].copy() if state["include_list"] is not None else None + ) + self.exclude_list = state["exclude_list"].copy() + self._loaded_skills = state["loaded_skills"].copy() + + return True + def find_skills(self, reload: bool = False) -> List[SkillMetadata]: """ Find all skills in the configured directory paths. @@ -397,6 +444,9 @@ def load_skill(self, skill_name: str) -> str: # Add to loaded skills set self._loaded_skills.add(skill_name) + # Persist state to global store + self._save_state() + result = f"Skill '{skill_name}' loaded successfully." # Show skill summary @@ -434,8 +484,239 @@ def remove_skill(self, skill_name: str) -> str: # Remove from loaded skills set self._loaded_skills.remove(skill_name) + # Persist state to global store + self._save_state() + return f"Skill '{skill_name}' removed successfully." + def include_skill(self, skill_name: str) -> str: + """ + Add a skill to the include list (whitelist), making only this skill visible. + This method controls which skills are discoverable via find_skills(). + + Args: + skill_name: Name of the skill to include + + Returns: + Success or error message + """ + if not skill_name: + return "Error: Skill name is required." + + # Check if coder is available + if not self.coder: + return "Error: Skills manager not connected to a coder instance." + + # Check if we're in agent mode + if not hasattr(self.coder, "edit_format") or self.coder.edit_format != "agent": + return "Error: Skill inclusion is only available in agent mode." + + # Find the skill to verify it exists + skills = self.find_skills(reload=True) + skill_found = any(skill.name == skill_name for skill in skills) + + if not skill_found: + # The skill might already be filtered out by the include/exclude lists. + # Check if it exists in any directory by scanning without filters. + original_include = self.include_list + original_exclude = self.exclude_list + self.include_list = None + self.exclude_list = set() + all_skills = self.find_skills(reload=True) + self.include_list = original_include + self.exclude_list = original_exclude + skill_found = any(skill.name == skill_name for skill in all_skills) + + if not skill_found: + return f"Error: Skill '{skill_name}' not found in configured directories." + + # Ensure include_list is initialized + if self.include_list is None: + self.include_list = set() + self.include_list.add(skill_name) + + # Also remove from exclude_list if present + if skill_name in self.exclude_list: + self.exclude_list.discard(skill_name) + + # Persist state to global store + self._save_state() + + # Clear caches so find_skills reflects the change + self.hot_reload() + + return f"Skill '{skill_name}' has been included (whitelisted)." + + def exclude_skill(self, skill_name: str) -> str: + """ + Add a skill to the exclude list (blacklist), hiding it from discovery. + This method controls which skills are hidden via find_skills(). + + Args: + skill_name: Name of the skill to exclude + + Returns: + Success or error message + """ + if not skill_name: + return "Error: Skill name is required." + + # Check if coder is available + if not self.coder: + return "Error: Skills manager not connected to a coder instance." + + # Check if we're in agent mode + if not hasattr(self.coder, "edit_format") or self.coder.edit_format != "agent": + return "Error: Skill exclusion is only available in agent mode." + + # Find the skill to verify it exists + skills = self.find_skills(reload=True) + skill_found = any(skill.name == skill_name for skill in skills) + + if not skill_found: + # The skill might already be filtered out by include/exclude lists. + # Check if it exists in any directory by scanning without filters. + original_include = self.include_list + original_exclude = self.exclude_list + self.include_list = None + self.exclude_list = set() + all_skills = self.find_skills(reload=True) + self.include_list = original_include + self.exclude_list = original_exclude + skill_found = any(skill.name == skill_name for skill in all_skills) + + if not skill_found: + return f"Error: Skill '{skill_name}' not found in configured directories." + + # Add to exclude_list + self.exclude_list.add(skill_name) + + # Also remove from include_list if present + if self.include_list and skill_name in self.include_list: + self.include_list.discard(skill_name) + # If include_list is now empty, reset to None (no whitelist filtering) + if not self.include_list: + self.include_list = None + + # Also remove from loaded_skills if present, since it won't be visible + if skill_name in self._loaded_skills: + self._loaded_skills.discard(skill_name) + + # Persist state to global store + self._save_state() + + # Clear caches so find_skills reflects the change + self.hot_reload() + + return f"Skill '{skill_name}' has been excluded (blacklisted)." + + def get_all_skills_info(self) -> List[Dict[str, Any]]: + """ + Get detailed information about all skills across all directories, + including their current state (included, excluded, loaded) and file paths. + + This bypasses include/exclude filters to give a complete picture. + + Returns: + List of dicts with keys: name, description, path, license, allowed_tools, + status ("included", "excluded", "visible"), loaded, has_references, + has_scripts, has_assets, has_evals + """ + # Save current filter state + original_include = self.include_list + original_exclude = self.exclude_list + + # Scan without filters to find all skills + self.include_list = None + self.exclude_list = set() + all_skills = self.find_skills(reload=True) + + # Restore original filter state + self.include_list = original_include + self.exclude_list = original_exclude + + # Also restore the cache to reflect the actual filters + self.hot_reload() + + result = [] + for meta in all_skills: + skill_name = meta.name + + # Determine status + if original_include is not None and skill_name in original_include: + status = "included" + elif skill_name in original_exclude: + status = "excluded" + else: + status = "visible" + + # Check if loaded + is_loaded = skill_name in self._loaded_skills + + skill_content = self._skills_cache.get(skill_name) + has_references = bool(skill_content and skill_content.references) + has_scripts = bool(skill_content and skill_content.scripts) + has_assets = bool(skill_content and skill_content.assets) + has_evals = bool(skill_content and skill_content.evals) + + info = { + "name": skill_name, + "description": meta.description, + "path": str(meta.path), + "license": meta.license, + "allowed_tools": meta.allowed_tools, + "status": status, + "loaded": is_loaded, + "has_references": has_references, + "has_scripts": has_scripts, + "has_assets": has_assets, + "has_evals": has_evals, + } + result.append(info) + + return result + + def get_skills_list_formatted(self) -> str: + """ + Get a human-readable table of all skills with their states and paths. + + Returns: + Formatted string listing all skills with state and path info + """ + all_skills = self.get_all_skills_info() + + if not all_skills: + return "No skills found in the configured directories." + + # Calculate column widths + name_width = max(len(s["name"]) for s in all_skills) + name_width = max(name_width, len("Skill Name")) + + status_width = max(len(s["status"]) for s in all_skills) + status_width = max(status_width, len("Status")) + + result = f"Found {len(all_skills)} skill(s) in configured directories:\n\n" + + # Header + header = f" {'Skill Name'.ljust(name_width)} {'Status'.ljust(status_width)} Loaded Path" + result += header + "\n" + result += "-" * len(header) + "\n" + + for skill in all_skills: + name = skill["name"].ljust(name_width) + status = skill["status"].ljust(status_width) + loaded = "Yes" if skill["loaded"] else "No" + path = skill["path"] + result += f" {name} {status} {loaded:<5} {path}\n" + + result += "\n" + result += "Status meanings:\n" + result += " included - Skill is whitelisted (skill available for discovery/loading)\n" + result += " excluded - Skill is blacklisted (hidden from discovery)\n" + result += " loaded - Whether the skill content has been loaded via load_skill\n" + + return result + @classmethod def skill_summary_loader( cls, diff --git a/cecli/io.py b/cecli/io.py index 4f50b9f6a02..c3f207bade8 100644 --- a/cecli/io.py +++ b/cecli/io.py @@ -762,6 +762,8 @@ def interrupt_input(self): coder = self.coder() if coder and hasattr(coder, "interrupt_event"): coder.interrupt_event.set() + if self.output_task and not self.output_task.done(): + self.output_task.cancel() if self.prompt_session and self.prompt_session.app: # Store any partial input before interrupting diff --git a/cecli/main.py b/cecli/main.py index bfebaffc6d1..0549ea78b55 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -1247,6 +1247,9 @@ def get_io(pretty): if switch.kwargs.get("show_announcements") is False: coder.suppress_announcements_for_next_prompt = True + except KeyboardInterrupt: + coder.keyboard_interrupt() + continue except SystemExit: sys.settrace(None) await coder.auto_save_session(force=True) diff --git a/cecli/models.py b/cecli/models.py index 04e47c7d4ac..495895bda12 100644 --- a/cecli/models.py +++ b/cecli/models.py @@ -19,7 +19,7 @@ from cecli import __version__ from cecli.dump import dump from cecli.exceptions import LiteLLMExceptions -from cecli.helpers import nested +from cecli.helpers import coroutines, nested from cecli.helpers.file_searcher import generate_search_path_list, handle_core_files from cecli.helpers.model_providers import ModelProviderManager from cecli.helpers.nested import deep_merge @@ -1132,6 +1132,7 @@ async def send_completion( min_wait=0, max_wait=2, override_kwargs={}, + interrupt_event=None, ): if os.environ.get("CECLI_SANITY_CHECK_TURNS"): sanity_check_messages(messages) @@ -1290,7 +1291,14 @@ async def send_completion( return hash_object, self.model_error_response() print(f"Retrying in {retry_delay:.1f} seconds...") - await asyncio.sleep(retry_delay) + if interrupt_event: + _res, interrupted = await coroutines.interruptible( + asyncio.sleep(retry_delay), interrupt_event + ) + if interrupted: + raise KeyboardInterrupt("Interrupted during retry sleep") + else: + await asyncio.sleep(retry_delay) continue async def simple_send_with_retries( diff --git a/cecli/sessions.py b/cecli/sessions.py index c1e9fbdc5f3..5384ae4e0a2 100644 --- a/cecli/sessions.py +++ b/cecli/sessions.py @@ -214,6 +214,12 @@ def _build_session_data(self, session_name) -> Dict: "mcps": connected_mcps, "skills": skills_data, "tools": agent_config_data, + "usage": { + "total_tokens_sent": self.coder.total_tokens_sent, + "total_tokens_received": self.coder.total_tokens_received, + "total_cached_tokens": self.coder.total_cached_tokens, + "total_cost": self.coder.total_cost, + }, } def _find_session_file(self, session_identifier: str) -> Optional[Path]: @@ -271,6 +277,12 @@ async def _apply_session_data(self, session_data: Dict, session_file: Path) -> b else: self.io.tool_warning(f"File not found, skipping: {rel_fname}") + # Load usage stats + usage = session_data.get("usage", {}) + self.coder.total_tokens_sent = usage.get("total_tokens_sent", 0) + self.coder.total_tokens_received = usage.get("total_tokens_received", 0) + self.coder.total_cached_tokens = usage.get("total_cached_tokens", 0) + self.coder.total_cost = usage.get("total_cost", 0.0) if session_data.get("model"): self.coder.main_model = models.Model( session_data.get("model", self.coder.args.model), diff --git a/cecli/tools/__init__.py b/cecli/tools/__init__.py index 39639f525e3..9cc334f0894 100644 --- a/cecli/tools/__init__.py +++ b/cecli/tools/__init__.py @@ -9,7 +9,6 @@ edit_text, explore_code, finished, - get_lines, git_branch, git_diff, git_log, @@ -19,6 +18,7 @@ grep, load_skill, ls, + read_range, remove_skill, thinking, undo_change, @@ -33,7 +33,6 @@ edit_text, explore_code, finished, - get_lines, git_branch, git_diff, git_log, @@ -43,6 +42,7 @@ grep, load_skill, ls, + read_range, remove_skill, thinking, undo_change, diff --git a/cecli/tools/command.py b/cecli/tools/command.py index 28c1bec9ba6..4bf1ec941c4 100644 --- a/cecli/tools/command.py +++ b/cecli/tools/command.py @@ -228,6 +228,15 @@ async def _execute_with_timeout(cls, coder, command_string, timeout, use_pty=Fal start_time = time.time() while True: + if coder.interrupt_event.is_set(): + process.terminate() + try: + process.wait(timeout=1) + except subprocess.TimeoutExpired: + process.kill() + BackgroundCommandManager.stop_background_command(command_key) + return "Command execution interrupted by user." + # Check if process has completed exit_code = process.poll() if exit_code is not None: diff --git a/cecli/tools/context_manager.py b/cecli/tools/context_manager.py index 6a0ed86808a..7a27a4e60bb 100644 --- a/cecli/tools/context_manager.py +++ b/cecli/tools/context_manager.py @@ -15,30 +15,25 @@ class Tool(BaseTool): "function": { "name": "ContextManager", "description": ( - "Manage multiple files in the chat context: remove, editable, view, and create." + "Manage multiple files in the chat context: add, read_only, create, and remove." " Accepts arrays of file paths for each operation." ), "parameters": { "type": "object", "properties": { - "remove": { - "type": "array", - "items": {"type": "string"}, - "description": "List of file paths to remove from context.", - }, - "editable": { + "add": { "type": "array", "items": {"type": "string"}, "description": ( - "List of file paths to make editable. Limit to at most 2 at a time." + "List of file paths to add to context. Limit to at most 2 at a time." ), }, - "view": { + "read_only": { "type": "array", "items": {"type": "string"}, "description": ( - "List of file paths to view (add as read-only). Limit to at most 2 at a" - " time." + "List of file paths to add as read-only. " + "Limit to at most 2 at a time." ), }, "create": { @@ -46,6 +41,11 @@ class Tool(BaseTool): "items": {"type": "string"}, "description": "List of file paths to create.", }, + "remove": { + "type": "array", + "items": {"type": "string"}, + "description": "List of file paths to remove from context.", + }, }, "additionalProperties": False, "required": [], @@ -54,7 +54,7 @@ class Tool(BaseTool): } @classmethod - def execute(cls, coder, remove=None, editable=None, view=None, create=None, **kwargs): + def execute(cls, coder, remove=None, add=None, read_only=None, create=None, **kwargs): """Perform batch operations on the coder's context. Parameters @@ -63,7 +63,7 @@ def execute(cls, coder, remove=None, editable=None, view=None, create=None, **kw The active coder handling file context. remove: list[str] | None Files to remove from the context. - editable: list[str] | None + add: list[str] | None Files to promote to editable status. view: list[str] | None Files to add as read-only view. @@ -71,8 +71,8 @@ def execute(cls, coder, remove=None, editable=None, view=None, create=None, **kw Files to create and make editable. """ remove_files = sorted(parse_arg_as_list(remove), key=cls._natural_sort_key) - editable_files = sorted(parse_arg_as_list(editable), key=cls._natural_sort_key) - view_files = sorted(parse_arg_as_list(view), key=cls._natural_sort_key) + editable_files = sorted(parse_arg_as_list(add), key=cls._natural_sort_key) + view_files = sorted(parse_arg_as_list(read_only), key=cls._natural_sort_key) create_files = sorted(parse_arg_as_list(create), key=cls._natural_sort_key) if not remove_files and not editable_files and not view_files and not create_files: diff --git a/cecli/tools/edit_text.py b/cecli/tools/edit_text.py index 5b4d64f7c3c..c3ce8cfac5f 100644 --- a/cecli/tools/edit_text.py +++ b/cecli/tools/edit_text.py @@ -111,7 +111,7 @@ def execute( """ if not coder.edit_allowed: raise ToolError( - "Please call `GetLines` first to make sure edits are appropriately scoped" + "Please call `ReadRange` first to make sure edits are appropriately scoped" ) tool_name = "EditText" diff --git a/cecli/tools/grep.py b/cecli/tools/grep.py index 68ca5a103b2..03f51d57275 100644 --- a/cecli/tools/grep.py +++ b/cecli/tools/grep.py @@ -4,6 +4,7 @@ import oslex +from cecli.helpers.hashline import strip_hashline from cecli.run_cmd import run_cmd_subprocess from cecli.tools.utils.base_tool import BaseTool from cecli.tools.utils.output import color_markers, tool_footer, tool_header @@ -108,7 +109,7 @@ def execute( all_results = [] for search_op in searches: - pattern = search_op.get("pattern") + pattern = strip_hashline(search_op.get("pattern")) file_pattern = search_op.get("file_pattern", "*") directory = search_op.get("directory", search_op.get("path", ".")) use_regex = search_op.get("use_regex", False) diff --git a/cecli/tools/get_lines.py b/cecli/tools/read_range.py similarity index 92% rename from cecli/tools/get_lines.py rename to cecli/tools/read_range.py index 274db580d95..c845ff9d687 100644 --- a/cecli/tools/get_lines.py +++ b/cecli/tools/read_range.py @@ -1,5 +1,6 @@ import json import os +from typing import Dict from cecli.helpers.hashline import hashline, strip_hashline from cecli.tools.utils.base_tool import BaseTool @@ -13,11 +14,11 @@ class Tool(BaseTool): - NORM_NAME = "getlines" + NORM_NAME = "readrange" SCHEMA = { "type": "function", "function": { - "name": "GetLines", + "name": "ReadRange", "description": ( "Get hashline prefixes of content between start and end patterns in files." " Accepts an array of `show` objects, each with file_path, start_text," @@ -78,6 +79,7 @@ class Tool(BaseTool): } _last_invocation = {} # file_path -> {start_idx, end_idx} + _last_read_turn: Dict[str, int] = {} # abs_path -> turn_count when last read @classmethod def execute(cls, coder, show, **kwargs): @@ -87,8 +89,8 @@ def execute(cls, coder, show, **kwargs): Accepts an array of show operations to perform. Uses utility functions for path resolution and error handling. """ - tool_name = "GetLines" - already_up_to_date = False + tool_name = "ReadRange" + already_up_to_date = None try: # 1. Validate show parameter @@ -277,11 +279,25 @@ def execute(cls, coder, show, **kwargs): abs_path ) - if original_context_content and original_context_content == new_context_content: + if ( + original_context_content + and original_context_content == new_context_content + and already_up_to_date is not False + ): already_up_to_date = True else: + already_up_to_date = False + + # Conditionally remove old file context messages + # If the file was last read >= 10 turns ago, keep old messages (allow coexistence) + # Otherwise, remove them to avoid duplicates + last_turn = cls._last_read_turn.get(abs_path) + if last_turn is None or coder.turn_count - last_turn < 10: ConversationService.get_files(coder).remove_file_messages(abs_path) + # Update the last read turn for this file + cls._last_read_turn[abs_path] = coder.turn_count + ConversationService.get_chunks(coder).add_file_context_messages() # Log success and return the formatted context directly @@ -290,8 +306,8 @@ def execute(cls, coder, show, **kwargs): if already_up_to_date: coder.io.tool_output("File contents already up to date") return ( - "File contents already up to date." - " Do not call `GetLines` again with these parameters until you edit the file." + "Lines already up to date in context for these files." + " Do not call `ReadRange` again with these parameters again unless you edit the relevant files." ) else: coder.io.tool_output(f"✅ Successfully retrieved context for {len(show)} file(s)") @@ -306,7 +322,7 @@ def execute(cls, coder, show, **kwargs): @classmethod def format_output(cls, coder, mcp_server, tool_response): - """Format output for GetLines tool.""" + """Format output for ReadRange tool.""" color_start, color_end = color_markers(coder) try: diff --git a/cecli/tools/utils/base_tool.py b/cecli/tools/utils/base_tool.py index f35f2ebf3ef..f31f8037bae 100644 --- a/cecli/tools/utils/base_tool.py +++ b/cecli/tools/utils/base_tool.py @@ -120,7 +120,8 @@ def process_response(cls, coder, params): ) # Add current invocation to history (keeping only last 3) - cls._invocations[tool_name].append((current_params_tuple, params)) + if params: + cls._invocations[tool_name].append((current_params_tuple, params)) if len(cls._invocations[tool_name]) > 3: cls._invocations[tool_name] = cls._invocations[tool_name][-3:] diff --git a/cecli/tui/app.py b/cecli/tui/app.py index fc87bd7211b..6a0e9db6f72 100644 --- a/cecli/tui/app.py +++ b/cecli/tui/app.py @@ -106,7 +106,10 @@ def __init__(self, coder_worker, output_queue, input_queue, args): show=True, ) self.bind( - self._encode_keys(self.get_keys_for("cancel")), "noop", description="Cancel", show=True + self._encode_keys(self.get_keys_for("cancel")), + "interrupt", + description="Cancel", + show=True, ) self.bind( self._encode_keys(self.get_keys_for("editor")), diff --git a/tests/tools/test_get_lines.py b/tests/tools/test_get_lines.py index b70e6b12e96..50a1108237f 100644 --- a/tests/tools/test_get_lines.py +++ b/tests/tools/test_get_lines.py @@ -4,7 +4,7 @@ import pytest -from cecli.tools import get_lines +from cecli.tools import read_range class DummyIO: @@ -29,6 +29,8 @@ def __init__(self, root): self.uuid = str(uuid.uuid4()) # Generate unique UUID for each instance + self.turn_count = 0 + def abs_root_path(self, file_path): path = Path(file_path) if path.is_absolute(): @@ -50,7 +52,7 @@ def coder_with_file(tmp_path): def test_pattern_with_zero_line_number_is_allowed(coder_with_file): coder, file_path = coder_with_file - result = get_lines.Tool.execute( + result = read_range.Tool.execute( coder, show=[ { @@ -70,7 +72,7 @@ def test_pattern_with_zero_line_number_is_allowed(coder_with_file): def test_empty_pattern_uses_line_number(coder_with_file): coder, file_path = coder_with_file - result = get_lines.Tool.execute( + result = read_range.Tool.execute( coder, show=[ { @@ -91,7 +93,7 @@ def test_conflicting_pattern_and_line_number_raise(coder_with_file): coder, file_path = coder_with_file # Test that missing start_text raises an error - result = get_lines.Tool.execute( + result = read_range.Tool.execute( coder, show=[ { @@ -126,7 +128,7 @@ def test_multiline_pattern_search(coder_with_file): coder, file_path = coder_with_file # file_path contains "alpha\nbeta\ngamma\n" - result = get_lines.Tool.execute( + result = read_range.Tool.execute( coder, show=[ {