Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 52 additions & 61 deletions runware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ async def _removeImageBackground(

# Add provider settings if provided
if removeImageBackgroundPayload.providerSettings:
self._addImageProviderSettings(task_params, removeImageBackgroundPayload)
self._addProviderSettings(task_params, removeImageBackgroundPayload)

# Add safety settings if provided
if removeImageBackgroundPayload.safety:
Expand Down Expand Up @@ -1057,7 +1057,7 @@ async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IIma

# Add provider settings if provided
if upscaleGanPayload.providerSettings:
self._addImageProviderSettings(task_params, upscaleGanPayload)
self._addProviderSettings(task_params, upscaleGanPayload)

# Add safety settings if provided
if upscaleGanPayload.safety:
Expand Down Expand Up @@ -1115,60 +1115,59 @@ async def _imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IIma
await self.ensureConnection()
return await self._vectorize(vectorizePayload)

async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
# Process the image from inputs
input_image = vectorizePayload.inputs.image

if not input_image:
raise ValueError("Image is required in inputs for vectorize task")

# Upload the image if it's a local file
image_uploaded = await self.uploadImage(input_image)

if not image_uploaded or not image_uploaded.imageUUID:
return []

taskUUID = getUUID()

# Create a dictionary with mandatory parameters
task_params = {
async def _processVectorizeInputs(self, vectorizePayload: IVectorize) -> None:
if not vectorizePayload.inputs or not vectorizePayload.inputs.image:
return
vectorizePayload.inputs.image = await process_image(vectorizePayload.inputs.image)

def _buildVectorizeRequest(self, vectorizePayload: IVectorize) -> Dict[str, Any]:
request_object = {
"taskType": ETaskType.IMAGE_VECTORIZE.value,
"taskUUID": taskUUID,
"inputs": {
"image": image_uploaded.imageUUID
}
"taskUUID": vectorizePayload.taskUUID,
}

# Add optional parameters if they are provided
if vectorizePayload.model is not None:
task_params["model"] = vectorizePayload.model
request_object["model"] = vectorizePayload.model
if vectorizePayload.outputType is not None:
task_params["outputType"] = vectorizePayload.outputType
request_object["outputType"] = vectorizePayload.outputType
if vectorizePayload.outputFormat is not None:
task_params["outputFormat"] = vectorizePayload.outputFormat
request_object["outputFormat"] = vectorizePayload.outputFormat
if vectorizePayload.includeCost:
task_params["includeCost"] = vectorizePayload.includeCost
request_object["includeCost"] = vectorizePayload.includeCost
if vectorizePayload.webhookURL:
task_params["webhookURL"] = vectorizePayload.webhookURL

request_object["webhookURL"] = vectorizePayload.webhookURL
if vectorizePayload.width is not None:
request_object["width"] = vectorizePayload.width
if vectorizePayload.height is not None:
request_object["height"] = vectorizePayload.height
if vectorizePayload.positivePrompt is not None:
request_object["positivePrompt"] = vectorizePayload.positivePrompt.strip()
self._addOptionalField(request_object, vectorizePayload.inputs)
self._addProviderSettings(request_object, vectorizePayload)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use smth like:

def build_request(mapping: dict[str, Any]) -> dict[str, Any]:
    """Build request dict, excluding None values."""
    return {k: v for k, v in mapping.items() if v is not None}

return request_object

async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
await self._processVectorizeInputs(vectorizePayload)
vectorizePayload.taskUUID = vectorizePayload.taskUUID or getUUID()
task_params = self._buildVectorizeRequest(vectorizePayload)

# Send the task with all applicable parameters
await self.send([task_params])

if vectorizePayload.webhookURL:
return await self._handleWebhookAcknowledgment(
task_uuid=taskUUID,
task_uuid=vectorizePayload.taskUUID,
task_type="vectorize",
debug_key="image-vectorize-webhook"
)

let_lis = await self.listenToImages(
onPartialImages=None,
taskUUID=taskUUID,
taskUUID=vectorizePayload.taskUUID,
groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES,
)

images = await self.getSimililarImage(
taskUUID=taskUUID,
taskUUID=vectorizePayload.taskUUID,
numberOfImages=1,
shouldThrowError=True,
lis=let_lis,
Expand Down Expand Up @@ -2098,7 +2097,7 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]:
request_object["stopSequences"] = requestText.stopSequences
if requestText.includeCost is not None:
request_object["includeCost"] = requestText.includeCost
self._addTextProviderSettings(request_object, requestText)
self._addProviderSettings(request_object, requestText)
return request_object

async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]:
Expand Down Expand Up @@ -2204,7 +2203,7 @@ def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str
self._addOptionalImageFields(request_object, requestImage)
self._addImageSpecialFields(request_object, requestImage, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data)
self._addOptionalField(request_object, requestImage.inputs)
self._addImageProviderSettings(request_object, requestImage)
self._addProviderSettings(request_object, requestImage)
self._addOptionalField(request_object, requestImage.ultralytics)
self._addOptionalField(request_object, requestImage.safety)
self._addOptionalField(request_object, requestImage.settings)
Expand Down Expand Up @@ -2316,17 +2315,23 @@ def _addSafetySettings(self, request_object: Dict[str, Any], safety: ISafety) ->
if safety_dict:
request_object["safety"] = safety_dict

def _addImageProviderSettings(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None:
if not requestImage.providerSettings:
return
provider_dict = requestImage.providerSettings.to_request_dict()
if provider_dict:
request_object["providerSettings"] = provider_dict

def _addProviderSettings(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
if not requestVideo.providerSettings:
def _addProviderSettings(
self,
request_object: Dict[str, Any],
payload: Union[
IImageInference,
IImageBackgroundRemoval,
IImageUpscale,
IVectorize,
IVideoInference,
IAudioInference,
ITextInference,
],
) -> None:
providerSettings = getattr(payload, "providerSettings", None)
if not providerSettings:
return
provider_dict = requestVideo.providerSettings.to_request_dict()
provider_dict = providerSettings.to_request_dict()
if provider_dict:
request_object["providerSettings"] = provider_dict

Expand Down Expand Up @@ -2691,7 +2696,7 @@ def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]:
self._addOptionalField(request_object, requestAudio.speech)
self._addOptionalField(request_object, requestAudio.audioSettings)
self._addOptionalField(request_object, requestAudio.settings)
self._addAudioProviderSettings(request_object, requestAudio)
self._addProviderSettings(request_object, requestAudio)
self._addOptionalField(request_object, requestAudio.inputs)
self._addOptionalField(request_object, requestAudio.settings)

Expand All @@ -2709,20 +2714,6 @@ def _addOptionalAudioFields(self, request_object: Dict[str, Any], requestAudio:
request_object[field] = value


def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
if not requestAudio.providerSettings:
return
provider_dict = requestAudio.providerSettings.to_request_dict()
if provider_dict:
request_object["providerSettings"] = provider_dict

def _addTextProviderSettings(self, request_object: Dict[str, Any], requestText: ITextInference) -> None:
if not requestText.providerSettings:
return
provider_dict = requestText.providerSettings.to_request_dict()
if provider_dict:
request_object["providerSettings"] = provider_dict

async def _handleInitialAudioResponse(
self,
task_uuid: str,
Expand Down
15 changes: 10 additions & 5 deletions runware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ def request_key(self) -> str:
| IRecraftProviderSettings
)

VectorizeProviderSettings = IRecraftProviderSettings

@dataclass
class ISafety(SerializableMixin):
tolerance: Optional[bool] = None
Expand Down Expand Up @@ -1000,14 +1002,17 @@ class IImageBackgroundRemoval(IImageCaption):

@dataclass
class IVectorize:

inputs: IInputs = None
inputs: Optional[IInputs] = None
includeCost: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only one field is not optional?

taskUUID: Optional[str] = None
model: Optional[str] = None
outputType: Optional[IOutputType] = "URL"
outputFormat: Optional[IOutputFormat] = "SVG"
model: Optional[str] = None
outputType: Optional[IOutputType] = "URL"
outputFormat: Optional[IOutputFormat] = "SVG"
webhookURL: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
positivePrompt: Optional[str] = None
providerSettings: Optional[VectorizeProviderSettings] = None


@dataclass
Expand Down