diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index c37dd77fd9..499fdc391d 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -412,7 +412,6 @@ def request_piece_to_pyrit_message_piece( role: ChatMessageRole, conversation_id: str, sequence: int, - labels: Optional[dict[str, str]] = None, ) -> PyritMessagePiece: """ Convert a single request piece DTO to a PyRIT MessagePiece domain object. @@ -422,7 +421,6 @@ def request_piece_to_pyrit_message_piece( role: The message role. conversation_id: The conversation/attack ID. sequence: The message sequence number. - labels: Optional labels to attach to the piece. Returns: PyritMessagePiece domain object. @@ -442,7 +440,6 @@ def request_piece_to_pyrit_message_piece( conversation_id=conversation_id, sequence=sequence, prompt_metadata=metadata, - labels=labels or {}, original_prompt_id=original_prompt_id, ) @@ -452,7 +449,6 @@ def request_to_pyrit_message( request: AddMessageRequest, conversation_id: str, sequence: int, - labels: Optional[dict[str, str]] = None, ) -> PyritMessage: """ Build a PyRIT Message from an AddMessageRequest DTO. @@ -461,18 +457,13 @@ def request_to_pyrit_message( request: The inbound API request. conversation_id: The conversation/attack ID. sequence: The message sequence number. - labels: Optional labels to attach to each piece. Returns: PyritMessage ready to send to the target. """ pieces = [ request_piece_to_pyrit_message_piece( - piece=p, - role=request.role, - conversation_id=conversation_id, - sequence=sequence, - labels=labels, + piece=p, role=request.role, conversation_id=conversation_id, sequence=sequence ) for p in request.pieces ] diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 4105bb68aa..0c8438d4eb 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -286,7 +286,6 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt conversation_id = self._duplicate_conversation_up_to( source_conversation_id=request.source_conversation_id, cutoff_index=request.cutoff_index, - labels_override=labels, remap_assistant_to_simulated=True, ) else: @@ -319,7 +318,6 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt await self._store_prepended_messages( conversation_id=conversation_id, prepended=request.prepended_conversation, - labels=labels, ) return CreateAttackResponse( @@ -802,7 +800,6 @@ def _duplicate_conversation_up_to( *, source_conversation_id: str, cutoff_index: int, - labels_override: Optional[dict[str, str]] = None, remap_assistant_to_simulated: bool = False, ) -> str: """ @@ -815,9 +812,6 @@ def _duplicate_conversation_up_to( Args: source_conversation_id: The conversation to copy from. cutoff_index: Include messages with sequence <= cutoff_index. - labels_override: When provided, the duplicated pieces' labels are - replaced with these values. Used when branching into a new - attack that belongs to a different operator. remap_assistant_to_simulated: When True, pieces with role ``assistant`` are changed to ``simulated_assistant`` so the branched context is inert and won't confuse the target. @@ -832,8 +826,6 @@ def _duplicate_conversation_up_to( # Apply optional overrides to the fresh pieces before persisting for piece in all_pieces: - if labels_override is not None: - piece.labels = dict(labels_override) if remap_assistant_to_simulated and piece.api_role == "assistant": piece._role = "simulated_assistant" @@ -924,7 +916,6 @@ async def _store_prepended_messages( self, conversation_id: str, prepended: list[Any], - labels: Optional[dict[str, str]] = None, ) -> None: """Store prepended conversation messages in memory.""" for seq, msg in enumerate(prepended): @@ -934,7 +925,6 @@ async def _store_prepended_messages( role=msg.role, conversation_id=conversation_id, sequence=seq, - labels=labels, ) self._memory.add_message_pieces_to_memory(message_pieces=[piece]) @@ -960,7 +950,6 @@ async def _send_and_store_message_async( request=request, conversation_id=conversation_id, sequence=sequence, - labels=labels, ) converter_configs = self._get_converter_configs(request) @@ -971,7 +960,6 @@ async def _send_and_store_message_async( target=target_obj, conversation_id=conversation_id, request_converter_configurations=converter_configs, - labels=labels, ) # PromptNormalizer stores both request and response in memory automatically @@ -991,7 +979,6 @@ async def _store_message_only_async( role=request.role, conversation_id=conversation_id, sequence=sequence, - labels=labels, ) self._memory.add_message_pieces_to_memory(message_pieces=[piece]) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 7a27cb5666..beb4faa6cf 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -110,7 +110,6 @@ def get_adversarial_chat_messages( conversation_id=adversarial_chat_conversation_id, attack_identifier=attack_identifier, prompt_target_identifier=adversarial_chat_target_identifier, - labels=labels, ) result.append(adversarial_piece.to_message()) @@ -260,7 +259,6 @@ def set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, attack_identifier=self._attack_identifier, - labels=labels, ) async def initialize_context_async( diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index ed95c5d226..1aae795d47 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -288,7 +288,6 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac conversation_id=context.session.conversation_id, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, - labels=context.memory_labels, attack_identifier=self.get_identifier(), ) diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index f137b322f3..6ab7db4618 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -311,7 +311,6 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: system_prompt=system_prompt, conversation_id=context.session.adversarial_chat_conversation_id, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) # Initialize backtrack count in context @@ -534,7 +533,6 @@ async def _send_prompt_to_adversarial_chat_async( conversation_id=context.session.adversarial_chat_conversation_id, target=self._adversarial_chat, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) if not response: @@ -620,7 +618,6 @@ async def _send_prompt_to_objective_target_async( request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) if not response: diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index 8447737578..ee672bf353 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -366,7 +366,6 @@ async def _send_prompt_to_objective_target_async( conversation_id=context.session.conversation_id, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, - labels=context.memory_labels, # combined with strategy labels at _setup() attack_identifier=self.get_identifier(), ) diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 1feec20586..b330ad21d7 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -256,7 +256,6 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: system_prompt=adversarial_system_prompt, conversation_id=context.session.adversarial_chat_conversation_id, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> AttackResult: @@ -379,7 +378,6 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] conversation_id=context.session.adversarial_chat_conversation_id, target=self._adversarial_chat, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) # Check if the response is valid @@ -543,7 +541,6 @@ async def _send_prompt_to_objective_target_async( request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, target=self._objective_target, - labels=context.memory_labels, attack_identifier=self.get_identifier(), ) diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index f6ccc4ed64..a183df8f55 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -540,7 +540,6 @@ async def _send_prompt_to_target_async(self, prompt: str) -> Message: response_converter_configurations=self._response_converters, conversation_id=self.objective_target_conversation_id, target=self._objective_target, - labels=self._memory_labels, attack_identifier=self._attack_id, ) @@ -596,7 +595,6 @@ async def _send_initial_prompt_to_target_async(self) -> Message: response_converter_configurations=self._response_converters, conversation_id=self.objective_target_conversation_id, target=self._objective_target, - labels=self._memory_labels, attack_identifier=self._attack_id, ) @@ -982,7 +980,6 @@ async def _generate_first_turn_prompt_async(self, objective: str) -> str: system_prompt=system_prompt, conversation_id=self.adversarial_chat_conversation_id, attack_identifier=self._attack_id, - labels=self._memory_labels, ) logger.debug(f"Node {self.node_id}: Using initial seed prompt for first turn") @@ -1107,7 +1104,6 @@ async def _send_to_adversarial_chat_async(self, prompt_text: str) -> str: message=message, conversation_id=self.adversarial_chat_conversation_id, target=self._adversarial_chat, - labels=self._memory_labels, attack_identifier=self._attack_id, ) diff --git a/pyrit/executor/attack/single_turn/context_compliance.py b/pyrit/executor/attack/single_turn/context_compliance.py index d03ab2a41f..e03daee3db 100644 --- a/pyrit/executor/attack/single_turn/context_compliance.py +++ b/pyrit/executor/attack/single_turn/context_compliance.py @@ -235,7 +235,6 @@ async def _get_objective_as_benign_question_async( message=message, target=self._adversarial_chat, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) return response.get_value() @@ -262,7 +261,6 @@ async def _get_benign_question_answer_async( message=message, target=self._adversarial_chat, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) return response.get_value() @@ -287,7 +285,6 @@ async def _get_objective_as_question_async(self, *, objective: str, context: Sin message=message, target=self._adversarial_chat, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) return response.get_value() diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index cdb2d4b619..e19b78e107 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -325,7 +325,6 @@ async def _send_prompt_to_objective_target_async( conversation_id=context.conversation_id, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, - labels=context.memory_labels, # combined with strategy labels at _setup() attack_identifier=self.get_identifier(), ) diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 208c4040d7..5ab0abf9bd 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -213,7 +213,6 @@ async def _setup_async(self, *, context: AnecdoctorContext) -> None: system_prompt=system_prompt, conversation_id=context.conversation_id, attack_identifier=self.get_identifier(), - labels=context.memory_labels, ) async def _perform_async(self, *, context: AnecdoctorContext) -> AnecdoctorResult: @@ -305,7 +304,6 @@ async def _send_examples_to_target_async( conversation_id=context.conversation_id, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, - labels=context.memory_labels, attack_identifier=self.get_identifier(), ) @@ -374,7 +372,6 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) -> system_prompt=kg_system_prompt, conversation_id=kg_conversation_id, attack_identifier=self.get_identifier(), - labels=self._memory_labels, ) # Format examples for knowledge graph extraction using few-shot format @@ -390,7 +387,6 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) -> conversation_id=kg_conversation_id, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, - labels=self._memory_labels, attack_identifier=self.get_identifier(), ) diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index 2dc021b497..e8d1cc8a99 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -334,7 +334,6 @@ async def _setup_attack_async(self, *, context: XPIAContext) -> str: request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, target=self._attack_setup_target, - labels=context.memory_labels, attack_identifier=self.get_identifier(), conversation_id=context.attack_setup_target_conversation_id, ) @@ -566,7 +565,6 @@ async def process_async() -> str: target=self._processing_target, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, - labels=context.memory_labels, attack_identifier=self.get_identifier(), conversation_id=context.processing_conversation_id, ) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index e7d3097615..02b4bdbac1 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union -from sqlalchemy import and_, create_engine, event, exists, or_, text +from sqlalchemy import and_, create_engine, event, exists, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -449,7 +449,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Get the SQL Azure implementation for filtering AttackResults by labels. Matches if the labels are found on the AttackResultEntry directly - OR on an associated PromptMemoryEntry (via conversation_id). Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. @@ -470,30 +469,11 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ar_bindparams[value_param] = str(value) ar_combined = " AND ".join(ar_label_conditions) - direct_condition = and_( + return and_( AttackResultEntry.labels.isnot(None), text(f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}').bindparams(**ar_bindparams), ) - # --- Conversation-level match on PromptMemoryEntry.labels --- - pme_label_conditions = [] - pme_bindparams: dict[str, str] = {} - for key, value in labels.items(): - param_name = f"pme_label_{key}" - pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") - pme_bindparams[param_name] = str(value) - - pme_combined = " AND ".join(pme_label_conditions) - conversation_condition = exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - text(f"ISJSON(labels) = 1 AND {pme_combined}").bindparams(**pme_bindparams), - ) - ) - - return or_(direct_condition, conversation_condition) - def get_unique_attack_class_names(self) -> list[str]: """ Azure SQL implementation: extract unique class_name values from diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 6e7e1a1b15..cb2ea8619a 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1569,11 +1569,9 @@ def get_unique_attack_labels(self) -> dict[str, list[str]]: """ Return all unique label key-value pairs across attack results. - Labels live on ``PromptMemoryEntry.labels`` (the established SDK - path). This method JOINs with ``AttackResultEntry`` to scope the - query to conversations that belong to an attack, applies DISTINCT - to reduce duplicate label dicts, then aggregates unique key-value - pairs in Python. + Labels live directly on ``AttackResultEntry.labels``. This method + queries distinct label dicts from attack results, then aggregates + unique key-value pairs in Python. Returns: dict[str, list[str]]: Mapping of label keys to sorted lists of @@ -1582,16 +1580,7 @@ def get_unique_attack_labels(self) -> dict[str, list[str]]: label_values: dict[str, set[str]] = {} with closing(self.get_session()) as session: - rows = ( - session.query(PromptMemoryEntry.labels) - .join( - AttackResultEntry, - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - ) - .filter(PromptMemoryEntry.labels.isnot(None)) - .distinct() - .all() - ) + rows = session.query(AttackResultEntry.labels).filter(AttackResultEntry.labels.isnot(None)).distinct().all() for (labels,) in rows: if not isinstance(labels, dict): diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 60fb1dca64..4f5cd5767d 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -155,7 +155,6 @@ class PromptMemoryEntry(Base): sequence (int): The order of the conversation within a conversation_id. Can be the same number for multi-part requests or multi-part responses. timestamp (DateTime): The timestamp of the memory entry. - labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized. targeted_harm_categories (List[str]): The targeted harm categories for the memory entry. prompt_metadata (JSON): The metadata associated with the prompt. This can be specific to any scenarios. Because memory is how components talk with each other, this can be component specific. @@ -186,7 +185,6 @@ class PromptMemoryEntry(Base): conversation_id = mapped_column(String, nullable=False) sequence = mapped_column(INTEGER, nullable=False) timestamp = mapped_column(DateTime, nullable=False) - labels: Mapped[dict[str, str]] = mapped_column(JSON) prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON) targeted_harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON) converter_identifiers: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON) @@ -221,6 +219,42 @@ class PromptMemoryEntry(Base): foreign_keys="ScoreEntry.prompt_request_response_id", ) + def get_attack_result_entries(self) -> list["AttackResultEntry"]: + """ + Return all AttackResultEntry rows whose attack owns this conversation. + + Matches on the main conversation_id as well as any related conversation + (pruned, adversarial, etc.) without hard-coding specific JSON column names. + The SQL query fetches candidates, then ``AttackResult.includes_conversation`` + is used as the authoritative Python-level filter so that new conversation + types are automatically covered. + + Returns: + list of AttackResultEntry rows whose attack owns this conversation. + + """ + from sqlalchemy import inspect as sa_inspect + + session = sa_inspect(self).session + if session is None: + return [] + + # Fetch candidate rows: direct conversation_id match, or any row + # whose full text representation contains this conversation_id. + candidates = ( + session.query(AttackResultEntry).filter(AttackResultEntry.conversation_id == self.conversation_id).all() + ) + + if not candidates: + # Broader search: all rows, filtered by the model's own logic. + candidates = [ + e + for e in session.query(AttackResultEntry).all() + if e.get_attack_result().includes_conversation(self.conversation_id) + ] + + return candidates + def __init__(self, *, entry: MessagePiece): """ Initialize a PromptMemoryEntry from a MessagePiece. @@ -233,7 +267,6 @@ def __init__(self, *, entry: MessagePiece): self.conversation_id = entry.conversation_id self.sequence = entry.sequence self.timestamp = entry.timestamp - self.labels = entry.labels self.prompt_metadata = entry.prompt_metadata self.targeted_harm_categories = entry.targeted_harm_categories self.converter_identifiers = [ @@ -301,7 +334,6 @@ def get_message_piece(self) -> MessagePiece: id=self.id, conversation_id=self.conversation_id, sequence=self.sequence, - labels=self.labels, prompt_metadata=self.prompt_metadata, targeted_harm_categories=self.targeted_harm_categories, converter_identifiers=converter_ids, @@ -314,6 +346,10 @@ def get_message_piece(self) -> MessagePiece: timestamp=_ensure_utc(self.timestamp), ) message_piece.scores = [score.get_score() for score in self.scores] + attack_entries = self.get_attack_result_entries() + + # message_piece._set_labels([e.get_attack_result().labels for e in attack_entries]) # noqa: ERA001 + return message_piece def __str__(self) -> str: diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index ac8b2319eb..0496e9be51 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Optional, TypeVar, Union, cast -from sqlalchemy import and_, create_engine, exists, func, or_, text +from sqlalchemy import and_, create_engine, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -616,26 +616,15 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: SQLite implementation for filtering AttackResults by labels. Matches if the labels are found on the AttackResultEntry directly - OR on an associated PromptMemoryEntry (via conversation_id). Returns: Any: A SQLAlchemy condition for filtering by labels. """ - direct_condition = and_( + return and_( AttackResultEntry.labels.isnot(None), *[func.json_extract(AttackResultEntry.labels, f"$.{key}") == value for key, value in labels.items()], ) - conversation_condition = exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - *[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()], - ) - ) - - return or_(direct_condition, conversation_condition) - def get_unique_attack_class_names(self) -> list[str]: """ SQLite implementation: extract unique class_name values from diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 21e785fdc2..556c7ee6a6 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -555,7 +555,6 @@ def construct_response_from_request( role="assistant", original_value=resp_text, conversation_id=request.conversation_id, - labels=request.labels, prompt_target_identifier=request.prompt_target_identifier, attack_identifier=request.attack_identifier, original_value_data_type=response_type, diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 91d01032bf..cfa26ad2de 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -37,7 +37,6 @@ def __init__( id: Optional[uuid.UUID | str] = None, # noqa: A002 conversation_id: Optional[str] = None, sequence: int = -1, - labels: Optional[dict[str, str]] = None, prompt_metadata: Optional[dict[str, Union[str, int]]] = None, converter_identifiers: Optional[list[Union[ComponentIdentifier, dict[str, str]]]] = None, prompt_target_identifier: Optional[Union[ComponentIdentifier, dict[str, Any]]] = None, @@ -65,7 +64,6 @@ def __init__( conversation_id: The identifier for the conversation which is associated with a single target. Defaults to None. sequence: The order of the conversation within a conversation_id. Defaults to -1. - labels: The labels associated with the memory entry. Several can be standardized. Defaults to None. prompt_metadata: The metadata associated with the prompt. This can be specific to any scenarios. Because memory is how components talk with each other, this can be component specific. e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. @@ -114,7 +112,6 @@ def __init__( self.timestamp = timestamp.replace(tzinfo=timezone.utc) else: self.timestamp = timestamp - self.labels = labels or {} self.prompt_metadata = prompt_metadata or {} # Handle converter_identifiers: normalize to ComponentIdentifier (handles dict with deprecation warning) @@ -169,6 +166,16 @@ def __init__( self.scores = scores if scores else [] self.targeted_harm_categories = targeted_harm_categories if targeted_harm_categories else [] + self._labels: dict[str, str] = {} + + @property + def labels(self) -> dict[str, str]: + """Labels associated with this message piece, hydrated by the memory layer.""" + return self._labels + + def _set_labels(self, labels: dict[str, str]) -> None: + self._labels = labels + async def set_sha256_values_async(self) -> None: """ Compute SHA256 hash values for original and converted payloads. diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index b730a58669..d24119833f 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -53,7 +53,6 @@ async def send_prompt_async( conversation_id: Optional[str] = None, request_converter_configurations: list[PromptConverterConfiguration] | None = None, response_converter_configurations: list[PromptConverterConfiguration] | None = None, - labels: Optional[dict[str, str]] = None, attack_identifier: Optional[ComponentIdentifier] = None, ) -> Message: """ @@ -67,7 +66,6 @@ async def send_prompt_async( converting the request. Defaults to an empty list. response_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for converting the response. Defaults to an empty list. - labels (Optional[dict[str, str]], optional): Labels associated with the request. Defaults to None. attack_identifier (Optional[ComponentIdentifier], optional): Identifier for the attack. Defaults to None. @@ -90,8 +88,6 @@ async def send_prompt_async( for piece in request.message_pieces: piece.conversation_id = conversation_id - if labels: - piece.labels = labels piece.prompt_target_identifier = target.get_identifier() if attack_identifier: piece.attack_identifier = attack_identifier diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index ce1f254678..2464144d6c 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -70,7 +70,6 @@ def set_system_prompt( system_prompt: str, conversation_id: str, attack_identifier: Optional[ComponentIdentifier] = None, - labels: Optional[dict[str, str]] = None, ) -> None: """ Set the system prompt for the prompt target. May be overridden by subclasses. @@ -91,7 +90,6 @@ def set_system_prompt( converted_value=system_prompt, prompt_target_identifier=self.get_identifier(), attack_identifier=attack_identifier, - labels=labels, ).to_message() ) diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 2948988c63..fbd3c29bf4 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -683,7 +683,6 @@ def _parse_response_output_section( role="assistant", original_value=piece_value, conversation_id=message_piece.conversation_id, - labels=message_piece.labels, prompt_target_identifier=message_piece.prompt_target_identifier, attack_identifier=message_piece.attack_identifier, original_value_data_type=piece_type, @@ -791,7 +790,6 @@ def _make_tool_piece(self, output: dict[str, Any], call_id: str, *, reference_pi ), original_value_data_type="function_call_output", conversation_id=reference_piece.conversation_id, - labels={"call_id": call_id}, prompt_target_identifier=reference_piece.prompt_target_identifier, attack_identifier=reference_piece.attack_identifier, ) diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index de317ad7b9..510991d2c4 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -242,7 +242,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: role="assistant", original_value=piece_data, conversation_id=request_piece.conversation_id, - labels=request_piece.labels, prompt_target_identifier=request_piece.prompt_target_identifier, attack_identifier=request_piece.attack_identifier, original_value_data_type=piece_type, diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 4c736daf2d..64b3ba2bec 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -85,7 +85,6 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: original_value_data_type=row.get("data_type", None), # type: ignore[arg-type] conversation_id=row.get("conversation_id", None), sequence=int(sequence_str) if sequence_str else None, - labels=labels, response_error=row.get("response_error", None), # type: ignore[arg-type] prompt_target_identifier=self.get_identifier(), ) diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index e2b7a5ce13..d5e0d33b42 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -84,7 +84,6 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non converted_value=conversation_text, id=original_piece.id, conversation_id=original_piece.conversation_id, - labels=original_piece.labels, prompt_target_identifier=original_piece.prompt_target_identifier, attack_identifier=original_piece.attack_identifier, original_value_data_type=original_piece.original_value_data_type, diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index a4d87fd6d3..77148f4351 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -24,7 +24,7 @@ from collections.abc import Sequence -def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_categories=None, labels=None): +def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_categories=None): """Helper function to create MessagePiece with optional targeted harm categories and labels.""" return MessagePiece( role="user", @@ -32,7 +32,6 @@ def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_ca converted_value=f"Test prompt {prompt_num}", conversation_id=conversation_id, targeted_harm_categories=targeted_harm_categories, - labels=labels, ) @@ -981,27 +980,6 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me assert results[0].conversation_id == "conv_1" -def test_get_attack_results_by_labels_falls_back_to_conversation_labels(sqlite_instance: MemoryInterface): - """Test that label filtering matches via PromptMemoryEntry when AttackResult has no labels.""" - - # Attack result with NO labels - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={}) - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) - - # Conversation message carries the labels instead - message_piece = create_message_piece("conv_1", 1, labels={"operation": "legacy_op"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) - - # Should still find the attack result via the PME fallback path - results = sqlite_instance.get_attack_results(labels={"operation": "legacy_op"}) - assert len(results) == 1 - assert results[0].conversation_id == "conv_1" - - # Non-matching label should return nothing - results = sqlite_instance.get_attack_results(labels={"operation": "missing"}) - assert len(results) == 0 - - # --------------------------------------------------------------------------- # get_unique_attack_labels tests # --------------------------------------------------------------------------- @@ -1014,11 +992,8 @@ def test_get_unique_attack_labels_empty(sqlite_instance: MemoryInterface): def test_get_unique_attack_labels_single(sqlite_instance: MemoryInterface): - """Returns labels from a single attack result's message pieces.""" - message = create_message_piece("conv_1", 1, labels={"env": "prod", "team": "red"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message]) - - ar = create_attack_result("conv_1", 1) + """Returns labels from a single attack result.""" + ar = create_attack_result("conv_1", 1, labels={"env": "prod", "team": "red"}) sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) result = sqlite_instance.get_unique_attack_labels() @@ -1027,20 +1002,16 @@ def test_get_unique_attack_labels_single(sqlite_instance: MemoryInterface): def test_get_unique_attack_labels_multiple_attacks_merges_values(sqlite_instance: MemoryInterface): """Values from different attacks are merged and sorted.""" - msg1 = create_message_piece("conv_1", 1, labels={"env": "prod", "team": "red"}) - msg2 = create_message_piece("conv_2", 2, labels={"env": "staging", "team": "red"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg1, msg2]) - - ar1 = create_attack_result("conv_1", 1) - ar2 = create_attack_result("conv_2", 2) + ar1 = create_attack_result("conv_1", 1, labels={"env": "prod", "team": "red"}) + ar2 = create_attack_result("conv_2", 2, labels={"env": "staging", "team": "red"}) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) result = sqlite_instance.get_unique_attack_labels() assert result == {"env": ["prod", "staging"], "team": ["red"]} -def test_get_unique_attack_labels_no_pieces(sqlite_instance: MemoryInterface): - """Attack results without any message pieces return empty dict.""" +def test_get_unique_attack_labels_no_labels(sqlite_instance: MemoryInterface): + """Attack results with empty labels return empty dict.""" ar = create_attack_result("conv_1", 1) sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) @@ -1048,34 +1019,18 @@ def test_get_unique_attack_labels_no_pieces(sqlite_instance: MemoryInterface): assert result == {} -def test_get_unique_attack_labels_pieces_without_labels(sqlite_instance: MemoryInterface): - """Message pieces with no labels are skipped.""" - msg = create_message_piece("conv_1", 1) # labels=None - sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg]) - - ar = create_attack_result("conv_1", 1) +def test_get_unique_attack_labels_null_labels_skipped(sqlite_instance: MemoryInterface): + """Attack results with null labels are skipped.""" + ar = create_attack_result("conv_1", 1, labels=None) sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) result = sqlite_instance.get_unique_attack_labels() assert result == {} -def test_get_unique_attack_labels_ignores_non_attack_pieces(sqlite_instance: MemoryInterface): - """Labels on pieces not linked to any attack are excluded.""" - msg = create_message_piece("conv_no_attack", 1, labels={"env": "prod"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg]) - - # No AttackResult for "conv_no_attack" - result = sqlite_instance.get_unique_attack_labels() - assert result == {} - - def test_get_unique_attack_labels_non_string_values_skipped(sqlite_instance: MemoryInterface): """Non-string label values are ignored.""" - msg = create_message_piece("conv_1", 1, labels={"env": "prod", "count": 42}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg]) - - ar = create_attack_result("conv_1", 1) + ar = create_attack_result("conv_1", 1, labels={"env": "prod", "count": 42}) sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) result = sqlite_instance.get_unique_attack_labels() @@ -1084,12 +1039,8 @@ def test_get_unique_attack_labels_non_string_values_skipped(sqlite_instance: Mem def test_get_unique_attack_labels_keys_sorted(sqlite_instance: MemoryInterface): """Returned keys and values are sorted alphabetically.""" - msg1 = create_message_piece("conv_1", 1, labels={"zoo": "z_val", "alpha": "a"}) - msg2 = create_message_piece("conv_2", 2, labels={"alpha": "b"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg1, msg2]) - - ar1 = create_attack_result("conv_1", 1) - ar2 = create_attack_result("conv_2", 2) + ar1 = create_attack_result("conv_1", 1, labels={"zoo": "z_val", "alpha": "a"}) + ar2 = create_attack_result("conv_2", 2, labels={"alpha": "b"}) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) result = sqlite_instance.get_unique_attack_labels() @@ -1104,20 +1055,16 @@ def test_get_unique_attack_labels_non_dict_labels_skipped(sqlite_instance: Memor from sqlalchemy import text - # Insert a real attack + piece with normal labels first - msg1 = create_message_piece("conv_1", 1, labels={"env": "prod"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg1]) - ar1 = create_attack_result("conv_1", 1) + # Insert attack with normal labels + ar1 = create_attack_result("conv_1", 1, labels={"env": "prod"}) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) # Insert a second attack and use raw SQL to set labels to a JSON string - msg2 = create_message_piece("conv_2", 2, labels={"placeholder": "x"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg2]) ar2 = create_attack_result("conv_2", 2) sqlite_instance.add_attack_results_to_memory(attack_results=[ar2]) with closing(sqlite_instance.get_session()) as session: session.execute( - text('UPDATE "PromptMemoryEntries" SET labels = \'"just_a_string"\' WHERE conversation_id = :cid'), + text('UPDATE "AttackResultEntries" SET labels = \'"just_a_string"\' WHERE conversation_id = :cid'), {"cid": "conv_2"}, ) session.commit() diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index d13064a321..93b04032c5 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -668,7 +668,6 @@ def test_message_piece_to_dict(): converted_value="Hello", conversation_id="test_conversation", sequence=1, - labels={"label1": "value1"}, targeted_harm_categories=["violence", "illegal"], prompt_metadata={"key": "metadata"}, converter_identifiers=[ @@ -962,27 +961,6 @@ def test_message_piece_harm_categories_serialization(): assert result["targeted_harm_categories"] == harm_categories -def test_message_piece_harm_categories_with_labels(): - """Test that harm_categories and labels can coexist.""" - harm_categories = ["violence", "illegal"] - labels = {"operation": "test_op", "researcher": "alice"} - - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - targeted_harm_categories=harm_categories, - labels=labels, - ) - - assert entry.targeted_harm_categories == harm_categories - assert entry.labels == labels - - result = entry.to_dict() - assert result["targeted_harm_categories"] == harm_categories - assert result["labels"] == labels - - class TestSimulatedAssistantRole: """Tests for simulated_assistant role properties."""