From 6e1bb94bd214d2798707698b4795b7ea95055385 Mon Sep 17 00:00:00 2001 From: Sirshendu Ganguly Date: Thu, 5 Feb 2026 23:33:51 -0500 Subject: [PATCH 1/2] remove context dependency --- runware/base.py | 125 +++++++++++++++++++++++++++++++---------------- runware/types.py | 6 +++ 2 files changed, 90 insertions(+), 41 deletions(-) diff --git a/runware/base.py b/runware/base.py index 57a99e0..72f89d0 100644 --- a/runware/base.py +++ b/runware/base.py @@ -49,6 +49,7 @@ IAudioInference, IFrameImage, IAsyncTaskResponse, + IGetResponseType, IVectorize, I3dInference, I3d, @@ -133,42 +134,52 @@ def __init__( self._reconnect_lock = asyncio.Lock() - async def _retry_with_reconnect(self, func, *args, **kwargs): + async def _retry_with_reconnect(self, func, request_model, task_type: str): + task_uuid = getattr(request_model, "taskUUID", None) + delivery_method = getattr(request_model, "deliveryMethod", None) last_error = None for attempt in range(MAX_RETRY_ATTEMPTS): try: - result = await func(*args, **kwargs) + result = await func(request_model) self._reconnection_manager.on_connection_success() return result except Exception as e: last_error = e - # When conflictTaskUUID: raise on first attempt, return async response on retry if isinstance(e, RunwareAPIError) and e.code == "conflictTaskUUID": if attempt == 0: raise - else: - # - context = e.error_data.get("context", {}) - task_type = context.get("taskType") - task_uuid = context.get("taskUUID") or e.error_data.get("taskUUID") - delivery_method_raw = context.get("deliveryMethod") - delivery_method_enum = EDeliveryMethod(delivery_method_raw) if isinstance(delivery_method_raw, str) else delivery_method_raw if delivery_method_raw else None - - if task_type and task_uuid and delivery_method_enum is EDeliveryMethod.ASYNC: - return createAsyncTaskResponse({ - "taskType": task_type, - "taskUUID": task_uuid - }) - - raise RunwareAPIError({ - "code": "conflictTaskUUIDDuringRetries", - "message": "Lost connection during request submission", + + delivery_method_enum = None + if delivery_method is not None: + delivery_method_enum = ( + EDeliveryMethod(delivery_method) + if isinstance(delivery_method, str) + else delivery_method + ) + + if task_type and task_uuid and delivery_method_enum is EDeliveryMethod.ASYNC: + return createAsyncTaskResponse({ + "taskType": task_type, "taskUUID": task_uuid }) + + conflict_task_uuid = e.error_data.get("taskUUID") or task_uuid + if conflict_task_uuid: + number_results = getattr(request_model, "numberResults", 1) or 1 + return await self._pollResults( + task_uuid=conflict_task_uuid, + number_results=number_results + ) + + raise RunwareAPIError({ + "code": "conflictTaskUUIDDuringRetries", + "message": "Lost connection during request submission", + "taskUUID": task_uuid + }) if not isinstance(e, ConnectionError): raise @@ -330,7 +341,9 @@ def handle_connection_response(self, m): self._invalidAPIkey = None async def photoMaker(self, requestPhotoMaker: IPhotoMaker) -> Union[List[IImage], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._photoMaker, requestPhotoMaker) + return await self._retry_with_reconnect( + self._photoMaker, requestPhotoMaker, task_type=ETaskType.PHOTO_MAKER.value + ) async def _photoMaker(self, requestPhotoMaker: IPhotoMaker) -> Union[List[IImage], IAsyncTaskResponse]: retry_count = 0 @@ -438,7 +451,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool: async def imageInference( self, requestImage: IImageInference ) -> Union[List[IImage], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._imageInference, requestImage) + return await self._retry_with_reconnect( + self._imageInference, requestImage, task_type=ETaskType.IMAGE_INFERENCE.value + ) async def _imageInference( self, requestImage: IImageInference @@ -624,7 +639,9 @@ async def _requestImages( # return images async def imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._imageCaption, requestImageToText) + return await self._retry_with_reconnect( + self._imageCaption, requestImageToText, task_type=ETaskType.IMAGE_CAPTION.value + ) async def _imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]: await self.ensureConnection() @@ -733,7 +750,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool: return None async def videoCaption(self, requestVideoCaption: IVideoCaption) -> Union[List[IVideoToText], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._videoCaption, requestVideoCaption) + return await self._retry_with_reconnect( + self._videoCaption, requestVideoCaption, task_type=ETaskType.VIDEO_CAPTION.value + ) async def _videoCaption(self, requestVideoCaption: IVideoCaption) -> Union[List[IVideoToText], IAsyncTaskResponse]: await self.ensureConnection() @@ -776,7 +795,10 @@ async def _requestVideoCaption( ) async def videoBackgroundRemoval(self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval) -> Union[List[IVideo], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._videoBackgroundRemoval, requestVideoBackgroundRemoval) + return await self._retry_with_reconnect( + self._videoBackgroundRemoval, requestVideoBackgroundRemoval, + task_type=ETaskType.VIDEO_BACKGROUND_REMOVAL.value + ) async def _videoBackgroundRemoval(self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval) -> Union[List[IVideo], IAsyncTaskResponse]: await self.ensureConnection() @@ -832,7 +854,9 @@ async def _requestVideoBackgroundRemoval( ) async def videoUpscale(self, requestVideoUpscale: IVideoUpscale) -> Union[List[IVideo], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._videoUpscale, requestVideoUpscale) + return await self._retry_with_reconnect( + self._videoUpscale, requestVideoUpscale, task_type=ETaskType.VIDEO_UPSCALE.value + ) async def _videoUpscale(self, requestVideoUpscale: IVideoUpscale) -> Union[List[IVideo], IAsyncTaskResponse]: await self.ensureConnection() @@ -992,7 +1016,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool: return image_list async def imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._imageUpscale, upscaleGanPayload) + return await self._retry_with_reconnect( + self._imageUpscale, upscaleGanPayload, task_type=ETaskType.IMAGE_UPSCALE.value + ) async def _imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]: await self.ensureConnection() @@ -1105,7 +1131,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool: return image_list async def imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._imageVectorize, vectorizePayload) + return await self._retry_with_reconnect( + self._imageVectorize, vectorizePayload, task_type=ETaskType.IMAGE_VECTORIZE.value + ) async def _imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]: await self.ensureConnection() @@ -1270,7 +1298,9 @@ async def check(resolve: Any, reject: Any, *args: Any) -> bool: return list(set(enhanced_prompts)) async def uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: - return await self._retry_with_reconnect(self._uploadImage, file) + return await self._retry_with_reconnect( + self._uploadImage, file, task_type=ETaskType.IMAGE_UPLOAD.value + ) async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: await self.ensureConnection() @@ -1347,7 +1377,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool: return image async def uploadMedia(self, media_url: str) -> Optional[MediaStorageType]: - return await self._retry_with_reconnect(self._uploadMedia, media_url) + return await self._retry_with_reconnect( + self._uploadMedia, media_url, task_type=ETaskType.MEDIA_STORAGE.value + ) async def _uploadMedia(self, media_url: str) -> Optional[MediaStorageType]: await self.ensureConnection() @@ -1810,10 +1842,14 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool: async def modelUpload( self, requestModel: IUploadModelBaseType ) -> Optional[IUploadModelResponse]: - return await self._retry_with_reconnect(self._modelUpload, requestModel) + return await self._retry_with_reconnect( + self._modelUpload, requestModel, task_type=ETaskType.MODEL_UPLOAD.value + ) async def modelSearch(self, payload: IModelSearch) -> IModelSearchResponse: - return await self._retry_with_reconnect(self._modelSearch, payload) + return await self._retry_with_reconnect( + self._modelSearch, payload, task_type=ETaskType.MODEL_SEARCH.value + ) async def _modelSearch(self, payload: IModelSearch) -> IModelSearchResponse: try: @@ -1867,14 +1903,18 @@ async def check(resolve: Callable, reject: Callable, *args: Any) -> bool: raise RunwareAPIError({"message": str(e)}) async def videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._videoInference, requestVideo) + return await self._retry_with_reconnect( + self._videoInference, requestVideo, task_type=ETaskType.VIDEO_INFERENCE.value + ) async def _videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]: await self.ensureConnection() return await self._requestVideo(requestVideo) async def inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._inference3d, request3d) + return await self._retry_with_reconnect( + self._inference3d, request3d, task_type=ETaskType.INFERENCE_3D.value + ) async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTaskResponse]: await self.ensureConnection() @@ -1885,18 +1925,19 @@ async def getResponse( taskUUID: str, numberResults: Optional[int] = 1, ) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d]]: - return await self._retry_with_reconnect(self._getResponse, taskUUID, numberResults) + request = IGetResponseType(taskUUID=taskUUID, numberResults=numberResults or 1) + return await self._retry_with_reconnect( + self._getResponse, request, task_type=ETaskType.GET_RESPONSE.value + ) async def _getResponse( self, - taskUUID: str, - numberResults: Optional[int] = 1, + request: IGetResponseType, ) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d]]: await self.ensureConnection() - return await self._pollResults( - task_uuid=taskUUID, - number_results=numberResults, + task_uuid=request.taskUUID, + number_results=request.numberResults, ) async def _requestVideo(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]: @@ -2519,7 +2560,9 @@ def _hasPendingResults(self, responses: List[Dict[str, Any]]) -> bool: return any(response.get("status") == "processing" for response in responses) async def audioInference(self, requestAudio: IAudioInference) -> Union[List[IAudio], IAsyncTaskResponse]: - return await self._retry_with_reconnect(self._audioInference, requestAudio) + return await self._retry_with_reconnect( + self._audioInference, requestAudio, task_type=ETaskType.AUDIO_INFERENCE.value + ) async def _audioInference(self, requestAudio: IAudioInference) -> Union[List[IAudio], IAsyncTaskResponse]: await self.ensureConnection() diff --git a/runware/types.py b/runware/types.py index 759b010..3731c8b 100644 --- a/runware/types.py +++ b/runware/types.py @@ -124,6 +124,12 @@ class IAsyncTaskResponse: taskUUID: str +@dataclass +class IGetResponseType: + taskUUID: str + numberResults: int = 1 + + @dataclass class RunwareBaseType: apiKey: str From 9e843b69777e28dc10e8e462c2d40eb1b23367a8 Mon Sep 17 00:00:00 2001 From: Sirshendu Ganguly Date: Thu, 5 Feb 2026 23:50:28 -0500 Subject: [PATCH 2/2] fix imageUpload and mediaUpload --- runware/base.py | 18 ++++++++++++------ runware/types.py | 12 ++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/runware/base.py b/runware/base.py index 72f89d0..68f0511 100644 --- a/runware/base.py +++ b/runware/base.py @@ -50,6 +50,8 @@ IFrameImage, IAsyncTaskResponse, IGetResponseType, + IUploadImageRequest, + IUploadMediaRequest, IVectorize, I3dInference, I3d, @@ -1298,14 +1300,16 @@ async def check(resolve: Any, reject: Any, *args: Any) -> bool: return list(set(enhanced_prompts)) async def uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: + request = IUploadImageRequest(file=file, taskUUID=getUUID()) return await self._retry_with_reconnect( - self._uploadImage, file, task_type=ETaskType.IMAGE_UPLOAD.value + self._uploadImage, request, task_type=ETaskType.IMAGE_UPLOAD.value ) - async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: + async def _uploadImage(self, request: IUploadImageRequest) -> Optional[UploadImageType]: await self.ensureConnection() - task_uuid = getUUID() + file = request.file + task_uuid = request.taskUUID local_file = True if isinstance(file, str): @@ -1377,14 +1381,16 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool: return image async def uploadMedia(self, media_url: str) -> Optional[MediaStorageType]: + request = IUploadMediaRequest(media_url=media_url, taskUUID=getUUID()) return await self._retry_with_reconnect( - self._uploadMedia, media_url, task_type=ETaskType.MEDIA_STORAGE.value + self._uploadMedia, request, task_type=ETaskType.MEDIA_STORAGE.value ) - async def _uploadMedia(self, media_url: str) -> Optional[MediaStorageType]: + async def _uploadMedia(self, request: IUploadMediaRequest) -> Optional[MediaStorageType]: await self.ensureConnection() - task_uuid = getUUID() + media_url = request.media_url + task_uuid = request.taskUUID media_data = media_url if isinstance(media_url, str): diff --git a/runware/types.py b/runware/types.py index 3731c8b..2cd8d0e 100644 --- a/runware/types.py +++ b/runware/types.py @@ -130,6 +130,18 @@ class IGetResponseType: numberResults: int = 1 +@dataclass +class IUploadImageRequest: + file: Union[File, str] + taskUUID: str + + +@dataclass +class IUploadMediaRequest: + media_url: str + taskUUID: str + + @dataclass class RunwareBaseType: apiKey: str