Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 94 additions & 45 deletions runware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
IAudioInference,
IFrameImage,
IAsyncTaskResponse,
IGetResponseType,
IUploadImageRequest,
IUploadMediaRequest,
IVectorize,
I3dInference,
I3d,
Expand Down Expand Up @@ -133,42 +136,52 @@ def __init__(
self._reconnect_lock = asyncio.Lock()


async def _retry_with_reconnect(self, func, *args, **kwargs):
async def _retry_with_reconnect(self, func, request_model, task_type: str):
task_uuid = getattr(request_model, "taskUUID", None)
delivery_method = getattr(request_model, "deliveryMethod", None)

last_error = None

for attempt in range(MAX_RETRY_ATTEMPTS):
try:
result = await func(*args, **kwargs)
result = await func(request_model)
self._reconnection_manager.on_connection_success()
return result

except Exception as e:
last_error = e

# When conflictTaskUUID: raise on first attempt, return async response on retry
if isinstance(e, RunwareAPIError) and e.code == "conflictTaskUUID":
if attempt == 0:
raise
else:
#
context = e.error_data.get("context", {})
task_type = context.get("taskType")
task_uuid = context.get("taskUUID") or e.error_data.get("taskUUID")
delivery_method_raw = context.get("deliveryMethod")
delivery_method_enum = EDeliveryMethod(delivery_method_raw) if isinstance(delivery_method_raw, str) else delivery_method_raw if delivery_method_raw else None

if task_type and task_uuid and delivery_method_enum is EDeliveryMethod.ASYNC:
return createAsyncTaskResponse({
"taskType": task_type,
"taskUUID": task_uuid
})

raise RunwareAPIError({
"code": "conflictTaskUUIDDuringRetries",
"message": "Lost connection during request submission",

delivery_method_enum = None
if delivery_method is not None:
delivery_method_enum = (
EDeliveryMethod(delivery_method)
if isinstance(delivery_method, str)
else delivery_method
)

if task_type and task_uuid and delivery_method_enum is EDeliveryMethod.ASYNC:
return createAsyncTaskResponse({
"taskType": task_type,
"taskUUID": task_uuid
})

conflict_task_uuid = e.error_data.get("taskUUID") or task_uuid
if conflict_task_uuid:
number_results = getattr(request_model, "numberResults", 1) or 1
return await self._pollResults(
task_uuid=conflict_task_uuid,
number_results=number_results
)

raise RunwareAPIError({
"code": "conflictTaskUUIDDuringRetries",
"message": "Lost connection during request submission",
Comment on lines +165 to +182
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conflict handling logic has a potentially problematic fallback behavior. When a conflictTaskUUID error occurs on retry (attempt > 0), if the delivery method is not ASYNC or the required fields are missing, the code now falls back to polling for results (lines 172-178). This assumes the task was already submitted successfully, but a conflictTaskUUID error during a retry after a connection error could indicate the task submission failed or was never completed. Polling for a task that was never successfully submitted will likely fail or timeout. Consider whether this fallback is appropriate, or if it should still raise an error when the task submission status is uncertain.

Suggested change
if task_type and task_uuid and delivery_method_enum is EDeliveryMethod.ASYNC:
return createAsyncTaskResponse({
"taskType": task_type,
"taskUUID": task_uuid
})
conflict_task_uuid = e.error_data.get("taskUUID") or task_uuid
if conflict_task_uuid:
number_results = getattr(request_model, "numberResults", 1) or 1
return await self._pollResults(
task_uuid=conflict_task_uuid,
number_results=number_results
)
raise RunwareAPIError({
"code": "conflictTaskUUIDDuringRetries",
"message": "Lost connection during request submission",
# For ASYNC delivery, return an async task handle if we have a known task UUID.
if task_type and task_uuid and delivery_method_enum is EDeliveryMethod.ASYNC:
return createAsyncTaskResponse({
"taskType": task_type,
"taskUUID": task_uuid
})
# For non-ASYNC or unknown delivery methods, only poll when the server explicitly
# reports a taskUUID in the error data. Otherwise, the submission status is
# uncertain and we should not assume the task exists.
error_data = getattr(e, "error_data", {}) or {}
conflict_task_uuid = error_data.get("taskUUID")
if conflict_task_uuid:
number_results = getattr(request_model, "numberResults", 1) or 1
return await self._pollResults(
task_uuid=conflict_task_uuid,
number_results=number_results
)
raise RunwareAPIError({
"code": "conflictTaskUUIDDuringRetries",
"message": "Lost connection during request submission; "
"unable to confirm whether the task was created.",

Copilot uses AI. Check for mistakes.
"taskUUID": task_uuid
})

if not isinstance(e, ConnectionError):
raise
Expand Down Expand Up @@ -330,7 +343,9 @@ def handle_connection_response(self, m):
self._invalidAPIkey = None

async def photoMaker(self, requestPhotoMaker: IPhotoMaker) -> Union[List[IImage], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._photoMaker, requestPhotoMaker)
return await self._retry_with_reconnect(
self._photoMaker, requestPhotoMaker, task_type=ETaskType.PHOTO_MAKER.value
)

async def _photoMaker(self, requestPhotoMaker: IPhotoMaker) -> Union[List[IImage], IAsyncTaskResponse]:
retry_count = 0
Expand Down Expand Up @@ -438,7 +453,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
async def imageInference(
self, requestImage: IImageInference
) -> Union[List[IImage], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._imageInference, requestImage)
return await self._retry_with_reconnect(
self._imageInference, requestImage, task_type=ETaskType.IMAGE_INFERENCE.value
)

async def _imageInference(
self, requestImage: IImageInference
Expand Down Expand Up @@ -624,7 +641,9 @@ async def _requestImages(
# return images

async def imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._imageCaption, requestImageToText)
return await self._retry_with_reconnect(
self._imageCaption, requestImageToText, task_type=ETaskType.IMAGE_CAPTION.value
)

async def _imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]:
await self.ensureConnection()
Expand Down Expand Up @@ -733,7 +752,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
return None

async def videoCaption(self, requestVideoCaption: IVideoCaption) -> Union[List[IVideoToText], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._videoCaption, requestVideoCaption)
return await self._retry_with_reconnect(
self._videoCaption, requestVideoCaption, task_type=ETaskType.VIDEO_CAPTION.value
)

async def _videoCaption(self, requestVideoCaption: IVideoCaption) -> Union[List[IVideoToText], IAsyncTaskResponse]:
await self.ensureConnection()
Expand Down Expand Up @@ -776,7 +797,10 @@ async def _requestVideoCaption(
)

async def videoBackgroundRemoval(self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval) -> Union[List[IVideo], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._videoBackgroundRemoval, requestVideoBackgroundRemoval)
return await self._retry_with_reconnect(
self._videoBackgroundRemoval, requestVideoBackgroundRemoval,
task_type=ETaskType.VIDEO_BACKGROUND_REMOVAL.value
)

async def _videoBackgroundRemoval(self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval) -> Union[List[IVideo], IAsyncTaskResponse]:
await self.ensureConnection()
Expand Down Expand Up @@ -832,7 +856,9 @@ async def _requestVideoBackgroundRemoval(
)

async def videoUpscale(self, requestVideoUpscale: IVideoUpscale) -> Union[List[IVideo], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._videoUpscale, requestVideoUpscale)
return await self._retry_with_reconnect(
self._videoUpscale, requestVideoUpscale, task_type=ETaskType.VIDEO_UPSCALE.value
)

async def _videoUpscale(self, requestVideoUpscale: IVideoUpscale) -> Union[List[IVideo], IAsyncTaskResponse]:
await self.ensureConnection()
Expand Down Expand Up @@ -992,7 +1018,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
return image_list

async def imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._imageUpscale, upscaleGanPayload)
return await self._retry_with_reconnect(
self._imageUpscale, upscaleGanPayload, task_type=ETaskType.IMAGE_UPSCALE.value
)

async def _imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]:
await self.ensureConnection()
Expand Down Expand Up @@ -1105,7 +1133,9 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
return image_list

async def imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._imageVectorize, vectorizePayload)
return await self._retry_with_reconnect(
self._imageVectorize, vectorizePayload, task_type=ETaskType.IMAGE_VECTORIZE.value
)

async def _imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
await self.ensureConnection()
Expand Down Expand Up @@ -1270,12 +1300,16 @@ async def check(resolve: Any, reject: Any, *args: Any) -> bool:
return list(set(enhanced_prompts))

async def uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]:
return await self._retry_with_reconnect(self._uploadImage, file)
request = IUploadImageRequest(file=file, taskUUID=getUUID())
return await self._retry_with_reconnect(
self._uploadImage, request, task_type=ETaskType.IMAGE_UPLOAD.value
)

async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]:
async def _uploadImage(self, request: IUploadImageRequest) -> Optional[UploadImageType]:
await self.ensureConnection()

task_uuid = getUUID()
file = request.file
task_uuid = request.taskUUID
local_file = True

if isinstance(file, str):
Expand Down Expand Up @@ -1347,12 +1381,16 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
return image

async def uploadMedia(self, media_url: str) -> Optional[MediaStorageType]:
return await self._retry_with_reconnect(self._uploadMedia, media_url)
request = IUploadMediaRequest(media_url=media_url, taskUUID=getUUID())
return await self._retry_with_reconnect(
self._uploadMedia, request, task_type=ETaskType.MEDIA_STORAGE.value
)

async def _uploadMedia(self, media_url: str) -> Optional[MediaStorageType]:
async def _uploadMedia(self, request: IUploadMediaRequest) -> Optional[MediaStorageType]:
await self.ensureConnection()

task_uuid = getUUID()
media_url = request.media_url
task_uuid = request.taskUUID
media_data = media_url

if isinstance(media_url, str):
Expand Down Expand Up @@ -1810,10 +1848,14 @@ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
async def modelUpload(
self, requestModel: IUploadModelBaseType
) -> Optional[IUploadModelResponse]:
return await self._retry_with_reconnect(self._modelUpload, requestModel)
return await self._retry_with_reconnect(
self._modelUpload, requestModel, task_type=ETaskType.MODEL_UPLOAD.value
)

async def modelSearch(self, payload: IModelSearch) -> IModelSearchResponse:
return await self._retry_with_reconnect(self._modelSearch, payload)
return await self._retry_with_reconnect(
self._modelSearch, payload, task_type=ETaskType.MODEL_SEARCH.value
)

async def _modelSearch(self, payload: IModelSearch) -> IModelSearchResponse:
try:
Expand Down Expand Up @@ -1867,14 +1909,18 @@ async def check(resolve: Callable, reject: Callable, *args: Any) -> bool:
raise RunwareAPIError({"message": str(e)})

async def videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._videoInference, requestVideo)
return await self._retry_with_reconnect(
self._videoInference, requestVideo, task_type=ETaskType.VIDEO_INFERENCE.value
)

async def _videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]:
await self.ensureConnection()
return await self._requestVideo(requestVideo)

