From 41365d36d8f2f39ff114fef9482cf0985da10d72 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Tue, 14 Apr 2026 13:50:57 -0400 Subject: [PATCH 01/11] normalization in send_prompt_async --- .../azure_blob_storage_target.py | 6 +- pyrit/prompt_target/azure_ml_chat_target.py | 57 ++- pyrit/prompt_target/common/prompt_target.py | 65 ++- pyrit/prompt_target/gandalf_target.py | 6 +- .../prompt_target/http_target/http_target.py | 6 +- .../http_target/httpx_api_target.py | 5 +- .../hugging_face/hugging_face_chat_target.py | 9 +- .../hugging_face_endpoint_target.py | 6 +- .../openai/openai_chat_target.py | 13 +- .../openai/openai_completion_target.py | 6 +- .../openai/openai_image_target.py | 5 +- .../openai/openai_realtime_target.py | 7 +- .../openai/openai_response_target.py | 21 +- .../prompt_target/openai/openai_tts_target.py | 6 +- .../openai/openai_video_target.py | 6 +- .../playwright_copilot_target.py | 6 +- pyrit/prompt_target/playwright_target.py | 6 +- pyrit/prompt_target/prompt_shield_target.py | 5 +- pyrit/prompt_target/text_target.py | 6 +- .../prompt_target/websocket_copilot_target.py | 6 +- tests/integration/mocks.py | 4 +- .../test_prompt_target_contract.py | 4 +- tests/unit/mocks.py | 4 +- .../target/test_huggingface_chat_target.py | 2 + .../test_normalize_async_integration.py | 395 ++++++++++++++++++ .../target/test_openai_chat_target.py | 2 + .../target/test_openai_response_target.py | 2 + .../target/test_openai_target_auth.py | 2 +- .../target/test_target_capabilities.py | 4 +- .../prompt_target/test_prompt_chat_target.py | 4 +- tests/unit/registry/test_target_registry.py | 6 +- 31 files changed, 605 insertions(+), 77 deletions(-) create mode 100644 tests/unit/prompt_target/target/test_normalize_async_integration.py diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index dcabad099c..7142c4b9ca 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -196,18 +196,20 @@ def _parse_url(self) -> tuple[str, str]: return container_url, blob_prefix @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ (Async) Sends prompt to target, which creates a file and uploads it as a blob to the provided storage container. Args: message (Message): A Message to be sent to the target. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response with the Blob URL. """ - self._validate_request(message=message) request = message.message_pieces[0] # default file name is .txt, but can be overridden by prompt metadata diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index cd7083c61a..193783cb37 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import logging +import warnings from typing import Any, Optional from httpx import HTTPStatusError @@ -14,13 +15,18 @@ pyrit_target_retry, ) from pyrit.identifiers import ComponentIdentifier -from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer +from pyrit.message_normalizer import ChatMessageNormalizer, GenericSystemSquashNormalizer, MessageListNormalizer from pyrit.models import ( Message, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p @@ -77,9 +83,10 @@ def __init__( Defaults to the value of the `AZURE_ML_KEY` environment variable. model_name (str, Optional): The name of the model being used (e.g., "Llama-3.2-3B-Instruct"). Used for identification purposes. Defaults to empty string. - message_normalizer (MessageListNormalizer, Optional): The message normalizer. - For models that do not allow system prompts such as mistralai-Mixtral-8x7B-Instruct-v01, - GenericSystemSquashNormalizer() can be passed in. Defaults to ChatMessageNormalizer(). + message_normalizer (MessageListNormalizer, Optional): **Deprecated.** Use + ``custom_configuration`` with ``CapabilityHandlingPolicy`` instead. Previously used for + models that do not allow system prompts. Defaults to ChatMessageNormalizer(). + Will be removed in v0.14.0. max_new_tokens (int, Optional): The maximum number of tokens to generate in the response. Defaults to 400. temperature (float, Optional): The temperature for generating diverse responses. 1.0 is most random, @@ -105,6 +112,34 @@ def __init__( endpoint_value = default_values.get_required_value( env_var_name=self.endpoint_uri_environment_variable, passed_value=endpoint ) + + # Translate legacy message_normalizer into TargetConfiguration + if message_normalizer is not None and isinstance(message_normalizer, GenericSystemSquashNormalizer): + warnings.warn( + "Passing GenericSystemSquashNormalizer as message_normalizer is deprecated. " + "Use custom_configuration=TargetConfiguration(capabilities=TargetCapabilities(" + "supports_system_prompt=False), policy=CapabilityHandlingPolicy(behaviors={" + "CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT})) instead. " + "Will be removed in v0.14.0.", + DeprecationWarning, + stacklevel=2, + ) + if custom_configuration is None: + custom_configuration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_message_pieces=True, + supports_editable_history=True, + supports_multi_turn=True, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ) + PromptChatTarget.__init__( self, max_requests_per_minute=max_requests_per_minute, @@ -164,12 +199,15 @@ def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously send a message to the Azure ML chat target. Args: message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. @@ -179,18 +217,13 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: RateLimitException: If the target rate limit is exceeded. HTTPStatusError: For any other HTTP errors during the process. """ - self._validate_request(message=message) request = message.message_pieces[0] - # Get chat messages from memory and append the current message - messages = list(self._memory.get_conversation(conversation_id=request.conversation_id)) - messages.append(message) - logger.info(f"Sending the following prompt to the prompt target: {request}") try: resp_text = await self._complete_chat_async( - messages=messages, + messages=normalized_conversation, ) if not resp_text: diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 35aad03183..b0eb346b8c 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -108,14 +108,51 @@ def __init__( if self._verbose: logging.basicConfig(level=logging.INFO) - @abc.abstractmethod async def send_prompt_async(self, *, message: Message) -> list[Message]: """ - Send a normalized prompt async to the prompt target. + Validate, normalize, and send a prompt to the target. + + This is the public entry point called by the prompt normalizer. It: + + 1. Validates the request against the target's capabilities. + 2. Fetches the conversation from memory, appends ``message``, and runs + the normalization pipeline (system‑squash, history‑squash, etc.). + 3. Delegates to :meth:`_send_prompt_target_async` with both the original + message and the normalized conversation. + + Subclasses MUST NOT override this method. Override + :meth:`_send_prompt_target_async` instead. + + Args: + message (Message): The message to send. + + Returns: + list[Message]: Response messages from the target. + """ + self._validate_request(message=message) + normalized_conversation = await self._get_normalized_conversation_async(message=message) + return await self._send_prompt_target_async( + message=message, normalized_conversation=normalized_conversation + ) + + @abc.abstractmethod + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: + """ + Target-specific send logic. + + Called by :meth:`send_prompt_async` after validation and normalization. + + Args: + message (Message): The original message (unmodified). + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. Single-turn targets may ignore this and use only + ``message``. Returns: - list[Message]: A list of message responses. Most targets return a single message, - but some (like response target with tool calls) may return multiple messages. + list[Message]: Response messages from the target. """ def _validate_request(self, *, message: Message) -> None: @@ -163,6 +200,26 @@ def _validate_request(self, *, message: Message) -> None: f"This target only supports a single turn conversation. {custom_configuration_message}" ) + async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: + """ + Fetch the conversation from memory, append the current message, and run the + normalization pipeline. + + The original conversation in memory is never mutated. The returned list is an + ephemeral copy intended only for building the API request body. + + Args: + message (Message): The current message to append. + + Returns: + list[Message]: The normalized conversation (possibly with system prompt squashed, + history squashed, etc.). + """ + conversation_id = message.message_pieces[0].conversation_id + conversation = self._memory.get_conversation(conversation_id=conversation_id) + conversation.append(message) + return await self.configuration.normalize_async(messages=list(conversation)) + def set_model_name(self, *, model_name: str) -> None: """ Set the model name for this target. diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index f92326e66b..2db5bbf8a7 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -85,17 +85,19 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously send a message to the Gandalf target. Args: message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. """ - self._validate_request(message=message) request = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index da805d50fe..d354833f26 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -155,17 +155,19 @@ def _inject_prompt_into_request(self, request: MessagePiece) -> str: return http_request_w_prompt @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously send a message to the HTTP target. Args: message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. """ - self._validate_request(message=message) request = message.message_pieces[0] http_request_w_prompt = self._inject_prompt_into_request(request) diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index 7555038b64..4590321ba0 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -108,7 +108,9 @@ def __init__( raise ValueError(f"File uploads are not allowed with HTTP method: {self.method}") @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Override the parent's method to skip raw http_request usage, and do a standard "API mode" approach. @@ -125,7 +127,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: httpx.RequestError: If the request fails. FileNotFoundError: If the specified file to upload is not found. """ - self._validate_request(message=message) message_piece: MessagePiece = message.message_pieces[0] # If user didn't set file_path, see if the PDF path is in converted_value diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index ae02026004..2d343c70d3 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -295,10 +295,16 @@ async def load_model_and_tokenizer(self) -> None: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Send a normalized prompt asynchronously to the HuggingFace model. + Args: + message (Message): The message to send. + normalized_conversation (list[Message]): The normalized conversation history. + Returns: list[Message]: A list containing the response object with generated text pieces. @@ -309,7 +315,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Load the model and tokenizer using the encapsulated method await self.load_model_and_tokenizer_task - self._validate_request(message=message) request = message.message_pieces[0] prompt_template = request.converted_value diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index 5c4e3d9371..bdaced7098 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -87,13 +87,16 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Send a normalized prompt asynchronously to a cloud-based HuggingFace model endpoint. Args: message (Message): The message containing the input data and associated details such as conversation ID and role. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response object with generated text pieces. @@ -102,7 +105,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ValueError: If the response from the Hugging Face API is not successful. Exception: If an error occurs during the HTTP request to the Hugging Face endpoint. """ - self._validate_request(message=message) request = message.message_pieces[0] headers = {"Authorization": f"Bearer {self.hf_token}"} payload: dict[str, object] = { diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 277ca4c08c..db2ac82f3d 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -241,28 +241,25 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously sends a message and handles the response within a managed conversation context. Args: message (Message): The message object. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. """ - self._validate_request(message=message) - message_piece: MessagePiece = message.message_pieces[0] json_config = self._get_json_response_config(message_piece=message_piece) - # Get conversation from memory and append the current message - conversation = self._memory.get_conversation(conversation_id=message_piece.conversation_id) - conversation.append(message) - logger.info(f"Sending the following prompt to the prompt target: {message}") - body = await self._construct_request_body(conversation=conversation, json_config=json_config) + body = await self._construct_request_body(conversation=normalized_conversation, json_config=json_config) # Use unified error handling - automatically detects ChatCompletion and validates response = await self._handle_openai_request( diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index b1ee2efe25..95653a5062 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -122,17 +122,19 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously send a message to the OpenAI completion target. Args: message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. """ - self._validate_request(message=message) message_piece = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {message_piece}") diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index a99e01dd60..a9cd002ed9 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -148,10 +148,11 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async( + async def _send_prompt_target_async( self, *, message: Message, + normalized_conversation: list[Message], ) -> list[Message]: """ Send a prompt to the OpenAI image target and return the response. @@ -159,11 +160,11 @@ async def send_prompt_async( Args: message (Message): The message to send. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the image target. """ - self._validate_request(message=message) logger.info(f"Sending the following prompt to the prompt target: {message}") diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index b14a6e823e..8587bb938f 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -336,12 +336,15 @@ def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously send a message to the OpenAI realtime target. Args: message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. @@ -359,8 +362,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Give the server a moment to process the session update await asyncio.sleep(0.5) - self._validate_request(message=message) - request = message.message_pieces[0] response_type = request.converted_value_data_type diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 8e02ce7b42..4ada55bac2 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -518,7 +518,9 @@ async def _construct_message_from_response(self, response: Any, request: Message @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Send prompt, handle agentic tool calls (function_call), return all messages. @@ -528,24 +530,19 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Args: message: The initial prompt from the user. + normalized_conversation: The normalized conversation history. Returns: List of messages generated during the interaction (assistant responses and tool messages). The normalizer will persist all of these to memory. """ - self._validate_request(message=message) - message_piece: MessagePiece = message.message_pieces[0] json_config = _JsonResponseConfig(enabled=False) if message.message_pieces: last_piece = message.message_pieces[-1] json_config = self._get_json_response_config(message_piece=last_piece) - # Get full conversation history from memory and append the current message - conversation: MutableSequence[Message] = self._memory.get_conversation( - conversation_id=message_piece.conversation_id - ) - conversation.append(message) + working_conversation: MutableSequence[Message] = list(normalized_conversation) # Track all responses generated during this interaction responses_to_return: list[Message] = [] @@ -554,9 +551,9 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: tool_call_section: Optional[dict[str, Any]] = None while True: - logger.info(f"Sending conversation with {len(conversation)} messages to the prompt target") + logger.info(f"Sending conversation with {len(working_conversation)} messages to the prompt target") - body = await self._construct_request_body(conversation=conversation, json_config=json_config) + body = await self._construct_request_body(conversation=working_conversation, json_config=json_config) # Use unified error handling - automatically detects Response and validates result = await self._handle_openai_request( @@ -565,7 +562,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ) # Add result to conversation and responses list - conversation.append(result) + working_conversation.append(result) responses_to_return.append(result) # Extract tool call if present @@ -583,7 +580,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: tool_message = Message(message_pieces=[tool_piece], skip_validation=True) # Add tool output message to conversation and responses list - conversation.append(tool_message) + working_conversation.append(tool_message) responses_to_return.append(tool_message) # Continue loop to send tool result and get next response diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 1ec69d3070..fa2d353351 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -114,17 +114,19 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously send a message to the OpenAI TTS target. Args: message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the audio response from the prompt target. """ - self._validate_request(message=message) message_piece = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {message_piece}") diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 8a23e08565..572e8de651 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -182,7 +182,9 @@ def _validate_duration(self) -> None: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously sends a message and generates a video using the OpenAI SDK. @@ -197,6 +199,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Args: message: The message object containing the prompt. + normalized_conversation: The normalized conversation history. Returns: A list containing the response with the generated video path. @@ -205,7 +208,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: RateLimitException: If the rate limit is exceeded. ValueError: If the request is invalid. """ - self._validate_request(message=message) text_piece = message.get_piece_by_type(data_type="text") diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index de317ad7b9..66687cd229 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -205,13 +205,16 @@ def _get_selectors(self) -> CopilotSelectors: file_picker_selector='span.fui-MenuItem__content:has-text("Upload images and files")', ) - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Send a message to Microsoft Copilot and return the response. Args: message (Message): The message to send. Can contain multiple pieces of type 'text' or 'image_path'. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from Copilot. @@ -219,7 +222,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Raises: RuntimeError: If an error occurs during interaction. """ - self._validate_request(message=message) try: response_content = await self._interact_with_copilot_async(message) diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index c97f5d803b..dbbb8a9518 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -100,12 +100,15 @@ def __init__( self._page = page @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously send a message to the Playwright target. Args: message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. @@ -113,7 +116,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Raises: RuntimeError: If the Playwright page is not initialized or if an error occurs during interaction. """ - self._validate_request(message=message) if not self._page: raise RuntimeError( "Playwright page is not initialized. Please pass a Page object when initializing PlaywrightTarget." diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 8bb249f350..416639a8c6 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -123,7 +123,9 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Parse the text in message to separate the userPrompt and documents contents, then send an HTTP request to the endpoint and obtain a response in JSON. For more info, visit @@ -132,7 +134,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Returns: list[Message]: A list containing the response object with generated text pieces. """ - self._validate_request(message=message) request = message.message_pieces[0] diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 4c736daf2d..0ad126fb1a 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -42,17 +42,19 @@ def __init__( super().__init__(custom_configuration=custom_configuration, custom_capabilities=custom_capabilities) self._text_stream = text_stream - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously write a message to the text stream. Args: message (Message): The message object to write to the stream. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: An empty list (no response expected). """ - self._validate_request(message=message) self._text_stream.write(f"{str(message)}\n") self._text_stream.flush() diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 9b56078272..cceb243437 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -644,7 +644,9 @@ def _generate_consistent_copilot_ids(self, *, pyrit_conversation_id: str) -> tup @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: """ Asynchronously send a message to Microsoft Copilot using WebSocket. @@ -654,6 +656,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Args: message (Message): A message to be sent to the target. + normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from Copilot. @@ -663,7 +666,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: InvalidStatus: If the WebSocket handshake fails with an HTTP status error. RuntimeError: If any other error occurs during WebSocket communication. """ - self._validate_request(message=message) pyrit_conversation_id = message.message_pieces[0].conversation_id is_start_of_session = self._is_start_of_session(conversation_id=pyrit_conversation_id) diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index 5b872eb014..643bc3a66b 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -66,7 +66,9 @@ def set_system_prompt( ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: self.prompt_sent.append(message.get_value()) return [ diff --git a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py index 68f69f5c8d..ae8050672c 100644 --- a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py +++ b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py @@ -23,7 +23,9 @@ class _MinimalTarget(PromptTarget): """Minimal concrete PromptTarget for contract testing.""" - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: return [] def _validate_request(self, *, message) -> None: diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 0bfa55f609..254132f1d0 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -149,7 +149,9 @@ def set_system_prompt( ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: self.prompt_sent.append(message.get_value()) return [ diff --git a/tests/unit/prompt_target/target/test_huggingface_chat_target.py b/tests/unit/prompt_target/target/test_huggingface_chat_target.py index 0de8ddbbf3..6ce2a388ec 100644 --- a/tests/unit/prompt_target/target/test_huggingface_chat_target.py +++ b/tests/unit/prompt_target/target/test_huggingface_chat_target.py @@ -163,6 +163,7 @@ async def test_load_model_and_tokenizer(): @pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") @pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") async def test_send_prompt_async(): hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) await hf_chat.load_model_and_tokenizer() @@ -185,6 +186,7 @@ async def test_send_prompt_async(): @pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") @pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") async def test_missing_chat_template_error(): hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) await hf_chat.load_model_and_tokenizer() diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py new file mode 100644 index 0000000000..647287e367 --- /dev/null +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -0,0 +1,395 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import warnings +from collections.abc import MutableSequence +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from openai.types.chat import ChatCompletion + +from pyrit.identifiers import ComponentIdentifier +from pyrit.memory.memory_interface import MemoryInterface +from pyrit.message_normalizer import GenericSystemSquashNormalizer +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import AzureMLChatTarget, OpenAIChatTarget +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.prompt_target.openai.openai_response_target import OpenAIResponseTarget + + +def _make_message_piece(*, role: str, content: str, conversation_id: str = "conv1") -> MessagePiece: + return MessagePiece( + role=role, + conversation_id=conversation_id, + original_value=content, + converted_value=content, + original_value_data_type="text", + converted_value_data_type="text", + prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), + attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), + ) + + +def _make_message(*, role: str, content: str, conversation_id: str = "conv1") -> Message: + return Message(message_pieces=[_make_message_piece(role=role, content=content, conversation_id=conversation_id)]) + + +def _create_mock_chat_completion(content: str = "hi") -> MagicMock: + mock = MagicMock(spec=ChatCompletion) + mock.choices = [MagicMock()] + mock.choices[0].finish_reason = "stop" + mock.choices[0].message.content = content + mock.choices[0].message.audio = None + mock.choices[0].message.tool_calls = None + mock.model_dump_json.return_value = json.dumps( + {"choices": [{"finish_reason": "stop", "message": {"content": content}}]} + ) + return mock + + +# --------------------------------------------------------------------------- +# OpenAIChatTarget — normalize_async is called +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_openai_chat_target_calls_normalize_async(): + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + mock_completion = _create_mock_chat_completion("world") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + with patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize: + mock_normalize.return_value = [user_msg] + await target.send_prompt_async(message=user_msg) + + mock_normalize.assert_called_once() + call_messages = mock_normalize.call_args.kwargs["messages"] + assert len(call_messages) == 1 + assert call_messages[0].get_value() == "hello" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_openai_chat_target_sends_normalized_to_construct_request(): + """Verify that the normalized (not original) conversation is used for the API body.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + user_msg = _make_message(role="user", content="original") + adapted_msg = _make_message(role="user", content="adapted") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + mock_completion = _create_mock_chat_completion("response") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + with patch.object( + target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg] + ), patch.object(target, "_construct_request_body", new_callable=AsyncMock, return_value={"model": "gpt-4o", "messages": []}) as mock_construct: + await target.send_prompt_async(message=user_msg) + + # _construct_request_body should receive the adapted message, not the original + call_conv = mock_construct.call_args.kwargs["conversation"] + assert len(call_conv) == 1 + assert call_conv[0].get_value() == "adapted" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_openai_chat_target_memory_not_mutated(): + """Memory-backed conversation must not be altered by normalize_async.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_message(role="system", content="be nice") + user_msg = _make_message(role="user", content="hello") + + # Memory returns a conversation with a system message + memory_conversation: MutableSequence[Message] = [system_msg] + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = memory_conversation + target._memory = mock_memory + + mock_completion = _create_mock_chat_completion("response") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + await target.send_prompt_async(message=user_msg) + + # Memory-backed conversation should still contain the system message + assert len(memory_conversation) == 2 # system + appended user + assert memory_conversation[0].get_piece().api_role == "system" + + +# --------------------------------------------------------------------------- +# OpenAIResponseTarget — normalize_async is called +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_openai_response_target_calls_normalize_async(): + target = OpenAIResponseTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + # Mock the API to return a simple response (no tool calls) + mock_response = MagicMock() + mock_response.error = None + mock_response.status = "completed" + mock_response.output = [MagicMock()] + mock_response.output[0].type = "message" + mock_response.output[0].content = [MagicMock()] + mock_response.output[0].content[0].type = "output_text" + mock_response.output[0].content[0].text = "world" + mock_response.model_dump_json.return_value = json.dumps({"output": [{"type": "message", "content": [{"type": "output_text", "text": "world"}]}]}) + target._async_client.responses.create = AsyncMock(return_value=mock_response) + + with patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize: + mock_normalize.return_value = [user_msg] + await target.send_prompt_async(message=user_msg) + + mock_normalize.assert_called_once() + call_messages = mock_normalize.call_args.kwargs["messages"] + assert len(call_messages) == 1 + + +# --------------------------------------------------------------------------- +# AzureMLChatTarget — normalize_async is called +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_azure_ml_target_calls_normalize_async(): + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + ) + + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + with patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize, \ + patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"): + mock_normalize.return_value = [user_msg] + await target.send_prompt_async(message=user_msg) + + mock_normalize.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_azure_ml_target_sends_normalized_to_complete_chat(): + """Normalized (not original) messages should be passed to _complete_chat_async.""" + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + ) + + user_msg = _make_message(role="user", content="original") + adapted_msg = _make_message(role="user", content="adapted") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + with patch.object( + target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg] + ), patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat: + await target.send_prompt_async(message=user_msg) + + call_messages = mock_chat.call_args.kwargs["messages"] + assert len(call_messages) == 1 + assert call_messages[0].get_value() == "adapted" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_azure_ml_target_memory_not_mutated(): + """Memory should retain original messages after normalization.""" + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_message(role="system", content="be nice") + user_msg = _make_message(role="user", content="hello") + + memory_conversation: MutableSequence[Message] = [system_msg] + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = memory_conversation + target._memory = mock_memory + + with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"): + await target.send_prompt_async(message=user_msg) + + # Memory must still have original system message + assert len(memory_conversation) == 2 + assert memory_conversation[0].get_piece().api_role == "system" + + +# --------------------------------------------------------------------------- +# AzureMLChatTarget — message_normalizer deprecation +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +def test_azure_ml_generic_system_squash_normalizer_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + message_normalizer=GenericSystemSquashNormalizer(), + ) + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + assert "message_normalizer is deprecated" in str(deprecation_warnings[0].message) + + +@pytest.mark.usefixtures("patch_central_database") +def test_azure_ml_generic_system_squash_normalizer_creates_adapt_configuration(): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + message_normalizer=GenericSystemSquashNormalizer(), + ) + # The configuration should now have supports_system_prompt=False with ADAPT policy + assert not target.capabilities.supports_system_prompt + # Pipeline should have a system squash normalizer + assert target.configuration.includes(capability=CapabilityName.MULTI_TURN) + assert not target.configuration.includes(capability=CapabilityName.SYSTEM_PROMPT) + + +@pytest.mark.usefixtures("patch_central_database") +def test_azure_ml_generic_system_squash_normalizer_does_not_override_explicit_config(): + """If custom_configuration is already provided, message_normalizer deprecation should not override it.""" + custom_config = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=True, + supports_multi_message_pieces=True, + ) + ) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + message_normalizer=GenericSystemSquashNormalizer(), + custom_configuration=custom_config, + ) + # Explicit custom_configuration should win + assert target.capabilities.supports_system_prompt + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_azure_ml_system_squash_via_configuration_pipeline(): + """End-to-end: GenericSystemSquashNormalizer-equivalent behavior via TargetConfiguration pipeline.""" + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_message(role="system", content="be nice") + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [system_msg] + target._memory = mock_memory + + with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat: + await target.send_prompt_async(message=user_msg) + + # _complete_chat_async should receive normalized messages (system squashed into user) + call_messages = mock_chat.call_args.kwargs["messages"] + roles = [m.get_piece().api_role for m in call_messages] + assert "system" not in roles + # The squashed message should contain the system content + assert "be nice" in call_messages[0].get_value() + assert "hello" in call_messages[0].get_value() diff --git a/tests/unit/prompt_target/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py index 7f05f8d83f..085a15fb91 100644 --- a/tests/unit/prompt_target/target/test_openai_chat_target.py +++ b/tests/unit/prompt_target/target/test_openai_chat_target.py @@ -322,6 +322,7 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j return_value=mock_completion ) target._memory = MagicMock(MemoryInterface) + target._memory.get_conversation.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -500,6 +501,7 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di return_value=mock_completion ) target._memory = MagicMock(MemoryInterface) + target._memory.get_conversation.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) diff --git a/tests/unit/prompt_target/target/test_openai_response_target.py b/tests/unit/prompt_target/target/test_openai_response_target.py index 507f8f0935..4e0527e0e1 100644 --- a/tests/unit/prompt_target/target/test_openai_response_target.py +++ b/tests/unit/prompt_target/target/test_openai_response_target.py @@ -343,6 +343,7 @@ async def test_send_prompt_async_empty_response_adds_to_memory( ): target._async_client.responses.create = AsyncMock(return_value=mock_response) # type: ignore[method-assign] target._memory = MagicMock(MemoryInterface) + target._memory.get_conversation.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -485,6 +486,7 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di ): target._async_client.responses.create = AsyncMock(return_value=mock_response) # type: ignore[method-assign] target._memory = MagicMock(MemoryInterface) + target._memory.get_conversation.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) diff --git a/tests/unit/prompt_target/target/test_openai_target_auth.py b/tests/unit/prompt_target/target/test_openai_target_auth.py index ac8444257a..834c6b863a 100644 --- a/tests/unit/prompt_target/target/test_openai_target_auth.py +++ b/tests/unit/prompt_target/target/test_openai_target_auth.py @@ -34,7 +34,7 @@ async def _construct_message_from_response(self, response, request): def _validate_request(self, *, message) -> None: pass - async def send_prompt_async(self, *, message): + async def _send_prompt_target_async(self, *, message, normalized_conversation): raise NotImplementedError diff --git a/tests/unit/prompt_target/target/test_target_capabilities.py b/tests/unit/prompt_target/target/test_target_capabilities.py index 7cd58793a2..9d4560e569 100644 --- a/tests/unit/prompt_target/target/test_target_capabilities.py +++ b/tests/unit/prompt_target/target/test_target_capabilities.py @@ -493,7 +493,9 @@ def _make_target_class(self, *, default_config: "TargetConfiguration"): class _ConcreteTarget(PromptTarget): _DEFAULT_CONFIGURATION = default_config - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: return [] return _ConcreteTarget diff --git a/tests/unit/prompt_target/test_prompt_chat_target.py b/tests/unit/prompt_target/test_prompt_chat_target.py index 872e5dec71..7dca1c743e 100644 --- a/tests/unit/prompt_target/test_prompt_chat_target.py +++ b/tests/unit/prompt_target/test_prompt_chat_target.py @@ -105,7 +105,9 @@ def test_init_subclass_promotes_default_capabilities_with_warning(): class _LegacyTarget(PromptTarget): _DEFAULT_CAPABILITIES = TargetCapabilities(supports_multi_turn=True) - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_target_async( + self, *, message: Message, normalized_conversation: list[Message] + ) -> list[Message]: return [] deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index 503d096a38..01dfa2c0a8 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -17,10 +17,11 @@ class MockPromptTarget(PromptTarget): def __init__(self, *, model_name: str = "mock_model") -> None: super().__init__(model_name=model_name) - async def send_prompt_async( + async def _send_prompt_target_async( self, *, message: Message, + normalized_conversation: list[Message], ) -> list[Message]: return [ MessagePiece( @@ -39,10 +40,11 @@ class MockPromptChatTarget(PromptChatTarget): def __init__(self, *, model_name: str = "mock_chat_model", endpoint: str = "http://chat-test") -> None: super().__init__(model_name=model_name, endpoint=endpoint) - async def send_prompt_async( + async def _send_prompt_target_async( self, *, message: Message, + normalized_conversation: list[Message], ) -> list[Message]: return [ MessagePiece( From 738fa99e077d3e9a945a1294dbd985eeb6d814b0 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Tue, 14 Apr 2026 17:44:16 -0400 Subject: [PATCH 02/11] replace message with normalized conversation --- pyrit/prompt_target/azure_blob_storage_target.py | 4 ++-- pyrit/prompt_target/azure_ml_chat_target.py | 4 ++-- pyrit/prompt_target/common/prompt_target.py | 12 +++++------- pyrit/prompt_target/gandalf_target.py | 4 ++-- pyrit/prompt_target/http_target/http_target.py | 4 ++-- pyrit/prompt_target/http_target/httpx_api_target.py | 3 ++- .../hugging_face/hugging_face_chat_target.py | 4 ++-- .../hugging_face/hugging_face_endpoint_target.py | 5 ++--- pyrit/prompt_target/openai/openai_chat_target.py | 4 ++-- .../prompt_target/openai/openai_completion_target.py | 4 ++-- pyrit/prompt_target/openai/openai_image_target.py | 3 +-- pyrit/prompt_target/openai/openai_realtime_target.py | 4 ++-- pyrit/prompt_target/openai/openai_response_target.py | 4 ++-- pyrit/prompt_target/openai/openai_tts_target.py | 4 ++-- pyrit/prompt_target/openai/openai_video_target.py | 4 ++-- pyrit/prompt_target/playwright_copilot_target.py | 5 ++--- pyrit/prompt_target/playwright_target.py | 4 ++-- pyrit/prompt_target/prompt_shield_target.py | 4 ++-- pyrit/prompt_target/text_target.py | 4 ++-- pyrit/prompt_target/websocket_copilot_target.py | 4 ++-- tests/integration/mocks.py | 3 ++- .../test_prompt_target_contract.py | 2 +- tests/unit/mocks.py | 3 ++- .../prompt_target/target/test_openai_target_auth.py | 2 +- .../prompt_target/target/test_target_capabilities.py | 2 +- tests/unit/prompt_target/test_prompt_chat_target.py | 2 +- tests/unit/registry/test_target_registry.py | 2 -- 27 files changed, 50 insertions(+), 54 deletions(-) diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index 7142c4b9ca..a2a2bf62af 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -197,19 +197,19 @@ def _parse_url(self) -> tuple[str, str]: @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ (Async) Sends prompt to target, which creates a file and uploads it as a blob to the provided storage container. Args: - message (Message): A Message to be sent to the target. normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response with the Blob URL. """ + message = normalized_conversation[-1] request = message.message_pieces[0] # default file name is .txt, but can be overridden by prompt metadata diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 193783cb37..a0f86354c0 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -200,13 +200,12 @@ def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously send a message to the Azure ML chat target. Args: - message (Message): The message object containing the prompt to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: @@ -217,6 +216,7 @@ async def _send_prompt_target_async( RateLimitException: If the target rate limit is exceeded. HTTPStatusError: For any other HTTP errors during the process. """ + message = normalized_conversation[-1] request = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index b0eb346b8c..b92b808ae2 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -117,8 +117,8 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: 1. Validates the request against the target's capabilities. 2. Fetches the conversation from memory, appends ``message``, and runs the normalization pipeline (system‑squash, history‑squash, etc.). - 3. Delegates to :meth:`_send_prompt_target_async` with both the original - message and the normalized conversation. + 3. Delegates to :meth:`_send_prompt_target_async` with the normalized + conversation. Subclasses MUST NOT override this method. Override :meth:`_send_prompt_target_async` instead. @@ -132,12 +132,12 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: self._validate_request(message=message) normalized_conversation = await self._get_normalized_conversation_async(message=message) return await self._send_prompt_target_async( - message=message, normalized_conversation=normalized_conversation + normalized_conversation=normalized_conversation ) @abc.abstractmethod async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Target-specific send logic. @@ -145,11 +145,9 @@ async def _send_prompt_target_async( Called by :meth:`send_prompt_async` after validation and normalization. Args: - message (Message): The original message (unmodified). normalized_conversation (list[Message]): The full conversation (history + current message) after running the normalization - pipeline. Single-turn targets may ignore this and use only - ``message``. + pipeline. The current message is the last element. Returns: list[Message]: Response messages from the target. diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 2db5bbf8a7..a1bc38a681 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -86,18 +86,18 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously send a message to the Gandalf target. Args: - message (Message): The message object containing the prompt to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. """ + message = normalized_conversation[-1] request = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index d354833f26..dc9459d55e 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -156,18 +156,18 @@ def _inject_prompt_into_request(self, request: MessagePiece) -> str: @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously send a message to the HTTP target. Args: - message (Message): The message object containing the prompt to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. """ + message = normalized_conversation[-1] request = message.message_pieces[0] http_request_w_prompt = self._inject_prompt_into_request(request) diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index 4590321ba0..dfa48e1236 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -109,7 +109,7 @@ def __init__( @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Override the parent's method to skip raw http_request usage, @@ -127,6 +127,7 @@ async def _send_prompt_target_async( httpx.RequestError: If the request fails. FileNotFoundError: If the specified file to upload is not found. """ + message = normalized_conversation[-1] message_piece: MessagePiece = message.message_pieces[0] # If user didn't set file_path, see if the PDF path is in converted_value diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 2d343c70d3..d7fa18f102 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -296,13 +296,12 @@ async def load_model_and_tokenizer(self) -> None: @limit_requests_per_minute @pyrit_target_retry async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Send a normalized prompt asynchronously to the HuggingFace model. Args: - message (Message): The message to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: @@ -315,6 +314,7 @@ async def _send_prompt_target_async( # Load the model and tokenizer using the encapsulated method await self.load_model_and_tokenizer_task + message = normalized_conversation[-1] request = message.message_pieces[0] prompt_template = request.converted_value diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index bdaced7098..f1c96a1588 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -88,14 +88,12 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Send a normalized prompt asynchronously to a cloud-based HuggingFace model endpoint. Args: - message (Message): The message containing the input data and associated details - such as conversation ID and role. normalized_conversation (list[Message]): The normalized conversation history. Returns: @@ -105,6 +103,7 @@ async def _send_prompt_target_async( ValueError: If the response from the Hugging Face API is not successful. Exception: If an error occurs during the HTTP request to the Hugging Face endpoint. """ + message = normalized_conversation[-1] request = message.message_pieces[0] headers = {"Authorization": f"Bearer {self.hf_token}"} payload: dict[str, object] = { diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index db2ac82f3d..1ca6f12d37 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -242,18 +242,18 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously sends a message and handles the response within a managed conversation context. Args: - message (Message): The message object. normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. """ + message = normalized_conversation[-1] message_piece: MessagePiece = message.message_pieces[0] json_config = self._get_json_response_config(message_piece=message_piece) diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 95653a5062..07740ea8ae 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -123,18 +123,18 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously send a message to the OpenAI completion target. Args: - message (Message): The message object containing the prompt to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the prompt target. """ + message = normalized_conversation[-1] message_piece = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {message_piece}") diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index a9cd002ed9..8ee8c8a42a 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -151,7 +151,6 @@ def _build_identifier(self) -> ComponentIdentifier: async def _send_prompt_target_async( self, *, - message: Message, normalized_conversation: list[Message], ) -> list[Message]: """ @@ -159,12 +158,12 @@ async def _send_prompt_target_async( Supports both image generation (text input) and image editing (text + images input). Args: - message (Message): The message to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the response from the image target. """ + message = normalized_conversation[-1] logger.info(f"Sending the following prompt to the prompt target: {message}") diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 8587bb938f..229685f752 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -337,13 +337,12 @@ def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str: @limit_requests_per_minute @pyrit_target_retry async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously send a message to the OpenAI realtime target. Args: - message (Message): The message object containing the prompt to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: @@ -352,6 +351,7 @@ async def _send_prompt_target_async( Raises: ValueError: If the message piece type is unsupported. """ + message = normalized_conversation[-1] conversation_id = message.message_pieces[0].conversation_id if conversation_id not in self._existing_conversation: connection = await self.connect(conversation_id=conversation_id) diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 4ada55bac2..a15ad53641 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -519,7 +519,7 @@ async def _construct_message_from_response(self, response: Any, request: Message @limit_requests_per_minute @pyrit_target_retry async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Send prompt, handle agentic tool calls (function_call), return all messages. @@ -529,13 +529,13 @@ async def _send_prompt_target_async( - Agentic tool-calling loops that may require multiple back-and-forth exchanges Args: - message: The initial prompt from the user. normalized_conversation: The normalized conversation history. Returns: List of messages generated during the interaction (assistant responses and tool messages). The normalizer will persist all of these to memory. """ + message = normalized_conversation[-1] message_piece: MessagePiece = message.message_pieces[0] json_config = _JsonResponseConfig(enabled=False) if message.message_pieces: diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index fa2d353351..27d3137ebe 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -115,18 +115,18 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute @pyrit_target_retry async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously send a message to the OpenAI TTS target. Args: - message (Message): The message object containing the prompt to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: A list containing the audio response from the prompt target. """ + message = normalized_conversation[-1] message_piece = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {message_piece}") diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 572e8de651..68fcbbae88 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -183,7 +183,7 @@ def _validate_duration(self) -> None: @limit_requests_per_minute @pyrit_target_retry async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously sends a message and generates a video using the OpenAI SDK. @@ -198,7 +198,6 @@ async def _send_prompt_target_async( chained remixes. Args: - message: The message object containing the prompt. normalized_conversation: The normalized conversation history. Returns: @@ -208,6 +207,7 @@ async def _send_prompt_target_async( RateLimitException: If the rate limit is exceeded. ValueError: If the request is invalid. """ + message = normalized_conversation[-1] text_piece = message.get_piece_by_type(data_type="text") diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 66687cd229..2329c66b8c 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -206,14 +206,12 @@ def _get_selectors(self) -> CopilotSelectors: ) async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Send a message to Microsoft Copilot and return the response. Args: - message (Message): The message to send. Can contain multiple pieces - of type 'text' or 'image_path'. normalized_conversation (list[Message]): The normalized conversation history. Returns: @@ -222,6 +220,7 @@ async def _send_prompt_target_async( Raises: RuntimeError: If an error occurs during interaction. """ + message = normalized_conversation[-1] try: response_content = await self._interact_with_copilot_async(message) diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index dbbb8a9518..757f8e88ad 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -101,13 +101,12 @@ def __init__( @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously send a message to the Playwright target. Args: - message (Message): The message object containing the prompt to send. normalized_conversation (list[Message]): The normalized conversation history. Returns: @@ -116,6 +115,7 @@ async def _send_prompt_target_async( Raises: RuntimeError: If the Playwright page is not initialized or if an error occurs during interaction. """ + message = normalized_conversation[-1] if not self._page: raise RuntimeError( "Playwright page is not initialized. Please pass a Page object when initializing PlaywrightTarget." diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 416639a8c6..8eb65eb4a4 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -124,7 +124,7 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Parse the text in message to separate the userPrompt and documents contents, @@ -134,7 +134,7 @@ async def _send_prompt_target_async( Returns: list[Message]: A list containing the response object with generated text pieces. """ - + message = normalized_conversation[-1] request = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 0ad126fb1a..6c766c7cbc 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -43,18 +43,18 @@ def __init__( self._text_stream = text_stream async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously write a message to the text stream. Args: - message (Message): The message object to write to the stream. normalized_conversation (list[Message]): The normalized conversation history. Returns: list[Message]: An empty list (no response expected). """ + message = normalized_conversation[-1] self._text_stream.write(f"{str(message)}\n") self._text_stream.flush() diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index cceb243437..a11033c167 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -645,7 +645,7 @@ def _generate_consistent_copilot_ids(self, *, pyrit_conversation_id: str) -> tup @limit_requests_per_minute @pyrit_target_retry async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: """ Asynchronously send a message to Microsoft Copilot using WebSocket. @@ -655,7 +655,6 @@ async def _send_prompt_target_async( state server-side, so only the current message is sent (no explicit history required). Args: - message (Message): A message to be sent to the target. normalized_conversation (list[Message]): The normalized conversation history. Returns: @@ -666,6 +665,7 @@ async def _send_prompt_target_async( InvalidStatus: If the WebSocket handshake fails with an HTTP status error. RuntimeError: If any other error occurs during WebSocket communication. """ + message = normalized_conversation[-1] pyrit_conversation_id = message.message_pieces[0].conversation_id is_start_of_session = self._is_start_of_session(conversation_id=pyrit_conversation_id) diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index 643bc3a66b..58c4d4b5be 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -67,8 +67,9 @@ def set_system_prompt( @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: + message = normalized_conversation[-1] self.prompt_sent.append(message.get_value()) return [ diff --git a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py index ae8050672c..885ffd9004 100644 --- a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py +++ b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py @@ -24,7 +24,7 @@ class _MinimalTarget(PromptTarget): """Minimal concrete PromptTarget for contract testing.""" async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: return [] diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 254132f1d0..30ce390f10 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -150,8 +150,9 @@ def set_system_prompt( @limit_requests_per_minute async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: + message = normalized_conversation[-1] self.prompt_sent.append(message.get_value()) return [ diff --git a/tests/unit/prompt_target/target/test_openai_target_auth.py b/tests/unit/prompt_target/target/test_openai_target_auth.py index 834c6b863a..5dfade7dce 100644 --- a/tests/unit/prompt_target/target/test_openai_target_auth.py +++ b/tests/unit/prompt_target/target/test_openai_target_auth.py @@ -34,7 +34,7 @@ async def _construct_message_from_response(self, response, request): def _validate_request(self, *, message) -> None: pass - async def _send_prompt_target_async(self, *, message, normalized_conversation): + async def _send_prompt_target_async(self, *, normalized_conversation): raise NotImplementedError diff --git a/tests/unit/prompt_target/target/test_target_capabilities.py b/tests/unit/prompt_target/target/test_target_capabilities.py index 9d4560e569..2e29c42cd9 100644 --- a/tests/unit/prompt_target/target/test_target_capabilities.py +++ b/tests/unit/prompt_target/target/test_target_capabilities.py @@ -494,7 +494,7 @@ class _ConcreteTarget(PromptTarget): _DEFAULT_CONFIGURATION = default_config async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: return [] diff --git a/tests/unit/prompt_target/test_prompt_chat_target.py b/tests/unit/prompt_target/test_prompt_chat_target.py index 7dca1c743e..3170a3c9db 100644 --- a/tests/unit/prompt_target/test_prompt_chat_target.py +++ b/tests/unit/prompt_target/test_prompt_chat_target.py @@ -106,7 +106,7 @@ class _LegacyTarget(PromptTarget): _DEFAULT_CAPABILITIES = TargetCapabilities(supports_multi_turn=True) async def _send_prompt_target_async( - self, *, message: Message, normalized_conversation: list[Message] + self, *, normalized_conversation: list[Message] ) -> list[Message]: return [] diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index 01dfa2c0a8..89c614f898 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -20,7 +20,6 @@ def __init__(self, *, model_name: str = "mock_model") -> None: async def _send_prompt_target_async( self, *, - message: Message, normalized_conversation: list[Message], ) -> list[Message]: return [ @@ -43,7 +42,6 @@ def __init__(self, *, model_name: str = "mock_chat_model", endpoint: str = "http async def _send_prompt_target_async( self, *, - message: Message, normalized_conversation: list[Message], ) -> list[Message]: return [ From 34bd6563efad90ed356543738d94f18b6ce34dfb Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Tue, 14 Apr 2026 17:59:44 -0400 Subject: [PATCH 03/11] pre-commit --- .../azure_blob_storage_target.py | 4 +-- pyrit/prompt_target/azure_ml_chat_target.py | 4 +-- pyrit/prompt_target/common/prompt_target.py | 8 ++---- pyrit/prompt_target/gandalf_target.py | 4 +-- .../prompt_target/http_target/http_target.py | 4 +-- .../http_target/httpx_api_target.py | 4 +-- .../hugging_face/hugging_face_chat_target.py | 4 +-- .../hugging_face_endpoint_target.py | 4 +-- .../openai/openai_chat_target.py | 4 +-- .../openai/openai_completion_target.py | 4 +-- .../openai/openai_realtime_target.py | 4 +-- .../openai/openai_response_target.py | 4 +-- .../prompt_target/openai/openai_tts_target.py | 4 +-- .../openai/openai_video_target.py | 4 +-- .../playwright_copilot_target.py | 4 +-- pyrit/prompt_target/playwright_target.py | 4 +-- pyrit/prompt_target/prompt_shield_target.py | 4 +-- pyrit/prompt_target/text_target.py | 4 +-- .../prompt_target/websocket_copilot_target.py | 4 +-- tests/integration/mocks.py | 4 +-- .../test_prompt_target_contract.py | 4 +-- tests/unit/mocks.py | 4 +-- .../test_normalize_async_integration.py | 26 ++++++++++++------- .../target/test_target_capabilities.py | 4 +-- .../prompt_target/test_prompt_chat_target.py | 4 +-- 25 files changed, 42 insertions(+), 84 deletions(-) diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index a2a2bf62af..5e927a72db 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -196,9 +196,7 @@ def _parse_url(self) -> tuple[str, str]: return container_url, blob_prefix @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ (Async) Sends prompt to target, which creates a file and uploads it as a blob to the provided storage container. diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index a0f86354c0..cab3698339 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -199,9 +199,7 @@ def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str ) @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Azure ML chat target. diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index b92b808ae2..8b90ad5970 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -131,14 +131,10 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: """ self._validate_request(message=message) normalized_conversation = await self._get_normalized_conversation_async(message=message) - return await self._send_prompt_target_async( - normalized_conversation=normalized_conversation - ) + return await self._send_prompt_target_async(normalized_conversation=normalized_conversation) @abc.abstractmethod - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Target-specific send logic. diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index a1bc38a681..2b80fba859 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -85,9 +85,7 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Gandalf target. diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index dc9459d55e..b01103cb1a 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -155,9 +155,7 @@ def _inject_prompt_into_request(self, request: MessagePiece) -> str: return http_request_w_prompt @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the HTTP target. diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index dfa48e1236..3735ac8c4a 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -108,9 +108,7 @@ def __init__( raise ValueError(f"File uploads are not allowed with HTTP method: {self.method}") @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Override the parent's method to skip raw http_request usage, and do a standard "API mode" approach. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index d7fa18f102..d1231b54d7 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -295,9 +295,7 @@ async def load_model_and_tokenizer(self) -> None: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a normalized prompt asynchronously to the HuggingFace model. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index f1c96a1588..a6e433dd49 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -87,9 +87,7 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a normalized prompt asynchronously to a cloud-based HuggingFace model endpoint. diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 1ca6f12d37..b214da1799 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -241,9 +241,7 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously sends a message and handles the response within a managed conversation context. diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 07740ea8ae..2c447e6e26 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -122,9 +122,7 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI completion target. diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 229685f752..51cd81129d 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -336,9 +336,7 @@ def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI realtime target. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index a15ad53641..94da695bf7 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -518,9 +518,7 @@ async def _construct_message_from_response(self, response: Any, request: Message @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send prompt, handle agentic tool calls (function_call), return all messages. diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 27d3137ebe..1bc8e49a81 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -114,9 +114,7 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI TTS target. diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 68fcbbae88..9fefa981e8 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -182,9 +182,7 @@ def _validate_duration(self) -> None: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously sends a message and generates a video using the OpenAI SDK. diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 2329c66b8c..cbdfac80d4 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -205,9 +205,7 @@ def _get_selectors(self) -> CopilotSelectors: file_picker_selector='span.fui-MenuItem__content:has-text("Upload images and files")', ) - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a message to Microsoft Copilot and return the response. diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index 757f8e88ad..eb64f16286 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -100,9 +100,7 @@ def __init__( self._page = page @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Playwright target. diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 8eb65eb4a4..f26b9f97f4 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -123,9 +123,7 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Parse the text in message to separate the userPrompt and documents contents, then send an HTTP request to the endpoint and obtain a response in JSON. For more info, visit diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 6c766c7cbc..30b7fde435 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -42,9 +42,7 @@ def __init__( super().__init__(custom_configuration=custom_configuration, custom_capabilities=custom_capabilities) self._text_stream = text_stream - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously write a message to the text stream. diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index a11033c167..b9ed993918 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -644,9 +644,7 @@ def _generate_consistent_copilot_ids(self, *, pyrit_conversation_id: str) -> tup @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to Microsoft Copilot using WebSocket. diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index 58c4d4b5be..e222956cb0 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -66,9 +66,7 @@ def set_system_prompt( ) @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: message = normalized_conversation[-1] self.prompt_sent.append(message.get_value()) diff --git a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py index 885ffd9004..c031cbbb08 100644 --- a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py +++ b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py @@ -23,9 +23,7 @@ class _MinimalTarget(PromptTarget): """Minimal concrete PromptTarget for contract testing.""" - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] def _validate_request(self, *, message) -> None: diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 30ce390f10..8c787d3ce0 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -149,9 +149,7 @@ def set_system_prompt( ) @limit_requests_per_minute - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: message = normalized_conversation[-1] self.prompt_sent.append(message.get_value()) diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 647287e367..9ebaa48cbe 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -107,9 +107,12 @@ async def test_openai_chat_target_sends_normalized_to_construct_request(): mock_completion = _create_mock_chat_completion("response") target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) - with patch.object( - target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg] - ), patch.object(target, "_construct_request_body", new_callable=AsyncMock, return_value={"model": "gpt-4o", "messages": []}) as mock_construct: + with ( + patch.object(target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg]), + patch.object( + target, "_construct_request_body", new_callable=AsyncMock, return_value={"model": "gpt-4o", "messages": []} + ) as mock_construct, + ): await target.send_prompt_async(message=user_msg) # _construct_request_body should receive the adapted message, not the original @@ -191,7 +194,9 @@ async def test_openai_response_target_calls_normalize_async(): mock_response.output[0].content = [MagicMock()] mock_response.output[0].content[0].type = "output_text" mock_response.output[0].content[0].text = "world" - mock_response.model_dump_json.return_value = json.dumps({"output": [{"type": "message", "content": [{"type": "output_text", "text": "world"}]}]}) + mock_response.model_dump_json.return_value = json.dumps( + {"output": [{"type": "message", "content": [{"type": "output_text", "text": "world"}]}]} + ) target._async_client.responses.create = AsyncMock(return_value=mock_response) with patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize: @@ -222,8 +227,10 @@ async def test_azure_ml_target_calls_normalize_async(): mock_memory.get_conversation.return_value = [] target._memory = mock_memory - with patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize, \ - patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"): + with ( + patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize, + patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"), + ): mock_normalize.return_value = [user_msg] await target.send_prompt_async(message=user_msg) @@ -246,9 +253,10 @@ async def test_azure_ml_target_sends_normalized_to_complete_chat(): mock_memory.get_conversation.return_value = [] target._memory = mock_memory - with patch.object( - target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg] - ), patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat: + with ( + patch.object(target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg]), + patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat, + ): await target.send_prompt_async(message=user_msg) call_messages = mock_chat.call_args.kwargs["messages"] diff --git a/tests/unit/prompt_target/target/test_target_capabilities.py b/tests/unit/prompt_target/target/test_target_capabilities.py index 2e29c42cd9..72b2fa3a46 100644 --- a/tests/unit/prompt_target/target/test_target_capabilities.py +++ b/tests/unit/prompt_target/target/test_target_capabilities.py @@ -493,9 +493,7 @@ def _make_target_class(self, *, default_config: "TargetConfiguration"): class _ConcreteTarget(PromptTarget): _DEFAULT_CONFIGURATION = default_config - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] return _ConcreteTarget diff --git a/tests/unit/prompt_target/test_prompt_chat_target.py b/tests/unit/prompt_target/test_prompt_chat_target.py index 3170a3c9db..562450b077 100644 --- a/tests/unit/prompt_target/test_prompt_chat_target.py +++ b/tests/unit/prompt_target/test_prompt_chat_target.py @@ -105,9 +105,7 @@ def test_init_subclass_promotes_default_capabilities_with_warning(): class _LegacyTarget(PromptTarget): _DEFAULT_CAPABILITIES = TargetCapabilities(supports_multi_turn=True) - async def _send_prompt_target_async( - self, *, normalized_conversation: list[Message] - ) -> list[Message]: + async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] From 858c94103ec8a9ee2ec0e8895a3233eb2287a397 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 15 Apr 2026 14:23:15 -0400 Subject: [PATCH 04/11] update deprecation message --- pyrit/prompt_target/azure_ml_chat_target.py | 10 ++++------ .../common/conversation_normalization_pipeline.py | 7 ++++--- pyrit/prompt_target/common/target_configuration.py | 6 ++++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index cab3698339..547711e1eb 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -85,7 +85,7 @@ def __init__( Used for identification purposes. Defaults to empty string. message_normalizer (MessageListNormalizer, Optional): **Deprecated.** Use ``custom_configuration`` with ``CapabilityHandlingPolicy`` instead. Previously used for - models that do not allow system prompts. Defaults to ChatMessageNormalizer(). + models that do not allow system prompts. Will be removed in v0.14.0. max_new_tokens (int, Optional): The maximum number of tokens to generate in the response. Defaults to 400. @@ -114,12 +114,9 @@ def __init__( ) # Translate legacy message_normalizer into TargetConfiguration - if message_normalizer is not None and isinstance(message_normalizer, GenericSystemSquashNormalizer): + if message_normalizer is not None: warnings.warn( - "Passing GenericSystemSquashNormalizer as message_normalizer is deprecated. " - "Use custom_configuration=TargetConfiguration(capabilities=TargetCapabilities(" - "supports_system_prompt=False), policy=CapabilityHandlingPolicy(behaviors={" - "CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT})) instead. " + "Passing message_normalizer is deprecated. Use custom_configuration with CapabilityHandlingPolicy instead. " "Will be removed in v0.14.0.", DeprecationWarning, stacklevel=2, @@ -138,6 +135,7 @@ def __init__( CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, } ), + normalizer_overrides={CapabilityName.SYSTEM_PROMPT: MessageListNormalizer([ChatMessageNormalizer()])} ) PromptChatTarget.__init__( diff --git a/pyrit/prompt_target/common/conversation_normalization_pipeline.py b/pyrit/prompt_target/common/conversation_normalization_pipeline.py index 3b6fd0f90c..6c50b78084 100644 --- a/pyrit/prompt_target/common/conversation_normalization_pipeline.py +++ b/pyrit/prompt_target/common/conversation_normalization_pipeline.py @@ -2,6 +2,8 @@ # Licensed under the MIT license. import logging +from collections.abc import Mapping +from typing import Any from pyrit.message_normalizer import ( GenericSystemSquashNormalizer, @@ -63,7 +65,7 @@ def from_capabilities( *, capabilities: TargetCapabilities, policy: CapabilityHandlingPolicy, - normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None, + normalizer_overrides: Mapping[CapabilityName, MessageListNormalizer[Any]] | None = None, ) -> "ConversationNormalizationPipeline": """ Resolve capabilities and policy into a concrete pipeline of normalizers. @@ -80,7 +82,7 @@ def from_capabilities( Args: capabilities (TargetCapabilities): The target's declared capabilities. policy (CapabilityHandlingPolicy): How to handle each missing capability. - normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): + normalizer_overrides (Mapping[CapabilityName, MessageListNormalizer[Any]] | None): Optional overrides for specific capability normalizers. Falls back to the defaults from ``_NORMALIZER_REGISTRY``. @@ -103,7 +105,6 @@ def from_capabilities( # workflow is implemented. if behavior == UnsupportedCapabilityBehavior.ADAPT: normalizer = overrides.get(capability, default_normalizer) - normalizers.append(normalizer) return cls(normalizers=tuple(normalizers)) diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index 0daabd7fdb..c7784996ed 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -3,6 +3,8 @@ import logging import warnings +from collections.abc import Mapping +from typing import Any from pyrit.message_normalizer import MessageListNormalizer from pyrit.models import Message @@ -84,7 +86,7 @@ def __init__( *, capabilities: TargetCapabilities, policy: CapabilityHandlingPolicy | None = None, - normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None, + normalizer_overrides: Mapping[CapabilityName, MessageListNormalizer[Any]] | None = None, ) -> None: """ Build a target configuration and resolve the normalization pipeline. @@ -93,7 +95,7 @@ def __init__( capabilities (TargetCapabilities): The target's declared capabilities. policy (CapabilityHandlingPolicy | None): How to handle each missing capability. Defaults to RAISE for all adaptable capabilities. - normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): + normalizer_overrides (Mapping[CapabilityName, MessageListNormalizer[Any]] | None): Optional overrides for specific capability normalizers. """ self._capabilities = capabilities From 709b517d18ec770d35451c94f72e847d0a4f82ef Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 15 Apr 2026 18:56:22 -0400 Subject: [PATCH 05/11] pr comments --- .../azure_blob_storage_target.py | 6 +- pyrit/prompt_target/azure_ml_chat_target.py | 98 +++++----- pyrit/prompt_target/common/prompt_target.py | 87 ++++----- pyrit/prompt_target/gandalf_target.py | 6 +- .../prompt_target/http_target/http_target.py | 6 +- .../http_target/httpx_api_target.py | 2 +- .../hugging_face/hugging_face_chat_target.py | 6 +- .../hugging_face_endpoint_target.py | 11 +- .../openai/openai_chat_target.py | 6 +- .../openai/openai_completion_target.py | 6 +- .../openai/openai_image_target.py | 11 +- .../openai/openai_realtime_target.py | 6 +- .../openai/openai_response_target.py | 6 +- .../prompt_target/openai/openai_tts_target.py | 6 +- .../openai/openai_video_target.py | 13 +- .../playwright_copilot_target.py | 6 +- pyrit/prompt_target/playwright_target.py | 6 +- pyrit/prompt_target/prompt_shield_target.py | 2 +- pyrit/prompt_target/text_target.py | 8 +- .../prompt_target/websocket_copilot_target.py | 13 +- tests/integration/mocks.py | 4 +- .../test_prompt_target_contract.py | 4 +- tests/unit/mocks.py | 4 +- .../target/test_huggingface_chat_target.py | 2 +- .../prompt_target/target/test_image_target.py | 4 +- .../test_normalize_async_integration.py | 167 ++++++++++++++++-- .../target/test_openai_chat_target.py | 2 +- .../target/test_openai_response_target.py | 6 +- .../target/test_openai_target_auth.py | 4 +- .../target/test_playwright_copilot_target.py | 6 +- .../target/test_playwright_target.py | 10 +- .../target/test_realtime_target.py | 2 +- .../target/test_target_capabilities.py | 2 +- .../prompt_target/target/test_tts_target.py | 4 +- .../prompt_target/target/test_video_target.py | 29 ++- .../target/test_websocket_copilot_target.py | 8 +- .../prompt_target/test_prompt_chat_target.py | 2 +- tests/unit/registry/test_target_registry.py | 8 +- 38 files changed, 383 insertions(+), 196 deletions(-) diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index 5e927a72db..71664ed88a 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -196,13 +196,15 @@ def _parse_url(self) -> tuple[str, str]: return container_url, blob_prefix @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ (Async) Sends prompt to target, which creates a file and uploads it as a blob to the provided storage container. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response with the Blob URL. diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 547711e1eb..ba92008369 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -3,7 +3,7 @@ import logging import warnings -from typing import Any, Optional +from typing import Any from httpx import HTTPStatusError @@ -15,7 +15,7 @@ pyrit_target_retry, ) from pyrit.identifiers import ComponentIdentifier -from pyrit.message_normalizer import ChatMessageNormalizer, GenericSystemSquashNormalizer, MessageListNormalizer +from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer from pyrit.models import ( Message, construct_response_from_request, @@ -60,48 +60,48 @@ class AzureMLChatTarget(PromptChatTarget): def __init__( self, *, - endpoint: Optional[str] = None, - api_key: Optional[str] = None, + endpoint: str | None = None, + api_key: str | None = None, model_name: str = "", - message_normalizer: Optional[MessageListNormalizer[Any]] = None, + message_normalizer: MessageListNormalizer[Any] | None = None, max_new_tokens: int = 400, temperature: float = 1.0, top_p: float = 1.0, repetition_penalty: float = 1.0, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, - custom_capabilities: Optional[TargetCapabilities] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, + custom_capabilities: TargetCapabilities | None = None, **param_kwargs: Any, ) -> None: """ Initialize an instance of the AzureMLChatTarget class. Args: - endpoint (str, Optional): The endpoint URL for the deployed Azure ML model. + endpoint (str | None): The endpoint URL for the deployed Azure ML model. Defaults to the value of the AZURE_ML_MANAGED_ENDPOINT environment variable. - api_key (str, Optional): The API key for accessing the Azure ML endpoint. + api_key (str | None): The API key for accessing the Azure ML endpoint. Defaults to the value of the `AZURE_ML_KEY` environment variable. - model_name (str, Optional): The name of the model being used (e.g., "Llama-3.2-3B-Instruct"). + model_name (str): The name of the model being used (e.g., "Llama-3.2-3B-Instruct"). Used for identification purposes. Defaults to empty string. - message_normalizer (MessageListNormalizer, Optional): **Deprecated.** Use + message_normalizer (MessageListNormalizer[Any] | None): **Deprecated.** Use ``custom_configuration`` with ``CapabilityHandlingPolicy`` instead. Previously used for models that do not allow system prompts. Will be removed in v0.14.0. - max_new_tokens (int, Optional): The maximum number of tokens to generate in the response. + max_new_tokens (int): The maximum number of tokens to generate in the response. Defaults to 400. - temperature (float, Optional): The temperature for generating diverse responses. 1.0 is most random, + temperature (float): The temperature for generating diverse responses. 1.0 is most random, 0.0 is least random. Defaults to 1.0. - top_p (float, Optional): The top-p value for generating diverse responses. It represents + top_p (float): The top-p value for generating diverse responses. It represents the cumulative probability of the top tokens to keep. Defaults to 1.0. - repetition_penalty (float, Optional): The repetition penalty for generating diverse responses. + repetition_penalty (float): The repetition penalty for generating diverse responses. 1.0 means no penalty with a greater value (up to 2.0) meaning more penalty for repeating tokens. Defaults to 1.2. - max_requests_per_minute (int, Optional): Number of requests the target can handle per + max_requests_per_minute (int | None): Number of requests the target can handle per minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. - custom_configuration (TargetConfiguration, Optional): Override the default configuration for this target + custom_configuration (TargetConfiguration | None): Override the default configuration for this target instance. Useful for targets whose capabilities depend on deployment configuration. - custom_capabilities (TargetCapabilities, Optional): **Deprecated.** Use + custom_capabilities (TargetCapabilities | None): **Deprecated.** Use ``custom_configuration`` instead. Will be removed in v0.14.0. **param_kwargs: Additional parameters to pass to the model for generating responses. Example parameters can be found here: https://huggingface.co/docs/api-inference/tasks/text-generation. @@ -115,28 +115,35 @@ def __init__( # Translate legacy message_normalizer into TargetConfiguration if message_normalizer is not None: + if custom_configuration is not None: + raise ValueError( + "Cannot specify both 'message_normalizer' and 'custom_configuration'. " + "Use 'custom_configuration' only; 'message_normalizer' is deprecated and " + "will be removed in v0.14.0." + ) warnings.warn( - "Passing message_normalizer is deprecated. Use custom_configuration with CapabilityHandlingPolicy instead. " - "Will be removed in v0.14.0.", + "Passing message_normalizer is deprecated. Use custom_configuration with " + "CapabilityHandlingPolicy instead. Will be removed in v0.14.0.", DeprecationWarning, stacklevel=2, ) - if custom_configuration is None: - custom_configuration = TargetConfiguration( - capabilities=TargetCapabilities( - supports_multi_message_pieces=True, - supports_editable_history=True, - supports_multi_turn=True, - supports_system_prompt=False, - ), - policy=CapabilityHandlingPolicy( - behaviors={ - CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, - } - ), - normalizer_overrides={CapabilityName.SYSTEM_PROMPT: MessageListNormalizer([ChatMessageNormalizer()])} - ) + # The legacy message_normalizer was primarily used to handle system prompts + # for models that don't support them (e.g. GenericSystemSquashNormalizer). + # We translate it into a TargetConfiguration that marks system_prompt as + # unsupported + ADAPT so the pipeline invokes the user's normalizer. + default_caps = self._DEFAULT_CONFIGURATION.capabilities + default_behaviors = dict(self._DEFAULT_CONFIGURATION.policy.behaviors) + default_behaviors[CapabilityName.SYSTEM_PROMPT] = UnsupportedCapabilityBehavior.ADAPT + custom_configuration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_message_pieces=default_caps.supports_multi_message_pieces, + supports_editable_history=default_caps.supports_editable_history, + supports_multi_turn=default_caps.supports_multi_turn, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy(behaviors=default_behaviors), + normalizer_overrides={CapabilityName.SYSTEM_PROMPT: message_normalizer}, + ) PromptChatTarget.__init__( self, @@ -152,7 +159,7 @@ def __init__( validate_temperature(temperature) validate_top_p(top_p) - self.message_normalizer = message_normalizer if message_normalizer is not None else ChatMessageNormalizer() + self.message_normalizer = message_normalizer self._max_new_tokens = max_new_tokens self._temperature = temperature self._top_p = top_p @@ -172,11 +179,10 @@ def _build_identifier(self) -> ComponentIdentifier: "top_p": self._top_p, "max_new_tokens": self._max_new_tokens, "repetition_penalty": self._repetition_penalty, - "message_normalizer": self.message_normalizer.__class__.__name__, }, ) - def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str] = None) -> None: + def _initialize_vars(self, endpoint: str | None = None, api_key: str | None = None) -> None: """ Set the endpoint and key for accessing the Azure ML model. Use this function to manually pass in your own endpoint uri and api key. Defaults to the values in the .env file for the variables @@ -197,12 +203,14 @@ def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str ) @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Azure ML chat target. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. @@ -288,8 +296,8 @@ async def _construct_http_body_async( Returns: dict: The constructed HTTP request body. """ - # Use the message normalizer to convert Messages to dict format - messages_dict = await self.message_normalizer.normalize_to_dicts_async(messages) + wire_format = ChatMessageNormalizer() + messages_dict = await wire_format.normalize_to_dicts_async(messages) # Parameters include additional ones passed in through **kwargs. Those not accepted by the model will # be ignored. We only include commonly supported parameters here - model-specific parameters like @@ -321,5 +329,5 @@ def _get_headers(self) -> dict[str, str]: return headers - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 015a4dd6e5..48f66437ea 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -4,7 +4,7 @@ import abc import logging import warnings -from typing import Any, Optional, Union +from typing import Any, Union, final from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory, MemoryInterface @@ -29,7 +29,7 @@ class PromptTarget(Identifiable): # An empty list implies that the prompt target supports all converters. supported_converters: list[Any] - _identifier: Optional[ComponentIdentifier] = None + _identifier: ComponentIdentifier | None = None # Class-level default configuration for this target type. # @@ -63,30 +63,30 @@ def __init_subclass__(cls, **kwargs: Any) -> None: def __init__( self, verbose: bool = False, - max_requests_per_minute: Optional[int] = None, + max_requests_per_minute: int | None = None, endpoint: str = "", model_name: str = "", - underlying_model: Optional[str] = None, - custom_configuration: Optional[TargetConfiguration] = None, - custom_capabilities: Optional[TargetCapabilities] = None, + underlying_model: str | None = None, + custom_configuration: TargetConfiguration | None = None, + custom_capabilities: TargetCapabilities | None = None, ) -> None: """ Initialize the PromptTarget. Args: verbose (bool): Enable verbose logging. Defaults to False. - max_requests_per_minute (int, Optional): Maximum number of requests per minute. + max_requests_per_minute (int | None): Maximum number of requests per minute. endpoint (str): The endpoint URL. Defaults to empty string. model_name (str): The model name. Defaults to empty string. - underlying_model (str, Optional): The underlying model name (e.g., "gpt-4o") for + underlying_model (str | None): The underlying model name (e.g., "gpt-4o") for identification purposes. This is useful when the deployment name in Azure differs from the actual model. If not provided, ``model_name`` will be used for the identifier. Defaults to None. - custom_configuration (TargetConfiguration, Optional): Override the default configuration + custom_configuration (TargetConfiguration | None): Override the default configuration for this target instance. Useful for targets whose capabilities depend on deployment configuration (e.g., Playwright, HTTP). If None, uses the class-level ``_DEFAULT_CONFIGURATION``. Defaults to None. - custom_capabilities (TargetCapabilities, Optional): **Deprecated.** Use + custom_capabilities (TargetCapabilities | None): **Deprecated.** Use ``custom_configuration`` instead. Will be removed in v0.14.0. """ custom_configuration = resolve_configuration_compat( @@ -108,20 +108,21 @@ def __init__( if self._verbose: logging.basicConfig(level=logging.INFO) + @final async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Validate, normalize, and send a prompt to the target. This is the public entry point called by the prompt normalizer. It: - 1. Validates the request against the target's capabilities. - 2. Fetches the conversation from memory, appends ``message``, and runs + 1. Fetches the conversation from memory, appends ``message``, and runs the normalization pipeline (system‑squash, history‑squash, etc.). - 3. Delegates to :meth:`_send_prompt_target_async` with the normalized + 2. Validates the normalized conversation against the target's capabilities. + 3. Delegates to :meth:`_send_prompt_to_target_async` with the normalized conversation. Subclasses MUST NOT override this method. Override - :meth:`_send_prompt_target_async` instead. + :meth:`_send_prompt_to_target_async` instead. Args: message (Message): The message to send. @@ -129,12 +130,16 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Returns: list[Message]: Response messages from the target. """ - self._validate_request(message=message) + if not message.message_pieces: + raise ValueError("Message must contain at least one message piece. Received: 0 pieces.") normalized_conversation = await self._get_normalized_conversation_async(message=message) - return await self._send_prompt_target_async(normalized_conversation=normalized_conversation) + if not normalized_conversation: + raise ValueError("Normalization pipeline returned an empty conversation. Cannot send an empty request.") + self._validate_request(normalized_conversation=normalized_conversation) + return await self._send_prompt_to_target_async(normalized_conversation=normalized_conversation) @abc.abstractmethod - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Target-specific send logic. @@ -149,22 +154,24 @@ async def _send_prompt_target_async(self, *, normalized_conversation: list[Messa list[Message]: Response messages from the target. """ - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ - Validate the provided message. + Validate the normalized conversation before sending to the target. + + Called after the normalization pipeline has run. Validates the last + message (the current request) for piece count, data types, and checks + whether the full conversation violates multi-turn constraints. Args: - message: The message to validate. + normalized_conversation: The normalized conversation to validate. + The last element is the current request message. Raises: - ValueError: if the target does not support the provided message pieces or if the message - violates any constraints based on the target's capabilities. This includes checks - for the number of message pieces, supported data types, and multi-turn conversation support. - + ValueError: if the target does not support the provided message pieces or if the + conversation violates any constraints based on the target's capabilities. """ + message = normalized_conversation[-1] n_pieces = len(message.message_pieces) - if n_pieces == 0: - raise ValueError("Message must contain at least one message piece. Received: 0 pieces.") custom_configuration_message = ( "If your target does support this, set the custom_configuration parameter accordingly." @@ -185,14 +192,10 @@ def _validate_request(self, *, message: Message) -> None: f"{custom_configuration_message}" ) - if not self.capabilities.supports_multi_turn: - request = message.message_pieces[0] - messages = self._memory.get_message_pieces(conversation_id=request.conversation_id) - - if len(messages) > 0: - raise ValueError( - f"This target only supports a single turn conversation. {custom_configuration_message}" - ) + if not self.capabilities.supports_multi_turn and len(normalized_conversation) > 1: + raise ValueError( + f"This target only supports a single turn conversation. {custom_configuration_message}" + ) async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: """ @@ -210,9 +213,9 @@ async def _get_normalized_conversation_async(self, *, message: Message) -> list[ history squashed, etc.). """ conversation_id = message.message_pieces[0].conversation_id - conversation = self._memory.get_conversation(conversation_id=conversation_id) + conversation = list(self._memory.get_conversation(conversation_id=conversation_id)) conversation.append(message) - return await self.configuration.normalize_async(messages=list(conversation)) + return await self.configuration.normalize_async(messages=conversation) def set_model_name(self, *, model_name: str) -> None: """ @@ -232,8 +235,8 @@ def dispose_db_engine(self) -> None: def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] | None = None, ) -> ComponentIdentifier: """ Construct the target identifier. @@ -246,9 +249,9 @@ def _create_identifier( to set the identifier with their specific parameters. Args: - params (Optional[Dict[str, Any]]): Additional behavioral parameters from + params (dict[str, Any] | None): Additional behavioral parameters from the subclass (e.g., temperature, top_p). Merged into the base params. - children (Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]]): + children (dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] | None): Named child component identifiers. Returns: @@ -294,7 +297,7 @@ def capabilities(self) -> TargetCapabilities: return self._configuration.capabilities @classmethod - def get_default_configuration(cls, underlying_model: Optional[str] = None) -> TargetConfiguration: + def get_default_configuration(cls, underlying_model: str | None = None) -> TargetConfiguration: """ Return the configuration for the given underlying model, falling back to the class-level ``_DEFAULT_CONFIGURATION`` when the model is not recognized. @@ -319,7 +322,7 @@ def get_default_configuration(cls, underlying_model: Optional[str] = None) -> Ta return cls._DEFAULT_CONFIGURATION @classmethod - def get_default_capabilities(cls, underlying_model: Optional[str] = None) -> TargetCapabilities: + def get_default_capabilities(cls, underlying_model: str | None = None) -> TargetCapabilities: """ Return the default capabilities for the given model. diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 2b80fba859..99ca1c3b7d 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -85,12 +85,14 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Gandalf target. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index b01103cb1a..2a44024c36 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -155,12 +155,14 @@ def _inject_prompt_into_request(self, request: MessagePiece) -> str: return http_request_w_prompt @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the HTTP target. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index 3735ac8c4a..ae5e8989c1 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -108,7 +108,7 @@ def __init__( raise ValueError(f"File uploads are not allowed with HTTP method: {self.method}") @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Override the parent's method to skip raw http_request usage, and do a standard "API mode" approach. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index d1231b54d7..7583a5efe4 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -295,12 +295,14 @@ async def load_model_and_tokenizer(self) -> None: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a normalized prompt asynchronously to the HuggingFace model. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response object with generated text pieces. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index a6e433dd49..da47fa34a2 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -87,12 +87,14 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a normalized prompt asynchronously to a cloud-based HuggingFace model endpoint. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response object with generated text pieces. @@ -145,16 +147,17 @@ async def _send_prompt_target_async(self, *, normalized_conversation: list[Messa logger.error(f"Error occurred during HTTP request to the Hugging Face endpoint: {e}") raise - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validate the provided message. Args: - message (Message): The message to validate. + normalized_conversation: The normalized conversation to validate. Raises: ValueError: If the request is not valid for this target. """ + message = normalized_conversation[-1] n_pieces = len(message.message_pieces) if n_pieces != 1: raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index c7f5cd00b5..021b0d2e0a 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -240,12 +240,14 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously sends a message and handles the response within a managed conversation context. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 9924486639..c7a851d729 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -121,12 +121,14 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI completion target. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 5ad3aed9db..7fb98da604 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -147,7 +147,7 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async( + async def _send_prompt_to_target_async( self, *, normalized_conversation: list[Message], @@ -157,7 +157,9 @@ async def _send_prompt_target_async( Supports both image generation (text input) and image editing (text + images input). Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the image target. @@ -320,8 +322,9 @@ async def _get_image_bytes(self, image_data: Any) -> bytes: raise EmptyResponseException(message="The image generation returned an empty response.") - def _validate_request(self, *, message: Message) -> None: - super()._validate_request(message=message) + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + super()._validate_request(normalized_conversation=normalized_conversation) + message = normalized_conversation[-1] text_pieces = [p for p in message.message_pieces if p.converted_value_data_type == "text"] image_pieces = [p for p in message.message_pieces if p.converted_value_data_type == "image_path"] diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 96be320da6..dfe3d52b30 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -335,12 +335,14 @@ def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI realtime target. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index f02328ecad..45e1435cbe 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -517,7 +517,7 @@ async def _construct_message_from_response(self, response: Any, request: Message @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send prompt, handle agentic tool calls (function_call), return all messages. @@ -526,7 +526,9 @@ async def _send_prompt_target_async(self, *, normalized_conversation: list[Messa - Agentic tool-calling loops that may require multiple back-and-forth exchanges Args: - normalized_conversation: The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: List of messages generated during the interaction (assistant responses and tool messages). diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 32d44a686d..3013bfb963 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -113,12 +113,14 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI TTS target. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the audio response from the prompt target. diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 3f6a1a8eb2..7d7737d5bc 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -181,7 +181,7 @@ def _validate_duration(self) -> None: @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously sends a message and generates a video using the OpenAI SDK. @@ -195,7 +195,9 @@ async def _send_prompt_target_async(self, *, normalized_conversation: list[Messa chained remixes. Args: - normalized_conversation: The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: A list containing the response with the generated video path. @@ -458,7 +460,7 @@ async def _save_video_response( prompt_metadata=prompt_metadata, ) - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validate the request message. @@ -468,12 +470,13 @@ def _validate_request(self, *, message: Message) -> None: - Text piece + video_path piece (remix mode via history lookup) Args: - message: The message to validate. + normalized_conversation: The normalized conversation to validate. Raises: ValueError: If the request is invalid. """ - super()._validate_request(message=message) + super()._validate_request(normalized_conversation=normalized_conversation) + message = normalized_conversation[-1] text_pieces = message.get_pieces_by_type(data_type="text") image_pieces = message.get_pieces_by_type(data_type="image_path") diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index cbdfac80d4..58d30f088b 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -205,12 +205,14 @@ def _get_selectors(self) -> CopilotSelectors: file_picker_selector='span.fui-MenuItem__content:has-text("Upload images and files")', ) - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a message to Microsoft Copilot and return the response. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from Copilot. diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index eb64f16286..c77bf41178 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -100,12 +100,14 @@ def __init__( self._page = page @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Playwright target. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index f26b9f97f4..9fb9b7d30f 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -123,7 +123,7 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Parse the text in message to separate the userPrompt and documents contents, then send an HTTP request to the endpoint and obtain a response in JSON. For more info, visit diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 30b7fde435..0c29769219 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -42,12 +42,14 @@ def __init__( super().__init__(custom_configuration=custom_configuration, custom_capabilities=custom_capabilities) self._text_stream = text_stream - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously write a message to the text stream. Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: An empty list (no response expected). @@ -95,7 +97,7 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: self._memory.add_message_pieces_to_memory(message_pieces=message_pieces) return message_pieces - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass async def cleanup_target(self) -> None: diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index b9ed993918..67b7977107 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -584,17 +584,18 @@ async def _connect_and_send( return response - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validate that the message meets target requirements. Args: - message (Message): The message to validate. + normalized_conversation (list[Message]): The normalized conversation to validate. Raises: ValueError: If message contains unsupported data types or invalid image formats. """ - super()._validate_request(message=message) + super()._validate_request(normalized_conversation=normalized_conversation) + message = normalized_conversation[-1] for piece in message.message_pieces: piece_type = piece.converted_value_data_type @@ -644,7 +645,7 @@ def _generate_consistent_copilot_ids(self, *, pyrit_conversation_id: str) -> tup @limit_requests_per_minute @pyrit_target_retry - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to Microsoft Copilot using WebSocket. @@ -653,7 +654,9 @@ async def _send_prompt_target_async(self, *, normalized_conversation: list[Messa state server-side, so only the current message is sent (no explicit history required). Args: - normalized_conversation (list[Message]): The normalized conversation history. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from Copilot. diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index e222956cb0..dd0e187bac 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -66,7 +66,7 @@ def set_system_prompt( ) @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: message = normalized_conversation[-1] self.prompt_sent.append(message.get_value()) @@ -80,7 +80,7 @@ async def _send_prompt_target_async(self, *, normalized_conversation: list[Messa ).to_message() ] - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validates the provided message """ diff --git a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py index c031cbbb08..07e0393556 100644 --- a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py +++ b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py @@ -23,10 +23,10 @@ class _MinimalTarget(PromptTarget): """Minimal concrete PromptTarget for contract testing.""" - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] - def _validate_request(self, *, message) -> None: + def _validate_request(self, *, normalized_conversation) -> None: pass diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 8c787d3ce0..84850dfc5d 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -149,7 +149,7 @@ def set_system_prompt( ) @limit_requests_per_minute - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: message = normalized_conversation[-1] self.prompt_sent.append(message.get_value()) @@ -163,7 +163,7 @@ async def _send_prompt_target_async(self, *, normalized_conversation: list[Messa ).to_message() ] - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validates the provided message """ diff --git a/tests/unit/prompt_target/target/test_huggingface_chat_target.py b/tests/unit/prompt_target/target/test_huggingface_chat_target.py index 6ce2a388ec..8148c278dd 100644 --- a/tests/unit/prompt_target/target/test_huggingface_chat_target.py +++ b/tests/unit/prompt_target/target/test_huggingface_chat_target.py @@ -232,7 +232,7 @@ async def test_invalid_prompt_request_validation(): message = Message(message_pieces=[message_piece1, message_piece2]) with pytest.raises(ValueError) as excinfo: - hf_chat._validate_request(message=message) + hf_chat._validate_request(normalized_conversation=[message]) assert "This target only supports a single message piece." in str(excinfo.value) diff --git a/tests/unit/prompt_target/target/test_image_target.py b/tests/unit/prompt_target/target/test_image_target.py index 0d495535c3..8f31b515ce 100644 --- a/tests/unit/prompt_target/target/test_image_target.py +++ b/tests/unit/prompt_target/target/test_image_target.py @@ -518,8 +518,10 @@ async def test_validate_previous_conversations( ): message_piece = sample_conversations[0] + prior_message = Message(message_pieces=[message_piece]) + mock_memory = MagicMock() - mock_memory.get_message_pieces.return_value = sample_conversations + mock_memory.get_conversation.return_value = [prior_message] mock_memory.add_message_to_memory = AsyncMock() image_target._memory = mock_memory diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 9ebaa48cbe..14265c24e2 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -160,8 +160,8 @@ async def test_openai_chat_target_memory_not_mutated(): await target.send_prompt_async(message=user_msg) - # Memory-backed conversation should still contain the system message - assert len(memory_conversation) == 2 # system + appended user + # Memory-backed conversation must not be mutated by send_prompt_async + assert len(memory_conversation) == 1 assert memory_conversation[0].get_piece().api_role == "system" @@ -299,8 +299,8 @@ async def test_azure_ml_target_memory_not_mutated(): with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"): await target.send_prompt_async(message=user_msg) - # Memory must still have original system message - assert len(memory_conversation) == 2 + # Memory must still have original system message only (not mutated) + assert len(memory_conversation) == 1 assert memory_conversation[0].get_piece().api_role == "system" @@ -325,6 +325,7 @@ def test_azure_ml_generic_system_squash_normalizer_emits_deprecation_warning(): @pytest.mark.usefixtures("patch_central_database") def test_azure_ml_generic_system_squash_normalizer_creates_adapt_configuration(): + """Legacy message_normalizer should be translated into a TargetConfiguration with ADAPT policy.""" with warnings.catch_warnings(record=True): warnings.simplefilter("always") target = AzureMLChatTarget( @@ -332,16 +333,15 @@ def test_azure_ml_generic_system_squash_normalizer_creates_adapt_configuration() api_key="valid_api_key", message_normalizer=GenericSystemSquashNormalizer(), ) - # The configuration should now have supports_system_prompt=False with ADAPT policy + # The shim should create a config with supports_system_prompt=False assert not target.capabilities.supports_system_prompt - # Pipeline should have a system squash normalizer assert target.configuration.includes(capability=CapabilityName.MULTI_TURN) assert not target.configuration.includes(capability=CapabilityName.SYSTEM_PROMPT) @pytest.mark.usefixtures("patch_central_database") -def test_azure_ml_generic_system_squash_normalizer_does_not_override_explicit_config(): - """If custom_configuration is already provided, message_normalizer deprecation should not override it.""" +def test_azure_ml_message_normalizer_and_custom_config_raises(): + """Passing both message_normalizer and custom_configuration should raise ValueError.""" custom_config = TargetConfiguration( capabilities=TargetCapabilities( supports_multi_turn=True, @@ -349,16 +349,13 @@ def test_azure_ml_generic_system_squash_normalizer_does_not_override_explicit_co supports_multi_message_pieces=True, ) ) - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - target = AzureMLChatTarget( + with pytest.raises(ValueError, match="Cannot specify both"): + AzureMLChatTarget( endpoint="http://aml-test-endpoint.com", api_key="valid_api_key", message_normalizer=GenericSystemSquashNormalizer(), custom_configuration=custom_config, ) - # Explicit custom_configuration should win - assert target.capabilities.supports_system_prompt @pytest.mark.asyncio @@ -401,3 +398,147 @@ async def test_azure_ml_system_squash_via_configuration_pipeline(): # The squashed message should contain the system content assert "be nice" in call_messages[0].get_value() assert "hello" in call_messages[0].get_value() + + +# --------------------------------------------------------------------------- +# _get_normalized_conversation_async — unit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_fetches_history_and_appends_message(): + """The method should fetch history from memory, append the current message, and return them.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + history_msg = _make_message(role="assistant", content="previous answer") + user_msg = _make_message(role="user", content="new question") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [history_msg] + target._memory = mock_memory + + result = await target._get_normalized_conversation_async(message=user_msg) + + mock_memory.get_conversation.assert_called_once_with(conversation_id="conv1") + assert len(result) == 2 + assert result[0].get_value() == "previous answer" + assert result[1].get_value() == "new question" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_empty_history(): + """When memory has no history, the result should contain only the current message.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + result = await target._get_normalized_conversation_async(message=user_msg) + + assert len(result) == 1 + assert result[0].get_value() == "hello" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_does_not_mutate_memory(): + """The original memory-backed list must not be modified by the method.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + history_msg = _make_message(role="assistant", content="old") + user_msg = _make_message(role="user", content="new") + + memory_list: MutableSequence[Message] = [history_msg] + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = memory_list + target._memory = mock_memory + + await target._get_normalized_conversation_async(message=user_msg) + + # Memory list must still have only the original message + assert len(memory_list) == 1 + assert memory_list[0].get_value() == "old" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_runs_pipeline(): + """The method should invoke the normalization pipeline on the assembled conversation.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_message(role="system", content="be helpful") + user_msg = _make_message(role="user", content="hi") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [system_msg] + target._memory = mock_memory + + result = await target._get_normalized_conversation_async(message=user_msg) + + # System-squash normalizer should merge system into user + assert len(result) == 1 + assert "be helpful" in result[0].get_value() + assert "hi" in result[0].get_value() + roles = [m.get_piece().api_role for m in result] + assert "system" not in roles + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_passthrough_when_no_adaptation_needed(): + """When the target supports all capabilities, the pipeline should pass messages through unchanged.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + system_msg = _make_message(role="system", content="be nice") + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [system_msg] + target._memory = mock_memory + + result = await target._get_normalized_conversation_async(message=user_msg) + + # No adaptation — messages pass through as-is + assert len(result) == 2 + assert result[0].get_piece().api_role == "system" + assert result[0].get_value() == "be nice" + assert result[1].get_piece().api_role == "user" + assert result[1].get_value() == "hello" diff --git a/tests/unit/prompt_target/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py index 3ab7b0d7ad..1b3c7da16e 100644 --- a/tests/unit/prompt_target/target/test_openai_chat_target.py +++ b/tests/unit/prompt_target/target/test_openai_chat_target.py @@ -583,7 +583,7 @@ def test_validate_request_unsupported_data_types(target: OpenAIChatTarget): ) with pytest.raises(ValueError) as excinfo: - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) assert "This target supports only the following data types" in str(excinfo.value), ( "Error not raised for unsupported data types" diff --git a/tests/unit/prompt_target/target/test_openai_response_target.py b/tests/unit/prompt_target/target/test_openai_response_target.py index c3b73d01aa..0bd517ae4d 100644 --- a/tests/unit/prompt_target/target/test_openai_response_target.py +++ b/tests/unit/prompt_target/target/test_openai_response_target.py @@ -586,7 +586,7 @@ def test_validate_request_unsupported_data_types(target: OpenAIResponseTarget): ) with pytest.raises(ValueError) as excinfo: - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) assert "This target supports only the following data types" in str(excinfo.value), ( "Error not raised for unsupported data types" @@ -659,7 +659,7 @@ def test_validate_request_allows_text_and_image(target: OpenAIResponseTarget): ), ] ) - target._validate_request(message=req) + target._validate_request(normalized_conversation=[req]) def test_validate_request_raises_for_invalid_type(target: OpenAIResponseTarget): @@ -669,7 +669,7 @@ def test_validate_request_raises_for_invalid_type(target: OpenAIResponseTarget): ] ) with pytest.raises(ValueError) as excinfo: - target._validate_request(message=req) + target._validate_request(normalized_conversation=[req]) assert "This target supports only the following data types" in str(excinfo.value) diff --git a/tests/unit/prompt_target/target/test_openai_target_auth.py b/tests/unit/prompt_target/target/test_openai_target_auth.py index 5dfade7dce..38102323df 100644 --- a/tests/unit/prompt_target/target/test_openai_target_auth.py +++ b/tests/unit/prompt_target/target/test_openai_target_auth.py @@ -31,10 +31,10 @@ def _get_provider_examples(self) -> dict[str, str]: async def _construct_message_from_response(self, response, request): raise NotImplementedError - def _validate_request(self, *, message) -> None: + def _validate_request(self, *, normalized_conversation) -> None: pass - async def _send_prompt_target_async(self, *, normalized_conversation): + async def _send_prompt_to_target_async(self, *, normalized_conversation): raise NotImplementedError diff --git a/tests/unit/prompt_target/target/test_playwright_copilot_target.py b/tests/unit/prompt_target/target/test_playwright_copilot_target.py index 8a5aaf1d9c..7e19cdc9ad 100644 --- a/tests/unit/prompt_target/target/test_playwright_copilot_target.py +++ b/tests/unit/prompt_target/target/test_playwright_copilot_target.py @@ -155,7 +155,7 @@ def test_validate_request_unsupported_type(self, mock_page): match=r"This target supports only the following data types.*If your target does support this, set the" r" custom_configuration parameter accordingly", ): - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_valid_text(self, mock_page, text_request_piece): """Test validation with valid text request.""" @@ -163,14 +163,14 @@ def test_validate_request_valid_text(self, mock_page, text_request_piece): request = Message(message_pieces=[text_request_piece]) # Should not raise any exception - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_valid_multimodal(self, mock_page, multimodal_request): """Test validation with valid multimodal request.""" target = PlaywrightCopilotTarget(page=mock_page) # Should not raise any exception - target._validate_request(message=multimodal_request) + target._validate_request(normalized_conversation=[multimodal_request]) @pytest.mark.asyncio async def test_send_text_async(self, mock_page): diff --git a/tests/unit/prompt_target/target/test_playwright_target.py b/tests/unit/prompt_target/target/test_playwright_target.py index 4df935f09e..f91021ccbc 100644 --- a/tests/unit/prompt_target/target/test_playwright_target.py +++ b/tests/unit/prompt_target/target/test_playwright_target.py @@ -129,7 +129,7 @@ def test_validate_request_unsupported_type(self, mock_interaction_func, mock_pag match=r"This target supports only the following data types.*If your target does support this, set the" r" custom_configuration parameter accordingly", ): - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_valid_text(self, mock_interaction_func, mock_page, text_message_piece): """Test validation with valid text request.""" @@ -137,7 +137,7 @@ def test_validate_request_valid_text(self, mock_interaction_func, mock_page, tex request = Message(message_pieces=[text_message_piece]) # Should not raise any exception - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_valid_image(self, mock_interaction_func, mock_page, image_message_piece): """Test validation with valid image request.""" @@ -145,7 +145,7 @@ def test_validate_request_valid_image(self, mock_interaction_func, mock_page, im request = Message(message_pieces=[image_message_piece]) # Should not raise any exception - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_mixed_valid_types( self, mock_interaction_func, mock_page, text_message_piece, image_message_piece @@ -155,7 +155,7 @@ def test_validate_request_mixed_valid_types( request = Message(message_pieces=[text_message_piece, image_message_piece]) # Should not raise any exception - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) @pytest.mark.asyncio async def test_send_prompt_async_single_text(self, mock_interaction_func, mock_page, text_message_piece): @@ -345,7 +345,7 @@ def test_validate_request_multiple_unsupported_types(self, mock_interaction_func match=r"This target supports only the following data types.*If your target does support this, set the" r" custom_configuration parameter accordingly", ): - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) @pytest.mark.asyncio async def test_interaction_function_with_complex_response(self, mock_page): diff --git a/tests/unit/prompt_target/target/test_realtime_target.py b/tests/unit/prompt_target/target/test_realtime_target.py index 88f047722f..9dcb5c577f 100644 --- a/tests/unit/prompt_target/target/test_realtime_target.py +++ b/tests/unit/prompt_target/target/test_realtime_target.py @@ -185,7 +185,7 @@ async def test_send_prompt_async_invalid_request(target): ) message = Message(message_pieces=[message_piece]) with pytest.raises(ValueError) as excinfo: - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) assert "This target supports only the following data types" in str(excinfo.value) assert "image_path" in str(excinfo.value) diff --git a/tests/unit/prompt_target/target/test_target_capabilities.py b/tests/unit/prompt_target/target/test_target_capabilities.py index 72b2fa3a46..5c34221fce 100644 --- a/tests/unit/prompt_target/target/test_target_capabilities.py +++ b/tests/unit/prompt_target/target/test_target_capabilities.py @@ -493,7 +493,7 @@ def _make_target_class(self, *, default_config: "TargetConfiguration"): class _ConcreteTarget(PromptTarget): _DEFAULT_CONFIGURATION = default_config - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] return _ConcreteTarget diff --git a/tests/unit/prompt_target/target/test_tts_target.py b/tests/unit/prompt_target/target/test_tts_target.py index ec595b55b4..a9b1f5832a 100644 --- a/tests/unit/prompt_target/target/test_tts_target.py +++ b/tests/unit/prompt_target/target/test_tts_target.py @@ -90,8 +90,10 @@ async def test_tts_validate_previous_conversations( ): message_piece = sample_conversations[0] + prior_message = Message(message_pieces=[message_piece]) + mock_memory = MagicMock() - mock_memory.get_message_pieces.return_value = sample_conversations + mock_memory.get_conversation.return_value = [prior_message] mock_memory.add_message_to_memory = AsyncMock() tts_target._memory = mock_memory diff --git a/tests/unit/prompt_target/target/test_video_target.py b/tests/unit/prompt_target/target/test_video_target.py index faba7830a0..250f6d5d48 100644 --- a/tests/unit/prompt_target/target/test_video_target.py +++ b/tests/unit/prompt_target/target/test_video_target.py @@ -81,7 +81,7 @@ def test_video_validate_request_multiple_text_pieces(video_target: OpenAIVideoTa msg2 = MessagePiece( role="user", original_value="test2", converted_value="test2", conversation_id=conversation_id ) - video_target._validate_request(message=Message([msg1, msg2])) + video_target._validate_request(normalized_conversation=[Message([msg1, msg2])]) def test_video_validate_prompt_type_image_only(video_target: OpenAIVideoTarget): @@ -90,7 +90,7 @@ def test_video_validate_prompt_type_image_only(video_target: OpenAIVideoTarget): msg = MessagePiece( role="user", original_value="test", converted_value="test", converted_value_data_type="image_path" ) - video_target._validate_request(message=Message([msg])) + video_target._validate_request(normalized_conversation=[Message([msg])]) @pytest.mark.asyncio @@ -374,7 +374,7 @@ def test_validate_accepts_text_only(self, video_target: OpenAIVideoTarget): """Test validation accepts single text piece (text-to-video mode).""" msg = MessagePiece(role="user", original_value="test prompt", converted_value="test prompt") # Should not raise - video_target._validate_request(message=Message([msg])) + video_target._validate_request(normalized_conversation=[Message([msg])]) def test_validate_accepts_text_and_image(self, video_target: OpenAIVideoTarget): """Test validation accepts text + image (image-to-video mode).""" @@ -393,7 +393,7 @@ def test_validate_accepts_text_and_image(self, video_target: OpenAIVideoTarget): conversation_id=conversation_id, ) # Should not raise - video_target._validate_request(message=Message([msg_text, msg_image])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_image])]) def test_validate_rejects_multiple_images(self, video_target: OpenAIVideoTarget): """Test validation rejects multiple image pieces.""" @@ -419,7 +419,7 @@ def test_validate_rejects_multiple_images(self, video_target: OpenAIVideoTarget) conversation_id=conversation_id, ) with pytest.raises(ValueError, match="at most 1 image piece"): - video_target._validate_request(message=Message([msg_text, msg_img1, msg_img2])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_img1, msg_img2])]) def test_validate_rejects_unsupported_types(self, video_target: OpenAIVideoTarget): """Test validation rejects unsupported data types.""" @@ -442,7 +442,7 @@ def test_validate_rejects_unsupported_types(self, video_target: OpenAIVideoTarge match="This target supports only the following data types.*If your target does support this, set the" " custom_configuration parameter accordingly", ): - video_target._validate_request(message=Message([msg_text, msg_audio])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_audio])]) def test_validate_rejects_remix_with_image(self, video_target: OpenAIVideoTarget): """Test validation rejects remix mode combined with image input.""" @@ -462,7 +462,7 @@ def test_validate_rejects_remix_with_image(self, video_target: OpenAIVideoTarget conversation_id=conversation_id, ) with pytest.raises(ValueError, match="Cannot use image input in remix mode"): - video_target._validate_request(message=Message([msg_text, msg_image])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_image])]) @pytest.mark.usefixtures("patch_central_database") @@ -789,7 +789,7 @@ def test_validate_rejects_no_text_piece(self, video_target: OpenAIVideoTarget): converted_value_data_type="image_path", ) with pytest.raises(ValueError, match="Expected exactly 1 text piece"): - video_target._validate_request(message=Message([msg])) + video_target._validate_request(normalized_conversation=[Message([msg])]) @pytest.mark.asyncio async def test_image_to_video_with_jpeg(self, video_target: OpenAIVideoTarget): @@ -993,12 +993,7 @@ def test_video_validate_previous_conversations( ): message_piece = sample_conversations[0] - mock_memory = MagicMock() - mock_memory.get_message_pieces.return_value = sample_conversations - mock_memory.add_message_to_memory = AsyncMock() - - video_target._memory = mock_memory - + prior_message = Message(message_pieces=[message_piece]) request = Message(message_pieces=[message_piece]) with pytest.raises( @@ -1006,7 +1001,7 @@ def test_video_validate_previous_conversations( match="This target only supports a single turn conversation.*If your target does support this, set the" " custom_configuration parameter accordingly", ): - video_target._validate_request(message=request) + video_target._validate_request(normalized_conversation=[prior_message, request]) @pytest.mark.usefixtures("patch_central_database") @@ -1039,7 +1034,7 @@ def test_validate_accepts_text_and_video_path(self, video_target: OpenAIVideoTar conversation_id=conversation_id, ) # Should not raise - video_target._validate_request(message=Message([msg_text, msg_video])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_video])]) def test_validate_rejects_video_path_and_image_path(self, video_target: OpenAIVideoTarget) -> None: """Test validation rejects combining video_path and image_path.""" @@ -1065,7 +1060,7 @@ def test_validate_rejects_video_path_and_image_path(self, video_target: OpenAIVi conversation_id=conversation_id, ) with pytest.raises(ValueError, match="Cannot combine video_path and image_path"): - video_target._validate_request(message=Message([msg_text, msg_video, msg_image])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_video, msg_image])]) def test_remix_keeps_video_path_pieces_when_ids_match(self, video_target: OpenAIVideoTarget) -> None: """Test that video_path pieces are preserved after validation so normalizer stores them.""" diff --git a/tests/unit/prompt_target/target/test_websocket_copilot_target.py b/tests/unit/prompt_target/target/test_websocket_copilot_target.py index c6b6ce6c01..227c113529 100644 --- a/tests/unit/prompt_target/target/test_websocket_copilot_target.py +++ b/tests/unit/prompt_target/target/test_websocket_copilot_target.py @@ -743,19 +743,19 @@ def test_validate_request_data_types(self, mock_authenticator, make_message_piec message = Message(message_pieces=[message_piece]) if should_pass: - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) else: with pytest.raises( ValueError, match=f"This target supports only the following data types: image_path, text. Received: {data_type}.", ): - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) def test_validate_request_with_multiple_text_pieces(self, mock_authenticator, make_message_piece): target = WebSocketCopilotTarget(authenticator=mock_authenticator) message_pieces = [make_message_piece(f"test{i}", conversation_id="123") for i in range(3)] message = Message(message_pieces=message_pieces) - target._validate_request(message=message) # should not raise + target._validate_request(normalized_conversation=[message]) # should not raise def test_validate_request_with_mixed_valid_content(self, mock_authenticator, make_message_piece): target = WebSocketCopilotTarget(authenticator=mock_authenticator) @@ -764,7 +764,7 @@ def test_validate_request_with_mixed_valid_content(self, mock_authenticator, mak make_message_piece("/path/to/image.png", data_type="image_path", conversation_id="123"), ] message = Message(message_pieces=message_pieces) - target._validate_request(message=message) # should not raise + target._validate_request(normalized_conversation=[message]) # should not raise @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/prompt_target/test_prompt_chat_target.py b/tests/unit/prompt_target/test_prompt_chat_target.py index 562450b077..3ffcb6341a 100644 --- a/tests/unit/prompt_target/test_prompt_chat_target.py +++ b/tests/unit/prompt_target/test_prompt_chat_target.py @@ -105,7 +105,7 @@ def test_init_subclass_promotes_default_capabilities_with_warning(): class _LegacyTarget(PromptTarget): _DEFAULT_CAPABILITIES = TargetCapabilities(supports_multi_turn=True) - async def _send_prompt_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index 89c614f898..815a893884 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -17,7 +17,7 @@ class MockPromptTarget(PromptTarget): def __init__(self, *, model_name: str = "mock_model") -> None: super().__init__(model_name=model_name) - async def _send_prompt_target_async( + async def _send_prompt_to_target_async( self, *, normalized_conversation: list[Message], @@ -29,7 +29,7 @@ async def _send_prompt_target_async( ).to_message() ] - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass @@ -39,7 +39,7 @@ class MockPromptChatTarget(PromptChatTarget): def __init__(self, *, model_name: str = "mock_chat_model", endpoint: str = "http://chat-test") -> None: super().__init__(model_name=model_name, endpoint=endpoint) - async def _send_prompt_target_async( + async def _send_prompt_to_target_async( self, *, normalized_conversation: list[Message], @@ -51,7 +51,7 @@ async def _send_prompt_target_async( ).to_message() ] - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass From dbd8207a5254b201dccaf969b06c7d4310150b92 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 15 Apr 2026 19:07:52 -0400 Subject: [PATCH 06/11] add normalized conversation and pre-commit --- pyrit/prompt_target/common/prompt_target.py | 4 +-- .../openai/openai_realtime_target.py | 33 +++++++++++-------- .../target/test_realtime_target.py | 24 ++++++-------- 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 48f66437ea..ec9e826bd0 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -193,9 +193,7 @@ def _validate_request(self, *, normalized_conversation: list[Message]) -> None: ) if not self.capabilities.supports_multi_turn and len(normalized_conversation) > 1: - raise ValueError( - f"This target only supports a single turn conversation. {custom_configuration_message}" - ) + raise ValueError(f"This target only supports a single turn conversation. {custom_configuration_message}") async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: """ diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index dfe3d52b30..b9ad968bef 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -297,33 +297,32 @@ def _set_system_prompt_and_config_vars(self, system_prompt: str) -> dict[str, An return session_config - async def send_config(self, conversation_id: str) -> None: + async def send_config(self, *, conversation_id: str, conversation: list[Message]) -> None: """ Send the session configuration using OpenAI client. Args: conversation_id (str): Conversation ID + conversation (list[Message]): The conversation history to extract the system prompt from. """ # Extract system prompt from conversation history - system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = self._get_system_prompt_from_conversation(conversation=conversation) config_variables = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) connection = self._get_connection(conversation_id=conversation_id) await connection.session.update(session=config_variables) logger.info("Session configuration sent") - def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str: + def _get_system_prompt_from_conversation(self, *, conversation: list[Message]) -> str: """ Retrieve the system prompt from conversation history. Args: - conversation_id (str): The conversation ID + conversation (list[Message]): The conversation messages to search. Returns: str: The system prompt from conversation history, or a default if none found """ - conversation = self._memory.get_conversation(conversation_id=conversation_id) - # Look for a system message at the beginning of the conversation if conversation and len(conversation) > 0: first_message = conversation[0] @@ -357,7 +356,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me self._existing_conversation[conversation_id] = connection # Only send config when creating a new connection - await self.send_config(conversation_id=conversation_id) + await self.send_config(conversation_id=conversation_id, conversation=normalized_conversation) # Give the server a moment to process the session update await asyncio.sleep(0.5) @@ -367,12 +366,14 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me # Order of messages sent varies based on the data format of the prompt if response_type == "audio_path": output_audio_path, result = await self.send_audio_async( - filename=request.converted_value, conversation_id=conversation_id + filename=request.converted_value, conversation_id=conversation_id, + conversation=normalized_conversation, ) elif response_type == "text": output_audio_path, result = await self.send_text_async( - text=request.converted_value, conversation_id=conversation_id + text=request.converted_value, conversation_id=conversation_id, + conversation=normalized_conversation, ) else: raise ValueError(f"Unsupported response type: {response_type}") @@ -669,13 +670,16 @@ def _extract_error_details(*, response: Any) -> str: return f"[{error_type}] {error_message}" return "Unknown error occurred" - async def send_text_async(self, text: str, conversation_id: str) -> tuple[str, RealtimeTargetResult]: + async def send_text_async( + self, text: str, conversation_id: str, conversation: list[Message], + ) -> tuple[str, RealtimeTargetResult]: """ Send text prompt using OpenAI Realtime API client. Args: text: prompt to send. conversation_id: conversation ID + conversation: The normalized conversation history. Returns: Tuple[str, RealtimeTargetResult]: Path to saved audio file and the RealtimeTargetResult @@ -715,7 +719,7 @@ async def send_text_async(self, text: str, conversation_id: str) -> tuple[str, R self._existing_conversation[conversation_id] = new_connection # Send session configuration to new connection - system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = self._get_system_prompt_from_conversation(conversation=conversation) session_config = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) await new_connection.session.update(session=session_config) @@ -723,13 +727,16 @@ async def send_text_async(self, text: str, conversation_id: str) -> tuple[str, R output_audio_path = await self.save_audio(audio_bytes=result.audio_bytes, sample_rate=24000) return output_audio_path, result - async def send_audio_async(self, filename: str, conversation_id: str) -> tuple[str, RealtimeTargetResult]: + async def send_audio_async( + self, filename: str, conversation_id: str, conversation: list[Message], + ) -> tuple[str, RealtimeTargetResult]: """ Send an audio message using OpenAI Realtime API client. Args: filename (str): The path to the audio file. conversation_id (str): Conversation ID + conversation (list[Message]): The normalized conversation history. Returns: Tuple[str, RealtimeTargetResult]: Path to saved audio file and the RealtimeTargetResult @@ -784,7 +791,7 @@ async def send_audio_async(self, filename: str, conversation_id: str) -> tuple[s self._existing_conversation[conversation_id] = new_connection # Send session configuration to new connection - system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = self._get_system_prompt_from_conversation(conversation=conversation) session_config = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) await new_connection.session.update(session=session_config) diff --git a/tests/unit/prompt_target/target/test_realtime_target.py b/tests/unit/prompt_target/target/test_realtime_target.py index 9dcb5c577f..bc66b773fb 100644 --- a/tests/unit/prompt_target/target/test_realtime_target.py +++ b/tests/unit/prompt_target/target/test_realtime_target.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest @@ -64,6 +64,7 @@ async def test_send_prompt_async(target): target.send_text_async.assert_called_once_with( text="Hello", conversation_id="test_conversation_id", + conversation=ANY, ) assert response[0].get_value() == "hello" assert response[0].get_value(1) == "output.wav" @@ -75,23 +76,21 @@ async def test_send_prompt_async(target): @pytest.mark.asyncio async def test_get_system_prompt_from_conversation_with_system_message(target): """Test that system prompt is extracted from conversation history when present.""" - conversation_id = "test_conversation_with_system" - # Add a system message to memory + # Create a system message system_message = Message( message_pieces=[ MessagePiece( role="system", original_value="You are a helpful assistant specialized in security.", converted_value="You are a helpful assistant specialized in security.", - conversation_id=conversation_id, + conversation_id="test_conversation_with_system", ) ] ) - target._memory.add_message_to_memory(request=system_message) # Get the system prompt - system_prompt = target._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = target._get_system_prompt_from_conversation(conversation=[system_message]) assert system_prompt == "You are a helpful assistant specialized in security." @@ -99,23 +98,21 @@ async def test_get_system_prompt_from_conversation_with_system_message(target): @pytest.mark.asyncio async def test_get_system_prompt_from_conversation_default(target): """Test that default system prompt is returned when no system message in conversation.""" - conversation_id = "test_conversation_no_system" - # Add a user message (no system message) + # Create a user message (no system message) user_message = Message( message_pieces=[ MessagePiece( role="user", original_value="Hello", converted_value="Hello", - conversation_id=conversation_id, + conversation_id="test_conversation_no_system", ) ] ) - target._memory.add_message_to_memory(request=user_message) # Get the system prompt - system_prompt = target._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = target._get_system_prompt_from_conversation(conversation=[user_message]) assert system_prompt == "You are a helpful AI assistant" @@ -123,10 +120,9 @@ async def test_get_system_prompt_from_conversation_default(target): @pytest.mark.asyncio async def test_get_system_prompt_empty_conversation(target): """Test that default system prompt is returned for empty conversation.""" - conversation_id = "test_empty_conversation" - # Get the system prompt without adding any messages - system_prompt = target._get_system_prompt_from_conversation(conversation_id=conversation_id) + # Get the system prompt without any messages + system_prompt = target._get_system_prompt_from_conversation(conversation=[]) assert system_prompt == "You are a helpful AI assistant" From 9f7ed82010051eb08eaa4f32834b24919a9500ad Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Thu, 16 Apr 2026 11:07:18 -0400 Subject: [PATCH 07/11] pr comments & pre commit --- pyrit/prompt_target/azure_ml_chat_target.py | 5 ++++- pyrit/prompt_target/common/prompt_target.py | 3 +++ .../openai/openai_realtime_target.py | 18 ++++++++++++++---- .../openai/openai_response_target.py | 6 ++---- .../target/test_normalize_async_integration.py | 7 ++++++- .../target/test_openai_target_auth.py | 1 - 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index ba92008369..fee748556c 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -108,6 +108,10 @@ def __init__( Note that the link above may not be comprehensive, and specific acceptable parameters may be model-dependent. If a model does not accept a certain parameter that is passed in, it will be skipped without throwing an error. + + Raises: + ValueError: If both `message_normalizer` and `custom_configuration` are provided, + since `message_normalizer` is deprecated and the two configurations may conflict. """ endpoint_value = default_values.get_required_value( env_var_name=self.endpoint_uri_environment_variable, passed_value=endpoint @@ -159,7 +163,6 @@ def __init__( validate_temperature(temperature) validate_top_p(top_p) - self.message_normalizer = message_normalizer self._max_new_tokens = max_new_tokens self._temperature = temperature self._top_p = top_p diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index ec9e826bd0..2fa280eabc 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -129,6 +129,9 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Returns: list[Message]: Response messages from the target. + + Raises: + ValueError: If the message or normalized conversation are empty. """ if not message.message_pieces: raise ValueError("Message must contain at least one message piece. Received: 0 pieces.") diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index b9ad968bef..f298def2aa 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -366,13 +366,15 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me # Order of messages sent varies based on the data format of the prompt if response_type == "audio_path": output_audio_path, result = await self.send_audio_async( - filename=request.converted_value, conversation_id=conversation_id, + filename=request.converted_value, + conversation_id=conversation_id, conversation=normalized_conversation, ) elif response_type == "text": output_audio_path, result = await self.send_text_async( - text=request.converted_value, conversation_id=conversation_id, + text=request.converted_value, + conversation_id=conversation_id, conversation=normalized_conversation, ) else: @@ -671,7 +673,11 @@ def _extract_error_details(*, response: Any) -> str: return "Unknown error occurred" async def send_text_async( - self, text: str, conversation_id: str, conversation: list[Message], + self, + *, + text: str, + conversation_id: str, + conversation: list[Message], ) -> tuple[str, RealtimeTargetResult]: """ Send text prompt using OpenAI Realtime API client. @@ -728,7 +734,11 @@ async def send_text_async( return output_audio_path, result async def send_audio_async( - self, filename: str, conversation_id: str, conversation: list[Message], + self, + *, + filename: str, + conversation_id: str, + conversation: list[Message], ) -> tuple[str, RealtimeTargetResult]: """ Send an audio message using OpenAI Realtime API client. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 45e1435cbe..3e6e8c1bd6 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -536,10 +536,8 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me """ message = normalized_conversation[-1] message_piece: MessagePiece = message.message_pieces[0] - json_config = _JsonResponseConfig(enabled=False) - if message.message_pieces: - last_piece = message.message_pieces[-1] - json_config = self._get_json_response_config(message_piece=last_piece) + last_piece = message.message_pieces[-1] + json_config = self._get_json_response_config(message_piece=last_piece) working_conversation: MutableSequence[Message] = list(normalized_conversation) diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 14265c24e2..4014739303 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -1,11 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import json import warnings -from collections.abc import MutableSequence +from typing import TYPE_CHECKING from unittest.mock import AsyncMock, MagicMock, patch +if TYPE_CHECKING: + from collections.abc import MutableSequence + import pytest from openai.types.chat import ChatCompletion diff --git a/tests/unit/prompt_target/target/test_openai_target_auth.py b/tests/unit/prompt_target/target/test_openai_target_auth.py index 38102323df..18c8037d63 100644 --- a/tests/unit/prompt_target/target/test_openai_target_auth.py +++ b/tests/unit/prompt_target/target/test_openai_target_auth.py @@ -20,7 +20,6 @@ def _set_openai_env_configuration_vars(self) -> None: self.model_name_environment_variable = "TEST_MODEL" self.endpoint_environment_variable = "TEST_ENDPOINT" self.api_key_environment_variable = "TEST_API_KEY" - self.underlying_model_environment_variable = "TEST_UNDERLYING_MODEL" def _get_target_api_paths(self) -> list[str]: return [] From 961ed68f56b94289e0f53f8f74249adacbf9cebd Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Thu, 16 Apr 2026 16:12:19 -0400 Subject: [PATCH 08/11] propagate lineage and PR comments --- pyrit/prompt_target/azure_ml_chat_target.py | 4 +- pyrit/prompt_target/common/prompt_target.py | 45 +++- .../openai/openai_chat_target.py | 3 + .../openai/openai_realtime_target.py | 19 +- .../target/test_prompt_target.py | 250 ++++++++++++++++++ 5 files changed, 309 insertions(+), 12 deletions(-) diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index fee748556c..8a4068e60c 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -86,7 +86,7 @@ def __init__( message_normalizer (MessageListNormalizer[Any] | None): **Deprecated.** Use ``custom_configuration`` with ``CapabilityHandlingPolicy`` instead. Previously used for models that do not allow system prompts. - Will be removed in v0.14.0. + Will be removed in v0.15.0. max_new_tokens (int): The maximum number of tokens to generate in the response. Defaults to 400. temperature (float): The temperature for generating diverse responses. 1.0 is most random, @@ -123,7 +123,7 @@ def __init__( raise ValueError( "Cannot specify both 'message_normalizer' and 'custom_configuration'. " "Use 'custom_configuration' only; 'message_normalizer' is deprecated and " - "will be removed in v0.14.0." + "will be removed in v0.15.0." ) warnings.warn( "Passing message_normalizer is deprecated. Use custom_configuration with " diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 2fa280eabc..257d5d8952 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -115,7 +115,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: This is the public entry point called by the prompt normalizer. It: - 1. Fetches the conversation from memory, appends ``message``, and runs + 1. Validates the message, fetches the conversation from memory, appends ``message``, and runs the normalization pipeline (system‑squash, history‑squash, etc.). 2. Validates the normalized conversation against the target's capabilities. 3. Delegates to :meth:`_send_prompt_to_target_async` with the normalized @@ -133,8 +133,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Raises: ValueError: If the message or normalized conversation are empty. """ - if not message.message_pieces: - raise ValueError("Message must contain at least one message piece. Received: 0 pieces.") + message.validate() normalized_conversation = await self._get_normalized_conversation_async(message=message) if not normalized_conversation: raise ValueError("Normalization pipeline returned an empty conversation. Cannot send an empty request.") @@ -206,6 +205,11 @@ async def _get_normalized_conversation_async(self, *, message: Message) -> list[ The original conversation in memory is never mutated. The returned list is an ephemeral copy intended only for building the API request body. + After normalization, the metadata from the original ``message`` is copied + onto the last normalized message so that downstream code (e.g. + ``construct_response_from_request``) propagates the correct + ``conversation_id``, ``labels``, ``attack_identifier``, etc. to the response. + Args: message (Message): The current message to append. @@ -216,7 +220,40 @@ async def _get_normalized_conversation_async(self, *, message: Message) -> list[ conversation_id = message.message_pieces[0].conversation_id conversation = list(self._memory.get_conversation(conversation_id=conversation_id)) conversation.append(message) - return await self.configuration.normalize_async(messages=conversation) + normalized = await self.configuration.normalize_async(messages=conversation) + if normalized: + self._propagate_lineage(source=message, target_message=normalized[-1]) + return normalized + + @staticmethod + def _propagate_lineage(*, source: Message, target_message: Message) -> None: + """ + Copy request-lineage metadata from ``source`` onto every piece in ``target_message``. + + Normalizers may create brand-new ``Message`` objects (e.g. ``HistorySquashNormalizer`` + uses ``Message.from_prompt``) that carry fresh random ``conversation_id`` values and + lack ``labels``, ``attack_identifier``, etc. This method restores the original + metadata so that the response built from the normalized message stays part of the + correct conversation and retains traceability. + + Note: + Only the **last** message in the normalized list is stamped (the caller passes + ``normalized[-1]``). This is intentional: earlier messages in the list are + history entries fetched from memory that already carry correct metadata. + Only the final message — which corresponds to the current user turn and may + have been rebuilt by a normalizer — needs its lineage restored. + + Args: + source: The original (pre-normalization) message whose metadata is authoritative. + target_message: The normalized message whose pieces will be updated in place. + """ + source_piece = source.message_pieces[0] + for piece in target_message.message_pieces: + piece.conversation_id = source_piece.conversation_id + piece.labels = source_piece.labels + piece.attack_identifier = source_piece.attack_identifier + piece.prompt_target_identifier = source_piece.prompt_target_identifier + piece.prompt_metadata = source_piece.prompt_metadata def set_model_name(self, *, model_name: str) -> None: """ diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index d246a4ba01..da45377a87 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -72,6 +72,9 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget): supports_json_output=True, supports_multi_message_pieces=True, supports_system_prompt=True, + input_modalities=frozenset( + {frozenset({"text"}), frozenset({"image_path"}), frozenset({"text", "image_path"})} + ), ) ) diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index f298def2aa..2b6d4d032d 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -297,16 +297,23 @@ def _set_system_prompt_and_config_vars(self, system_prompt: str) -> dict[str, An return session_config - async def send_config(self, *, conversation_id: str, conversation: list[Message]) -> None: + async def send_config(self, *, conversation_id: str, normalized_conversation: list[Message] | None = None) -> None: """ Send the session configuration using OpenAI client. Args: conversation_id (str): Conversation ID - conversation (list[Message]): The conversation history to extract the system prompt from. - """ - # Extract system prompt from conversation history - system_prompt = self._get_system_prompt_from_conversation(conversation=conversation) + conversation (list[Message] | None): The conversation history to extract the system prompt + from. If None, the conversation is fetched from memory. Defaults to None. + """ + # Extract system prompt from conversation history. Use the conversation passed in if available, + # otherwise fetch from memory. + resolved_conversation = ( + normalized_conversation + if normalized_conversation is not None + else list(self._memory.get_conversation(conversation_id=conversation_id)) + ) + system_prompt = self._get_system_prompt_from_conversation(conversation=resolved_conversation) config_variables = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) connection = self._get_connection(conversation_id=conversation_id) @@ -356,7 +363,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me self._existing_conversation[conversation_id] = connection # Only send config when creating a new connection - await self.send_config(conversation_id=conversation_id, conversation=normalized_conversation) + await self.send_config(conversation_id=conversation_id, normalized_conversation=normalized_conversation) # Give the server a moment to process the session update await asyncio.sleep(0.5) diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index 08ce7f217d..e73e8740d3 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -1,16 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest +from openai.types.chat import ChatCompletion from unit.mocks import get_sample_conversations, openai_chat_response_json_dict from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.identifiers import ComponentIdentifier +from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece from pyrit.prompt_target import OpenAIChatTarget +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration @pytest.fixture @@ -150,3 +160,243 @@ async def test_send_prompt_async_with_delay( mock_create.assert_called_once() mock_sleep.assert_called_once_with(6) # 60/max_requests_per_minute + + +# --------------------------------------------------------------------------- +# _propagate_lineage — metadata preservation after normalization +# --------------------------------------------------------------------------- + +_LINEAGE_CONVERSATION_ID = "original-conv-id-12345" +_LINEAGE_LABELS = {"op_name": "test_op", "user_id": "user42"} +_LINEAGE_ATTACK_IDENTIFIER = ComponentIdentifier(class_name="TestAttack", class_module="tests.attacks") +_LINEAGE_PROMPT_TARGET_IDENTIFIER = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit") +_LINEAGE_PROMPT_METADATA = {"scenario": "test_scenario", "turn": 3} + + +def _make_lineage_piece(*, role: str, content: str) -> MessagePiece: + return MessagePiece( + role=role, + conversation_id=_LINEAGE_CONVERSATION_ID, + original_value=content, + converted_value=content, + original_value_data_type="text", + converted_value_data_type="text", + labels=dict(_LINEAGE_LABELS), + prompt_target_identifier=_LINEAGE_PROMPT_TARGET_IDENTIFIER, + attack_identifier=_LINEAGE_ATTACK_IDENTIFIER, + prompt_metadata=dict(_LINEAGE_PROMPT_METADATA), + ) + + +def _make_lineage_message(*, role: str, content: str) -> Message: + return Message(message_pieces=[_make_lineage_piece(role=role, content=content)]) + + +def _make_mock_chat_completion(content: str = "response") -> MagicMock: + mock = MagicMock(spec=ChatCompletion) + mock.choices = [MagicMock()] + mock.choices[0].finish_reason = "stop" + mock.choices[0].message.content = content + mock.choices[0].message.audio = None + mock.choices[0].message.tool_calls = None + mock.model_dump_json.return_value = json.dumps( + {"choices": [{"finish_reason": "stop", "message": {"content": content}}]} + ) + return mock + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_history_squash_preserves_metadata_on_normalized_message(): + """ + After history squash, _propagate_lineage should restore the original request's + metadata (conversation_id, labels, attack_identifier) onto the squashed message. + """ + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=True, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ), + ), + ) + + history_msg = _make_lineage_message(role="assistant", content="previous answer") + user_msg = _make_lineage_message(role="user", content="follow-up question") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [history_msg] + target._memory = mock_memory + + normalized = await target._get_normalized_conversation_async(message=user_msg) + + assert len(normalized) == 1 + + normalized_piece = normalized[0].message_pieces[0] + + assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID + assert normalized_piece.labels == _LINEAGE_LABELS + assert normalized_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER + assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER + assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_response_preserves_metadata_after_history_squash(): + """ + End-to-end: after history squash the response must carry the original + request's conversation_id, labels, and attack_identifier — not the + random values created by the normalizer. + """ + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=True, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ), + ), + ) + + history_msg = _make_lineage_message(role="assistant", content="previous answer") + user_msg = _make_lineage_message(role="user", content="follow-up question") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [history_msg] + target._memory = mock_memory + + mock_completion = _make_mock_chat_completion("target response") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + response_messages = await target.send_prompt_async(message=user_msg) + + assert len(response_messages) == 1 + response_piece = response_messages[0].message_pieces[0] + + assert response_piece.conversation_id == _LINEAGE_CONVERSATION_ID + assert response_piece.labels == _LINEAGE_LABELS + assert response_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER + assert response_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER + assert response_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_system_squash_preserves_metadata(): + """ + GenericSystemSquashNormalizer also creates messages via Message.from_prompt. + _propagate_lineage should restore the original metadata after system squash too. + """ + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_lineage_message(role="system", content="be helpful") + user_msg = _make_lineage_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [system_msg] + target._memory = mock_memory + + normalized = await target._get_normalized_conversation_async(message=user_msg) + + assert len(normalized) == 1 + assert "be helpful" in normalized[0].get_value() + + normalized_piece = normalized[0].message_pieces[0] + + assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID + assert normalized_piece.labels == _LINEAGE_LABELS + assert normalized_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER + assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER + assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_history_squash_propagates_lineage_to_all_pieces(): + """ + When the squashed message contains multiple pieces, _propagate_lineage + must stamp every piece — not just the first one. + """ + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=True, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ), + ), + ) + + history_msg = _make_lineage_message(role="assistant", content="previous answer") + # Build a user message with two pieces to exercise multi-piece stamping. + user_msg = Message( + message_pieces=[ + _make_lineage_piece(role="user", content="first part"), + _make_lineage_piece(role="user", content="second part"), + ] + ) + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [history_msg] + target._memory = mock_memory + + normalized = await target._get_normalized_conversation_async(message=user_msg) + + assert len(normalized) == 1 + + for piece in normalized[0].message_pieces: + assert piece.conversation_id == _LINEAGE_CONVERSATION_ID + assert piece.labels == _LINEAGE_LABELS + assert piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER + assert piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER + assert piece.prompt_metadata == _LINEAGE_PROMPT_METADATA From 8973981b5d4aaac6e283abbbe6ba66e26a68fe0d Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Thu, 16 Apr 2026 16:57:31 -0400 Subject: [PATCH 09/11] fix tests --- .../prompt_target/target/test_http_api_target.py | 15 ++++----------- .../unit/prompt_target/target/test_http_target.py | 11 ++++------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/unit/prompt_target/target/test_http_api_target.py b/tests/unit/prompt_target/target/test_http_api_target.py index 3b14d2e704..16613ec024 100644 --- a/tests/unit/prompt_target/target/test_http_api_target.py +++ b/tests/unit/prompt_target/target/test_http_api_target.py @@ -148,14 +148,7 @@ async def test_send_prompt_async_missing_explicit_file_path_raises(mock_request, @pytest.mark.asyncio -@patch("httpx.AsyncClient.request") -async def test_send_prompt_async_validation(mock_request, patch_central_database): - # Create an invalid message (empty message_pieces) - message = MagicMock() - message.message_pieces = [] - target = HTTPXAPITarget(http_url="http://example.com/validate/", method="POST", timeout=180) - - with pytest.raises(ValueError) as excinfo: - await target.send_prompt_async(message=message) - - assert "Message must contain at least one message piece. Received: 0 pieces." in str(excinfo.value) +async def test_send_prompt_async_validation(patch_central_database): + # Creating a Message with no pieces raises immediately + with pytest.raises(ValueError, match="must have at least one message piece"): + Message(message_pieces=[]) diff --git a/tests/unit/prompt_target/target/test_http_target.py b/tests/unit/prompt_target/target/test_http_target.py index ee9ec61c3d..0e6dbc2f19 100644 --- a/tests/unit/prompt_target/target/test_http_target.py +++ b/tests/unit/prompt_target/target/test_http_target.py @@ -7,6 +7,7 @@ import httpx import pytest +from pyrit.models import Message from pyrit.prompt_target.http_target.http_target import HTTPTarget from pyrit.prompt_target.http_target.http_target_callback_functions import ( get_http_target_json_response_callback_function, @@ -145,13 +146,9 @@ async def test_send_prompt_async_client_kwargs(patch_central_database): @pytest.mark.asyncio async def test_send_prompt_async_validation(mock_http_target): - # Create an invalid message (missing message_pieces) - invalid_message = MagicMock() - invalid_message.message_pieces = [] - with pytest.raises(ValueError) as value_error: - await mock_http_target.send_prompt_async(message=invalid_message) - - assert str(value_error.value) == "Message must contain at least one message piece. Received: 0 pieces." + # Creating a Message with no pieces raises immediately + with pytest.raises(ValueError, match="must have at least one message piece"): + Message(message_pieces=[]) @pytest.mark.asyncio From 0f64a5eb960af5b90e10f32a6fac44f317b4f4d7 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Thu, 16 Apr 2026 17:07:03 -0400 Subject: [PATCH 10/11] fix docstring --- pyrit/prompt_target/openai/openai_realtime_target.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 2b6d4d032d..feeeeed5b8 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -303,8 +303,8 @@ async def send_config(self, *, conversation_id: str, normalized_conversation: li Args: conversation_id (str): Conversation ID - conversation (list[Message] | None): The conversation history to extract the system prompt - from. If None, the conversation is fetched from memory. Defaults to None. + normalized_conversation (list[Message] | None): The normalized_conversation history to extract the system + prompt from. If None, the conversation is fetched from memory. Defaults to None. """ # Extract system prompt from conversation history. Use the conversation passed in if available, # otherwise fetch from memory. From 54e0a25627bb869de62037828fde3067a0c1519d Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Thu, 16 Apr 2026 17:23:06 -0400 Subject: [PATCH 11/11] whitespace :| --- pyrit/prompt_target/openai/openai_realtime_target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index feeeeed5b8..2b8380e666 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -303,7 +303,7 @@ async def send_config(self, *, conversation_id: str, normalized_conversation: li Args: conversation_id (str): Conversation ID - normalized_conversation (list[Message] | None): The normalized_conversation history to extract the system + normalized_conversation (list[Message] | None): The normalized_conversation history to extract the system prompt from. If None, the conversation is fetched from memory. Defaults to None. """ # Extract system prompt from conversation history. Use the conversation passed in if available,