diff --git a/runware/base.py b/runware/base.py index 721c295..d5b7043 100644 --- a/runware/base.py +++ b/runware/base.py @@ -944,7 +944,7 @@ async def _removeImageBackground( # Add provider settings if provided if removeImageBackgroundPayload.providerSettings: - self._addImageProviderSettings(task_params, removeImageBackgroundPayload) + self._addProviderSettings(task_params, removeImageBackgroundPayload) # Add safety settings if provided if removeImageBackgroundPayload.safety: @@ -1057,7 +1057,7 @@ async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IIma # Add provider settings if provided if upscaleGanPayload.providerSettings: - self._addImageProviderSettings(task_params, upscaleGanPayload) + self._addProviderSettings(task_params, upscaleGanPayload) # Add safety settings if provided if upscaleGanPayload.safety: @@ -1115,60 +1115,59 @@ async def _imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IIma await self.ensureConnection() return await self._vectorize(vectorizePayload) - async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]: - # Process the image from inputs - input_image = vectorizePayload.inputs.image - - if not input_image: - raise ValueError("Image is required in inputs for vectorize task") - - # Upload the image if it's a local file - image_uploaded = await self.uploadImage(input_image) - - if not image_uploaded or not image_uploaded.imageUUID: - return [] - - taskUUID = getUUID() - - # Create a dictionary with mandatory parameters - task_params = { + async def _processVectorizeInputs(self, vectorizePayload: IVectorize) -> None: + if not vectorizePayload.inputs or not vectorizePayload.inputs.image: + return + vectorizePayload.inputs.image = await process_image(vectorizePayload.inputs.image) + + def _buildVectorizeRequest(self, vectorizePayload: IVectorize) -> Dict[str, Any]: + request_object = { "taskType": ETaskType.IMAGE_VECTORIZE.value, - "taskUUID": taskUUID, - "inputs": { - "image": image_uploaded.imageUUID - } + "taskUUID": vectorizePayload.taskUUID, } - - # Add optional parameters if they are provided if vectorizePayload.model is not None: - task_params["model"] = vectorizePayload.model + request_object["model"] = vectorizePayload.model if vectorizePayload.outputType is not None: - task_params["outputType"] = vectorizePayload.outputType + request_object["outputType"] = vectorizePayload.outputType if vectorizePayload.outputFormat is not None: - task_params["outputFormat"] = vectorizePayload.outputFormat + request_object["outputFormat"] = vectorizePayload.outputFormat if vectorizePayload.includeCost: - task_params["includeCost"] = vectorizePayload.includeCost + request_object["includeCost"] = vectorizePayload.includeCost if vectorizePayload.webhookURL: - task_params["webhookURL"] = vectorizePayload.webhookURL - + request_object["webhookURL"] = vectorizePayload.webhookURL + if vectorizePayload.width is not None: + request_object["width"] = vectorizePayload.width + if vectorizePayload.height is not None: + request_object["height"] = vectorizePayload.height + if vectorizePayload.positivePrompt is not None: + request_object["positivePrompt"] = vectorizePayload.positivePrompt.strip() + self._addOptionalField(request_object, vectorizePayload.inputs) + self._addProviderSettings(request_object, vectorizePayload) + return request_object + + async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]: + await self._processVectorizeInputs(vectorizePayload) + vectorizePayload.taskUUID = vectorizePayload.taskUUID or getUUID() + task_params = self._buildVectorizeRequest(vectorizePayload) + # Send the task with all applicable parameters await self.send([task_params]) if vectorizePayload.webhookURL: return await self._handleWebhookAcknowledgment( - task_uuid=taskUUID, + task_uuid=vectorizePayload.taskUUID, task_type="vectorize", debug_key="image-vectorize-webhook" ) let_lis = await self.listenToImages( onPartialImages=None, - taskUUID=taskUUID, + taskUUID=vectorizePayload.taskUUID, groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES, ) images = await self.getSimililarImage( - taskUUID=taskUUID, + taskUUID=vectorizePayload.taskUUID, numberOfImages=1, shouldThrowError=True, lis=let_lis, @@ -2098,7 +2097,7 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]: request_object["stopSequences"] = requestText.stopSequences if requestText.includeCost is not None: request_object["includeCost"] = requestText.includeCost - self._addTextProviderSettings(request_object, requestText) + self._addProviderSettings(request_object, requestText) return request_object async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: @@ -2204,7 +2203,7 @@ def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str self._addOptionalImageFields(request_object, requestImage) self._addImageSpecialFields(request_object, requestImage, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data) self._addOptionalField(request_object, requestImage.inputs) - self._addImageProviderSettings(request_object, requestImage) + self._addProviderSettings(request_object, requestImage) self._addOptionalField(request_object, requestImage.ultralytics) self._addOptionalField(request_object, requestImage.safety) self._addOptionalField(request_object, requestImage.settings) @@ -2316,17 +2315,23 @@ def _addSafetySettings(self, request_object: Dict[str, Any], safety: ISafety) -> if safety_dict: request_object["safety"] = safety_dict - def _addImageProviderSettings(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None: - if not requestImage.providerSettings: - return - provider_dict = requestImage.providerSettings.to_request_dict() - if provider_dict: - request_object["providerSettings"] = provider_dict - - def _addProviderSettings(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None: - if not requestVideo.providerSettings: + def _addProviderSettings( + self, + request_object: Dict[str, Any], + payload: Union[ + IImageInference, + IImageBackgroundRemoval, + IImageUpscale, + IVectorize, + IVideoInference, + IAudioInference, + ITextInference, + ], + ) -> None: + providerSettings = getattr(payload, "providerSettings", None) + if not providerSettings: return - provider_dict = requestVideo.providerSettings.to_request_dict() + provider_dict = providerSettings.to_request_dict() if provider_dict: request_object["providerSettings"] = provider_dict @@ -2691,7 +2696,7 @@ def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]: self._addOptionalField(request_object, requestAudio.speech) self._addOptionalField(request_object, requestAudio.audioSettings) self._addOptionalField(request_object, requestAudio.settings) - self._addAudioProviderSettings(request_object, requestAudio) + self._addProviderSettings(request_object, requestAudio) self._addOptionalField(request_object, requestAudio.inputs) self._addOptionalField(request_object, requestAudio.settings) @@ -2709,20 +2714,6 @@ def _addOptionalAudioFields(self, request_object: Dict[str, Any], requestAudio: request_object[field] = value - def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None: - if not requestAudio.providerSettings: - return - provider_dict = requestAudio.providerSettings.to_request_dict() - if provider_dict: - request_object["providerSettings"] = provider_dict - - def _addTextProviderSettings(self, request_object: Dict[str, Any], requestText: ITextInference) -> None: - if not requestText.providerSettings: - return - provider_dict = requestText.providerSettings.to_request_dict() - if provider_dict: - request_object["providerSettings"] = provider_dict - async def _handleInitialAudioResponse( self, task_uuid: str, diff --git a/runware/types.py b/runware/types.py index 254cb63..2e661fd 100644 --- a/runware/types.py +++ b/runware/types.py @@ -682,6 +682,8 @@ def request_key(self) -> str: | IRecraftProviderSettings ) +VectorizeProviderSettings = IRecraftProviderSettings + @dataclass class ISafety(SerializableMixin): tolerance: Optional[bool] = None @@ -1000,14 +1002,17 @@ class IImageBackgroundRemoval(IImageCaption): @dataclass class IVectorize: - - inputs: IInputs = None + inputs: Optional[IInputs] = None includeCost: bool = False taskUUID: Optional[str] = None - model: Optional[str] = None - outputType: Optional[IOutputType] = "URL" - outputFormat: Optional[IOutputFormat] = "SVG" + model: Optional[str] = None + outputType: Optional[IOutputType] = "URL" + outputFormat: Optional[IOutputFormat] = "SVG" webhookURL: Optional[str] = None + width: Optional[int] = None + height: Optional[int] = None + positivePrompt: Optional[str] = None + providerSettings: Optional[VectorizeProviderSettings] = None @dataclass