diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 5af8566be9..faf7e1d73f 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -40,7 +40,7 @@ from pyrit.models.data_type_serializer import data_serializer_factory from pyrit.prompt_converter import PromptConverter from pyrit.prompt_target import PromptChatTarget -from pyrit.registry.instance_registries import ConverterRegistry +from pyrit.registry.object_registries import ConverterRegistry _DATA_TYPE_EXTENSION: dict[str, str] = { "image_path": ".png", diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 7959a80408..26d66c8fa1 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -24,7 +24,7 @@ TargetListResponse, ) from pyrit.prompt_target import PromptTarget -from pyrit.registry.instance_registries import TargetRegistry +from pyrit.registry.object_registries import TargetRegistry def _build_target_class_registry() -> dict[str, type]: diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 76c162d342..4f8290e993 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Registry module for PyRIT class and instance registries.""" +"""Registry module for PyRIT class and object registries.""" from pyrit.registry.base import RegistryProtocol from pyrit.registry.class_registries import ( @@ -17,16 +17,22 @@ discover_in_package, discover_subclasses_in_loaded_modules, ) -from pyrit.registry.instance_registries import ( +from pyrit.registry.object_registries import ( + AttackTechniqueRegistry, BaseInstanceRegistry, + ConverterRegistry, RegistryEntry, + RetrievableInstanceRegistry, ScorerRegistry, TargetRegistry, ) __all__ = [ + "AttackTechniqueRegistry", "BaseClassRegistry", "BaseInstanceRegistry", + "ConverterRegistry", + "RetrievableInstanceRegistry", "ClassEntry", "discover_in_directory", "discover_in_package", diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index 02a973a869..766c554fc6 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -5,12 +5,17 @@ Shared base types for PyRIT registries. This module contains types shared between class registries (which store Type[T]) -and instance registries (which store T instances). +and object registries (which store T instances). """ -from collections.abc import Iterator +from __future__ import annotations + from dataclasses import dataclass -from typing import Any, Optional, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Optional, Protocol, TypeVar, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Self # Type variable for metadata (invariant for Protocol compatibility) MetadataT = TypeVar("MetadataT") @@ -43,7 +48,7 @@ class RegistryProtocol(Protocol[MetadataT]): """ Protocol defining the common interface for all registries. - Both class registries (BaseClassRegistry) and instance registries + Both class registries (BaseClassRegistry) and object registries (BaseInstanceRegistry) implement this interface, enabling code that works with either registry type. @@ -52,7 +57,7 @@ class RegistryProtocol(Protocol[MetadataT]): """ @classmethod - def get_registry_singleton(cls) -> "RegistryProtocol[MetadataT]": + def get_registry_singleton(cls) -> Self: """Get the singleton instance of this registry.""" ... diff --git a/pyrit/registry/class_registries/__init__.py b/pyrit/registry/class_registries/__init__.py index b29a6a6eed..c003ccc59a 100644 --- a/pyrit/registry/class_registries/__init__.py +++ b/pyrit/registry/class_registries/__init__.py @@ -7,7 +7,7 @@ This package contains registries that store classes (Type[T]) which can be instantiated on demand. Examples include ScenarioRegistry and InitializerRegistry. -For registries that store pre-configured instances, see instance_registries/. +For registries that store pre-configured instances, see object_registries/. """ from pyrit.registry.class_registries.base_class_registry import ( diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index ba735d8b2d..6b44c6e832 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -7,7 +7,7 @@ This module provides the abstract base class for registries that store classes (Type[T]). These registries allow on-demand instantiation of registered classes. -For registries that store pre-configured instances, see instance_registries/. +For registries that store pre-configured instances, see object_registries/. Terminology: - **Metadata**: A TypedDict describing a registered class (e.g., ScenarioMetadata) @@ -16,9 +16,14 @@ - **ClassEntry**: Internal wrapper holding a class plus optional factory/defaults """ +from __future__ import annotations + from abc import ABC, abstractmethod -from collections.abc import Callable, Iterator -from typing import Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Generic, Optional, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing import Self from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.registry.base import RegistryProtocol @@ -107,7 +112,7 @@ class BaseClassRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]) """ # Class-level singleton instances, keyed by registry class - _instances: dict[type, "BaseClassRegistry[object, object]"] = {} + _instances: dict[type, BaseClassRegistry[object, object]] = {} def __init__(self, *, lazy_discovery: bool = True) -> None: """ @@ -128,7 +133,7 @@ def __init__(self, *, lazy_discovery: bool = True) -> None: self._discovered = True @classmethod - def get_registry_singleton(cls) -> "BaseClassRegistry[T, MetadataT]": + def get_registry_singleton(cls) -> Self: """ Get the singleton instance of this registry. diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 4a23320535..4007c58d66 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -66,16 +66,6 @@ class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetad The directory structure is used for organization but not exposed to users. """ - @classmethod - def get_registry_singleton(cls) -> InitializerRegistry: - """ - Get the singleton instance of the InitializerRegistry. - - Returns: - The singleton InitializerRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: bool = False) -> None: """ Initialize the initializer registry. diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 6f6d949adf..f8b0e3e87f 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -67,16 +67,6 @@ class ScenarioRegistry(BaseClassRegistry["Scenario", ScenarioMetadata]): Scenarios are identified by their dotted name (e.g., "garak.encoding", "foundry.red_team_agent"). """ - @classmethod - def get_registry_singleton(cls) -> ScenarioRegistry: - """ - Get the singleton instance of the ScenarioRegistry. - - Returns: - The singleton ScenarioRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def __init__(self, *, lazy_discovery: bool = True) -> None: """ Initialize the scenario registry. diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/object_registries/__init__.py similarity index 52% rename from pyrit/registry/instance_registries/__init__.py rename to pyrit/registry/object_registries/__init__.py index d635813936..0a43a5af2f 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/object_registries/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. """ -Instance registries package. +Object registries package. This package contains registries that store pre-configured instances (not classes). Examples include ScorerRegistry which stores Scorer instances that have been @@ -11,25 +11,33 @@ For registries that store classes (Type[T]), see class_registries/. """ -from pyrit.registry.instance_registries.base_instance_registry import ( +from pyrit.registry.object_registries.attack_technique_registry import ( + AttackTechniqueRegistry, +) +from pyrit.registry.object_registries.base_instance_registry import ( BaseInstanceRegistry, RegistryEntry, ) -from pyrit.registry.instance_registries.converter_registry import ( +from pyrit.registry.object_registries.converter_registry import ( ConverterRegistry, ) -from pyrit.registry.instance_registries.scorer_registry import ( +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, +) +from pyrit.registry.object_registries.scorer_registry import ( ScorerRegistry, ) -from pyrit.registry.instance_registries.target_registry import ( +from pyrit.registry.object_registries.target_registry import ( TargetRegistry, ) __all__ = [ - # Base class + # Base classes "BaseInstanceRegistry", + "RetrievableInstanceRegistry", "RegistryEntry", # Concrete registries + "AttackTechniqueRegistry", "ConverterRegistry", "ScorerRegistry", "TargetRegistry", diff --git a/pyrit/registry/object_registries/attack_technique_registry.py b/pyrit/registry/object_registries/attack_technique_registry.py new file mode 100644 index 0000000000..2b68ffd651 --- /dev/null +++ b/pyrit/registry/object_registries/attack_technique_registry.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +AttackTechniqueRegistry — Singleton registry of reusable attack technique factories. + +Scenarios and initializers register technique factories (capturing technique-specific +config). Scenarios retrieve them via ``create_technique()``, which calls the factory +with the scenario's objective target and scorer. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pyrit.registry.object_registries.base_instance_registry import ( + BaseInstanceRegistry, +) + +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_config import ( + AttackAdversarialConfig, + AttackConverterConfig, + AttackScoringConfig, + ) + from pyrit.prompt_target import PromptTarget + from pyrit.scenario.core.attack_technique import AttackTechnique + from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + +logger = logging.getLogger(__name__) + + +class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory"]): + """ + Singleton registry of reusable attack technique factories. + + Scenarios and initializers register technique factories (capturing + technique-specific config). Scenarios retrieve them via ``create_technique()``, + which calls the factory with the scenario's objective target and scorer. + """ + + def register_technique( + self, + *, + name: str, + factory: AttackTechniqueFactory, + tags: dict[str, str] | list[str] | None = None, + ) -> None: + """ + Register an attack technique factory. + + Args: + name: The registry name for this technique. + factory: The factory that produces attack techniques. + tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` + or a ``list[str]`` (each string becomes a key with value ``""``). + """ + self.register(factory, name=name, tags=tags) + logger.debug(f"Registered attack technique factory: {name} ({factory.attack_class.__name__})") + + def create_technique( + self, + name: str, + *, + objective_target: PromptTarget, + attack_scoring_config: AttackScoringConfig, + attack_adversarial_config: AttackAdversarialConfig | None = None, + attack_converter_config: AttackConverterConfig | None = None, + ) -> AttackTechnique: + """ + Retrieve a factory by name and produce a fresh attack technique. + + Args: + name: The registry name of the technique. + objective_target: The target to attack. + attack_scoring_config: Scoring configuration for the attack. + attack_adversarial_config: Optional adversarial configuration override. + attack_converter_config: Optional converter configuration override. + + Returns: + A fresh AttackTechnique with a newly-constructed attack strategy. + + Raises: + KeyError: If no technique is registered with the given name. + """ + entry = self._registry_items.get(name) + if entry is None: + raise KeyError(f"No technique registered with name '{name}'") + return entry.instance.create( + objective_target=objective_target, + attack_scoring_config=attack_scoring_config, + attack_adversarial_config=attack_adversarial_config, + attack_converter_config=attack_converter_config, + ) diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/object_registries/base_instance_registry.py similarity index 68% rename from pyrit/registry/instance_registries/base_instance_registry.py rename to pyrit/registry/object_registries/base_instance_registry.py index b29ea6aa08..1d60417b9b 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/object_registries/base_instance_registry.py @@ -4,35 +4,38 @@ """ Base instance registry for PyRIT. -This module provides the abstract base class for registries that store -pre-configured instances (not classes). Unlike class registries which -store Type[T] and create instances on demand, instance registries store -already-instantiated objects. +This module provides ``BaseInstanceRegistry``, the shared infrastructure for +registries that store ``Identifiable`` objects (not classes): singleton +lifecycle, registration, tags, metadata, container protocol. -Examples include: -- ScorerRegistry: stores Scorer instances configured with their chat_target +Subclass directly for registries that store factories or other +non-retrievable items (e.g., ``AttackTechniqueRegistry``). For registries +where callers retrieve stored objects directly, subclass +``RetrievableInstanceRegistry`` instead. + +For registries that store classes (Type[T]), see ``class_registries/``. """ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from pyrit.identifiers import ComponentIdentifier +from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.registry.base import RegistryProtocol if TYPE_CHECKING: from collections.abc import Iterator + from typing import Self -T = TypeVar("T") # The type of instances stored -MetadataT = TypeVar("MetadataT", bound=ComponentIdentifier) +T = TypeVar("T", bound=Identifiable) # The type of items stored @dataclass class RegistryEntry(Generic[T]): """ - A wrapper around a registered instance, holding its name, tags, and the instance itself. + A wrapper around a registered item, holding its name, tags, and the item itself. Tags are always stored as ``dict[str, str]``. When callers pass a plain ``list[str]``, each string is normalized to a key with an empty-string value. @@ -48,27 +51,30 @@ class RegistryEntry(Generic[T]): tags: dict[str, str] = field(default_factory=dict) -class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): +class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): """ - Abstract base class for registries that store pre-configured instances. + Abstract base class providing shared registry infrastructure. - This class implements RegistryProtocol. Unlike BaseClassRegistry which stores - Type[T] and supports lazy discovery, instance registries store already-instantiated - objects that are registered explicitly (typically during initialization). + Provides singleton lifecycle, registration, tag-based lookup, metadata + filtering, and the standard container protocol (``__contains__``, + ``__len__``, ``__iter__``). - Type Parameters: - T: The type of instances stored in the registry. - MetadataT: A TypedDict subclass for instance metadata. + Subclass directly when stored items should not be retrievable via + ``get()`` (e.g., factory registries). For registries that expose + direct item retrieval, subclass ``RetrievableInstanceRegistry`` instead. - Subclasses must implement: - - _build_metadata(): Convert an instance to its metadata representation + All stored items must implement ``Identifiable``, which provides + ``get_identifier()`` for metadata generation. + + Type Parameters: + T: The type of items stored in the registry (must be Identifiable). """ # Class-level singleton instances, keyed by registry class - _instances: dict[type, BaseInstanceRegistry[Any, Any]] = {} + _instances: dict[type, BaseInstanceRegistry[Any]] = {} @classmethod - def get_registry_singleton(cls) -> BaseInstanceRegistry[T, MetadataT]: + def get_registry_singleton(cls) -> Self: """ Get the singleton instance of this registry. @@ -79,7 +85,7 @@ def get_registry_singleton(cls) -> BaseInstanceRegistry[T, MetadataT]: """ if cls not in cls._instances: cls._instances[cls] = cls() - return cls._instances[cls] + return cls._instances[cls] # type: ignore[return-value] @classmethod def reset_instance(cls) -> None: @@ -92,7 +98,7 @@ def reset_instance(cls) -> None: del cls._instances[cls] @staticmethod - def _normalize_tags(tags: Optional[Union[dict[str, str], list[str]]] = None) -> dict[str, str]: + def _normalize_tags(tags: dict[str, str] | list[str] | None = None) -> dict[str, str]: """ Normalize tags into a ``dict[str, str]``. @@ -110,24 +116,24 @@ def _normalize_tags(tags: Optional[Union[dict[str, str], list[str]]] = None) -> return dict(tags) def __init__(self) -> None: - """Initialize the instance registry.""" + """Initialize the registry.""" # Maps registry names to registry entries self._registry_items: dict[str, RegistryEntry[T]] = {} - self._metadata_cache: Optional[list[MetadataT]] = None + self._metadata_cache: list[ComponentIdentifier] | None = None def register( self, instance: T, *, name: str, - tags: Optional[Union[dict[str, str], list[str]]] = None, + tags: dict[str, str] | list[str] | None = None, ) -> None: """ - Register an instance. + Register an item. Args: - instance: The pre-configured instance to register. - name: The registry name for this instance. + instance: The item to register. + name: The registry name for this item. tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` or a ``list[str]`` (each string becomes a key with value ``""``). """ @@ -135,33 +141,6 @@ def register( self._registry_items[name] = RegistryEntry(name=name, instance=instance, tags=normalized) self._metadata_cache = None - def get(self, name: str) -> Optional[T]: - """ - Get a registered instance by name. - - Args: - name: The registry name of the instance. - - Returns: - The instance, or None if not found. - """ - entry = self._registry_items.get(name) - if entry is None: - return None - return entry.instance - - def get_entry(self, name: str) -> Optional[RegistryEntry[T]]: - """ - Get a full registry entry by name, including tags. - - Args: - name: The registry name of the entry. - - Returns: - The RegistryEntry, or None if not found. - """ - return self._registry_items.get(name) - def get_names(self) -> list[str]: """ Get a sorted list of all registered names. @@ -171,20 +150,11 @@ def get_names(self) -> list[str]: """ return sorted(self._registry_items.keys()) - def get_all_instances(self) -> list[RegistryEntry[T]]: - """ - Get all registered entries sorted by name. - - Returns: - List of RegistryEntry objects sorted by name. - """ - return [self._registry_items[name] for name in sorted(self._registry_items.keys())] - def get_by_tag( self, *, tag: str, - value: Optional[str] = None, + value: str | None = None, ) -> list[RegistryEntry[T]]: """ Get all entries that have a given tag, optionally matching a specific value. @@ -208,7 +178,7 @@ def add_tags( self, *, name: str, - tags: Union[dict[str, str], list[str]], + tags: dict[str, str] | list[str], ) -> None: """ Add tags to an existing registry entry. @@ -275,11 +245,11 @@ def find_dependents_of_tag(self, *, tag: str) -> list[RegistryEntry[T]]: def list_metadata( self, *, - include_filters: Optional[dict[str, object]] = None, - exclude_filters: Optional[dict[str, object]] = None, - ) -> list[MetadataT]: + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[ComponentIdentifier]: """ - List metadata for all registered instances, optionally filtered. + List metadata for all registered items, optionally filtered. Supports filtering on any metadata property: - Simple types (str, int, bool): exact match @@ -294,7 +264,7 @@ def list_metadata( Any matching filter excludes the item. Returns: - List of metadata dictionaries describing each registered instance. + List of ComponentIdentifier metadata for each registered item. """ from pyrit.registry.base import _matches_filters @@ -314,19 +284,18 @@ def list_metadata( if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) ] - @abstractmethod - def _build_metadata(self, name: str, instance: T) -> MetadataT: + def _build_metadata(self, name: str, instance: T) -> ComponentIdentifier: """ - Build metadata for an instance. + Build metadata for an item via its ``Identifiable`` interface. Args: - name: The registry name of the instance. - instance: The instance. + name: The registry name of the item. + instance: The item. Returns: - A metadata dictionary describing the instance. + The item's ComponentIdentifier. """ - ... + return instance.get_identifier() def __contains__(self, name: str) -> bool: """ @@ -339,10 +308,10 @@ def __contains__(self, name: str) -> bool: def __len__(self) -> int: """ - Get the count of registered instances. + Get the count of registered items. Returns: - The number of registered instances. + The number of registered items. """ return len(self._registry_items) diff --git a/pyrit/registry/instance_registries/converter_registry.py b/pyrit/registry/object_registries/converter_registry.py similarity index 68% rename from pyrit/registry/instance_registries/converter_registry.py rename to pyrit/registry/object_registries/converter_registry.py index 987c94fa0f..4d83c9e1fd 100644 --- a/pyrit/registry/instance_registries/converter_registry.py +++ b/pyrit/registry/object_registries/converter_registry.py @@ -14,9 +14,8 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.identifiers import ComponentIdentifier -from pyrit.registry.instance_registries.base_instance_registry import ( - BaseInstanceRegistry, +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, ) if TYPE_CHECKING: @@ -25,7 +24,7 @@ logger = logging.getLogger(__name__) -class ConverterRegistry(BaseInstanceRegistry["PromptConverter", ComponentIdentifier]): +class ConverterRegistry(RetrievableInstanceRegistry["PromptConverter"]): """ Registry for managing available converter instances. @@ -34,16 +33,6 @@ class ConverterRegistry(BaseInstanceRegistry["PromptConverter", ComponentIdentif with their required parameters. """ - @classmethod - def get_registry_singleton(cls) -> ConverterRegistry: - """ - Get the singleton instance of the ConverterRegistry. - - Returns: - The singleton ConverterRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def register_instance( self, converter: PromptConverter, @@ -78,16 +67,3 @@ def get_instance_by_name(self, name: str) -> Optional[PromptConverter]: The converter instance, or None if not found. """ return self.get(name) - - def _build_metadata(self, name: str, instance: PromptConverter) -> ComponentIdentifier: - """ - Build metadata for a converter instance. - - Args: - name: The registry name of the converter. - instance: The converter instance. - - Returns: - ComponentIdentifier: The converter's identifier. - """ - return instance.get_identifier() diff --git a/pyrit/registry/object_registries/retrievable_instance_registry.py b/pyrit/registry/object_registries/retrievable_instance_registry.py new file mode 100644 index 0000000000..b5bc4fdfec --- /dev/null +++ b/pyrit/registry/object_registries/retrievable_instance_registry.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Retrievable instance registry for PyRIT. + +This module provides ``RetrievableInstanceRegistry``, which extends +``BaseInstanceRegistry`` with ``get()``, ``get_entry()``, and +``get_all_instances()`` for registries where callers retrieve stored +objects directly (e.g., ``ScorerRegistry``, ``ConverterRegistry``, +``TargetRegistry``). + +For the shared base class, see ``base_instance_registry``. +For registries that store classes (Type[T]), see ``class_registries/``. +""" + +from __future__ import annotations + +from pyrit.registry.object_registries.base_instance_registry import ( + BaseInstanceRegistry, + RegistryEntry, + T, +) + +# Re-export so existing ``from retrievable_instance_registry import ...`` still works +__all__ = ["RetrievableInstanceRegistry", "BaseInstanceRegistry", "RegistryEntry"] + + +class RetrievableInstanceRegistry(BaseInstanceRegistry[T]): + """ + Base class for registries that store directly-retrievable instances. + + Extends ``BaseInstanceRegistry`` with ``get()``, ``get_entry()``, and + ``get_all_instances()`` for registries where callers retrieve the + stored objects directly (e.g., scorers, converters, targets). + + For registries that store factories or other non-retrievable items, + subclass ``BaseInstanceRegistry`` directly instead. + + Type Parameters: + T: The type of instances stored in the registry (must be Identifiable). + """ + + def get(self, name: str) -> T | None: + """ + Get a registered instance by name. + + Args: + name: The registry name of the instance. + + Returns: + The instance, or None if not found. + """ + entry = self._registry_items.get(name) + if entry is None: + return None + return entry.instance + + def get_entry(self, name: str) -> RegistryEntry[T] | None: + """ + Get a full registry entry by name, including tags. + + Args: + name: The registry name of the entry. + + Returns: + The RegistryEntry, or None if not found. + """ + return self._registry_items.get(name) + + def get_all_instances(self) -> list[RegistryEntry[T]]: + """ + Get all registered entries sorted by name. + + Returns: + List of RegistryEntry objects sorted by name. + """ + return [self._registry_items[name] for name in sorted(self._registry_items.keys())] diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/object_registries/scorer_registry.py similarity index 71% rename from pyrit/registry/instance_registries/scorer_registry.py rename to pyrit/registry/object_registries/scorer_registry.py index e1fb7e1e9c..af5c59946f 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/object_registries/scorer_registry.py @@ -12,9 +12,8 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.identifiers import ComponentIdentifier -from pyrit.registry.instance_registries.base_instance_registry import ( - BaseInstanceRegistry, +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, ) if TYPE_CHECKING: @@ -23,7 +22,7 @@ logger = logging.getLogger(__name__) -class ScorerRegistry(BaseInstanceRegistry["Scorer", ComponentIdentifier]): +class ScorerRegistry(RetrievableInstanceRegistry["Scorer"]): """ Registry for managing available scorer instances. @@ -35,16 +34,6 @@ class ScorerRegistry(BaseInstanceRegistry["Scorer", ComponentIdentifier]): or a custom name provided during registration. """ - @classmethod - def get_registry_singleton(cls) -> ScorerRegistry: - """ - Get the singleton instance of the ScorerRegistry. - - Returns: - The singleton ScorerRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def register_instance( self, scorer: Scorer, @@ -84,16 +73,3 @@ def get_instance_by_name(self, name: str) -> Optional[Scorer]: The scorer instance, or None if not found. """ return self.get(name) - - def _build_metadata(self, name: str, instance: Scorer) -> ComponentIdentifier: - """ - Build metadata for a scorer instance. - - Args: - name: The registry name of the scorer. - instance: The scorer instance. - - Returns: - ComponentIdentifier: The scorer's identifier - """ - return instance.get_identifier() diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/object_registries/target_registry.py similarity index 71% rename from pyrit/registry/instance_registries/target_registry.py rename to pyrit/registry/object_registries/target_registry.py index 763d4a9e31..c6fefd3926 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/object_registries/target_registry.py @@ -12,9 +12,8 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.identifiers import ComponentIdentifier -from pyrit.registry.instance_registries.base_instance_registry import ( - BaseInstanceRegistry, +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, ) if TYPE_CHECKING: @@ -23,7 +22,7 @@ logger = logging.getLogger(__name__) -class TargetRegistry(BaseInstanceRegistry["PromptTarget", ComponentIdentifier]): +class TargetRegistry(RetrievableInstanceRegistry["PromptTarget"]): """ Registry for managing available prompt target instances. @@ -35,16 +34,6 @@ class TargetRegistry(BaseInstanceRegistry["PromptTarget", ComponentIdentifier]): or a custom name provided during registration. """ - @classmethod - def get_registry_singleton(cls) -> TargetRegistry: - """ - Get the singleton instance of the TargetRegistry. - - Returns: - The singleton TargetRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def register_instance( self, target: PromptTarget, @@ -85,16 +74,3 @@ def get_instance_by_name(self, name: str) -> Optional[PromptTarget]: The target instance, or None if not found. """ return self.get(name) - - def _build_metadata(self, name: str, instance: PromptTarget) -> ComponentIdentifier: - """ - Build metadata for a target instance. - - Args: - name: The registry name of the target. - instance: The target instance. - - Returns: - ComponentIdentifier: The target's identifier. - """ - return instance.get_identifier() diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 53f865305a..e8ebfb2946 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -19,6 +19,7 @@ from pyrit.scenario.core import ( AtomicAttack, AttackTechnique, + AttackTechniqueFactory, DatasetConfiguration, Scenario, ScenarioCompositeStrategy, @@ -44,6 +45,7 @@ __all__ = [ "AtomicAttack", "AttackTechnique", + "AttackTechniqueFactory", "DatasetConfiguration", "Scenario", "ScenarioCompositeStrategy", diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 55c4517aca..8f40282bef 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -5,6 +5,7 @@ from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy @@ -12,6 +13,7 @@ __all__ = [ "AtomicAttack", "AttackTechnique", + "AttackTechniqueFactory", "DatasetConfiguration", "EXPLICIT_SEED_GROUPS_KEY", "Scenario", diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py new file mode 100644 index 0000000000..fac94e4932 --- /dev/null +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +AttackTechniqueFactory — Deferred construction of AttackTechnique instances. + +Captures technique-specific configuration at registration time and produces +fresh, fully-constructed attacks when scenario-specific params (objective target, +scorer) become available. +""" + +from __future__ import annotations + +import copy +import inspect +from typing import TYPE_CHECKING, Any + +from pyrit.identifiers import ComponentIdentifier, Identifiable +from pyrit.scenario.core.attack_technique import AttackTechnique + +if TYPE_CHECKING: + from pyrit.executor.attack import AttackStrategy + from pyrit.executor.attack.core.attack_config import ( + AttackAdversarialConfig, + AttackConverterConfig, + AttackScoringConfig, + ) + from pyrit.models import SeedAttackTechniqueGroup + from pyrit.prompt_target import PromptTarget + + +class AttackTechniqueFactory(Identifiable): + """ + A factory that produces AttackTechnique instances on demand. + + Captures technique-specific configuration (converters, adversarial config, + tree depth, etc.) at registration time. Produces fresh, fully-constructed + attacks by calling the real constructor with the captured params plus + scenario-specific objective_target and scoring config. + + Validates kwargs against the attack class constructor signature at + construction time, catching typos and incompatible parameter names early. + """ + + def __init__( + self, + *, + attack_class: type[AttackStrategy[Any, Any]], + attack_kwargs: dict[str, Any] | None = None, + seed_technique: SeedAttackTechniqueGroup | None = None, + ) -> None: + """ + Initialize the factory with a technique-specific configuration. + + Args: + attack_class: The AttackStrategy subclass to instantiate. + attack_kwargs: Keyword arguments to pass to the attack constructor. + Must not include ``objective_target`` (provided at create time). + seed_technique: Optional technique seed group to attach to created techniques. + + Raises: + TypeError: If any kwarg name is not a valid constructor parameter, + or if the attack class constructor uses ``**kwargs``. + ValueError: If ``objective_target`` is included in attack_kwargs. + """ + self._attack_class = attack_class + self._attack_kwargs = copy.deepcopy(attack_kwargs) if attack_kwargs else {} + self._seed_technique = seed_technique + + self._validate_kwargs() + + def _validate_kwargs(self) -> None: + """ + Validate that all kwargs are valid parameters for the attack class constructor. + + Uses ``inspect.signature`` on the attack class ``__init__``, which works through + the ``@apply_defaults`` decorator (it uses ``functools.wraps``). + + Raises: + TypeError: If any kwarg name is not a valid constructor parameter, + or if the constructor uses ``**kwargs`` (all parameters must be + explicitly named). + ValueError: If ``objective_target`` is included in attack_kwargs. + """ + if "objective_target" in self._attack_kwargs: + raise ValueError("objective_target must not be in attack_kwargs — it is provided at create() time.") + + sig = inspect.signature(self._attack_class.__init__) + + # Reject constructors that accept **kwargs — we require explicitly named + # parameters so that validation is meaningful. + has_var_keyword = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()) + if has_var_keyword: + raise TypeError( + f"{self._attack_class.__name__}.__init__ accepts **kwargs, which prevents " + f"parameter validation. All attack constructor parameters must be explicitly named." + ) + + valid_params = { + name + for name, param in sig.parameters.items() + if name != "self" + and param.kind + in ( + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + } + + invalid = set(self._attack_kwargs) - valid_params + if invalid: + raise TypeError( + f"Invalid kwargs for {self._attack_class.__name__}: {sorted(invalid)}. " + f"Valid parameters: {sorted(valid_params)}" + ) + + @property + def attack_class(self) -> type[AttackStrategy[Any, Any]]: + """The attack strategy class this factory produces.""" + return self._attack_class + + @property + def seed_technique(self) -> SeedAttackTechniqueGroup | None: + """The optional technique seed group.""" + return self._seed_technique + + def create( + self, + *, + objective_target: PromptTarget, + attack_scoring_config: AttackScoringConfig, + attack_adversarial_config: AttackAdversarialConfig | None = None, + attack_converter_config: AttackConverterConfig | None = None, + ) -> AttackTechnique: + """ + Create a fresh AttackTechnique bound to the given target and scorer. + + Each call produces a fully independent attack instance by calling the + real constructor. Config objects are deep-copied to prevent shared + mutable state between instances. + + Args: + objective_target: The target to attack. + attack_scoring_config: Scoring configuration for the attack. + attack_adversarial_config: Optional adversarial configuration. + Overrides any adversarial config in the frozen kwargs. + attack_converter_config: Optional converter configuration. + Overrides any converter config in the frozen kwargs. + + Returns: + A fresh AttackTechnique with a newly-constructed attack strategy. + """ + kwargs = copy.deepcopy(self._attack_kwargs) + kwargs["objective_target"] = objective_target + kwargs["attack_scoring_config"] = attack_scoring_config + if attack_adversarial_config is not None: + kwargs["attack_adversarial_config"] = attack_adversarial_config + if attack_converter_config is not None: + kwargs["attack_converter_config"] = attack_converter_config + + attack = self._attack_class(**kwargs) + return AttackTechnique(attack=attack, seed_technique=self._seed_technique) + + @staticmethod + def _serialize_value(value: Any) -> Any: + """ + Convert a value to a JSON-safe representation for identifier hashing. + + Primitives are included directly. Identifiable objects contribute their + hash. Collections are serialized recursively. Other types fall back to + their qualified class name. + + Returns: + Any: A JSON-serializable representation of the value. + """ + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + return [AttackTechniqueFactory._serialize_value(v) for v in value] + if isinstance(value, dict): + return {str(k): AttackTechniqueFactory._serialize_value(v) for k, v in sorted(value.items())} + if isinstance(value, Identifiable): + return value.get_identifier().hash + return f"<{type(value).__qualname__}>" + + def _build_identifier(self) -> ComponentIdentifier: + """ + Build the behavioral identity for this factory. + + Includes the attack class name and kwargs with their serialized values + so that factories with different configurations produce different hashes. + + Returns: + ComponentIdentifier: The frozen identity snapshot. + """ + kwargs_for_id = {k: self._serialize_value(v) for k, v in sorted(self._attack_kwargs.items())} + return ComponentIdentifier.of( + self, + params={ + "attack_class": self._attack_class.__name__, + "kwargs": kwargs_for_id, + }, + ) diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 852ca1712f..f67bdcf3f0 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -23,7 +23,7 @@ SuffixAppendConverter, ) from pyrit.prompt_converter.prompt_converter import get_converter_modalities -from pyrit.registry.instance_registries import ConverterRegistry +from pyrit.registry.object_registries import ConverterRegistry @pytest.fixture(autouse=True) diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 06b3c5713c..70f61acafa 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -13,7 +13,7 @@ from pyrit.backend.models.targets import CreateTargetRequest from pyrit.backend.services.target_service import TargetService, get_target_service from pyrit.identifiers import ComponentIdentifier -from pyrit.registry.instance_registries import TargetRegistry +from pyrit.registry.object_registries import TargetRegistry @pytest.fixture(autouse=True) diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py new file mode 100644 index 0000000000..e0d7463b51 --- /dev/null +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -0,0 +1,254 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the AttackTechniqueRegistry class.""" + +from unittest.mock import MagicMock + +import pytest + +from pyrit.executor.attack.core.attack_config import AttackScoringConfig +from pyrit.identifiers import ComponentIdentifier +from pyrit.prompt_target import PromptTarget +from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry +from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + + +class _StubAttack: + """Minimal stub for testing the registry without real AttackStrategy weight.""" + + def __init__(self, *, objective_target, attack_scoring_config=None, max_turns: int = 5): + self.objective_target = objective_target + self.attack_scoring_config = attack_scoring_config + self.max_turns = max_turns + + def get_identifier(self): + return ComponentIdentifier( + class_name="_StubAttack", + class_module="tests.unit.registry.test_attack_technique_registry", + params={"max_turns": self.max_turns}, + ) + + +class TestAttackTechniqueRegistrySingleton: + """Tests for the singleton pattern.""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_get_registry_singleton_returns_same_instance(self): + instance1 = AttackTechniqueRegistry.get_registry_singleton() + instance2 = AttackTechniqueRegistry.get_registry_singleton() + + assert instance1 is instance2 + + def test_get_registry_singleton_returns_correct_type(self): + instance = AttackTechniqueRegistry.get_registry_singleton() + + assert isinstance(instance, AttackTechniqueRegistry) + + def test_reset_instance_clears_singleton(self): + instance1 = AttackTechniqueRegistry.get_registry_singleton() + AttackTechniqueRegistry.reset_instance() + instance2 = AttackTechniqueRegistry.get_registry_singleton() + + assert instance1 is not instance2 + + +class TestAttackTechniqueRegistryRegister: + """Tests for registering technique factories.""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + self.registry = AttackTechniqueRegistry.get_registry_singleton() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_register_technique_stores_factory(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + self.registry.register_technique(name="stub_attack", factory=factory) + + assert "stub_attack" in self.registry + assert self.registry._registry_items["stub_attack"].instance is factory + + def test_register_technique_with_tags(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + self.registry.register_technique( + name="stub_attack", + factory=factory, + tags=["single_turn", "encoding"], + ) + + entries = self.registry.get_by_tag(tag="single_turn") + assert len(entries) == 1 + assert entries[0].name == "stub_attack" + + def test_register_multiple_techniques(self): + factory1 = AttackTechniqueFactory(attack_class=_StubAttack) + factory2 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 20}, + ) + + self.registry.register_technique(name="stub_5", factory=factory1) + self.registry.register_technique(name="stub_20", factory=factory2) + + assert len(self.registry) == 2 + assert self.registry.get_names() == ["stub_20", "stub_5"] + + +class TestAttackTechniqueRegistryCreateTechnique: + """Tests for create_technique().""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + self.registry = AttackTechniqueRegistry.get_registry_singleton() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_create_technique_returns_attack_technique(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="stub", factory=factory) + target = MagicMock(spec=PromptTarget) + scoring = MagicMock(spec=AttackScoringConfig) + + technique = self.registry.create_technique("stub", objective_target=target, attack_scoring_config=scoring) + + assert isinstance(technique, AttackTechnique) + assert isinstance(technique.attack, _StubAttack) + assert technique.attack.objective_target is target + + def test_create_technique_passes_scoring_config(self): + class _ScoringStub: + def __init__(self, *, objective_target, attack_scoring_config=None): + self.objective_target = objective_target + self.attack_scoring_config = attack_scoring_config + + def get_identifier(self): + return ComponentIdentifier(class_name="_ScoringStub", class_module="test") + + factory = AttackTechniqueFactory(attack_class=_ScoringStub) + self.registry.register_technique(name="scoring_stub", factory=factory) + target = MagicMock(spec=PromptTarget) + scoring = MagicMock(spec=AttackScoringConfig) + + technique = self.registry.create_technique( + "scoring_stub", objective_target=target, attack_scoring_config=scoring + ) + + assert technique.attack.attack_scoring_config is scoring + + def test_create_technique_raises_on_missing_name(self): + with pytest.raises(KeyError, match="No technique registered with name 'nonexistent'"): + self.registry.create_technique( + "nonexistent", + objective_target=MagicMock(spec=PromptTarget), + attack_scoring_config=MagicMock(spec=AttackScoringConfig), + ) + + def test_create_technique_preserves_frozen_kwargs(self): + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 42}, + ) + self.registry.register_technique(name="custom", factory=factory) + target = MagicMock(spec=PromptTarget) + + technique = self.registry.create_technique( + "custom", objective_target=target, attack_scoring_config=MagicMock(spec=AttackScoringConfig) + ) + + assert technique.attack.max_turns == 42 + + +class TestAttackTechniqueRegistryMetadata: + """Tests for metadata / list_metadata on the registry.""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + self.registry = AttackTechniqueRegistry.get_registry_singleton() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_build_metadata_returns_component_identifier(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="stub", factory=factory) + + metadata = self.registry.list_metadata() + + assert len(metadata) == 1 + assert isinstance(metadata[0], ComponentIdentifier) + assert metadata[0].class_name == "AttackTechniqueFactory" + + def test_metadata_matches_factory_identifier(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="stub", factory=factory) + + metadata = self.registry.list_metadata() + + assert metadata[0] == factory.get_identifier() + + +class TestAttackTechniqueRegistryInherited: + """Tests for inherited BaseInstanceRegistry methods.""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + self.registry = AttackTechniqueRegistry.get_registry_singleton() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_contains(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="exists", factory=factory) + + assert "exists" in self.registry + assert "missing" not in self.registry + + def test_len(self): + assert len(self.registry) == 0 + + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="a", factory=factory) + + assert len(self.registry) == 1 + + def test_get_names_returns_sorted(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="zeta", factory=factory) + self.registry.register_technique(name="alpha", factory=factory) + self.registry.register_technique(name="beta", factory=factory) + + assert self.registry.get_names() == ["alpha", "beta", "zeta"] + + def test_tag_based_queries(self): + factory1 = AttackTechniqueFactory(attack_class=_StubAttack) + factory2 = AttackTechniqueFactory(attack_class=_StubAttack, attack_kwargs={"max_turns": 20}) + + self.registry.register_technique(name="f1", factory=factory1, tags=["multi_turn"]) + self.registry.register_technique(name="f2", factory=factory2, tags=["single_turn"]) + + multi = self.registry.get_by_tag(tag="multi_turn") + assert len(multi) == 1 + assert multi[0].name == "f1" + + single = self.registry.get_by_tag(tag="single_turn") + assert len(single) == 1 + assert single[0].name == "f2" + + def test_iter_yields_sorted_names(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="b", factory=factory) + self.registry.register_technique(name="a", factory=factory) + + assert list(self.registry) == ["a", "b"] diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index e61300cb0e..a0b5a75913 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -3,24 +3,54 @@ import pytest -from pyrit.identifiers import ComponentIdentifier -from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry, RegistryEntry +from pyrit.identifiers import ComponentIdentifier, Identifiable +from pyrit.registry.object_registries.base_instance_registry import ( + BaseInstanceRegistry, + RegistryEntry, +) +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, +) -class ConcreteTestRegistry(BaseInstanceRegistry[str, ComponentIdentifier]): - """Concrete implementation of BaseInstanceRegistry for testing.""" +class _TestItem(Identifiable): + """Minimal Identifiable stub wrapping a string value for testing.""" - def _build_metadata(self, name: str, instance: str) -> ComponentIdentifier: - """Build test metadata from a string instance.""" + def __init__(self, value: str) -> None: + self.value = value + + def _build_identifier(self) -> ComponentIdentifier: return ComponentIdentifier( - class_name="str", - class_module="builtins", - params={"category": "test" if "test" in instance.lower() else "other"}, + class_name="_TestItem", + class_module="test", + params={"category": "test" if "test" in self.value.lower() else "other"}, ) + def __eq__(self, other: object) -> bool: + if isinstance(other, _TestItem): + return self.value == other.value + if isinstance(other, str): + return self.value == other + return NotImplemented + + def __hash__(self) -> int: + return hash(self.value) + + def __repr__(self) -> str: + return f"_TestItem({self.value!r})" + + +def _item(value: str) -> _TestItem: + """Shorthand factory for _TestItem.""" + return _TestItem(value) -class TestBaseInstanceRegistrySingleton: - """Tests for the singleton pattern in BaseInstanceRegistry.""" + +class ConcreteTestRegistry(RetrievableInstanceRegistry["_TestItem"]): + """Concrete implementation of RetrievableInstanceRegistry for testing.""" + + +class TestRetrievableInstanceRegistrySingleton: + """Tests for the singleton pattern in RetrievableInstanceRegistry.""" def setup_method(self): """Reset the singleton before each test.""" @@ -52,8 +82,8 @@ def test_reset_instance_when_not_exists_does_not_raise(self): ConcreteTestRegistry.reset_instance() -class TestBaseInstanceRegistryRegistration: - """Tests for registration functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryRegistration: + """Tests for registration functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -66,16 +96,16 @@ def teardown_method(self): def test_register_adds_instance(self): """Test that register adds an instance to the registry.""" - self.registry.register("test_value", name="test_name") + self.registry.register(_item("test_value"), name="test_name") assert "test_name" in self.registry assert self.registry.get("test_name") == "test_value" def test_register_multiple_instances(self): """Test registering multiple instances.""" - self.registry.register("value1", name="name1") - self.registry.register("value2", name="name2") - self.registry.register("value3", name="name3") + self.registry.register(_item("value1"), name="name1") + self.registry.register(_item("value2"), name="name2") + self.registry.register(_item("value3"), name="name3") assert len(self.registry) == 3 assert self.registry.get("name1") == "value1" @@ -84,34 +114,34 @@ def test_register_multiple_instances(self): def test_register_overwrites_existing(self): """Test that registering with the same name overwrites the existing instance.""" - self.registry.register("original", name="name") - self.registry.register("updated", name="name") + self.registry.register(_item("original"), name="name") + self.registry.register(_item("updated"), name="name") assert len(self.registry) == 1 assert self.registry.get("name") == "updated" def test_register_invalidates_metadata_cache(self): """Test that registering a new instance invalidates the metadata cache.""" - self.registry.register("value1", name="name1") + self.registry.register(_item("value1"), name="name1") # Build cache by calling list_metadata metadata1 = self.registry.list_metadata() assert len(metadata1) == 1 # Register new instance - should invalidate cache - self.registry.register("value2", name="name2") + self.registry.register(_item("value2"), name="name2") metadata2 = self.registry.list_metadata() assert len(metadata2) == 2 -class TestBaseInstanceRegistryGet: - """Tests for get functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryGet: + """Tests for get functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" ConcreteTestRegistry.reset_instance() self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register("test_value", name="test_name") + self.registry.register(_item("test_value"), name="test_name") def teardown_method(self): """Reset the singleton after each test.""" @@ -128,14 +158,14 @@ def test_get_nonexistent_returns_none(self): assert result is None -class TestBaseInstanceRegistryGetEntry: - """Tests for get_entry functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryGetEntry: + """Tests for get_entry functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" ConcreteTestRegistry.reset_instance() self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register("test_value", name="test_name", tags={"role": "scorer"}) + self.registry.register(_item("test_value"), name="test_name", tags={"role": "scorer"}) def teardown_method(self): """Reset the singleton after each test.""" @@ -156,8 +186,8 @@ def test_get_entry_nonexistent_returns_none(self): assert result is None -class TestBaseInstanceRegistryGetNames: - """Tests for get_names functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryGetNames: + """Tests for get_names functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -175,16 +205,16 @@ def test_get_names_empty_registry(self): def test_get_names_returns_sorted_list(self): """Test that get_names returns a sorted list of names.""" - self.registry.register("value3", name="zeta") - self.registry.register("value1", name="alpha") - self.registry.register("value2", name="beta") + self.registry.register(_item("value3"), name="zeta") + self.registry.register(_item("value1"), name="alpha") + self.registry.register(_item("value2"), name="beta") names = self.registry.get_names() assert names == ["alpha", "beta", "zeta"] -class TestBaseInstanceRegistryGetAllInstances: - """Tests for get_all_instances functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryGetAllInstances: + """Tests for get_all_instances functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -197,8 +227,8 @@ def teardown_method(self): def test_get_all_instances_returns_list_of_registry_entries(self): """Test that get_all_instances returns a list of RegistryEntry objects.""" - self.registry.register("value1", name="name1") - self.registry.register("value2", name="name2") + self.registry.register(_item("value1"), name="name1") + self.registry.register(_item("value2"), name="name2") result = self.registry.get_all_instances() assert isinstance(result, list) @@ -207,17 +237,17 @@ def test_get_all_instances_returns_list_of_registry_entries(self): def test_get_all_instances_sorted_by_name(self): """Test that get_all_instances returns entries sorted by name.""" - self.registry.register("value_z", name="zeta") - self.registry.register("value_a", name="alpha") - self.registry.register("value_b", name="beta") + self.registry.register(_item("value_z"), name="zeta") + self.registry.register(_item("value_a"), name="alpha") + self.registry.register(_item("value_b"), name="beta") result = self.registry.get_all_instances() assert [e.name for e in result] == ["alpha", "beta", "zeta"] def test_get_all_instances_preserves_tags(self): """Test that get_all_instances preserves tags on entries.""" - self.registry.register("value1", name="name1", tags={"role": "scorer"}) - self.registry.register("value2", name="name2", tags=["fast"]) + self.registry.register(_item("value1"), name="name1", tags={"role": "scorer"}) + self.registry.register(_item("value2"), name="name2", tags=["fast"]) result = self.registry.get_all_instances() entry_map = {e.name: e for e in result} @@ -230,16 +260,16 @@ def test_get_all_instances_empty_registry(self): assert result == [] -class TestBaseInstanceRegistryListMetadata: - """Tests for list_metadata functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryListMetadata: + """Tests for list_metadata functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" ConcreteTestRegistry.reset_instance() self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register("test_item_1", name="item1") - self.registry.register("other_item_2", name="item2") - self.registry.register("test_item_3", name="item3") + self.registry.register(_item("test_item_1"), name="item1") + self.registry.register(_item("other_item_2"), name="item2") + self.registry.register(_item("test_item_3"), name="item3") def teardown_method(self): """Reset the singleton after each test.""" @@ -256,9 +286,9 @@ def test_list_metadata_sorted_by_name(self): # Since unique_name is auto-computed, we verify we get 3 items in order # The actual unique_name field is auto-computed from class_name::hash assert len(metadata) == 3 - # All should have "str" in the unique_name since class_name is "str" + # All should have "_TestItem" in the unique_name since class_name is "_TestItem" for m in metadata: - assert "str" in m.unique_name + assert "_TestItem" in m.unique_name def test_list_metadata_with_filter(self): """Test filtering metadata by a field.""" @@ -280,7 +310,7 @@ def test_list_metadata_with_exclude_filter(self): def test_list_metadata_combined_include_and_exclude(self): """Test combined include and exclude filters.""" # Add another test item to have more variety - self.registry.register("another_test_item", name="item4") + self.registry.register(_item("another_test_item"), name="item4") # Get items with category "test" but exclude by class_name "str" # Since all have class_name="str", excluding by class_name would exclude all @@ -301,8 +331,8 @@ def test_list_metadata_caching(self): assert len(metadata1) == 3 -class TestBaseInstanceRegistryTags: - """Tests for tag registration and retrieval in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryTags: + """Tests for tag registration and retrieval in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -315,7 +345,7 @@ def teardown_method(self): def test_register_with_dict_tags(self): """Test that dict tags are stored correctly.""" - self.registry.register("value", name="name1", tags={"role": "scorer", "provider": "azure"}) + self.registry.register(_item("value"), name="name1", tags={"role": "scorer", "provider": "azure"}) entry = self.registry.get_entry("name1") assert entry is not None @@ -323,7 +353,7 @@ def test_register_with_dict_tags(self): def test_register_with_list_tags(self): """Test that list tags are normalized to dict with empty string values.""" - self.registry.register("value", name="name1", tags=["fast", "default"]) + self.registry.register(_item("value"), name="name1", tags=["fast", "default"]) entry = self.registry.get_entry("name1") assert entry is not None @@ -331,7 +361,7 @@ def test_register_with_list_tags(self): def test_register_without_tags(self): """Test that registering without tags defaults to empty dict.""" - self.registry.register("value", name="name1") + self.registry.register(_item("value"), name="name1") entry = self.registry.get_entry("name1") assert entry is not None @@ -339,9 +369,9 @@ def test_register_without_tags(self): def test_get_by_tag_key_only(self): """Test get_by_tag matching by key only (any value).""" - self.registry.register("v1", name="n1", tags={"role": "scorer"}) - self.registry.register("v2", name="n2", tags={"role": "target"}) - self.registry.register("v3", name="n3", tags={"provider": "azure"}) + self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) + self.registry.register(_item("v2"), name="n2", tags={"role": "target"}) + self.registry.register(_item("v3"), name="n3", tags={"provider": "azure"}) results = self.registry.get_by_tag(tag="role") assert len(results) == 2 @@ -349,9 +379,9 @@ def test_get_by_tag_key_only(self): def test_get_by_tag_key_and_value(self): """Test get_by_tag matching by key and specific value.""" - self.registry.register("v1", name="n1", tags={"role": "scorer"}) - self.registry.register("v2", name="n2", tags={"role": "target"}) - self.registry.register("v3", name="n3", tags={"role": "scorer"}) + self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) + self.registry.register(_item("v2"), name="n2", tags={"role": "target"}) + self.registry.register(_item("v3"), name="n3", tags={"role": "scorer"}) results = self.registry.get_by_tag(tag="role", value="scorer") assert len(results) == 2 @@ -359,31 +389,31 @@ def test_get_by_tag_key_and_value(self): def test_get_by_tag_no_match(self): """Test get_by_tag returns empty list when no entries match.""" - self.registry.register("v1", name="n1", tags={"role": "scorer"}) + self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) results = self.registry.get_by_tag(tag="nonexistent") assert results == [] def test_get_by_tag_value_no_match(self): """Test get_by_tag returns empty when key exists but value does not match.""" - self.registry.register("v1", name="n1", tags={"role": "scorer"}) + self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) results = self.registry.get_by_tag(tag="role", value="nonexistent") assert results == [] def test_get_by_tag_returns_sorted_by_name(self): """Test that get_by_tag results are sorted by name.""" - self.registry.register("v3", name="zeta", tags=["shared"]) - self.registry.register("v1", name="alpha", tags=["shared"]) - self.registry.register("v2", name="beta", tags=["shared"]) + self.registry.register(_item("v3"), name="zeta", tags=["shared"]) + self.registry.register(_item("v1"), name="alpha", tags=["shared"]) + self.registry.register(_item("v2"), name="beta", tags=["shared"]) results = self.registry.get_by_tag(tag="shared") assert [e.name for e in results] == ["alpha", "beta", "zeta"] def test_get_by_tag_with_list_tags(self): """Test get_by_tag works with list-style tags (normalized to empty string values).""" - self.registry.register("v1", name="n1", tags=["fast", "default"]) - self.registry.register("v2", name="n2", tags=["slow"]) + self.registry.register(_item("v1"), name="n1", tags=["fast", "default"]) + self.registry.register(_item("v2"), name="n2", tags=["slow"]) results = self.registry.get_by_tag(tag="fast") assert len(results) == 1 @@ -391,7 +421,7 @@ def test_get_by_tag_with_list_tags(self): def test_get_by_tag_with_list_tags_value_empty_string(self): """Test get_by_tag with explicit empty string value matches list-style tags.""" - self.registry.register("v1", name="n1", tags=["fast"]) + self.registry.register(_item("v1"), name="n1", tags=["fast"]) results = self.registry.get_by_tag(tag="fast", value="") assert len(results) == 1 @@ -399,11 +429,11 @@ def test_get_by_tag_with_list_tags_value_empty_string(self): def test_normalize_tags_none(self): """Test _normalize_tags returns empty dict for None.""" - assert BaseInstanceRegistry._normalize_tags(None) == {} + assert RetrievableInstanceRegistry._normalize_tags(None) == {} def test_normalize_tags_list(self): """Test _normalize_tags converts list to dict with empty values.""" - assert BaseInstanceRegistry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} + assert RetrievableInstanceRegistry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} def test_normalize_tags_dict(self): """Test _normalize_tags returns a copy of the dict.""" @@ -413,15 +443,15 @@ def test_normalize_tags_dict(self): assert result is not original -class TestBaseInstanceRegistryDunderMethods: - """Tests for dunder methods (__contains__, __len__, __iter__) in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryDunderMethods: + """Tests for dunder methods (__contains__, __len__, __iter__) in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" ConcreteTestRegistry.reset_instance() self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register("value1", name="name1") - self.registry.register("value2", name="name2") + self.registry.register(_item("value1"), name="name1") + self.registry.register(_item("value2"), name="name2") def teardown_method(self): """Reset the singleton after each test.""" @@ -457,8 +487,54 @@ def test_iter_allows_for_loop(self): assert collected == ["name1", "name2"] -class TestBaseInstanceRegistryAddTags: - """Tests for add_tags functionality in BaseInstanceRegistry.""" +class _ItemOnlyRegistry(BaseInstanceRegistry["_TestItem"]): + """Concrete BaseInstanceRegistry subclass — should NOT have get/get_entry/get_all_instances.""" + + +class TestBaseInstanceRegistryDoesNotExposeInstanceMethods: + """Verify that BaseInstanceRegistry subclasses lack instance-retrieval methods.""" + + def test_item_registry_has_no_get(self): + """BaseInstanceRegistry subclasses should not have a get() method.""" + assert not hasattr(_ItemOnlyRegistry, "get") + + def test_item_registry_has_no_get_entry(self): + """BaseInstanceRegistry subclasses should not have a get_entry() method.""" + assert not hasattr(_ItemOnlyRegistry, "get_entry") + + def test_item_registry_has_no_get_all_instances(self): + """BaseInstanceRegistry subclasses should not have a get_all_instances() method.""" + assert not hasattr(_ItemOnlyRegistry, "get_all_instances") + + def test_instance_registry_has_get(self): + """RetrievableInstanceRegistry subclasses should have get().""" + assert hasattr(ConcreteTestRegistry, "get") + + def test_instance_registry_has_get_entry(self): + """RetrievableInstanceRegistry subclasses should have get_entry().""" + assert hasattr(ConcreteTestRegistry, "get_entry") + + def test_instance_registry_has_get_all_instances(self): + """RetrievableInstanceRegistry subclasses should have get_all_instances().""" + assert hasattr(ConcreteTestRegistry, "get_all_instances") + + def test_item_registry_shares_common_methods(self): + """BaseInstanceRegistry subclasses should have shared registry methods.""" + for method in ( + "register", + "get_names", + "get_by_tag", + "add_tags", + "list_metadata", + "find_dependents_of_tag", + "get_registry_singleton", + "reset_instance", + ): + assert hasattr(_ItemOnlyRegistry, method), f"Missing method: {method}" + + +class TestRetrievableInstanceRegistryAddTags: + """Tests for add_tags functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -471,7 +547,7 @@ def teardown_method(self): def test_add_tags_with_list(self): """Test adding list-style tags to an existing entry.""" - self.registry.register("value", name="entry1") + self.registry.register(_item("value"), name="entry1") self.registry.add_tags(name="entry1", tags=["fast", "default"]) entry = self.registry.get_entry("entry1") @@ -480,7 +556,7 @@ def test_add_tags_with_list(self): def test_add_tags_with_dict(self): """Test adding dict-style tags to an existing entry.""" - self.registry.register("value", name="entry1") + self.registry.register(_item("value"), name="entry1") self.registry.add_tags(name="entry1", tags={"role": "scorer"}) entry = self.registry.get_entry("entry1") @@ -489,7 +565,7 @@ def test_add_tags_with_dict(self): def test_add_tags_merges_with_existing(self): """Test that add_tags merges new tags with existing ones.""" - self.registry.register("value", name="entry1", tags={"existing": "yes"}) + self.registry.register(_item("value"), name="entry1", tags={"existing": "yes"}) self.registry.add_tags(name="entry1", tags=["new_tag"]) entry = self.registry.get_entry("entry1") @@ -503,7 +579,7 @@ def test_add_tags_raises_for_missing_entry(self): def test_add_tags_invalidates_metadata_cache(self): """Test that add_tags invalidates the metadata cache.""" - self.registry.register("value", name="entry1") + self.registry.register(_item("value"), name="entry1") self.registry.list_metadata() # Build cache self.registry.add_tags(name="entry1", tags=["new"]) @@ -513,7 +589,7 @@ def test_add_tags_invalidates_metadata_cache(self): def test_add_tags_entries_findable_by_get_by_tag(self): """Test that entries are findable via get_by_tag after add_tags.""" - self.registry.register("value", name="entry1") + self.registry.register(_item("value"), name="entry1") self.registry.add_tags(name="entry1", tags=["best_scorer"]) results = self.registry.get_by_tag(tag="best_scorer") @@ -521,25 +597,22 @@ def test_add_tags_entries_findable_by_get_by_tag(self): assert results[0].name == "entry1" -class _IdentifiableStub: +class _IdentifiableStub(Identifiable): """A minimal stub that holds a ComponentIdentifier for dependency tests.""" def __init__(self, identifier: ComponentIdentifier) -> None: - self.identifier = identifier + self._stored_identifier = identifier - def get_identifier(self) -> ComponentIdentifier: - return self.identifier + def _build_identifier(self) -> ComponentIdentifier: + return self._stored_identifier -class IdentifierTestRegistry(BaseInstanceRegistry["_IdentifiableStub", ComponentIdentifier]): +class IdentifierTestRegistry(BaseInstanceRegistry["_IdentifiableStub"]): """Registry for testing dependency-related functionality with ComponentIdentifier trees.""" - def _build_metadata(self, name: str, instance: "_IdentifiableStub") -> ComponentIdentifier: - return instance.get_identifier() - class TestFindDependentsOfTag: - """Tests for BaseInstanceRegistry.find_dependents_of_tag.""" + """Tests for RetrievableInstanceRegistry.find_dependents_of_tag.""" def setup_method(self) -> None: IdentifierTestRegistry.reset_instance() diff --git a/tests/unit/registry/test_converter_registry.py b/tests/unit/registry/test_converter_registry.py index b62fedb6da..871fb85964 100644 --- a/tests/unit/registry/test_converter_registry.py +++ b/tests/unit/registry/test_converter_registry.py @@ -4,7 +4,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import PromptDataType from pyrit.prompt_converter import ConverterResult, PromptConverter -from pyrit.registry.instance_registries.converter_registry import ConverterRegistry +from pyrit.registry.object_registries.converter_registry import ConverterRegistry class MockTextConverter(PromptConverter): @@ -300,7 +300,7 @@ def test_list_metadata_combined_include_and_exclude(self): class TestConverterRegistryInheritedMethods: - """Tests for inherited methods from BaseInstanceRegistry.""" + """Tests for inherited methods from RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry.""" diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py index 9d521b75fc..078d530967 100644 --- a/tests/unit/registry/test_scorer_registry.py +++ b/tests/unit/registry/test_scorer_registry.py @@ -5,7 +5,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, MessagePiece, Score -from pyrit.registry.instance_registries.scorer_registry import ScorerRegistry +from pyrit.registry.object_registries.scorer_registry import ScorerRegistry from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer import Scorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -328,7 +328,7 @@ def test_list_metadata_combined_include_and_exclude(self): class TestScorerRegistryInheritedMethods: - """Tests for inherited methods from BaseInstanceRegistry.""" + """Tests for inherited methods from RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry.""" diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index 503d096a38..3f05624e9b 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -8,7 +8,7 @@ from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget -from pyrit.registry.instance_registries.target_registry import TargetRegistry +from pyrit.registry.object_registries.target_registry import TargetRegistry class MockPromptTarget(PromptTarget): diff --git a/tests/unit/scenario/test_attack_technique_factory.py b/tests/unit/scenario/test_attack_technique_factory.py new file mode 100644 index 0000000000..00734eb009 --- /dev/null +++ b/tests/unit/scenario/test_attack_technique_factory.py @@ -0,0 +1,368 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the AttackTechniqueFactory class.""" + +from unittest.mock import MagicMock + +import pytest + +from pyrit.executor.attack.core.attack_config import AttackConverterConfig, AttackScoringConfig +from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack +from pyrit.identifiers import ComponentIdentifier, Identifiable +from pyrit.models import SeedAttackTechniqueGroup, SeedPrompt +from pyrit.prompt_target import PromptTarget +from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + + +def _make_seed_technique() -> SeedAttackTechniqueGroup: + return SeedAttackTechniqueGroup( + seeds=[ + SeedPrompt(value="technique1", data_type="text", is_general_technique=True), + ] + ) + + +class _StubAttack: + """ + Minimal stub that mimics an AttackStrategy constructor signature. + + We use a plain class rather than a real AttackStrategy subclass to keep + the unit tests fast and free of heavyweight base-class initialization. + ``inspect.signature`` sees the same keyword-only parameters that the + factory's ``_validate_kwargs`` expects. + """ + + def __init__( + self, + *, + objective_target: PromptTarget, + attack_scoring_config: AttackScoringConfig | None = None, + attack_converter_config: AttackConverterConfig | None = None, + max_turns: int = 5, + ) -> None: + self.objective_target = objective_target + self.attack_scoring_config = attack_scoring_config + self.attack_converter_config = attack_converter_config + self.max_turns = max_turns + + def get_identifier(self) -> ComponentIdentifier: + return ComponentIdentifier( + class_name="_StubAttack", + class_module="tests.unit.scenario.test_attack_technique_factory", + ) + + +class TestFactoryInit: + """Tests for AttackTechniqueFactory construction and validation.""" + + def test_init_defaults(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + assert factory.attack_class is _StubAttack + assert factory.seed_technique is None + + def test_init_stores_seed_technique(self): + seeds = _make_seed_technique() + factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seeds) + + assert factory.seed_technique is seeds + + def test_validate_kwargs_accepts_valid_params(self): + """All valid kwarg names should pass without error.""" + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, + ) + assert factory.attack_class is _StubAttack + + def test_validate_kwargs_rejects_unknown_params(self): + """Typo or nonexistent kwarg should raise TypeError immediately.""" + with pytest.raises(TypeError, match="Invalid kwargs.*max_turn"): + AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turn": 10}, # typo: should be max_turns + ) + + def test_validate_kwargs_rejects_objective_target(self): + """objective_target must not be in attack_kwargs.""" + target = MagicMock(spec=PromptTarget) + with pytest.raises(ValueError, match="objective_target must not be in attack_kwargs"): + AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"objective_target": target}, + ) + + def test_validate_kwargs_rejects_multiple_invalid(self): + """Multiple bad kwargs should all be reported.""" + with pytest.raises(TypeError, match="Invalid kwargs"): + AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"bad_param_1": 1, "bad_param_2": 2}, + ) + + def test_validate_kwargs_rejects_var_keyword_constructor(self): + """Constructors with **kwargs prevent parameter validation and should be rejected.""" + + class _KwargsAttack: + def __init__(self, **kwargs): + pass + + with pytest.raises(TypeError, match="accepts \\*\\*kwargs.*parameter validation"): + AttackTechniqueFactory(attack_class=_KwargsAttack) + + def test_validate_kwargs_rejects_var_keyword_even_with_named_params(self): + """Mixed named params + **kwargs should still be rejected.""" + + class _MixedAttack: + def __init__(self, *, objective_target, max_turns: int = 5, **extra): + pass + + with pytest.raises(TypeError, match="accepts \\*\\*kwargs"): + AttackTechniqueFactory( + attack_class=_MixedAttack, + attack_kwargs={"max_turns": 10}, + ) + + def test_validate_kwargs_works_with_real_attack_class(self): + """ + Validate that inspect.signature correctly sees through @apply_defaults + and functools.wraps on a real AttackStrategy subclass. + """ + # PromptSendingAttack uses @apply_defaults — factory should see its real params + factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) + assert factory.attack_class is PromptSendingAttack + + def test_validate_kwargs_rejects_invalid_param_on_real_attack_class(self): + """A typo kwarg should be caught even through @apply_defaults.""" + with pytest.raises(TypeError, match="Invalid kwargs.*nonexistent_param"): + AttackTechniqueFactory( + attack_class=PromptSendingAttack, + attack_kwargs={"nonexistent_param": 42}, + ) + + +class TestFactoryCreate: + """Tests for AttackTechniqueFactory.create().""" + + def _scoring(self) -> AttackScoringConfig: + return MagicMock(spec=AttackScoringConfig) + + def test_create_produces_attack_technique(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + target = MagicMock(spec=PromptTarget) + + technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + + assert isinstance(technique, AttackTechnique) + assert isinstance(technique.attack, _StubAttack) + assert technique.attack.objective_target is target + + def test_create_passes_frozen_kwargs(self): + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 42}, + ) + target = MagicMock(spec=PromptTarget) + + technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + + assert technique.attack.max_turns == 42 + + def test_create_passes_scoring_config(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + target = MagicMock(spec=PromptTarget) + scoring = MagicMock(spec=AttackScoringConfig) + + technique = factory.create(objective_target=target, attack_scoring_config=scoring) + + assert technique.attack.attack_scoring_config is scoring + + def test_create_overrides_frozen_scoring_config(self): + """Create-time scoring config should override the frozen one.""" + frozen_scoring = MagicMock(spec=AttackScoringConfig) + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"attack_scoring_config": frozen_scoring}, + ) + target = MagicMock(spec=PromptTarget) + override_scoring = MagicMock(spec=AttackScoringConfig) + + technique = factory.create(objective_target=target, attack_scoring_config=override_scoring) + + assert technique.attack.attack_scoring_config is override_scoring + assert technique.attack.attack_scoring_config is not frozen_scoring + + def test_create_preserves_seed_technique(self): + seeds = _make_seed_technique() + factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seeds) + target = MagicMock(spec=PromptTarget) + + technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + + assert technique.seed_technique is seeds + + def test_create_produces_independent_instances(self): + """Two create() calls should produce fully independent attack instances.""" + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10}, + ) + target1 = MagicMock(spec=PromptTarget) + target2 = MagicMock(spec=PromptTarget) + scoring = self._scoring() + + technique1 = factory.create(objective_target=target1, attack_scoring_config=scoring) + technique2 = factory.create(objective_target=target2, attack_scoring_config=scoring) + + assert technique1.attack is not technique2.attack + assert technique1.attack.objective_target is target1 + assert technique2.attack.objective_target is target2 + + def test_create_deepcopies_kwargs(self): + """Mutating the original kwargs dict should not affect future creates.""" + mutable_list = [1, 2, 3] + + class _ListAttack: + def __init__(self, *, objective_target, attack_scoring_config=None, items: list | None = None): + self.objective_target = objective_target + self.items = items + + def get_identifier(self): + return ComponentIdentifier(class_name="_ListAttack", class_module="test") + + factory = AttackTechniqueFactory( + attack_class=_ListAttack, + attack_kwargs={"items": mutable_list}, + ) + target = MagicMock(spec=PromptTarget) + + technique1 = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + # Mutate the source list + mutable_list.append(999) + + technique2 = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + + # First create should have the original snapshot + assert technique1.attack.items == [1, 2, 3] + # Second create should also have the original (from deepcopy of stored kwargs) + assert technique2.attack.items == [1, 2, 3] + + def test_create_without_optional_configs_omits_them(self): + """When optional configs are None, adversarial and converter should not be passed.""" + unset = object() + + class _SentinelAttack: + def __init__( + self, + *, + objective_target, + attack_scoring_config, + attack_adversarial_config=unset, + attack_converter_config=unset, + ): + self.objective_target = objective_target + self.adversarial_was_passed = attack_adversarial_config is not unset + self.converter_was_passed = attack_converter_config is not unset + + def get_identifier(self): + return ComponentIdentifier(class_name="_SentinelAttack", class_module="test") + + factory = AttackTechniqueFactory(attack_class=_SentinelAttack) + target = MagicMock(spec=PromptTarget) + technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + + assert not technique.attack.adversarial_was_passed + assert not technique.attack.converter_was_passed + + +class TestFactoryIdentifier: + """Tests for AttackTechniqueFactory._build_identifier().""" + + def test_identifier_includes_attack_class_name(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + identifier = factory.get_identifier() + + assert isinstance(identifier, ComponentIdentifier) + assert identifier.class_name == "AttackTechniqueFactory" + assert identifier.params["attack_class"] == "_StubAttack" + + def test_identifier_includes_kwargs_with_values(self): + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, + ) + + identifier = factory.get_identifier() + + assert identifier.params["kwargs"] == {"attack_scoring_config": None, "max_turns": 10} + + def test_identifier_empty_kwargs(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + identifier = factory.get_identifier() + + assert identifier.params["kwargs"] == {} + + def test_same_keys_different_values_produce_different_hashes(self): + """Two factories with max_turns=5 vs max_turns=50 must have different hashes.""" + factory1 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 5}, + ) + factory2 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 50}, + ) + + assert factory1.get_identifier().hash != factory2.get_identifier().hash + + def test_different_kwargs_keys_produce_different_hashes(self): + factory1 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10}, + ) + factory2 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, + ) + + assert factory1.get_identifier().hash != factory2.get_identifier().hash + + def test_identifier_serializes_identifiable_values(self): + """Identifiable objects in kwargs should contribute their hash to the identifier.""" + expected_id = ComponentIdentifier( + class_name="MockConfig", + class_module="test", + params={"key": "value"}, + ) + mock_identifiable = MagicMock(spec=Identifiable) + mock_identifiable.get_identifier.return_value = expected_id + + class _IdentifiableParamAttack: + def __init__(self, *, objective_target, config=None): + pass + + def get_identifier(self): + return ComponentIdentifier(class_name="_IdentifiableParamAttack", class_module="test") + + factory = AttackTechniqueFactory( + attack_class=_IdentifiableParamAttack, + attack_kwargs={"config": mock_identifiable}, + ) + + identifier = factory.get_identifier() + config_value = identifier.params["kwargs"]["config"] + # Should be the hash string from the identifiable, not the object itself + assert isinstance(config_value, str) + assert config_value == expected_id.hash + + def test_identifier_is_cached(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + first = factory.get_identifier() + second = factory.get_identifier() + + assert first is second