diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index dcabad099..71664ed88 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -196,18 +196,20 @@ def _parse_url(self) -> tuple[str, str]: return container_url, blob_prefix @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ (Async) Sends prompt to target, which creates a file and uploads it as a blob to the provided storage container. Args: - message (Message): A Message to be sent to the target. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response with the Blob URL. """ - self._validate_request(message=message) + message = normalized_conversation[-1] request = message.message_pieces[0] # default file name is .txt, but can be overridden by prompt metadata diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index cd7083c61..8a4068e60 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -2,7 +2,8 @@ # Licensed under the MIT license. import logging -from typing import Any, Optional +import warnings +from typing import Any from httpx import HTTPStatusError @@ -20,7 +21,12 @@ construct_response_from_request, ) from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p @@ -54,57 +60,95 @@ class AzureMLChatTarget(PromptChatTarget): def __init__( self, *, - endpoint: Optional[str] = None, - api_key: Optional[str] = None, + endpoint: str | None = None, + api_key: str | None = None, model_name: str = "", - message_normalizer: Optional[MessageListNormalizer[Any]] = None, + message_normalizer: MessageListNormalizer[Any] | None = None, max_new_tokens: int = 400, temperature: float = 1.0, top_p: float = 1.0, repetition_penalty: float = 1.0, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, - custom_capabilities: Optional[TargetCapabilities] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, + custom_capabilities: TargetCapabilities | None = None, **param_kwargs: Any, ) -> None: """ Initialize an instance of the AzureMLChatTarget class. Args: - endpoint (str, Optional): The endpoint URL for the deployed Azure ML model. + endpoint (str | None): The endpoint URL for the deployed Azure ML model. Defaults to the value of the AZURE_ML_MANAGED_ENDPOINT environment variable. - api_key (str, Optional): The API key for accessing the Azure ML endpoint. + api_key (str | None): The API key for accessing the Azure ML endpoint. Defaults to the value of the `AZURE_ML_KEY` environment variable. - model_name (str, Optional): The name of the model being used (e.g., "Llama-3.2-3B-Instruct"). + model_name (str): The name of the model being used (e.g., "Llama-3.2-3B-Instruct"). Used for identification purposes. Defaults to empty string. - message_normalizer (MessageListNormalizer, Optional): The message normalizer. - For models that do not allow system prompts such as mistralai-Mixtral-8x7B-Instruct-v01, - GenericSystemSquashNormalizer() can be passed in. Defaults to ChatMessageNormalizer(). - max_new_tokens (int, Optional): The maximum number of tokens to generate in the response. + message_normalizer (MessageListNormalizer[Any] | None): **Deprecated.** Use + ``custom_configuration`` with ``CapabilityHandlingPolicy`` instead. Previously used for + models that do not allow system prompts. + Will be removed in v0.15.0. + max_new_tokens (int): The maximum number of tokens to generate in the response. Defaults to 400. - temperature (float, Optional): The temperature for generating diverse responses. 1.0 is most random, + temperature (float): The temperature for generating diverse responses. 1.0 is most random, 0.0 is least random. Defaults to 1.0. - top_p (float, Optional): The top-p value for generating diverse responses. It represents + top_p (float): The top-p value for generating diverse responses. It represents the cumulative probability of the top tokens to keep. Defaults to 1.0. - repetition_penalty (float, Optional): The repetition penalty for generating diverse responses. + repetition_penalty (float): The repetition penalty for generating diverse responses. 1.0 means no penalty with a greater value (up to 2.0) meaning more penalty for repeating tokens. Defaults to 1.2. - max_requests_per_minute (int, Optional): Number of requests the target can handle per + max_requests_per_minute (int | None): Number of requests the target can handle per minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. - custom_configuration (TargetConfiguration, Optional): Override the default configuration for this target + custom_configuration (TargetConfiguration | None): Override the default configuration for this target instance. Useful for targets whose capabilities depend on deployment configuration. - custom_capabilities (TargetCapabilities, Optional): **Deprecated.** Use + custom_capabilities (TargetCapabilities | None): **Deprecated.** Use ``custom_configuration`` instead. Will be removed in v0.14.0. **param_kwargs: Additional parameters to pass to the model for generating responses. Example parameters can be found here: https://huggingface.co/docs/api-inference/tasks/text-generation. Note that the link above may not be comprehensive, and specific acceptable parameters may be model-dependent. If a model does not accept a certain parameter that is passed in, it will be skipped without throwing an error. + + Raises: + ValueError: If both `message_normalizer` and `custom_configuration` are provided, + since `message_normalizer` is deprecated and the two configurations may conflict. """ endpoint_value = default_values.get_required_value( env_var_name=self.endpoint_uri_environment_variable, passed_value=endpoint ) + + # Translate legacy message_normalizer into TargetConfiguration + if message_normalizer is not None: + if custom_configuration is not None: + raise ValueError( + "Cannot specify both 'message_normalizer' and 'custom_configuration'. " + "Use 'custom_configuration' only; 'message_normalizer' is deprecated and " + "will be removed in v0.15.0." + ) + warnings.warn( + "Passing message_normalizer is deprecated. Use custom_configuration with " + "CapabilityHandlingPolicy instead. Will be removed in v0.14.0.", + DeprecationWarning, + stacklevel=2, + ) + # The legacy message_normalizer was primarily used to handle system prompts + # for models that don't support them (e.g. GenericSystemSquashNormalizer). + # We translate it into a TargetConfiguration that marks system_prompt as + # unsupported + ADAPT so the pipeline invokes the user's normalizer. + default_caps = self._DEFAULT_CONFIGURATION.capabilities + default_behaviors = dict(self._DEFAULT_CONFIGURATION.policy.behaviors) + default_behaviors[CapabilityName.SYSTEM_PROMPT] = UnsupportedCapabilityBehavior.ADAPT + custom_configuration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_message_pieces=default_caps.supports_multi_message_pieces, + supports_editable_history=default_caps.supports_editable_history, + supports_multi_turn=default_caps.supports_multi_turn, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy(behaviors=default_behaviors), + normalizer_overrides={CapabilityName.SYSTEM_PROMPT: message_normalizer}, + ) + PromptChatTarget.__init__( self, max_requests_per_minute=max_requests_per_minute, @@ -119,7 +163,6 @@ def __init__( validate_temperature(temperature) validate_top_p(top_p) - self.message_normalizer = message_normalizer if message_normalizer is not None else ChatMessageNormalizer() self._max_new_tokens = max_new_tokens self._temperature = temperature self._top_p = top_p @@ -139,11 +182,10 @@ def _build_identifier(self) -> ComponentIdentifier: "top_p": self._top_p, "max_new_tokens": self._max_new_tokens, "repetition_penalty": self._repetition_penalty, - "message_normalizer": self.message_normalizer.__class__.__name__, }, ) - def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str] = None) -> None: + def _initialize_vars(self, endpoint: str | None = None, api_key: str | None = None) -> None: """ Set the endpoint and key for accessing the Azure ML model. Use this function to manually pass in your own endpoint uri and api key. Defaults to the values in the .env file for the variables @@ -164,12 +206,14 @@ def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Azure ML chat target. Args: - message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. @@ -179,18 +223,14 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: RateLimitException: If the target rate limit is exceeded. HTTPStatusError: For any other HTTP errors during the process. """ - self._validate_request(message=message) + message = normalized_conversation[-1] request = message.message_pieces[0] - # Get chat messages from memory and append the current message - messages = list(self._memory.get_conversation(conversation_id=request.conversation_id)) - messages.append(message) - logger.info(f"Sending the following prompt to the prompt target: {request}") try: resp_text = await self._complete_chat_async( - messages=messages, + messages=normalized_conversation, ) if not resp_text: @@ -259,8 +299,8 @@ async def _construct_http_body_async( Returns: dict: The constructed HTTP request body. """ - # Use the message normalizer to convert Messages to dict format - messages_dict = await self.message_normalizer.normalize_to_dicts_async(messages) + wire_format = ChatMessageNormalizer() + messages_dict = await wire_format.normalize_to_dicts_async(messages) # Parameters include additional ones passed in through **kwargs. Those not accepted by the model will # be ignored. We only include commonly supported parameters here - model-specific parameters like @@ -292,5 +332,5 @@ def _get_headers(self) -> dict[str, str]: return headers - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass diff --git a/pyrit/prompt_target/common/conversation_normalization_pipeline.py b/pyrit/prompt_target/common/conversation_normalization_pipeline.py index 3b6fd0f90..6c50b7808 100644 --- a/pyrit/prompt_target/common/conversation_normalization_pipeline.py +++ b/pyrit/prompt_target/common/conversation_normalization_pipeline.py @@ -2,6 +2,8 @@ # Licensed under the MIT license. import logging +from collections.abc import Mapping +from typing import Any from pyrit.message_normalizer import ( GenericSystemSquashNormalizer, @@ -63,7 +65,7 @@ def from_capabilities( *, capabilities: TargetCapabilities, policy: CapabilityHandlingPolicy, - normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None, + normalizer_overrides: Mapping[CapabilityName, MessageListNormalizer[Any]] | None = None, ) -> "ConversationNormalizationPipeline": """ Resolve capabilities and policy into a concrete pipeline of normalizers. @@ -80,7 +82,7 @@ def from_capabilities( Args: capabilities (TargetCapabilities): The target's declared capabilities. policy (CapabilityHandlingPolicy): How to handle each missing capability. - normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): + normalizer_overrides (Mapping[CapabilityName, MessageListNormalizer[Any]] | None): Optional overrides for specific capability normalizers. Falls back to the defaults from ``_NORMALIZER_REGISTRY``. @@ -103,7 +105,6 @@ def from_capabilities( # workflow is implemented. if behavior == UnsupportedCapabilityBehavior.ADAPT: normalizer = overrides.get(capability, default_normalizer) - normalizers.append(normalizer) return cls(normalizers=tuple(normalizers)) diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index df3990e06..257d5d895 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -4,7 +4,7 @@ import abc import logging import warnings -from typing import Any, Optional, Union +from typing import Any, Union, final from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory, MemoryInterface @@ -29,7 +29,7 @@ class PromptTarget(Identifiable): # An empty list implies that the prompt target supports all converters. supported_converters: list[Any] - _identifier: Optional[ComponentIdentifier] = None + _identifier: ComponentIdentifier | None = None # Class-level default configuration for this target type. # @@ -63,30 +63,30 @@ def __init_subclass__(cls, **kwargs: Any) -> None: def __init__( self, verbose: bool = False, - max_requests_per_minute: Optional[int] = None, + max_requests_per_minute: int | None = None, endpoint: str = "", model_name: str = "", - underlying_model: Optional[str] = None, - custom_configuration: Optional[TargetConfiguration] = None, - custom_capabilities: Optional[TargetCapabilities] = None, + underlying_model: str | None = None, + custom_configuration: TargetConfiguration | None = None, + custom_capabilities: TargetCapabilities | None = None, ) -> None: """ Initialize the PromptTarget. Args: verbose (bool): Enable verbose logging. Defaults to False. - max_requests_per_minute (int, Optional): Maximum number of requests per minute. + max_requests_per_minute (int | None): Maximum number of requests per minute. endpoint (str): The endpoint URL. Defaults to empty string. model_name (str): The model name. Defaults to empty string. - underlying_model (str, Optional): The underlying model name (e.g., "gpt-4o") for + underlying_model (str | None): The underlying model name (e.g., "gpt-4o") for identification purposes. This is useful when the deployment name in Azure differs from the actual model. If not provided, ``model_name`` will be used for the identifier. Defaults to None. - custom_configuration (TargetConfiguration, Optional): Override the default configuration + custom_configuration (TargetConfiguration | None): Override the default configuration for this target instance. Useful for targets whose capabilities depend on deployment configuration (e.g., Playwright, HTTP). If None, uses the class-level ``_DEFAULT_CONFIGURATION``. Defaults to None. - custom_capabilities (TargetCapabilities, Optional): **Deprecated.** Use + custom_capabilities (TargetCapabilities | None): **Deprecated.** Use ``custom_configuration`` instead. Will be removed in v0.14.0. """ custom_configuration = resolve_configuration_compat( @@ -108,32 +108,72 @@ def __init__( if self._verbose: logging.basicConfig(level=logging.INFO) - @abc.abstractmethod + @final async def send_prompt_async(self, *, message: Message) -> list[Message]: """ - Send a normalized prompt async to the prompt target. + Validate, normalize, and send a prompt to the target. + + This is the public entry point called by the prompt normalizer. It: + + 1. Validates the message, fetches the conversation from memory, appends ``message``, and runs + the normalization pipeline (system‑squash, history‑squash, etc.). + 2. Validates the normalized conversation against the target's capabilities. + 3. Delegates to :meth:`_send_prompt_to_target_async` with the normalized + conversation. + + Subclasses MUST NOT override this method. Override + :meth:`_send_prompt_to_target_async` instead. + + Args: + message (Message): The message to send. Returns: - list[Message]: A list of message responses. Most targets return a single message, - but some (like response target with tool calls) may return multiple messages. + list[Message]: Response messages from the target. + + Raises: + ValueError: If the message or normalized conversation are empty. """ + message.validate() + normalized_conversation = await self._get_normalized_conversation_async(message=message) + if not normalized_conversation: + raise ValueError("Normalization pipeline returned an empty conversation. Cannot send an empty request.") + self._validate_request(normalized_conversation=normalized_conversation) + return await self._send_prompt_to_target_async(normalized_conversation=normalized_conversation) - def _validate_request(self, *, message: Message) -> None: + @abc.abstractmethod + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ - Validate the provided message. + Target-specific send logic. + + Called by :meth:`send_prompt_async` after validation and normalization. Args: - message: The message to validate. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. - Raises: - ValueError: if the target does not support the provided message pieces or if the message - violates any constraints based on the target's capabilities. This includes checks - for the number of message pieces, supported data types, and multi-turn conversation support. + Returns: + list[Message]: Response messages from the target. + """ + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ + Validate the normalized conversation before sending to the target. + + Called after the normalization pipeline has run. Validates the last + message (the current request) for piece count, data types, and checks + whether the full conversation violates multi-turn constraints. + + Args: + normalized_conversation: The normalized conversation to validate. + The last element is the current request message. + + Raises: + ValueError: if the target does not support the provided message pieces or if the + conversation violates any constraints based on the target's capabilities. + """ + message = normalized_conversation[-1] n_pieces = len(message.message_pieces) - if n_pieces == 0: - raise ValueError("Message must contain at least one message piece. Received: 0 pieces.") custom_configuration_message = ( "If your target does support this, set the custom_configuration parameter accordingly." @@ -154,14 +194,66 @@ def _validate_request(self, *, message: Message) -> None: f"{custom_configuration_message}" ) - if not self.capabilities.supports_multi_turn: - request = message.message_pieces[0] - messages = self._memory.get_message_pieces(conversation_id=request.conversation_id) + if not self.capabilities.supports_multi_turn and len(normalized_conversation) > 1: + raise ValueError(f"This target only supports a single turn conversation. {custom_configuration_message}") - if len(messages) > 0: - raise ValueError( - f"This target only supports a single turn conversation. {custom_configuration_message}" - ) + async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: + """ + Fetch the conversation from memory, append the current message, and run the + normalization pipeline. + + The original conversation in memory is never mutated. The returned list is an + ephemeral copy intended only for building the API request body. + + After normalization, the metadata from the original ``message`` is copied + onto the last normalized message so that downstream code (e.g. + ``construct_response_from_request``) propagates the correct + ``conversation_id``, ``labels``, ``attack_identifier``, etc. to the response. + + Args: + message (Message): The current message to append. + + Returns: + list[Message]: The normalized conversation (possibly with system prompt squashed, + history squashed, etc.). + """ + conversation_id = message.message_pieces[0].conversation_id + conversation = list(self._memory.get_conversation(conversation_id=conversation_id)) + conversation.append(message) + normalized = await self.configuration.normalize_async(messages=conversation) + if normalized: + self._propagate_lineage(source=message, target_message=normalized[-1]) + return normalized + + @staticmethod + def _propagate_lineage(*, source: Message, target_message: Message) -> None: + """ + Copy request-lineage metadata from ``source`` onto every piece in ``target_message``. + + Normalizers may create brand-new ``Message`` objects (e.g. ``HistorySquashNormalizer`` + uses ``Message.from_prompt``) that carry fresh random ``conversation_id`` values and + lack ``labels``, ``attack_identifier``, etc. This method restores the original + metadata so that the response built from the normalized message stays part of the + correct conversation and retains traceability. + + Note: + Only the **last** message in the normalized list is stamped (the caller passes + ``normalized[-1]``). This is intentional: earlier messages in the list are + history entries fetched from memory that already carry correct metadata. + Only the final message — which corresponds to the current user turn and may + have been rebuilt by a normalizer — needs its lineage restored. + + Args: + source: The original (pre-normalization) message whose metadata is authoritative. + target_message: The normalized message whose pieces will be updated in place. + """ + source_piece = source.message_pieces[0] + for piece in target_message.message_pieces: + piece.conversation_id = source_piece.conversation_id + piece.labels = source_piece.labels + piece.attack_identifier = source_piece.attack_identifier + piece.prompt_target_identifier = source_piece.prompt_target_identifier + piece.prompt_metadata = source_piece.prompt_metadata def set_model_name(self, *, model_name: str) -> None: """ @@ -181,8 +273,8 @@ def dispose_db_engine(self) -> None: def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] | None = None, ) -> ComponentIdentifier: """ Construct the target identifier. @@ -195,9 +287,9 @@ def _create_identifier( to set the identifier with their specific parameters. Args: - params (Optional[Dict[str, Any]]): Additional behavioral parameters from + params (dict[str, Any] | None): Additional behavioral parameters from the subclass (e.g., temperature, top_p). Merged into the base params. - children (Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]]): + children (dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] | None): Named child component identifiers. Returns: @@ -243,7 +335,7 @@ def capabilities(self) -> TargetCapabilities: return self._configuration.capabilities @classmethod - def get_default_configuration(cls, underlying_model: Optional[str] = None) -> TargetConfiguration: + def get_default_configuration(cls, underlying_model: str | None = None) -> TargetConfiguration: """ Return the configuration for the given underlying model, falling back to the class-level ``_DEFAULT_CONFIGURATION`` when the model is not recognized. @@ -268,7 +360,7 @@ def get_default_configuration(cls, underlying_model: Optional[str] = None) -> Ta return cls._DEFAULT_CONFIGURATION @classmethod - def get_default_capabilities(cls, underlying_model: Optional[str] = None) -> TargetCapabilities: + def get_default_capabilities(cls, underlying_model: str | None = None) -> TargetCapabilities: """ Return the default capabilities for the given model. diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index 0daabd7fd..c7784996e 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -3,6 +3,8 @@ import logging import warnings +from collections.abc import Mapping +from typing import Any from pyrit.message_normalizer import MessageListNormalizer from pyrit.models import Message @@ -84,7 +86,7 @@ def __init__( *, capabilities: TargetCapabilities, policy: CapabilityHandlingPolicy | None = None, - normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None, + normalizer_overrides: Mapping[CapabilityName, MessageListNormalizer[Any]] | None = None, ) -> None: """ Build a target configuration and resolve the normalization pipeline. @@ -93,7 +95,7 @@ def __init__( capabilities (TargetCapabilities): The target's declared capabilities. policy (CapabilityHandlingPolicy | None): How to handle each missing capability. Defaults to RAISE for all adaptable capabilities. - normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): + normalizer_overrides (Mapping[CapabilityName, MessageListNormalizer[Any]] | None): Optional overrides for specific capability normalizers. """ self._capabilities = capabilities diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index f92326e66..99ca1c3b7 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -85,17 +85,19 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Gandalf target. Args: - message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. """ - self._validate_request(message=message) + message = normalized_conversation[-1] request = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index da805d50f..2a44024c3 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -155,17 +155,19 @@ def _inject_prompt_into_request(self, request: MessagePiece) -> str: return http_request_w_prompt @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the HTTP target. Args: - message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. """ - self._validate_request(message=message) + message = normalized_conversation[-1] request = message.message_pieces[0] http_request_w_prompt = self._inject_prompt_into_request(request) diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index 7555038b6..ae5e8989c 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -108,7 +108,7 @@ def __init__( raise ValueError(f"File uploads are not allowed with HTTP method: {self.method}") @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Override the parent's method to skip raw http_request usage, and do a standard "API mode" approach. @@ -125,7 +125,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: httpx.RequestError: If the request fails. FileNotFoundError: If the specified file to upload is not found. """ - self._validate_request(message=message) + message = normalized_conversation[-1] message_piece: MessagePiece = message.message_pieces[0] # If user didn't set file_path, see if the PDF path is in converted_value diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index ae0202600..7583a5efe 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -295,10 +295,15 @@ async def load_model_and_tokenizer(self) -> None: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a normalized prompt asynchronously to the HuggingFace model. + Args: + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. + Returns: list[Message]: A list containing the response object with generated text pieces. @@ -309,7 +314,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Load the model and tokenizer using the encapsulated method await self.load_model_and_tokenizer_task - self._validate_request(message=message) + message = normalized_conversation[-1] request = message.message_pieces[0] prompt_template = request.converted_value diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index 5c4e3d937..da47fa34a 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -87,13 +87,14 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a normalized prompt asynchronously to a cloud-based HuggingFace model endpoint. Args: - message (Message): The message containing the input data and associated details - such as conversation ID and role. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response object with generated text pieces. @@ -102,7 +103,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ValueError: If the response from the Hugging Face API is not successful. Exception: If an error occurs during the HTTP request to the Hugging Face endpoint. """ - self._validate_request(message=message) + message = normalized_conversation[-1] request = message.message_pieces[0] headers = {"Authorization": f"Bearer {self.hf_token}"} payload: dict[str, object] = { @@ -146,16 +147,17 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: logger.error(f"Error occurred during HTTP request to the Hugging Face endpoint: {e}") raise - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validate the provided message. Args: - message (Message): The message to validate. + normalized_conversation: The normalized conversation to validate. Raises: ValueError: If the request is not valid for this target. """ + message = normalized_conversation[-1] n_pieces = len(message.message_pieces) if n_pieces != 1: raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 514f0a7c2..da45377a8 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -72,6 +72,9 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget): supports_json_output=True, supports_multi_message_pieces=True, supports_system_prompt=True, + input_modalities=frozenset( + {frozenset({"text"}), frozenset({"image_path"}), frozenset({"text", "image_path"})} + ), ) ) @@ -240,28 +243,25 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously sends a message and handles the response within a managed conversation context. Args: - message (Message): The message object. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. """ - self._validate_request(message=message) - + message = normalized_conversation[-1] message_piece: MessagePiece = message.message_pieces[0] json_config = self._get_json_response_config(message_piece=message_piece) - # Get conversation from memory and append the current message - conversation = self._memory.get_conversation(conversation_id=message_piece.conversation_id) - conversation.append(message) - logger.info(f"Sending the following prompt to the prompt target: {message}") - body = await self._construct_request_body(conversation=conversation, json_config=json_config) + body = await self._construct_request_body(conversation=normalized_conversation, json_config=json_config) # Use unified error handling - automatically detects ChatCompletion and validates response = await self._handle_openai_request( diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 14a39f29a..c7a851d72 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -121,17 +121,19 @@ def _get_provider_examples(self) -> dict[str, str]: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI completion target. Args: - message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. """ - self._validate_request(message=message) + message = normalized_conversation[-1] message_piece = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {message_piece}") diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 2009c5ef7..7fb98da60 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -147,22 +147,24 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async( + async def _send_prompt_to_target_async( self, *, - message: Message, + normalized_conversation: list[Message], ) -> list[Message]: """ Send a prompt to the OpenAI image target and return the response. Supports both image generation (text input) and image editing (text + images input). Args: - message (Message): The message to send. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the image target. """ - self._validate_request(message=message) + message = normalized_conversation[-1] logger.info(f"Sending the following prompt to the prompt target: {message}") @@ -320,8 +322,9 @@ async def _get_image_bytes(self, image_data: Any) -> bytes: raise EmptyResponseException(message="The image generation returned an empty response.") - def _validate_request(self, *, message: Message) -> None: - super()._validate_request(message=message) + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + super()._validate_request(normalized_conversation=normalized_conversation) + message = normalized_conversation[-1] text_pieces = [p for p in message.message_pieces if p.converted_value_data_type == "text"] image_pieces = [p for p in message.message_pieces if p.converted_value_data_type == "image_path"] diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index cf947d7ca..2b8380e66 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -297,33 +297,39 @@ def _set_system_prompt_and_config_vars(self, system_prompt: str) -> dict[str, An return session_config - async def send_config(self, conversation_id: str) -> None: + async def send_config(self, *, conversation_id: str, normalized_conversation: list[Message] | None = None) -> None: """ Send the session configuration using OpenAI client. Args: conversation_id (str): Conversation ID - """ - # Extract system prompt from conversation history - system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id) + normalized_conversation (list[Message] | None): The normalized_conversation history to extract the system + prompt from. If None, the conversation is fetched from memory. Defaults to None. + """ + # Extract system prompt from conversation history. Use the conversation passed in if available, + # otherwise fetch from memory. + resolved_conversation = ( + normalized_conversation + if normalized_conversation is not None + else list(self._memory.get_conversation(conversation_id=conversation_id)) + ) + system_prompt = self._get_system_prompt_from_conversation(conversation=resolved_conversation) config_variables = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) connection = self._get_connection(conversation_id=conversation_id) await connection.session.update(session=config_variables) logger.info("Session configuration sent") - def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str: + def _get_system_prompt_from_conversation(self, *, conversation: list[Message]) -> str: """ Retrieve the system prompt from conversation history. Args: - conversation_id (str): The conversation ID + conversation (list[Message]): The conversation messages to search. Returns: str: The system prompt from conversation history, or a default if none found """ - conversation = self._memory.get_conversation(conversation_id=conversation_id) - # Look for a system message at the beginning of the conversation if conversation and len(conversation) > 0: first_message = conversation[0] @@ -335,12 +341,14 @@ def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI realtime target. Args: - message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. @@ -348,30 +356,33 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Raises: ValueError: If the message piece type is unsupported. """ + message = normalized_conversation[-1] conversation_id = message.message_pieces[0].conversation_id if conversation_id not in self._existing_conversation: connection = await self.connect(conversation_id=conversation_id) self._existing_conversation[conversation_id] = connection # Only send config when creating a new connection - await self.send_config(conversation_id=conversation_id) + await self.send_config(conversation_id=conversation_id, normalized_conversation=normalized_conversation) # Give the server a moment to process the session update await asyncio.sleep(0.5) - self._validate_request(message=message) - request = message.message_pieces[0] response_type = request.converted_value_data_type # Order of messages sent varies based on the data format of the prompt if response_type == "audio_path": output_audio_path, result = await self.send_audio_async( - filename=request.converted_value, conversation_id=conversation_id + filename=request.converted_value, + conversation_id=conversation_id, + conversation=normalized_conversation, ) elif response_type == "text": output_audio_path, result = await self.send_text_async( - text=request.converted_value, conversation_id=conversation_id + text=request.converted_value, + conversation_id=conversation_id, + conversation=normalized_conversation, ) else: raise ValueError(f"Unsupported response type: {response_type}") @@ -668,13 +679,20 @@ def _extract_error_details(*, response: Any) -> str: return f"[{error_type}] {error_message}" return "Unknown error occurred" - async def send_text_async(self, text: str, conversation_id: str) -> tuple[str, RealtimeTargetResult]: + async def send_text_async( + self, + *, + text: str, + conversation_id: str, + conversation: list[Message], + ) -> tuple[str, RealtimeTargetResult]: """ Send text prompt using OpenAI Realtime API client. Args: text: prompt to send. conversation_id: conversation ID + conversation: The normalized conversation history. Returns: Tuple[str, RealtimeTargetResult]: Path to saved audio file and the RealtimeTargetResult @@ -714,7 +732,7 @@ async def send_text_async(self, text: str, conversation_id: str) -> tuple[str, R self._existing_conversation[conversation_id] = new_connection # Send session configuration to new connection - system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = self._get_system_prompt_from_conversation(conversation=conversation) session_config = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) await new_connection.session.update(session=session_config) @@ -722,13 +740,20 @@ async def send_text_async(self, text: str, conversation_id: str) -> tuple[str, R output_audio_path = await self.save_audio(audio_bytes=result.audio_bytes, sample_rate=24000) return output_audio_path, result - async def send_audio_async(self, filename: str, conversation_id: str) -> tuple[str, RealtimeTargetResult]: + async def send_audio_async( + self, + *, + filename: str, + conversation_id: str, + conversation: list[Message], + ) -> tuple[str, RealtimeTargetResult]: """ Send an audio message using OpenAI Realtime API client. Args: filename (str): The path to the audio file. conversation_id (str): Conversation ID + conversation (list[Message]): The normalized conversation history. Returns: Tuple[str, RealtimeTargetResult]: Path to saved audio file and the RealtimeTargetResult @@ -783,7 +808,7 @@ async def send_audio_async(self, filename: str, conversation_id: str) -> tuple[s self._existing_conversation[conversation_id] = new_connection # Send session configuration to new connection - system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = self._get_system_prompt_from_conversation(conversation=conversation) session_config = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) await new_connection.session.update(session=session_config) diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 2948988c6..3e6e8c1bd 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -517,7 +517,7 @@ async def _construct_message_from_response(self, response: Any, request: Message @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send prompt, handle agentic tool calls (function_call), return all messages. @@ -526,25 +526,20 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: - Agentic tool-calling loops that may require multiple back-and-forth exchanges Args: - message: The initial prompt from the user. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: List of messages generated during the interaction (assistant responses and tool messages). The normalizer will persist all of these to memory. """ - self._validate_request(message=message) - + message = normalized_conversation[-1] message_piece: MessagePiece = message.message_pieces[0] - json_config = _JsonResponseConfig(enabled=False) - if message.message_pieces: - last_piece = message.message_pieces[-1] - json_config = self._get_json_response_config(message_piece=last_piece) - - # Get full conversation history from memory and append the current message - conversation: MutableSequence[Message] = self._memory.get_conversation( - conversation_id=message_piece.conversation_id - ) - conversation.append(message) + last_piece = message.message_pieces[-1] + json_config = self._get_json_response_config(message_piece=last_piece) + + working_conversation: MutableSequence[Message] = list(normalized_conversation) # Track all responses generated during this interaction responses_to_return: list[Message] = [] @@ -553,9 +548,9 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: tool_call_section: Optional[dict[str, Any]] = None while True: - logger.info(f"Sending conversation with {len(conversation)} messages to the prompt target") + logger.info(f"Sending conversation with {len(working_conversation)} messages to the prompt target") - body = await self._construct_request_body(conversation=conversation, json_config=json_config) + body = await self._construct_request_body(conversation=working_conversation, json_config=json_config) # Use unified error handling - automatically detects Response and validates result = await self._handle_openai_request( @@ -564,7 +559,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ) # Add result to conversation and responses list - conversation.append(result) + working_conversation.append(result) responses_to_return.append(result) # Extract tool call if present @@ -582,7 +577,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: tool_message = Message(message_pieces=[tool_piece], skip_validation=True) # Add tool output message to conversation and responses list - conversation.append(tool_message) + working_conversation.append(tool_message) responses_to_return.append(tool_message) # Continue loop to send tool result and get next response diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index bde1e2df9..3013bfb96 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -113,17 +113,19 @@ def _build_identifier(self) -> ComponentIdentifier: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the OpenAI TTS target. Args: - message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the audio response from the prompt target. """ - self._validate_request(message=message) + message = normalized_conversation[-1] message_piece = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {message_piece}") diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index ebbee5eeb..7d7737d5b 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -181,7 +181,7 @@ def _validate_duration(self) -> None: @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously sends a message and generates a video using the OpenAI SDK. @@ -195,7 +195,9 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: chained remixes. Args: - message: The message object containing the prompt. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: A list containing the response with the generated video path. @@ -204,7 +206,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: RateLimitException: If the rate limit is exceeded. ValueError: If the request is invalid. """ - self._validate_request(message=message) + message = normalized_conversation[-1] text_piece = message.get_piece_by_type(data_type="text") @@ -458,7 +460,7 @@ async def _save_video_response( prompt_metadata=prompt_metadata, ) - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validate the request message. @@ -468,12 +470,13 @@ def _validate_request(self, *, message: Message) -> None: - Text piece + video_path piece (remix mode via history lookup) Args: - message: The message to validate. + normalized_conversation: The normalized conversation to validate. Raises: ValueError: If the request is invalid. """ - super()._validate_request(message=message) + super()._validate_request(normalized_conversation=normalized_conversation) + message = normalized_conversation[-1] text_pieces = message.get_pieces_by_type(data_type="text") image_pieces = message.get_pieces_by_type(data_type="image_path") diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index de317ad7b..58d30f088 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -205,13 +205,14 @@ def _get_selectors(self) -> CopilotSelectors: file_picker_selector='span.fui-MenuItem__content:has-text("Upload images and files")', ) - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Send a message to Microsoft Copilot and return the response. Args: - message (Message): The message to send. Can contain multiple pieces - of type 'text' or 'image_path'. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from Copilot. @@ -219,7 +220,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Raises: RuntimeError: If an error occurs during interaction. """ - self._validate_request(message=message) + message = normalized_conversation[-1] try: response_content = await self._interact_with_copilot_async(message) diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index c97f5d803..c77bf4117 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -100,12 +100,14 @@ def __init__( self._page = page @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to the Playwright target. Args: - message (Message): The message object containing the prompt to send. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from the prompt target. @@ -113,7 +115,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Raises: RuntimeError: If the Playwright page is not initialized or if an error occurs during interaction. """ - self._validate_request(message=message) + message = normalized_conversation[-1] if not self._page: raise RuntimeError( "Playwright page is not initialized. Please pass a Page object when initializing PlaywrightTarget." diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 8bb249f35..9fb9b7d30 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -123,7 +123,7 @@ def _build_identifier(self) -> ComponentIdentifier: ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Parse the text in message to separate the userPrompt and documents contents, then send an HTTP request to the endpoint and obtain a response in JSON. For more info, visit @@ -132,8 +132,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Returns: list[Message]: A list containing the response object with generated text pieces. """ - self._validate_request(message=message) - + message = normalized_conversation[-1] request = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 4c736daf2..0c2976921 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -42,17 +42,19 @@ def __init__( super().__init__(custom_configuration=custom_configuration, custom_capabilities=custom_capabilities) self._text_stream = text_stream - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously write a message to the text stream. Args: - message (Message): The message object to write to the stream. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: An empty list (no response expected). """ - self._validate_request(message=message) + message = normalized_conversation[-1] self._text_stream.write(f"{str(message)}\n") self._text_stream.flush() @@ -95,7 +97,7 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: self._memory.add_message_pieces_to_memory(message_pieces=message_pieces) return message_pieces - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass async def cleanup_target(self) -> None: diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 9b5607827..67b797710 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -584,17 +584,18 @@ async def _connect_and_send( return response - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validate that the message meets target requirements. Args: - message (Message): The message to validate. + normalized_conversation (list[Message]): The normalized conversation to validate. Raises: ValueError: If message contains unsupported data types or invalid image formats. """ - super()._validate_request(message=message) + super()._validate_request(normalized_conversation=normalized_conversation) + message = normalized_conversation[-1] for piece in message.message_pieces: piece_type = piece.converted_value_data_type @@ -644,7 +645,7 @@ def _generate_consistent_copilot_ids(self, *, pyrit_conversation_id: str) -> tup @limit_requests_per_minute @pyrit_target_retry - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ Asynchronously send a message to Microsoft Copilot using WebSocket. @@ -653,7 +654,9 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: state server-side, so only the current message is sent (no explicit history required). Args: - message (Message): A message to be sent to the target. + normalized_conversation (list[Message]): The full conversation + (history + current message) after running the normalization + pipeline. The current message is the last element. Returns: list[Message]: A list containing the response from Copilot. @@ -663,7 +666,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: InvalidStatus: If the WebSocket handshake fails with an HTTP status error. RuntimeError: If any other error occurs during WebSocket communication. """ - self._validate_request(message=message) + message = normalized_conversation[-1] pyrit_conversation_id = message.message_pieces[0].conversation_id is_start_of_session = self._is_start_of_session(conversation_id=pyrit_conversation_id) diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index 5b872eb01..dd0e187ba 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -66,7 +66,8 @@ def set_system_prompt( ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + message = normalized_conversation[-1] self.prompt_sent.append(message.get_value()) return [ @@ -79,7 +80,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ).to_message() ] - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validates the provided message """ diff --git a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py index 68f69f5c8..07e039355 100644 --- a/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py +++ b/tests/partner_integration/azure_ai_evaluation/test_prompt_target_contract.py @@ -23,10 +23,10 @@ class _MinimalTarget(PromptTarget): """Minimal concrete PromptTarget for contract testing.""" - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] - def _validate_request(self, *, message) -> None: + def _validate_request(self, *, normalized_conversation) -> None: pass diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 0bfa55f60..84850dfc5 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -149,7 +149,8 @@ def set_system_prompt( ) @limit_requests_per_minute - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + message = normalized_conversation[-1] self.prompt_sent.append(message.get_value()) return [ @@ -162,7 +163,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ).to_message() ] - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validates the provided message """ diff --git a/tests/unit/prompt_target/target/test_http_api_target.py b/tests/unit/prompt_target/target/test_http_api_target.py index 3b14d2e70..16613ec02 100644 --- a/tests/unit/prompt_target/target/test_http_api_target.py +++ b/tests/unit/prompt_target/target/test_http_api_target.py @@ -148,14 +148,7 @@ async def test_send_prompt_async_missing_explicit_file_path_raises(mock_request, @pytest.mark.asyncio -@patch("httpx.AsyncClient.request") -async def test_send_prompt_async_validation(mock_request, patch_central_database): - # Create an invalid message (empty message_pieces) - message = MagicMock() - message.message_pieces = [] - target = HTTPXAPITarget(http_url="http://example.com/validate/", method="POST", timeout=180) - - with pytest.raises(ValueError) as excinfo: - await target.send_prompt_async(message=message) - - assert "Message must contain at least one message piece. Received: 0 pieces." in str(excinfo.value) +async def test_send_prompt_async_validation(patch_central_database): + # Creating a Message with no pieces raises immediately + with pytest.raises(ValueError, match="must have at least one message piece"): + Message(message_pieces=[]) diff --git a/tests/unit/prompt_target/target/test_http_target.py b/tests/unit/prompt_target/target/test_http_target.py index ee9ec61c3..0e6dbc2f1 100644 --- a/tests/unit/prompt_target/target/test_http_target.py +++ b/tests/unit/prompt_target/target/test_http_target.py @@ -7,6 +7,7 @@ import httpx import pytest +from pyrit.models import Message from pyrit.prompt_target.http_target.http_target import HTTPTarget from pyrit.prompt_target.http_target.http_target_callback_functions import ( get_http_target_json_response_callback_function, @@ -145,13 +146,9 @@ async def test_send_prompt_async_client_kwargs(patch_central_database): @pytest.mark.asyncio async def test_send_prompt_async_validation(mock_http_target): - # Create an invalid message (missing message_pieces) - invalid_message = MagicMock() - invalid_message.message_pieces = [] - with pytest.raises(ValueError) as value_error: - await mock_http_target.send_prompt_async(message=invalid_message) - - assert str(value_error.value) == "Message must contain at least one message piece. Received: 0 pieces." + # Creating a Message with no pieces raises immediately + with pytest.raises(ValueError, match="must have at least one message piece"): + Message(message_pieces=[]) @pytest.mark.asyncio diff --git a/tests/unit/prompt_target/target/test_huggingface_chat_target.py b/tests/unit/prompt_target/target/test_huggingface_chat_target.py index 0de8ddbbf..8148c278d 100644 --- a/tests/unit/prompt_target/target/test_huggingface_chat_target.py +++ b/tests/unit/prompt_target/target/test_huggingface_chat_target.py @@ -163,6 +163,7 @@ async def test_load_model_and_tokenizer(): @pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") @pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") async def test_send_prompt_async(): hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) await hf_chat.load_model_and_tokenizer() @@ -185,6 +186,7 @@ async def test_send_prompt_async(): @pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") @pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") async def test_missing_chat_template_error(): hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) await hf_chat.load_model_and_tokenizer() @@ -230,7 +232,7 @@ async def test_invalid_prompt_request_validation(): message = Message(message_pieces=[message_piece1, message_piece2]) with pytest.raises(ValueError) as excinfo: - hf_chat._validate_request(message=message) + hf_chat._validate_request(normalized_conversation=[message]) assert "This target only supports a single message piece." in str(excinfo.value) diff --git a/tests/unit/prompt_target/target/test_image_target.py b/tests/unit/prompt_target/target/test_image_target.py index 0d495535c..8f31b515c 100644 --- a/tests/unit/prompt_target/target/test_image_target.py +++ b/tests/unit/prompt_target/target/test_image_target.py @@ -518,8 +518,10 @@ async def test_validate_previous_conversations( ): message_piece = sample_conversations[0] + prior_message = Message(message_pieces=[message_piece]) + mock_memory = MagicMock() - mock_memory.get_message_pieces.return_value = sample_conversations + mock_memory.get_conversation.return_value = [prior_message] mock_memory.add_message_to_memory = AsyncMock() image_target._memory = mock_memory diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py new file mode 100644 index 000000000..401473930 --- /dev/null +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -0,0 +1,549 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import json +import warnings +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, MagicMock, patch + +if TYPE_CHECKING: + from collections.abc import MutableSequence + +import pytest +from openai.types.chat import ChatCompletion + +from pyrit.identifiers import ComponentIdentifier +from pyrit.memory.memory_interface import MemoryInterface +from pyrit.message_normalizer import GenericSystemSquashNormalizer +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import AzureMLChatTarget, OpenAIChatTarget +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.prompt_target.openai.openai_response_target import OpenAIResponseTarget + + +def _make_message_piece(*, role: str, content: str, conversation_id: str = "conv1") -> MessagePiece: + return MessagePiece( + role=role, + conversation_id=conversation_id, + original_value=content, + converted_value=content, + original_value_data_type="text", + converted_value_data_type="text", + prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), + attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), + ) + + +def _make_message(*, role: str, content: str, conversation_id: str = "conv1") -> Message: + return Message(message_pieces=[_make_message_piece(role=role, content=content, conversation_id=conversation_id)]) + + +def _create_mock_chat_completion(content: str = "hi") -> MagicMock: + mock = MagicMock(spec=ChatCompletion) + mock.choices = [MagicMock()] + mock.choices[0].finish_reason = "stop" + mock.choices[0].message.content = content + mock.choices[0].message.audio = None + mock.choices[0].message.tool_calls = None + mock.model_dump_json.return_value = json.dumps( + {"choices": [{"finish_reason": "stop", "message": {"content": content}}]} + ) + return mock + + +# --------------------------------------------------------------------------- +# OpenAIChatTarget — normalize_async is called +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_openai_chat_target_calls_normalize_async(): + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + mock_completion = _create_mock_chat_completion("world") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + with patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize: + mock_normalize.return_value = [user_msg] + await target.send_prompt_async(message=user_msg) + + mock_normalize.assert_called_once() + call_messages = mock_normalize.call_args.kwargs["messages"] + assert len(call_messages) == 1 + assert call_messages[0].get_value() == "hello" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_openai_chat_target_sends_normalized_to_construct_request(): + """Verify that the normalized (not original) conversation is used for the API body.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + user_msg = _make_message(role="user", content="original") + adapted_msg = _make_message(role="user", content="adapted") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + mock_completion = _create_mock_chat_completion("response") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + with ( + patch.object(target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg]), + patch.object( + target, "_construct_request_body", new_callable=AsyncMock, return_value={"model": "gpt-4o", "messages": []} + ) as mock_construct, + ): + await target.send_prompt_async(message=user_msg) + + # _construct_request_body should receive the adapted message, not the original + call_conv = mock_construct.call_args.kwargs["conversation"] + assert len(call_conv) == 1 + assert call_conv[0].get_value() == "adapted" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_openai_chat_target_memory_not_mutated(): + """Memory-backed conversation must not be altered by normalize_async.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_message(role="system", content="be nice") + user_msg = _make_message(role="user", content="hello") + + # Memory returns a conversation with a system message + memory_conversation: MutableSequence[Message] = [system_msg] + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = memory_conversation + target._memory = mock_memory + + mock_completion = _create_mock_chat_completion("response") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + await target.send_prompt_async(message=user_msg) + + # Memory-backed conversation must not be mutated by send_prompt_async + assert len(memory_conversation) == 1 + assert memory_conversation[0].get_piece().api_role == "system" + + +# --------------------------------------------------------------------------- +# OpenAIResponseTarget — normalize_async is called +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_openai_response_target_calls_normalize_async(): + target = OpenAIResponseTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + # Mock the API to return a simple response (no tool calls) + mock_response = MagicMock() + mock_response.error = None + mock_response.status = "completed" + mock_response.output = [MagicMock()] + mock_response.output[0].type = "message" + mock_response.output[0].content = [MagicMock()] + mock_response.output[0].content[0].type = "output_text" + mock_response.output[0].content[0].text = "world" + mock_response.model_dump_json.return_value = json.dumps( + {"output": [{"type": "message", "content": [{"type": "output_text", "text": "world"}]}]} + ) + target._async_client.responses.create = AsyncMock(return_value=mock_response) + + with patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize: + mock_normalize.return_value = [user_msg] + await target.send_prompt_async(message=user_msg) + + mock_normalize.assert_called_once() + call_messages = mock_normalize.call_args.kwargs["messages"] + assert len(call_messages) == 1 + + +# --------------------------------------------------------------------------- +# AzureMLChatTarget — normalize_async is called +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_azure_ml_target_calls_normalize_async(): + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + ) + + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + with ( + patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize, + patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"), + ): + mock_normalize.return_value = [user_msg] + await target.send_prompt_async(message=user_msg) + + mock_normalize.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_azure_ml_target_sends_normalized_to_complete_chat(): + """Normalized (not original) messages should be passed to _complete_chat_async.""" + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + ) + + user_msg = _make_message(role="user", content="original") + adapted_msg = _make_message(role="user", content="adapted") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + with ( + patch.object(target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg]), + patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat, + ): + await target.send_prompt_async(message=user_msg) + + call_messages = mock_chat.call_args.kwargs["messages"] + assert len(call_messages) == 1 + assert call_messages[0].get_value() == "adapted" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_azure_ml_target_memory_not_mutated(): + """Memory should retain original messages after normalization.""" + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_message(role="system", content="be nice") + user_msg = _make_message(role="user", content="hello") + + memory_conversation: MutableSequence[Message] = [system_msg] + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = memory_conversation + target._memory = mock_memory + + with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"): + await target.send_prompt_async(message=user_msg) + + # Memory must still have original system message only (not mutated) + assert len(memory_conversation) == 1 + assert memory_conversation[0].get_piece().api_role == "system" + + +# --------------------------------------------------------------------------- +# AzureMLChatTarget — message_normalizer deprecation +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +def test_azure_ml_generic_system_squash_normalizer_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + message_normalizer=GenericSystemSquashNormalizer(), + ) + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + assert "message_normalizer is deprecated" in str(deprecation_warnings[0].message) + + +@pytest.mark.usefixtures("patch_central_database") +def test_azure_ml_generic_system_squash_normalizer_creates_adapt_configuration(): + """Legacy message_normalizer should be translated into a TargetConfiguration with ADAPT policy.""" + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + message_normalizer=GenericSystemSquashNormalizer(), + ) + # The shim should create a config with supports_system_prompt=False + assert not target.capabilities.supports_system_prompt + assert target.configuration.includes(capability=CapabilityName.MULTI_TURN) + assert not target.configuration.includes(capability=CapabilityName.SYSTEM_PROMPT) + + +@pytest.mark.usefixtures("patch_central_database") +def test_azure_ml_message_normalizer_and_custom_config_raises(): + """Passing both message_normalizer and custom_configuration should raise ValueError.""" + custom_config = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=True, + supports_multi_message_pieces=True, + ) + ) + with pytest.raises(ValueError, match="Cannot specify both"): + AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + message_normalizer=GenericSystemSquashNormalizer(), + custom_configuration=custom_config, + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_azure_ml_system_squash_via_configuration_pipeline(): + """End-to-end: GenericSystemSquashNormalizer-equivalent behavior via TargetConfiguration pipeline.""" + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="valid_api_key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_message(role="system", content="be nice") + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [system_msg] + target._memory = mock_memory + + with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat: + await target.send_prompt_async(message=user_msg) + + # _complete_chat_async should receive normalized messages (system squashed into user) + call_messages = mock_chat.call_args.kwargs["messages"] + roles = [m.get_piece().api_role for m in call_messages] + assert "system" not in roles + # The squashed message should contain the system content + assert "be nice" in call_messages[0].get_value() + assert "hello" in call_messages[0].get_value() + + +# --------------------------------------------------------------------------- +# _get_normalized_conversation_async — unit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_fetches_history_and_appends_message(): + """The method should fetch history from memory, append the current message, and return them.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + history_msg = _make_message(role="assistant", content="previous answer") + user_msg = _make_message(role="user", content="new question") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [history_msg] + target._memory = mock_memory + + result = await target._get_normalized_conversation_async(message=user_msg) + + mock_memory.get_conversation.assert_called_once_with(conversation_id="conv1") + assert len(result) == 2 + assert result[0].get_value() == "previous answer" + assert result[1].get_value() == "new question" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_empty_history(): + """When memory has no history, the result should contain only the current message.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + result = await target._get_normalized_conversation_async(message=user_msg) + + assert len(result) == 1 + assert result[0].get_value() == "hello" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_does_not_mutate_memory(): + """The original memory-backed list must not be modified by the method.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + history_msg = _make_message(role="assistant", content="old") + user_msg = _make_message(role="user", content="new") + + memory_list: MutableSequence[Message] = [history_msg] + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = memory_list + target._memory = mock_memory + + await target._get_normalized_conversation_async(message=user_msg) + + # Memory list must still have only the original message + assert len(memory_list) == 1 + assert memory_list[0].get_value() == "old" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_runs_pipeline(): + """The method should invoke the normalization pipeline on the assembled conversation.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_message(role="system", content="be helpful") + user_msg = _make_message(role="user", content="hi") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [system_msg] + target._memory = mock_memory + + result = await target._get_normalized_conversation_async(message=user_msg) + + # System-squash normalizer should merge system into user + assert len(result) == 1 + assert "be helpful" in result[0].get_value() + assert "hi" in result[0].get_value() + roles = [m.get_piece().api_role for m in result] + assert "system" not in roles + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_get_normalized_conversation_passthrough_when_no_adaptation_needed(): + """When the target supports all capabilities, the pipeline should pass messages through unchanged.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + system_msg = _make_message(role="system", content="be nice") + user_msg = _make_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [system_msg] + target._memory = mock_memory + + result = await target._get_normalized_conversation_async(message=user_msg) + + # No adaptation — messages pass through as-is + assert len(result) == 2 + assert result[0].get_piece().api_role == "system" + assert result[0].get_value() == "be nice" + assert result[1].get_piece().api_role == "user" + assert result[1].get_value() == "hello" diff --git a/tests/unit/prompt_target/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py index 1683dd58c..1b3c7da16 100644 --- a/tests/unit/prompt_target/target/test_openai_chat_target.py +++ b/tests/unit/prompt_target/target/test_openai_chat_target.py @@ -322,6 +322,7 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j return_value=mock_completion ) target._memory = MagicMock(MemoryInterface) + target._memory.get_conversation.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -500,6 +501,7 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di return_value=mock_completion ) target._memory = MagicMock(MemoryInterface) + target._memory.get_conversation.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -581,7 +583,7 @@ def test_validate_request_unsupported_data_types(target: OpenAIChatTarget): ) with pytest.raises(ValueError) as excinfo: - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) assert "This target supports only the following data types" in str(excinfo.value), ( "Error not raised for unsupported data types" diff --git a/tests/unit/prompt_target/target/test_openai_response_target.py b/tests/unit/prompt_target/target/test_openai_response_target.py index cb39b64e7..0bd517ae4 100644 --- a/tests/unit/prompt_target/target/test_openai_response_target.py +++ b/tests/unit/prompt_target/target/test_openai_response_target.py @@ -343,6 +343,7 @@ async def test_send_prompt_async_empty_response_adds_to_memory( ): target._async_client.responses.create = AsyncMock(return_value=mock_response) # type: ignore[method-assign] target._memory = MagicMock(MemoryInterface) + target._memory.get_conversation.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -485,6 +486,7 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di ): target._async_client.responses.create = AsyncMock(return_value=mock_response) # type: ignore[method-assign] target._memory = MagicMock(MemoryInterface) + target._memory.get_conversation.return_value = [] with pytest.raises(EmptyResponseException): await target.send_prompt_async(message=message) @@ -584,7 +586,7 @@ def test_validate_request_unsupported_data_types(target: OpenAIResponseTarget): ) with pytest.raises(ValueError) as excinfo: - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) assert "This target supports only the following data types" in str(excinfo.value), ( "Error not raised for unsupported data types" @@ -657,7 +659,7 @@ def test_validate_request_allows_text_and_image(target: OpenAIResponseTarget): ), ] ) - target._validate_request(message=req) + target._validate_request(normalized_conversation=[req]) def test_validate_request_raises_for_invalid_type(target: OpenAIResponseTarget): @@ -667,7 +669,7 @@ def test_validate_request_raises_for_invalid_type(target: OpenAIResponseTarget): ] ) with pytest.raises(ValueError) as excinfo: - target._validate_request(message=req) + target._validate_request(normalized_conversation=[req]) assert "This target supports only the following data types" in str(excinfo.value) diff --git a/tests/unit/prompt_target/target/test_openai_target_auth.py b/tests/unit/prompt_target/target/test_openai_target_auth.py index ac8444257..18c8037d6 100644 --- a/tests/unit/prompt_target/target/test_openai_target_auth.py +++ b/tests/unit/prompt_target/target/test_openai_target_auth.py @@ -20,7 +20,6 @@ def _set_openai_env_configuration_vars(self) -> None: self.model_name_environment_variable = "TEST_MODEL" self.endpoint_environment_variable = "TEST_ENDPOINT" self.api_key_environment_variable = "TEST_API_KEY" - self.underlying_model_environment_variable = "TEST_UNDERLYING_MODEL" def _get_target_api_paths(self) -> list[str]: return [] @@ -31,10 +30,10 @@ def _get_provider_examples(self) -> dict[str, str]: async def _construct_message_from_response(self, response, request): raise NotImplementedError - def _validate_request(self, *, message) -> None: + def _validate_request(self, *, normalized_conversation) -> None: pass - async def send_prompt_async(self, *, message): + async def _send_prompt_to_target_async(self, *, normalized_conversation): raise NotImplementedError diff --git a/tests/unit/prompt_target/target/test_playwright_copilot_target.py b/tests/unit/prompt_target/target/test_playwright_copilot_target.py index 8a5aaf1d9..7e19cdc9a 100644 --- a/tests/unit/prompt_target/target/test_playwright_copilot_target.py +++ b/tests/unit/prompt_target/target/test_playwright_copilot_target.py @@ -155,7 +155,7 @@ def test_validate_request_unsupported_type(self, mock_page): match=r"This target supports only the following data types.*If your target does support this, set the" r" custom_configuration parameter accordingly", ): - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_valid_text(self, mock_page, text_request_piece): """Test validation with valid text request.""" @@ -163,14 +163,14 @@ def test_validate_request_valid_text(self, mock_page, text_request_piece): request = Message(message_pieces=[text_request_piece]) # Should not raise any exception - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_valid_multimodal(self, mock_page, multimodal_request): """Test validation with valid multimodal request.""" target = PlaywrightCopilotTarget(page=mock_page) # Should not raise any exception - target._validate_request(message=multimodal_request) + target._validate_request(normalized_conversation=[multimodal_request]) @pytest.mark.asyncio async def test_send_text_async(self, mock_page): diff --git a/tests/unit/prompt_target/target/test_playwright_target.py b/tests/unit/prompt_target/target/test_playwright_target.py index 4df935f09..f91021ccb 100644 --- a/tests/unit/prompt_target/target/test_playwright_target.py +++ b/tests/unit/prompt_target/target/test_playwright_target.py @@ -129,7 +129,7 @@ def test_validate_request_unsupported_type(self, mock_interaction_func, mock_pag match=r"This target supports only the following data types.*If your target does support this, set the" r" custom_configuration parameter accordingly", ): - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_valid_text(self, mock_interaction_func, mock_page, text_message_piece): """Test validation with valid text request.""" @@ -137,7 +137,7 @@ def test_validate_request_valid_text(self, mock_interaction_func, mock_page, tex request = Message(message_pieces=[text_message_piece]) # Should not raise any exception - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_valid_image(self, mock_interaction_func, mock_page, image_message_piece): """Test validation with valid image request.""" @@ -145,7 +145,7 @@ def test_validate_request_valid_image(self, mock_interaction_func, mock_page, im request = Message(message_pieces=[image_message_piece]) # Should not raise any exception - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) def test_validate_request_mixed_valid_types( self, mock_interaction_func, mock_page, text_message_piece, image_message_piece @@ -155,7 +155,7 @@ def test_validate_request_mixed_valid_types( request = Message(message_pieces=[text_message_piece, image_message_piece]) # Should not raise any exception - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) @pytest.mark.asyncio async def test_send_prompt_async_single_text(self, mock_interaction_func, mock_page, text_message_piece): @@ -345,7 +345,7 @@ def test_validate_request_multiple_unsupported_types(self, mock_interaction_func match=r"This target supports only the following data types.*If your target does support this, set the" r" custom_configuration parameter accordingly", ): - target._validate_request(message=request) + target._validate_request(normalized_conversation=[request]) @pytest.mark.asyncio async def test_interaction_function_with_complex_response(self, mock_page): diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index 08ce7f217..e73e8740d 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -1,16 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json from collections.abc import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch import pytest +from openai.types.chat import ChatCompletion from unit.mocks import get_sample_conversations, openai_chat_response_json_dict from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.identifiers import ComponentIdentifier +from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece from pyrit.prompt_target import OpenAIChatTarget +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration @pytest.fixture @@ -150,3 +160,243 @@ async def test_send_prompt_async_with_delay( mock_create.assert_called_once() mock_sleep.assert_called_once_with(6) # 60/max_requests_per_minute + + +# --------------------------------------------------------------------------- +# _propagate_lineage — metadata preservation after normalization +# --------------------------------------------------------------------------- + +_LINEAGE_CONVERSATION_ID = "original-conv-id-12345" +_LINEAGE_LABELS = {"op_name": "test_op", "user_id": "user42"} +_LINEAGE_ATTACK_IDENTIFIER = ComponentIdentifier(class_name="TestAttack", class_module="tests.attacks") +_LINEAGE_PROMPT_TARGET_IDENTIFIER = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit") +_LINEAGE_PROMPT_METADATA = {"scenario": "test_scenario", "turn": 3} + + +def _make_lineage_piece(*, role: str, content: str) -> MessagePiece: + return MessagePiece( + role=role, + conversation_id=_LINEAGE_CONVERSATION_ID, + original_value=content, + converted_value=content, + original_value_data_type="text", + converted_value_data_type="text", + labels=dict(_LINEAGE_LABELS), + prompt_target_identifier=_LINEAGE_PROMPT_TARGET_IDENTIFIER, + attack_identifier=_LINEAGE_ATTACK_IDENTIFIER, + prompt_metadata=dict(_LINEAGE_PROMPT_METADATA), + ) + + +def _make_lineage_message(*, role: str, content: str) -> Message: + return Message(message_pieces=[_make_lineage_piece(role=role, content=content)]) + + +def _make_mock_chat_completion(content: str = "response") -> MagicMock: + mock = MagicMock(spec=ChatCompletion) + mock.choices = [MagicMock()] + mock.choices[0].finish_reason = "stop" + mock.choices[0].message.content = content + mock.choices[0].message.audio = None + mock.choices[0].message.tool_calls = None + mock.model_dump_json.return_value = json.dumps( + {"choices": [{"finish_reason": "stop", "message": {"content": content}}]} + ) + return mock + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_history_squash_preserves_metadata_on_normalized_message(): + """ + After history squash, _propagate_lineage should restore the original request's + metadata (conversation_id, labels, attack_identifier) onto the squashed message. + """ + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=True, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ), + ), + ) + + history_msg = _make_lineage_message(role="assistant", content="previous answer") + user_msg = _make_lineage_message(role="user", content="follow-up question") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [history_msg] + target._memory = mock_memory + + normalized = await target._get_normalized_conversation_async(message=user_msg) + + assert len(normalized) == 1 + + normalized_piece = normalized[0].message_pieces[0] + + assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID + assert normalized_piece.labels == _LINEAGE_LABELS + assert normalized_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER + assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER + assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_response_preserves_metadata_after_history_squash(): + """ + End-to-end: after history squash the response must carry the original + request's conversation_id, labels, and attack_identifier — not the + random values created by the normalizer. + """ + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=True, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ), + ), + ) + + history_msg = _make_lineage_message(role="assistant", content="previous answer") + user_msg = _make_lineage_message(role="user", content="follow-up question") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [history_msg] + target._memory = mock_memory + + mock_completion = _make_mock_chat_completion("target response") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + response_messages = await target.send_prompt_async(message=user_msg) + + assert len(response_messages) == 1 + response_piece = response_messages[0].message_pieces[0] + + assert response_piece.conversation_id == _LINEAGE_CONVERSATION_ID + assert response_piece.labels == _LINEAGE_LABELS + assert response_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER + assert response_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER + assert response_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_system_squash_preserves_metadata(): + """ + GenericSystemSquashNormalizer also creates messages via Message.from_prompt. + _propagate_lineage should restore the original metadata after system squash too. + """ + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + } + ), + ), + ) + + system_msg = _make_lineage_message(role="system", content="be helpful") + user_msg = _make_lineage_message(role="user", content="hello") + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [system_msg] + target._memory = mock_memory + + normalized = await target._get_normalized_conversation_async(message=user_msg) + + assert len(normalized) == 1 + assert "be helpful" in normalized[0].get_value() + + normalized_piece = normalized[0].message_pieces[0] + + assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID + assert normalized_piece.labels == _LINEAGE_LABELS + assert normalized_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER + assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER + assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_history_squash_propagates_lineage_to_all_pieces(): + """ + When the squashed message contains multiple pieces, _propagate_lineage + must stamp every piece — not just the first one. + """ + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=True, + supports_multi_message_pieces=True, + input_modalities=frozenset({frozenset(["text"])}), + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ), + ), + ) + + history_msg = _make_lineage_message(role="assistant", content="previous answer") + # Build a user message with two pieces to exercise multi-piece stamping. + user_msg = Message( + message_pieces=[ + _make_lineage_piece(role="user", content="first part"), + _make_lineage_piece(role="user", content="second part"), + ] + ) + + mock_memory = MagicMock(spec=MemoryInterface) + mock_memory.get_conversation.return_value = [history_msg] + target._memory = mock_memory + + normalized = await target._get_normalized_conversation_async(message=user_msg) + + assert len(normalized) == 1 + + for piece in normalized[0].message_pieces: + assert piece.conversation_id == _LINEAGE_CONVERSATION_ID + assert piece.labels == _LINEAGE_LABELS + assert piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER + assert piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER + assert piece.prompt_metadata == _LINEAGE_PROMPT_METADATA diff --git a/tests/unit/prompt_target/target/test_realtime_target.py b/tests/unit/prompt_target/target/test_realtime_target.py index 88f047722..bc66b773f 100644 --- a/tests/unit/prompt_target/target/test_realtime_target.py +++ b/tests/unit/prompt_target/target/test_realtime_target.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest @@ -64,6 +64,7 @@ async def test_send_prompt_async(target): target.send_text_async.assert_called_once_with( text="Hello", conversation_id="test_conversation_id", + conversation=ANY, ) assert response[0].get_value() == "hello" assert response[0].get_value(1) == "output.wav" @@ -75,23 +76,21 @@ async def test_send_prompt_async(target): @pytest.mark.asyncio async def test_get_system_prompt_from_conversation_with_system_message(target): """Test that system prompt is extracted from conversation history when present.""" - conversation_id = "test_conversation_with_system" - # Add a system message to memory + # Create a system message system_message = Message( message_pieces=[ MessagePiece( role="system", original_value="You are a helpful assistant specialized in security.", converted_value="You are a helpful assistant specialized in security.", - conversation_id=conversation_id, + conversation_id="test_conversation_with_system", ) ] ) - target._memory.add_message_to_memory(request=system_message) # Get the system prompt - system_prompt = target._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = target._get_system_prompt_from_conversation(conversation=[system_message]) assert system_prompt == "You are a helpful assistant specialized in security." @@ -99,23 +98,21 @@ async def test_get_system_prompt_from_conversation_with_system_message(target): @pytest.mark.asyncio async def test_get_system_prompt_from_conversation_default(target): """Test that default system prompt is returned when no system message in conversation.""" - conversation_id = "test_conversation_no_system" - # Add a user message (no system message) + # Create a user message (no system message) user_message = Message( message_pieces=[ MessagePiece( role="user", original_value="Hello", converted_value="Hello", - conversation_id=conversation_id, + conversation_id="test_conversation_no_system", ) ] ) - target._memory.add_message_to_memory(request=user_message) # Get the system prompt - system_prompt = target._get_system_prompt_from_conversation(conversation_id=conversation_id) + system_prompt = target._get_system_prompt_from_conversation(conversation=[user_message]) assert system_prompt == "You are a helpful AI assistant" @@ -123,10 +120,9 @@ async def test_get_system_prompt_from_conversation_default(target): @pytest.mark.asyncio async def test_get_system_prompt_empty_conversation(target): """Test that default system prompt is returned for empty conversation.""" - conversation_id = "test_empty_conversation" - # Get the system prompt without adding any messages - system_prompt = target._get_system_prompt_from_conversation(conversation_id=conversation_id) + # Get the system prompt without any messages + system_prompt = target._get_system_prompt_from_conversation(conversation=[]) assert system_prompt == "You are a helpful AI assistant" @@ -185,7 +181,7 @@ async def test_send_prompt_async_invalid_request(target): ) message = Message(message_pieces=[message_piece]) with pytest.raises(ValueError) as excinfo: - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) assert "This target supports only the following data types" in str(excinfo.value) assert "image_path" in str(excinfo.value) diff --git a/tests/unit/prompt_target/target/test_target_capabilities.py b/tests/unit/prompt_target/target/test_target_capabilities.py index 7cd58793a..5c34221fc 100644 --- a/tests/unit/prompt_target/target/test_target_capabilities.py +++ b/tests/unit/prompt_target/target/test_target_capabilities.py @@ -493,7 +493,7 @@ def _make_target_class(self, *, default_config: "TargetConfiguration"): class _ConcreteTarget(PromptTarget): _DEFAULT_CONFIGURATION = default_config - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] return _ConcreteTarget diff --git a/tests/unit/prompt_target/target/test_tts_target.py b/tests/unit/prompt_target/target/test_tts_target.py index ec595b55b..a9b1f5832 100644 --- a/tests/unit/prompt_target/target/test_tts_target.py +++ b/tests/unit/prompt_target/target/test_tts_target.py @@ -90,8 +90,10 @@ async def test_tts_validate_previous_conversations( ): message_piece = sample_conversations[0] + prior_message = Message(message_pieces=[message_piece]) + mock_memory = MagicMock() - mock_memory.get_message_pieces.return_value = sample_conversations + mock_memory.get_conversation.return_value = [prior_message] mock_memory.add_message_to_memory = AsyncMock() tts_target._memory = mock_memory diff --git a/tests/unit/prompt_target/target/test_video_target.py b/tests/unit/prompt_target/target/test_video_target.py index faba7830a..250f6d5d4 100644 --- a/tests/unit/prompt_target/target/test_video_target.py +++ b/tests/unit/prompt_target/target/test_video_target.py @@ -81,7 +81,7 @@ def test_video_validate_request_multiple_text_pieces(video_target: OpenAIVideoTa msg2 = MessagePiece( role="user", original_value="test2", converted_value="test2", conversation_id=conversation_id ) - video_target._validate_request(message=Message([msg1, msg2])) + video_target._validate_request(normalized_conversation=[Message([msg1, msg2])]) def test_video_validate_prompt_type_image_only(video_target: OpenAIVideoTarget): @@ -90,7 +90,7 @@ def test_video_validate_prompt_type_image_only(video_target: OpenAIVideoTarget): msg = MessagePiece( role="user", original_value="test", converted_value="test", converted_value_data_type="image_path" ) - video_target._validate_request(message=Message([msg])) + video_target._validate_request(normalized_conversation=[Message([msg])]) @pytest.mark.asyncio @@ -374,7 +374,7 @@ def test_validate_accepts_text_only(self, video_target: OpenAIVideoTarget): """Test validation accepts single text piece (text-to-video mode).""" msg = MessagePiece(role="user", original_value="test prompt", converted_value="test prompt") # Should not raise - video_target._validate_request(message=Message([msg])) + video_target._validate_request(normalized_conversation=[Message([msg])]) def test_validate_accepts_text_and_image(self, video_target: OpenAIVideoTarget): """Test validation accepts text + image (image-to-video mode).""" @@ -393,7 +393,7 @@ def test_validate_accepts_text_and_image(self, video_target: OpenAIVideoTarget): conversation_id=conversation_id, ) # Should not raise - video_target._validate_request(message=Message([msg_text, msg_image])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_image])]) def test_validate_rejects_multiple_images(self, video_target: OpenAIVideoTarget): """Test validation rejects multiple image pieces.""" @@ -419,7 +419,7 @@ def test_validate_rejects_multiple_images(self, video_target: OpenAIVideoTarget) conversation_id=conversation_id, ) with pytest.raises(ValueError, match="at most 1 image piece"): - video_target._validate_request(message=Message([msg_text, msg_img1, msg_img2])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_img1, msg_img2])]) def test_validate_rejects_unsupported_types(self, video_target: OpenAIVideoTarget): """Test validation rejects unsupported data types.""" @@ -442,7 +442,7 @@ def test_validate_rejects_unsupported_types(self, video_target: OpenAIVideoTarge match="This target supports only the following data types.*If your target does support this, set the" " custom_configuration parameter accordingly", ): - video_target._validate_request(message=Message([msg_text, msg_audio])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_audio])]) def test_validate_rejects_remix_with_image(self, video_target: OpenAIVideoTarget): """Test validation rejects remix mode combined with image input.""" @@ -462,7 +462,7 @@ def test_validate_rejects_remix_with_image(self, video_target: OpenAIVideoTarget conversation_id=conversation_id, ) with pytest.raises(ValueError, match="Cannot use image input in remix mode"): - video_target._validate_request(message=Message([msg_text, msg_image])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_image])]) @pytest.mark.usefixtures("patch_central_database") @@ -789,7 +789,7 @@ def test_validate_rejects_no_text_piece(self, video_target: OpenAIVideoTarget): converted_value_data_type="image_path", ) with pytest.raises(ValueError, match="Expected exactly 1 text piece"): - video_target._validate_request(message=Message([msg])) + video_target._validate_request(normalized_conversation=[Message([msg])]) @pytest.mark.asyncio async def test_image_to_video_with_jpeg(self, video_target: OpenAIVideoTarget): @@ -993,12 +993,7 @@ def test_video_validate_previous_conversations( ): message_piece = sample_conversations[0] - mock_memory = MagicMock() - mock_memory.get_message_pieces.return_value = sample_conversations - mock_memory.add_message_to_memory = AsyncMock() - - video_target._memory = mock_memory - + prior_message = Message(message_pieces=[message_piece]) request = Message(message_pieces=[message_piece]) with pytest.raises( @@ -1006,7 +1001,7 @@ def test_video_validate_previous_conversations( match="This target only supports a single turn conversation.*If your target does support this, set the" " custom_configuration parameter accordingly", ): - video_target._validate_request(message=request) + video_target._validate_request(normalized_conversation=[prior_message, request]) @pytest.mark.usefixtures("patch_central_database") @@ -1039,7 +1034,7 @@ def test_validate_accepts_text_and_video_path(self, video_target: OpenAIVideoTar conversation_id=conversation_id, ) # Should not raise - video_target._validate_request(message=Message([msg_text, msg_video])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_video])]) def test_validate_rejects_video_path_and_image_path(self, video_target: OpenAIVideoTarget) -> None: """Test validation rejects combining video_path and image_path.""" @@ -1065,7 +1060,7 @@ def test_validate_rejects_video_path_and_image_path(self, video_target: OpenAIVi conversation_id=conversation_id, ) with pytest.raises(ValueError, match="Cannot combine video_path and image_path"): - video_target._validate_request(message=Message([msg_text, msg_video, msg_image])) + video_target._validate_request(normalized_conversation=[Message([msg_text, msg_video, msg_image])]) def test_remix_keeps_video_path_pieces_when_ids_match(self, video_target: OpenAIVideoTarget) -> None: """Test that video_path pieces are preserved after validation so normalizer stores them.""" diff --git a/tests/unit/prompt_target/target/test_websocket_copilot_target.py b/tests/unit/prompt_target/target/test_websocket_copilot_target.py index c6b6ce6c0..227c11352 100644 --- a/tests/unit/prompt_target/target/test_websocket_copilot_target.py +++ b/tests/unit/prompt_target/target/test_websocket_copilot_target.py @@ -743,19 +743,19 @@ def test_validate_request_data_types(self, mock_authenticator, make_message_piec message = Message(message_pieces=[message_piece]) if should_pass: - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) else: with pytest.raises( ValueError, match=f"This target supports only the following data types: image_path, text. Received: {data_type}.", ): - target._validate_request(message=message) + target._validate_request(normalized_conversation=[message]) def test_validate_request_with_multiple_text_pieces(self, mock_authenticator, make_message_piece): target = WebSocketCopilotTarget(authenticator=mock_authenticator) message_pieces = [make_message_piece(f"test{i}", conversation_id="123") for i in range(3)] message = Message(message_pieces=message_pieces) - target._validate_request(message=message) # should not raise + target._validate_request(normalized_conversation=[message]) # should not raise def test_validate_request_with_mixed_valid_content(self, mock_authenticator, make_message_piece): target = WebSocketCopilotTarget(authenticator=mock_authenticator) @@ -764,7 +764,7 @@ def test_validate_request_with_mixed_valid_content(self, mock_authenticator, mak make_message_piece("/path/to/image.png", data_type="image_path", conversation_id="123"), ] message = Message(message_pieces=message_pieces) - target._validate_request(message=message) # should not raise + target._validate_request(normalized_conversation=[message]) # should not raise @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/prompt_target/test_prompt_chat_target.py b/tests/unit/prompt_target/test_prompt_chat_target.py index 872e5dec7..3ffcb6341 100644 --- a/tests/unit/prompt_target/test_prompt_chat_target.py +++ b/tests/unit/prompt_target/test_prompt_chat_target.py @@ -105,7 +105,7 @@ def test_init_subclass_promotes_default_capabilities_with_warning(): class _LegacyTarget(PromptTarget): _DEFAULT_CAPABILITIES = TargetCapabilities(supports_multi_turn=True) - async def send_prompt_async(self, *, message: Message) -> list[Message]: + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: return [] deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index 3f05624e9..137c138f4 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -17,10 +17,10 @@ class MockPromptTarget(PromptTarget): def __init__(self, *, model_name: str = "mock_model") -> None: super().__init__(model_name=model_name) - async def send_prompt_async( + async def _send_prompt_to_target_async( self, *, - message: Message, + normalized_conversation: list[Message], ) -> list[Message]: return [ MessagePiece( @@ -29,7 +29,7 @@ async def send_prompt_async( ).to_message() ] - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass @@ -39,10 +39,10 @@ class MockPromptChatTarget(PromptChatTarget): def __init__(self, *, model_name: str = "mock_chat_model", endpoint: str = "http://chat-test") -> None: super().__init__(model_name=model_name, endpoint=endpoint) - async def send_prompt_async( + async def _send_prompt_to_target_async( self, *, - message: Message, + normalized_conversation: list[Message], ) -> list[Message]: return [ MessagePiece( @@ -51,7 +51,7 @@ async def send_prompt_async( ).to_message() ] - def _validate_request(self, *, message: Message) -> None: + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: pass