diff --git a/README.md b/README.md index 56406e3..fa5faf3 100644 --- a/README.md +++ b/README.md @@ -1081,10 +1081,14 @@ The Runware SDK provides configurable timeout settings for different operations Set environment variables to customize timeout behavior: ```bash +# Concurrency +RUNWARE_MAX_CONCURRENT_REQUESTS=15 # Max concurrent API requests per client (default: 15) + # Image Operations (milliseconds) RUNWARE_IMAGE_INFERENCE_TIMEOUT=300000 # Image generation (default: 5 min) RUNWARE_IMAGE_OPERATION_TIMEOUT=120000 # Caption, upscale, background removal (default: 2 min) RUNWARE_IMAGE_UPLOAD_TIMEOUT=60000 # Image upload (default: 1 min) +RUNWARE_MAX_POLLS_IMAGE_GENERATION=480 # Max polling attempts for async image / getResponse (default: 480) # Model Operations (milliseconds) RUNWARE_MODEL_UPLOAD_TIMEOUT=900000 # Model upload (default: 15 min) @@ -1092,12 +1096,15 @@ RUNWARE_MODEL_UPLOAD_TIMEOUT=900000 # Model upload (default: 15 min) # Video Operations (milliseconds) RUNWARE_VIDEO_INITIAL_TIMEOUT=30000 # Initial response wait (default: 30 sec) RUNWARE_VIDEO_POLLING_DELAY=3000 # Delay between status checks (default: 3 sec) -RUNWARE_MAX_POLLS_VIDEO_GENERATION=480 # Max polling attempts (default: 480, ~24 min total) +RUNWARE_MAX_POLLS_VIDEO_GENERATION=480 # Max polling attempts for video / caption / upscale / bg removal (default: 480, ~24 min total) + +# 3D Operations +RUNWARE_MAX_POLLS_3D_GENERATION=480 # Max polling attempts for 3D inference / getResponse (default: 480) # Audio Operations (milliseconds) RUNWARE_AUDIO_INFERENCE_TIMEOUT=300000 # Audio generation (default: 5 min) RUNWARE_AUDIO_POLLING_DELAY=1000 # Delay between status checks (default: 1 sec) -RUNWARE_MAX_POLLS_AUDIO_GENERATION=240 # Max polling attempts (default: 240, ~4 min total) +RUNWARE_MAX_POLLS_AUDIO_GENERATION=240 # Max polling attempts for audio inference (default: 240, ~4 min total) # Other Operations (milliseconds) RUNWARE_PROMPT_ENHANCE_TIMEOUT=60000 # Prompt enhancement (default: 1 min) diff --git a/runware/base.py b/runware/base.py index 721c295..6b9fa61 100644 --- a/runware/base.py +++ b/runware/base.py @@ -3,10 +3,10 @@ import logging import os import re -import uuid from asyncio import gather from dataclasses import asdict -from typing import List, Optional, Union, Callable, Any, Dict +from random import uniform +from typing import List, Optional, Union, Callable, Any, Dict, Tuple from websockets.protocol import State @@ -50,8 +50,12 @@ IFrameImage, IAsyncTaskResponse, IVectorize, + OperationState, I3dInference, I3d, + IGetResponseRequest, + IUploadImageRequest, + IUploadMediaRequest, ITextInference, IText, ) @@ -62,7 +66,6 @@ fileToBase64, createImageFromResponse, createImageToTextFromResponse, - createVideoToTextFromResponse, createEnhancedPromptsFromResponse, instantiateDataclassList, RunwareAPIError, @@ -89,12 +92,14 @@ IMAGE_POLLING_DELAY, TEXT_POLLING_DELAY, AUDIO_INITIAL_TIMEOUT, - AUDIO_INFERENCE_TIMEOUT, AUDIO_POLLING_DELAY, MAX_POLLS, MAX_POLLS_VIDEO_GENERATION, MAX_POLLS_AUDIO_GENERATION, + MAX_POLLS_3D_GENERATION, + MAX_POLLS_IMAGE_GENERATION, MAX_RETRY_ATTEMPTS, + MAX_CONCURRENT_REQUESTS, ) # Configure logging @@ -135,87 +140,336 @@ def __init__( self._listener_tasks = set() self._reconnection_manager = ReconnectionManager(logger=self.logger) self._reconnect_lock = asyncio.Lock() + self._pending_operations: Dict[str, Dict[str, Any]] = {} + self._operations_lock = asyncio.Lock() + if MAX_CONCURRENT_REQUESTS <= 0: + raise ValueError( + "RUNWARE_MAX_CONCURRENT_REQUESTS must be greater than 0 " + f"(got {MAX_CONCURRENT_REQUESTS}). A value of 0 would cause asyncio.Semaphore to deadlock." + ) + self._request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + async def _register_pending_operation( + self, + task_uuid: str, + expected_results: int = 1, + on_partial: "Optional[Callable[[List[Any], Optional[IError]], None]]" = None, + complete_predicate: "Optional[Callable[[Dict[str, Any]], bool]]" = None, + result_filter: "Optional[Callable[[Dict[str, Any]], bool]]" = None + ) -> "Tuple[asyncio.Future, bool]": + async with self._operations_lock: + if task_uuid in self._pending_operations: + existing_op = self._pending_operations[task_uuid] + existing_state = existing_op.get("state") + existing_future = existing_op.get("future") + + if existing_state == OperationState.REGISTERED: + # Share the same in-flight request: second caller gets same future, does not send again. + # Ensures one charge and both callers get the result when it arrives. + self.logger.debug( + f"Sharing pending operation | TaskUUID: {task_uuid} | " + f"Partial results: {len(existing_op.get('results', []))}" + ) + return existing_future, False - async def _retry_with_reconnect(self, func, *args, **kwargs): + if existing_state == OperationState.SENT: + # Share the same in-flight operation so both callers get the result. + self.logger.debug( + f"Sharing in-flight operation | TaskUUID: {task_uuid} | " + f"Partial results: {len(existing_op.get('results', []))}" + ) + return existing_future, False - last_error = None - - for attempt in range(MAX_RETRY_ATTEMPTS): + if existing_state == OperationState.DISCONNECTED: + self.logger.info( + f"Resending after disconnect | TaskUUID: {task_uuid} | " + f"Clearing {len(existing_op.get('results', []))} partial results" + ) + loop = asyncio.get_running_loop() + new_future = loop.create_future() + existing_op["future"] = new_future + existing_op["results"] = [] + existing_op["state"] = OperationState.SENT + if on_partial is not None: + existing_op["on_partial"] = on_partial + if complete_predicate is not None: + existing_op["complete_predicate"] = complete_predicate + if result_filter is not None: + existing_op["result_filter"] = result_filter + return new_future, True + + loop = asyncio.get_running_loop() + future = loop.create_future() + self._pending_operations[task_uuid] = { + "future": future, + "expected": expected_results, + "results": [], + "state": OperationState.REGISTERED, + "on_partial": on_partial, + "complete_predicate": complete_predicate, + "result_filter": result_filter + } + return future, True + + async def _mark_operation_sent(self, task_uuid: str) -> None: + async with self._operations_lock: + op = self._pending_operations.get(task_uuid) + if op and op.get("state") == OperationState.REGISTERED: + op["state"] = OperationState.SENT + + async def _unregister_pending_operation(self, task_uuid: str, force: bool = False) -> "Optional[List[Dict[str, Any]]]": + async with self._operations_lock: + op = self._pending_operations.get(task_uuid) + if not op: + return None + + if not force and op.get("state") == OperationState.DISCONNECTED: + return op.get("results", []) + + return self._pending_operations.pop(task_uuid, {}).get("results") + + async def _handle_pending_operation_message(self, item: "Dict[str, Any]") -> bool: + task_uuid = item.get("taskUUID") + if not task_uuid: + return False + + on_partial_callback = None + async with self._operations_lock: + op = self._pending_operations.get(task_uuid) + if op is None: + return False + + future = op["future"] + + if future.done(): + return True + + if self._is_error_response(item): + if not future.done(): + future.set_exception(RunwareAPIError(item)) + return True + + result_filter = op.get("result_filter") + if result_filter is not None: + if not result_filter(item): + return True + + op["results"].append(item) + + if op["on_partial"]: + on_partial_callback = op["on_partial"] + + if op["complete_predicate"]: + is_complete = op["complete_predicate"](item) + else: + is_complete = len(op["results"]) >= op["expected"] + + if is_complete and not future.done(): + logger.debug(f"Completing pending operation: {task_uuid}, results: {len(op['results'])}") + future.set_result(op["results"]) + + if on_partial_callback: + try: + if item.get("imageUUID"): + partial_images = [createImageFromResponse(item)] + on_partial_callback(partial_images, None) + elif item.get("videoUUID") or item.get("mediaUUID"): + on_partial_callback([item], None) + elif item.get("audioUUID"): + on_partial_callback([item], None) + except Exception as e: + logger.error(f"Error in on_partial callback: {e}") + + return True + + async def _handle_pending_operation_error(self, error: "Dict[str, Any]") -> bool: + task_uuid = error.get("taskUUID") + if not task_uuid: + return False + + on_partial_callback = None + error_obj = None + async with self._operations_lock: + op = self._pending_operations.get(task_uuid) + if op is None: + return False + + future = op["future"] + + if future.done(): + return True + + if op["on_partial"]: + on_partial_callback = op["on_partial"] + error_obj = IError( + error=True, + error_message=error.get("message", "Unknown error"), + task_uuid=task_uuid, + error_code=error.get("code"), + error_type=error.get("type"), + parameter=error.get("parameter"), + documentation=error.get("documentation"), + ) + + if not future.done(): + future.set_exception(RunwareAPIError(error)) + + if on_partial_callback and error_obj is not None: try: - result = await func(*args, **kwargs) - self._reconnection_manager.on_connection_success() - return result - + on_partial_callback([], error_obj) except Exception as e: - last_error = e - - # When conflictTaskUUID: raise on first attempt, return async response on retry - if isinstance(e, RunwareAPIError) and e.code == "conflictTaskUUID": - if attempt == 0: + logger.error(f"Error in on_partial error callback: {e}") + + return True + + async def _do_reconnect(self, last_error: Optional[Exception] = None) -> bool: + async with self._reconnect_lock: + if self.connected() and self._connectionSessionUUID is not None: + self.logger.info("Connection already re-established by another task") + return True + + should_open_circuit = self._reconnection_manager.on_auth_failure() + + if should_open_circuit: + self.logger.error("Authentication circuit breaker opened due to repeated failures") + raise ConnectionError( + f"Authentication circuit breaker opened due to repeated failures. " + f"Last error: {str(last_error)}" + ) + + try: + self.logger.info(f"Reconnecting after error: {str(last_error)}") + + self._invalidAPIkey = None + self._connectionSessionUUID = None + + await self.connect() + + if not self.connected(): + raise ConnectionError("Reconnection failed; WebSocket is not open") + + self.logger.info("Reconnection successful") + return True + + except Exception as reconnect_error: + self.logger.error(f"Error while reconnecting: {reconnect_error}", exc_info=True) + return False + + async def _retry_with_reconnect(self, func, request_model, **func_kwargs): + task_uuid = getattr(request_model, 'taskUUID', None) + last_error = None + + try: + for attempt in range(MAX_RETRY_ATTEMPTS): + try: + result = await func(request_model, **func_kwargs) + self._reconnection_manager.on_connection_success() + return result + + except Exception as e: + last_error = e + + if isinstance(e, RunwareAPIError) and e.code == "conflictTaskUUID": + if attempt == 0: + raise + conflict_task_uuid = e.error_data.get("taskUUID") or getattr(request_model, 'taskUUID', None) + if conflict_task_uuid: + number_results = getattr(request_model, 'numberResults', 1) or 1 + return await self._pollResults( + task_uuid=conflict_task_uuid, + number_results=number_results + ) + raise + + if not isinstance(e, ConnectionError): raise + + if attempt >= MAX_RETRY_ATTEMPTS - 1: + self.logger.error(f"Max retry attempts ({MAX_RETRY_ATTEMPTS}) exceeded") + raise ConnectionError( + f"Failed after {MAX_RETRY_ATTEMPTS} attempts. " + f"Last error: {last_error}" + ) + + reconnected = await self._do_reconnect(last_error) + + if reconnected: + jitter = uniform(0.1, 0.5) * (attempt + 1) + await asyncio.sleep(jitter) else: - # - context = e.error_data.get("context", {}) - task_type = context.get("taskType") - task_uuid = context.get("taskUUID") or e.error_data.get("taskUUID") - delivery_method_raw = context.get("deliveryMethod") - delivery_method_enum = EDeliveryMethod(delivery_method_raw) if isinstance(delivery_method_raw, str) else delivery_method_raw if delivery_method_raw else None - + delay = self._reconnection_manager.calculate_delay() + await asyncio.sleep(delay) + finally: + if task_uuid: + await self._unregister_pending_operation(task_uuid, force=True) + + async def _retry_async_with_reconnect(self, func, request_model, task_type: str): + task_uuid = getattr(request_model, 'taskUUID', None) + delivery_method = getattr(request_model, 'deliveryMethod', None) + + try: + for attempt in range(MAX_RETRY_ATTEMPTS): + try: + result = await func(request_model) + self._reconnection_manager.on_connection_success() + return result + + except Exception as e: + last_error = e + + if isinstance(e, RunwareAPIError) and e.code == "conflictTaskUUID": + if attempt == 0: + raise + + delivery_method_enum = None + if delivery_method is not None: + delivery_method_enum = ( + EDeliveryMethod(delivery_method) + if isinstance(delivery_method, str) + else delivery_method + ) + if task_type and task_uuid and delivery_method_enum is EDeliveryMethod.ASYNC: return createAsyncTaskResponse({ "taskType": task_type, "taskUUID": task_uuid }) - + + conflict_task_uuid = e.error_data.get("taskUUID") or task_uuid + if conflict_task_uuid: + number_results = getattr(request_model, 'numberResults', 1) or 1 + return await self._pollResults( + task_uuid=conflict_task_uuid, + number_results=number_results + ) + raise RunwareAPIError({ "code": "conflictTaskUUIDDuringRetries", "message": "Lost connection during request submission", "taskUUID": task_uuid }) - - if not isinstance(e, ConnectionError): - raise - - if attempt >= MAX_RETRY_ATTEMPTS - 1: - self.logger.error(f"Max authentication retry attempts ({MAX_RETRY_ATTEMPTS}) exceeded") - raise ConnectionError( - f"Failed to authenticate after {MAX_RETRY_ATTEMPTS} attempts. " - f"Last error: {last_error}" - ) - - async with self._reconnect_lock: - # Check if already connected (another concurrent task may have reconnected) - if self.connected() and self._connectionSessionUUID is not None: - self.logger.info("Connection already re-established by another task, retrying request") - continue - - should_open_circuit = self._reconnection_manager.on_auth_failure() - - if should_open_circuit: - self.logger.error("Authentication circuit breaker opened due to repeated failures") + + if not isinstance(e, ConnectionError): + raise + + if attempt >= MAX_RETRY_ATTEMPTS - 1: + self.logger.error(f"Max retry attempts ({MAX_RETRY_ATTEMPTS}) exceeded") raise ConnectionError( - f"Authentication circuit breaker opened due to repeated failures. " - f"Last error: {str(last_error)}" + f"Failed after {MAX_RETRY_ATTEMPTS} attempts. " + f"Last error: {last_error}" ) - try: - self.logger.info(f"Reconnecting after auth error: {str(e)}") - - self._invalidAPIkey = None - self._connectionSessionUUID = None - - await self.connect() - - - if not self.connected(): - raise ConnectionError("Reconnection failed; WebSocket is not open") - - self.logger.info("Reconnection successful, retrying request") - - except Exception as reconnect_error: - self.logger.error(f"Error while reconnecting: {reconnect_error}", exc_info=True) + + reconnected = await self._do_reconnect(last_error) + + if reconnected: + jitter = uniform(0.1, 0.5) * (attempt + 1) + await asyncio.sleep(jitter) + else: delay = self._reconnection_manager.calculate_delay() await asyncio.sleep(delay) + finally: + if task_uuid: + await self._unregister_pending_operation(task_uuid, force=True) def _handle_error_response(self, response: Dict[str, Any]) -> None: """ @@ -224,20 +478,21 @@ def _handle_error_response(self, response: Dict[str, Any]) -> None: """ if not self._is_error_response(response): return - + # If an authentication error, raise ConnectionError to trigger retry if response.get("taskType") == "authentication" or response.get("code") == "missingApiKey": error_message = response.get("message", "Authentication error") self.logger.warning(f"Authentication error detected: {error_message}") raise ConnectionError(error_message) - - # For all other errors + + # For all other errors raise RunwareAPIError(response) - + def _create_safe_async_listener(self, async_func): def wrapper(m): task = asyncio.create_task(async_func(m)) self._listener_tasks.add(task) + def handle_task_exception(t): self._listener_tasks.discard(t) if not t.cancelled(): @@ -327,316 +582,282 @@ def handle_connection_response(self, m): else: self._invalidAPIkey = "Error connection" return - + self._connectionSessionUUID = m.get("newConnectionSessionUUID", {}).get( "connectionSessionUUID" ) self._invalidAPIkey = None - async def photoMaker(self, requestPhotoMaker: IPhotoMaker) -> Union[List[IImage], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._photoMaker, requestPhotoMaker) + async def photoMaker(self, requestPhotoMaker: "IPhotoMaker") -> "Union[List[IImage], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_with_reconnect(self._photoMaker, requestPhotoMaker) - async def _photoMaker(self, requestPhotoMaker: IPhotoMaker) -> Union[List[IImage], IAsyncTaskResponse]: - retry_count = 0 - - try: - await self.ensureConnection() + async def _photoMaker(self, requestPhotoMaker: "IPhotoMaker") -> "Union[List[IImage], IAsyncTaskResponse]": + await self.ensureConnection() - task_uuid = requestPhotoMaker.taskUUID or getUUID() - requestPhotoMaker.taskUUID = task_uuid + task_uuid = requestPhotoMaker.taskUUID or getUUID() + requestPhotoMaker.taskUUID = task_uuid - for i, image in enumerate(requestPhotoMaker.inputImages): - if isLocalFile(image) and not str(image).startswith("http"): - requestPhotoMaker.inputImages[i] = await fileToBase64(image) + for i, image in enumerate(requestPhotoMaker.inputImages): + if isLocalFile(image) and not str(image).startswith("http"): + requestPhotoMaker.inputImages[i] = await fileToBase64(image) - prompt = f"{requestPhotoMaker.positivePrompt}".strip() - request_object = { - "taskUUID": requestPhotoMaker.taskUUID, - "model": requestPhotoMaker.model, - "positivePrompt": prompt, - "numberResults": requestPhotoMaker.numberResults, - "height": requestPhotoMaker.height, - "width": requestPhotoMaker.width, - "taskType": ETaskType.PHOTO_MAKER.value, - "style": requestPhotoMaker.style, - "strength": requestPhotoMaker.strength, - **( - {"inputImages": requestPhotoMaker.inputImages} - if requestPhotoMaker.inputImages - else {} - ), - **( - {"steps": requestPhotoMaker.steps} - if requestPhotoMaker.steps - else {} - ), - } + prompt = f"{requestPhotoMaker.positivePrompt}".strip() + request_object = { + "taskUUID": requestPhotoMaker.taskUUID, + "model": requestPhotoMaker.model, + "positivePrompt": prompt, + "numberResults": requestPhotoMaker.numberResults, + "height": requestPhotoMaker.height, + "width": requestPhotoMaker.width, + "taskType": ETaskType.PHOTO_MAKER.value, + "style": requestPhotoMaker.style, + "strength": requestPhotoMaker.strength, + **( + {"inputImages": requestPhotoMaker.inputImages} + if requestPhotoMaker.inputImages + else {} + ), + **( + {"steps": requestPhotoMaker.steps} + if requestPhotoMaker.steps + else {} + ), + } - if requestPhotoMaker.outputFormat is not None: - request_object["outputFormat"] = requestPhotoMaker.outputFormat - if requestPhotoMaker.includeCost: - request_object["includeCost"] = requestPhotoMaker.includeCost - if requestPhotoMaker.outputType: - request_object["outputType"] = requestPhotoMaker.outputType - if requestPhotoMaker.webhookURL: - request_object["webhookURL"] = requestPhotoMaker.webhookURL + if requestPhotoMaker.outputFormat is not None: + request_object["outputFormat"] = requestPhotoMaker.outputFormat + if requestPhotoMaker.includeCost: + request_object["includeCost"] = requestPhotoMaker.includeCost + if requestPhotoMaker.outputType: + request_object["outputType"] = requestPhotoMaker.outputType + if requestPhotoMaker.webhookURL: + request_object["webhookURL"] = requestPhotoMaker.webhookURL + return await self._handleWebhookRequest( + request_object=request_object, + task_uuid=task_uuid, + task_type="photoMaker", + debug_key="photo-maker-webhook" + ) - await self.send([request_object]) + numberOfResults = requestPhotoMaker.numberResults - if requestPhotoMaker.webhookURL: - return await self._handleWebhookAcknowledgment( - task_uuid=task_uuid, - task_type="photoMaker", - debug_key="photo-maker-webhook" - ) + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=numberOfResults, + complete_predicate=None, + result_filter=lambda r: r.get("imageUUID") is not None + ) - lis = self.globalListener( - taskUUID=task_uuid, + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=IMAGE_INFERENCE_TIMEOUT / 1000) + + unique_results = {} + for made_photo in results: + image_uuid = made_photo.get("imageUUID") + if image_uuid and image_uuid not in unique_results: + unique_results[image_uuid] = made_photo + + if not unique_results: + raise Exception(f"No valid photoMaker results received | TaskUUID: {task_uuid}") + + return instantiateDataclassList(IImage, list(unique_results.values())) + + except asyncio.TimeoutError: + op = self._pending_operations.get(task_uuid) + partial_count = len(op["results"]) if op else 0 + raise Exception( + f"Timeout waiting for photoMaker | TaskUUID: {task_uuid} | " + f"Expected: {numberOfResults} | Received: {partial_count} | " + f"Timeout: {IMAGE_INFERENCE_TIMEOUT}ms" ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(task_uuid) - numberOfResults = requestPhotoMaker.numberResults - - async def check(resolve: callable, reject: callable, *args: Any) -> bool: - async with self._messages_lock: - photo_maker_list = self._globalMessages.get(task_uuid, []) - unique_results = {} + async def imageInference( + self, requestImage: "IImageInference" + ) -> "Union[List[IImage], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._imageInference, + requestImage, + task_type=ETaskType.IMAGE_INFERENCE.value + ) - for made_photo in photo_maker_list: - if made_photo.get("code"): - raise RunwareAPIError(made_photo) + async def _imageInference( + self, requestImage: "IImageInference" + ) -> "Union[List[IImage], IAsyncTaskResponse]": + await self.ensureConnection() - if made_photo.get("taskType") != "photoMaker": - continue + control_net_data: "List[IControlNet]" = [] + requestImage.taskUUID = requestImage.taskUUID or getUUID() + requestImage.maskImage = await process_image(requestImage.maskImage) + requestImage.seedImage = await process_image(requestImage.seedImage) - image_uuid = made_photo.get("imageUUID") - if image_uuid not in unique_results: - unique_results[image_uuid] = made_photo + if requestImage.referenceImages: + requestImage.referenceImages = await process_image(requestImage.referenceImages) + + if requestImage.controlNet: + for control_data in requestImage.controlNet: + image_uploaded = await self.uploadImage(control_data.guideImage) + if not image_uploaded: + return [] + if hasattr(control_data, "preprocessor"): + control_data.preprocessor = control_data.preprocessor.value + control_data.guideImage = image_uploaded.imageUUID + control_net_data.append(control_data) + + prompt = requestImage.positivePrompt.strip() if requestImage.positivePrompt else None + control_net_data_dicts = [asdict(item) for item in control_net_data] + + instant_id_data = {} + if requestImage.instantID: + instant_id_data = { + k: v + for k, v in vars(requestImage.instantID).items() + if v is not None + } + if "inputImage" in instant_id_data: + instant_id_data["inputImage"] = await process_image(instant_id_data["inputImage"]) + if "poseImage" in instant_id_data: + instant_id_data["poseImage"] = await process_image(instant_id_data["poseImage"]) + + ip_adapters_data = [] + if requestImage.ipAdapters: + for ip_adapter in requestImage.ipAdapters: + ip_adapter_data = { + k: v for k, v in vars(ip_adapter).items() if v is not None + } + if "guideImage" in ip_adapter_data: + ip_adapter_data["guideImage"] = await process_image(ip_adapter_data["guideImage"]) + ip_adapters_data.append(ip_adapter_data) + + ace_plus_plus_data = {} + if requestImage.acePlusPlus: + ace_plus_plus_data = { + "inputImages": [], + "repaintingScale": requestImage.acePlusPlus.repaintingScale, + "type": requestImage.acePlusPlus.taskType, + } + if requestImage.acePlusPlus.inputImages: + ace_plus_plus_data["inputImages"] = await process_image(requestImage.acePlusPlus.inputImages) + if requestImage.acePlusPlus.inputMasks: + ace_plus_plus_data["inputMasks"] = await process_image(requestImage.acePlusPlus.inputMasks) + + pulid_data = {} + if requestImage.puLID: + pulid_data = { + "inputImages": [], + "idWeight": requestImage.puLID.idWeight, + "trueCFGScale": requestImage.puLID.trueCFGScale, + "CFGStartStep": requestImage.puLID.CFGStartStep, + "CFGStartStepPercentage": requestImage.puLID.CFGStartStepPercentage, + } + if requestImage.puLID.inputImages: + pulid_data["inputImages"] = await process_image(requestImage.puLID.inputImages) - if 0 < numberOfResults <= len(unique_results): - del self._globalMessages[task_uuid] - resolve(list(unique_results.values())) - return True + request_object = self._buildImageRequest( + requestImage, prompt, control_net_data_dicts, + instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data + ) - return False + delivery_method_enum = EDeliveryMethod(requestImage.deliveryMethod) if isinstance(requestImage.deliveryMethod, + str) else requestImage.deliveryMethod + task_uuid = requestImage.taskUUID + number_results = requestImage.numberResults or 1 - response = await getIntervalWithPromise(check, debugKey="photo-maker", timeOutDuration=IMAGE_INFERENCE_TIMEOUT) + if delivery_method_enum is EDeliveryMethod.ASYNC: + if requestImage.webhookURL: + request_object["webhookURL"] = requestImage.webhookURL + return await self._handleWebhookRequest( + request_object=request_object, + task_uuid=task_uuid, + task_type="imageInference", + debug_key="image-inference-webhook" + ) - lis["destroy"]() + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=lambda r: True + ) - if isinstance(response, dict): + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=IMAGE_INITIAL_TIMEOUT / 1000) + response = results[0] self._handle_error_response(response) - if response: - if not isinstance(response, list): - response = [response] - - return instantiateDataclassList(IImage, response) - - except Exception as e: - if retry_count >= 2: - logger.error(f"Error in photoMaker request:", exc_info=e) - raise RunwareAPIError({"message": f"PhotoMaker failed after retries: {str(e)}"}) - else: - raise e - - async def imageInference( - self, requestImage: IImageInference - ) -> Union[List[IImage], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._imageInference, requestImage) + if response.get("status") == "success" or response.get("imageUUID") is not None: + return instantiateDataclassList(IImage, results) - async def _imageInference( - self, requestImage: IImageInference - ) -> Union[List[IImage], IAsyncTaskResponse]: - let_lis: Optional[Any] = None - request_object: Optional[Dict[str, Any]] = None - task_uuids: List[str] = [] - retry_count = 0 - try: - await self.ensureConnection() - control_net_data: List[IControlNet] = [] - requestImage.taskUUID = requestImage.taskUUID or getUUID() - requestImage.maskImage = await process_image(requestImage.maskImage) - requestImage.seedImage = await process_image(requestImage.seedImage) - if requestImage.referenceImages: - requestImage.referenceImages = await process_image( - requestImage.referenceImages + return createAsyncTaskResponse(response) + except asyncio.TimeoutError: + raise ConnectionError( + f"Timeout waiting for async acknowledgment | TaskUUID: {task_uuid} | " + f"Timeout: {IMAGE_INITIAL_TIMEOUT}ms" ) - if requestImage.controlNet: - for control_data in requestImage.controlNet: - image_uploaded = await self.uploadImage(control_data.guideImage) - if not image_uploaded: - return [] - if hasattr(control_data, "preprocessor"): - control_data.preprocessor = control_data.preprocessor.value - control_data.guideImage = image_uploaded.imageUUID - control_net_data.append(control_data) - prompt = requestImage.positivePrompt.strip() if requestImage.positivePrompt else None - - control_net_data_dicts = [asdict(item) for item in control_net_data] - - instant_id_data = {} - if requestImage.instantID: - instant_id_data = { - k: v - for k, v in vars(requestImage.instantID).items() - if v is not None - } - - if "inputImage" in instant_id_data: - instant_id_data["inputImage"] = await process_image( - instant_id_data["inputImage"] - ) - - if "poseImage" in instant_id_data: - instant_id_data["poseImage"] = await process_image( - instant_id_data["poseImage"] - ) - - ip_adapters_data = [] - if requestImage.ipAdapters: - for ip_adapter in requestImage.ipAdapters: - ip_adapter_data = { - k: v for k, v in vars(ip_adapter).items() if v is not None - } - if "guideImage" in ip_adapter_data: - ip_adapter_data["guideImage"] = await process_image( - ip_adapter_data["guideImage"] - ) - - ip_adapters_data.append(ip_adapter_data) - - ace_plus_plus_data = {} - if requestImage.acePlusPlus: - ace_plus_plus_data = { - "inputImages": [], - "repaintingScale": requestImage.acePlusPlus.repaintingScale, - "type": requestImage.acePlusPlus.taskType, - } - if requestImage.acePlusPlus.inputImages: - ace_plus_plus_data["inputImages"] = await process_image( - requestImage.acePlusPlus.inputImages - ) - if requestImage.acePlusPlus.inputMasks: - ace_plus_plus_data["inputMasks"] = await process_image( - requestImage.acePlusPlus.inputMasks - ) - - pulid_data = {} - if requestImage.puLID: - pulid_data = { - "inputImages": [], - "idWeight": requestImage.puLID.idWeight, - "trueCFGScale": requestImage.puLID.trueCFGScale, - "CFGStartStep": requestImage.puLID.CFGStartStep, - "CFGStartStepPercentage": requestImage.puLID.CFGStartStepPercentage, - } - if requestImage.puLID.inputImages: - pulid_data["inputImages"] = await process_image( - requestImage.puLID.inputImages - ) + finally: + await self._unregister_pending_operation(task_uuid) + + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=number_results, + on_partial=requestImage.onPartialImages, + complete_predicate=None, + result_filter=lambda r: r.get("imageUUID") is not None + ) - request_object = self._buildImageRequest(requestImage, prompt, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data) - - delivery_method_enum = EDeliveryMethod(requestImage.deliveryMethod) if isinstance(requestImage.deliveryMethod, str) else requestImage.deliveryMethod - - if delivery_method_enum is EDeliveryMethod.ASYNC: - if requestImage.webhookURL: - request_object["webhookURL"] = requestImage.webhookURL - + try: + if should_send: await self.send([request_object]) - - return await self._handleInitialImageResponse( - requestImage.taskUUID, - requestImage.numberResults, - requestImage.deliveryMethod, - request_object.get("webhookURL"), - "image-inference-initial" - ) - - return await self._requestImages( - request_object=request_object, - task_uuids=task_uuids, - let_lis=let_lis, - retry_count=retry_count, - number_of_images=requestImage.numberResults, - on_partial_images=requestImage.onPartialImages, - ) - except Exception as e: - if let_lis: - let_lis["destroy"]() - raise e - - async def _requestImages( - self, - request_object: Dict[str, Any], - task_uuids: List[str], - let_lis: Optional[Any], - retry_count: int, - number_of_images: int, - on_partial_images: Optional[Callable[[List[IImage], Optional[IError]], None]], - ) -> Union[List[IImage], IAsyncTaskResponse]: - retry_count += 1 - if let_lis: - let_lis["destroy"]() - images_with_similar_task = [ - img for img in self._globalImages if img.get("taskUUID") in task_uuids - ] + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=IMAGE_INFERENCE_TIMEOUT / 1000) - task_uuid = request_object.get("taskUUID") - if task_uuid is None: - task_uuid = getUUID() + if not results: + raise Exception(f"No results received | TaskUUID: {task_uuid}") - task_uuids.append(task_uuid) + return instantiateDataclassList(IImage, results) - image_remaining = number_of_images - len(images_with_similar_task) - new_request_object = { - **request_object, - "taskUUID": task_uuid, - "numberResults": image_remaining, - } + except asyncio.TimeoutError: + op = self._pending_operations.get(task_uuid) + partial_count = len(op["results"]) if op else 0 - await self.send([new_request_object]) + if op and op["results"]: + self.logger.warning( + f"Timeout but returning {partial_count} partial results | " + f"TaskUUID: {task_uuid} | Expected: {number_results}" + ) + return instantiateDataclassList(IImage, op["results"]) - if new_request_object.get("webhookURL"): - return await self._handleWebhookAcknowledgment( - task_uuid=task_uuid, - task_type="imageInference", - debug_key="image-inference-webhook" + raise Exception( + f"Timeout waiting for image inference | TaskUUID: {task_uuid} | " + f"Expected: {number_results} | Received: {partial_count} | " + f"Timeout: {IMAGE_INFERENCE_TIMEOUT}ms" ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(task_uuid) - let_lis = await self.listenToImages( - onPartialImages=on_partial_images, - taskUUID=task_uuid, - groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES, - ) - images = await self.getSimililarImage( - taskUUID=task_uuids, - numberOfImages=number_of_images, - shouldThrowError=True, - lis=let_lis, - ) - - let_lis["destroy"]() - # TODO: NameError("name 'image_path' is not defined"). I think I remove the images when I have onPartialImages - if images: - if "code" in images: - # This indicates an error response - raise RunwareAPIError(images) - - return instantiateDataclassList(IImage, images) - - # return images - - async def imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._imageCaption, requestImageToText) + async def imageCaption(self, requestImageToText: "IImageCaption") -> "Union[IImageToText, IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_with_reconnect(self._imageCaption, requestImageToText) async def _imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]: await self.ensureConnection() return await self._requestImageToText(requestImageToText) async def _requestImageToText( - self, requestImageToText: IImageCaption - ) -> Union[IImageToText, IAsyncTaskResponse]: + self, requestImageToText: "IImageCaption" + ) -> "Union[IImageToText, IAsyncTaskResponse]": # Prepare image list - inputImages is primary, inputImage is convenience if requestImageToText.inputImages is not None: images_to_process = requestImageToText.inputImages @@ -659,6 +880,7 @@ async def _requestImageToText( uploaded_images.append(image_uploaded.imageUUID) taskUUID = getUUID() + requestImageToText.taskUUID = taskUUID # Create a dictionary with mandatory parameters task_params = { @@ -690,63 +912,51 @@ async def _requestImageToText( task_params["includeCost"] = requestImageToText.includeCost if requestImageToText.webhookURL: task_params["webhookURL"] = requestImageToText.webhookURL - - await self.send([task_params]) - - if requestImageToText.webhookURL: - return await self._handleWebhookAcknowledgment( + return await self._handleWebhookRequest( + request_object=task_params, task_uuid=taskUUID, task_type="imageCaption", debug_key="image-caption-webhook" ) - lis = self.globalListener( - taskUUID=taskUUID, - ) - - async def check(resolve: callable, reject: callable, *args: Any) -> bool: - async with self._messages_lock: - response = self._globalMessages.get(taskUUID) - if response: - image_to_text = response[0] - else: - image_to_text = response - if image_to_text and image_to_text.get("error"): - reject(image_to_text) - return True - - if image_to_text: - del self._globalMessages[taskUUID] - resolve(image_to_text) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="image-to-text", timeOutDuration=IMAGE_OPERATION_TIMEOUT + future, should_send = await self._register_pending_operation( + taskUUID, + expected_results=1, + complete_predicate=lambda r: True ) - - lis["destroy"]() - - self._handle_error_response(response) - - if response: + try: + if should_send: + await self.send([task_params]) + await self._mark_operation_sent(taskUUID) + results = await asyncio.wait_for(future, timeout=IMAGE_OPERATION_TIMEOUT / 1000) + response = results[0] + self._handle_error_response(response) return createImageToTextFromResponse(response) - else: - return None - - async def videoCaption(self, requestVideoCaption: IVideoCaption) -> Union[List[IVideoToText], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._videoCaption, requestVideoCaption) - - async def _videoCaption(self, requestVideoCaption: IVideoCaption) -> Union[List[IVideoToText], IAsyncTaskResponse]: - await self.ensureConnection() - return await self._requestVideoCaption(requestVideoCaption) + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for image caption | TaskUUID: {taskUUID} | " + f"Timeout: {IMAGE_OPERATION_TIMEOUT}ms" + ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(taskUUID) + + async def videoCaption(self, requestVideoCaption: "IVideoCaption") -> "Union[List[IVideoToText], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._requestVideoCaption, + requestVideoCaption, + task_type=ETaskType.VIDEO_CAPTION.value + ) async def _requestVideoCaption( - self, requestVideoCaption: IVideoCaption - ) -> Union[List[IVideoToText], IAsyncTaskResponse]: + self, requestVideoCaption: "IVideoCaption" + ) -> "Union[List[IVideoToText], IAsyncTaskResponse]": + await self.ensureConnection() taskUUID = requestVideoCaption.taskUUID or getUUID() + requestVideoCaption.taskUUID = taskUUID # Create the request object task_params = { @@ -764,32 +974,35 @@ async def _requestVideoCaption( task_params["includeCost"] = requestVideoCaption.includeCost if requestVideoCaption.webhookURL: task_params["webhookURL"] = requestVideoCaption.webhookURL - - await self.send([task_params]) - - if requestVideoCaption.webhookURL: - return await self._handleWebhookAcknowledgment( + return await self._handleWebhookRequest( + request_object=task_params, task_uuid=taskUUID, task_type="caption", debug_key="video-caption-webhook" ) + await self.send([task_params]) + return await self._pollResults( task_uuid=taskUUID, number_results=1, ) - async def videoBackgroundRemoval(self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval) -> Union[List[IVideo], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._videoBackgroundRemoval, requestVideoBackgroundRemoval) - - async def _videoBackgroundRemoval(self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval) -> Union[List[IVideo], IAsyncTaskResponse]: - await self.ensureConnection() - return await self._requestVideoBackgroundRemoval(requestVideoBackgroundRemoval) + async def videoBackgroundRemoval(self, + requestVideoBackgroundRemoval: "IVideoBackgroundRemoval") -> "Union[List[IVideo], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._requestVideoBackgroundRemoval, + requestVideoBackgroundRemoval, + task_type=ETaskType.VIDEO_BACKGROUND_REMOVAL.value + ) async def _requestVideoBackgroundRemoval( - self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval - ) -> Union[List[IVideo], IAsyncTaskResponse]: + self, requestVideoBackgroundRemoval: "IVideoBackgroundRemoval" + ) -> "Union[List[IVideo], IAsyncTaskResponse]": + await self.ensureConnection() taskUUID = requestVideoBackgroundRemoval.taskUUID or getUUID() + requestVideoBackgroundRemoval.taskUUID = taskUUID # Create the request object task_params = { @@ -818,34 +1031,35 @@ async def _requestVideoBackgroundRemoval( } task_params["settings"] = settings_dict - await self.send([task_params]) - if requestVideoBackgroundRemoval.webhookURL: - return await self._handleWebhookAcknowledgment( + return await self._handleWebhookRequest( + request_object=task_params, task_uuid=taskUUID, task_type="removeBackground", debug_key="video-background-removal-webhook" ) return await self._handleInitialVideoResponse( - taskUUID, - 1, - requestVideoBackgroundRemoval.deliveryMethod, - task_params.get("webhookURL"), - "video-background-removal-initial" + request_object=task_params, + task_uuid=taskUUID, + number_results=1, + delivery_method=requestVideoBackgroundRemoval.deliveryMethod, + webhook_url=None, + debug_key="video-background-removal-initial" ) - async def videoUpscale(self, requestVideoUpscale: IVideoUpscale) -> Union[List[IVideo], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._videoUpscale, requestVideoUpscale) + async def videoUpscale(self, requestVideoUpscale: "IVideoUpscale") -> "Union[List[IVideo], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._requestVideoUpscale, + requestVideoUpscale, + task_type=ETaskType.VIDEO_UPSCALE.value + ) - async def _videoUpscale(self, requestVideoUpscale: IVideoUpscale) -> Union[List[IVideo], IAsyncTaskResponse]: + async def _requestVideoUpscale(self, requestVideoUpscale: "IVideoUpscale") -> "Union[List[IVideo], IAsyncTaskResponse]": await self.ensureConnection() - return await self._requestVideoUpscale(requestVideoUpscale) - - async def _requestVideoUpscale( - self, requestVideoUpscale: IVideoUpscale - ) -> Union[List[IVideo], IAsyncTaskResponse]: taskUUID = requestVideoUpscale.taskUUID or getUUID() + requestVideoUpscale.taskUUID = taskUUID # Create the request object task_params = { @@ -870,47 +1084,45 @@ async def _requestVideoUpscale( if requestVideoUpscale.webhookURL: task_params["webhookURL"] = requestVideoUpscale.webhookURL - await self.send([task_params]) - if requestVideoUpscale.webhookURL: - return await self._handleWebhookAcknowledgment( + return await self._handleWebhookRequest( + request_object=task_params, task_uuid=taskUUID, task_type="upscale", debug_key="video-upscale-webhook" ) return await self._handleInitialVideoResponse( - taskUUID, - 1, - requestVideoUpscale.deliveryMethod, - task_params.get("webhookURL"), - "video-upscale-initial" + request_object=task_params, + task_uuid=taskUUID, + number_results=1, + delivery_method=requestVideoUpscale.deliveryMethod, + webhook_url=None, + debug_key="video-upscale-initial" ) async def imageBackgroundRemoval( - self, removeImageBackgroundPayload: IImageBackgroundRemoval - ) -> Union[List[IImage], IAsyncTaskResponse]: - try: - await self.ensureConnection() - return await asyncRetry( - lambda: self._removeImageBackground(removeImageBackgroundPayload) - ) - except Exception as e: - raise e + self, removeImageBackgroundPayload: "IImageBackgroundRemoval" + ) -> "Union[List[IImage], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_with_reconnect(self._removeImageBackground, removeImageBackgroundPayload) async def _removeImageBackground( - self, removeImageBackgroundPayload: IImageBackgroundRemoval - ) -> Union[List[IImage], IAsyncTaskResponse]: + self, removeImageBackgroundPayload: "IImageBackgroundRemoval" + ) -> "Union[List[IImage], IAsyncTaskResponse]": + await self.ensureConnection() inputImage = removeImageBackgroundPayload.inputImage image_uploaded = await self.uploadImage(inputImage) if not image_uploaded or not image_uploaded.imageUUID: return [] + if removeImageBackgroundPayload.taskUUID is not None: taskUUID = removeImageBackgroundPayload.taskUUID else: taskUUID = getUUID() + removeImageBackgroundPayload.taskUUID = taskUUID # Create a dictionary with mandatory parameters task_params = { @@ -941,77 +1153,65 @@ async def _removeImageBackground( if v is not None } task_params.update(settings_dict) - + # Add provider settings if provided if removeImageBackgroundPayload.providerSettings: self._addImageProviderSettings(task_params, removeImageBackgroundPayload) - + # Add safety settings if provided if removeImageBackgroundPayload.safety: self._addSafetySettings(task_params, removeImageBackgroundPayload.safety) - # Send the task with all applicable parameters - await self.send([task_params]) - if removeImageBackgroundPayload.webhookURL: - return await self._handleWebhookAcknowledgment( + return await self._handleWebhookRequest( + request_object=task_params, task_uuid=taskUUID, task_type="imageBackgroundRemoval", debug_key="image-background-removal-webhook" ) - lis = self.globalListener( - taskUUID=taskUUID, - ) - - async def check(resolve: callable, reject: callable, *args: Any) -> bool: - async with self._messages_lock: - response = self._globalMessages.get(taskUUID) - if response: - new_remove_background = response[0] - else: - new_remove_background = response - if new_remove_background and new_remove_background.get("error"): - reject(new_remove_background) - return True - - if new_remove_background: - del self._globalMessages[taskUUID] - resolve(new_remove_background) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="remove-image-background", timeOutDuration=IMAGE_OPERATION_TIMEOUT + future, should_send = await self._register_pending_operation( + taskUUID, + expected_results=1, + complete_predicate=None, + result_filter=lambda r: r.get("imageUUID") is not None ) - lis["destroy"]() - - self._handle_error_response(response) - - image = createImageFromResponse(response) - image_list: List[IImage] = [image] - - return image_list + try: + if should_send: + await self.send([task_params]) + await self._mark_operation_sent(taskUUID) + results = await asyncio.wait_for(future, timeout=IMAGE_OPERATION_TIMEOUT / 1000) + response = results[0] + self._handle_error_response(response) + image = createImageFromResponse(response) + return [image] + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for background removal | TaskUUID: {taskUUID} | " + f"Timeout: {IMAGE_OPERATION_TIMEOUT}ms" + ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(taskUUID) - async def imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._imageUpscale, upscaleGanPayload) + async def imageUpscale(self, upscaleGanPayload: "IImageUpscale") -> "Union[List[IImage], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_with_reconnect(self._imageUpscale, upscaleGanPayload) async def _imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]: await self.ensureConnection() return await self._upscaleGan(upscaleGanPayload) - async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]: + async def _upscaleGan(self, upscaleGanPayload: "IImageUpscale") -> "Union[List[IImage], IAsyncTaskResponse]": # Support both inputImage (legacy) and inputs.image (new format) inputImage = upscaleGanPayload.inputImage if not inputImage and upscaleGanPayload.inputs and upscaleGanPayload.inputs.image: inputImage = upscaleGanPayload.inputs.image - + if not inputImage: raise ValueError("Either inputImage or inputs.image must be provided") - - upscaleFactor = upscaleGanPayload.upscaleFactor image_uploaded = await self.uploadImage(inputImage) @@ -1019,6 +1219,7 @@ async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IIma return [] taskUUID = getUUID() + upscaleGanPayload.taskUUID = taskUUID # Create a dictionary with mandatory parameters task_params = { @@ -1026,7 +1227,7 @@ async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IIma "taskUUID": taskUUID, "upscaleFactor": upscaleGanPayload.upscaleFactor, } - + # Use inputs.image format if inputs is provided, otherwise use inputImage (legacy) if upscaleGanPayload.inputs and upscaleGanPayload.inputs.image: task_params["inputs"] = {"image": image_uploaded.imageUUID} @@ -1054,82 +1255,69 @@ async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IIma task_params["includeCost"] = upscaleGanPayload.includeCost if upscaleGanPayload.webhookURL: task_params["webhookURL"] = upscaleGanPayload.webhookURL - + # Add provider settings if provided if upscaleGanPayload.providerSettings: self._addImageProviderSettings(task_params, upscaleGanPayload) - + # Add safety settings if provided if upscaleGanPayload.safety: self._addSafetySettings(task_params, upscaleGanPayload.safety) - # Send the task with all applicable parameters - - await self.send([task_params]) - if upscaleGanPayload.webhookURL: - return await self._handleWebhookAcknowledgment( + return await self._handleWebhookRequest( + request_object=task_params, task_uuid=taskUUID, task_type="imageUpscale", debug_key="image-upscale-webhook" ) - lis = self.globalListener( - taskUUID=taskUUID, - ) - - async def check(resolve: callable, reject: callable, *args: Any) -> bool: - async with self._messages_lock: - response = self._globalMessages.get(taskUUID) - if response: - upscaled_image = response[0] - else: - upscaled_image = response - if upscaled_image and upscaled_image.get("error"): - reject(upscaled_image) - return True - - if upscaled_image: - del self._globalMessages[taskUUID] - resolve(upscaled_image) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="upscale-gan", timeOutDuration=IMAGE_OPERATION_TIMEOUT + future, should_send = await self._register_pending_operation( + taskUUID, + expected_results=1, + complete_predicate=None, + result_filter=lambda r: r.get("imageUUID") is not None ) - lis["destroy"]() - - self._handle_error_response(response) - - image = createImageFromResponse(response) - image_list: List[IImage] = [image] - return image_list + try: + if should_send: + await self.send([task_params]) + await self._mark_operation_sent(taskUUID) + results = await asyncio.wait_for(future, timeout=IMAGE_OPERATION_TIMEOUT / 1000) + response = results[0] + self._handle_error_response(response) + image = createImageFromResponse(response) + return [image] + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for image upscale | TaskUUID: {taskUUID} | " + f"Timeout: {IMAGE_OPERATION_TIMEOUT}ms" + ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(taskUUID) - async def imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._imageVectorize, vectorizePayload) + async def imageVectorize(self, vectorizePayload: "IVectorize") -> "Union[List[IImage], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_with_reconnect(self._vectorize, vectorizePayload) - async def _imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]: + async def _vectorize(self, vectorizePayload: "IVectorize") -> Union[List["IImage"], "IAsyncTaskResponse"]: await self.ensureConnection() - return await self._vectorize(vectorizePayload) - - async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]: # Process the image from inputs input_image = vectorizePayload.inputs.image - + if not input_image: raise ValueError("Image is required in inputs for vectorize task") - + # Upload the image if it's a local file image_uploaded = await self.uploadImage(input_image) - + if not image_uploaded or not image_uploaded.imageUUID: return [] - + taskUUID = getUUID() - + # Create a dictionary with mandatory parameters task_params = { "taskType": ETaskType.IMAGE_VECTORIZE.value, @@ -1138,7 +1326,7 @@ async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], "image": image_uploaded.imageUUID } } - + # Add optional parameters if they are provided if vectorizePayload.model is not None: task_params["model"] = vectorizePayload.model @@ -1150,41 +1338,49 @@ async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], task_params["includeCost"] = vectorizePayload.includeCost if vectorizePayload.webhookURL: task_params["webhookURL"] = vectorizePayload.webhookURL - - # Send the task with all applicable parameters - await self.send([task_params]) - - if vectorizePayload.webhookURL: - return await self._handleWebhookAcknowledgment( + return await self._handleWebhookRequest( + request_object=task_params, task_uuid=taskUUID, task_type="vectorize", debug_key="image-vectorize-webhook" ) - - let_lis = await self.listenToImages( - onPartialImages=None, - taskUUID=taskUUID, - groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES, - ) - - images = await self.getSimililarImage( - taskUUID=taskUUID, - numberOfImages=1, - shouldThrowError=True, - lis=let_lis, + + future, should_send = await self._register_pending_operation( + taskUUID, + expected_results=1, + complete_predicate=None, + result_filter=lambda r: r.get("imageUUID") is not None ) - - let_lis["destroy"]() - - if "code" in images or "errors" in images: - # This indicates an error response - raise RunwareAPIError(images) - - return instantiateDataclassList(IImage, images) + + try: + + if should_send: + await self.send([task_params]) + await self._mark_operation_sent(taskUUID) + results = await asyncio.wait_for(future, timeout=IMAGE_OPERATION_TIMEOUT / 1000) + + if not results: + raise Exception(f"No results received | TaskUUID: {taskUUID}") + + for result in results: + if "code" in result or "errors" in result: + raise RunwareAPIError(result) + + return instantiateDataclassList(IImage, results) + + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for vectorize | TaskUUID: {taskUUID} | " + f"Timeout: {IMAGE_OPERATION_TIMEOUT}ms" + ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(taskUUID) async def promptEnhance( - self, promptEnhancer: IPromptEnhance - ) -> Union[List[IEnhancedPrompt], IAsyncTaskResponse]: + self, promptEnhancer: "IPromptEnhance" + ) -> "Union[List[IEnhancedPrompt], IAsyncTaskResponse]": """ Enhance the given prompt by generating multiple versions of it. @@ -1192,27 +1388,20 @@ async def promptEnhance( :return: A list of IEnhancedPrompt objects representing the enhanced versions of the prompt. :raises: Any error that occurs during the enhancement process. """ - try: - await self.ensureConnection() - return await asyncRetry(lambda: self._enhancePrompt(promptEnhancer)) - except Exception as e: - raise e + async with self._request_semaphore: + return await self._retry_with_reconnect(self._enhancePrompt, promptEnhancer) async def _enhancePrompt( - self, promptEnhancer: IPromptEnhance - ) -> Union[List[IEnhancedPrompt], IAsyncTaskResponse]: - """ - Internal method to perform the actual prompt enhancement. - - :param promptEnhancer: An IPromptEnhancer object containing the prompt details. - :return: A list of IEnhancedPrompt objects representing the enhanced versions of the prompt. - """ + self, promptEnhancer: "IPromptEnhance" + ) -> "Union[List[IEnhancedPrompt], IAsyncTaskResponse]": + self.ensureConnection() prompt = promptEnhancer.prompt promptMaxLength = getattr(promptEnhancer, "promptMaxLength", 380) promptVersions = promptEnhancer.promptVersions or 1 taskUUID = getUUID() + promptEnhancer.taskUUID = taskUUID # Create a dictionary with mandatory parameters task_params = { @@ -1227,61 +1416,63 @@ async def _enhancePrompt( if promptEnhancer.includeCost: task_params["includeCost"] = promptEnhancer.includeCost - has_webhook = promptEnhancer.webhookURL - if has_webhook: + if promptEnhancer.webhookURL: task_params["webhookURL"] = promptEnhancer.webhookURL - - # Send the task with all applicable parameters - await self.send([task_params]) - - if has_webhook: - return await self._handleWebhookAcknowledgment( + return await self._handleWebhookRequest( + request_object=task_params, task_uuid=taskUUID, task_type="promptEnhance", debug_key="prompt-enhance-webhook" ) - lis = self.globalListener( - taskUUID=taskUUID, - ) - - async def check(resolve: Any, reject: Any, *args: Any) -> bool: - async with self._messages_lock: - response = self._globalMessages.get(taskUUID) - if isinstance(response, dict) and response.get("error"): - reject(response) - return True - if response: - del self._globalMessages[taskUUID] - resolve(response) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="enhance-prompt", timeOutDuration=PROMPT_ENHANCE_TIMEOUT + future, should_send = await self._register_pending_operation( + taskUUID, + expected_results=promptVersions, + complete_predicate=None, + result_filter=lambda r: r.get("text") is not None ) - lis["destroy"]() - - if "code" in response[0]: - # This indicates an error response - raise RunwareAPIError(response[0]) - - # Transform the response to a list of IEnhancedPrompt objects - enhanced_prompts = createEnhancedPromptsFromResponse(response) - - return list(set(enhanced_prompts)) + try: + if should_send: + await self.send([task_params]) + await self._mark_operation_sent(taskUUID) + results = await asyncio.wait_for(future, timeout=PROMPT_ENHANCE_TIMEOUT / 1000) + + if not results: + raise Exception(f"No results received | TaskUUID: {taskUUID}") + + for result in results: + if "code" in result: + raise RunwareAPIError(result) + + # Transform the response to a list of IEnhancedPrompt objects + enhanced_prompts = createEnhancedPromptsFromResponse(results) + return list(set(enhanced_prompts)) + + except asyncio.TimeoutError: + op = self._pending_operations.get(taskUUID) + partial_count = len(op["results"]) if op else 0 + raise Exception( + f"Timeout waiting for prompt enhance | TaskUUID: {taskUUID} | " + f"Expected: {promptVersions} | Received: {partial_count} | " + f"Timeout: {PROMPT_ENHANCE_TIMEOUT}ms" + ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(taskUUID) async def uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: - return await self._retry_with_reconnect(self._uploadImage, file) + request = IUploadImageRequest(file=file, taskUUID=getUUID()) + return await self._retry_with_reconnect(self._uploadImage, request) - async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: + async def _uploadImage(self, request: IUploadImageRequest) -> "Optional[UploadImageType]": await self.ensureConnection() - - task_uuid = getUUID() + + file = request.file + task_uuid = request.taskUUID local_file = True - + if isinstance(file, str): if os.path.exists(file): local_file = True @@ -1294,6 +1485,7 @@ async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType ): # Assume it's a base64 string (with or without data URI prefix) local_file = False + if not local_file: return UploadImageType( imageUUID=file, @@ -1304,61 +1496,53 @@ async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType # Convert file to base64 (handles both File objects and string paths) file_data = await fileToBase64(file) - await self.send( - [ - { - "taskType": ETaskType.IMAGE_UPLOAD.value, - "taskUUID": task_uuid, - "image": file_data, - } - ] - ) - - lis = self.globalListener(taskUUID=task_uuid) - - async def check(resolve: callable, reject: callable, *args: Any) -> bool: - async with self._messages_lock: - uploaded_image_list = self._globalMessages.get(task_uuid) - uploaded_image = uploaded_image_list[0] if uploaded_image_list else None - - if uploaded_image and uploaded_image.get("error"): - reject(uploaded_image) - return True - - if uploaded_image: - del self._globalMessages[task_uuid] - resolve(uploaded_image) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="upload-image", timeOutDuration=IMAGE_UPLOAD_TIMEOUT + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=lambda r: True ) - lis["destroy"]() + try: + if should_send: + await self.send([ + { + "taskType": ETaskType.IMAGE_UPLOAD.value, + "taskUUID": task_uuid, + "image": file_data, + } + ]) + await self._mark_operation_sent(task_uuid) - self._handle_error_response(response) + results = await asyncio.wait_for(future, timeout=IMAGE_UPLOAD_TIMEOUT / 1000) + response = results[0] + self._handle_error_response(response) - if response: - image = UploadImageType( + return UploadImageType( imageUUID=response["imageUUID"], imageURL=response["imageURL"], taskUUID=response["taskUUID"], ) - else: - image = None - return image + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for image upload | TaskUUID: {task_uuid} | " + f"Timeout: {IMAGE_UPLOAD_TIMEOUT}ms" + ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(task_uuid) - async def uploadMedia(self, media_url: str) -> Optional[MediaStorageType]: - return await self._retry_with_reconnect(self._uploadMedia, media_url) + async def uploadMedia(self, media_url: str) -> "Optional[MediaStorageType]": + request = IUploadMediaRequest(media_url=media_url, taskUUID=getUUID()) + return await self._retry_with_reconnect(self._uploadMedia, request) - async def _uploadMedia(self, media_url: str) -> Optional[MediaStorageType]: + async def _uploadMedia(self, request: IUploadMediaRequest) -> "Optional[MediaStorageType]": await self.ensureConnection() - - task_uuid = getUUID() + + media_url = request.media_url + task_uuid = request.taskUUID media_data = media_url - + if isinstance(media_url, str): if os.path.exists(media_url): # Local file - convert to base64 @@ -1367,70 +1551,48 @@ async def _uploadMedia(self, media_url: str) -> Optional[MediaStorageType]: if media_data.startswith("data:"): media_data = media_data.split(",", 1)[1] # For URLs and base64 strings, send them directly to the API - - await self.send( - [ - { - "taskType": ETaskType.MEDIA_STORAGE.value, - "taskUUID": task_uuid, - "operation": "upload", - "media": media_data, - } - ] - ) - - lis = self.globalListener(taskUUID=task_uuid) - def check(resolve: callable, reject: callable, *args: Any) -> bool: - uploaded_media_list = self._globalMessages.get(task_uuid) - uploaded_media = uploaded_media_list[0] if uploaded_media_list else None - - if uploaded_media and uploaded_media.get("error"): - reject(uploaded_media) - return True - - if uploaded_media: - del self._globalMessages[task_uuid] - resolve(uploaded_media) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="upload-media", timeOutDuration=self._timeout + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=lambda r: r.get("mediaUUID") is not None ) - lis["destroy"]() + try: + if should_send: + await self.send( + [ + { + "taskType": ETaskType.MEDIA_STORAGE.value, + "taskUUID": task_uuid, + "operation": "upload", + "media": media_data, + } + ] + ) + await self._mark_operation_sent(task_uuid) - self._handle_error_response(response) + results = await asyncio.wait_for(future, timeout=self._timeout / 1000) + response = results[0] - if response: - media = MediaStorageType( - mediaUUID=response["mediaUUID"], - taskUUID=response["taskUUID"], - ) - else: - media = None - return media + self._handle_error_response(response) - async def uploadUnprocessedImage( - self, - file: Union[File, str], - preProcessorType: EPreProcessorGroup, - width: int = None, - height: int = None, - lowThresholdCanny: int = None, - highThresholdCanny: int = None, - includeHandsAndFaceOpenPose: bool = True, - ) -> Optional[UploadImageType]: - # Create a dummy UploadImageType object - uploaded_unprocessed_image = UploadImageType( - imageUUID=str(uuid.uuid4()), - imageURL="https://example.com/uploaded_unprocessed_image.jpg", - taskUUID=str(uuid.uuid4()), - ) + if response: + return MediaStorageType( + mediaUUID=response["mediaUUID"], + taskUUID=response["taskUUID"], + ) + return None - return uploaded_unprocessed_image + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for media upload | TaskUUID: {task_uuid} | " + f"Timeout: {self._timeout}ms" + ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(task_uuid) async def listenToImages( self, @@ -1562,29 +1724,6 @@ def global_check(m): return temp_listener - async def handleIncompleteImages( - self, taskUUIDs: List[str], error: Any - ) -> Optional[List[IImage]]: - """ - Handle scenarios where the requested number of images is not fully received. - - :param taskUUIDs: A list of task UUIDs to filter the images. - :param error: The error object to raise if there are no or only one image. - :return: A list of available images if there are more than one, otherwise None. - :raises: The provided error if there are no or only one image. - """ - async with self._images_lock: - imagesWithSimilarTask = [ - img for img in self._globalImages if img["taskUUID"] in taskUUIDs - ] - if len(imagesWithSimilarTask) > 1: - self._globalImages = [ - img for img in self._globalImages if img["taskUUID"] not in taskUUIDs - ] - return imagesWithSimilarTask - else: - raise error - async def ensureConnection(self) -> None: """ Ensure that a connection is established with the server. @@ -1648,7 +1787,7 @@ async def check( f"TaskUUIDs: {taskUUIDs}" )) return True - + async with self._images_lock: logger.debug(f"Check # Global images: {self._globalImages}") imagesWithSimilarTask = [ @@ -1702,10 +1841,10 @@ async def check( raise Exception(error_msg) from e async def _modelUpload( - self, requestModel: IUploadModelBaseType - ) -> Optional[IUploadModelResponse]: + self, requestModel: "IUploadModelBaseType" + ) -> "Optional[IUploadModelResponse]": await self.ensureConnection() - + task_uuid = requestModel.taskUUID or getUUID() base_fields = { "taskType": ETaskType.MODEL_UPLOAD.value, @@ -1747,57 +1886,40 @@ async def _modelUpload( }, } - await self.send([request_object]) + def is_ready(item: "Dict[str, Any]") -> bool: + return item.get("status") == "ready" - lis = self.globalListener( - taskUUID=task_uuid, + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=is_ready, + result_filter=lambda r: r.get("status") is not None ) - async def check(resolve: callable, reject: callable, *args: Any) -> bool: - async with self._messages_lock: - uploaded_model_list = self._globalMessages.get(task_uuid, []) - unique_statuses = set() - all_models = [] - - for uploaded_model in uploaded_model_list: - if uploaded_model.get("code"): - self._handle_error_response(uploaded_model) - - status = uploaded_model.get("status") - - if status not in unique_statuses: - all_models.append(uploaded_model) - unique_statuses.add(status) - - if status is not None and "error" in status: - self._handle_error_response(uploaded_model) - - if status == "ready": - uploaded_model_list.remove(uploaded_model) - if not uploaded_model_list: - del self._globalMessages[task_uuid] - else: - self._globalMessages[task_uuid] = uploaded_model_list - resolve(all_models) - return True + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=MODEL_UPLOAD_TIMEOUT / 1000) - return False + unique_statuses = set() + all_models = [] - response = await getIntervalWithPromise( - check, debugKey="upload-model", timeOutDuration=MODEL_UPLOAD_TIMEOUT - ) + for uploaded_model in results: + if uploaded_model.get("code"): + self._handle_error_response(uploaded_model) - lis["destroy"]() + status = uploaded_model.get("status") - if isinstance(response, dict): - self._handle_error_response(response) + if status is not None and "error" in status: + self._handle_error_response(uploaded_model) - if response: - if not isinstance(response, list): - response = [response] + if status not in unique_statuses: + all_models.append(uploaded_model) + unique_statuses.add(status) models = [] - for item in response: + for item in all_models: models.append( { "taskType": item.get("taskType"), @@ -1807,19 +1929,33 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool: "air": item.get("air"), } ) - else: - models = None - return models + + return models + + except asyncio.TimeoutError: + op = self._pending_operations.get(task_uuid) + partial_count = len(op["results"]) if op else 0 + raise Exception( + f"Timeout waiting for model upload | TaskUUID: {task_uuid} | " + f"Received: {partial_count} status updates | " + f"Timeout: {MODEL_UPLOAD_TIMEOUT}ms" + ) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(task_uuid) async def modelUpload( - self, requestModel: IUploadModelBaseType - ) -> Optional[IUploadModelResponse]: - return await self._retry_with_reconnect(self._modelUpload, requestModel) + self, requestModel: "IUploadModelBaseType" + ) -> "Optional[IUploadModelResponse]": + async with self._request_semaphore: + return await self._retry_with_reconnect(self._modelUpload, requestModel) - async def modelSearch(self, payload: IModelSearch) -> IModelSearchResponse: - return await self._retry_with_reconnect(self._modelSearch, payload) + async def modelSearch(self, payload: "IModelSearch") -> "IModelSearchResponse": + async with self._request_semaphore: + return await self._retry_with_reconnect(self._modelSearch, payload) - async def _modelSearch(self, payload: IModelSearch) -> IModelSearchResponse: + async def _modelSearch(self, payload: "IModelSearch") -> "IModelSearchResponse": try: await self.ensureConnection() task_uuid = getUUID() @@ -1838,101 +1974,112 @@ async def _modelSearch(self, payload: IModelSearch) -> IModelSearchResponse: } ) - await self.send([request_object]) - - listener = self.globalListener(taskUUID=task_uuid) - - async def check(resolve: Callable, reject: Callable, *args: Any) -> bool: - async with self._messages_lock: - response = self._globalMessages.get(task_uuid) - if response: - if response[0].get("error"): - reject(response[0]) - return True - del self._globalMessages[task_uuid] - resolve(response[0]) - return True - return False - - response = await getIntervalWithPromise( - check, debugKey="model-search", timeOutDuration=self._timeout + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=lambda r: True ) - listener["destroy"]() - - self._handle_error_response(response) - - return instantiateDataclass(IModelSearchResponse, response) + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=self._timeout / 1000) + response = results[0] + self._handle_error_response(response) + return instantiateDataclass(IModelSearchResponse, response) + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for model search | TaskUUID: {task_uuid} | " + f"Timeout: {self._timeout}ms" + ) + finally: + await self._unregister_pending_operation(task_uuid) + except RunwareAPIError: + raise except Exception as e: if isinstance(e, RunwareAPIError): raise - raise RunwareAPIError({"message": str(e)}) - async def videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._videoInference, requestVideo) + async def videoInference(self, requestVideo: "IVideoInference") -> "Union[List[IVideo], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._videoInference, + requestVideo, + task_type=ETaskType.VIDEO_INFERENCE.value + ) async def _videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]: await self.ensureConnection() return await self._requestVideo(requestVideo) async def inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._inference3d, request3d) + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._inference3d, + request3d, + task_type=ETaskType.INFERENCE_3D.value, + ) async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTaskResponse]: await self.ensureConnection() return await self._request3d(request3d) async def textInference(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._textInference, requestText) - - async def _textInference(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: - await self.ensureConnection() - return await self._requestText(requestText) + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._requestText, + requestText, + task_type=ETaskType.TEXT_INFERENCE.value, + ) async def getResponse( self, taskUUID: str, numberResults: Optional[int] = 1, ) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText]]: - return await self._retry_with_reconnect(self._getResponse, taskUUID, numberResults) + async with self._request_semaphore: + request = IGetResponseRequest( + taskUUID=taskUUID, + numberResults=numberResults or 1, + ) + return await self._retry_with_reconnect(self._getResponse, request) async def _getResponse( self, - taskUUID: str, - numberResults: Optional[int] = 1, + request_model: IGetResponseRequest, ) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText]]: await self.ensureConnection() return await self._pollResults( - task_uuid=taskUUID, - number_results=numberResults, + task_uuid=request_model.taskUUID, + number_results=request_model.numberResults, ) - async def _requestVideo(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]: + async def _requestVideo(self, requestVideo: "IVideoInference") -> "Union[List[IVideo], IAsyncTaskResponse]": await self._processVideoImages(requestVideo) requestVideo.taskUUID = requestVideo.taskUUID or getUUID() request_object = self._buildVideoRequest(requestVideo) - if requestVideo.webhookURL: request_object["webhookURL"] = requestVideo.webhookURL - await self.send([request_object]) - if requestVideo.skipResponse: + await self.send([request_object]) return IAsyncTaskResponse( taskType=ETaskType.VIDEO_INFERENCE.value, taskUUID=requestVideo.taskUUID ) return await self._handleInitialVideoResponse( - requestVideo.taskUUID, - requestVideo.numberResults, - requestVideo.deliveryMethod, - request_object.get("webhookURL"), - "video-inference-initial" + request_object=request_object, + task_uuid=requestVideo.taskUUID, + number_results=requestVideo.numberResults, + delivery_method=requestVideo.deliveryMethod, + webhook_url=request_object.get("webhookURL"), + debug_key="video-inference-initial" ) async def _processVideoImages(self, requestVideo: IVideoInference) -> None: @@ -1975,11 +2122,11 @@ def _buildVideoRequest(self, requestVideo: IVideoInference) -> Dict[str, Any]: "taskUUID": requestVideo.taskUUID, "model": requestVideo.model, } - + # Only add numberResults if it's not None if requestVideo.numberResults is not None: request_object["numberResults"] = requestVideo.numberResults - + # Only add positivePrompt if it's not None if requestVideo.positivePrompt is not None: request_object["positivePrompt"] = requestVideo.positivePrompt.strip() @@ -2016,7 +2163,7 @@ def _addVideoImages(self, request_object: Dict[str, Any], requestVideo: IVideoIn if requestVideo.referenceImages: request_object["referenceImages"] = requestVideo.referenceImages - + # Add lora if present if requestVideo.lora: request_object["lora"] = [ @@ -2067,13 +2214,14 @@ async def _request3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTa await self._process3dInputs(request3d) request3d.taskUUID = request3d.taskUUID or getUUID() request_object = self._build3dRequest(request3d) - await self.send([request_object]) + return await self._handleInitial3dResponse( - request3d.taskUUID, - request3d.numberResults or 1, - request3d.deliveryMethod, - request_object.get("webhookURL"), - "3d-inference-initial", + request_object=request_object, + task_uuid=request3d.taskUUID, + number_results=request3d.numberResults or 1, + delivery_method=request3d.deliveryMethod, + webhook_url=request_object.get("webhookURL"), + debug_key="3d-inference-initial", ) def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]: @@ -2098,98 +2246,103 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]: request_object["stopSequences"] = requestText.stopSequences if requestText.includeCost is not None: request_object["includeCost"] = requestText.includeCost + if requestText.numberResults is not None: + request_object["numberResults"] = requestText.numberResults self._addTextProviderSettings(request_object, requestText) return request_object async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: + await self.ensureConnection() requestText.taskUUID = requestText.taskUUID or getUUID() request_object = self._buildTextRequest(requestText) - await self.send([request_object]) + + if requestText.webhookURL: + request_object["webhookURL"] = requestText.webhookURL + return await self._handleInitialTextResponse( - requestText.taskUUID, - requestText.deliveryMethod, - "text-inference-initial", + request_object=request_object, + task_uuid=requestText.taskUUID, + number_results=requestText.numberResults or 1, + delivery_method=requestText.deliveryMethod, + webhook_url=request_object.get("webhookURL"), + debug_key="text-inference-initial", ) async def _handleInitialTextResponse( self, + request_object: "Dict[str, Any]", task_uuid: str, - delivery_method: Union[str, EDeliveryMethod] = EDeliveryMethod.SYNC, + number_results: int, + delivery_method: "Union[str, EDeliveryMethod]" = EDeliveryMethod.SYNC, + webhook_url: "Optional[str]" = None, debug_key: str = "text-inference-initial", ) -> Union[List[IText], IAsyncTaskResponse]: - lis = self.globalListener(taskUUID=task_uuid) - delivery_method_enum = delivery_method if isinstance(delivery_method, EDeliveryMethod) else EDeliveryMethod(delivery_method) + if delivery_method is None: + delivery_method = EDeliveryMethod.SYNC + delivery_method_enum = delivery_method if isinstance(delivery_method, EDeliveryMethod) else EDeliveryMethod( + delivery_method, + ) - async def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool: - if not self.connected() or not self.isWebsocketReadyState(): - reject(ConnectionError( - f"Connection lost while waiting for text response | " - f"TaskUUID: {task_uuid} | " - f"Delivery method: {delivery_method_enum}" - )) + def is_text_complete(r: "Dict[str, Any]") -> bool: + if r.get("status") == "success": + return True + if r.get("text") is not None: + return True + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: return True + return False - async with self._messages_lock: - response_list = self._globalMessages.get(task_uuid, []) + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=is_text_complete, + ) - if not response_list: - return False + timeout = TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else TEXT_INITIAL_TIMEOUT - response = response_list[0] + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=timeout / 1000) - if self._is_error_response(response): - del self._globalMessages[task_uuid] - raise RunwareAPIError(response) + if not results: + raise ConnectionError( + f"No initial response received for text inference | " + f"delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + ) - if response.get("status") == "success" or response.get("text") is not None: - del self._globalMessages[task_uuid] - resolve([response]) - return True + response = results[0] + self._handle_error_response(response) - if delivery_method_enum is EDeliveryMethod.ASYNC: - del self._globalMessages[task_uuid] - async_response = createAsyncTaskResponse(response) - resolve([async_response]) - return True + if response.get("status") == "success" or response.get("text") is not None: + return instantiateDataclassList(IText, results) - return False + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return createAsyncTaskResponse(response) - try: - initial_response = await getIntervalWithPromise( - check_initial_response, - debugKey=debug_key, - timeOutDuration=TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else TEXT_INITIAL_TIMEOUT, - ) - except RunwareAPIError: - raise - except Exception as e: + return instantiateDataclassList(IText, results) + + except asyncio.TimeoutError: if not self.connected() or not self.isWebsocketReadyState(): raise ConnectionError( f"Connection lost while waiting for text response | " - f"TaskUUID: {task_uuid} | " - f"Delivery method: {delivery_method_enum}" + f"TaskUUID: {task_uuid} | Delivery method: {delivery_method_enum}" ) + if delivery_method_enum is EDeliveryMethod.SYNC: - error_msg = ( - f"Timeout waiting for text generation | " - f"TaskUUID: {task_uuid} | " - f"Timeout: {TIMEOUT_DURATION}ms | " - f"Original error: {str(e)}" + raise ConnectionError( + f"Timeout waiting for text generation | TaskUUID: {task_uuid} | " + f"Timeout: {timeout}ms" ) - raise ConnectionError(error_msg) - initial_response = None - finally: - lis["destroy"]() - - if not initial_response or len(initial_response) == 0: raise ConnectionError( - f"No initial response received for text generation | delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + f"Timeout waiting for text generation | TaskUUID: {task_uuid} | " + f"Timeout: {timeout}ms" ) - - if isinstance(initial_response[0], IAsyncTaskResponse): - return initial_response[0] - - return instantiateDataclassList(IText, initial_response) + except RunwareAPIError: + raise + finally: + await self._unregister_pending_operation(task_uuid) def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str], control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict]) -> Dict[str, Any]: request_object = { @@ -2200,7 +2353,7 @@ def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str } if prompt: request_object["positivePrompt"] = prompt - + self._addOptionalImageFields(request_object, requestImage) self._addImageSpecialFields(request_object, requestImage, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data) self._addOptionalField(request_object, requestImage.inputs) @@ -2208,7 +2361,7 @@ def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str self._addOptionalField(request_object, requestImage.ultralytics) self._addOptionalField(request_object, requestImage.safety) self._addOptionalField(request_object, requestImage.settings) - + return request_object @@ -2220,7 +2373,7 @@ def _addOptionalImageFields(self, request_object: Dict[str, Any], requestImage: "clipSkip", "promptWeighting", "maskMargin", "vae", "webhookURL", "acceleration", "useCache", "ttl", "resolution" ] - + for field in optional_fields: value = getattr(requestImage, field, None) if value is not None: @@ -2234,28 +2387,28 @@ def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: I # Add controlNet if present if control_net_data_dicts: request_object["controlNet"] = control_net_data_dicts - + # Add lora if present if requestImage.lora: request_object["lora"] = [ {"model": lora.model, "weight": lora.weight} for lora in requestImage.lora ] - + # Add lycoris if present if requestImage.lycoris: request_object["lycoris"] = [ {"model": lycoris.model, "weight": lycoris.weight} for lycoris in requestImage.lycoris ] - + # Add embeddings if present if requestImage.embeddings: request_object["embeddings"] = [ {"model": embedding.model} for embedding in requestImage.embeddings ] - + # Add refiner if present if requestImage.refiner: refiner_dict = {"model": requestImage.refiner.model} @@ -2264,11 +2417,11 @@ def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: I if requestImage.refiner.startStepPercentage is not None: refiner_dict["startStepPercentage"] = requestImage.refiner.startStepPercentage request_object["refiner"] = refiner_dict - + # Add instantID if present if instant_id_data: request_object["instantID"] = instant_id_data - + # Add outpaint if present if requestImage.outpaint: outpaint_dict = { @@ -2277,26 +2430,26 @@ def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: I if v is not None } request_object["outpaint"] = outpaint_dict - + # Add ipAdapters if present if ip_adapters_data: request_object["ipAdapters"] = ip_adapters_data - + # Add acePlusPlus if present if ace_plus_plus_data: request_object["acePlusPlus"] = ace_plus_plus_data - + # Add puLID if present if pulid_data: request_object["puLID"] = pulid_data - + # Add referenceImages if present if requestImage.referenceImages: request_object["referenceImages"] = requestImage.referenceImages - + # Add acceleratorOptions if present self._addOptionalField(request_object, requestImage.acceleratorOptions) - + # Add advancedFeatures if present if requestImage.advancedFeatures: pipeline_options_dict = { @@ -2305,7 +2458,7 @@ def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: I if v is not None } request_object["advancedFeatures"] = pipeline_options_dict - + # Add extraArgs if present if hasattr(requestImage, "extraArgs") and isinstance(requestImage.extraArgs, dict): request_object.update(requestImage.extraArgs) @@ -2337,203 +2490,184 @@ def _addOptionalField(self, request_object: Dict[str, Any], obj: Any) -> None: if obj_dict: request_object.update(obj_dict) - async def _handleWebhookAcknowledgment( + async def _handleWebhookRequest( self, + request_object: "Dict[str, Any]", task_uuid: str, task_type: str, debug_key: str, - ) -> IAsyncTaskResponse: - lis = self.globalListener(taskUUID=task_uuid) - - async def check_webhook_ack(resolve: callable, reject: callable, *args: Any) -> bool: - async with self._messages_lock: - response_list = self._globalMessages.get(task_uuid, []) - - if not response_list: - return False - - response = response_list[0] if isinstance(response_list, list) else response_list - - self._handle_error_response(response) - - if isinstance(response, dict) and response.get("error"): - reject(response) - return True - - if response.get("taskType") == task_type: - del self._globalMessages[task_uuid] - async_response = createAsyncTaskResponse(response) - resolve(async_response) - return True - - return False + ) -> "IAsyncTaskResponse": + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=lambda r: r.get("taskType") == task_type or r.get("taskUUID") == task_uuid + ) try: - response = await getIntervalWithPromise( - check_webhook_ack, debugKey=debug_key, timeOutDuration=WEBHOOK_TIMEOUT + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=WEBHOOK_TIMEOUT / 1000) + response = results[0] + self._handle_error_response(response) + return createAsyncTaskResponse(response) + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for webhook acknowledgment | TaskUUID: {task_uuid} | " + f"TaskType: {task_type} | Timeout: {WEBHOOK_TIMEOUT}ms" ) + except RunwareAPIError: + raise finally: - lis["destroy"]() + await self._unregister_pending_operation(task_uuid) - if isinstance(response, dict): - self._handle_error_response(response) + async def _handleInitialVideoResponse( + self, + request_object: "Dict[str, Any]", + task_uuid: str, + number_results: int, + delivery_method: "Union[str, EDeliveryMethod]" = None, + webhook_url: "Optional[str]" = None, + debug_key: str = "video-inference-initial" + ) -> "Union[List[IVideo], IAsyncTaskResponse]": + if delivery_method is None: + delivery_method = EDeliveryMethod.ASYNC + delivery_method_enum = delivery_method if isinstance(delivery_method, EDeliveryMethod) else EDeliveryMethod( + delivery_method) + + def is_video_complete(r: "Dict[str, Any]") -> bool: + if r.get("status") == "success": + return True + if r.get("videoUUID") is not None or r.get("mediaUUID") is not None: + return True + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return True + return False - return response + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=is_video_complete + ) - async def _handleInitialVideoResponse(self, task_uuid: str, number_results: int, delivery_method: Union[str, EDeliveryMethod] = EDeliveryMethod.ASYNC, webhook_url: Optional[str] = None, debug_key: str = "video-inference-initial") -> Union[List[IVideo], IAsyncTaskResponse]: - lis = self.globalListener(taskUUID=task_uuid) - delivery_method_enum = delivery_method if isinstance(delivery_method, EDeliveryMethod) else EDeliveryMethod(delivery_method) + timeout = TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else VIDEO_INITIAL_TIMEOUT - async def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool: - # Check if connection was lost during the wait - if not self.connected() or not self.isWebsocketReadyState(): - reject(ConnectionError( - f"Connection lost while waiting for video response | " - f"TaskUUID: {task_uuid} | " - f"Delivery method: {delivery_method_enum}" - )) - return True - - async with self._messages_lock: - response_list = self._globalMessages.get(task_uuid, []) - - if not response_list: - return False + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=timeout / 1000) - response = response_list[0] + if not results: + raise ConnectionError( + f"No initial response received for video generation | " + f"delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + ) - if self._is_error_response(response): - del self._globalMessages[task_uuid] - raise RunwareAPIError(response) + response = results[0] + self._handle_error_response(response) - if response.get("status") == "success" or response.get("videoUUID") is not None or response.get("mediaUUID") is not None: - del self._globalMessages[task_uuid] - resolve([response]) - return True + if response.get("status") == "success" or response.get("videoUUID") is not None or response.get( + "mediaUUID") is not None: + return instantiateDataclassList(IVideo, results) - if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: - del self._globalMessages[task_uuid] - async_response = createAsyncTaskResponse(response) - resolve([async_response]) - return True + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return createAsyncTaskResponse(response) - return False + return instantiateDataclassList(IVideo, results) - try: - initial_response = await getIntervalWithPromise( - check_initial_response, - debugKey=debug_key, - timeOutDuration=TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else VIDEO_INITIAL_TIMEOUT - ) - except RunwareAPIError: - raise - except Exception as e: - # Check if connection was lost during the wait + except asyncio.TimeoutError: if not self.connected() or not self.isWebsocketReadyState(): raise ConnectionError( f"Connection lost while waiting for video response | " - f"TaskUUID: {task_uuid} | " - f"Delivery method: {delivery_method_enum}" + f"TaskUUID: {task_uuid} | Delivery method: {delivery_method_enum}" ) + if delivery_method_enum is EDeliveryMethod.SYNC: - # Raise ConnectionError so _retry_with_reconnect can retry the request - error_msg = ( - f"Timeout waiting for video generation | " - f"TaskUUID: {task_uuid} | " - f"Timeout: {TIMEOUT_DURATION}ms | " - f"Original error: {str(e)}" + raise ConnectionError( + f"Timeout waiting for video generation | TaskUUID: {task_uuid} | " + f"Timeout: {timeout}ms" ) - raise ConnectionError(error_msg) - initial_response = None - finally: - lis["destroy"]() - - if not initial_response or len(initial_response) == 0: - # No response from server means connection was lost during request raise ConnectionError( - f"No initial response received for video generation | delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + f"Timeout waiting for video generation | TaskUUID: {task_uuid} | " + f"Timeout: {timeout}ms" ) - - if isinstance(initial_response[0], IAsyncTaskResponse): - return initial_response[0] - - return instantiateDataclassList(IVideo, initial_response) + finally: + await self._unregister_pending_operation(task_uuid) - async def _handleInitial3dResponse(self, task_uuid: str, number_results: int, delivery_method: Union[str, EDeliveryMethod] = EDeliveryMethod.ASYNC, webhook_url: Optional[str] = None, debug_key: str = "3d-inference-initial") -> Union[List[I3d], IAsyncTaskResponse]: - lis = self.globalListener(taskUUID=task_uuid) + async def _handleInitial3dResponse( + self, + request_object: Dict[str, Any], + task_uuid: str, + number_results: int, + delivery_method: Union[str, EDeliveryMethod] = EDeliveryMethod.ASYNC, + webhook_url: Optional[str] = None, + debug_key: str = "3d-inference-initial", + ) -> Union[List[I3d], IAsyncTaskResponse]: delivery_method_enum = delivery_method if isinstance(delivery_method, EDeliveryMethod) else EDeliveryMethod(delivery_method) - async def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool: - if not self.connected() or not self.isWebsocketReadyState(): - reject(ConnectionError( - f"Connection lost while waiting for 3d response | " - f"TaskUUID: {task_uuid} | " - f"Delivery method: {delivery_method_enum}" - )) + def is_3d_complete(r: Dict[str, Any]) -> bool: + if r.get("status") == "success": return True + outputs = r.get("outputs") + if outputs is not None and outputs.get("files") is not None: + return True + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return True + return False - async with self._messages_lock: - response_list = self._globalMessages.get(task_uuid, []) + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=is_3d_complete, + ) - if not response_list: - return False + timeout = TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else VIDEO_INITIAL_TIMEOUT - response = response_list[0] + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=timeout / 1000) - if self._is_error_response(response): - del self._globalMessages[task_uuid] - raise RunwareAPIError(response) + if not results: + raise ConnectionError( + f"No initial response received for 3d generation | " + f"delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + ) - outputs = response.get("outputs") - if response.get("status") == "success" or (outputs is not None and outputs.get("files") is not None): - del self._globalMessages[task_uuid] - resolve([response]) - return True + response = results[0] + self._handle_error_response(response) - if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: - del self._globalMessages[task_uuid] - async_response = createAsyncTaskResponse(response) - resolve([async_response]) - return True + if response.get("status") == "success": + return instantiateDataclassList(I3d, results) + outputs = response.get("outputs") + if outputs is not None and outputs.get("files") is not None: + return instantiateDataclassList(I3d, results) - return False + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return createAsyncTaskResponse(response) - try: - initial_response = await getIntervalWithPromise( - check_initial_response, - debugKey=debug_key, - timeOutDuration=TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else VIDEO_INITIAL_TIMEOUT - ) - except RunwareAPIError: - raise - except Exception as e: + return instantiateDataclassList(I3d, results) + + except asyncio.TimeoutError: if not self.connected() or not self.isWebsocketReadyState(): raise ConnectionError( f"Connection lost while waiting for 3d response | " - f"TaskUUID: {task_uuid} | " - f"Delivery method: {delivery_method_enum}" + f"TaskUUID: {task_uuid} | Delivery method: {delivery_method_enum}" ) + if delivery_method_enum is EDeliveryMethod.SYNC: - error_msg = ( - f"Timeout waiting for 3d generation | " - f"TaskUUID: {task_uuid} | " - f"Timeout: {TIMEOUT_DURATION}ms | " - f"Original error: {str(e)}" + raise ConnectionError( + f"Timeout waiting for 3d generation | TaskUUID: {task_uuid} | " + f"Timeout: {timeout}ms" ) - raise ConnectionError(error_msg) - initial_response = None - finally: - lis["destroy"]() - - if not initial_response or len(initial_response) == 0: raise ConnectionError( - f"No initial response received for 3d generation | delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + f"Timeout waiting for 3d generation | TaskUUID: {task_uuid} | " + f"Timeout: {timeout}ms" ) - - if isinstance(initial_response[0], IAsyncTaskResponse): - return initial_response[0] - - return instantiateDataclassList(I3d, initial_response) + finally: + await self._unregister_pending_operation(task_uuid) async def _handleInitialImageResponse( self, @@ -2590,7 +2724,6 @@ async def check_initial_response(resolve: callable, reject: callable, *args: Any except RunwareAPIError: raise except Exception as e: - # Check if connection was lost during the wait if not self.connected() or not self.isWebsocketReadyState(): raise ConnectionError( f"Connection lost while waiting for image response | " @@ -2598,7 +2731,6 @@ async def check_initial_response(resolve: callable, reject: callable, *args: Any f"Delivery method: {delivery_method_enum}" ) if delivery_method_enum is EDeliveryMethod.SYNC: - # Raise ConnectionError so _retry_with_reconnect can retry the request error_msg = ( f"Timeout waiting for image generation | " f"TaskUUID: {task_uuid} | " @@ -2611,63 +2743,66 @@ async def check_initial_response(resolve: callable, reject: callable, *args: Any lis["destroy"]() if not initial_response or len(initial_response) == 0: - # No response from server means connection was lost during request raise ConnectionError( - f"No initial response received for image inference | delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + f"No initial response received for image generation | delivery_method={delivery_method_enum} | taskUUID={task_uuid}" ) - + if isinstance(initial_response[0], IAsyncTaskResponse): return initial_response[0] - - return instantiateDataclassList(IImage, initial_response) - async def _sendPollRequest(self, task_uuid: str, poll_count: int) -> List[Dict[str, Any]]: - lis = self.globalListener(taskUUID=task_uuid) + return instantiateDataclassList(IImage, initial_response) + async def _sendPollRequest(self, task_uuid: str, poll_count: int) -> "List[Dict[str, Any]]": + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=lambda r: True + ) try: - await self.send([{ - "taskType": ETaskType.GET_RESPONSE.value, - "taskUUID": task_uuid - }]) - - async def check_poll_response(resolve: callable, reject: callable, *args: Any) -> bool: - async with self._messages_lock: - response_list = self._globalMessages.get(task_uuid, []) - if response_list: - del self._globalMessages[task_uuid] - resolve(response_list) - return True - return False - - return await getIntervalWithPromise( - check_poll_response, - debugKey=f"poll-{poll_count}", - timeOutDuration=VIDEO_INITIAL_TIMEOUT + if should_send: + await self.send([{ + "taskType": ETaskType.GET_RESPONSE.value, + "taskUUID": task_uuid + }]) + await self._mark_operation_sent(task_uuid) + + results = await asyncio.wait_for(future, timeout=VIDEO_INITIAL_TIMEOUT / 1000) + return results + + except asyncio.TimeoutError: + op = self._pending_operations.get(task_uuid) + if op and op["results"]: + return op["results"] + raise Exception( + f"Timeout waiting for poll response | TaskUUID: {task_uuid} | " + f"Poll: {poll_count} | Timeout: {VIDEO_INITIAL_TIMEOUT}ms" ) finally: - lis["destroy"]() + await self._unregister_pending_operation(task_uuid) def _hasPendingResults(self, responses: List[Dict[str, Any]]) -> bool: return any(response.get("status") == "processing" for response in responses) - async def audioInference(self, requestAudio: IAudioInference) -> Union[List[IAudio], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._audioInference, requestAudio) + async def audioInference(self, requestAudio: "IAudioInference") -> "Union[List[IAudio], IAsyncTaskResponse]": + async with self._request_semaphore: + return await self._retry_async_with_reconnect( + self._requestAudio, + requestAudio, + task_type=ETaskType.AUDIO_INFERENCE.value + ) - async def _audioInference(self, requestAudio: IAudioInference) -> Union[List[IAudio], IAsyncTaskResponse]: + async def _requestAudio(self, requestAudio: "IAudioInference") -> Union[List["IAudio"], "IAsyncTaskResponse"]: await self.ensureConnection() - return await self._requestAudio(requestAudio) - - async def _requestAudio(self, requestAudio: IAudioInference) -> Union[List[IAudio], IAsyncTaskResponse]: requestAudio.taskUUID = requestAudio.taskUUID or getUUID() request_object = self._buildAudioRequest(requestAudio) - await self.send([request_object]) return await self._handleInitialAudioResponse( - requestAudio.taskUUID, - requestAudio.numberResults, - requestAudio.deliveryMethod, - request_object.get("webhookURL"), - "audio-inference-initial" + request_object=request_object, + task_uuid=requestAudio.taskUUID, + number_results=requestAudio.numberResults, + delivery_method=requestAudio.deliveryMethod, + webhook_url=request_object.get("webhookURL"), + debug_key="audio-inference-initial" ) def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]: @@ -2678,15 +2813,15 @@ def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]: "model": requestAudio.model, "numberResults": requestAudio.numberResults, } - + # Only add positivePrompt if it's provided if requestAudio.positivePrompt is not None: request_object["positivePrompt"] = requestAudio.positivePrompt.strip() - + # Only add duration if it's provided and not using composition plan if requestAudio.duration is not None: - request_object["duration"] = requestAudio.duration - + request_object["duration"] = requestAudio.duration + self._addOptionalAudioFields(request_object, requestAudio) self._addOptionalField(request_object, requestAudio.speech) self._addOptionalField(request_object, requestAudio.audioSettings) @@ -2724,126 +2859,76 @@ def _addTextProviderSettings(self, request_object: Dict[str, Any], requestText: request_object["providerSettings"] = provider_dict async def _handleInitialAudioResponse( - self, - task_uuid: str, - number_results: int, - delivery_method: Union[str, EDeliveryMethod] = EDeliveryMethod.SYNC, - webhook_url: Optional[str] = None, - debug_key: str = "audio-inference-initial" - ) -> Union[List[IAudio], IAsyncTaskResponse]: - lis = self.globalListener(taskUUID=task_uuid) - delivery_method_enum = delivery_method if isinstance(delivery_method, EDeliveryMethod) else EDeliveryMethod(delivery_method) - - async def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool: - # Check if connection was lost during the wait - if not self.connected() or not self.isWebsocketReadyState(): - reject(ConnectionError( - f"Connection lost while waiting for audio response | " - f"TaskUUID: {task_uuid} | " - f"Delivery method: {delivery_method_enum}" - )) + self, + request_object: "Dict[str, Any]", + task_uuid: str, + number_results: int, + delivery_method: "Union[str, EDeliveryMethod]" = None, + webhook_url: "Optional[str]" = None, + debug_key: str = "audio-inference-initial" + ) -> "Union[List[IAudio], IAsyncTaskResponse]": + if delivery_method is None: + delivery_method = EDeliveryMethod.SYNC + delivery_method_enum = delivery_method if isinstance(delivery_method, EDeliveryMethod) else EDeliveryMethod( + delivery_method) + + def is_audio_complete(r: "Dict[str, Any]") -> bool: + if r.get("status") == "success": return True - - async with self._messages_lock: - response_list = self._globalMessages.get(task_uuid, []) + if r.get("audioUUID") is not None: + return True + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return True + return False - if not response_list: - return False + future, should_send = await self._register_pending_operation( + task_uuid, + expected_results=1, + complete_predicate=is_audio_complete + ) - response = response_list[0] + timeout = TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else AUDIO_INITIAL_TIMEOUT - if self._is_error_response(response): - del self._globalMessages[task_uuid] - raise RunwareAPIError(response) + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(task_uuid) + results = await asyncio.wait_for(future, timeout=timeout / 1000) - if response.get("status") == "success" or response.get("audioUUID") is not None: - - del self._globalMessages[task_uuid] - resolve([response]) - return True + if not results: + raise ConnectionError( + f"No initial response received for audio inference | " + f"delivery_method={delivery_method_enum} | taskUUID={task_uuid}" + ) - if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: - del self._globalMessages[task_uuid] - async_response = createAsyncTaskResponse(response) - resolve([async_response]) - return True + response = results[0] + self._handle_error_response(response) - return False + if response.get("status") == "success" or response.get("audioUUID") is not None: + return instantiateDataclassList(IAudio, results) - try: - initial_response = await getIntervalWithPromise( - check_initial_response, - debugKey=debug_key, - timeOutDuration=TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else AUDIO_INITIAL_TIMEOUT - ) - except RunwareAPIError: - raise - except Exception as e: - # Check if connection was lost during the wait + if webhook_url or delivery_method_enum is EDeliveryMethod.ASYNC: + return createAsyncTaskResponse(response) + + return instantiateDataclassList(IAudio, results) + + except asyncio.TimeoutError: if not self.connected() or not self.isWebsocketReadyState(): raise ConnectionError( f"Connection lost while waiting for audio response | " - f"TaskUUID: {task_uuid} | " - f"Delivery method: {delivery_method_enum}" + f"TaskUUID: {task_uuid} | Delivery method: {delivery_method_enum}" ) + if delivery_method_enum is EDeliveryMethod.SYNC: - # Raise ConnectionError so _retry_with_reconnect can retry the request - error_msg = ( - f"Timeout waiting for audio generation | " - f"TaskUUID: {task_uuid} | " - f"Timeout: {TIMEOUT_DURATION}ms | " - f"Original error: {str(e)}" + raise ConnectionError( + f"Timeout waiting for audio generation | TaskUUID: {task_uuid} | " + f"Timeout: {timeout}ms" ) - raise ConnectionError(error_msg) - initial_response = None + raise + except RunwareAPIError: + raise finally: - lis["destroy"]() - - if not initial_response or len(initial_response) == 0: - # No response from server means connection was lost during request - raise ConnectionError( - f"No initial response received for audio inference | delivery_method={delivery_method_enum} | taskUUID={task_uuid}" - ) - - if isinstance(initial_response[0], IAsyncTaskResponse): - return initial_response[0] - - return instantiateDataclassList(IAudio, initial_response) - - async def _waitForAudioCompletion(self, task_uuid: str) -> Optional[IAudio]: - lis = self.globalListener(taskUUID=task_uuid) - - async def check(resolve: Callable, reject: Callable, *args: Any) -> bool: - async with self._messages_lock: - response = self._globalMessages.get(task_uuid) - if response: - audio_response = response[0] if isinstance(response, list) else response - else: - audio_response = response - - if audio_response and audio_response.get("error"): - reject(audio_response) - return True - - if audio_response: - del self._globalMessages[task_uuid] - resolve(audio_response) - return True - - return False - - try: - response = await getIntervalWithPromise( - check, debugKey="audio-inference", timeOutDuration=AUDIO_INFERENCE_TIMEOUT - ) - lis["destroy"]() - - self._handle_error_response(response) - - return self._createAudioFromResponse(response) if response else None - except Exception as e: - lis["destroy"]() - raise e + await self._unregister_pending_operation(task_uuid) def _processPollingResponse(self, responses: List[Dict[str, Any]]) -> List[Dict[str, Any]]: completed_results: List[Dict[str, Any]] = [] @@ -2864,9 +2949,8 @@ async def _pollResults( # Default to 1 if number_results is None if number_results is None: number_results = 1 - - completed_results: List[Dict[str, Any]] = [] - lis = self.globalListener(taskUUID=task_uuid) + + completed_results: "List[Dict[str, Any]]" = [] task_type = None response_cls: Optional[Union[IVideo, IVideoToText, IAudio, IImage, I3d, IText]] = None @@ -2874,31 +2958,31 @@ async def _pollResults( polling_delay: int = VIDEO_POLLING_DELAY timeout_message: str = f"Polling timeout after {MAX_POLLS} polls" - def configure_from_task_type(task_type: Optional[str]) -> Optional[tuple]: - if not task_type: + def configure_from_task_type(task_type_val: Optional[str]): + if not task_type_val: return None - match task_type: + match task_type_val: case ETaskType.AUDIO_INFERENCE.value: return ( IAudio, - MAX_POLLS, + MAX_POLLS_AUDIO_GENERATION, AUDIO_POLLING_DELAY, - f"Audio generation timeout after {MAX_POLLS} polls" + f"Audio generation timeout after {MAX_POLLS_AUDIO_GENERATION} polls" ) case ETaskType.VIDEO_CAPTION.value: return ( IVideoToText, - MAX_POLLS, + MAX_POLLS_VIDEO_GENERATION, VIDEO_POLLING_DELAY, - f"Video caption generation timeout after {MAX_POLLS} polls" + f"Video caption generation timeout after {MAX_POLLS_VIDEO_GENERATION} polls" ) case ETaskType.IMAGE_INFERENCE.value: return ( IImage, - MAX_POLLS, + MAX_POLLS_IMAGE_GENERATION, IMAGE_POLLING_DELAY, - f"Image generation timeout after {MAX_POLLS} polls" + f"Image generation timeout after {MAX_POLLS_IMAGE_GENERATION} polls" ) case ( ETaskType.VIDEO_INFERENCE.value @@ -2907,16 +2991,16 @@ def configure_from_task_type(task_type: Optional[str]) -> Optional[tuple]: ): return ( IVideo, - MAX_POLLS, + MAX_POLLS_VIDEO_GENERATION, VIDEO_POLLING_DELAY, - f"Video generation timeout after {MAX_POLLS} polls" + f"Video generation timeout after {MAX_POLLS_VIDEO_GENERATION} polls" ) case ETaskType.INFERENCE_3D.value: return ( I3d, - MAX_POLLS, + MAX_POLLS_3D_GENERATION, VIDEO_POLLING_DELAY, - f"3d generation timeout after {MAX_POLLS} polls" + f"3d generation timeout after {MAX_POLLS_3D_GENERATION} polls" ) case ETaskType.TEXT_INFERENCE.value: return ( @@ -2926,10 +3010,17 @@ def configure_from_task_type(task_type: Optional[str]) -> Optional[tuple]: f"Text generation timeout after {MAX_POLLS} polls" ) case _: - raise ValueError(f"Unsupported task type for polling: {task_type}") - + raise ValueError(f"Unsupported task type for polling: {task_type_val}") + + max_polls_loop = max( + MAX_POLLS, + MAX_POLLS_VIDEO_GENERATION, + MAX_POLLS_AUDIO_GENERATION, + MAX_POLLS_3D_GENERATION, + MAX_POLLS_IMAGE_GENERATION, + ) try: - for poll_count in range(MAX_POLLS): + for poll_count in range(max_polls_loop): try: responses = await self._sendPollRequest(task_uuid, poll_count) @@ -2951,11 +3042,21 @@ def configure_from_task_type(task_type: Optional[str]) -> Optional[tuple]: if len(completed_results) >= number_results: return instantiateDataclassList( - response_cls, + response_cls or IVideo, completed_results[:number_results] ) - if not processed_responses and not self._hasPendingResults(responses): + has_pending = self._hasPendingResults(responses) + has_queued = any( + response.get("status") in ("queued", "pending", "scheduled", "waiting") + for response in responses + ) + + if not processed_responses and not has_pending and not has_queued: + has_task_response = any(r.get("taskUUID") == task_uuid for r in responses) + if has_task_response: + logger.warning(f"Received response for {task_uuid} but status unclear, continuing poll") + continue raise RunwareAPIError({"message": f"Unexpected polling response at poll {poll_count}"}) except RunwareAPIError: @@ -2969,11 +3070,12 @@ def configure_from_task_type(task_type: Optional[str]) -> Optional[tuple]: await asyncio.sleep(polling_delay / 1000) - finally: - lis["destroy"]() - async with self._messages_lock: - if task_uuid in self._globalMessages: - del self._globalMessages[task_uuid] + if completed_results: + return instantiateDataclassList(response_cls or IVideo, completed_results[:number_results]) + return [] + + except Exception: + raise def _createAudioFromResponse(self, response: Dict[str, Any]) -> IAudio: return IAudio( diff --git a/runware/server.py b/runware/server.py index c8d2451..426b3b1 100644 --- a/runware/server.py +++ b/runware/server.py @@ -9,7 +9,7 @@ from websockets.protocol import State from typing import Any, Dict, Optional -from .types import SdkType +from .types import SdkType, OperationState from .utils import ( BASE_RUNWARE_URLS, PING_INTERVAL, @@ -172,6 +172,8 @@ async def disconnect(self): self.logger.info("Disconnecting from Runware server") self._is_shutting_down = True + await self._cancel_pending_operations("Disconnected by user") + for task_name, task in list(self._tasks.items()): if task and not task.done(): task.cancel() @@ -214,6 +216,44 @@ async def on_message(self, ws, message): self.logger.error(f"Failed to parse JSON message:", exc_info=e) return + handled_task_uuids = set() + + if self._pending_operations: + if "data" in m and isinstance(m["data"], list): + for item in m["data"]: + task_uuid = item.get("taskUUID") + if task_uuid and await self._handle_pending_operation_message(item): + handled_task_uuids.add(task_uuid) + + if "errors" in m and isinstance(m["errors"], list): + for error in m["errors"]: + task_uuid = error.get("taskUUID") + if task_uuid and await self._handle_pending_operation_error(error): + handled_task_uuids.add(task_uuid) + + if handled_task_uuids: + remaining_data = [ + item for item in m.get("data", []) + if item.get("taskUUID") not in handled_task_uuids + ] + remaining_errors = [ + err for err in m.get("errors", []) + if err.get("taskUUID") not in handled_task_uuids + ] + + if not remaining_data and not remaining_errors: + return + + m = dict(m) + if remaining_data: + m["data"] = remaining_data + elif "data" in m: + del m["data"] + if remaining_errors: + m["errors"] = remaining_errors + elif "errors" in m: + del m["errors"] + listeners_snapshot = list(self._listeners) async_tasks = [] @@ -311,6 +351,8 @@ def _get_task_by_name(self, name): async def handleClose(self): self.logger.debug("Handling close") + await self._cancel_pending_operations("Connection lost during operation") + reconnecting_task = self._tasks.get("Task_Reconnecting") if reconnecting_task is not None: if not reconnecting_task.done() and not reconnecting_task.cancelled(): @@ -380,6 +422,31 @@ async def reconnect(): ) self._tasks["Task_Reconnecting"] = self._reconnecting_task + async def _cancel_pending_operations(self, reason: str = "Connection closed"): + async with self._operations_lock: + operations_to_remove = [] + + for task_uuid, op in self._pending_operations.items(): + future = op.get("future") + state = op.get("state", OperationState.REGISTERED) + + if state == OperationState.REGISTERED: + if future and not future.done(): + future.set_exception(ConnectionError( + f"{reason} | TaskUUID: {task_uuid} | Request was not sent" + )) + operations_to_remove.append(task_uuid) + + elif state == OperationState.SENT: + op["state"] = OperationState.DISCONNECTED + if future and not future.done(): + future.set_exception(ConnectionError( + f"{reason} | TaskUUID: {task_uuid}" + )) + + for task_uuid in operations_to_remove: + del self._pending_operations[task_uuid] + async def heartBeat(self): if self._last_pong_time == 0.0: self._last_pong_time = time.perf_counter() diff --git a/runware/types.py b/runware/types.py index 254cb63..6058a88 100644 --- a/runware/types.py +++ b/runware/types.py @@ -107,6 +107,12 @@ class EDeliveryMethod(Enum): SYNC = "sync" ASYNC = "async" +class OperationState(Enum): + """State machine for pending operations.""" + REGISTERED = "registered" # Future created, request NOT sent + SENT = "sent" # send() completed successfully + DISCONNECTED = "disconnected" # Connection lost after SENT + # Define the types using Literal IOutputType = Literal["base64Data", "dataURI", "URL"] @@ -125,6 +131,24 @@ class IAsyncTaskResponse: taskUUID: str +@dataclass +class IGetResponseRequest: + taskUUID: str + numberResults: int = 1 + + +@dataclass +class IUploadImageRequest: + file: Union[File, str] + taskUUID: str + + +@dataclass +class IUploadMediaRequest: + media_url: str + taskUUID: str + + @dataclass class RunwareBaseType: apiKey: str @@ -1572,6 +1596,7 @@ class ITextInference: messages: List[ITextInferenceMessage] taskUUID: Optional[str] = None deliveryMethod: str = "sync" + numberResults: Optional[int] = 1 maxTokens: Optional[int] = None temperature: Optional[float] = None topP: Optional[float] = None @@ -1580,6 +1605,7 @@ class ITextInference: stopSequences: Optional[List[str]] = None includeCost: Optional[bool] = None providerSettings: Optional[TextProviderSettings] = None + webhookURL: Optional[str] = None @dataclass diff --git a/runware/utils.py b/runware/utils.py index 2574acf..f590ffe 100644 --- a/runware/utils.py +++ b/runware/utils.py @@ -51,7 +51,7 @@ # WebSocket connection health check timeout (milliseconds) # Maximum time to wait for pong response after sending ping # Used in: server.heartBeat() to detect connection loss -PING_TIMEOUT_DURATION = 10000 +PING_TIMEOUT_DURATION = 30000 # WebSocket ping interval (milliseconds) # How often to send ping messages to keep connection alive @@ -189,17 +189,24 @@ 480000 )) # Maximum polling attempts for video generation -# Number of polling iterations before timing out video generation -# Used in: _pollVideoResults() for video generation status checks +# Used in: _pollResults() for video inference / video caption / video background removal / video upscale MAX_POLLS_VIDEO_GENERATION = int(os.environ.get("RUNWARE_MAX_POLLS_VIDEO_GENERATION", 480)) # Maximum polling attempts for audio generation -# Number of polling iterations before timing out audio generation -# Used in: _pollAudioResults() for audio generation status checks +# Used in: _pollResults() for audio inference task type MAX_POLLS_AUDIO_GENERATION = int(os.environ.get("RUNWARE_MAX_POLLS_AUDIO_GENERATION", 240)) +# Maximum polling attempts for 3D generation +# Used in: _pollResults() for 3d inference task type +MAX_POLLS_3D_GENERATION = int(os.environ.get("RUNWARE_MAX_POLLS_3D_GENERATION", 480)) +# Maximum polling attempts for image generation +# Used in: _pollResults() for image inference task type +MAX_POLLS_IMAGE_GENERATION = int(os.environ.get("RUNWARE_MAX_POLLS_IMAGE_GENERATION", 480)) + +# Default / fallback max polls (e.g. when task type unknown) MAX_POLLS = int(os.environ.get("RUNWARE_MAX_POLLS", 480)) +MAX_CONCURRENT_REQUESTS = int(os.environ.get("RUNWARE_MAX_CONCURRENT_REQUESTS", 15)) class LISTEN_TO_IMAGES_KEY: REQUEST_IMAGES = "REQUEST_IMAGES"