Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions decart/realtime/client.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
from typing import Callable, Optional
from typing import Callable, Optional, Union
import asyncio
import base64
import logging
import uuid
import enum
import aiohttp
from aiortc import MediaStreamTrack

from .webrtc_manager import WebRTCManager, WebRTCConfiguration
from .messages import PromptMessage, SetAvatarImageMessage
from .messages import PromptMessage, SetAvatarImageMessage, SetParamsMessage
from .types import ConnectionState, RealtimeConnectOptions
from ..types import FileInput
from ..errors import DecartSDKError, InvalidInputError, WebRTCError
from ..process.request import file_input_to_bytes


class _Unset(enum.Enum):
"""Sentinel to distinguish 'not provided' from None (which means 'clear')."""

UNSET = "UNSET"


UNSET = _Unset.UNSET

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -184,6 +194,54 @@ async def set_image(self, image: FileInput) -> None:
finally:
self._manager.unregister_image_set_wait()

async def set(
self,
*,
prompt: Optional[str] = None,
enhance: Optional[bool] = None,
image: Union[FileInput, None, _Unset] = UNSET,
) -> None:
image_provided = image is not UNSET

if prompt is None and not image_provided:
raise InvalidInputError("At least one of 'prompt' or 'image' must be provided")

fields: dict = {"type": "set"}

if prompt is not None:
if not prompt.strip():
raise InvalidInputError("Prompt cannot be empty")
fields["prompt"] = prompt

if enhance is not None:
fields["enhance_prompt"] = enhance

if image_provided:
if image is None:
fields["image_data"] = None
else:
if not self._http_session:
raise InvalidInputError("HTTP session not available")
image_bytes, _ = await file_input_to_bytes(image, self._http_session)
fields["image_data"] = base64.b64encode(image_bytes).decode("utf-8")

message = SetParamsMessage(**fields)

event, result = self._manager.register_set_wait()

try:
await self._manager.send_message(message)

try:
await asyncio.wait_for(event.wait(), timeout=30.0)
except asyncio.TimeoutError:
raise DecartSDKError("Set acknowledgment timed out")

if not result["success"]:
raise DecartSDKError(result.get("error") or "Set failed")
finally:
self._manager.unregister_set_wait()

def is_connected(self) -> bool:
return self._manager.is_connected()

Expand Down
21 changes: 19 additions & 2 deletions decart/realtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ class IceRestartMessage(BaseModel):
turn_config: TurnConfig


class SetAckMessage(BaseModel):
type: Literal["set_ack"]
success: bool
error: Optional[str] = None


# Discriminated union for incoming messages
IncomingMessage = Annotated[
Union[
Expand All @@ -95,6 +101,7 @@ class IceRestartMessage(BaseModel):
SessionIdMessage,
PromptAckMessage,
SetImageAckMessage,
SetAckMessage,
ErrorMessage,
ReadyMessage,
IceRestartMessage,
Expand Down Expand Up @@ -131,8 +138,16 @@ class SetAvatarImageMessage(BaseModel):
image_data: str # Base64-encoded image


# Outgoing message union (no discriminator needed - we know what we're sending)
OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage]
class SetParamsMessage(BaseModel):
type: Literal["set"] = "set"
prompt: Optional[str] = None
enhance_prompt: Optional[bool] = None
image_data: Optional[str] = None


OutgoingMessage = Union[
OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage, SetParamsMessage
]


def parse_incoming_message(data: dict) -> IncomingMessage:
Expand Down Expand Up @@ -161,4 +176,6 @@ def message_to_json(message: OutgoingMessage) -> str:
Returns:
JSON string
"""
if isinstance(message, SetParamsMessage):
return message.model_dump_json(exclude_unset=True)
return message.model_dump_json()
21 changes: 21 additions & 0 deletions decart/realtime/webrtc_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
PromptAckMessage,
SetImageAckMessage,
SetAvatarImageMessage,
SetAckMessage,
ErrorMessage,
IceRestartMessage,
OutgoingMessage,
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
self._ice_candidates_queue: list[RTCIceCandidate] = []
self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {}
self._pending_image_set: Optional[tuple[asyncio.Event, dict]] = None
self._pending_set: Optional[tuple[asyncio.Event, dict]] = None

async def connect(
self,
Expand Down Expand Up @@ -254,6 +256,8 @@ async def _handle_message(self, data: dict) -> None:
self._handle_prompt_ack(message)
elif message.type == "set_image_ack":
self._handle_set_image_ack(message)
elif message.type == "set_ack":
self._handle_set_ack(message)
elif message.type == "error":
self._handle_error(message)
elif message.type == "ready":
Expand Down Expand Up @@ -307,6 +311,14 @@ def _handle_set_image_ack(self, message: SetImageAckMessage) -> None:
result["error"] = message.error
event.set()

def _handle_set_ack(self, message: SetAckMessage) -> None:
logger.debug(f"Received set_ack: success={message.success}, error={message.error}")
if self._pending_set:
event, result = self._pending_set
result["success"] = message.success
result["error"] = message.error
event.set()

def _handle_error(self, message: ErrorMessage) -> None:
logger.error(f"Received error from server: {message.error}")
error = WebRTCError(message.error)
Expand Down Expand Up @@ -367,6 +379,15 @@ def register_image_set_wait(self) -> tuple[asyncio.Event, dict]:
def unregister_image_set_wait(self) -> None:
self._pending_image_set = None

def register_set_wait(self) -> tuple[asyncio.Event, dict]:
event = asyncio.Event()
result: dict = {"success": False, "error": None}
self._pending_set = (event, result)
return event, result

def unregister_set_wait(self) -> None:
self._pending_set = None

def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]:
event = asyncio.Event()
result: dict = {"success": False, "error": None}
Expand Down
6 changes: 6 additions & 0 deletions decart/realtime/webrtc_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,9 @@ def register_image_set_wait(self) -> tuple[asyncio.Event, dict]:

def unregister_image_set_wait(self) -> None:
self._connection.unregister_image_set_wait()

def register_set_wait(self) -> tuple[asyncio.Event, dict]:
return self._connection.register_set_wait()

def unregister_set_wait(self) -> None:
self._connection.unregister_set_wait()
4 changes: 4 additions & 0 deletions examples/avatar_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def on_error(error):
# print("Updating avatar image...")
# await realtime_client.set_image(Path("new_avatar.png"))
# print("✓ Avatar image updated!")
#
# Or use the unified set() method:
# await realtime_client.set(image=Path("new_avatar.png"))
# await realtime_client.set(prompt="A friendly greeting", image=Path("new_avatar.png"))

try:
while True:
Expand Down
2 changes: 1 addition & 1 deletion examples/realtime_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def on_error(error):

print("\n🎨 Changing style to 'Cyberpunk city'...")
try:
await realtime_client.set_prompt("Cyberpunk city")
await realtime_client.set(prompt="Cyberpunk city")
print("✓ Prompt set successfully")
except Exception as e:
print(f"⚠️ Failed to set prompt: {e}")
Expand Down
Loading
Loading