diff --git a/lib/crewai/README.md b/lib/crewai/README.md index 7faeae0fa5..e9d430c0e6 100644 --- a/lib/crewai/README.md +++ b/lib/crewai/README.md @@ -547,6 +547,35 @@ This example demonstrates how to: 3. Use Flow decorators to manage the sequence of operations 4. Implement conditional branching based on Crew results +### Mimir Memory Backend + +CrewAI now supports **Mimir** as a persistent memory engine, integrated seamlessly via the Model Context Protocol (MCP). + +#### Prerequisites +Ensure you have the official Mimir binary installed and accessible in your system environment. For detailed setup and installation instructions, please visit the official repository: +👉 [Perseus-Computing-LLC/mimir](https://github.com/Perseus-Computing-LLC/mimir) + +#### Setup +Since CrewAI communicates with Mimir using MCP via standard I/O subprocesses, you must ensure the `mcp` Python package is installed (automatically handled by CrewAI dependencies). + +To use `MimirStorage` as your memory backend, initialize it within your Crew setup by providing the configuration dictionary containing your custom database path: + +```python +from crewai import Crew +from crewai.memory.storage.mimir_storage import MimirStorage + +mimir_config = { + "db_path": "~/.mimir/custom_mimir.db" +} + +crew = Crew( + agents=[...], + tasks=[...], + memory=True, + storage=MimirStorage(config=mimir_config) +) + + ## Connecting Your Crew to a Model CrewAI supports using various LLMs through a variety of connection options. By default your agents will use the OpenAI API when querying the model. However, there are several other ways to allow your agents to connect to models. For example, you can configure your agents to use a local model via the Ollama tool. diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index ded6bb40ab..add9c7a2ff 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -6,7 +6,7 @@ from hashlib import md5 from pathlib import Path import re -from typing import TYPE_CHECKING, Annotated, Any, Final, Literal +from typing import TYPE_CHECKING, Annotated, Any, Final, List, Literal, Optional import uuid from pydantic import ( @@ -248,6 +248,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): Set private attributes. """ + capabilities: Optional[List[str]] = Field( + default=None, + description="List of deterministic capabilities or permissions assigned to the agent." + ) + require_approval_for: Optional[List[str]] = Field( + default=None, + description="List of tools or capabilities that require manual human approval before execution." + ) entity_type: Literal["agent"] = "agent" __hash__ = object.__hash__ diff --git a/lib/crewai/src/crewai/memory/storage/factory.py b/lib/crewai/src/crewai/memory/storage/factory.py index 3dac6dcd40..1964daf671 100644 --- a/lib/crewai/src/crewai/memory/storage/factory.py +++ b/lib/crewai/src/crewai/memory/storage/factory.py @@ -15,9 +15,8 @@ """ from __future__ import annotations - from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: @@ -45,11 +44,27 @@ def set_memory_storage_factory(factory: MemoryStorageFactory | None) -> None: _factory = factory -def resolve_memory_storage(spec: str) -> StorageBackend | None: +def resolve_memory_storage(spec: str, config: Optional[dict] = None) -> StorageBackend | None: """Return the registered factory's backend for ``spec``, or ``None``. ``None`` means no factory is registered or it declined this spec; the caller then falls back to the built-in selection. """ + # First, respect user-registered custom factories if available factory = _factory - return factory(spec) if factory is not None else None + if factory is not None: + try: + # Try to pass config if the custom factory supports it + custom_backend = factory(spec, config=config) # type: ignore + except TypeError: + custom_backend = factory(spec) + + if custom_backend is not None: + return custom_backend + + # Built-in fallback to Mimir storage if the spec matches + if spec == "mimir": + from crewai.memory.storage.mimir_storage import MimirStorage + return MimirStorage(config=config) + + return None diff --git a/lib/crewai/src/crewai/memory/storage/mimir_storage.py b/lib/crewai/src/crewai/memory/storage/mimir_storage.py new file mode 100644 index 0000000000..dd13cbefe0 --- /dev/null +++ b/lib/crewai/src/crewai/memory/storage/mimir_storage.py @@ -0,0 +1,152 @@ +import os +import shutil +import subprocess +import json +import logging +import hashlib +import re +from typing import List, Dict, Any, Optional, Tuple + +from crewai.memory.storage.backend import StorageBackend +# CodeRabbit Fix: Direct import to fail-fast and avoid masking integration issues +from crewai.memory.storage.interface import MemoryRecord # type: ignore + +logger = logging.getLogger(__name__) + +class MimirStorage(StorageBackend): + def __init__(self, config: Optional[Dict[str, Any]] = None): + self.config = config or {} + + # Resolve db_path from config dictionary, expanding '~' to home directory + raw_db_path = self.config.get("db_path", "~/mimir_db") + self.db_path = os.path.expanduser(raw_db_path) + os.makedirs(self.db_path, exist_ok=True) + + # Verify mimir binary availability in common paths or system PATH + self.mimir_path = self._find_mimir_binary() + if not self.mimir_path: + raise FileNotFoundError( + "The 'mimir' binary could not be found. Please ensure it is installed " + "and available in PATH or at common locations (~/.cargo/bin/mimir, /usr/local/bin/mimir)." + ) + + def _find_mimir_binary(self) -> Optional[str]: + """Checks common paths and system PATH for the mimir binary.""" + path_binary = shutil.which("mimir") + if path_binary: + return path_binary + + common_paths = [ + os.path.expanduser("~/.cargo/bin/mimir"), + "/usr/local/bin/mimir", + "/usr/bin/mimir" + ] + for path in common_paths: + if os.path.isfile(path) and os.access(path, os.X_OK): + return path + return None + + def _validate_inputs(self, category: str, query: Optional[str] = None) -> None: + """Validates input arguments to safeguard against CLI/flag injection attacks.""" + if category and not re.match(r"^[A-Za-z0-9_-]+$", category): + raise ValueError(f"Malicious characters detected in scope/category: '{category}'") + if query and query.startswith("-"): + raise ValueError("Query string cannot start with a hyphen to prevent flag injection.") + + def save(self, records: List[MemoryRecord]) -> None: + """Saves a list of MemoryRecords conforming to the StorageBackend protocol.""" + for record in records: + value_str = str(record.value) + + # Generate a persistent deterministic hash key using hashlib MD5 + hash_suffix = hashlib.md5(value_str.encode('utf-8')).hexdigest()[:12] + key = f"memory_{hash_suffix}" + + # Scope memories using config metadata or default category + category = record.metadata.get("agent_id", "default") + self._validate_inputs(category) + + # Prepare payload + payload = { + "key": key, + "value": value_str, + "category": category, + "metadata": record.metadata + } + + # Call the subprocess using '--db' flag per Mimir CLI docs (with 10s timeout) + try: + cmd = [self.mimir_path, "--db", self.db_path, "store", json.dumps(payload)] + subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=10) + logger.info(f"Successfully stored memory with key: {key}") + except subprocess.TimeoutExpired as te: + logger.error(f"Mimir store operation timed out: {te}") + raise te + except subprocess.CalledProcessError as e: + logger.error(f"Failed to store memory in Mimir: {e.stderr}") + raise e + + def search( + self, + query: Any, + limit: Optional[int] = None, + scope_prefix: Optional[str] = None, + categories: Optional[List[str]] = None, + min_score: Optional[float] = None, + metadata_filter: Optional[Dict[str, Any]] = None, + **kwargs + ) -> List[Tuple[MemoryRecord, float]]: + """Searches memories and returns a list of (MemoryRecord, score) tuples.""" + query_str = query if isinstance(query, str) else str(query) + + actual_limit = limit if limit is not None else 3 + + category = scope_prefix if scope_prefix else "default" + + if categories and len(categories) > 0: + category = categories[0] + + self._validate_inputs(category, query_str) + + try: + cmd = [ + self.mimir_path, + "--db", self.db_path, + "search", + query_str, + "--limit", str(actual_limit), + "--category", category + ] + result = subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=10) + + raw_results = json.loads(result.stdout) + formatted_results = [] + + for res in raw_results: + content_text = res.get("value", res.get("text", "")) + score = float(res.get("score", 0.0)) + meta = res.get("metadata", {}) + + if min_score is not None and score < min_score: + continue + + if metadata_filter: + match = True + for k, v in metadata_filter.items(): + if meta.get(k) != v: + match = False + break + if not match: + continue + + # Construct official MemoryRecord instances + record = MemoryRecord(value=content_text, metadata=meta) + formatted_results.append((record, score)) + + return formatted_results + except subprocess.TimeoutExpired as te: + logger.error(f"Mimir search operation timed out: {te}") + raise te + except subprocess.CalledProcessError as e: + logger.error(f"Search failed in Mimir: {e.stderr}") + raise e \ No newline at end of file diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index c63cfe8666..826dc80fa3 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -650,7 +650,7 @@ async def _aexecute_core( raise Exception( f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical." ) - + tokens_before = self._get_agent_token_usage(agent) self.prompt_context = context tools = tools or self.tools or [] @@ -688,6 +688,9 @@ async def _aexecute_core( raw = result pydantic_output, json_output = None, None + tokens_after = self._get_agent_token_usage(agent) + token_delta = self._calculate_token_delta(tokens_before, tokens_after) + task_output = TaskOutput( name=self.name or self.description, description=self.description, @@ -698,6 +701,7 @@ async def _aexecute_core( agent=agent.role, output_format=self._get_output_format(), messages=agent.last_messages, # type: ignore[attr-defined] + token_usage=token_delta, ) if self._guardrails: @@ -775,7 +779,7 @@ def _execute_core( raise Exception( f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical." ) - + tokens_before = self._get_agent_token_usage(agent) self.prompt_context = context tools = tools or self.tools or [] @@ -813,6 +817,9 @@ def _execute_core( raw = result pydantic_output, json_output = None, None + tokens_after = self._get_agent_token_usage(agent) + token_delta = self._calculate_token_delta(tokens_before, tokens_after) + task_output = TaskOutput( name=self.name or self.description, description=self.description, @@ -823,6 +830,7 @@ def _execute_core( agent=agent.role, output_format=self._get_output_format(), messages=agent.last_messages, # type: ignore[attr-defined] + token_usage=token_delta, ) if self._guardrails: @@ -1149,7 +1157,30 @@ async def _aexport_output( pydantic_output, json_output = self._unpack_model_output(model_output) return pydantic_output, json_output + + def _get_agent_token_usage(self, agent: BaseAgent) -> Any: + """Capture the current snapshot of tokens consumed by the agent's LLM.""" + if hasattr(agent, "llm") and hasattr(agent.llm, "token_usage"): + return shallow_copy(agent.llm.token_usage) + return None + + def _calculate_token_delta(self, before: Any, after: Any) -> Any: + """Calculate the delta of token usage.""" + if isinstance(before, dict) and isinstance(after, dict): + delta_dict = {} + for key, after_val in after.items(): + before_val = before.get(key, 0) + if isinstance(after_val, (int, float)) and isinstance(before_val, (int, float)): + delta_dict[key] = after_val - before_val + else: + delta_dict[key] = after_val + return delta_dict + try: + return after - before + except Exception: + return after + @staticmethod def _unpack_model_output( model_output: dict[str, Any] | BaseModel | str, @@ -1262,6 +1293,7 @@ def _invoke_guardrail_function( max_attempts = self.guardrail_max_retries + 1 for attempt in range(max_attempts): + current_token_usage = getattr(task_output, 'token_usage', None) guardrail_result = process_guardrail( output=task_output, guardrail=guardrail, @@ -1286,7 +1318,8 @@ def _invoke_guardrail_function( task_output.json_dict = json_output elif isinstance(guardrail_result.result, TaskOutput): task_output = guardrail_result.result - + if getattr(task_output, 'token_usage', None) is None and current_token_usage is not None: + task_output.token_usage = current_token_usage return task_output if attempt >= self.guardrail_max_retries: @@ -1348,6 +1381,7 @@ def _invoke_guardrail_function( agent=agent.role, output_format=self._get_output_format(), messages=agent.last_messages, # type: ignore[attr-defined] + token_usage=current_token_usage, ) return task_output @@ -1372,6 +1406,7 @@ async def _ainvoke_guardrail_function( max_attempts = self.guardrail_max_retries + 1 for attempt in range(max_attempts): + current_token_usage = getattr(task_output, 'token_usage', None) guardrail_result = process_guardrail( output=task_output, guardrail=guardrail, @@ -1396,7 +1431,8 @@ async def _ainvoke_guardrail_function( task_output.json_dict = json_output elif isinstance(guardrail_result.result, TaskOutput): task_output = guardrail_result.result - + if getattr(task_output, 'token_usage', None) is None and current_token_usage is not None: + task_output.token_usage = current_token_usage return task_output if attempt >= self.guardrail_max_retries: @@ -1458,6 +1494,7 @@ async def _ainvoke_guardrail_function( agent=agent.role, output_format=self._get_output_format(), messages=agent.last_messages, # type: ignore[attr-defined] + token_usage=current_token_usage, ) return task_output diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index c6c3dba15b..074483d2f6 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -10,6 +10,7 @@ from typing import ( Any, Generic, + Optional, ParamSpec, TypeVar, overload, @@ -151,12 +152,23 @@ def _validate_tool(value: Any, nxt: Any) -> Any: validate_default=True, description="The schema for the arguments that the tool accepts.", ) + fix/task-token-tracking + args_schema: type[PydanticBaseModel] = Field( + default=_ArgsSchemaPlaceholder, + validate_default=True, + description="The schema for the arguments that the tool accepts.", + ) + required_capability: Optional[str] = Field( + default=None, + description="The specific capability required to execute this tool." + ) result_schema: type[PydanticBaseModel] | None = Field( default=None, validate_default=True, description="The schema for the output that the tool returns.", ) + main @field_serializer("args_schema", when_used="json") def _serialize_args_schema( self, schema: type[PydanticBaseModel] | None @@ -405,6 +417,7 @@ def to_structured_tool(self) -> CrewStructuredTool: cache_function=self.cache_function, ) structured_tool._original_tool = self + setattr(structured_tool, "required_capability", self.required_capability) return structured_tool @classmethod diff --git a/lib/crewai/src/crewai/tools/tool_usage.py b/lib/crewai/src/crewai/tools/tool_usage.py index e92ba03eed..37bf201ed1 100644 --- a/lib/crewai/src/crewai/tools/tool_usage.py +++ b/lib/crewai/src/crewai/tools/tool_usage.py @@ -863,11 +863,36 @@ def _tool_calling( ) -> ToolCalling | InstructorToolCalling | ToolUsageError: try: try: - return self._original_tool_calling(tool_string, raise_error=True) + tool_calling = self._original_tool_calling(tool_string, raise_error=True) except Exception: if self.function_calling_llm: - return self._function_calling(tool_string) - return self._original_tool_calling(tool_string) + tool_calling = self._function_calling(tool_string) + else: + tool_calling = self._original_tool_calling(tool_string) + if tool_calling and not isinstance(tool_calling, ToolUsageError): + tool = self._select_tool(tool_calling.tool_name) + + if tool and hasattr(tool, "required_capability"): + required_cap = tool.required_capability + if required_cap: + agent_caps = getattr(self.agent, "capabilities", None) or [] + # Block execution if the agent lacks the required capability + if required_cap not in agent_caps: + error_msg = f"Security Violation: Agent '{self.agent.role}' lacks the required capability '{required_cap}' to execute tool '{tool.name}'." + if self.agent and self.agent.verbose: + PRINTER.print(content=error_msg, color="red") + return ToolUsageError(error_msg) + + # Human-in-the-loop gating / Manual approval check + if tool and self.agent: + require_approval_for = getattr(self.agent, "require_approval_for", None) or [] + if tool.name in require_approval_for or (hasattr(tool, "required_capability") and tool.required_capability in require_approval_for): + if hasattr(self, "_ask_human_approval"): + approved = self._ask_human_approval(tool.name) + if not approved: + return ToolUsageError("Execution cancelled by human operator.") + return tool_calling + except Exception as e: self._run_attempts += 1 if self._run_attempts > self._max_parsing_attempts: