From 915f6ce7441354ce0ee59c8ef5eb73eb14f519e4 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 13 Apr 2026 16:15:49 -0700 Subject: [PATCH 01/10] Add AttackTechniqueFactory and AttackTechniqueRegistry Introduce a factory-based registry for reusable attack technique configurations. AttackTechniqueFactory captures technique-specific kwargs at registration time and calls the real attack constructor at create() time with scenario-specific objective_target and scorer. Key design decisions: - Signature validation at construction time catches kwarg typos early - Deep-copy of kwargs at both construction and create() prevents shared mutable state between factory calls - No changes to any existing attack class New files: - pyrit/scenario/core/attack_technique_factory.py - pyrit/registry/instance_registries/attack_technique_registry.py - tests/unit/scenario/test_attack_technique_factory.py (19 tests) - tests/unit/registry/test_attack_technique_registry.py (17 tests) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/registry/__init__.py | 2 + .../registry/instance_registries/__init__.py | 4 + .../attack_technique_registry.py | 119 ++++++++ pyrit/scenario/__init__.py | 2 + pyrit/scenario/core/__init__.py | 2 + .../scenario/core/attack_technique_factory.py | 173 +++++++++++ .../test_attack_technique_registry.py | 256 ++++++++++++++++ .../scenario/test_attack_technique_factory.py | 276 ++++++++++++++++++ 8 files changed, 834 insertions(+) create mode 100644 pyrit/registry/instance_registries/attack_technique_registry.py create mode 100644 pyrit/scenario/core/attack_technique_factory.py create mode 100644 tests/unit/registry/test_attack_technique_registry.py create mode 100644 tests/unit/scenario/test_attack_technique_factory.py diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 76c162d342..4dcb5bd2c7 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -18,6 +18,7 @@ discover_subclasses_in_loaded_modules, ) from pyrit.registry.instance_registries import ( + AttackTechniqueRegistry, BaseInstanceRegistry, RegistryEntry, ScorerRegistry, @@ -25,6 +26,7 @@ ) __all__ = [ + "AttackTechniqueRegistry", "BaseClassRegistry", "BaseInstanceRegistry", "ClassEntry", diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index d635813936..f53335b65c 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -11,6 +11,9 @@ For registries that store classes (Type[T]), see class_registries/. """ +from pyrit.registry.instance_registries.attack_technique_registry import ( + AttackTechniqueRegistry, +) from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, RegistryEntry, @@ -30,6 +33,7 @@ "BaseInstanceRegistry", "RegistryEntry", # Concrete registries + "AttackTechniqueRegistry", "ConverterRegistry", "ScorerRegistry", "TargetRegistry", diff --git a/pyrit/registry/instance_registries/attack_technique_registry.py b/pyrit/registry/instance_registries/attack_technique_registry.py new file mode 100644 index 0000000000..a15f3864c6 --- /dev/null +++ b/pyrit/registry/instance_registries/attack_technique_registry.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +AttackTechniqueRegistry — Singleton registry of reusable attack technique factories. + +Scenarios and initializers register technique factories (capturing technique-specific +config). Scenarios retrieve them via ``create_technique()``, which calls the factory +with the scenario's objective target and scorer. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional, Union + +from pyrit.identifiers import ComponentIdentifier +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, +) + +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_config import ( + AttackAdversarialConfig, + AttackConverterConfig, + AttackScoringConfig, + ) + from pyrit.prompt_target import PromptTarget + from pyrit.scenario.core.attack_technique import AttackTechnique + from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + +logger = logging.getLogger(__name__) + + +class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory", ComponentIdentifier]): + """ + Singleton registry of reusable attack technique factories. + + Scenarios and initializers register technique factories (capturing + technique-specific config). Scenarios retrieve them via ``create_technique()``, + which calls the factory with the scenario's objective target and scorer. + """ + + @classmethod + def get_registry_singleton(cls) -> AttackTechniqueRegistry: + """ + Get the singleton instance of the AttackTechniqueRegistry. + + Returns: + The singleton AttackTechniqueRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def register_technique( + self, + *, + name: str, + factory: AttackTechniqueFactory, + tags: Optional[Union[dict[str, str], list[str]]] = None, + ) -> None: + """ + Register an attack technique factory. + + Args: + name: The registry name for this technique. + factory: The factory that produces attack techniques. + tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` + or a ``list[str]`` (each string becomes a key with value ``""``). + """ + self.register(factory, name=name, tags=tags) + logger.debug(f"Registered attack technique factory: {name} ({factory.attack_class.__name__})") + + def create_technique( + self, + name: str, + *, + objective_target: PromptTarget, + attack_scoring_config: AttackScoringConfig | None = None, + attack_adversarial_config: AttackAdversarialConfig | None = None, + attack_converter_config: AttackConverterConfig | None = None, + ) -> AttackTechnique: + """ + Retrieve a factory by name and produce a fresh attack technique. + + Args: + name: The registry name of the technique. + objective_target: The target to attack. + attack_scoring_config: Optional scoring configuration override. + attack_adversarial_config: Optional adversarial configuration override. + attack_converter_config: Optional converter configuration override. + + Returns: + A fresh AttackTechnique with a newly-constructed attack strategy. + + Raises: + KeyError: If no technique is registered with the given name. + """ + factory = self.get(name) + if factory is None: + raise KeyError(f"No technique registered with name '{name}'") + return factory.create( + objective_target=objective_target, + attack_scoring_config=attack_scoring_config, + attack_adversarial_config=attack_adversarial_config, + attack_converter_config=attack_converter_config, + ) + + def _build_metadata(self, name: str, instance: AttackTechniqueFactory) -> ComponentIdentifier: + """ + Build metadata for a technique factory. + + Args: + name: The registry name of the factory. + instance: The factory instance. + + Returns: + ComponentIdentifier: The factory's identifier. + """ + return instance.get_identifier() diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 53f865305a..e8ebfb2946 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -19,6 +19,7 @@ from pyrit.scenario.core import ( AtomicAttack, AttackTechnique, + AttackTechniqueFactory, DatasetConfiguration, Scenario, ScenarioCompositeStrategy, @@ -44,6 +45,7 @@ __all__ = [ "AtomicAttack", "AttackTechnique", + "AttackTechniqueFactory", "DatasetConfiguration", "Scenario", "ScenarioCompositeStrategy", diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 55c4517aca..8f40282bef 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -5,6 +5,7 @@ from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy @@ -12,6 +13,7 @@ __all__ = [ "AtomicAttack", "AttackTechnique", + "AttackTechniqueFactory", "DatasetConfiguration", "EXPLICIT_SEED_GROUPS_KEY", "Scenario", diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py new file mode 100644 index 0000000000..8d6155c311 --- /dev/null +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +AttackTechniqueFactory — Deferred construction of AttackTechnique instances. + +Captures technique-specific configuration at registration time and produces +fresh, fully-constructed attacks when scenario-specific params (objective target, +scorer) become available. +""" + +from __future__ import annotations + +import copy +import inspect +from typing import TYPE_CHECKING, Any + +from pyrit.identifiers import ComponentIdentifier, Identifiable +from pyrit.scenario.core.attack_technique import AttackTechnique + +if TYPE_CHECKING: + from pyrit.executor.attack import AttackStrategy + from pyrit.executor.attack.core.attack_config import ( + AttackAdversarialConfig, + AttackConverterConfig, + AttackScoringConfig, + ) + from pyrit.models import SeedAttackTechniqueGroup + from pyrit.prompt_target import PromptTarget + + +class AttackTechniqueFactory(Identifiable): + """ + A factory that produces AttackTechnique instances on demand. + + Captures technique-specific configuration (converters, adversarial config, + tree depth, etc.) at registration time. Produces fresh, fully-constructed + attacks by calling the real constructor with the captured params plus + scenario-specific objective_target and scoring config. + + Validates kwargs against the attack class constructor signature at + construction time, catching typos and incompatible parameter names early. + """ + + def __init__( + self, + *, + attack_class: type[AttackStrategy], + attack_kwargs: dict[str, Any] | None = None, + seed_technique: SeedAttackTechniqueGroup | None = None, + ) -> None: + """ + Initialize the factory with a technique-specific configuration. + + Args: + attack_class: The AttackStrategy subclass to instantiate. + attack_kwargs: Keyword arguments to pass to the attack constructor. + Must not include ``objective_target`` (provided at create time). + seed_technique: Optional technique seed group to attach to created techniques. + + Raises: + TypeError: If any kwarg name is not a valid constructor parameter. + ValueError: If ``objective_target`` is included in attack_kwargs. + """ + self._attack_class = attack_class + self._attack_kwargs = copy.deepcopy(attack_kwargs) if attack_kwargs else {} + self._seed_technique = seed_technique + + self._validate_kwargs() + + def _validate_kwargs(self) -> None: + """ + Validate that all kwargs are valid parameters for the attack class constructor. + + Uses ``inspect.signature`` on the attack class ``__init__``, which works through + the ``@apply_defaults`` decorator (it uses ``functools.wraps``). + + Raises: + TypeError: If any kwarg name is not a valid constructor parameter. + ValueError: If ``objective_target`` is included in attack_kwargs. + """ + if "objective_target" in self._attack_kwargs: + raise ValueError( + "objective_target must not be in attack_kwargs — " + "it is provided at create() time." + ) + + sig = inspect.signature(self._attack_class.__init__) + valid_params = { + name + for name, param in sig.parameters.items() + if name != "self" + and param.kind + in ( + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + } + + invalid = set(self._attack_kwargs) - valid_params + if invalid: + raise TypeError( + f"Invalid kwargs for {self._attack_class.__name__}: {sorted(invalid)}. " + f"Valid parameters: {sorted(valid_params)}" + ) + + @property + def attack_class(self) -> type[AttackStrategy]: + """The attack strategy class this factory produces.""" + return self._attack_class + + @property + def seed_technique(self) -> SeedAttackTechniqueGroup | None: + """The optional technique seed group.""" + return self._seed_technique + + def create( + self, + *, + objective_target: PromptTarget, + attack_scoring_config: AttackScoringConfig | None = None, + attack_adversarial_config: AttackAdversarialConfig | None = None, + attack_converter_config: AttackConverterConfig | None = None, + ) -> AttackTechnique: + """ + Create a fresh AttackTechnique bound to the given target and scorer. + + Each call produces a fully independent attack instance by calling the + real constructor. Config objects are deep-copied to prevent shared + mutable state between instances. + + Args: + objective_target: The target to attack. + attack_scoring_config: Optional scoring configuration. + Overrides any scoring config in the frozen kwargs. + attack_adversarial_config: Optional adversarial configuration. + Overrides any adversarial config in the frozen kwargs. + attack_converter_config: Optional converter configuration. + Overrides any converter config in the frozen kwargs. + + Returns: + A fresh AttackTechnique with a newly-constructed attack strategy. + """ + kwargs = copy.deepcopy(self._attack_kwargs) + kwargs["objective_target"] = objective_target + if attack_scoring_config is not None: + kwargs["attack_scoring_config"] = attack_scoring_config + if attack_adversarial_config is not None: + kwargs["attack_adversarial_config"] = attack_adversarial_config + if attack_converter_config is not None: + kwargs["attack_converter_config"] = attack_converter_config + + attack = self._attack_class(**kwargs) + return AttackTechnique(attack=attack, seed_technique=self._seed_technique) + + def _build_identifier(self) -> ComponentIdentifier: + """ + Build the behavioral identity for this factory. + + Includes the attack class name and the sorted kwarg keys so that + factories with different configurations are distinguishable. + + Returns: + ComponentIdentifier: The frozen identity snapshot. + """ + kwargs_summary = sorted(self._attack_kwargs.keys()) + return ComponentIdentifier.of( + self, + params={ + "attack_class": self._attack_class.__name__, + "kwargs_keys": kwargs_summary, + }, + ) diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py new file mode 100644 index 0000000000..c5e795b453 --- /dev/null +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -0,0 +1,256 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the AttackTechniqueRegistry class.""" + +from unittest.mock import MagicMock + +import pytest + +from pyrit.executor.attack.core.attack_config import AttackScoringConfig +from pyrit.identifiers import ComponentIdentifier +from pyrit.prompt_target import PromptTarget +from pyrit.registry.instance_registries.attack_technique_registry import AttackTechniqueRegistry +from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + + +class _StubAttack: + """Minimal stub for testing the registry without real AttackStrategy weight.""" + + def __init__(self, *, objective_target, max_turns: int = 5): + self.objective_target = objective_target + self.max_turns = max_turns + + def get_identifier(self): + return ComponentIdentifier( + class_name="_StubAttack", + class_module="tests.unit.registry.test_attack_technique_registry", + params={"max_turns": self.max_turns}, + ) + + +class TestAttackTechniqueRegistrySingleton: + """Tests for the singleton pattern.""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_get_registry_singleton_returns_same_instance(self): + instance1 = AttackTechniqueRegistry.get_registry_singleton() + instance2 = AttackTechniqueRegistry.get_registry_singleton() + + assert instance1 is instance2 + + def test_get_registry_singleton_returns_correct_type(self): + instance = AttackTechniqueRegistry.get_registry_singleton() + + assert isinstance(instance, AttackTechniqueRegistry) + + def test_reset_instance_clears_singleton(self): + instance1 = AttackTechniqueRegistry.get_registry_singleton() + AttackTechniqueRegistry.reset_instance() + instance2 = AttackTechniqueRegistry.get_registry_singleton() + + assert instance1 is not instance2 + + +class TestAttackTechniqueRegistryRegister: + """Tests for registering technique factories.""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + self.registry = AttackTechniqueRegistry.get_registry_singleton() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_register_technique_stores_factory(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + self.registry.register_technique(name="stub_attack", factory=factory) + + assert "stub_attack" in self.registry + assert self.registry.get("stub_attack") is factory + + def test_register_technique_with_tags(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + self.registry.register_technique( + name="stub_attack", + factory=factory, + tags=["single_turn", "encoding"], + ) + + entries = self.registry.get_by_tag(tag="single_turn") + assert len(entries) == 1 + assert entries[0].name == "stub_attack" + + def test_register_multiple_techniques(self): + factory1 = AttackTechniqueFactory(attack_class=_StubAttack) + factory2 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 20}, + ) + + self.registry.register_technique(name="stub_5", factory=factory1) + self.registry.register_technique(name="stub_20", factory=factory2) + + assert len(self.registry) == 2 + assert self.registry.get_names() == ["stub_20", "stub_5"] + + +class TestAttackTechniqueRegistryCreateTechnique: + """Tests for create_technique().""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + self.registry = AttackTechniqueRegistry.get_registry_singleton() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_create_technique_returns_attack_technique(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="stub", factory=factory) + target = MagicMock(spec=PromptTarget) + + technique = self.registry.create_technique("stub", objective_target=target) + + assert isinstance(technique, AttackTechnique) + assert isinstance(technique.attack, _StubAttack) + assert technique.attack.objective_target is target + + def test_create_technique_passes_scoring_config(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="stub", factory=factory) + target = MagicMock(spec=PromptTarget) + scoring = MagicMock(spec=AttackScoringConfig) + + # _StubAttack doesn't have scoring config, but factory passes it through + # The real constructor would use it; here we verify the flow doesn't error + # when using a stub that doesn't accept it. Let's use a factory that does. + + class _ScoringStub: + def __init__(self, *, objective_target, attack_scoring_config=None): + self.objective_target = objective_target + self.attack_scoring_config = attack_scoring_config + + def get_identifier(self): + return ComponentIdentifier(class_name="_ScoringStub", class_module="test") + + factory2 = AttackTechniqueFactory(attack_class=_ScoringStub) + self.registry.register_technique(name="scoring_stub", factory=factory2) + + technique = self.registry.create_technique( + "scoring_stub", objective_target=target, attack_scoring_config=scoring + ) + + assert technique.attack.attack_scoring_config is scoring + + def test_create_technique_raises_on_missing_name(self): + with pytest.raises(KeyError, match="No technique registered with name 'nonexistent'"): + self.registry.create_technique( + "nonexistent", + objective_target=MagicMock(spec=PromptTarget), + ) + + def test_create_technique_preserves_frozen_kwargs(self): + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 42}, + ) + self.registry.register_technique(name="custom", factory=factory) + target = MagicMock(spec=PromptTarget) + + technique = self.registry.create_technique("custom", objective_target=target) + + assert technique.attack.max_turns == 42 + + +class TestAttackTechniqueRegistryMetadata: + """Tests for metadata / list_metadata on the registry.""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + self.registry = AttackTechniqueRegistry.get_registry_singleton() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_build_metadata_returns_component_identifier(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="stub", factory=factory) + + metadata = self.registry.list_metadata() + + assert len(metadata) == 1 + assert isinstance(metadata[0], ComponentIdentifier) + assert metadata[0].class_name == "AttackTechniqueFactory" + + def test_metadata_matches_factory_identifier(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="stub", factory=factory) + + metadata = self.registry.list_metadata() + + assert metadata[0] == factory.get_identifier() + + +class TestAttackTechniqueRegistryInherited: + """Tests for inherited BaseInstanceRegistry methods.""" + + def setup_method(self): + AttackTechniqueRegistry.reset_instance() + self.registry = AttackTechniqueRegistry.get_registry_singleton() + + def teardown_method(self): + AttackTechniqueRegistry.reset_instance() + + def test_contains(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="exists", factory=factory) + + assert "exists" in self.registry + assert "missing" not in self.registry + + def test_len(self): + assert len(self.registry) == 0 + + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="a", factory=factory) + + assert len(self.registry) == 1 + + def test_get_names_returns_sorted(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="zeta", factory=factory) + self.registry.register_technique(name="alpha", factory=factory) + self.registry.register_technique(name="beta", factory=factory) + + assert self.registry.get_names() == ["alpha", "beta", "zeta"] + + def test_tag_based_queries(self): + factory1 = AttackTechniqueFactory(attack_class=_StubAttack) + factory2 = AttackTechniqueFactory(attack_class=_StubAttack, attack_kwargs={"max_turns": 20}) + + self.registry.register_technique(name="f1", factory=factory1, tags=["multi_turn"]) + self.registry.register_technique(name="f2", factory=factory2, tags=["single_turn"]) + + multi = self.registry.get_by_tag(tag="multi_turn") + assert len(multi) == 1 + assert multi[0].name == "f1" + + single = self.registry.get_by_tag(tag="single_turn") + assert len(single) == 1 + assert single[0].name == "f2" + + def test_iter_yields_sorted_names(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + self.registry.register_technique(name="b", factory=factory) + self.registry.register_technique(name="a", factory=factory) + + assert list(self.registry) == ["a", "b"] diff --git a/tests/unit/scenario/test_attack_technique_factory.py b/tests/unit/scenario/test_attack_technique_factory.py new file mode 100644 index 0000000000..d7214cf880 --- /dev/null +++ b/tests/unit/scenario/test_attack_technique_factory.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the AttackTechniqueFactory class.""" + +from unittest.mock import MagicMock + +import pytest + +from pyrit.executor.attack.core.attack_config import AttackConverterConfig, AttackScoringConfig +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import SeedAttackTechniqueGroup, SeedPrompt +from pyrit.prompt_target import PromptTarget +from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + + +def _make_seed_technique() -> SeedAttackTechniqueGroup: + return SeedAttackTechniqueGroup( + seeds=[ + SeedPrompt(value="technique1", data_type="text", is_general_technique=True), + ] + ) + + +class _StubAttack: + """ + Minimal stub that mimics an AttackStrategy constructor signature. + + We use a plain class rather than a real AttackStrategy subclass to keep + the unit tests fast and free of heavyweight base-class initialization. + ``inspect.signature`` sees the same keyword-only parameters that the + factory's ``_validate_kwargs`` expects. + """ + + def __init__( + self, + *, + objective_target: PromptTarget, + attack_scoring_config: AttackScoringConfig | None = None, + attack_converter_config: AttackConverterConfig | None = None, + max_turns: int = 5, + ) -> None: + self.objective_target = objective_target + self.attack_scoring_config = attack_scoring_config + self.attack_converter_config = attack_converter_config + self.max_turns = max_turns + + def get_identifier(self) -> ComponentIdentifier: + return ComponentIdentifier( + class_name="_StubAttack", + class_module="tests.unit.scenario.test_attack_technique_factory", + ) + + +class TestFactoryInit: + """Tests for AttackTechniqueFactory construction and validation.""" + + def test_init_defaults(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + assert factory.attack_class is _StubAttack + assert factory.seed_technique is None + + def test_init_stores_seed_technique(self): + seeds = _make_seed_technique() + factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seeds) + + assert factory.seed_technique is seeds + + def test_validate_kwargs_accepts_valid_params(self): + """All valid kwarg names should pass without error.""" + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, + ) + assert factory.attack_class is _StubAttack + + def test_validate_kwargs_rejects_unknown_params(self): + """Typo or nonexistent kwarg should raise TypeError immediately.""" + with pytest.raises(TypeError, match="Invalid kwargs.*max_turn"): + AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turn": 10}, # typo: should be max_turns + ) + + def test_validate_kwargs_rejects_objective_target(self): + """objective_target must not be in attack_kwargs.""" + target = MagicMock(spec=PromptTarget) + with pytest.raises(ValueError, match="objective_target must not be in attack_kwargs"): + AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"objective_target": target}, + ) + + def test_validate_kwargs_rejects_multiple_invalid(self): + """Multiple bad kwargs should all be reported.""" + with pytest.raises(TypeError, match="Invalid kwargs"): + AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"bad_param_1": 1, "bad_param_2": 2}, + ) + + +class TestFactoryCreate: + """Tests for AttackTechniqueFactory.create().""" + + def test_create_produces_attack_technique(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + target = MagicMock(spec=PromptTarget) + + technique = factory.create(objective_target=target) + + assert isinstance(technique, AttackTechnique) + assert isinstance(technique.attack, _StubAttack) + assert technique.attack.objective_target is target + + def test_create_passes_frozen_kwargs(self): + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 42}, + ) + target = MagicMock(spec=PromptTarget) + + technique = factory.create(objective_target=target) + + assert technique.attack.max_turns == 42 + + def test_create_passes_scoring_config(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + target = MagicMock(spec=PromptTarget) + scoring = MagicMock(spec=AttackScoringConfig) + + technique = factory.create(objective_target=target, attack_scoring_config=scoring) + + assert technique.attack.attack_scoring_config is scoring + + def test_create_overrides_frozen_scoring_config(self): + """Create-time scoring config should override the frozen one.""" + frozen_scoring = MagicMock(spec=AttackScoringConfig) + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"attack_scoring_config": frozen_scoring}, + ) + target = MagicMock(spec=PromptTarget) + override_scoring = MagicMock(spec=AttackScoringConfig) + + technique = factory.create(objective_target=target, attack_scoring_config=override_scoring) + + assert technique.attack.attack_scoring_config is override_scoring + assert technique.attack.attack_scoring_config is not frozen_scoring + + def test_create_preserves_seed_technique(self): + seeds = _make_seed_technique() + factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seeds) + target = MagicMock(spec=PromptTarget) + + technique = factory.create(objective_target=target) + + assert technique.seed_technique is seeds + + def test_create_produces_independent_instances(self): + """Two create() calls should produce fully independent attack instances.""" + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10}, + ) + target1 = MagicMock(spec=PromptTarget) + target2 = MagicMock(spec=PromptTarget) + + technique1 = factory.create(objective_target=target1) + technique2 = factory.create(objective_target=target2) + + assert technique1.attack is not technique2.attack + assert technique1.attack.objective_target is target1 + assert technique2.attack.objective_target is target2 + + def test_create_deepcopies_kwargs(self): + """Mutating the original kwargs dict should not affect future creates.""" + mutable_list = [1, 2, 3] + + # Use a class that accepts a list param to test deepcopy + class _ListAttack: + def __init__(self, *, objective_target, items: list | None = None): + self.objective_target = objective_target + self.items = items + + def get_identifier(self): + return ComponentIdentifier(class_name="_ListAttack", class_module="test") + + factory = AttackTechniqueFactory( + attack_class=_ListAttack, + attack_kwargs={"items": mutable_list}, + ) + target = MagicMock(spec=PromptTarget) + + technique1 = factory.create(objective_target=target) + # Mutate the source list + mutable_list.append(999) + + technique2 = factory.create(objective_target=target) + + # First create should have the original snapshot + assert technique1.attack.items == [1, 2, 3] + # Second create should also have the original (from deepcopy of stored kwargs) + assert technique2.attack.items == [1, 2, 3] + + def test_create_without_optional_configs_omits_them(self): + """When optional configs are None, they should not be in kwargs.""" + call_kwargs: dict = {} + + class _SpyAttack: + def __init__(self, **kwargs): + call_kwargs.update(kwargs) + + def get_identifier(self): + return ComponentIdentifier(class_name="_SpyAttack", class_module="test") + + factory = AttackTechniqueFactory(attack_class=_SpyAttack) + target = MagicMock(spec=PromptTarget) + factory.create(objective_target=target) + + assert "objective_target" in call_kwargs + assert "attack_scoring_config" not in call_kwargs + assert "attack_adversarial_config" not in call_kwargs + assert "attack_converter_config" not in call_kwargs + + +class TestFactoryIdentifier: + """Tests for AttackTechniqueFactory._build_identifier().""" + + def test_identifier_includes_attack_class_name(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + identifier = factory.get_identifier() + + assert isinstance(identifier, ComponentIdentifier) + assert identifier.class_name == "AttackTechniqueFactory" + assert identifier.params["attack_class"] == "_StubAttack" + + def test_identifier_includes_kwargs_keys(self): + factory = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, + ) + + identifier = factory.get_identifier() + + assert sorted(identifier.params["kwargs_keys"]) == ["attack_scoring_config", "max_turns"] + + def test_identifier_empty_kwargs_keys(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + identifier = factory.get_identifier() + + assert identifier.params["kwargs_keys"] == [] + + def test_different_kwargs_produce_different_hashes(self): + factory1 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10}, + ) + factory2 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, + ) + + assert factory1.get_identifier().hash != factory2.get_identifier().hash + + def test_identifier_is_cached(self): + factory = AttackTechniqueFactory(attack_class=_StubAttack) + + first = factory.get_identifier() + second = factory.get_identifier() + + assert first is second From dac2099da656107c569d2b385c0f521bd87a0d3e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 13 Apr 2026 16:28:25 -0700 Subject: [PATCH 02/10] Address code review: reject **kwargs, hash values in identifier, add tests - _validate_kwargs now rejects constructors with **kwargs (VAR_KEYWORD) since parameter validation requires explicitly named params - _build_identifier includes serialized kwarg values (not just keys) so factories with different values produce different hashes - _serialize_value handles primitives, Identifiable objects, and collections recursively; falls back to type name for others - Added tests: VAR_KEYWORD rejection, real PromptSendingAttack validation through @apply_defaults, Identifiable value serialization, same-keys-different-values hash divergence - Rewrote test_create_without_optional_configs to use sentinel pattern instead of **kwargs spy (which is now rejected) - Cleaned up dead code in registry test_create_technique_passes_scoring_config Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenario/core/attack_technique_factory.py | 49 ++++++- .../test_attack_technique_registry.py | 15 +- .../scenario/test_attack_technique_factory.py | 129 +++++++++++++++--- 3 files changed, 157 insertions(+), 36 deletions(-) diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index 8d6155c311..d44c0b3697 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -59,7 +59,8 @@ def __init__( seed_technique: Optional technique seed group to attach to created techniques. Raises: - TypeError: If any kwarg name is not a valid constructor parameter. + TypeError: If any kwarg name is not a valid constructor parameter, + or if the attack class constructor uses ``**kwargs``. ValueError: If ``objective_target`` is included in attack_kwargs. """ self._attack_class = attack_class @@ -76,7 +77,9 @@ def _validate_kwargs(self) -> None: the ``@apply_defaults`` decorator (it uses ``functools.wraps``). Raises: - TypeError: If any kwarg name is not a valid constructor parameter. + TypeError: If any kwarg name is not a valid constructor parameter, + or if the constructor uses ``**kwargs`` (all parameters must be + explicitly named). ValueError: If ``objective_target`` is included in attack_kwargs. """ if "objective_target" in self._attack_kwargs: @@ -86,6 +89,18 @@ def _validate_kwargs(self) -> None: ) sig = inspect.signature(self._attack_class.__init__) + + # Reject constructors that accept **kwargs — we require explicitly named + # parameters so that validation is meaningful. + has_var_keyword = any( + param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values() + ) + if has_var_keyword: + raise TypeError( + f"{self._attack_class.__name__}.__init__ accepts **kwargs, which prevents " + f"parameter validation. All attack constructor parameters must be explicitly named." + ) + valid_params = { name for name, param in sig.parameters.items() @@ -153,21 +168,43 @@ def create( attack = self._attack_class(**kwargs) return AttackTechnique(attack=attack, seed_technique=self._seed_technique) + @staticmethod + def _serialize_value(value: Any) -> Any: + """ + Convert a value to a JSON-safe representation for identifier hashing. + + Primitives are included directly. Identifiable objects contribute their + hash. Collections are serialized recursively. Other types fall back to + their qualified class name. + + Returns: + Any: A JSON-serializable representation of the value. + """ + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + return [AttackTechniqueFactory._serialize_value(v) for v in value] + if isinstance(value, dict): + return {str(k): AttackTechniqueFactory._serialize_value(v) for k, v in sorted(value.items())} + if isinstance(value, Identifiable): + return value.get_identifier().hash + return f"<{type(value).__qualname__}>" + def _build_identifier(self) -> ComponentIdentifier: """ Build the behavioral identity for this factory. - Includes the attack class name and the sorted kwarg keys so that - factories with different configurations are distinguishable. + Includes the attack class name and kwargs with their serialized values + so that factories with different configurations produce different hashes. Returns: ComponentIdentifier: The frozen identity snapshot. """ - kwargs_summary = sorted(self._attack_kwargs.keys()) + kwargs_for_id = {k: self._serialize_value(v) for k, v in sorted(self._attack_kwargs.items())} return ComponentIdentifier.of( self, params={ "attack_class": self._attack_class.__name__, - "kwargs_keys": kwargs_summary, + "kwargs": kwargs_for_id, }, ) diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py index c5e795b453..3a9c6de4ce 100644 --- a/tests/unit/registry/test_attack_technique_registry.py +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -125,15 +125,6 @@ def test_create_technique_returns_attack_technique(self): assert technique.attack.objective_target is target def test_create_technique_passes_scoring_config(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) - self.registry.register_technique(name="stub", factory=factory) - target = MagicMock(spec=PromptTarget) - scoring = MagicMock(spec=AttackScoringConfig) - - # _StubAttack doesn't have scoring config, but factory passes it through - # The real constructor would use it; here we verify the flow doesn't error - # when using a stub that doesn't accept it. Let's use a factory that does. - class _ScoringStub: def __init__(self, *, objective_target, attack_scoring_config=None): self.objective_target = objective_target @@ -142,8 +133,10 @@ def __init__(self, *, objective_target, attack_scoring_config=None): def get_identifier(self): return ComponentIdentifier(class_name="_ScoringStub", class_module="test") - factory2 = AttackTechniqueFactory(attack_class=_ScoringStub) - self.registry.register_technique(name="scoring_stub", factory=factory2) + factory = AttackTechniqueFactory(attack_class=_ScoringStub) + self.registry.register_technique(name="scoring_stub", factory=factory) + target = MagicMock(spec=PromptTarget) + scoring = MagicMock(spec=AttackScoringConfig) technique = self.registry.create_technique( "scoring_stub", objective_target=target, attack_scoring_config=scoring diff --git a/tests/unit/scenario/test_attack_technique_factory.py b/tests/unit/scenario/test_attack_technique_factory.py index d7214cf880..8555a29ea8 100644 --- a/tests/unit/scenario/test_attack_technique_factory.py +++ b/tests/unit/scenario/test_attack_technique_factory.py @@ -8,7 +8,8 @@ import pytest from pyrit.executor.attack.core.attack_config import AttackConverterConfig, AttackScoringConfig -from pyrit.identifiers import ComponentIdentifier +from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack +from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.models import SeedAttackTechniqueGroup, SeedPrompt from pyrit.prompt_target import PromptTarget from pyrit.scenario.core.attack_technique import AttackTechnique @@ -101,6 +102,46 @@ def test_validate_kwargs_rejects_multiple_invalid(self): attack_kwargs={"bad_param_1": 1, "bad_param_2": 2}, ) + def test_validate_kwargs_rejects_var_keyword_constructor(self): + """Constructors with **kwargs prevent parameter validation and should be rejected.""" + + class _KwargsAttack: + def __init__(self, **kwargs): + pass + + with pytest.raises(TypeError, match="accepts \\*\\*kwargs.*parameter validation"): + AttackTechniqueFactory(attack_class=_KwargsAttack) + + def test_validate_kwargs_rejects_var_keyword_even_with_named_params(self): + """Mixed named params + **kwargs should still be rejected.""" + + class _MixedAttack: + def __init__(self, *, objective_target, max_turns: int = 5, **extra): + pass + + with pytest.raises(TypeError, match="accepts \\*\\*kwargs"): + AttackTechniqueFactory( + attack_class=_MixedAttack, + attack_kwargs={"max_turns": 10}, + ) + + def test_validate_kwargs_works_with_real_attack_class(self): + """ + Validate that inspect.signature correctly sees through @apply_defaults + and functools.wraps on a real AttackStrategy subclass. + """ + # PromptSendingAttack uses @apply_defaults — factory should see its real params + factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) + assert factory.attack_class is PromptSendingAttack + + def test_validate_kwargs_rejects_invalid_param_on_real_attack_class(self): + """A typo kwarg should be caught even through @apply_defaults.""" + with pytest.raises(TypeError, match="Invalid kwargs.*nonexistent_param"): + AttackTechniqueFactory( + attack_class=PromptSendingAttack, + attack_kwargs={"nonexistent_param": 42}, + ) + class TestFactoryCreate: """Tests for AttackTechniqueFactory.create().""" @@ -206,24 +247,33 @@ def get_identifier(self): assert technique2.attack.items == [1, 2, 3] def test_create_without_optional_configs_omits_them(self): - """When optional configs are None, they should not be in kwargs.""" - call_kwargs: dict = {} - - class _SpyAttack: - def __init__(self, **kwargs): - call_kwargs.update(kwargs) + """When optional configs are None, they should not be passed to the constructor.""" + unset = object() + + class _SentinelAttack: + def __init__( + self, + *, + objective_target, + attack_scoring_config=unset, + attack_adversarial_config=unset, + attack_converter_config=unset, + ): + self.objective_target = objective_target + self.scoring_was_passed = attack_scoring_config is not unset + self.adversarial_was_passed = attack_adversarial_config is not unset + self.converter_was_passed = attack_converter_config is not unset def get_identifier(self): - return ComponentIdentifier(class_name="_SpyAttack", class_module="test") + return ComponentIdentifier(class_name="_SentinelAttack", class_module="test") - factory = AttackTechniqueFactory(attack_class=_SpyAttack) + factory = AttackTechniqueFactory(attack_class=_SentinelAttack) target = MagicMock(spec=PromptTarget) - factory.create(objective_target=target) + technique = factory.create(objective_target=target) - assert "objective_target" in call_kwargs - assert "attack_scoring_config" not in call_kwargs - assert "attack_adversarial_config" not in call_kwargs - assert "attack_converter_config" not in call_kwargs + assert not technique.attack.scoring_was_passed + assert not technique.attack.adversarial_was_passed + assert not technique.attack.converter_was_passed class TestFactoryIdentifier: @@ -238,7 +288,7 @@ def test_identifier_includes_attack_class_name(self): assert identifier.class_name == "AttackTechniqueFactory" assert identifier.params["attack_class"] == "_StubAttack" - def test_identifier_includes_kwargs_keys(self): + def test_identifier_includes_kwargs_with_values(self): factory = AttackTechniqueFactory( attack_class=_StubAttack, attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, @@ -246,16 +296,29 @@ def test_identifier_includes_kwargs_keys(self): identifier = factory.get_identifier() - assert sorted(identifier.params["kwargs_keys"]) == ["attack_scoring_config", "max_turns"] + assert identifier.params["kwargs"] == {"attack_scoring_config": None, "max_turns": 10} - def test_identifier_empty_kwargs_keys(self): + def test_identifier_empty_kwargs(self): factory = AttackTechniqueFactory(attack_class=_StubAttack) identifier = factory.get_identifier() - assert identifier.params["kwargs_keys"] == [] + assert identifier.params["kwargs"] == {} - def test_different_kwargs_produce_different_hashes(self): + def test_same_keys_different_values_produce_different_hashes(self): + """Two factories with max_turns=5 vs max_turns=50 must have different hashes.""" + factory1 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 5}, + ) + factory2 = AttackTechniqueFactory( + attack_class=_StubAttack, + attack_kwargs={"max_turns": 50}, + ) + + assert factory1.get_identifier().hash != factory2.get_identifier().hash + + def test_different_kwargs_keys_produce_different_hashes(self): factory1 = AttackTechniqueFactory( attack_class=_StubAttack, attack_kwargs={"max_turns": 10}, @@ -267,6 +330,34 @@ def test_different_kwargs_produce_different_hashes(self): assert factory1.get_identifier().hash != factory2.get_identifier().hash + def test_identifier_serializes_identifiable_values(self): + """Identifiable objects in kwargs should contribute their hash to the identifier.""" + expected_id = ComponentIdentifier( + class_name="MockConfig", + class_module="test", + params={"key": "value"}, + ) + mock_identifiable = MagicMock(spec=Identifiable) + mock_identifiable.get_identifier.return_value = expected_id + + class _IdentifiableParamAttack: + def __init__(self, *, objective_target, config=None): + pass + + def get_identifier(self): + return ComponentIdentifier(class_name="_IdentifiableParamAttack", class_module="test") + + factory = AttackTechniqueFactory( + attack_class=_IdentifiableParamAttack, + attack_kwargs={"config": mock_identifiable}, + ) + + identifier = factory.get_identifier() + config_value = identifier.params["kwargs"]["config"] + # Should be the hash string from the identifiable, not the object itself + assert isinstance(config_value, str) + assert config_value == expected_id.hash + def test_identifier_is_cached(self): factory = AttackTechniqueFactory(attack_class=_StubAttack) From 4690e00a740c559a8ff6ec98729de6a49356aaaa Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 13 Apr 2026 16:44:30 -0700 Subject: [PATCH 03/10] Use typing.Self for registry singletons, remove 6 boilerplate overrides All registry subclasses had identical get_registry_singleton() overrides that existed solely for return type narrowing. Using Self in the base classes (RegistryProtocol, BaseClassRegistry, BaseInstanceRegistry) gives the same type safety and eliminates 69 lines of boilerplate. Self is imported under TYPE_CHECKING for 3.10 compatibility. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/registry/base.py | 11 ++++++++--- .../class_registries/base_class_registry.py | 13 +++++++++---- .../class_registries/initializer_registry.py | 10 ---------- .../registry/class_registries/scenario_registry.py | 10 ---------- .../attack_technique_registry.py | 10 ---------- .../instance_registries/base_instance_registry.py | 5 +++-- .../instance_registries/converter_registry.py | 10 ---------- .../registry/instance_registries/scorer_registry.py | 10 ---------- .../registry/instance_registries/target_registry.py | 10 ---------- 9 files changed, 20 insertions(+), 69 deletions(-) diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index 02a973a869..1bf8a4a298 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -8,9 +8,14 @@ and instance registries (which store T instances). """ -from collections.abc import Iterator +from __future__ import annotations + from dataclasses import dataclass -from typing import Any, Optional, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Optional, Protocol, TypeVar, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Self # Type variable for metadata (invariant for Protocol compatibility) MetadataT = TypeVar("MetadataT") @@ -52,7 +57,7 @@ class RegistryProtocol(Protocol[MetadataT]): """ @classmethod - def get_registry_singleton(cls) -> "RegistryProtocol[MetadataT]": + def get_registry_singleton(cls) -> Self: """Get the singleton instance of this registry.""" ... diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index ba735d8b2d..9439c24332 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -16,9 +16,14 @@ - **ClassEntry**: Internal wrapper holding a class plus optional factory/defaults """ +from __future__ import annotations + from abc import ABC, abstractmethod -from collections.abc import Callable, Iterator -from typing import Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Generic, Optional, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing import Self from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.registry.base import RegistryProtocol @@ -107,7 +112,7 @@ class BaseClassRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]) """ # Class-level singleton instances, keyed by registry class - _instances: dict[type, "BaseClassRegistry[object, object]"] = {} + _instances: dict[type, BaseClassRegistry[object, object]] = {} def __init__(self, *, lazy_discovery: bool = True) -> None: """ @@ -128,7 +133,7 @@ def __init__(self, *, lazy_discovery: bool = True) -> None: self._discovered = True @classmethod - def get_registry_singleton(cls) -> "BaseClassRegistry[T, MetadataT]": + def get_registry_singleton(cls) -> Self: """ Get the singleton instance of this registry. diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 4a23320535..4007c58d66 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -66,16 +66,6 @@ class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetad The directory structure is used for organization but not exposed to users. """ - @classmethod - def get_registry_singleton(cls) -> InitializerRegistry: - """ - Get the singleton instance of the InitializerRegistry. - - Returns: - The singleton InitializerRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: bool = False) -> None: """ Initialize the initializer registry. diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 6f6d949adf..f8b0e3e87f 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -67,16 +67,6 @@ class ScenarioRegistry(BaseClassRegistry["Scenario", ScenarioMetadata]): Scenarios are identified by their dotted name (e.g., "garak.encoding", "foundry.red_team_agent"). """ - @classmethod - def get_registry_singleton(cls) -> ScenarioRegistry: - """ - Get the singleton instance of the ScenarioRegistry. - - Returns: - The singleton ScenarioRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def __init__(self, *, lazy_discovery: bool = True) -> None: """ Initialize the scenario registry. diff --git a/pyrit/registry/instance_registries/attack_technique_registry.py b/pyrit/registry/instance_registries/attack_technique_registry.py index a15f3864c6..9241f390d0 100644 --- a/pyrit/registry/instance_registries/attack_technique_registry.py +++ b/pyrit/registry/instance_registries/attack_technique_registry.py @@ -41,16 +41,6 @@ class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory", Com which calls the factory with the scenario's objective target and scorer. """ - @classmethod - def get_registry_singleton(cls) -> AttackTechniqueRegistry: - """ - Get the singleton instance of the AttackTechniqueRegistry. - - Returns: - The singleton AttackTechniqueRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def register_technique( self, *, diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index b29ea6aa08..de344cc782 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from collections.abc import Iterator + from typing import Self T = TypeVar("T") # The type of instances stored MetadataT = TypeVar("MetadataT", bound=ComponentIdentifier) @@ -68,7 +69,7 @@ class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, Metadata _instances: dict[type, BaseInstanceRegistry[Any, Any]] = {} @classmethod - def get_registry_singleton(cls) -> BaseInstanceRegistry[T, MetadataT]: + def get_registry_singleton(cls) -> Self: """ Get the singleton instance of this registry. @@ -79,7 +80,7 @@ def get_registry_singleton(cls) -> BaseInstanceRegistry[T, MetadataT]: """ if cls not in cls._instances: cls._instances[cls] = cls() - return cls._instances[cls] + return cls._instances[cls] # type: ignore[return-value] @classmethod def reset_instance(cls) -> None: diff --git a/pyrit/registry/instance_registries/converter_registry.py b/pyrit/registry/instance_registries/converter_registry.py index 987c94fa0f..eacb8db9f5 100644 --- a/pyrit/registry/instance_registries/converter_registry.py +++ b/pyrit/registry/instance_registries/converter_registry.py @@ -34,16 +34,6 @@ class ConverterRegistry(BaseInstanceRegistry["PromptConverter", ComponentIdentif with their required parameters. """ - @classmethod - def get_registry_singleton(cls) -> ConverterRegistry: - """ - Get the singleton instance of the ConverterRegistry. - - Returns: - The singleton ConverterRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def register_instance( self, converter: PromptConverter, diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index e1fb7e1e9c..814356ef35 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -35,16 +35,6 @@ class ScorerRegistry(BaseInstanceRegistry["Scorer", ComponentIdentifier]): or a custom name provided during registration. """ - @classmethod - def get_registry_singleton(cls) -> ScorerRegistry: - """ - Get the singleton instance of the ScorerRegistry. - - Returns: - The singleton ScorerRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def register_instance( self, scorer: Scorer, diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py index 763d4a9e31..1f47ccde98 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/instance_registries/target_registry.py @@ -35,16 +35,6 @@ class TargetRegistry(BaseInstanceRegistry["PromptTarget", ComponentIdentifier]): or a custom name provided during registration. """ - @classmethod - def get_registry_singleton(cls) -> TargetRegistry: - """ - Get the singleton instance of the TargetRegistry. - - Returns: - The singleton TargetRegistry instance. - """ - return super().get_registry_singleton() # type: ignore[return-value] - def register_instance( self, target: PromptTarget, From 23482c9025655b06f33f50c2f22180777a052163 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 13 Apr 2026 16:48:59 -0700 Subject: [PATCH 04/10] Make attack_scoring_config required in create() and create_technique() Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../attack_technique_registry.py | 4 +-- .../scenario/core/attack_technique_factory.py | 8 ++--- .../test_attack_technique_registry.py | 11 +++++-- .../scenario/test_attack_technique_factory.py | 29 ++++++++++--------- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/pyrit/registry/instance_registries/attack_technique_registry.py b/pyrit/registry/instance_registries/attack_technique_registry.py index 9241f390d0..86da4491b1 100644 --- a/pyrit/registry/instance_registries/attack_technique_registry.py +++ b/pyrit/registry/instance_registries/attack_technique_registry.py @@ -65,7 +65,7 @@ def create_technique( name: str, *, objective_target: PromptTarget, - attack_scoring_config: AttackScoringConfig | None = None, + attack_scoring_config: AttackScoringConfig, attack_adversarial_config: AttackAdversarialConfig | None = None, attack_converter_config: AttackConverterConfig | None = None, ) -> AttackTechnique: @@ -75,7 +75,7 @@ def create_technique( Args: name: The registry name of the technique. objective_target: The target to attack. - attack_scoring_config: Optional scoring configuration override. + attack_scoring_config: Scoring configuration for the attack. attack_adversarial_config: Optional adversarial configuration override. attack_converter_config: Optional converter configuration override. diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index d44c0b3697..40d7bec579 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -133,7 +133,7 @@ def create( self, *, objective_target: PromptTarget, - attack_scoring_config: AttackScoringConfig | None = None, + attack_scoring_config: AttackScoringConfig, attack_adversarial_config: AttackAdversarialConfig | None = None, attack_converter_config: AttackConverterConfig | None = None, ) -> AttackTechnique: @@ -146,8 +146,7 @@ def create( Args: objective_target: The target to attack. - attack_scoring_config: Optional scoring configuration. - Overrides any scoring config in the frozen kwargs. + attack_scoring_config: Scoring configuration for the attack. attack_adversarial_config: Optional adversarial configuration. Overrides any adversarial config in the frozen kwargs. attack_converter_config: Optional converter configuration. @@ -158,8 +157,7 @@ def create( """ kwargs = copy.deepcopy(self._attack_kwargs) kwargs["objective_target"] = objective_target - if attack_scoring_config is not None: - kwargs["attack_scoring_config"] = attack_scoring_config + kwargs["attack_scoring_config"] = attack_scoring_config if attack_adversarial_config is not None: kwargs["attack_adversarial_config"] = attack_adversarial_config if attack_converter_config is not None: diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py index 3a9c6de4ce..809472c87d 100644 --- a/tests/unit/registry/test_attack_technique_registry.py +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -18,8 +18,9 @@ class _StubAttack: """Minimal stub for testing the registry without real AttackStrategy weight.""" - def __init__(self, *, objective_target, max_turns: int = 5): + def __init__(self, *, objective_target, attack_scoring_config=None, max_turns: int = 5): self.objective_target = objective_target + self.attack_scoring_config = attack_scoring_config self.max_turns = max_turns def get_identifier(self): @@ -117,8 +118,9 @@ def test_create_technique_returns_attack_technique(self): factory = AttackTechniqueFactory(attack_class=_StubAttack) self.registry.register_technique(name="stub", factory=factory) target = MagicMock(spec=PromptTarget) + scoring = MagicMock(spec=AttackScoringConfig) - technique = self.registry.create_technique("stub", objective_target=target) + technique = self.registry.create_technique("stub", objective_target=target, attack_scoring_config=scoring) assert isinstance(technique, AttackTechnique) assert isinstance(technique.attack, _StubAttack) @@ -149,6 +151,7 @@ def test_create_technique_raises_on_missing_name(self): self.registry.create_technique( "nonexistent", objective_target=MagicMock(spec=PromptTarget), + attack_scoring_config=MagicMock(spec=AttackScoringConfig), ) def test_create_technique_preserves_frozen_kwargs(self): @@ -159,7 +162,9 @@ def test_create_technique_preserves_frozen_kwargs(self): self.registry.register_technique(name="custom", factory=factory) target = MagicMock(spec=PromptTarget) - technique = self.registry.create_technique("custom", objective_target=target) + technique = self.registry.create_technique( + "custom", objective_target=target, attack_scoring_config=MagicMock(spec=AttackScoringConfig) + ) assert technique.attack.max_turns == 42 diff --git a/tests/unit/scenario/test_attack_technique_factory.py b/tests/unit/scenario/test_attack_technique_factory.py index 8555a29ea8..00734eb009 100644 --- a/tests/unit/scenario/test_attack_technique_factory.py +++ b/tests/unit/scenario/test_attack_technique_factory.py @@ -146,11 +146,14 @@ def test_validate_kwargs_rejects_invalid_param_on_real_attack_class(self): class TestFactoryCreate: """Tests for AttackTechniqueFactory.create().""" + def _scoring(self) -> AttackScoringConfig: + return MagicMock(spec=AttackScoringConfig) + def test_create_produces_attack_technique(self): factory = AttackTechniqueFactory(attack_class=_StubAttack) target = MagicMock(spec=PromptTarget) - technique = factory.create(objective_target=target) + technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) assert isinstance(technique, AttackTechnique) assert isinstance(technique.attack, _StubAttack) @@ -163,7 +166,7 @@ def test_create_passes_frozen_kwargs(self): ) target = MagicMock(spec=PromptTarget) - technique = factory.create(objective_target=target) + technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) assert technique.attack.max_turns == 42 @@ -196,7 +199,7 @@ def test_create_preserves_seed_technique(self): factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seeds) target = MagicMock(spec=PromptTarget) - technique = factory.create(objective_target=target) + technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) assert technique.seed_technique is seeds @@ -208,9 +211,10 @@ def test_create_produces_independent_instances(self): ) target1 = MagicMock(spec=PromptTarget) target2 = MagicMock(spec=PromptTarget) + scoring = self._scoring() - technique1 = factory.create(objective_target=target1) - technique2 = factory.create(objective_target=target2) + technique1 = factory.create(objective_target=target1, attack_scoring_config=scoring) + technique2 = factory.create(objective_target=target2, attack_scoring_config=scoring) assert technique1.attack is not technique2.attack assert technique1.attack.objective_target is target1 @@ -220,9 +224,8 @@ def test_create_deepcopies_kwargs(self): """Mutating the original kwargs dict should not affect future creates.""" mutable_list = [1, 2, 3] - # Use a class that accepts a list param to test deepcopy class _ListAttack: - def __init__(self, *, objective_target, items: list | None = None): + def __init__(self, *, objective_target, attack_scoring_config=None, items: list | None = None): self.objective_target = objective_target self.items = items @@ -235,11 +238,11 @@ def get_identifier(self): ) target = MagicMock(spec=PromptTarget) - technique1 = factory.create(objective_target=target) + technique1 = factory.create(objective_target=target, attack_scoring_config=self._scoring()) # Mutate the source list mutable_list.append(999) - technique2 = factory.create(objective_target=target) + technique2 = factory.create(objective_target=target, attack_scoring_config=self._scoring()) # First create should have the original snapshot assert technique1.attack.items == [1, 2, 3] @@ -247,7 +250,7 @@ def get_identifier(self): assert technique2.attack.items == [1, 2, 3] def test_create_without_optional_configs_omits_them(self): - """When optional configs are None, they should not be passed to the constructor.""" + """When optional configs are None, adversarial and converter should not be passed.""" unset = object() class _SentinelAttack: @@ -255,12 +258,11 @@ def __init__( self, *, objective_target, - attack_scoring_config=unset, + attack_scoring_config, attack_adversarial_config=unset, attack_converter_config=unset, ): self.objective_target = objective_target - self.scoring_was_passed = attack_scoring_config is not unset self.adversarial_was_passed = attack_adversarial_config is not unset self.converter_was_passed = attack_converter_config is not unset @@ -269,9 +271,8 @@ def get_identifier(self): factory = AttackTechniqueFactory(attack_class=_SentinelAttack) target = MagicMock(spec=PromptTarget) - technique = factory.create(objective_target=target) + technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) - assert not technique.attack.scoring_was_passed assert not technique.attack.adversarial_was_passed assert not technique.attack.converter_was_passed From be4e5febd58bf984060f9747af6752fa07b510fa Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 13 Apr 2026 17:20:29 -0700 Subject: [PATCH 05/10] Consolidate _build_metadata into BaseInstanceRegistry base class Remove identical _build_metadata overrides from all 4 instance registry subclasses (converter, scorer, target, attack_technique) by providing a concrete default in the base class. Drop unused MetadataT TypeVar and bound T to Identifiable so the default can call get_identifier() without type: ignore. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../attack_technique_registry.py | 16 +- .../base_instance_registry.py | 33 ++-- .../instance_registries/converter_registry.py | 16 +- .../instance_registries/scorer_registry.py | 16 +- .../instance_registries/target_registry.py | 16 +- .../registry/test_base_instance_registry.py | 153 ++++++++++-------- 6 files changed, 106 insertions(+), 144 deletions(-) diff --git a/pyrit/registry/instance_registries/attack_technique_registry.py b/pyrit/registry/instance_registries/attack_technique_registry.py index 86da4491b1..219f6253f4 100644 --- a/pyrit/registry/instance_registries/attack_technique_registry.py +++ b/pyrit/registry/instance_registries/attack_technique_registry.py @@ -14,7 +14,6 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.identifiers import ComponentIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -32,7 +31,7 @@ logger = logging.getLogger(__name__) -class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory", ComponentIdentifier]): +class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory"]): """ Singleton registry of reusable attack technique factories. @@ -94,16 +93,3 @@ def create_technique( attack_adversarial_config=attack_adversarial_config, attack_converter_config=attack_converter_config, ) - - def _build_metadata(self, name: str, instance: AttackTechniqueFactory) -> ComponentIdentifier: - """ - Build metadata for a technique factory. - - Args: - name: The registry name of the factory. - instance: The factory instance. - - Returns: - ComponentIdentifier: The factory's identifier. - """ - return instance.get_identifier() diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index de344cc782..e256097e1b 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -15,19 +15,18 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union -from pyrit.identifiers import ComponentIdentifier +from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.registry.base import RegistryProtocol if TYPE_CHECKING: from collections.abc import Iterator from typing import Self -T = TypeVar("T") # The type of instances stored -MetadataT = TypeVar("MetadataT", bound=ComponentIdentifier) +T = TypeVar("T", bound=Identifiable) # The type of instances stored @dataclass @@ -49,7 +48,7 @@ class RegistryEntry(Generic[T]): tags: dict[str, str] = field(default_factory=dict) -class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): +class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): """ Abstract base class for registries that store pre-configured instances. @@ -57,16 +56,15 @@ class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, Metadata Type[T] and supports lazy discovery, instance registries store already-instantiated objects that are registered explicitly (typically during initialization). - Type Parameters: - T: The type of instances stored in the registry. - MetadataT: A TypedDict subclass for instance metadata. + All stored instances must implement ``Identifiable``, which provides + ``get_identifier()`` for metadata generation. - Subclasses must implement: - - _build_metadata(): Convert an instance to its metadata representation + Type Parameters: + T: The type of instances stored in the registry (must be Identifiable). """ # Class-level singleton instances, keyed by registry class - _instances: dict[type, BaseInstanceRegistry[Any, Any]] = {} + _instances: dict[type, BaseInstanceRegistry[Any]] = {} @classmethod def get_registry_singleton(cls) -> Self: @@ -114,7 +112,7 @@ def __init__(self) -> None: """Initialize the instance registry.""" # Maps registry names to registry entries self._registry_items: dict[str, RegistryEntry[T]] = {} - self._metadata_cache: Optional[list[MetadataT]] = None + self._metadata_cache: Optional[list[ComponentIdentifier]] = None def register( self, @@ -278,7 +276,7 @@ def list_metadata( *, include_filters: Optional[dict[str, object]] = None, exclude_filters: Optional[dict[str, object]] = None, - ) -> list[MetadataT]: + ) -> list[ComponentIdentifier]: """ List metadata for all registered instances, optionally filtered. @@ -315,19 +313,18 @@ def list_metadata( if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) ] - @abstractmethod - def _build_metadata(self, name: str, instance: T) -> MetadataT: + def _build_metadata(self, name: str, instance: T) -> ComponentIdentifier: """ - Build metadata for an instance. + Build metadata for an instance via its ``Identifiable`` interface. Args: name: The registry name of the instance. instance: The instance. Returns: - A metadata dictionary describing the instance. + The instance's ComponentIdentifier. """ - ... + return instance.get_identifier() def __contains__(self, name: str) -> bool: """ diff --git a/pyrit/registry/instance_registries/converter_registry.py b/pyrit/registry/instance_registries/converter_registry.py index eacb8db9f5..19f8d03108 100644 --- a/pyrit/registry/instance_registries/converter_registry.py +++ b/pyrit/registry/instance_registries/converter_registry.py @@ -14,7 +14,6 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.identifiers import ComponentIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -25,7 +24,7 @@ logger = logging.getLogger(__name__) -class ConverterRegistry(BaseInstanceRegistry["PromptConverter", ComponentIdentifier]): +class ConverterRegistry(BaseInstanceRegistry["PromptConverter"]): """ Registry for managing available converter instances. @@ -68,16 +67,3 @@ def get_instance_by_name(self, name: str) -> Optional[PromptConverter]: The converter instance, or None if not found. """ return self.get(name) - - def _build_metadata(self, name: str, instance: PromptConverter) -> ComponentIdentifier: - """ - Build metadata for a converter instance. - - Args: - name: The registry name of the converter. - instance: The converter instance. - - Returns: - ComponentIdentifier: The converter's identifier. - """ - return instance.get_identifier() diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index 814356ef35..f645309508 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -12,7 +12,6 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.identifiers import ComponentIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -23,7 +22,7 @@ logger = logging.getLogger(__name__) -class ScorerRegistry(BaseInstanceRegistry["Scorer", ComponentIdentifier]): +class ScorerRegistry(BaseInstanceRegistry["Scorer"]): """ Registry for managing available scorer instances. @@ -74,16 +73,3 @@ def get_instance_by_name(self, name: str) -> Optional[Scorer]: The scorer instance, or None if not found. """ return self.get(name) - - def _build_metadata(self, name: str, instance: Scorer) -> ComponentIdentifier: - """ - Build metadata for a scorer instance. - - Args: - name: The registry name of the scorer. - instance: The scorer instance. - - Returns: - ComponentIdentifier: The scorer's identifier - """ - return instance.get_identifier() diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py index 1f47ccde98..88ae19f49c 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/instance_registries/target_registry.py @@ -12,7 +12,6 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.identifiers import ComponentIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -23,7 +22,7 @@ logger = logging.getLogger(__name__) -class TargetRegistry(BaseInstanceRegistry["PromptTarget", ComponentIdentifier]): +class TargetRegistry(BaseInstanceRegistry["PromptTarget"]): """ Registry for managing available prompt target instances. @@ -75,16 +74,3 @@ def get_instance_by_name(self, name: str) -> Optional[PromptTarget]: The target instance, or None if not found. """ return self.get(name) - - def _build_metadata(self, name: str, instance: PromptTarget) -> ComponentIdentifier: - """ - Build metadata for a target instance. - - Args: - name: The registry name of the target. - instance: The target instance. - - Returns: - ComponentIdentifier: The target's identifier. - """ - return instance.get_identifier() diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index e61300cb0e..f3b698041e 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -3,21 +3,45 @@ import pytest -from pyrit.identifiers import ComponentIdentifier +from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry, RegistryEntry -class ConcreteTestRegistry(BaseInstanceRegistry[str, ComponentIdentifier]): - """Concrete implementation of BaseInstanceRegistry for testing.""" +class _TestItem(Identifiable): + """Minimal Identifiable stub wrapping a string value for testing.""" + + def __init__(self, value: str) -> None: + self.value = value - def _build_metadata(self, name: str, instance: str) -> ComponentIdentifier: - """Build test metadata from a string instance.""" + def _build_identifier(self) -> ComponentIdentifier: return ComponentIdentifier( - class_name="str", - class_module="builtins", - params={"category": "test" if "test" in instance.lower() else "other"}, + class_name="_TestItem", + class_module="test", + params={"category": "test" if "test" in self.value.lower() else "other"}, ) + def __eq__(self, other: object) -> bool: + if isinstance(other, _TestItem): + return self.value == other.value + if isinstance(other, str): + return self.value == other + return NotImplemented + + def __hash__(self) -> int: + return hash(self.value) + + def __repr__(self) -> str: + return f"_TestItem({self.value!r})" + + +def _item(value: str) -> _TestItem: + """Shorthand factory for _TestItem.""" + return _TestItem(value) + + +class ConcreteTestRegistry(BaseInstanceRegistry["_TestItem"]): + """Concrete implementation of BaseInstanceRegistry for testing.""" + class TestBaseInstanceRegistrySingleton: """Tests for the singleton pattern in BaseInstanceRegistry.""" @@ -66,16 +90,16 @@ def teardown_method(self): def test_register_adds_instance(self): """Test that register adds an instance to the registry.""" - self.registry.register("test_value", name="test_name") + self.registry.register(_item("test_value"), name="test_name") assert "test_name" in self.registry assert self.registry.get("test_name") == "test_value" def test_register_multiple_instances(self): """Test registering multiple instances.""" - self.registry.register("value1", name="name1") - self.registry.register("value2", name="name2") - self.registry.register("value3", name="name3") + self.registry.register(_item("value1"), name="name1") + self.registry.register(_item("value2"), name="name2") + self.registry.register(_item("value3"), name="name3") assert len(self.registry) == 3 assert self.registry.get("name1") == "value1" @@ -84,21 +108,21 @@ def test_register_multiple_instances(self): def test_register_overwrites_existing(self): """Test that registering with the same name overwrites the existing instance.""" - self.registry.register("original", name="name") - self.registry.register("updated", name="name") + self.registry.register(_item("original"), name="name") + self.registry.register(_item("updated"), name="name") assert len(self.registry) == 1 assert self.registry.get("name") == "updated" def test_register_invalidates_metadata_cache(self): """Test that registering a new instance invalidates the metadata cache.""" - self.registry.register("value1", name="name1") + self.registry.register(_item("value1"), name="name1") # Build cache by calling list_metadata metadata1 = self.registry.list_metadata() assert len(metadata1) == 1 # Register new instance - should invalidate cache - self.registry.register("value2", name="name2") + self.registry.register(_item("value2"), name="name2") metadata2 = self.registry.list_metadata() assert len(metadata2) == 2 @@ -111,7 +135,7 @@ def setup_method(self): """Reset and get a fresh registry for each test.""" ConcreteTestRegistry.reset_instance() self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register("test_value", name="test_name") + self.registry.register(_item("test_value"), name="test_name") def teardown_method(self): """Reset the singleton after each test.""" @@ -135,7 +159,7 @@ def setup_method(self): """Reset and get a fresh registry for each test.""" ConcreteTestRegistry.reset_instance() self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register("test_value", name="test_name", tags={"role": "scorer"}) + self.registry.register(_item("test_value"), name="test_name", tags={"role": "scorer"}) def teardown_method(self): """Reset the singleton after each test.""" @@ -175,9 +199,9 @@ def test_get_names_empty_registry(self): def test_get_names_returns_sorted_list(self): """Test that get_names returns a sorted list of names.""" - self.registry.register("value3", name="zeta") - self.registry.register("value1", name="alpha") - self.registry.register("value2", name="beta") + self.registry.register(_item("value3"), name="zeta") + self.registry.register(_item("value1"), name="alpha") + self.registry.register(_item("value2"), name="beta") names = self.registry.get_names() assert names == ["alpha", "beta", "zeta"] @@ -197,8 +221,8 @@ def teardown_method(self): def test_get_all_instances_returns_list_of_registry_entries(self): """Test that get_all_instances returns a list of RegistryEntry objects.""" - self.registry.register("value1", name="name1") - self.registry.register("value2", name="name2") + self.registry.register(_item("value1"), name="name1") + self.registry.register(_item("value2"), name="name2") result = self.registry.get_all_instances() assert isinstance(result, list) @@ -207,17 +231,17 @@ def test_get_all_instances_returns_list_of_registry_entries(self): def test_get_all_instances_sorted_by_name(self): """Test that get_all_instances returns entries sorted by name.""" - self.registry.register("value_z", name="zeta") - self.registry.register("value_a", name="alpha") - self.registry.register("value_b", name="beta") + self.registry.register(_item("value_z"), name="zeta") + self.registry.register(_item("value_a"), name="alpha") + self.registry.register(_item("value_b"), name="beta") result = self.registry.get_all_instances() assert [e.name for e in result] == ["alpha", "beta", "zeta"] def test_get_all_instances_preserves_tags(self): """Test that get_all_instances preserves tags on entries.""" - self.registry.register("value1", name="name1", tags={"role": "scorer"}) - self.registry.register("value2", name="name2", tags=["fast"]) + self.registry.register(_item("value1"), name="name1", tags={"role": "scorer"}) + self.registry.register(_item("value2"), name="name2", tags=["fast"]) result = self.registry.get_all_instances() entry_map = {e.name: e for e in result} @@ -237,9 +261,9 @@ def setup_method(self): """Reset and get a fresh registry for each test.""" ConcreteTestRegistry.reset_instance() self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register("test_item_1", name="item1") - self.registry.register("other_item_2", name="item2") - self.registry.register("test_item_3", name="item3") + self.registry.register(_item("test_item_1"), name="item1") + self.registry.register(_item("other_item_2"), name="item2") + self.registry.register(_item("test_item_3"), name="item3") def teardown_method(self): """Reset the singleton after each test.""" @@ -256,9 +280,9 @@ def test_list_metadata_sorted_by_name(self): # Since unique_name is auto-computed, we verify we get 3 items in order # The actual unique_name field is auto-computed from class_name::hash assert len(metadata) == 3 - # All should have "str" in the unique_name since class_name is "str" + # All should have "_TestItem" in the unique_name since class_name is "_TestItem" for m in metadata: - assert "str" in m.unique_name + assert "_TestItem" in m.unique_name def test_list_metadata_with_filter(self): """Test filtering metadata by a field.""" @@ -280,7 +304,7 @@ def test_list_metadata_with_exclude_filter(self): def test_list_metadata_combined_include_and_exclude(self): """Test combined include and exclude filters.""" # Add another test item to have more variety - self.registry.register("another_test_item", name="item4") + self.registry.register(_item("another_test_item"), name="item4") # Get items with category "test" but exclude by class_name "str" # Since all have class_name="str", excluding by class_name would exclude all @@ -315,7 +339,7 @@ def teardown_method(self): def test_register_with_dict_tags(self): """Test that dict tags are stored correctly.""" - self.registry.register("value", name="name1", tags={"role": "scorer", "provider": "azure"}) + self.registry.register(_item("value"), name="name1", tags={"role": "scorer", "provider": "azure"}) entry = self.registry.get_entry("name1") assert entry is not None @@ -323,7 +347,7 @@ def test_register_with_dict_tags(self): def test_register_with_list_tags(self): """Test that list tags are normalized to dict with empty string values.""" - self.registry.register("value", name="name1", tags=["fast", "default"]) + self.registry.register(_item("value"), name="name1", tags=["fast", "default"]) entry = self.registry.get_entry("name1") assert entry is not None @@ -331,7 +355,7 @@ def test_register_with_list_tags(self): def test_register_without_tags(self): """Test that registering without tags defaults to empty dict.""" - self.registry.register("value", name="name1") + self.registry.register(_item("value"), name="name1") entry = self.registry.get_entry("name1") assert entry is not None @@ -339,9 +363,9 @@ def test_register_without_tags(self): def test_get_by_tag_key_only(self): """Test get_by_tag matching by key only (any value).""" - self.registry.register("v1", name="n1", tags={"role": "scorer"}) - self.registry.register("v2", name="n2", tags={"role": "target"}) - self.registry.register("v3", name="n3", tags={"provider": "azure"}) + self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) + self.registry.register(_item("v2"), name="n2", tags={"role": "target"}) + self.registry.register(_item("v3"), name="n3", tags={"provider": "azure"}) results = self.registry.get_by_tag(tag="role") assert len(results) == 2 @@ -349,9 +373,9 @@ def test_get_by_tag_key_only(self): def test_get_by_tag_key_and_value(self): """Test get_by_tag matching by key and specific value.""" - self.registry.register("v1", name="n1", tags={"role": "scorer"}) - self.registry.register("v2", name="n2", tags={"role": "target"}) - self.registry.register("v3", name="n3", tags={"role": "scorer"}) + self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) + self.registry.register(_item("v2"), name="n2", tags={"role": "target"}) + self.registry.register(_item("v3"), name="n3", tags={"role": "scorer"}) results = self.registry.get_by_tag(tag="role", value="scorer") assert len(results) == 2 @@ -359,31 +383,31 @@ def test_get_by_tag_key_and_value(self): def test_get_by_tag_no_match(self): """Test get_by_tag returns empty list when no entries match.""" - self.registry.register("v1", name="n1", tags={"role": "scorer"}) + self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) results = self.registry.get_by_tag(tag="nonexistent") assert results == [] def test_get_by_tag_value_no_match(self): """Test get_by_tag returns empty when key exists but value does not match.""" - self.registry.register("v1", name="n1", tags={"role": "scorer"}) + self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) results = self.registry.get_by_tag(tag="role", value="nonexistent") assert results == [] def test_get_by_tag_returns_sorted_by_name(self): """Test that get_by_tag results are sorted by name.""" - self.registry.register("v3", name="zeta", tags=["shared"]) - self.registry.register("v1", name="alpha", tags=["shared"]) - self.registry.register("v2", name="beta", tags=["shared"]) + self.registry.register(_item("v3"), name="zeta", tags=["shared"]) + self.registry.register(_item("v1"), name="alpha", tags=["shared"]) + self.registry.register(_item("v2"), name="beta", tags=["shared"]) results = self.registry.get_by_tag(tag="shared") assert [e.name for e in results] == ["alpha", "beta", "zeta"] def test_get_by_tag_with_list_tags(self): """Test get_by_tag works with list-style tags (normalized to empty string values).""" - self.registry.register("v1", name="n1", tags=["fast", "default"]) - self.registry.register("v2", name="n2", tags=["slow"]) + self.registry.register(_item("v1"), name="n1", tags=["fast", "default"]) + self.registry.register(_item("v2"), name="n2", tags=["slow"]) results = self.registry.get_by_tag(tag="fast") assert len(results) == 1 @@ -391,7 +415,7 @@ def test_get_by_tag_with_list_tags(self): def test_get_by_tag_with_list_tags_value_empty_string(self): """Test get_by_tag with explicit empty string value matches list-style tags.""" - self.registry.register("v1", name="n1", tags=["fast"]) + self.registry.register(_item("v1"), name="n1", tags=["fast"]) results = self.registry.get_by_tag(tag="fast", value="") assert len(results) == 1 @@ -420,8 +444,8 @@ def setup_method(self): """Reset and get a fresh registry for each test.""" ConcreteTestRegistry.reset_instance() self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register("value1", name="name1") - self.registry.register("value2", name="name2") + self.registry.register(_item("value1"), name="name1") + self.registry.register(_item("value2"), name="name2") def teardown_method(self): """Reset the singleton after each test.""" @@ -471,7 +495,7 @@ def teardown_method(self): def test_add_tags_with_list(self): """Test adding list-style tags to an existing entry.""" - self.registry.register("value", name="entry1") + self.registry.register(_item("value"), name="entry1") self.registry.add_tags(name="entry1", tags=["fast", "default"]) entry = self.registry.get_entry("entry1") @@ -480,7 +504,7 @@ def test_add_tags_with_list(self): def test_add_tags_with_dict(self): """Test adding dict-style tags to an existing entry.""" - self.registry.register("value", name="entry1") + self.registry.register(_item("value"), name="entry1") self.registry.add_tags(name="entry1", tags={"role": "scorer"}) entry = self.registry.get_entry("entry1") @@ -489,7 +513,7 @@ def test_add_tags_with_dict(self): def test_add_tags_merges_with_existing(self): """Test that add_tags merges new tags with existing ones.""" - self.registry.register("value", name="entry1", tags={"existing": "yes"}) + self.registry.register(_item("value"), name="entry1", tags={"existing": "yes"}) self.registry.add_tags(name="entry1", tags=["new_tag"]) entry = self.registry.get_entry("entry1") @@ -503,7 +527,7 @@ def test_add_tags_raises_for_missing_entry(self): def test_add_tags_invalidates_metadata_cache(self): """Test that add_tags invalidates the metadata cache.""" - self.registry.register("value", name="entry1") + self.registry.register(_item("value"), name="entry1") self.registry.list_metadata() # Build cache self.registry.add_tags(name="entry1", tags=["new"]) @@ -513,7 +537,7 @@ def test_add_tags_invalidates_metadata_cache(self): def test_add_tags_entries_findable_by_get_by_tag(self): """Test that entries are findable via get_by_tag after add_tags.""" - self.registry.register("value", name="entry1") + self.registry.register(_item("value"), name="entry1") self.registry.add_tags(name="entry1", tags=["best_scorer"]) results = self.registry.get_by_tag(tag="best_scorer") @@ -521,22 +545,19 @@ def test_add_tags_entries_findable_by_get_by_tag(self): assert results[0].name == "entry1" -class _IdentifiableStub: +class _IdentifiableStub(Identifiable): """A minimal stub that holds a ComponentIdentifier for dependency tests.""" def __init__(self, identifier: ComponentIdentifier) -> None: - self.identifier = identifier + self._stored_identifier = identifier - def get_identifier(self) -> ComponentIdentifier: - return self.identifier + def _build_identifier(self) -> ComponentIdentifier: + return self._stored_identifier -class IdentifierTestRegistry(BaseInstanceRegistry["_IdentifiableStub", ComponentIdentifier]): +class IdentifierTestRegistry(BaseInstanceRegistry["_IdentifiableStub"]): """Registry for testing dependency-related functionality with ComponentIdentifier trees.""" - def _build_metadata(self, name: str, instance: "_IdentifiableStub") -> ComponentIdentifier: - return instance.get_identifier() - class TestFindDependentsOfTag: """Tests for BaseInstanceRegistry.find_dependents_of_tag.""" From 77aebc388d3b11174b8b64eb2215884233230647 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 14 Apr 2026 12:13:01 -0700 Subject: [PATCH 06/10] Split BaseItemRegistry from BaseInstanceRegistry for clean hierarchy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract shared infrastructure (singleton, registration, tags, metadata, container protocol) into BaseItemRegistry. BaseInstanceRegistry extends it with get(), get_entry(), get_all_instances() for direct-retrieval registries. AttackTechniqueRegistry now extends BaseItemRegistry directly, so factory registries don't inherit misleading instance-retrieval methods. Hierarchy: BaseItemRegistry (shared core) ├── BaseInstanceRegistry (+ get/get_entry/get_all_instances) │ ├── ConverterRegistry │ ├── ScorerRegistry │ └── TargetRegistry └── AttackTechniqueRegistry (factory — no get()) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/registry/__init__.py | 2 + .../registry/instance_registries/__init__.py | 4 +- .../attack_technique_registry.py | 10 +- .../base_instance_registry.py | 157 ++++++++++-------- .../test_attack_technique_registry.py | 2 +- .../registry/test_base_instance_registry.py | 46 ++++- 6 files changed, 147 insertions(+), 74 deletions(-) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 4dcb5bd2c7..d32ae963fa 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -20,6 +20,7 @@ from pyrit.registry.instance_registries import ( AttackTechniqueRegistry, BaseInstanceRegistry, + BaseItemRegistry, RegistryEntry, ScorerRegistry, TargetRegistry, @@ -29,6 +30,7 @@ "AttackTechniqueRegistry", "BaseClassRegistry", "BaseInstanceRegistry", + "BaseItemRegistry", "ClassEntry", "discover_in_directory", "discover_in_package", diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index f53335b65c..2055b0a5a3 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -16,6 +16,7 @@ ) from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, + BaseItemRegistry, RegistryEntry, ) from pyrit.registry.instance_registries.converter_registry import ( @@ -29,8 +30,9 @@ ) __all__ = [ - # Base class + # Base classes "BaseInstanceRegistry", + "BaseItemRegistry", "RegistryEntry", # Concrete registries "AttackTechniqueRegistry", diff --git a/pyrit/registry/instance_registries/attack_technique_registry.py b/pyrit/registry/instance_registries/attack_technique_registry.py index 219f6253f4..611906becc 100644 --- a/pyrit/registry/instance_registries/attack_technique_registry.py +++ b/pyrit/registry/instance_registries/attack_technique_registry.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional, Union from pyrit.registry.instance_registries.base_instance_registry import ( - BaseInstanceRegistry, + BaseItemRegistry, ) if TYPE_CHECKING: @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory"]): +class AttackTechniqueRegistry(BaseItemRegistry["AttackTechniqueFactory"]): """ Singleton registry of reusable attack technique factories. @@ -84,10 +84,10 @@ def create_technique( Raises: KeyError: If no technique is registered with the given name. """ - factory = self.get(name) - if factory is None: + entry = self._registry_items.get(name) + if entry is None: raise KeyError(f"No technique registered with name '{name}'") - return factory.create( + return entry.instance.create( objective_target=objective_target, attack_scoring_config=attack_scoring_config, attack_adversarial_config=attack_adversarial_config, diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index e256097e1b..e71bc6bfb4 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -2,15 +2,22 @@ # Licensed under the MIT license. """ -Base instance registry for PyRIT. +Base item and instance registries for PyRIT. -This module provides the abstract base class for registries that store -pre-configured instances (not classes). Unlike class registries which -store Type[T] and create instances on demand, instance registries store -already-instantiated objects. +This module provides two base classes for registries that store +``Identifiable`` objects (not classes): -Examples include: -- ScorerRegistry: stores Scorer instances configured with their chat_target +- ``BaseItemRegistry``: Shared infrastructure — singleton lifecycle, + registration, tags, metadata, container protocol. Use this directly + for registries that store factories or other non-retrievable items + (e.g., ``AttackTechniqueRegistry``). + +- ``BaseInstanceRegistry(BaseItemRegistry)``: Adds ``get()``, + ``get_entry()``, and ``get_all_instances()`` for registries where + callers retrieve stored objects directly (e.g., ``ScorerRegistry``, + ``ConverterRegistry``, ``TargetRegistry``). + +For registries that store classes (Type[T]), see ``class_registries/``. """ from __future__ import annotations @@ -26,13 +33,13 @@ from collections.abc import Iterator from typing import Self -T = TypeVar("T", bound=Identifiable) # The type of instances stored +T = TypeVar("T", bound=Identifiable) # The type of items stored @dataclass class RegistryEntry(Generic[T]): """ - A wrapper around a registered instance, holding its name, tags, and the instance itself. + A wrapper around a registered item, holding its name, tags, and the item itself. Tags are always stored as ``dict[str, str]``. When callers pass a plain ``list[str]``, each string is normalized to a key with an empty-string value. @@ -48,23 +55,27 @@ class RegistryEntry(Generic[T]): tags: dict[str, str] = field(default_factory=dict) -class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): +class BaseItemRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): """ - Abstract base class for registries that store pre-configured instances. + Abstract base class providing shared registry infrastructure. + + Provides singleton lifecycle, registration, tag-based lookup, metadata + filtering, and the standard container protocol (``__contains__``, + ``__len__``, ``__iter__``). - This class implements RegistryProtocol. Unlike BaseClassRegistry which stores - Type[T] and supports lazy discovery, instance registries store already-instantiated - objects that are registered explicitly (typically during initialization). + Subclass directly when stored items should not be retrievable via + ``get()`` (e.g., factory registries). For registries that expose + direct item retrieval, subclass ``BaseInstanceRegistry`` instead. - All stored instances must implement ``Identifiable``, which provides + All stored items must implement ``Identifiable``, which provides ``get_identifier()`` for metadata generation. Type Parameters: - T: The type of instances stored in the registry (must be Identifiable). + T: The type of items stored in the registry (must be Identifiable). """ # Class-level singleton instances, keyed by registry class - _instances: dict[type, BaseInstanceRegistry[Any]] = {} + _instances: dict[type, BaseItemRegistry[Any]] = {} @classmethod def get_registry_singleton(cls) -> Self: @@ -109,7 +120,7 @@ def _normalize_tags(tags: Optional[Union[dict[str, str], list[str]]] = None) -> return dict(tags) def __init__(self) -> None: - """Initialize the instance registry.""" + """Initialize the registry.""" # Maps registry names to registry entries self._registry_items: dict[str, RegistryEntry[T]] = {} self._metadata_cache: Optional[list[ComponentIdentifier]] = None @@ -122,11 +133,11 @@ def register( tags: Optional[Union[dict[str, str], list[str]]] = None, ) -> None: """ - Register an instance. + Register an item. Args: - instance: The pre-configured instance to register. - name: The registry name for this instance. + instance: The item to register. + name: The registry name for this item. tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` or a ``list[str]`` (each string becomes a key with value ``""``). """ @@ -134,33 +145,6 @@ def register( self._registry_items[name] = RegistryEntry(name=name, instance=instance, tags=normalized) self._metadata_cache = None - def get(self, name: str) -> Optional[T]: - """ - Get a registered instance by name. - - Args: - name: The registry name of the instance. - - Returns: - The instance, or None if not found. - """ - entry = self._registry_items.get(name) - if entry is None: - return None - return entry.instance - - def get_entry(self, name: str) -> Optional[RegistryEntry[T]]: - """ - Get a full registry entry by name, including tags. - - Args: - name: The registry name of the entry. - - Returns: - The RegistryEntry, or None if not found. - """ - return self._registry_items.get(name) - def get_names(self) -> list[str]: """ Get a sorted list of all registered names. @@ -170,15 +154,6 @@ def get_names(self) -> list[str]: """ return sorted(self._registry_items.keys()) - def get_all_instances(self) -> list[RegistryEntry[T]]: - """ - Get all registered entries sorted by name. - - Returns: - List of RegistryEntry objects sorted by name. - """ - return [self._registry_items[name] for name in sorted(self._registry_items.keys())] - def get_by_tag( self, *, @@ -278,7 +253,7 @@ def list_metadata( exclude_filters: Optional[dict[str, object]] = None, ) -> list[ComponentIdentifier]: """ - List metadata for all registered instances, optionally filtered. + List metadata for all registered items, optionally filtered. Supports filtering on any metadata property: - Simple types (str, int, bool): exact match @@ -293,7 +268,7 @@ def list_metadata( Any matching filter excludes the item. Returns: - List of metadata dictionaries describing each registered instance. + List of ComponentIdentifier metadata for each registered item. """ from pyrit.registry.base import _matches_filters @@ -315,14 +290,14 @@ def list_metadata( def _build_metadata(self, name: str, instance: T) -> ComponentIdentifier: """ - Build metadata for an instance via its ``Identifiable`` interface. + Build metadata for an item via its ``Identifiable`` interface. Args: - name: The registry name of the instance. - instance: The instance. + name: The registry name of the item. + instance: The item. Returns: - The instance's ComponentIdentifier. + The item's ComponentIdentifier. """ return instance.get_identifier() @@ -337,10 +312,10 @@ def __contains__(self, name: str) -> bool: def __len__(self) -> int: """ - Get the count of registered instances. + Get the count of registered items. Returns: - The number of registered instances. + The number of registered items. """ return len(self._registry_items) @@ -352,3 +327,55 @@ def __iter__(self) -> Iterator[str]: An iterator over sorted registered names. """ return iter(sorted(self._registry_items.keys())) + + +class BaseInstanceRegistry(BaseItemRegistry[T]): + """ + Base class for registries that store directly-retrievable instances. + + Extends ``BaseItemRegistry`` with ``get()``, ``get_entry()``, and + ``get_all_instances()`` for registries where callers retrieve the + stored objects directly (e.g., scorers, converters, targets). + + For registries that store factories or other non-retrievable items, + subclass ``BaseItemRegistry`` directly instead. + + Type Parameters: + T: The type of instances stored in the registry (must be Identifiable). + """ + + def get(self, name: str) -> Optional[T]: + """ + Get a registered instance by name. + + Args: + name: The registry name of the instance. + + Returns: + The instance, or None if not found. + """ + entry = self._registry_items.get(name) + if entry is None: + return None + return entry.instance + + def get_entry(self, name: str) -> Optional[RegistryEntry[T]]: + """ + Get a full registry entry by name, including tags. + + Args: + name: The registry name of the entry. + + Returns: + The RegistryEntry, or None if not found. + """ + return self._registry_items.get(name) + + def get_all_instances(self) -> list[RegistryEntry[T]]: + """ + Get all registered entries sorted by name. + + Returns: + List of RegistryEntry objects sorted by name. + """ + return [self._registry_items[name] for name in sorted(self._registry_items.keys())] diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py index 809472c87d..4bf1827112 100644 --- a/tests/unit/registry/test_attack_technique_registry.py +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -75,7 +75,7 @@ def test_register_technique_stores_factory(self): self.registry.register_technique(name="stub_attack", factory=factory) assert "stub_attack" in self.registry - assert self.registry.get("stub_attack") is factory + assert self.registry._registry_items["stub_attack"].instance is factory def test_register_technique_with_tags(self): factory = AttackTechniqueFactory(attack_class=_StubAttack) diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index f3b698041e..71a47fdc41 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -4,7 +4,11 @@ import pytest from pyrit.identifiers import ComponentIdentifier, Identifiable -from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry, RegistryEntry +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, + BaseItemRegistry, + RegistryEntry, +) class _TestItem(Identifiable): @@ -481,6 +485,44 @@ def test_iter_allows_for_loop(self): assert collected == ["name1", "name2"] +class _ItemOnlyRegistry(BaseItemRegistry["_TestItem"]): + """Concrete BaseItemRegistry subclass — should NOT have get/get_entry/get_all_instances.""" + + +class TestBaseItemRegistryDoesNotExposeInstanceMethods: + """Verify that BaseItemRegistry subclasses lack instance-retrieval methods.""" + + def test_item_registry_has_no_get(self): + """BaseItemRegistry subclasses should not have a get() method.""" + assert not hasattr(_ItemOnlyRegistry, "get") + + def test_item_registry_has_no_get_entry(self): + """BaseItemRegistry subclasses should not have a get_entry() method.""" + assert not hasattr(_ItemOnlyRegistry, "get_entry") + + def test_item_registry_has_no_get_all_instances(self): + """BaseItemRegistry subclasses should not have a get_all_instances() method.""" + assert not hasattr(_ItemOnlyRegistry, "get_all_instances") + + def test_instance_registry_has_get(self): + """BaseInstanceRegistry subclasses should have get().""" + assert hasattr(ConcreteTestRegistry, "get") + + def test_instance_registry_has_get_entry(self): + """BaseInstanceRegistry subclasses should have get_entry().""" + assert hasattr(ConcreteTestRegistry, "get_entry") + + def test_instance_registry_has_get_all_instances(self): + """BaseInstanceRegistry subclasses should have get_all_instances().""" + assert hasattr(ConcreteTestRegistry, "get_all_instances") + + def test_item_registry_shares_common_methods(self): + """BaseItemRegistry subclasses should have shared registry methods.""" + for method in ("register", "get_names", "get_by_tag", "add_tags", "list_metadata", + "find_dependents_of_tag", "get_registry_singleton", "reset_instance"): + assert hasattr(_ItemOnlyRegistry, method), f"Missing method: {method}" + + class TestBaseInstanceRegistryAddTags: """Tests for add_tags functionality in BaseInstanceRegistry.""" @@ -555,7 +597,7 @@ def _build_identifier(self) -> ComponentIdentifier: return self._stored_identifier -class IdentifierTestRegistry(BaseInstanceRegistry["_IdentifiableStub"]): +class IdentifierTestRegistry(BaseItemRegistry["_IdentifiableStub"]): """Registry for testing dependency-related functionality with ComponentIdentifier trees.""" From 6e977bac196a97df626de4b5499fe190d3f3a5e7 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 14 Apr 2026 16:38:03 -0700 Subject: [PATCH 07/10] Extract BaseItemRegistry into its own base_item_registry.py file Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../registry/instance_registries/__init__.py | 2 + .../attack_technique_registry.py | 6 +- .../base_instance_registry.py | 335 +----------------- .../instance_registries/base_item_registry.py | 325 +++++++++++++++++ .../scenario/core/attack_technique_factory.py | 9 +- .../registry/test_base_instance_registry.py | 14 +- 6 files changed, 360 insertions(+), 331 deletions(-) create mode 100644 pyrit/registry/instance_registries/base_item_registry.py diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index 2055b0a5a3..c9fe2e3642 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -16,6 +16,8 @@ ) from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, +) +from pyrit.registry.instance_registries.base_item_registry import ( BaseItemRegistry, RegistryEntry, ) diff --git a/pyrit/registry/instance_registries/attack_technique_registry.py b/pyrit/registry/instance_registries/attack_technique_registry.py index 611906becc..c1476895b0 100644 --- a/pyrit/registry/instance_registries/attack_technique_registry.py +++ b/pyrit/registry/instance_registries/attack_technique_registry.py @@ -12,9 +12,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING -from pyrit.registry.instance_registries.base_instance_registry import ( +from pyrit.registry.instance_registries.base_item_registry import ( BaseItemRegistry, ) @@ -45,7 +45,7 @@ def register_technique( *, name: str, factory: AttackTechniqueFactory, - tags: Optional[Union[dict[str, str], list[str]]] = None, + tags: dict[str, str] | list[str] | None = None, ) -> None: """ Register an attack technique factory. diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index e71bc6bfb4..11e501d6f3 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -2,331 +2,28 @@ # Licensed under the MIT license. """ -Base item and instance registries for PyRIT. +Instance registry for PyRIT. -This module provides two base classes for registries that store -``Identifiable`` objects (not classes): - -- ``BaseItemRegistry``: Shared infrastructure — singleton lifecycle, - registration, tags, metadata, container protocol. Use this directly - for registries that store factories or other non-retrievable items - (e.g., ``AttackTechniqueRegistry``). - -- ``BaseInstanceRegistry(BaseItemRegistry)``: Adds ``get()``, - ``get_entry()``, and ``get_all_instances()`` for registries where - callers retrieve stored objects directly (e.g., ``ScorerRegistry``, - ``ConverterRegistry``, ``TargetRegistry``). +This module provides ``BaseInstanceRegistry``, which extends +``BaseItemRegistry`` with ``get()``, ``get_entry()``, and +``get_all_instances()`` for registries where callers retrieve stored +objects directly (e.g., ``ScorerRegistry``, ``ConverterRegistry``, +``TargetRegistry``). +For the shared base class, see ``base_item_registry``. For registries that store classes (Type[T]), see ``class_registries/``. """ from __future__ import annotations -from abc import ABC -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union - -from pyrit.identifiers import ComponentIdentifier, Identifiable -from pyrit.registry.base import RegistryProtocol - -if TYPE_CHECKING: - from collections.abc import Iterator - from typing import Self - -T = TypeVar("T", bound=Identifiable) # The type of items stored - - -@dataclass -class RegistryEntry(Generic[T]): - """ - A wrapper around a registered item, holding its name, tags, and the item itself. - - Tags are always stored as ``dict[str, str]``. When callers pass a plain - ``list[str]``, each string is normalized to a key with an empty-string value. - - Attributes: - name: The registry name for this entry. - instance: The registered object. - tags: Key-value tags for categorization and filtering. - """ - - name: str - instance: T - tags: dict[str, str] = field(default_factory=dict) - - -class BaseItemRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): - """ - Abstract base class providing shared registry infrastructure. - - Provides singleton lifecycle, registration, tag-based lookup, metadata - filtering, and the standard container protocol (``__contains__``, - ``__len__``, ``__iter__``). - - Subclass directly when stored items should not be retrievable via - ``get()`` (e.g., factory registries). For registries that expose - direct item retrieval, subclass ``BaseInstanceRegistry`` instead. - - All stored items must implement ``Identifiable``, which provides - ``get_identifier()`` for metadata generation. - - Type Parameters: - T: The type of items stored in the registry (must be Identifiable). - """ - - # Class-level singleton instances, keyed by registry class - _instances: dict[type, BaseItemRegistry[Any]] = {} - - @classmethod - def get_registry_singleton(cls) -> Self: - """ - Get the singleton instance of this registry. - - Creates the instance on first call with default parameters. - - Returns: - The singleton instance of this registry class. - """ - if cls not in cls._instances: - cls._instances[cls] = cls() - return cls._instances[cls] # type: ignore[return-value] - - @classmethod - def reset_instance(cls) -> None: - """ - Reset the singleton instance. - - Useful for testing or reinitializing the registry. - """ - if cls in cls._instances: - del cls._instances[cls] - - @staticmethod - def _normalize_tags(tags: Optional[Union[dict[str, str], list[str]]] = None) -> dict[str, str]: - """ - Normalize tags into a ``dict[str, str]``. - - Args: - tags: Tags as a dict, a list of string keys (values default to ``""``), - or ``None`` (returns empty dict). - - Returns: - A ``dict[str, str]`` of normalised tags. - """ - if tags is None: - return {} - if isinstance(tags, list): - return dict.fromkeys(tags, "") - return dict(tags) - - def __init__(self) -> None: - """Initialize the registry.""" - # Maps registry names to registry entries - self._registry_items: dict[str, RegistryEntry[T]] = {} - self._metadata_cache: Optional[list[ComponentIdentifier]] = None - - def register( - self, - instance: T, - *, - name: str, - tags: Optional[Union[dict[str, str], list[str]]] = None, - ) -> None: - """ - Register an item. - - Args: - instance: The item to register. - name: The registry name for this item. - tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` - or a ``list[str]`` (each string becomes a key with value ``""``). - """ - normalized = self._normalize_tags(tags) - self._registry_items[name] = RegistryEntry(name=name, instance=instance, tags=normalized) - self._metadata_cache = None - - def get_names(self) -> list[str]: - """ - Get a sorted list of all registered names. - - Returns: - Sorted list of registry names (keys). - """ - return sorted(self._registry_items.keys()) - - def get_by_tag( - self, - *, - tag: str, - value: Optional[str] = None, - ) -> list[RegistryEntry[T]]: - """ - Get all entries that have a given tag, optionally matching a specific value. - - Args: - tag: The tag key to match. - value: If provided, only entries whose tag value equals this are returned. - If ``None``, any entry that has the tag key is returned regardless of value. - - Returns: - List of matching RegistryEntry objects sorted by name. - """ - results: list[RegistryEntry[T]] = [] - for name in sorted(self._registry_items.keys()): - entry = self._registry_items[name] - if tag in entry.tags and (value is None or entry.tags[tag] == value): - results.append(entry) - return results - - def add_tags( - self, - *, - name: str, - tags: Union[dict[str, str], list[str]], - ) -> None: - """ - Add tags to an existing registry entry. - - Args: - name: The registry name of the entry to tag. - tags: Tags to add. Accepts a ``dict[str, str]`` - or a ``list[str]`` (each string becomes a key with value ``""``). - - Raises: - KeyError: If no entry with the given name exists. - """ - entry = self._registry_items.get(name) - if entry is None: - raise KeyError(f"No entry named '{name}' in registry.") - entry.tags.update(self._normalize_tags(tags)) - self._metadata_cache = None - - def find_dependents_of_tag(self, *, tag: str) -> list[RegistryEntry[T]]: - """ - Find entries whose children depend on entries with the given tag. - - Scans each registry entry's ``ComponentIdentifier`` tree and checks - whether any child's ``eval_hash`` matches the ``eval_hash`` of an - entry that carries *tag*. Entries that themselves carry *tag* are - excluded from the results. - - This enables automatic dependency detection: for example, tagging - base refusal scorers with ``"refusal"`` lets you discover all - wrapper scorers (inverters, composites) that embed a refusal scorer - without any explicit ``depends_on`` declaration. - - Args: - tag: The tag key that identifies the "base" entries. - - Returns: - List of ``RegistryEntry`` objects that depend on tagged entries, - sorted by name. - """ - # Collect eval_hashes of all tagged entries - tagged_hashes: set[str] = set() - tagged_names: set[str] = set() - for entry in self.get_by_tag(tag=tag): - tagged_names.add(entry.name) - identifier = self._build_metadata(entry.name, entry.instance) - if identifier.eval_hash: - tagged_hashes.add(identifier.eval_hash) - - if not tagged_hashes: - return [] - - # Find non-tagged entries whose children reference a tagged eval_hash - dependents: list[RegistryEntry[T]] = [] - for name in sorted(self._registry_items.keys()): - if name in tagged_names: - continue - entry = self._registry_items[name] - identifier = self._build_metadata(name, entry.instance) - child_hashes = identifier._collect_child_eval_hashes() - if child_hashes & tagged_hashes: - dependents.append(entry) - return dependents - - def list_metadata( - self, - *, - include_filters: Optional[dict[str, object]] = None, - exclude_filters: Optional[dict[str, object]] = None, - ) -> list[ComponentIdentifier]: - """ - List metadata for all registered items, optionally filtered. - - Supports filtering on any metadata property: - - Simple types (str, int, bool): exact match - - List types: checks if filter value is in the list - - Args: - include_filters: Optional dict of filters that items must match. - Keys are metadata property names, values are the filter criteria. - All filters must match (AND logic). - exclude_filters: Optional dict of filters that items must NOT match. - Keys are metadata property names, values are the filter criteria. - Any matching filter excludes the item. - - Returns: - List of ComponentIdentifier metadata for each registered item. - """ - from pyrit.registry.base import _matches_filters - - if self._metadata_cache is None: - items = [] - for name in sorted(self._registry_items.keys()): - entry = self._registry_items[name] - items.append(self._build_metadata(name, entry.instance)) - self._metadata_cache = items - - if not include_filters and not exclude_filters: - return self._metadata_cache - - return [ - m - for m in self._metadata_cache - if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) - ] - - def _build_metadata(self, name: str, instance: T) -> ComponentIdentifier: - """ - Build metadata for an item via its ``Identifiable`` interface. - - Args: - name: The registry name of the item. - instance: The item. +from pyrit.registry.instance_registries.base_item_registry import ( + BaseItemRegistry, + RegistryEntry, + T, +) - Returns: - The item's ComponentIdentifier. - """ - return instance.get_identifier() - - def __contains__(self, name: str) -> bool: - """ - Check if a name is registered. - - Returns: - True if the name is registered, False otherwise. - """ - return name in self._registry_items - - def __len__(self) -> int: - """ - Get the count of registered items. - - Returns: - The number of registered items. - """ - return len(self._registry_items) - - def __iter__(self) -> Iterator[str]: - """ - Iterate over registered names. - - Returns: - An iterator over sorted registered names. - """ - return iter(sorted(self._registry_items.keys())) +# Re-export so existing ``from base_instance_registry import ...`` still works +__all__ = ["BaseInstanceRegistry", "BaseItemRegistry", "RegistryEntry"] class BaseInstanceRegistry(BaseItemRegistry[T]): @@ -344,7 +41,7 @@ class BaseInstanceRegistry(BaseItemRegistry[T]): T: The type of instances stored in the registry (must be Identifiable). """ - def get(self, name: str) -> Optional[T]: + def get(self, name: str) -> T | None: """ Get a registered instance by name. @@ -359,7 +56,7 @@ def get(self, name: str) -> Optional[T]: return None return entry.instance - def get_entry(self, name: str) -> Optional[RegistryEntry[T]]: + def get_entry(self, name: str) -> RegistryEntry[T] | None: """ Get a full registry entry by name, including tags. diff --git a/pyrit/registry/instance_registries/base_item_registry.py b/pyrit/registry/instance_registries/base_item_registry.py new file mode 100644 index 0000000000..db5da489b3 --- /dev/null +++ b/pyrit/registry/instance_registries/base_item_registry.py @@ -0,0 +1,325 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Base item registry for PyRIT. + +This module provides ``BaseItemRegistry``, the shared infrastructure for +registries that store ``Identifiable`` objects (not classes): singleton +lifecycle, registration, tags, metadata, container protocol. + +Subclass directly for registries that store factories or other +non-retrievable items (e.g., ``AttackTechniqueRegistry``). For registries +where callers retrieve stored objects directly, subclass +``BaseInstanceRegistry`` instead. + +For registries that store classes (Type[T]), see ``class_registries/``. +""" + +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from pyrit.identifiers import ComponentIdentifier, Identifiable +from pyrit.registry.base import RegistryProtocol + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Self + +T = TypeVar("T", bound=Identifiable) # The type of items stored + + +@dataclass +class RegistryEntry(Generic[T]): + """ + A wrapper around a registered item, holding its name, tags, and the item itself. + + Tags are always stored as ``dict[str, str]``. When callers pass a plain + ``list[str]``, each string is normalized to a key with an empty-string value. + + Attributes: + name: The registry name for this entry. + instance: The registered object. + tags: Key-value tags for categorization and filtering. + """ + + name: str + instance: T + tags: dict[str, str] = field(default_factory=dict) + + +class BaseItemRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): + """ + Abstract base class providing shared registry infrastructure. + + Provides singleton lifecycle, registration, tag-based lookup, metadata + filtering, and the standard container protocol (``__contains__``, + ``__len__``, ``__iter__``). + + Subclass directly when stored items should not be retrievable via + ``get()`` (e.g., factory registries). For registries that expose + direct item retrieval, subclass ``BaseInstanceRegistry`` instead. + + All stored items must implement ``Identifiable``, which provides + ``get_identifier()`` for metadata generation. + + Type Parameters: + T: The type of items stored in the registry (must be Identifiable). + """ + + # Class-level singleton instances, keyed by registry class + _instances: dict[type, BaseItemRegistry[Any]] = {} + + @classmethod + def get_registry_singleton(cls) -> Self: + """ + Get the singleton instance of this registry. + + Creates the instance on first call with default parameters. + + Returns: + The singleton instance of this registry class. + """ + if cls not in cls._instances: + cls._instances[cls] = cls() + return cls._instances[cls] # type: ignore[return-value] + + @classmethod + def reset_instance(cls) -> None: + """ + Reset the singleton instance. + + Useful for testing or reinitializing the registry. + """ + if cls in cls._instances: + del cls._instances[cls] + + @staticmethod + def _normalize_tags(tags: dict[str, str] | list[str] | None = None) -> dict[str, str]: + """ + Normalize tags into a ``dict[str, str]``. + + Args: + tags: Tags as a dict, a list of string keys (values default to ``""``), + or ``None`` (returns empty dict). + + Returns: + A ``dict[str, str]`` of normalised tags. + """ + if tags is None: + return {} + if isinstance(tags, list): + return dict.fromkeys(tags, "") + return dict(tags) + + def __init__(self) -> None: + """Initialize the registry.""" + # Maps registry names to registry entries + self._registry_items: dict[str, RegistryEntry[T]] = {} + self._metadata_cache: list[ComponentIdentifier] | None = None + + def register( + self, + instance: T, + *, + name: str, + tags: dict[str, str] | list[str] | None = None, + ) -> None: + """ + Register an item. + + Args: + instance: The item to register. + name: The registry name for this item. + tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` + or a ``list[str]`` (each string becomes a key with value ``""``). + """ + normalized = self._normalize_tags(tags) + self._registry_items[name] = RegistryEntry(name=name, instance=instance, tags=normalized) + self._metadata_cache = None + + def get_names(self) -> list[str]: + """ + Get a sorted list of all registered names. + + Returns: + Sorted list of registry names (keys). + """ + return sorted(self._registry_items.keys()) + + def get_by_tag( + self, + *, + tag: str, + value: str | None = None, + ) -> list[RegistryEntry[T]]: + """ + Get all entries that have a given tag, optionally matching a specific value. + + Args: + tag: The tag key to match. + value: If provided, only entries whose tag value equals this are returned. + If ``None``, any entry that has the tag key is returned regardless of value. + + Returns: + List of matching RegistryEntry objects sorted by name. + """ + results: list[RegistryEntry[T]] = [] + for name in sorted(self._registry_items.keys()): + entry = self._registry_items[name] + if tag in entry.tags and (value is None or entry.tags[tag] == value): + results.append(entry) + return results + + def add_tags( + self, + *, + name: str, + tags: dict[str, str] | list[str], + ) -> None: + """ + Add tags to an existing registry entry. + + Args: + name: The registry name of the entry to tag. + tags: Tags to add. Accepts a ``dict[str, str]`` + or a ``list[str]`` (each string becomes a key with value ``""``). + + Raises: + KeyError: If no entry with the given name exists. + """ + entry = self._registry_items.get(name) + if entry is None: + raise KeyError(f"No entry named '{name}' in registry.") + entry.tags.update(self._normalize_tags(tags)) + self._metadata_cache = None + + def find_dependents_of_tag(self, *, tag: str) -> list[RegistryEntry[T]]: + """ + Find entries whose children depend on entries with the given tag. + + Scans each registry entry's ``ComponentIdentifier`` tree and checks + whether any child's ``eval_hash`` matches the ``eval_hash`` of an + entry that carries *tag*. Entries that themselves carry *tag* are + excluded from the results. + + This enables automatic dependency detection: for example, tagging + base refusal scorers with ``"refusal"`` lets you discover all + wrapper scorers (inverters, composites) that embed a refusal scorer + without any explicit ``depends_on`` declaration. + + Args: + tag: The tag key that identifies the "base" entries. + + Returns: + List of ``RegistryEntry`` objects that depend on tagged entries, + sorted by name. + """ + # Collect eval_hashes of all tagged entries + tagged_hashes: set[str] = set() + tagged_names: set[str] = set() + for entry in self.get_by_tag(tag=tag): + tagged_names.add(entry.name) + identifier = self._build_metadata(entry.name, entry.instance) + if identifier.eval_hash: + tagged_hashes.add(identifier.eval_hash) + + if not tagged_hashes: + return [] + + # Find non-tagged entries whose children reference a tagged eval_hash + dependents: list[RegistryEntry[T]] = [] + for name in sorted(self._registry_items.keys()): + if name in tagged_names: + continue + entry = self._registry_items[name] + identifier = self._build_metadata(name, entry.instance) + child_hashes = identifier._collect_child_eval_hashes() + if child_hashes & tagged_hashes: + dependents.append(entry) + return dependents + + def list_metadata( + self, + *, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[ComponentIdentifier]: + """ + List metadata for all registered items, optionally filtered. + + Supports filtering on any metadata property: + - Simple types (str, int, bool): exact match + - List types: checks if filter value is in the list + + Args: + include_filters: Optional dict of filters that items must match. + Keys are metadata property names, values are the filter criteria. + All filters must match (AND logic). + exclude_filters: Optional dict of filters that items must NOT match. + Keys are metadata property names, values are the filter criteria. + Any matching filter excludes the item. + + Returns: + List of ComponentIdentifier metadata for each registered item. + """ + from pyrit.registry.base import _matches_filters + + if self._metadata_cache is None: + items = [] + for name in sorted(self._registry_items.keys()): + entry = self._registry_items[name] + items.append(self._build_metadata(name, entry.instance)) + self._metadata_cache = items + + if not include_filters and not exclude_filters: + return self._metadata_cache + + return [ + m + for m in self._metadata_cache + if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) + ] + + def _build_metadata(self, name: str, instance: T) -> ComponentIdentifier: + """ + Build metadata for an item via its ``Identifiable`` interface. + + Args: + name: The registry name of the item. + instance: The item. + + Returns: + The item's ComponentIdentifier. + """ + return instance.get_identifier() + + def __contains__(self, name: str) -> bool: + """ + Check if a name is registered. + + Returns: + True if the name is registered, False otherwise. + """ + return name in self._registry_items + + def __len__(self) -> int: + """ + Get the count of registered items. + + Returns: + The number of registered items. + """ + return len(self._registry_items) + + def __iter__(self) -> Iterator[str]: + """ + Iterate over registered names. + + Returns: + An iterator over sorted registered names. + """ + return iter(sorted(self._registry_items.keys())) diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index 40d7bec579..3b0900c527 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -83,18 +83,13 @@ def _validate_kwargs(self) -> None: ValueError: If ``objective_target`` is included in attack_kwargs. """ if "objective_target" in self._attack_kwargs: - raise ValueError( - "objective_target must not be in attack_kwargs — " - "it is provided at create() time." - ) + raise ValueError("objective_target must not be in attack_kwargs — it is provided at create() time.") sig = inspect.signature(self._attack_class.__init__) # Reject constructors that accept **kwargs — we require explicitly named # parameters so that validation is meaningful. - has_var_keyword = any( - param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values() - ) + has_var_keyword = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()) if has_var_keyword: raise TypeError( f"{self._attack_class.__name__}.__init__ accepts **kwargs, which prevents " diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index 71a47fdc41..5a88d5057c 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -6,6 +6,8 @@ from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, +) +from pyrit.registry.instance_registries.base_item_registry import ( BaseItemRegistry, RegistryEntry, ) @@ -518,8 +520,16 @@ def test_instance_registry_has_get_all_instances(self): def test_item_registry_shares_common_methods(self): """BaseItemRegistry subclasses should have shared registry methods.""" - for method in ("register", "get_names", "get_by_tag", "add_tags", "list_metadata", - "find_dependents_of_tag", "get_registry_singleton", "reset_instance"): + for method in ( + "register", + "get_names", + "get_by_tag", + "add_tags", + "list_metadata", + "find_dependents_of_tag", + "get_registry_singleton", + "reset_instance", + ): assert hasattr(_ItemOnlyRegistry, method), f"Missing method: {method}" From 2aa793fffefac95953b553544adeacd41b75df65 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 14 Apr 2026 21:42:59 -0700 Subject: [PATCH 08/10] Fix mypy: add type parameters to AttackStrategy generic usage Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/core/attack_technique_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index 3b0900c527..fac94e4932 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -45,7 +45,7 @@ class AttackTechniqueFactory(Identifiable): def __init__( self, *, - attack_class: type[AttackStrategy], + attack_class: type[AttackStrategy[Any, Any]], attack_kwargs: dict[str, Any] | None = None, seed_technique: SeedAttackTechniqueGroup | None = None, ) -> None: @@ -115,7 +115,7 @@ def _validate_kwargs(self) -> None: ) @property - def attack_class(self) -> type[AttackStrategy]: + def attack_class(self) -> type[AttackStrategy[Any, Any]]: """The attack strategy class this factory produces.""" return self._attack_class From ae0cd0d6a4a43215d3e289744df6ffeac8808bc5 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 15 Apr 2026 13:21:47 -0700 Subject: [PATCH 09/10] =?UTF-8?q?Rename=20BaseItemRegistry=20=E2=86=92=20B?= =?UTF-8?q?aseInstanceRegistry,=20BaseInstanceRegistry=20=E2=86=92=20Retri?= =?UTF-8?q?evableInstanceRegistry?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improve naming clarity for the registry hierarchy: - BaseItemRegistry → BaseInstanceRegistry (shared base for all object registries) - BaseInstanceRegistry → RetrievableInstanceRegistry (adds get/get_entry/get_all_instances) - instance_registries/ → object_registries/ (parallel with class_registries/) - base_item_registry.py → base_instance_registry.py - base_instance_registry.py → retrievable_instance_registry.py - Add ConverterRegistry to pyrit.registry __init__ exports Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/services/converter_service.py | 2 +- pyrit/backend/services/target_service.py | 2 +- pyrit/registry/__init__.py | 10 ++- pyrit/registry/base.py | 4 +- pyrit/registry/class_registries/__init__.py | 2 +- .../class_registries/base_class_registry.py | 2 +- .../__init__.py | 20 ++--- .../attack_technique_registry.py | 6 +- .../base_instance_registry.py} | 12 +-- .../converter_registry.py | 6 +- .../retrievable_instance_registry.py} | 22 ++--- .../scorer_registry.py | 6 +- .../target_registry.py | 6 +- tests/unit/backend/test_converter_service.py | 2 +- tests/unit/backend/test_target_service.py | 2 +- .../test_attack_technique_registry.py | 2 +- .../registry/test_base_instance_registry.py | 82 +++++++++---------- .../unit/registry/test_converter_registry.py | 4 +- tests/unit/registry/test_scorer_registry.py | 4 +- tests/unit/registry/test_target_registry.py | 2 +- 20 files changed, 100 insertions(+), 98 deletions(-) rename pyrit/registry/{instance_registries => object_registries}/__init__.py (59%) rename pyrit/registry/{instance_registries => object_registries}/attack_technique_registry.py (94%) rename pyrit/registry/{instance_registries/base_item_registry.py => object_registries/base_instance_registry.py} (96%) rename pyrit/registry/{instance_registries => object_registries}/converter_registry.py (91%) rename pyrit/registry/{instance_registries/base_instance_registry.py => object_registries/retrievable_instance_registry.py} (72%) rename pyrit/registry/{instance_registries => object_registries}/scorer_registry.py (92%) rename pyrit/registry/{instance_registries => object_registries}/target_registry.py (92%) diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 5af8566be9..faf7e1d73f 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -40,7 +40,7 @@ from pyrit.models.data_type_serializer import data_serializer_factory from pyrit.prompt_converter import PromptConverter from pyrit.prompt_target import PromptChatTarget -from pyrit.registry.instance_registries import ConverterRegistry +from pyrit.registry.object_registries import ConverterRegistry _DATA_TYPE_EXTENSION: dict[str, str] = { "image_path": ".png", diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 7959a80408..26d66c8fa1 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -24,7 +24,7 @@ TargetListResponse, ) from pyrit.prompt_target import PromptTarget -from pyrit.registry.instance_registries import TargetRegistry +from pyrit.registry.object_registries import TargetRegistry def _build_target_class_registry() -> dict[str, type]: diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index d32ae963fa..c595c0d08c 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Registry module for PyRIT class and instance registries.""" +"""Registry module for PyRIT class and object registries.""" from pyrit.registry.base import RegistryProtocol from pyrit.registry.class_registries import ( @@ -17,10 +17,11 @@ discover_in_package, discover_subclasses_in_loaded_modules, ) -from pyrit.registry.instance_registries import ( +from pyrit.registry.object_registries import ( AttackTechniqueRegistry, BaseInstanceRegistry, - BaseItemRegistry, + ConverterRegistry, + RetrievableInstanceRegistry, RegistryEntry, ScorerRegistry, TargetRegistry, @@ -30,7 +31,8 @@ "AttackTechniqueRegistry", "BaseClassRegistry", "BaseInstanceRegistry", - "BaseItemRegistry", + "ConverterRegistry", + "RetrievableInstanceRegistry", "ClassEntry", "discover_in_directory", "discover_in_package", diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index 1bf8a4a298..766c554fc6 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -5,7 +5,7 @@ Shared base types for PyRIT registries. This module contains types shared between class registries (which store Type[T]) -and instance registries (which store T instances). +and object registries (which store T instances). """ from __future__ import annotations @@ -48,7 +48,7 @@ class RegistryProtocol(Protocol[MetadataT]): """ Protocol defining the common interface for all registries. - Both class registries (BaseClassRegistry) and instance registries + Both class registries (BaseClassRegistry) and object registries (BaseInstanceRegistry) implement this interface, enabling code that works with either registry type. diff --git a/pyrit/registry/class_registries/__init__.py b/pyrit/registry/class_registries/__init__.py index b29a6a6eed..c003ccc59a 100644 --- a/pyrit/registry/class_registries/__init__.py +++ b/pyrit/registry/class_registries/__init__.py @@ -7,7 +7,7 @@ This package contains registries that store classes (Type[T]) which can be instantiated on demand. Examples include ScenarioRegistry and InitializerRegistry. -For registries that store pre-configured instances, see instance_registries/. +For registries that store pre-configured instances, see object_registries/. """ from pyrit.registry.class_registries.base_class_registry import ( diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index 9439c24332..6b44c6e832 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -7,7 +7,7 @@ This module provides the abstract base class for registries that store classes (Type[T]). These registries allow on-demand instantiation of registered classes. -For registries that store pre-configured instances, see instance_registries/. +For registries that store pre-configured instances, see object_registries/. Terminology: - **Metadata**: A TypedDict describing a registered class (e.g., ScenarioMetadata) diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/object_registries/__init__.py similarity index 59% rename from pyrit/registry/instance_registries/__init__.py rename to pyrit/registry/object_registries/__init__.py index c9fe2e3642..0a43a5af2f 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/object_registries/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. """ -Instance registries package. +Object registries package. This package contains registries that store pre-configured instances (not classes). Examples include ScorerRegistry which stores Scorer instances that have been @@ -11,30 +11,30 @@ For registries that store classes (Type[T]), see class_registries/. """ -from pyrit.registry.instance_registries.attack_technique_registry import ( +from pyrit.registry.object_registries.attack_technique_registry import ( AttackTechniqueRegistry, ) -from pyrit.registry.instance_registries.base_instance_registry import ( +from pyrit.registry.object_registries.base_instance_registry import ( BaseInstanceRegistry, -) -from pyrit.registry.instance_registries.base_item_registry import ( - BaseItemRegistry, RegistryEntry, ) -from pyrit.registry.instance_registries.converter_registry import ( +from pyrit.registry.object_registries.converter_registry import ( ConverterRegistry, ) -from pyrit.registry.instance_registries.scorer_registry import ( +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, +) +from pyrit.registry.object_registries.scorer_registry import ( ScorerRegistry, ) -from pyrit.registry.instance_registries.target_registry import ( +from pyrit.registry.object_registries.target_registry import ( TargetRegistry, ) __all__ = [ # Base classes "BaseInstanceRegistry", - "BaseItemRegistry", + "RetrievableInstanceRegistry", "RegistryEntry", # Concrete registries "AttackTechniqueRegistry", diff --git a/pyrit/registry/instance_registries/attack_technique_registry.py b/pyrit/registry/object_registries/attack_technique_registry.py similarity index 94% rename from pyrit/registry/instance_registries/attack_technique_registry.py rename to pyrit/registry/object_registries/attack_technique_registry.py index c1476895b0..2b68ffd651 100644 --- a/pyrit/registry/instance_registries/attack_technique_registry.py +++ b/pyrit/registry/object_registries/attack_technique_registry.py @@ -14,8 +14,8 @@ import logging from typing import TYPE_CHECKING -from pyrit.registry.instance_registries.base_item_registry import ( - BaseItemRegistry, +from pyrit.registry.object_registries.base_instance_registry import ( + BaseInstanceRegistry, ) if TYPE_CHECKING: @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -class AttackTechniqueRegistry(BaseItemRegistry["AttackTechniqueFactory"]): +class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory"]): """ Singleton registry of reusable attack technique factories. diff --git a/pyrit/registry/instance_registries/base_item_registry.py b/pyrit/registry/object_registries/base_instance_registry.py similarity index 96% rename from pyrit/registry/instance_registries/base_item_registry.py rename to pyrit/registry/object_registries/base_instance_registry.py index db5da489b3..1d60417b9b 100644 --- a/pyrit/registry/instance_registries/base_item_registry.py +++ b/pyrit/registry/object_registries/base_instance_registry.py @@ -2,16 +2,16 @@ # Licensed under the MIT license. """ -Base item registry for PyRIT. +Base instance registry for PyRIT. -This module provides ``BaseItemRegistry``, the shared infrastructure for +This module provides ``BaseInstanceRegistry``, the shared infrastructure for registries that store ``Identifiable`` objects (not classes): singleton lifecycle, registration, tags, metadata, container protocol. Subclass directly for registries that store factories or other non-retrievable items (e.g., ``AttackTechniqueRegistry``). For registries where callers retrieve stored objects directly, subclass -``BaseInstanceRegistry`` instead. +``RetrievableInstanceRegistry`` instead. For registries that store classes (Type[T]), see ``class_registries/``. """ @@ -51,7 +51,7 @@ class RegistryEntry(Generic[T]): tags: dict[str, str] = field(default_factory=dict) -class BaseItemRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): +class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): """ Abstract base class providing shared registry infrastructure. @@ -61,7 +61,7 @@ class BaseItemRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): Subclass directly when stored items should not be retrievable via ``get()`` (e.g., factory registries). For registries that expose - direct item retrieval, subclass ``BaseInstanceRegistry`` instead. + direct item retrieval, subclass ``RetrievableInstanceRegistry`` instead. All stored items must implement ``Identifiable``, which provides ``get_identifier()`` for metadata generation. @@ -71,7 +71,7 @@ class BaseItemRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): """ # Class-level singleton instances, keyed by registry class - _instances: dict[type, BaseItemRegistry[Any]] = {} + _instances: dict[type, BaseInstanceRegistry[Any]] = {} @classmethod def get_registry_singleton(cls) -> Self: diff --git a/pyrit/registry/instance_registries/converter_registry.py b/pyrit/registry/object_registries/converter_registry.py similarity index 91% rename from pyrit/registry/instance_registries/converter_registry.py rename to pyrit/registry/object_registries/converter_registry.py index 19f8d03108..4d83c9e1fd 100644 --- a/pyrit/registry/instance_registries/converter_registry.py +++ b/pyrit/registry/object_registries/converter_registry.py @@ -14,8 +14,8 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.registry.instance_registries.base_instance_registry import ( - BaseInstanceRegistry, +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, ) if TYPE_CHECKING: @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -class ConverterRegistry(BaseInstanceRegistry["PromptConverter"]): +class ConverterRegistry(RetrievableInstanceRegistry["PromptConverter"]): """ Registry for managing available converter instances. diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/object_registries/retrievable_instance_registry.py similarity index 72% rename from pyrit/registry/instance_registries/base_instance_registry.py rename to pyrit/registry/object_registries/retrievable_instance_registry.py index 11e501d6f3..b5bc4fdfec 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/object_registries/retrievable_instance_registry.py @@ -2,40 +2,40 @@ # Licensed under the MIT license. """ -Instance registry for PyRIT. +Retrievable instance registry for PyRIT. -This module provides ``BaseInstanceRegistry``, which extends -``BaseItemRegistry`` with ``get()``, ``get_entry()``, and +This module provides ``RetrievableInstanceRegistry``, which extends +``BaseInstanceRegistry`` with ``get()``, ``get_entry()``, and ``get_all_instances()`` for registries where callers retrieve stored objects directly (e.g., ``ScorerRegistry``, ``ConverterRegistry``, ``TargetRegistry``). -For the shared base class, see ``base_item_registry``. +For the shared base class, see ``base_instance_registry``. For registries that store classes (Type[T]), see ``class_registries/``. """ from __future__ import annotations -from pyrit.registry.instance_registries.base_item_registry import ( - BaseItemRegistry, +from pyrit.registry.object_registries.base_instance_registry import ( + BaseInstanceRegistry, RegistryEntry, T, ) -# Re-export so existing ``from base_instance_registry import ...`` still works -__all__ = ["BaseInstanceRegistry", "BaseItemRegistry", "RegistryEntry"] +# Re-export so existing ``from retrievable_instance_registry import ...`` still works +__all__ = ["RetrievableInstanceRegistry", "BaseInstanceRegistry", "RegistryEntry"] -class BaseInstanceRegistry(BaseItemRegistry[T]): +class RetrievableInstanceRegistry(BaseInstanceRegistry[T]): """ Base class for registries that store directly-retrievable instances. - Extends ``BaseItemRegistry`` with ``get()``, ``get_entry()``, and + Extends ``BaseInstanceRegistry`` with ``get()``, ``get_entry()``, and ``get_all_instances()`` for registries where callers retrieve the stored objects directly (e.g., scorers, converters, targets). For registries that store factories or other non-retrievable items, - subclass ``BaseItemRegistry`` directly instead. + subclass ``BaseInstanceRegistry`` directly instead. Type Parameters: T: The type of instances stored in the registry (must be Identifiable). diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/object_registries/scorer_registry.py similarity index 92% rename from pyrit/registry/instance_registries/scorer_registry.py rename to pyrit/registry/object_registries/scorer_registry.py index f645309508..af5c59946f 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/object_registries/scorer_registry.py @@ -12,8 +12,8 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.registry.instance_registries.base_instance_registry import ( - BaseInstanceRegistry, +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, ) if TYPE_CHECKING: @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -class ScorerRegistry(BaseInstanceRegistry["Scorer"]): +class ScorerRegistry(RetrievableInstanceRegistry["Scorer"]): """ Registry for managing available scorer instances. diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/object_registries/target_registry.py similarity index 92% rename from pyrit/registry/instance_registries/target_registry.py rename to pyrit/registry/object_registries/target_registry.py index 88ae19f49c..c6fefd3926 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/object_registries/target_registry.py @@ -12,8 +12,8 @@ import logging from typing import TYPE_CHECKING, Optional, Union -from pyrit.registry.instance_registries.base_instance_registry import ( - BaseInstanceRegistry, +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, ) if TYPE_CHECKING: @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -class TargetRegistry(BaseInstanceRegistry["PromptTarget"]): +class TargetRegistry(RetrievableInstanceRegistry["PromptTarget"]): """ Registry for managing available prompt target instances. diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 852ca1712f..f67bdcf3f0 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -23,7 +23,7 @@ SuffixAppendConverter, ) from pyrit.prompt_converter.prompt_converter import get_converter_modalities -from pyrit.registry.instance_registries import ConverterRegistry +from pyrit.registry.object_registries import ConverterRegistry @pytest.fixture(autouse=True) diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 06b3c5713c..70f61acafa 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -13,7 +13,7 @@ from pyrit.backend.models.targets import CreateTargetRequest from pyrit.backend.services.target_service import TargetService, get_target_service from pyrit.identifiers import ComponentIdentifier -from pyrit.registry.instance_registries import TargetRegistry +from pyrit.registry.object_registries import TargetRegistry @pytest.fixture(autouse=True) diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py index 4bf1827112..e0d7463b51 100644 --- a/tests/unit/registry/test_attack_technique_registry.py +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -10,7 +10,7 @@ from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.identifiers import ComponentIdentifier from pyrit.prompt_target import PromptTarget -from pyrit.registry.instance_registries.attack_technique_registry import AttackTechniqueRegistry +from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index 5a88d5057c..6ddfafc30d 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -4,11 +4,11 @@ import pytest from pyrit.identifiers import ComponentIdentifier, Identifiable -from pyrit.registry.instance_registries.base_instance_registry import ( - BaseInstanceRegistry, +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, ) -from pyrit.registry.instance_registries.base_item_registry import ( - BaseItemRegistry, +from pyrit.registry.object_registries.base_instance_registry import ( + BaseInstanceRegistry, RegistryEntry, ) @@ -45,12 +45,12 @@ def _item(value: str) -> _TestItem: return _TestItem(value) -class ConcreteTestRegistry(BaseInstanceRegistry["_TestItem"]): - """Concrete implementation of BaseInstanceRegistry for testing.""" +class ConcreteTestRegistry(RetrievableInstanceRegistry["_TestItem"]): + """Concrete implementation of RetrievableInstanceRegistry for testing.""" -class TestBaseInstanceRegistrySingleton: - """Tests for the singleton pattern in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistrySingleton: + """Tests for the singleton pattern in RetrievableInstanceRegistry.""" def setup_method(self): """Reset the singleton before each test.""" @@ -82,8 +82,8 @@ def test_reset_instance_when_not_exists_does_not_raise(self): ConcreteTestRegistry.reset_instance() -class TestBaseInstanceRegistryRegistration: - """Tests for registration functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryRegistration: + """Tests for registration functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -134,8 +134,8 @@ def test_register_invalidates_metadata_cache(self): assert len(metadata2) == 2 -class TestBaseInstanceRegistryGet: - """Tests for get functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryGet: + """Tests for get functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -158,8 +158,8 @@ def test_get_nonexistent_returns_none(self): assert result is None -class TestBaseInstanceRegistryGetEntry: - """Tests for get_entry functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryGetEntry: + """Tests for get_entry functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -186,8 +186,8 @@ def test_get_entry_nonexistent_returns_none(self): assert result is None -class TestBaseInstanceRegistryGetNames: - """Tests for get_names functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryGetNames: + """Tests for get_names functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -213,8 +213,8 @@ def test_get_names_returns_sorted_list(self): assert names == ["alpha", "beta", "zeta"] -class TestBaseInstanceRegistryGetAllInstances: - """Tests for get_all_instances functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryGetAllInstances: + """Tests for get_all_instances functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -260,8 +260,8 @@ def test_get_all_instances_empty_registry(self): assert result == [] -class TestBaseInstanceRegistryListMetadata: - """Tests for list_metadata functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryListMetadata: + """Tests for list_metadata functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -331,8 +331,8 @@ def test_list_metadata_caching(self): assert len(metadata1) == 3 -class TestBaseInstanceRegistryTags: - """Tests for tag registration and retrieval in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryTags: + """Tests for tag registration and retrieval in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -429,11 +429,11 @@ def test_get_by_tag_with_list_tags_value_empty_string(self): def test_normalize_tags_none(self): """Test _normalize_tags returns empty dict for None.""" - assert BaseInstanceRegistry._normalize_tags(None) == {} + assert RetrievableInstanceRegistry._normalize_tags(None) == {} def test_normalize_tags_list(self): """Test _normalize_tags converts list to dict with empty values.""" - assert BaseInstanceRegistry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} + assert RetrievableInstanceRegistry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} def test_normalize_tags_dict(self): """Test _normalize_tags returns a copy of the dict.""" @@ -443,8 +443,8 @@ def test_normalize_tags_dict(self): assert result is not original -class TestBaseInstanceRegistryDunderMethods: - """Tests for dunder methods (__contains__, __len__, __iter__) in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryDunderMethods: + """Tests for dunder methods (__contains__, __len__, __iter__) in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -487,39 +487,39 @@ def test_iter_allows_for_loop(self): assert collected == ["name1", "name2"] -class _ItemOnlyRegistry(BaseItemRegistry["_TestItem"]): - """Concrete BaseItemRegistry subclass — should NOT have get/get_entry/get_all_instances.""" +class _ItemOnlyRegistry(BaseInstanceRegistry["_TestItem"]): + """Concrete BaseInstanceRegistry subclass — should NOT have get/get_entry/get_all_instances.""" -class TestBaseItemRegistryDoesNotExposeInstanceMethods: - """Verify that BaseItemRegistry subclasses lack instance-retrieval methods.""" +class TestBaseInstanceRegistryDoesNotExposeInstanceMethods: + """Verify that BaseInstanceRegistry subclasses lack instance-retrieval methods.""" def test_item_registry_has_no_get(self): - """BaseItemRegistry subclasses should not have a get() method.""" + """BaseInstanceRegistry subclasses should not have a get() method.""" assert not hasattr(_ItemOnlyRegistry, "get") def test_item_registry_has_no_get_entry(self): - """BaseItemRegistry subclasses should not have a get_entry() method.""" + """BaseInstanceRegistry subclasses should not have a get_entry() method.""" assert not hasattr(_ItemOnlyRegistry, "get_entry") def test_item_registry_has_no_get_all_instances(self): - """BaseItemRegistry subclasses should not have a get_all_instances() method.""" + """BaseInstanceRegistry subclasses should not have a get_all_instances() method.""" assert not hasattr(_ItemOnlyRegistry, "get_all_instances") def test_instance_registry_has_get(self): - """BaseInstanceRegistry subclasses should have get().""" + """RetrievableInstanceRegistry subclasses should have get().""" assert hasattr(ConcreteTestRegistry, "get") def test_instance_registry_has_get_entry(self): - """BaseInstanceRegistry subclasses should have get_entry().""" + """RetrievableInstanceRegistry subclasses should have get_entry().""" assert hasattr(ConcreteTestRegistry, "get_entry") def test_instance_registry_has_get_all_instances(self): - """BaseInstanceRegistry subclasses should have get_all_instances().""" + """RetrievableInstanceRegistry subclasses should have get_all_instances().""" assert hasattr(ConcreteTestRegistry, "get_all_instances") def test_item_registry_shares_common_methods(self): - """BaseItemRegistry subclasses should have shared registry methods.""" + """BaseInstanceRegistry subclasses should have shared registry methods.""" for method in ( "register", "get_names", @@ -533,8 +533,8 @@ def test_item_registry_shares_common_methods(self): assert hasattr(_ItemOnlyRegistry, method), f"Missing method: {method}" -class TestBaseInstanceRegistryAddTags: - """Tests for add_tags functionality in BaseInstanceRegistry.""" +class TestRetrievableInstanceRegistryAddTags: + """Tests for add_tags functionality in RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry for each test.""" @@ -607,12 +607,12 @@ def _build_identifier(self) -> ComponentIdentifier: return self._stored_identifier -class IdentifierTestRegistry(BaseItemRegistry["_IdentifiableStub"]): +class IdentifierTestRegistry(BaseInstanceRegistry["_IdentifiableStub"]): """Registry for testing dependency-related functionality with ComponentIdentifier trees.""" class TestFindDependentsOfTag: - """Tests for BaseInstanceRegistry.find_dependents_of_tag.""" + """Tests for RetrievableInstanceRegistry.find_dependents_of_tag.""" def setup_method(self) -> None: IdentifierTestRegistry.reset_instance() diff --git a/tests/unit/registry/test_converter_registry.py b/tests/unit/registry/test_converter_registry.py index b62fedb6da..871fb85964 100644 --- a/tests/unit/registry/test_converter_registry.py +++ b/tests/unit/registry/test_converter_registry.py @@ -4,7 +4,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import PromptDataType from pyrit.prompt_converter import ConverterResult, PromptConverter -from pyrit.registry.instance_registries.converter_registry import ConverterRegistry +from pyrit.registry.object_registries.converter_registry import ConverterRegistry class MockTextConverter(PromptConverter): @@ -300,7 +300,7 @@ def test_list_metadata_combined_include_and_exclude(self): class TestConverterRegistryInheritedMethods: - """Tests for inherited methods from BaseInstanceRegistry.""" + """Tests for inherited methods from RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry.""" diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py index 9d521b75fc..078d530967 100644 --- a/tests/unit/registry/test_scorer_registry.py +++ b/tests/unit/registry/test_scorer_registry.py @@ -5,7 +5,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, MessagePiece, Score -from pyrit.registry.instance_registries.scorer_registry import ScorerRegistry +from pyrit.registry.object_registries.scorer_registry import ScorerRegistry from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer import Scorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -328,7 +328,7 @@ def test_list_metadata_combined_include_and_exclude(self): class TestScorerRegistryInheritedMethods: - """Tests for inherited methods from BaseInstanceRegistry.""" + """Tests for inherited methods from RetrievableInstanceRegistry.""" def setup_method(self): """Reset and get a fresh registry.""" diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index 503d096a38..3f05624e9b 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -8,7 +8,7 @@ from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget -from pyrit.registry.instance_registries.target_registry import TargetRegistry +from pyrit.registry.object_registries.target_registry import TargetRegistry class MockPromptTarget(PromptTarget): From 2627fbd9b8fc2d0c3d1d8c5e68b1b0b75dc242dd Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 15 Apr 2026 14:21:32 -0700 Subject: [PATCH 10/10] pre-commit --- pyrit/registry/__init__.py | 2 +- tests/unit/registry/test_base_instance_registry.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index c595c0d08c..4f8290e993 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -21,8 +21,8 @@ AttackTechniqueRegistry, BaseInstanceRegistry, ConverterRegistry, - RetrievableInstanceRegistry, RegistryEntry, + RetrievableInstanceRegistry, ScorerRegistry, TargetRegistry, ) diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index 6ddfafc30d..a0b5a75913 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -4,13 +4,13 @@ import pytest from pyrit.identifiers import ComponentIdentifier, Identifiable -from pyrit.registry.object_registries.retrievable_instance_registry import ( - RetrievableInstanceRegistry, -) from pyrit.registry.object_registries.base_instance_registry import ( BaseInstanceRegistry, RegistryEntry, ) +from pyrit.registry.object_registries.retrievable_instance_registry import ( + RetrievableInstanceRegistry, +) class _TestItem(Identifiable):