async def inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._inference3d, request3d)
return await self._retry_with_reconnect(
self._inference3d, request3d, task_type=ETaskType.INFERENCE_3D.value
)

async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTaskResponse]:
await self.ensureConnection()
Expand All @@ -1885,18 +1931,19 @@ async def getResponse(
taskUUID: str,
numberResults: Optional[int] = 1,
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d]]:
return await self._retry_with_reconnect(self._getResponse, taskUUID, numberResults)
request = IGetResponseType(taskUUID=taskUUID, numberResults=numberResults or 1)
return await self._retry_with_reconnect(
self._getResponse, request, task_type=ETaskType.GET_RESPONSE.value
)

async def _getResponse(
self,
taskUUID: str,
numberResults: Optional[int] = 1,
request: IGetResponseType,
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d]]:
await self.ensureConnection()

return await self._pollResults(
task_uuid=taskUUID,
number_results=numberResults,
task_uuid=request.taskUUID,
number_results=request.numberResults,
)

async def _requestVideo(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]:
Expand Down Expand Up @@ -2519,7 +2566,9 @@ def _hasPendingResults(self, responses: List[Dict[str, Any]]) -> bool:
return any(response.get("status") == "processing" for response in responses)

async def audioInference(self, requestAudio: IAudioInference) -> Union[List[IAudio], IAsyncTaskResponse]:
return await self._retry_with_reconnect(self._audioInference, requestAudio)
return await self._retry_with_reconnect(
self._audioInference, requestAudio, task_type=ETaskType.AUDIO_INFERENCE.value
)

async def _audioInference(self, requestAudio: IAudioInference) -> Union[List[IAudio], IAsyncTaskResponse]:
await self.ensureConnection()
Expand Down
18 changes: 18 additions & 0 deletions runware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ class IAsyncTaskResponse:
taskUUID: str


@dataclass
class IGetResponseType:
taskUUID: str
numberResults: int = 1


@dataclass
class IUploadImageRequest:
file: Union[File, str]
taskUUID: str


@dataclass
class IUploadMediaRequest:
media_url: str
taskUUID: str


@dataclass
class RunwareBaseType:
apiKey: str
Expand Down
Loading