diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 4411a56..ec39ace 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -1,18 +1,28 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Union import asyncio import base64 import logging import uuid +import enum import aiohttp from aiortc import MediaStreamTrack from .webrtc_manager import WebRTCManager, WebRTCConfiguration -from .messages import PromptMessage, SetAvatarImageMessage +from .messages import PromptMessage, SetAvatarImageMessage, SetParamsMessage from .types import ConnectionState, RealtimeConnectOptions from ..types import FileInput from ..errors import DecartSDKError, InvalidInputError, WebRTCError from ..process.request import file_input_to_bytes + +class _Unset(enum.Enum): + """Sentinel to distinguish 'not provided' from None (which means 'clear').""" + + UNSET = "UNSET" + + +UNSET = _Unset.UNSET + logger = logging.getLogger(__name__) @@ -184,6 +194,54 @@ async def set_image(self, image: FileInput) -> None: finally: self._manager.unregister_image_set_wait() + async def set( + self, + *, + prompt: Optional[str] = None, + enhance: Optional[bool] = None, + image: Union[FileInput, None, _Unset] = UNSET, + ) -> None: + image_provided = image is not UNSET + + if prompt is None and not image_provided: + raise InvalidInputError("At least one of 'prompt' or 'image' must be provided") + + fields: dict = {"type": "set"} + + if prompt is not None: + if not prompt.strip(): + raise InvalidInputError("Prompt cannot be empty") + fields["prompt"] = prompt + + if enhance is not None: + fields["enhance_prompt"] = enhance + + if image_provided: + if image is None: + fields["image_data"] = None + else: + if not self._http_session: + raise InvalidInputError("HTTP session not available") + image_bytes, _ = await file_input_to_bytes(image, self._http_session) + fields["image_data"] = base64.b64encode(image_bytes).decode("utf-8") + + message = SetParamsMessage(**fields) + + event, result = self._manager.register_set_wait() + + try: + await self._manager.send_message(message) + + try: + await asyncio.wait_for(event.wait(), timeout=30.0) + except asyncio.TimeoutError: + raise DecartSDKError("Set acknowledgment timed out") + + if not result["success"]: + raise DecartSDKError(result.get("error") or "Set failed") + finally: + self._manager.unregister_set_wait() + def is_connected(self) -> bool: return self._manager.is_connected() diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index 287073f..93600fd 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -87,6 +87,12 @@ class IceRestartMessage(BaseModel): turn_config: TurnConfig +class SetAckMessage(BaseModel): + type: Literal["set_ack"] + success: bool + error: Optional[str] = None + + # Discriminated union for incoming messages IncomingMessage = Annotated[ Union[ @@ -95,6 +101,7 @@ class IceRestartMessage(BaseModel): SessionIdMessage, PromptAckMessage, SetImageAckMessage, + SetAckMessage, ErrorMessage, ReadyMessage, IceRestartMessage, @@ -131,8 +138,16 @@ class SetAvatarImageMessage(BaseModel): image_data: str # Base64-encoded image -# Outgoing message union (no discriminator needed - we know what we're sending) -OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage] +class SetParamsMessage(BaseModel): + type: Literal["set"] = "set" + prompt: Optional[str] = None + enhance_prompt: Optional[bool] = None + image_data: Optional[str] = None + + +OutgoingMessage = Union[ + OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage, SetParamsMessage +] def parse_incoming_message(data: dict) -> IncomingMessage: @@ -161,4 +176,6 @@ def message_to_json(message: OutgoingMessage) -> str: Returns: JSON string """ + if isinstance(message, SetParamsMessage): + return message.model_dump_json(exclude_unset=True) return message.model_dump_json() diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index f99d14a..d0500fb 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -25,6 +25,7 @@ PromptAckMessage, SetImageAckMessage, SetAvatarImageMessage, + SetAckMessage, ErrorMessage, IceRestartMessage, OutgoingMessage, @@ -54,6 +55,7 @@ def __init__( self._ice_candidates_queue: list[RTCIceCandidate] = [] self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {} self._pending_image_set: Optional[tuple[asyncio.Event, dict]] = None + self._pending_set: Optional[tuple[asyncio.Event, dict]] = None async def connect( self, @@ -254,6 +256,8 @@ async def _handle_message(self, data: dict) -> None: self._handle_prompt_ack(message) elif message.type == "set_image_ack": self._handle_set_image_ack(message) + elif message.type == "set_ack": + self._handle_set_ack(message) elif message.type == "error": self._handle_error(message) elif message.type == "ready": @@ -307,6 +311,14 @@ def _handle_set_image_ack(self, message: SetImageAckMessage) -> None: result["error"] = message.error event.set() + def _handle_set_ack(self, message: SetAckMessage) -> None: + logger.debug(f"Received set_ack: success={message.success}, error={message.error}") + if self._pending_set: + event, result = self._pending_set + result["success"] = message.success + result["error"] = message.error + event.set() + def _handle_error(self, message: ErrorMessage) -> None: logger.error(f"Received error from server: {message.error}") error = WebRTCError(message.error) @@ -367,6 +379,15 @@ def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: def unregister_image_set_wait(self) -> None: self._pending_image_set = None + def register_set_wait(self) -> tuple[asyncio.Event, dict]: + event = asyncio.Event() + result: dict = {"success": False, "error": None} + self._pending_set = (event, result) + return event, result + + def unregister_set_wait(self) -> None: + self._pending_set = None + def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]: event = asyncio.Event() result: dict = {"success": False, "error": None} diff --git a/decart/realtime/webrtc_manager.py b/decart/realtime/webrtc_manager.py index f9f0764..752cf4b 100644 --- a/decart/realtime/webrtc_manager.py +++ b/decart/realtime/webrtc_manager.py @@ -106,3 +106,9 @@ def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: def unregister_image_set_wait(self) -> None: self._connection.unregister_image_set_wait() + + def register_set_wait(self) -> tuple[asyncio.Event, dict]: + return self._connection.register_set_wait() + + def unregister_set_wait(self) -> None: + self._connection.unregister_set_wait() diff --git a/examples/avatar_live.py b/examples/avatar_live.py index bebb911..efb08ab 100644 --- a/examples/avatar_live.py +++ b/examples/avatar_live.py @@ -145,6 +145,10 @@ def on_error(error): # print("Updating avatar image...") # await realtime_client.set_image(Path("new_avatar.png")) # print("✓ Avatar image updated!") + # + # Or use the unified set() method: + # await realtime_client.set(image=Path("new_avatar.png")) + # await realtime_client.set(prompt="A friendly greeting", image=Path("new_avatar.png")) try: while True: diff --git a/examples/realtime_synthetic.py b/examples/realtime_synthetic.py index b251814..b0bd7da 100644 --- a/examples/realtime_synthetic.py +++ b/examples/realtime_synthetic.py @@ -127,7 +127,7 @@ def on_error(error): print("\n🎨 Changing style to 'Cyberpunk city'...") try: - await realtime_client.set_prompt("Cyberpunk city") + await realtime_client.set(prompt="Cyberpunk city") print("✓ Prompt set successfully") except Exception as e: print(f"⚠️ Failed to set prompt: {e}") diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index 511b1b1..28da374 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -605,3 +605,489 @@ async def test_connect_with_initial_prompt(): call_kwargs = mock_manager.connect.call_args[1] assert "initial_prompt" in call_kwargs assert call_kwargs["initial_prompt"] == {"text": "Test prompt", "enhance": False} + + +# Tests for the unified set() method + + +@pytest.mark.asyncio +async def test_set_prompt_only(): + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + set_event = asyncio.Event() + set_result = {"success": True, "error": None} + + mock_manager.register_set_wait = MagicMock(return_value=(set_event, set_result)) + mock_manager.unregister_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + async def fire(): + await asyncio.sleep(0.01) + set_event.set() + + asyncio.create_task(fire()) + await realtime_client.set(prompt="A cat") + + mock_manager.send_message.assert_called() + message = mock_manager.send_message.call_args[0][0] + assert message.type == "set" + assert message.prompt == "A cat" + assert "image_data" not in message.model_fields_set + assert "enhance_prompt" not in message.model_fields_set + mock_manager.unregister_set_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_prompt_with_enhance(): + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + set_event = asyncio.Event() + set_result = {"success": True, "error": None} + + mock_manager.register_set_wait = MagicMock(return_value=(set_event, set_result)) + mock_manager.unregister_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + async def fire(): + await asyncio.sleep(0.01) + set_event.set() + + asyncio.create_task(fire()) + await realtime_client.set(prompt="A cat", enhance=True) + + message = mock_manager.send_message.call_args[0][0] + assert message.type == "set" + assert message.prompt == "A cat" + assert message.enhance_prompt is True + + +@pytest.mark.asyncio +async def test_set_image_only(): + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + set_event = asyncio.Event() + set_result = {"success": True, "error": None} + + mock_manager.register_set_wait = MagicMock(return_value=(set_event, set_result)) + mock_manager.unregister_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"fake image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + async def fire(): + await asyncio.sleep(0.01) + set_event.set() + + asyncio.create_task(fire()) + await realtime_client.set(image=b"fake image data") + + message = mock_manager.send_message.call_args[0][0] + assert message.type == "set" + assert isinstance(message.image_data, str) + assert "prompt" not in message.model_fields_set + mock_file_input.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_prompt_and_image(): + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + set_event = asyncio.Event() + set_result = {"success": True, "error": None} + + mock_manager.register_set_wait = MagicMock(return_value=(set_event, set_result)) + mock_manager.unregister_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"fake", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + async def fire(): + await asyncio.sleep(0.01) + set_event.set() + + asyncio.create_task(fire()) + await realtime_client.set(prompt="A cat", image=b"fake") + + message = mock_manager.send_message.call_args[0][0] + assert message.type == "set" + assert message.prompt == "A cat" + assert isinstance(message.image_data, str) + + +@pytest.mark.asyncio +async def test_set_image_none_clears(): + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + set_event = asyncio.Event() + set_result = {"success": True, "error": None} + + mock_manager.register_set_wait = MagicMock(return_value=(set_event, set_result)) + mock_manager.unregister_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + async def fire(): + await asyncio.sleep(0.01) + set_event.set() + + asyncio.create_task(fire()) + await realtime_client.set(image=None) + + message = mock_manager.send_message.call_args[0][0] + assert message.type == "set" + assert "image_data" in message.model_fields_set + assert message.image_data is None + + +@pytest.mark.asyncio +async def test_set_rejects_empty(): + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import InvalidInputError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + with pytest.raises(InvalidInputError) as exc_info: + await realtime_client.set() + + assert "at least one" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_set_rejects_empty_prompt(): + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import InvalidInputError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + with pytest.raises(InvalidInputError): + await realtime_client.set(prompt=" ") + + +@pytest.mark.asyncio +async def test_set_timeout(): + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + set_event = asyncio.Event() + set_result = {"success": False, "error": None} + + mock_manager.register_set_wait = MagicMock(return_value=(set_event, set_result)) + mock_manager.unregister_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import DecartSDKError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError): + with pytest.raises(DecartSDKError) as exc_info: + await realtime_client.set(prompt="A cat") + + assert "timed out" in str(exc_info.value).lower() + mock_manager.unregister_set_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_server_error(): + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + set_event = asyncio.Event() + set_result = {"success": False, "error": "Server error"} + + mock_manager.register_set_wait = MagicMock(return_value=(set_event, set_result)) + mock_manager.unregister_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import DecartSDKError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("lucy_2_rt"), + on_remote_stream=lambda t: None, + ), + ) + + async def fire(): + await asyncio.sleep(0.01) + set_event.set() + + asyncio.create_task(fire()) + + with pytest.raises(DecartSDKError) as exc_info: + await realtime_client.set(prompt="A cat") + + assert "Server error" in str(exc_info.value) + mock_manager.unregister_set_wait.assert_called_once() + + +def test_set_message_serialization(): + import json + from decart.realtime.messages import SetParamsMessage + + msg = SetParamsMessage(prompt="Hello") + raw = json.loads(msg.model_dump_json(exclude_unset=True)) + assert "prompt" in raw + assert raw["prompt"] == "Hello" + assert "image_data" not in raw + assert "enhance_prompt" not in raw + + msg2 = SetParamsMessage(image_data=None) + raw2 = json.loads(msg2.model_dump_json(exclude_unset=True)) + assert "image_data" in raw2 + assert raw2["image_data"] is None + + +def test_set_ack_message_parsing(): + from decart.realtime.messages import SetAckMessage, parse_incoming_message + + result = parse_incoming_message({"type": "set_ack", "success": True, "error": None}) + assert isinstance(result, SetAckMessage) + assert result.success is True + assert result.error is None + + result2 = parse_incoming_message({"type": "set_ack", "success": False, "error": "fail"}) + assert isinstance(result2, SetAckMessage) + assert result2.success is False + assert result2.error == "fail"