diff --git a/docs/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md similarity index 100% rename from docs/CODE_OF_CONDUCT.md rename to CODE_OF_CONDUCT.md diff --git a/docs/CONTRIBUTING.md b/CONTRIBUTING.md similarity index 100% rename from docs/CONTRIBUTING.md rename to CONTRIBUTING.md diff --git a/docs/LICENSE.md b/docs/LICENSE.md deleted file mode 100644 index 6734491..0000000 --- a/docs/LICENSE.md +++ /dev/null @@ -1,21 +0,0 @@ -# MIT License - -Copyright (c) 2024 Runware Inc. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/docs/controlnets.md b/docs/controlnets.md deleted file mode 100644 index 3315d13..0000000 --- a/docs/controlnets.md +++ /dev/null @@ -1,87 +0,0 @@ -# ControlNet Guide - -ControlNet offers advanced capabilities for precise image processing through the use of guide images in specific formats, known as preprocessed images. This powerful tool enhances the control and customization of image generation, enabling users to achieve desired artistic styles and detailed adjustments effectively. - -Using ControlNet via our API simplifies the integration of guide images into your workflow. By leveraging the API, you can seamlessly incorporate preprocessed images and specify various parameters to tailor the image generation process to your exact requirements. - -## Request - -Our API always accepts an array of objects as input, where each object represents a specific task to be performed. The structure of the object varies depending on the type of the task. For this section, we will focus on the parameters related to the ControlNet preprocessing task. - -The following JSON snippet shows the basic structure of a request object. All properties are explained in detail in the next section: - -```json -[ - { - "taskType": "imageControlNetPreProcess", - "taskUUID": "3303f1be-b3dc-41a2-94df-ead00498db57", - "inputImage": "ff1d9a0b-b80f-4665-ae07-8055b99f4aea", - "preProcessorType": "canny", - "height": 512, - "width": 512 - } -] -``` - -### Parameters - -| Parameter | Type | Description | -|-------------------------------|---------------|-----------------------------------------------------------------------------------------------------------------------------------------| -| taskType | string | Must be set to "imageControlNetPreProcess" for this operation. | -| taskUUID | UUIDv4 string | Unique identifier for the task, used to match async responses. | -| inputImage | UUIDv4 string | The UUID of the image to be preprocessed. Use the Image Upload functionality to obtain this. | -| preProcessorType | string | The type of preprocessor to use. See list of available options below. | -| width | integer | Optional. Will resize the image to this width. | -| height | integer | Optional. Will resize the image to this height. | -| lowThresholdCanny | integer | Optional. Available only for 'canny' preprocessor. Defines the lower threshold. Recommended value is 100. | -| highThresholdCanny | integer | Optional. Available only for 'canny' preprocessor. Defines the high threshold. Recommended value is 200. | -| includeHandsAndFaceOpenPose | boolean | Optional. Available only for 'openpose' preprocessor. Includes hands and face in the pose outline. Defaults to false. | - -### Preprocessor Types - -Available preprocessor types are: - -```json -canny -depth -mlsd -normalbae -openpose -tile -seg -lineart -lineart_anime -shuffle -scribble -softedge -``` - -## Response - -The response to the ControlNet preprocessing request will have the following format: - -```json -{ - "data": [ - { - "taskType": "imageControlNetPreProcess", - "taskUUID": "3303f1be-b3dc-41a2-94df-ead00498db57", - "guideImageUUID": "b6a06b3b-ce32-4884-ad93-c5eca7937ba0", - "inputImageUUID": "ff1d9a0b-b80f-4665-ae07-8055b99f4aea", - "guideImageURL": "https://im.runware.ai/image/ws/0.5/ii/b6a06b3b-ce32-4884-ad93-c5eca7937ba0.jpg", - "cost": 0.0006 - } - ] -} -``` - -### Response Parameters - -| Parameter | Type | Description | -|----------------|---------------|------------------------------------------------------------------------------------------------| -| taskType | string | The type of task, in this case "imageControlNetPreProcess". | -| taskUUID | UUIDv4 string | The unique identifier matching the original request. | -| guideImageUUID | UUIDv4 string | Unique identifier for the preprocessed guide image. | -| inputImageUUID | UUIDv4 string | The UUID of the original input image. | -| guideImageURL | string | The URL of the preprocessed guide image. Can be used to visualize or display the image in UIs. | -| cost | number | The cost of the operation (if includeCost was set to true in the request). | diff --git a/docs/editing.md b/docs/editing.md deleted file mode 100644 index de58d83..0000000 --- a/docs/editing.md +++ /dev/null @@ -1,63 +0,0 @@ -# Image Upscaling - -Enhance the resolution and quality of your images using Runware's advanced upscaling API. Transform low-resolution images into sharp, high-definition visuals. -Upscaling refers to the process of enhancing the resolution and overall quality of images. This technique is particularly useful for improving the visual clarity and detail of lower-resolution images, making them suitable for various high-definition applications. - -## Request - -To upscale an image, send a request in the following format: - -```json -[ - { - "taskType": "imageUpscale", - "taskUUID": "19abad0d-6ec5-40a6-b7af-203775fa5b7f", - "inputImage": "fd613011-3872-4f37-b4aa-0d343c051a27", - "outputType": "URL", - "outputFormat": "JPG", - "upscaleFactor": 2 - } -] -``` - -### Parameters - -| Parameter | Type | Description | -|---------------|---------------|---------------------------------------------------------------------------------------------------------------| -| taskType | string | Must be set to "imageUpscale" for this operation. | -| taskUUID | UUIDv4 string | Unique identifier for the task, used to match async responses. | -| inputImage | UUIDv4 string | The UUID of the image to be upscaled. Can be from a previously uploaded or generated image. | -| upscaleFactor | integer | The level of upscaling to be performed. Can be 2, 3, or 4. Each will increase the image size by that factor. | -| outputFormat | string | Specifies the format of the output image. Supported formats are: PNG, JPG and WEBP. | -| includeCost | boolean | Optional. If set to true, the response will include the cost of the operation. | - -## Response - -Responses will be delivered in the following format: - -```json -{ - "data": [ - { - "taskType": "imageUpscale", - "taskUUID": "19abad0d-6ec5-40a6-b7af-203775fa5b7f", - "imageUUID": "e0b6ed2b-311d-4abc-aa01-8f3fdbdb8860", - "inputImageUUID": "fd613011-3872-4f37-b4aa-0d343c051a27", - "imageURL": "https://im.runware.ai/image/ws/0.5/ii/e0b6ed2b-311d-4abc-aa01-8f3fdbdb8860.jpg", - "cost": 0 - } - ] -} -``` - -### Response Parameters - -| Parameter | Type | Description | -|--------------|---------------|------------------------------------------------------------------------------------------------| -| taskType | string | The type of task, in this case "imageUpscale". | -| taskUUID | UUIDv4 string | The unique identifier matching the original request. | -| imageUUID | UUIDv4 string | The UUID of the upscaled image. | -| imageURL | string | The URL where the upscaled image can be downloaded from. | -| cost | number | The cost of the operation (included if `includeCost` was set to true). | - -Note: The NSFW filter occasionally returns false positives and very rarely false negatives. \ No newline at end of file diff --git a/docs/examples/example1.md b/docs/examples/example1.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs/examples/example2.md b/docs/examples/example2.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs/getting_started.md b/docs/getting_started.md deleted file mode 100644 index aec50f8..0000000 --- a/docs/getting_started.md +++ /dev/null @@ -1,113 +0,0 @@ -# WebSockets Endpoint, API Key, and Connections - -This guide explains how to authenticate, connect to, and interact with the Runware WebSocket API. - -## Authentication - -To interact with the Runware API, you need to authenticate your requests using an API key. This key is unique to your account and identifies you when making requests. - -- You can create multiple keys for different projects or environments (development, production, staging). -- Keys can have descriptions and can be revoked at any time. -- With the new teams feature, you can share keys with your team members. - -To create an API key: - -1. Sign up on Runware -2. Visit the "API Keys" page -3. Click "Create Key" -4. Fill in the details for your new key - -## WebSockets - -We currently support WebSocket connections as they are more efficient, faster, and less resource-intensive. Our WebSocket connections are designed to be easy to work with, as each response contains the request ID, allowing for easy matching of requests to responses. - -- The API uses a bidirectional protocol that encodes all messages as JSON objects. -- You can connect using one of our provided SDKs (Python, JavaScript, Go) or manually. -- If connecting manually, the endpoint URL is `wss://ws-api.runware.ai/v1`. - -## New Connections - -WebSocket connections are point-to-point, so there's no need for each request to contain an authentication header. Instead, the first request must always be an authentication request that includes the API key. - -### Authentication Request - -```json -[ - { - "taskType": "authentication", - "apiKey": "" - } -] -``` - -### Authentication Response - -On successful authentication, you'll receive a response with a `connectionSessionUUID`: - -```json -{ - "data": [ - { - "taskType": "authentication", - "connectionSessionUUID": "f40c2aeb-f8a7-4af7-a1ab-7594c9bf778f" - } - ] -} -``` - -In case of an error, you'll receive an object with an error message: - -```json -{ - "error": true, - "errorMessageContentId": 1212, - "errorId": 19, - "errorMessage": "Invalid api key" -} -``` - -## Keeping Connection Alive - -The WebSocket connection is kept open for 120 seconds from the last message exchanged. If you don't send any messages for 120 seconds, the connection will be closed automatically. - -To keep the connection active, you can send a `ping` message: - -```json -[ - { - "taskType": "ping", - "ping": true - } -] -``` - -The server will respond with a `pong`: - -```json -{ - "data": [ - { - "taskType": "ping", - "pong": true - } - ] -} -``` - -## Resuming Connections - -If any service, server, or network becomes unresponsive, all undelivered images or tasks are kept in a buffer memory for 120 seconds. You can reconnect and receive these messages by including the `connectionSessionUUID` in the authentication request: - -```json -[ - { - "taskType": "authentication", - "apiKey": "", - "connectionSessionUUID": "f40c2aeb-f8a7-4af7-a1ab-7594c9bf778f" - } -] -``` - -This means you don't need to resend the initial request; it will be delivered when reconnecting. SDK libraries handle reconnections automatically. - -After establishing a connection, you can send various tasks to the API, such as text-to-image, image-to-image, inpainting, upscaling, image-to-text, image upload, etc. \ No newline at end of file diff --git a/docs/image_to_image.md b/docs/image_to_image.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs/text_to_image.md b/docs/text_to_image.md deleted file mode 100644 index 599c782..0000000 --- a/docs/text_to_image.md +++ /dev/null @@ -1,129 +0,0 @@ -# Image Inference API - -Generate images from text prompts or transform existing ones using Runware's API. This powerful feature allows you to create high-quality visuals, bringing creative ideas to life or enhancing existing images with new styles or subjects. - -## Introduction - -Image inference enables you to: - -1. **Text-to-Image**: Generate images from descriptive text prompts. -2. **Image-to-Image**: Transform existing images, controlling the strength of the transformation. -3. **Inpainting**: Replace parts of an image with new content. -4. **Outpainting**: Extend the boundaries of an image with new content. - -Advanced features include: - -- **ControlNet**: Precise control over image generation using additional input conditions. -- **LoRA**: Adapt models to specific styles or tasks. - -Our API is optimized for speed and efficiency, powered by our Sonic Inference Engine. - -## Request - -Requests are sent as an array of objects, each representing a specific task. Here's the basic structure for an image inference task: - -```json -[ - { - "taskType": "imageInference", - "taskUUID": "string", - "outputType": "string", - "outputFormat": "string", - "positivePrompt": "string", - "negativePrompt": "string", - "height": int, - "width": int, - "model": "string", - "steps": int, - "CFGScale": float, - "numberResults": int - } -] -``` - -### Parameters - -| Parameter | Type | Required | Description | -|------------------|-------------------------------|----------|------------------------------------------------------------------------------------------------| -| taskType | string | Yes | Must be set to "imageInference" for this task. | -| taskUUID | string (UUID v4) | Yes | Unique identifier for the task, used to match async responses. | -| outputType | string | No | Specifies the output format: "base64Data", "dataURI", or "URL" (default). | -| outputFormat | string | No | Specifies the image format: "JPG" (default), "PNG", or "WEBP". | -| positivePrompt | string | Yes | Text instruction guiding the image generation (4-2000 characters). | -| negativePrompt | string | No | Text instruction to avoid certain elements in the image (4-2000 characters). | -| height | integer | Yes | Height of the generated image (512-2048, must be divisible by 64). | -| width | integer | Yes | Width of the generated image (512-2048, must be divisible by 64). | -| model | string | Yes | AIR identifier of the model to use. | -| steps | integer | No | Number of inference steps (1-100, default 20). | -| CFGScale | float | No | Guidance scale for prompt adherence (0-30, default 7). | -| numberResults | integer | No | Number of images to generate (default 1). | - -Additional parameters: - -| Parameter | Type | Required | Description | -|---------------------|---------|----------|---------------------------------------------------------------------------------| -| uploadEndpoint | string | No | URL to upload the generated image using HTTP PUT. | -| checkNSFW | boolean | No | Enable NSFW content check (adds 0.1s to inference time). | -| includeCost | boolean | No | Include the cost of the operation in the response. | -| seedImage | string | No* | Image to use as a starting point (required for Image-to-Image, In/Outpainting). | -| maskImage | string | No* | Mask image for Inpainting/Outpainting (required for these operations). | -| strength | float | No | Influence of the seed image (0-1, default 0.8). | -| scheduler | string | No | Specify a different scheduler (default is model's own scheduler). | -| seed | integer | No | Seed for reproducible results (1-9223372036854776000). | -| clipSkip | integer | No | Number of CLIP layers to skip (0-2, default 0). | -| usePromptWeighting | boolean | No | Enable advanced prompt weighting (adds 0.2s to inference time). | - -### ControlNet - -To use ControlNet, include a `controlNet` array in your request with objects containing: - -| Parameter | Type | Required | Description | -|-----------|---------|----------|-----------------------------------------------------------| -| model | string | Yes | AIR identifier of the ControlNet model. | -| guideImage| string | Yes | Preprocessed guide image (UUID, data URI, base64, or URL).| -| weight | float | No | Weight of this ControlNet model (0-1). | -| startStep | integer | No | Step to start applying ControlNet. | -| endStep | integer | No | Step to stop applying ControlNet. | -| controlMode| string | No | "prompt", "controlnet", or "balanced". | - -### LoRA - -To use LoRA, include a `lora` array in your request with objects containing: - -| Parameter | Type | Required | Description | -|-----------|--------|----------|----------------------------------------| -| model | string | Yes | AIR identifier of the LoRA model. | -| weight | float | No | Weight of this LoRA model (default 1). | - -## Response - -The API returns results in the following format: - -```json -{ - "data": [ - { - "taskType": "imageInference", - "taskUUID": "a770f077-f413-47de-9dac-be0b26a35da6", - "imageUUID": "77da2d99-a6d3-44d9-b8c0-ae9fb06b6200", - "imageURL": "https://im.runware.ai/image/ws/0.5/ii/a770f077-f413-47de-9dac-be0b26a35da6.jpg", - "cost": 0.0013 - } - ] -} -``` - -### Response Parameters - -| Parameter | Type | Description | -|----------------|------------------|---------------------------------------------------------------------| -| taskType | string | Type of the task ("imageInference"). | -| taskUUID | string (UUID v4) | Unique identifier matching the original request. | -| imageUUID | string (UUID v4) | Unique identifier of the generated image. | -| imageURL | string | URL to download the image (if outputType is "URL"). | -| imageBase64Data| string | Base64-encoded image data (if outputType is "base64Data"). | -| imageDataURI | string | Data URI of the image (if outputType is "dataURI"). | -| NSFWContent | boolean | Indicates if the image was flagged as NSFW (if checkNSFW was true). | -| cost | float | Cost of the operation in USD (if includeCost was true). | - -Note: The API may return multiple images per message, as they are generated in parallel. diff --git a/docs/utilities.md b/docs/utilities.md deleted file mode 100644 index 7f712e7..0000000 --- a/docs/utilities.md +++ /dev/null @@ -1,27 +0,0 @@ -# Utilities - -The Runware SDK provides several utility functions to enhance your workflow and simplify common tasks. This section provides an overview of the available utilities. - -## Image Upload - -Images can be uploaded to be used as seed to reverse prompts and get image to text results. - -For detailed information on how to use the Image Upload utility, see the [Image Upload documentation](utilities_image_upload.md). - -## Image to Text - -Image to text allows you to upload an image and generate the prompt used to create similar images. - -To learn more about the Image to Text utility and how to use it in your code, refer to the [Image to Text documentation](utilities_image_to_text.md). - -## Prompt Enhancer - -Prompt enhancer can be used to add keywords to prompts that are meant to increase the quality or variety of results. - -For a comprehensive guide on using the Prompt Enhancer utility, see the [Prompt Enhancer documentation](utilities_prompt_enhancer.md). - ---- - -These utilities are designed to streamline your workflow and provide additional functionality to the core features of the Runware SDK. Each utility comes with its own detailed documentation, explaining how to use it effectively in your projects. - -If you have any questions or need further assistance with the utilities, please don't hesitate to [reach out](https://github.com/runware/sdk-python/issues) to our support team. diff --git a/docs/utilities_image_to_text.md b/docs/utilities_image_to_text.md deleted file mode 100644 index f84b623..0000000 --- a/docs/utilities_image_to_text.md +++ /dev/null @@ -1,52 +0,0 @@ -# Image to Text - -Image to text, also known as image captioning, allows you to obtain descriptive text prompts based on uploaded or previously generated images. This process is instrumental in generating textual descriptions that can be used to create additional images or provide detailed insights into visual content. - -## Request - -Image to text requests must have the following format: - -```json -[ - { - "taskType": "imageCaption", - "taskUUID": "f0a5574f-d653-47f1-ab42-e2c1631f1a47", - "inputImage": "5788104a-1ca7-4b7e-8a16-b27b57e86f87" - } -] -``` - -### Parameters - -| Parameter | Type | Description | -|-------------|--------------|---------------------------------------------------------------------------------------| -| taskType | string | Must be set to "imageCaption" for this operation. | -| taskUUID | UUIDv4 string | Unique identifier for the task, used to match async responses. | -| inputImage | UUIDv4 string | The UUID of the image to be analyzed. Can be from an uploaded or generated image. | -| includeCost | boolean | Optional. If set to true, the response will include the cost of the operation. | - -## Response - -Results will be delivered in the following format: - -```json -{ - "data": [ - { - "taskType": "imageCaption", - "taskUUID": "f0a5574f-d653-47f1-ab42-e2c1631f1a47", - "text": "arafed troll in the jungle with a backpack and a stick, cgi animation, cinematic movie image, gremlin, pixie character, nvidia promotional image, park background, with lots of scumbling, hollywood promotional image, on island, chesley, green fog, post-nuclear", - "cost": 0 - } - ] -} -``` - -### Response Parameters - -| Parameter | Type | Description | -|-----------|---------------|-----------------------------------------------------------------------| -| taskType | string | The type of task, in this case "imageCaption". | -| taskUUID | UUIDv4 string | The unique identifier matching the original request. | -| text | string | The resulting text prompt from analyzing the image. | -| cost | number | The cost of the operation (included if `includeCost` was set to true).| \ No newline at end of file diff --git a/docs/utilities_image_upload.md b/docs/utilities_image_upload.md deleted file mode 100644 index 630d59e..0000000 --- a/docs/utilities_image_upload.md +++ /dev/null @@ -1,51 +0,0 @@ -# Image Upload - -Image upload is necessary for using images as seeds for new image generation, or to run image-to-text operations and obtain prompts that would generate similar images. - -## Request - -Image upload requests must have the following format: - -```json -[ - { - "taskType": "imageUpload", - "taskUUID": "50836053-a0ee-4cf5-b9d6-ae7c5d140ada", - "image": "data:image/png;base64,iVBORw0KGgo..." - } -] -``` - -### Parameters - -| Parameter | Type | Description | -|-----------|--------------|---------------------------------------------------------------------------------------| -| taskType | string | Must be set to "imageUpload" for this operation. | -| taskUUID | UUIDv4 string | Unique identifier for the task, used to match async responses. | -| image | string | The image file in base64 format. Supported formats are: PNG, JPG, WEBP. | - -## Response - -The response to the image upload request will have the following format: - -```json -{ - "data": [ - { - "taskType": "imageUpload", - "taskUUID": "50836053-a0ee-4cf5-b9d6-ae7c5d140ada", - "imageUUID": "989ba605-1449-4e1e-b462-cd83ec9c1a67", - "imageURL": "https://im.runware.ai/image/ws/0.5/ii/989ba605-1449-4e1e-b462-cd83ec9c1a67.jpg" - } - ] -} -``` - -### Response Parameters - -| Parameter | Type | Description | -|-----------|---------------|------------------------------------------------------------------------------------------------| -| taskType | string | The type of task, in this case "imageUpload". | -| taskUUID | UUIDv4 string | The unique identifier matching the original request. | -| imageUUID | UUIDv4 string | Unique identifier for the uploaded image. Use this for referencing in other operations. | -| imageURL | string | The URL of the uploaded image. Can be used to visualize or display the image in UIs. | diff --git a/docs/utilities_prompt_enhancer.md b/docs/utilities_prompt_enhancer.md deleted file mode 100644 index 69bc557..0000000 --- a/docs/utilities_prompt_enhancer.md +++ /dev/null @@ -1,76 +0,0 @@ -# Prompt Enhancer (Magic Prompt) - -Prompt enhancing can be used to generate different or potentially improved results for a particular topic. It works by adding keywords to a given prompt. Note that enhancing a prompt does not always preserve the intended subject and does not guarantee improved results over the original prompt. - -## Request - -Prompt enhancing requests must have the following format: - -```json -[ - { - "taskType": "promptEnhance", - "taskUUID": "9da1a4ad-c3de-4470-905d-5be5c042f98a", - "prompt": "dog", - "promptMaxLength": 64, - "promptVersions": 4 - } -] -``` - -### Parameters - -| Parameter | Type | Description | -|-----------------|---------------|-----------------------------------------------------------------------------------------------------| -| taskType | string | Must be set to "promptEnhance" for this operation. | -| taskUUID | UUIDv4 string | Unique identifier for the task, used to match async responses. | -| prompt | string | The original prompt you want to enhance. | -| promptMaxLength | integer | Maximum length of the enhanced prompt. Value between 4 and 400. | -| promptVersions | integer | Number of enhanced prompt versions to generate. Value between 1 and 5. | -| includeCost | boolean | Optional. If set to true, the response will include the cost of the operation. | - -## Response - -Results will be delivered in the following format: - -```json -{ - "data": [ - { - "taskType": "promptEnhance", - "taskUUID": "9da1a4ad-c3de-4470-905d-5be5c042f98a", - "text": "dog, ilya kuvshinov, gaston bussiere, craig mullins, simon bisley, arthur rackham", - "cost": 0 - }, - { - "taskType": "promptEnhance", - "taskUUID": "9da1a4ad-c3de-4470-905d-5be5c042f98a", - "text": "dog, ilya kuvshinov, artgerm", - "cost": 0 - }, - { - "taskType": "promptEnhance", - "taskUUID": "9da1a4ad-c3de-4470-905d-5be5c042f98a", - "text": "dog, ilya kuvshinov, gaston bussiere, craig mullins, simon bisley", - "cost": 0 - }, - { - "taskType": "promptEnhance", - "taskUUID": "9da1a4ad-c3de-4470-905d-5be5c042f98a", - "text": "dog, ilya kuvshinov, artgerm, krenz cushart, greg rutkowski, pixiv. cinematic dramatic atmosphere, sharp focus, volumetric lighting, cinematic lighting, studio quality", - "cost": 0 - } - ] -} -``` - -The response contains an array of objects in the "data" field. The number of objects corresponds to the `promptVersions` requested. Each object represents an enhanced prompt suggestion. - -### Response Parameters - -| Parameter | Type | Description | -|-----------|---------------|-----------------------------------------------------------------------| -| taskType | string | The type of task, in this case "promptEnhance". | -| taskUUID | UUIDv4 string | The unique identifier matching the original request. | -| text | string | The enhanced prompt text. | -| cost | number | The cost of the operation (included if `includeCost` was set to true).| diff --git a/examples/01_basic_image_generation.py b/examples/01_basic_image_generation.py new file mode 100644 index 0000000..216b460 --- /dev/null +++ b/examples/01_basic_image_generation.py @@ -0,0 +1,130 @@ +""" +Basic Image Generation Example + +This example demonstrates how to generate images using the Runware SDK +with various configuration options and proper error handling. +""" + +import asyncio +import os +from typing import List + +from runware import IImage, IImageInference, Runware, RunwareError + + +async def basic_image_generation(): + """Generate a simple image with basic parameters.""" + + # Initialize client with API key from environment + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + # Connect to the Runware service + await runware.connect() + + # Create image generation request + request = IImageInference( + positivePrompt="A majestic mountain landscape at sunset with vibrant colors", + model="civitai:4384@128713", + numberResults=1, + height=1024, + width=1024, + negativePrompt="blurry, low quality, distorted", + steps=30, + CFGScale=7.5, + seed=42, + ) + + # Generate images + images: List[IImage] = await runware.imageInference(requestImage=request) + + # Process results + for i, image in enumerate(images): + print(f"Generated image {i + 1}:") + print(f" URL: {image.imageURL}") + print(f" UUID: {image.imageUUID}") + if image.seed: + print(f" Seed: {image.seed}") + if image.cost: + print(f" Cost: ${image.cost}") + + except RunwareError as e: + print(f"Runware API Error: {e}") + if hasattr(e, "code"): + print(f"Error Code: {e.code}") + except Exception as e: + print(f"Unexpected error: {e}") + finally: + # Always disconnect when done + await runware.disconnect() + + +async def batch_image_generation(): + """Generate multiple images with different configurations.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Generate multiple images with different aspect ratios + requests = [ + IImageInference( + positivePrompt="Portrait of a wise old wizard with a long beard", + model="civitai:4384@128713", + numberResults=1, + height=1024, + width=768, # Portrait orientation + negativePrompt="young, modern clothing", + ), + IImageInference( + positivePrompt="Futuristic cityscape with flying cars and neon lights", + model="civitai:4384@128713", + numberResults=1, + height=768, + width=1024, # Landscape orientation + negativePrompt="old, vintage, medieval", + ), + IImageInference( + positivePrompt="Cute cartoon animal in a magical forest", + model="civitai:4384@128713", + numberResults=2, + height=1024, + width=1024, # Square format + negativePrompt="realistic, dark, scary", + ), + ] + + # Process all requests concurrently + tasks = [runware.imageInference(requestImage=req) for req in requests] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results + for i, result in enumerate(results): + if isinstance(result, Exception): + print(f"Request {i + 1} failed: {result}") + else: + images: List[IImage] = result + print(f"Request {i + 1} completed with {len(images)} images:") + for image in images: + print(f" Image URL: {image.imageURL}") + + except RunwareError as e: + print(f"Runware API Error: {e}") + except Exception as e: + print(f"Unexpected error: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run all basic image generation examples.""" + print("=== Basic Image Generation ===") + await basic_image_generation() + + print("\n=== Batch Image Generation ===") + await batch_image_generation() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/02_advanced_image_generation.py b/examples/02_advanced_image_generation.py new file mode 100644 index 0000000..068607c --- /dev/null +++ b/examples/02_advanced_image_generation.py @@ -0,0 +1,226 @@ +""" +Advanced Image Generation Example + +This example demonstrates advanced image generation features including: +- LoRA models for style enhancement +- ControlNet for guided generation +- Refiner models for quality improvement +- Accelerator options for faster inference +- Progress callbacks for real-time updates +""" + +import asyncio +import os +from typing import List + +from runware import ( + EControlMode, + IAcceleratorOptions, + IControlNetGeneral, + IImage, + IImageInference, + ILora, + IRefiner, + ProgressUpdate, + Runware, + RunwareError, +) +import time + + +def progress_callback(progress: ProgressUpdate): + """Handle progress updates during image generation.""" + print(f"Operation {progress.operation_id}: {progress.progress:.1%} complete") + if progress.message: + print(f" Status: {progress.message}") + if progress.partial_results: + print(f" Received {len(progress.partial_results)} partial results") + + +async def lora_enhanced_generation(): + """Generate images using LoRA models for enhanced styling.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Define LoRA models for style enhancement + lora_models = [ + ILora(model="civitai:58390@62833", weight=0.8), + ILora(model="civitai:42903@232848", weight=0.6), + ] + + request = IImageInference( + positivePrompt="masterpiece, best quality, 1girl, elegant dress, garden background, soft lighting", + model="civitai:36520@76907", + lora=lora_models, + numberResults=3, + height=1024, + width=768, + negativePrompt="worst quality, blurry, nsfw", + steps=35, + CFGScale=8.0, + outputFormat="PNG", + includeCost=True, + ) + + images: List[IImage] = await runware.imageInference( + requestImage=request, progress_callback=progress_callback + ) + + print("LoRA-enhanced images generated:") + for i, image in enumerate(images): + print(f" Image {i + 1}: {image.imageURL}") + if image.cost: + print(f" Cost: ${image.cost}") + + except RunwareError as e: + print(f"Error in LoRA generation: {e}") + finally: + await runware.disconnect() + + +async def controlnet_guided_generation(): + """Generate images using ControlNet for precise control.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Define ControlNet configuration + controlnet = IControlNetGeneral( + model="civitai:38784@44716", + guideImage="https://huggingface.co/datasets/mishig/sample_images/resolve/main/canny-edge.jpg", + weight=0.8, + startStep=0, + endStep=15, + controlMode=EControlMode.BALANCED, + ) + + request = IImageInference( + positivePrompt="beautiful anime character, detailed eyes, colorful hair", + model="civitai:4384@128713", + controlNet=[controlnet], + numberResults=1, + height=768, + width=768, + steps=30, + CFGScale=7.0, + seed=12345, + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("ControlNet-guided images:") + for image in images: + print(f" URL: {image.imageURL}") + print(f" Generated with seed: {image.seed}") + + except RunwareError as e: + print(f"Error in ControlNet generation: {e}") + finally: + await runware.disconnect() + + +async def refiner_enhanced_generation(): + """Generate images with refiner model for quality enhancement.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Configure refiner model + refiner = IRefiner( + model="civitai:101055@128080", # SDXL refiner + startStep=25, # Start refining after 25 steps + ) + + request = IImageInference( + positivePrompt="hyperrealistic portrait of a astronaut in space, detailed helmet reflection", + model="civitai:101055@128078", + refiner=refiner, + numberResults=1, + height=1024, + width=1024, + steps=40, # More steps for better quality with refiner + CFGScale=7.5, + outputFormat="PNG", + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Refiner-enhanced images:") + for image in images: + print(f" High-quality URL: {image.imageURL}") + + except RunwareError as e: + print(f"Error in refiner generation: {e}") + finally: + await runware.disconnect() + + +async def fast_generation_with_accelerators(): + """Generate images quickly using accelerator options.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Configure accelerator options for speed + accelerator_options = IAcceleratorOptions( + teaCache=True, + teaCacheDistance=0.4, + cacheStartStep=5, + cacheStopStep=25, + ) + + request = IImageInference( + positivePrompt="vibrant fantasy landscape with magical creatures", + model="runware:100@1", # Flux model that supports acceleration + acceleratorOptions=accelerator_options, + numberResults=1, + height=1024, + width=1024, + steps=28, + CFGScale=3.5, + ) + + start_time = time.time() + + images: List[IImage] = await runware.imageInference( + requestImage=request, progress_callback=progress_callback + ) + + generation_time = time.time() - start_time + + print(f"Fast generation completed in {generation_time:.2f} seconds:") + for image in images: + print(f" URL: {image.imageURL}") + + except RunwareError as e: + print(f"Error in fast generation: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run all advanced image generation examples.""" + print("=== LoRA Enhanced Generation ===") + await lora_enhanced_generation() + + print("\n=== ControlNet Guided Generation ===") + await controlnet_guided_generation() + + print("\n=== Refiner Enhanced Generation ===") + await refiner_enhanced_generation() + + print("\n=== Fast Generation with Accelerators ===") + await fast_generation_with_accelerators() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/03_image_editing.py b/examples/03_image_editing.py new file mode 100644 index 0000000..775edbe --- /dev/null +++ b/examples/03_image_editing.py @@ -0,0 +1,259 @@ +""" +Image Editing Operations Example + +This example demonstrates various image editing capabilities: +- Image captioning for description generation +- Background removal with custom settings +- Image upscaling for resolution enhancement +- File upload and processing workflows +""" + +import asyncio +import os +from typing import List + +from runware import ( + IBackgroundRemovalSettings, + IImage, + IImageBackgroundRemoval, + IImageCaption, + IImageToText, + IImageUpscale, + Runware, + RunwareError, +) + + +async def image_captioning(): + """Generate descriptive captions for images.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Test image URL + image_url = "https://img.freepik.com/free-photo/macro-picture-red-leaf-lights-against-black-background_181624-32636.jpg" + + # Create captioning request + caption_request = IImageCaption(inputImage=image_url, includeCost=True) + + # Generate caption + result: IImageToText = await runware.imageCaption( + requestImageToText=caption_request + ) + + print("Image Caption Analysis:") + print(f" Image: {image_url}") + print(f" Description: {result.text}") + if result.cost: + print(f" Processing cost: ${result.cost}") + + except RunwareError as e: + print(f"Error in image captioning: {e}") + finally: + await runware.disconnect() + + +async def background_removal_basic(): + """Remove background from image using default settings.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Image with clear subject for background removal + image_url = "https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/headphones.jpeg" + + # Basic background removal request + removal_request = IImageBackgroundRemoval( + inputImage=image_url, outputType="URL", outputFormat="PNG", includeCost=True + ) + + # Process image + processed_images: List[IImage] = await runware.imageBackgroundRemoval( + removeImageBackgroundPayload=removal_request + ) + + print("Basic Background Removal:") + for i, image in enumerate(processed_images): + print(f" Result {i + 1}: {image.imageURL}") + if image.cost: + print(f" Cost: ${image.cost}") + + except RunwareError as e: + print(f"Error in background removal: {e}") + finally: + await runware.disconnect() + + +async def background_removal_advanced(): + """Remove background with advanced settings and custom model.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + image_url = "https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/headphones.jpeg" + + # Advanced background removal settings + advanced_settings = IBackgroundRemovalSettings( + rgba=[255, 255, 255, 0], # White transparent background + alphaMatting=True, # Better edge quality + postProcessMask=True, # Refine mask edges + returnOnlyMask=False, # Return processed image, not just mask + alphaMattingErodeSize=10, + alphaMattingForegroundThreshold=240, + alphaMattingBackgroundThreshold=10, + ) + + # Request with custom settings + removal_request = IImageBackgroundRemoval( + inputImage=image_url, + settings=advanced_settings, + outputType="URL", + outputFormat="PNG", + outputQuality=95, + includeCost=True, + ) + + processed_images: List[IImage] = await runware.imageBackgroundRemoval( + removeImageBackgroundPayload=removal_request + ) + + print("Advanced Background Removal:") + for image in processed_images: + print(f" High-quality result: {image.imageURL}") + if image.cost: + print(f" Cost: ${image.cost}") + + except RunwareError as e: + print(f"Error in advanced background removal: {e}") + finally: + await runware.disconnect() + + +async def image_upscaling(): + """Enhance image resolution using upscaling.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Lower resolution image for upscaling demonstration + image_url = "https://img.freepik.com/free-photo/macro-picture-red-leaf-lights-against-black-background_181624-32636.jpg" + + # Test different upscale factors + upscale_factors = [2, 4] + + for factor in upscale_factors: + upscale_request = IImageUpscale( + inputImage=image_url, + upscaleFactor=factor, + outputType="URL", + outputFormat="PNG", + includeCost=True, + ) + + upscaled_images: List[IImage] = await runware.imageUpscale( + upscaleGanPayload=upscale_request + ) + + print(f"Upscaling {factor}x:") + for image in upscaled_images: + print(f" Enhanced image: {image.imageURL}") + if image.cost: + print(f" Processing cost: ${image.cost}") + + except RunwareError as e: + print(f"Error in image upscaling: {e}") + finally: + await runware.disconnect() + + +async def complete_editing_workflow(): + """Demonstrate a complete image editing workflow.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + original_image = "https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/headphones.jpeg" + + print("Starting complete editing workflow...") + + # Step 1: Generate caption for the original image + print("Step 1: Analyzing image content...") + caption_request = IImageCaption(inputImage=original_image) + caption_result: IImageToText = await runware.imageCaption( + requestImageToText=caption_request + ) + print(f" Original image description: {caption_result.text}") + + # Step 2: Remove background + print("Step 2: Removing background...") + bg_removal_request = IImageBackgroundRemoval( + inputImage=original_image, outputType="URL", outputFormat="PNG" + ) + bg_removed_images: List[IImage] = await runware.imageBackgroundRemoval( + removeImageBackgroundPayload=bg_removal_request + ) + background_removed_url = bg_removed_images[0].imageURL + print(f" Background removed: {background_removed_url}") + + # Step 3: Upscale the result + print("Step 3: Enhancing resolution...") + upscale_request = IImageUpscale( + inputImage=background_removed_url, + upscaleFactor=2, + outputType="URL", + outputFormat="PNG", + ) + upscaled_images: List[IImage] = await runware.imageUpscale( + upscaleGanPayload=upscale_request + ) + final_image_url = upscaled_images[0].imageURL + print(f" Final enhanced image: {final_image_url}") + + # Step 4: Generate caption for final result + print("Step 4: Analyzing final result...") + final_caption_request = IImageCaption(inputImage=final_image_url) + final_caption: IImageToText = await runware.imageCaption( + requestImageToText=final_caption_request + ) + print(f" Final image description: {final_caption.text}") + + print("\nWorkflow completed successfully!") + print(f"Original: {original_image}") + print(f"Final: {final_image_url}") + + except RunwareError as e: + print(f"Error in editing workflow: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run all image editing examples.""" + print("=== Image Captioning ===") + await image_captioning() + + print("\n=== Basic Background Removal ===") + await background_removal_basic() + + print("\n=== Advanced Background Removal ===") + await background_removal_advanced() + + print("\n=== Image Upscaling ===") + await image_upscaling() + + print("\n=== Complete Editing Workflow ===") + await complete_editing_workflow() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/04_video_generation.py b/examples/04_video_generation.py new file mode 100644 index 0000000..e0e6a2b --- /dev/null +++ b/examples/04_video_generation.py @@ -0,0 +1,315 @@ +""" +Video Generation Example + +This example demonstrates video generation using different AI providers: +- Google Veo for high-quality cinematic videos +- Kling AI for creative video generation +- Minimax for text-to-video and image-to-video +- Bytedance for professional video content +- Pixverse for stylized video generation +- Vidu for anime-style videos +""" + +import asyncio +import os +from typing import List + +from runware import ( + IBytedanceProviderSettings, + IFrameImage, + IGoogleProviderSettings, + IMinimaxProviderSettings, + IPixverseProviderSettings, + IVideo, + IVideoInference, + IViduProviderSettings, + Runware, + RunwareError, +) + + +async def google_veo_generation(): + """Generate videos using Google's Veo model with advanced features.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Text-to-video generation + print("Generating text-to-video with Google Veo...") + text_to_video_request = IVideoInference( + positivePrompt="A majestic eagle soaring through mountain peaks at golden hour, cinematic cinematography", + model="google:3@0", + width=1280, + height=720, + duration=8, + numberResults=1, + seed=42, + includeCost=True, + providerSettings=IGoogleProviderSettings( + generateAudio=True, enhancePrompt=True + ), + ) + + videos: List[IVideo] = await runware.videoInference( + requestVideo=text_to_video_request + ) + + for video in videos: + print(f" Generated video: {video.videoURL}") + if video.cost: + print(f" Cost: ${video.cost}") + print(f" Seed used: {video.seed}") + + # Image-to-video generation + print("\nGenerating image-to-video with Google Veo...") + image_to_video_request = IVideoInference( + positivePrompt="The galaxy slowly rotates with sparkling stars", + model="google:2@0", + width=1280, + height=720, + duration=5, + numberResults=1, + frameImages=[ + IFrameImage( + inputImage="https://github.com/adilentiq/test-images/blob/main/common/image_15_mb.jpg?raw=true" + ) + ], + includeCost=True, + ) + + videos = await runware.videoInference(requestVideo=image_to_video_request) + + for video in videos: + print(f" I2V result: {video.videoURL}") + if video.cost: + print(f" Cost: ${video.cost}") + + except RunwareError as e: + print(f"Error with Google Veo: {e}") + finally: + await runware.disconnect() + + +async def kling_ai_generation(): + """Generate videos using Kling AI""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + request = IVideoInference( + positivePrompt="A beautiful woman portrait", + model="klingai:1@2", + width=1280, + height=720, + duration=5, + numberResults=1, + CFGScale=1.0, + includeCost=True, + frameImages=[ + IFrameImage( + inputImage="https://huggingface.co/ntc-ai/SDXL-LoRA-slider.Studio-Ghibli-style/resolve/main/images/Studio%20Ghibli%20style_17_-1.5.png", + frame="first", + ) + ], + ) + + videos: List[IVideo] = await runware.videoInference(requestVideo=request) + + print("Kling AI video:") + for video in videos: + print(f" Video URL: {video.videoURL}") + print(f" Status: {video.status}") + if video.cost: + print(f" Cost: ${video.cost}") + + except RunwareError as e: + print(f"Error with Kling AI: {e}") + finally: + await runware.disconnect() + + +async def minimax_generation(): + """Generate videos using Minimax with prompt optimization.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + request = IVideoInference( + positivePrompt="A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage", + model="minimax:1@1", + width=1366, + height=768, + duration=6, + numberResults=1, + seed=12345, + includeCost=True, + providerSettings=IMinimaxProviderSettings( + promptOptimizer=True # Enhance prompt automatically + ), + ) + + videos: List[IVideo] = await runware.videoInference(requestVideo=request) + + print("Minimax video generation:") + for video in videos: + print(f" Video URL: {video.videoURL}") + if video.cost: + print(f" Cost: ${video.cost}") + print(f" Generated with seed: {video.seed}") + + except RunwareError as e: + print(f"Error with Minimax: {e}") + finally: + await runware.disconnect() + + +async def bytedance_generation(): + """Generate videos using Bytedance with professional settings.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + request = IVideoInference( + positivePrompt="A couple in formal evening attire walking in heavy rain with an umbrella, cinematic lighting", + model="bytedance:1@1", + height=1504, + width=640, + duration=5, + numberResults=1, + seed=98765, + includeCost=True, + providerSettings=IBytedanceProviderSettings( + cameraFixed=False, # Allow camera movement + ), + ) + + videos: List[IVideo] = await runware.videoInference(requestVideo=request) + + print("Seedance 1.0 Lite video:") + for video in videos: + print(f" Video URL: {video.videoURL}") + if video.cost: + print(f" Cost: ${video.cost}") + + except RunwareError as e: + print(f"Error with Bytedance: {e}") + finally: + await runware.disconnect() + + +async def pixverse_stylized_generation(): + """Generate stylized videos using Pixverse with effects and styles.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + request = IVideoInference( + positivePrompt="A magical transformation sequence with sparkles and light effects", + negativePrompt="blurry, low quality", + model="pixverse:1@2", + width=1280, + height=720, + duration=5, + fps=24, + numberResults=1, + seed=55555, + includeCost=True, + frameImages=[ + IFrameImage( + inputImage="https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/man_beard.jpg" + ) + ], + providerSettings=IPixverseProviderSettings( + effect="boom drop", # Special effect + style="anime", # Anime art style + motionMode="normal", # Motion intensity + ), + ) + + videos: List[IVideo] = await runware.videoInference(requestVideo=request) + + print("Pixverse stylized video:") + for video in videos: + print(f" Anime-style video: {video.videoURL}") + print(f" Status: {video.status}") + if video.cost: + print(f" Cost: ${video.cost}") + + except RunwareError as e: + print(f"Error with Pixverse: {e}") + finally: + await runware.disconnect() + + +async def vidu_anime_generation(): + """Generate anime-style videos using Vidu.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + request = IVideoInference( + positivePrompt="A red fox moves stealthily through autumn woods, hunting for prey", + model="vidu:1@5", + width=1920, + height=1080, + duration=4, + numberResults=1, + seed=77777, + includeCost=True, + providerSettings=IViduProviderSettings( + style="anime", # Anime art style + movementAmplitude="auto", # Automatic movement detection + bgm=True, # Add background music + ), + ) + + videos: List[IVideo] = await runware.videoInference(requestVideo=request) + + print("Vidu anime-style video:") + for video in videos: + print(f" Anime video with BGM: {video.videoURL}") + print(f" Status: {video.status}") + if video.cost: + print(f" Cost: ${video.cost}") + + except RunwareError as e: + print(f"Error with Vidu: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run video generation examples for different providers.""" + print("=== Google Veo Video Generation ===") + await google_veo_generation() + + print("\n=== Kling AI ===") + await kling_ai_generation() + + print("\n=== Minimax with Prompt Optimization ===") + await minimax_generation() + + print("\n=== Bytedance Video ===") + await bytedance_generation() + + print("\n=== Pixverse Stylized Video ===") + await pixverse_stylized_generation() + # + print("\n=== Vidu Anime-Style Video ===") + await vidu_anime_generation() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/05_ace_plus_plus.py b/examples/05_ace_plus_plus.py new file mode 100644 index 0000000..62bb2f9 --- /dev/null +++ b/examples/05_ace_plus_plus.py @@ -0,0 +1,206 @@ +""" +ACE++ Advanced Character Editing Example + +ACE++ (Advanced Character Edit) enables character-consistent image generation +and editing while preserving identity. This example demonstrates: +- Portrait editing with identity preservation +- Subject integration and replacement +- Local editing with masks +- Logo and object placement +- Movie poster style editing +""" + +import asyncio +import os +from typing import List + +from runware import IAcePlusPlus, IImage, IImageInference, Runware, RunwareError + + +async def logo_placement(): + """Place logos and branding elements on products using masks.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Assets for logo placement + reference_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/logo_paste/1_ref.png" + mask_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/logo_paste/1_1_m.png" + init_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/logo_paste/1_1_edit.png" + + request = IImageInference( + positivePrompt="The logo is printed on the headphones with high quality and proper lighting.", + model="runware:102@1", + height=1024, + width=1024, + numberResults=1, + steps=28, + CFGScale=50.0, + referenceImages=[reference_image], # Logo reference + acePlusPlus=IAcePlusPlus( + inputImages=[init_image], # Product image + inputMasks=[mask_image], # Mask for logo placement area + repaintingScale=1.0, # Full prompt adherence for placement + taskType="subject", # Subject placement task + ), + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Logo Placement:") + for image in images: + print(f" Product with logo: {image.imageURL}") + print(f" Logo professionally integrated using mask guidance") + + except RunwareError as e: + print(f"Error in logo placement: {e}") + finally: + await runware.disconnect() + + +async def local_region_editing(): + """Edit specific regions of images using local editing masks.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Assets for local editing + mask_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/local/local_1_m.webp" + init_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/local/local_1.webp" + + request = IImageInference( + positivePrompt='By referencing the mask, restore a partial image from the doodle that aligns with the textual explanation: "1 white old owl".', + model="runware:102@1", + height=1024, + width=1024, + numberResults=1, + steps=28, + CFGScale=50.0, + acePlusPlus=IAcePlusPlus( + inputImages=[init_image], # Image to edit + inputMasks=[mask_image], # Local region mask + repaintingScale=0.5, # Balanced editing + taskType="local_editing", # Local editing mode + ), + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Local Region Editing:") + for image in images: + print(f" Locally edited image: {image.imageURL}") + print(f" Specific region refined while preserving surrounding areas") + + except RunwareError as e: + print(f"Error in local editing: {e}") + finally: + await runware.disconnect() + + +async def movie_poster_editing(): + """Create movie poster style edits with character replacement.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Movie poster editing assets + reference_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/movie_poster/1_ref.png" + mask_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/movie_poster/1_1_m.png" + init_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/movie_poster/1_1_edit.png" + + request = IImageInference( + positivePrompt="The man is facing the camera and is smiling with confidence and charisma, perfect for a movie poster.", + model="runware:102@1", + height=768, + width=1024, + numberResults=1, + steps=28, + CFGScale=50.0, + referenceImages=[reference_image], # Character reference + acePlusPlus=IAcePlusPlus( + inputImages=[init_image], # Poster template + inputMasks=[mask_image], # Character replacement area + repaintingScale=1.0, # Full creative freedom in masked area + taskType="portrait", # Portrait-aware processing + ), + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Movie Poster Editing:") + for image in images: + print(f" Movie poster: {image.imageURL}") + print(f" Character seamlessly integrated into poster design") + + except RunwareError as e: + print(f"Error in movie poster editing: {e}") + finally: + await runware.disconnect() + + +async def photo_editing_workflow(): + """Demonstrate a complex photo editing workflow using ACE++.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Professional photo editing assets + init_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/photo_editing/1_1_edit.png" + mask_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/photo_editing/1_1_m.png" + reference_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/photo_editing/1_ref.png" + + request = IImageInference( + positivePrompt="The item is put on the ground with proper lighting and realistic shadows, professional product photography.", + model="runware:102@1", + height=1024, + width=1024, + numberResults=1, + steps=28, + CFGScale=50.0, + referenceImages=[reference_image], # Product reference + acePlusPlus=IAcePlusPlus( + inputImages=[init_image], # Scene to edit + inputMasks=[mask_image], # Product placement area + repaintingScale=1.0, # Full control over placement + taskType="subject", # Subject placement + ), + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Professional Photo Editing:") + for image in images: + print(f" Edited photo: {image.imageURL}") + print(f" Product professionally integrated with realistic lighting") + + except RunwareError as e: + print(f"Error in photo editing: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run all ACE++ advanced editing examples.""" + print("\n=== Logo Placement ===") + await logo_placement() + + print("\n=== Local Region Editing ===") + await local_region_editing() + + print("\n=== Movie Poster Editing ===") + await movie_poster_editing() + + print("\n=== Professional Photo Editing ===") + await photo_editing_workflow() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/06_photo_maker.py b/examples/06_photo_maker.py new file mode 100644 index 0000000..559a8d4 --- /dev/null +++ b/examples/06_photo_maker.py @@ -0,0 +1,344 @@ +""" +PhotoMaker Example + +PhotoMaker enables identity-consistent photo generation by combining +multiple input photos to create new images while preserving identity. +This example demonstrates various PhotoMaker use cases and styles. +""" + +import asyncio +import os +from typing import List + +from runware import IImage, IPhotoMaker, Runware, RunwareError, UploadImageType + + +async def basic_photo_maker(): + """Generate photos using PhotoMaker with multiple input images.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Sample input images for identity reference + input_images = [ + "https://im.runware.ai/image/ws/0.5/ii/74723926-22f6-417c-befb-f2058fc88c13.webp", + "https://im.runware.ai/image/ws/0.5/ii/64acee31-100d-4aa1-a47e-6f8b432e7188.webp", + "https://im.runware.ai/image/ws/0.5/ii/1b39b0e0-6bf7-4c9a-8134-c0251b5ede01.webp", + "https://im.runware.ai/image/ws/0.5/ii/f4b4cec3-66d9-4c02-97c5-506b8813182a.webp", + ] + + request = IPhotoMaker( + model="civitai:139562@344487", # PhotoMaker compatible model + positivePrompt="img of a beautiful lady in a peaceful forest setting, natural lighting", + steps=35, + numberResults=2, + height=768, + width=512, + style="No style", # Natural style + strength=40, + outputFormat="WEBP", + includeCost=True, + inputImages=input_images, + ) + + photos: List[IImage] = await runware.photoMaker(requestPhotoMaker=request) + + print("Basic PhotoMaker Results:") + for i, photo in enumerate(photos): + print(f" Photo {i + 1}: {photo.imageURL}") + if photo.cost: + print(f" Cost: ${photo.cost}") + if photo.seed: + print(f" Seed: {photo.seed}") + + except RunwareError as e: + print(f"Error in basic PhotoMaker: {e}") + finally: + await runware.disconnect() + + +async def styled_photo_generation(): + """Generate photos with different artistic styles.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + input_images = [ + "https://im.runware.ai/image/ws/0.5/ii/74723926-22f6-417c-befb-f2058fc88c13.webp", + "https://im.runware.ai/image/ws/0.5/ii/64acee31-100d-4aa1-a47e-6f8b432e7188.webp", + ] + + # Test different artistic styles + styles = [ + ( + "Cinematic", + "img of a person in a dramatic movie scene with cinematic lighting", + ), + ( + "Digital Art", + "img of a person as a digital art character with vibrant colors", + ), + ( + "Fantasy art", + "img of a person as a fantasy character in a magical realm", + ), + ("Comic book", "img of a person in comic book art style with bold lines"), + ] + + for style_name, prompt in styles: + print(f"Generating {style_name} style...") + + request = IPhotoMaker( + model="civitai:139562@344487", + positivePrompt=prompt, + steps=30, + numberResults=1, + height=1024, + width=768, + style=style_name, + strength=50, + outputFormat="PNG", + inputImages=input_images, + ) + + photos: List[IImage] = await runware.photoMaker(requestPhotoMaker=request) + + for photo in photos: + print(f" {style_name} result: {photo.imageURL}") + + except RunwareError as e: + print(f"Error in styled photo generation: {e}") + finally: + await runware.disconnect() + + +async def professional_portraits(): + """Generate professional portrait photos with various settings.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + input_images = [ + "https://im.runware.ai/image/ws/0.5/ii/74723926-22f6-417c-befb-f2058fc88c13.webp", + "https://im.runware.ai/image/ws/0.5/ii/64acee31-100d-4aa1-a47e-6f8b432e7188.webp", + "https://im.runware.ai/image/ws/0.5/ii/1b39b0e0-6bf7-4c9a-8134-c0251b5ede01.webp", + ] + + # Professional portrait scenarios + scenarios = [ + { + "name": "Business Portrait", + "prompt": ( + "img of a professional person in business attire, office background, confident expression" + ), + "strength": 30, + }, + { + "name": "Casual Portrait", + "prompt": ( + "img of a person in casual clothing, natural outdoor setting, relaxed smile" + ), + "strength": 35, + }, + { + "name": "Artistic Portrait", + "prompt": ( + "img of a person with dramatic lighting, artistic composition, professional photography" + ), + "strength": 45, + }, + ] + + for scenario in scenarios: + print(f"Creating {scenario['name']}...") + + request = IPhotoMaker( + model="civitai:139562@344487", + positivePrompt=scenario["prompt"], + steps=40, # Higher steps for quality + numberResults=1, + height=1024, + width=768, + style="Photographic", # Realistic style + strength=scenario["strength"], + outputFormat="PNG", + includeCost=True, + inputImages=input_images, + ) + + photos: List[IImage] = await runware.photoMaker(requestPhotoMaker=request) + + for photo in photos: + print(f" {scenario['name']}: {photo.imageURL}") + if photo.cost: + print(f" Processing cost: ${photo.cost}") + + except RunwareError as e: + print(f"Error in professional portraits: {e}") + finally: + await runware.disconnect() + + +async def creative_photo_scenarios(): + """Generate creative photo scenarios with thematic elements.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + input_images = [ + "https://im.runware.ai/image/ws/0.5/ii/74723926-22f6-417c-befb-f2058fc88c13.webp", + "https://im.runware.ai/image/ws/0.5/ii/64acee31-100d-4aa1-a47e-6f8b432e7188.webp", + ] + + # Creative scenarios with specific themes + creative_themes = [ + { + "theme": "Vintage Portrait", + "prompt": ( + "img of a person in vintage 1920s clothing, sepia tones, classic photography style" + ), + "style": "Enhance", + "dimensions": (768, 1024), + }, + { + "theme": "Superhero Style", + "prompt": ( + "img of a person as a superhero character, dynamic pose, comic book style" + ), + "style": "Comic book", + "dimensions": (512, 768), + }, + { + "theme": "Fantasy Character", + "prompt": ( + "img of a person as an elegant elf character in a mystical forest" + ), + "style": "Fantasy art", + "dimensions": (768, 1024), + }, + { + "theme": "Futuristic Portrait", + "prompt": ( + "img of a person in futuristic sci-fi setting with neon lighting" + ), + "style": "Digital Art", + "dimensions": (1024, 768), + }, + ] + + for theme_config in creative_themes: + print(f"Creating {theme_config['theme']}...") + + width, height = theme_config["dimensions"] + + request = IPhotoMaker( + model="civitai:139562@344487", + positivePrompt=theme_config["prompt"], + steps=35, + numberResults=1, + height=height, + width=width, + style=theme_config["style"], + strength=50, # Higher strength for creative themes + outputFormat="WEBP", + inputImages=input_images, + ) + + photos: List[IImage] = await runware.photoMaker(requestPhotoMaker=request) + + for photo in photos: + print(f" {theme_config['theme']}: {photo.imageURL}") + print(f" Style: {theme_config['style']}") + print(f" Dimensions: {width}x{height}") + + except RunwareError as e: + print(f"Error in creative scenarios: {e}") + finally: + await runware.disconnect() + + +async def upload_and_generate(): + """Upload custom images and use them with PhotoMaker.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Example of uploading images (URLs in this case, but could be local files) + image_urls = [ + "https://im.runware.ai/image/ws/0.5/ii/74723926-22f6-417c-befb-f2058fc88c13.webp", + "https://im.runware.ai/image/ws/0.5/ii/64acee31-100d-4aa1-a47e-6f8b432e7188.webp", + ] + + # Upload images and get UUIDs + uploaded_images = [] + for i, url in enumerate(image_urls): + print(f"Uploading image {i + 1}...") + uploaded: UploadImageType = await runware.uploadImage(url) + if uploaded and uploaded.imageUUID: + uploaded_images.append(uploaded.imageUUID) + print(f" Uploaded successfully: {uploaded.imageUUID}") + else: + print(f" Failed to upload image {i + 1}") + + if len(uploaded_images) >= 2: + # Use uploaded images for PhotoMaker + request = IPhotoMaker( + model="civitai:139562@344487", + positivePrompt="img of a person in a beautiful garden setting, golden hour lighting, professional photography", + steps=32, + numberResults=1, + height=1024, + width=768, + style="Photographic", + strength=40, + outputFormat="PNG", + includeCost=True, + inputImages=uploaded_images, # Use uploaded UUIDs + ) + + photos: List[IImage] = await runware.photoMaker(requestPhotoMaker=request) + + print("PhotoMaker with uploaded images:") + for photo in photos: + print(f" Generated photo: {photo.imageURL}") + if photo.cost: + print(f" Cost: ${photo.cost}") + else: + print("Not enough images uploaded successfully") + + except RunwareError as e: + print(f"Error with upload and generate: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run all PhotoMaker examples.""" + print("=== Basic PhotoMaker ===") + await basic_photo_maker() + + print("\n=== Styled Photo Generation ===") + await styled_photo_generation() + + print("\n=== Professional Portraits ===") + await professional_portraits() + + print("\n=== Creative Photo Scenarios ===") + await creative_photo_scenarios() + + print("\n=== Upload and Generate ===") + await upload_and_generate() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/07_prompt_enhancement.py b/examples/07_prompt_enhancement.py new file mode 100644 index 0000000..7812f5a --- /dev/null +++ b/examples/07_prompt_enhancement.py @@ -0,0 +1,323 @@ +""" +Prompt Enhancement Example + +This example demonstrates how to use Runware's prompt enhancement feature +to automatically improve and expand prompts for better image generation results. +The enhanced prompts include more descriptive language, artistic terms, and +technical specifications that lead to higher quality outputs. +""" + +import asyncio +import os +from typing import List + +from runware import ( + IEnhancedPrompt, + IImage, + IImageInference, + IPromptEnhance, + Runware, + RunwareError, +) + + +async def basic_prompt_enhancement(): + """Enhance simple prompts to create more detailed descriptions.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Simple prompts to enhance + simple_prompts = [ + "a cat", + "sunset over mountains", + "beautiful woman", + "futuristic city", + ] + + for original_prompt in simple_prompts: + print(f"\nOriginal prompt: '{original_prompt}'") + + enhancer = IPromptEnhance( + prompt=original_prompt, + promptVersions=3, # Generate 3 different enhanced versions + promptMaxLength=200, # Maximum length for enhanced prompts + includeCost=True, + ) + + enhanced_prompts: List[IEnhancedPrompt] = await runware.promptEnhance( + promptEnhancer=enhancer + ) + + print("Enhanced versions:") + for i, enhanced in enumerate(enhanced_prompts, 1): + print(f" {i}. {enhanced.text}") + if enhanced.cost: + print(f" Cost: ${enhanced.cost}") + + except RunwareError as e: + print(f"Error in basic prompt enhancement: {e}") + finally: + await runware.disconnect() + + +async def artistic_prompt_enhancement(): + """Enhance prompts specifically for artistic and creative image generation.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Artistic prompts that can benefit from enhancement + artistic_prompts = [ + "abstract painting", + "portrait in renaissance style", + "cyberpunk street scene", + "magical forest", + ] + + for prompt in artistic_prompts: + print(f"\nArtistic prompt: '{prompt}'") + + enhancer = IPromptEnhance( + prompt=prompt, + promptVersions=2, + promptMaxLength=300, # Longer prompts for artistic detail + includeCost=True, + ) + + enhanced_prompts: List[IEnhancedPrompt] = await runware.promptEnhance( + promptEnhancer=enhancer + ) + + print("Enhanced artistic descriptions:") + for i, enhanced in enumerate(enhanced_prompts, 1): + print(f" Version {i}:") + print(f" {enhanced.text}") + if enhanced.cost: + print(f" Processing cost: ${enhanced.cost}") + + except RunwareError as e: + print(f"Error in artistic prompt enhancement: {e}") + finally: + await runware.disconnect() + + +async def photography_prompt_enhancement(): + """Enhance prompts for photorealistic image generation with technical terms.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Photography-focused prompts + photography_prompts = [ + "professional headshot", + "landscape photography", + "macro flower photo", + "street photography at night", + ] + + for prompt in photography_prompts: + print(f"\nPhotography prompt: '{prompt}'") + + enhancer = IPromptEnhance( + prompt=prompt, promptVersions=2, promptMaxLength=250, includeCost=True + ) + + enhanced_prompts: List[IEnhancedPrompt] = await runware.promptEnhance( + promptEnhancer=enhancer + ) + + print("Enhanced with photography terms:") + for i, enhanced in enumerate(enhanced_prompts, 1): + print(f" Version {i}: {enhanced.text}") + + except RunwareError as e: + print(f"Error in photography prompt enhancement: {e}") + finally: + await runware.disconnect() + + +async def compare_original_vs_enhanced(): + """Compare image generation results using original vs enhanced prompts.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + original_prompt = "a dragon in a castle" + print(f"Comparing results for: '{original_prompt}'") + + # First, enhance the prompt + enhancer = IPromptEnhance( + prompt=original_prompt, + promptVersions=1, # Just one enhanced version for comparison + promptMaxLength=180, + ) + + enhanced_prompts: List[IEnhancedPrompt] = await runware.promptEnhance( + promptEnhancer=enhancer + ) + + enhanced_prompt = enhanced_prompts[0].text + print(f"Enhanced to: '{enhanced_prompt}'") + + # Generate image with original prompt + print("\nGenerating with original prompt...") + original_request = IImageInference( + positivePrompt=original_prompt, + model="civitai:4384@128713", + numberResults=1, + height=768, + width=768, + seed=42, # Fixed seed for fair comparison + ) + + original_images: List[IImage] = await runware.imageInference( + requestImage=original_request + ) + + # Generate image with enhanced prompt + print("Generating with enhanced prompt...") + enhanced_request = IImageInference( + positivePrompt=enhanced_prompt, + model="civitai:4384@128713", + numberResults=1, + height=768, + width=768, + seed=42, # Same seed for comparison + ) + + enhanced_images: List[IImage] = await runware.imageInference( + requestImage=enhanced_request + ) + + # Display results + print("\nComparison Results:") + print(f"Original prompt result: {original_images[0].imageURL}") + print(f"Enhanced prompt result: {enhanced_images[0].imageURL}") + print( + "\nThe enhanced prompt should produce more detailed and visually appealing results!" + ) + + except RunwareError as e: + print(f"Error in comparison: {e}") + finally: + await runware.disconnect() + + +async def batch_prompt_enhancement(): + """Enhance multiple prompts efficiently in batch.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Batch of prompts to enhance + prompt_batch = [ + "cozy coffee shop", + "space exploration", + "underwater scene", + "medieval knight", + "tropical paradise", + ] + + print("Batch enhancing multiple prompts...") + + # Create enhancement tasks + enhancement_tasks = [] + for prompt in prompt_batch: + enhancer = IPromptEnhance( + prompt=prompt, promptVersions=1, promptMaxLength=150 + ) + task = runware.promptEnhance(promptEnhancer=enhancer) + enhancement_tasks.append((prompt, task)) + + # Execute all enhancements concurrently + results = await asyncio.gather( + *[task for _, task in enhancement_tasks], return_exceptions=True + ) + + # Process results + print("\nBatch Enhancement Results:") + for i, (original_prompt, result) in enumerate(zip(prompt_batch, results)): + print(f"\n{i + 1}. Original: '{original_prompt}'") + + if isinstance(result, Exception): + print(f" Error: {result}") + else: + enhanced_prompts: List[IEnhancedPrompt] = result + if enhanced_prompts: + print(f" Enhanced: '{enhanced_prompts[0].text}'") + + except RunwareError as e: + print(f"Error in batch enhancement: {e}") + finally: + await runware.disconnect() + + +async def length_variation_testing(): + """Test prompt enhancement with different maximum length settings.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + base_prompt = "magical wizard casting spells" + lengths = [50, 100, 200, 300] + + print(f"Testing different enhancement lengths for: '{base_prompt}'") + + for max_length in lengths: + print(f"\nMax length: {max_length} characters") + + enhancer = IPromptEnhance( + prompt=base_prompt, promptVersions=1, promptMaxLength=max_length + ) + + enhanced_prompts: List[IEnhancedPrompt] = await runware.promptEnhance( + promptEnhancer=enhancer + ) + + if enhanced_prompts: + enhanced_text = enhanced_prompts[0].text + actual_length = len(enhanced_text) + print(f" Result ({actual_length} chars): {enhanced_text}") + + except RunwareError as e: + print(f"Error in length variation testing: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run all prompt enhancement examples.""" + print("=== Basic Prompt Enhancement ===") + await basic_prompt_enhancement() + + print("\n=== Artistic Prompt Enhancement ===") + await artistic_prompt_enhancement() + + print("\n=== Photography Prompt Enhancement ===") + await photography_prompt_enhancement() + + print("\n=== Original vs Enhanced Comparison ===") + await compare_original_vs_enhanced() + + print("\n=== Batch Prompt Enhancement ===") + await batch_prompt_enhancement() + + print("\n=== Length Variation Testing ===") + await length_variation_testing() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/08_model_management.py b/examples/08_model_management.py new file mode 100644 index 0000000..caaa3db --- /dev/null +++ b/examples/08_model_management.py @@ -0,0 +1,332 @@ +""" +Model Management Example + +This example demonstrates model management capabilities including: +- Searching for models in the repository +- Uploading custom models (checkpoints, LoRA, ControlNet) +- Managing model metadata and configurations +- Working with different model architectures +""" + +import asyncio +import os + +from runware import ( + EModelArchitecture, + IModelSearch, + IModelSearchResponse, + IUploadModelCheckPoint, + IUploadModelControlNet, + IUploadModelLora, + IUploadModelResponse, + Runware, + RunwareError, +) + + +async def search_models_basic(): + """Search for models using basic criteria.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Basic model search + search_request = IModelSearch( + search="fantasy art", # Search term + limit=5, # Limit results + offset=0, # Starting position + ) + + results: IModelSearchResponse = await runware.modelSearch( + payload=search_request + ) + + print("Basic Model Search Results:") + print(f"Total models found: {results.totalResults}") + print(f"Showing first {len(results.results)} models:") + + for i, model in enumerate(results.results, 1): + print(f"\n{i}. {model.name} (v{model.version})") + print(f" Category: {model.category}") + print(f" Architecture: {model.architecture}") + print(f" AIR: {model.air}") + print(f" Tags: {', '.join(model.tags) if model.tags else 'None'}") + if model.comment: + print(f" Description: {model.comment}") + + # Access additional fields that may have been provided by API + if hasattr(model, "downloadURL"): + print(f" Download URL: {model.downloadURL}") + if hasattr(model, "shortDescription"): + print(f" Short Description: {model.shortDescription[:100]}...") + + except RunwareError as e: + print(f"Error in basic model search: {e}") + finally: + await runware.disconnect() + + +async def search_models_filtered(): + """Search for models with specific filters and categories.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Search for SDXL checkpoints + print("Searching for SDXL checkpoint models...") + checkpoint_search = IModelSearch( + category="checkpoint", + architecture=EModelArchitecture.SDXL, + visibility="public", + limit=3, + ) + + checkpoint_results: IModelSearchResponse = await runware.modelSearch( + payload=checkpoint_search + ) + + print(f"Found {checkpoint_results.totalResults} SDXL checkpoints:") + for model in checkpoint_results.results: + print(f" - {model.name}: {model.air}") + if model.defaultSteps: + print(f" Default steps: {model.defaultSteps}") + if model.defaultCFG: + print(f" Default CFG: {model.defaultCFG}") + + # Search for LoRA models + print("\nSearching for LoRA models...") + lora_search = IModelSearch(category="lora", tags=["anime", "style"], limit=3) + + lora_results: IModelSearchResponse = await runware.modelSearch( + payload=lora_search + ) + + print(f"Found {lora_results.totalResults} LoRA models:") + for model in lora_results.results: + print(f" - {model.name}: {model.air}") + if hasattr(model, "defaultWeight") and model.defaultWeight: + print(f" Default weight: {model.defaultWeight}") + if model.positiveTriggerWords: + print(f" Trigger words: {model.positiveTriggerWords}") + + # Search for ControlNet models + print("\nSearching for ControlNet models...") + controlnet_search = IModelSearch( + category="controlnet", architecture=EModelArchitecture.SDXL, limit=3 + ) + + controlnet_results: IModelSearchResponse = await runware.modelSearch( + payload=controlnet_search + ) + + print(f"Found {controlnet_results.totalResults} ControlNet models:") + for model in controlnet_results.results: + print(f" - {model.name}: {model.air}") + if hasattr(model, "conditioning") and model.conditioning: + print(f" Conditioning: {model.conditioning}") + + except RunwareError as e: + print(f"Error in filtered model search: {e}") + finally: + await runware.disconnect() + + +async def upload_checkpoint_model(): + """Upload a custom checkpoint model.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Configure checkpoint model upload + checkpoint_payload = IUploadModelCheckPoint( + air="runware:68487@0862923414", # Unique AIR identifier + name="SDXL Model", + heroImageURL="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/01.png?download=true", + downloadURL="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors?download=true", + uniqueIdentifier="unique_checkpoint_id_12345678901234567890123456789012", + version="1.0", + tags=["realistic", "portrait", "photography"], + architecture="sdxl", + type="base", # Required for checkpoints + defaultWeight=1.0, + format="safetensors", + positiveTriggerWords="masterpiece, best quality", + shortDescription="High-quality realistic portrait model", + private=False, + defaultScheduler="DPM++ 2M Karras", # Required for checkpoints + defaultSteps=30, + defaultGuidanceScale=7.5, + comment="Custom trained model for realistic portraits", + ) + + print("Uploading checkpoint model...") + upload_result: IUploadModelResponse = await runware.modelUpload( + requestModel=checkpoint_payload + ) + + if upload_result: + print(f"Checkpoint upload successful!") + print(f" AIR: {upload_result.air}") + print(f" Task UUID: {upload_result.taskUUID}") + print(f" Task Type: {upload_result.taskType}") + else: + print("Checkpoint upload failed") + + except RunwareError as e: + print(f"Error uploading checkpoint model: {e}") + finally: + await runware.disconnect() + + +async def upload_lora_model(): + """Upload a custom LoRA model.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Configure LoRA model upload + lora_payload = IUploadModelLora( + air="runware:68487@08629", + name="Ghibli lora", + heroImageURL="https://huggingface.co/ntc-ai/SDXL-LoRA-slider.Studio-Ghibli-style/resolve/main/images/Studio%20Ghibli%20style_17_3.0.png", + downloadURL="https://huggingface.co/ntc-ai/SDXL-LoRA-slider.Studio-Ghibli-style/resolve/main/Studio%20Ghibli%20style.safetensors?download=true", + uniqueIdentifier="unique_lora_id_abcdefghijklmnopqrstuvwxyz123456789012", + version="2.0", + tags=["anime", "style", "cartoon"], + architecture="sdxl", + format="safetensors", + defaultWeight=0.8, # Typical LoRA weight + positiveTriggerWords="anime style, cel shading", + shortDescription="Anime art style enhancement LoRA", + private=False, + comment="Trained on high-quality anime artwork", + ) + + print("Uploading LoRA model...") + upload_result: IUploadModelResponse = await runware.modelUpload( + requestModel=lora_payload + ) + + if upload_result: + print(f"LoRA upload successful!") + print(f" AIR: {upload_result.air}") + print(f" Task UUID: {upload_result.taskUUID}") + else: + print("LoRA upload failed") + + except RunwareError as e: + print(f"Error uploading LoRA model: {e}") + finally: + await runware.disconnect() + + +async def upload_controlnet_model(): + """Upload a custom ControlNet model.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Configure ControlNet model upload + controlnet_payload = IUploadModelControlNet( + air="runware:68487@08629112", + name="Custom Canny ControlNet", + downloadURL="https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0-small/resolve/main/diffusion_pytorch_model.safetensors?download=true", + uniqueIdentifier="unique_controlnet_id_987654321098765432109876543210", + version="1.5", + tags=["controlnet", "canny", "edge"], + architecture="sdxl", + format="safetensors", + conditioning="canny", # Required for ControlNet + shortDescription="Canny edge detection ControlNet for SDXL", + private=False, + comment="Fine-tuned for better edge detection accuracy", + ) + + print("Uploading ControlNet model...") + upload_result: IUploadModelResponse = await runware.modelUpload( + requestModel=controlnet_payload + ) + + if upload_result: + print(f"ControlNet upload successful!") + print(f" AIR: {upload_result.air}") + print(f" Task UUID: {upload_result.taskUUID}") + else: + print("ControlNet upload failed") + + except RunwareError as e: + print(f"Error uploading ControlNet model: {e}") + finally: + await runware.disconnect() + + +async def search_runware_models(): + """Search for runware's private models.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Search for runware's private models + private_search = IModelSearch( + visibility="private", limit=10 # Only private models + ) + + results: IModelSearchResponse = await runware.modelSearch( + payload=private_search + ) + + print("runware's Private Models:") + print(f"Total private models: {results.totalResults}") + + if results.results: + for i, model in enumerate(results.results, 1): + print(f"\n{i}. {model.name} (v{model.version})") + print(f" AIR: {model.air}") + print(f" Category: {model.category}") + print(f" Private: {model.private}") + if model.comment: + print(f" Notes: {model.comment}") + else: + print("No private models found") + + except RunwareError as e: + print(f"Error searching runware models: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run all model management examples.""" + print("=== Basic Model Search ===") + await search_models_basic() + + print("\n=== Filtered Model Search ===") + await search_models_filtered() + + print("\n=== Upload Checkpoint Model ===") + await upload_checkpoint_model() + + print("\n=== Upload LoRA Model ===") + await upload_lora_model() + + print("\n=== Upload ControlNet Model ===") + await upload_controlnet_model() + + print("\n=== Search runware Models ===") + await search_runware_models() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/09_image_to_image_inpainting.py b/examples/09_image_to_image_inpainting.py new file mode 100644 index 0000000..298bb2d --- /dev/null +++ b/examples/09_image_to_image_inpainting.py @@ -0,0 +1,417 @@ +""" +Image-to-Image and Inpainting Example + +This example demonstrates advanced image-to-image generation techniques: +- Seed image transformations with different strengths +- Inpainting with mask-based editing +- Outpainting for image extension +- InstantID for identity preservation +- IP Adapters for style transfer +- Reference image guidance +""" + +import asyncio +import os +from typing import List + +from runware import ( + IEmbedding, + IImage, + IImageInference, + IIpAdapter, + IOutpaint, + Runware, + RunwareError, +) + + +async def basic_image_to_image(): + """Transform existing images using seed images with different strengths.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Source image for transformation + seed_image = "https://img.freepik.com/free-photo/macro-picture-red-leaf-lights-against-black-background_181624-32636.jpg" + + # Test different transformation strengths + strengths = [0.3, 0.5, 0.7, 0.9] + + for strength in strengths: + print(f"Transforming with strength {strength}...") + + request = IImageInference( + positivePrompt="vibrant digital art, neon colors, cyberpunk aesthetic, highly detailed", + model="civitai:4384@128713", + seedImage=seed_image, + strength=strength, # How much to transform the original + numberResults=1, + height=768, + width=768, + negativePrompt="blurry, low quality, monochrome", + steps=30, + CFGScale=7.5, + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + for image in images: + print(f" Strength {strength}: {image.imageURL}") + print(f" Seed used: {image.seed}") + + except RunwareError as e: + print(f"Error in image-to-image transformation: {e}") + finally: + await runware.disconnect() + + +async def inpainting_with_masks(): + """Perform selective editing using mask-based inpainting.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Base image and mask for inpainting + base_image = "https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/background.jpg" + + # Create a simple inpainting scenario + print("Performing inpainting operation...") + + # For this example, we'll use the base image and create targeted edits + request = IImageInference( + positivePrompt="beautiful garden with colorful flowers, vibrant blooms, natural lighting", + model="civitai:4384@128713", + seedImage=base_image, + strength=0.8, # High strength for significant changes + numberResults=1, + height=1024, + width=1024, + steps=40, # More steps for better inpainting quality + CFGScale=8.0, + maskMargin=32, # Blend mask edges smoothly + negativePrompt="dead plants, withered, dark, gloomy", + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Inpainting Results:") + for image in images: + print(f" Inpainted image: {image.imageURL}") + print(f" Original base: {base_image}") + + except RunwareError as e: + print(f"Error in inpainting: {e}") + finally: + await runware.disconnect() + + +async def outpainting_extension(): + """Extend images beyond their borders using outpainting.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Image to extend + source_image = "https://img.freepik.com/free-photo/macro-picture-red-leaf-lights-against-black-background_181624-32636.jpg" + + # Configure outpainting extension + outpaint_config = IOutpaint( + top=64, # Extend 64px upward + right=128, # Extend 128px to the right + bottom=64, # Extend 64px downward + left=128, # Extend 128px to the left + blur=8, # Blur radius for seamless blending + ) + + print("Extending image with outpainting...") + + request = IImageInference( + positivePrompt="seamless natural extension, consistent lighting and style, photorealistic", + model="civitai:4384@128713", + seedImage=source_image, + outpaint=outpaint_config, + width=1024, + height=640, + strength=0.6, + numberResults=1, + steps=35, + CFGScale=7.0, + negativePrompt="seams, inconsistent lighting, artifacts", + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Outpainting Results:") + for image in images: + print(f" Extended image: {image.imageURL}") + print( + f" Extensions: top={outpaint_config.top}, right={outpaint_config.right}" + ) + print( + f" Extensions: bottom={outpaint_config.bottom}, left={outpaint_config.left}" + ) + + except RunwareError as e: + print(f"Error in outpainting: {e}") + finally: + await runware.disconnect() + + +async def ip_adapter_style_transfer(): + """Apply style transfer using IP Adapters.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Style reference images + style_images = [ + "https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/background.jpg", + "https://img.freepik.com/free-photo/macro-picture-red-leaf-lights-against-black-background_181624-32636.jpg", + ] + + # Configure IP Adapters + ip_adapters = [] + for i, style_img in enumerate(style_images): + ip_adapter = IIpAdapter( + model="runware:55@1", + guideImage=style_img, + weight=0.6, + ) + ip_adapters.append(ip_adapter) + + print("Applying style transfer with IP Adapters...") + + request = IImageInference( + positivePrompt="beautiful landscape painting, artistic composition, masterpiece quality", + model="civitai:288584@324619", + ipAdapters=ip_adapters, + numberResults=1, + height=1024, + width=1024, + steps=35, + CFGScale=8.0, + negativePrompt="low quality, blurry, distorted", + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("IP Adapter Style Transfer Results:") + for image in images: + print(f" Style-transferred image: {image.imageURL}") + print(f" Applied {len(ip_adapters)} style references") + + except RunwareError as e: + print(f"Error with IP Adapter: {e}") + finally: + await runware.disconnect() + + +async def reference_guided_generation(): + """Generate images guided by reference images.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Multiple reference images for guidance + reference_images = [ # right now it supports only 1 image in a list + "https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/background.jpg" + ] + print("Generating with reference image guidance...") + + request = IImageInference( + positivePrompt="epic fantasy landscape, magical atmosphere, vibrant colors, cinematic composition", + model="civitai:4384@128713", + referenceImages=reference_images, + numberResults=2, + height=1024, + width=1024, + steps=30, + CFGScale=7.5, + seed=98765, + negativePrompt="dark, gloomy, low quality, blurry", + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Reference-guided generation results:") + for i, image in enumerate(images, 1): + print(f" Generated image {i}: {image.imageURL}") + print(f" Guided by {len(reference_images)} reference images") + print(f" Seed: {image.seed}") + + except RunwareError as e: + print(f"Error in reference-guided generation: {e}") + finally: + await runware.disconnect() + + +async def embedding_enhanced_generation(): + """Generate images using textual embeddings for enhanced control.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + # Configure textual embeddings + embeddings = [ + IEmbedding(model="civitai:7808@9208"), + IEmbedding(model="civitai:4629@5637"), + ] + + print("Generating with textual embeddings...") + + request = IImageInference( + positivePrompt="award-winning photography, professional composition, perfect lighting", + model="civitai:4384@128713", + embeddings=embeddings, + numberResults=1, + height=1024, + width=1024, + steps=35, + CFGScale=8.0, + negativePrompt="amateur, poor lighting, distorted", + ) + + images: List[IImage] = await runware.imageInference(requestImage=request) + + print("Embedding-enhanced results:") + for image in images: + print(f" Enhanced image: {image.imageURL}") + print(f" Used {len(embeddings)} textual embeddings") + + except RunwareError as e: + print(f"Error with embeddings: {e}") + finally: + await runware.disconnect() + + +async def comprehensive_editing_workflow(): + """Demonstrate a comprehensive image editing workflow combining multiple techniques.""" + + runware = Runware(api_key=os.getenv("RUNWARE_API_KEY")) + + try: + await runware.connect() + + print("=== Comprehensive Image Editing Workflow ===") + + # Starting image + original_image = "https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/background.jpg" + + # Step 1: Style transformation + print("\n1. Applying artistic style transformation...") + style_request = IImageInference( + positivePrompt="impressionist painting style, soft brushstrokes, warm colors, artistic masterpiece", + model="civitai:4384@128713", + seedImage=original_image, + strength=0.6, + numberResults=1, + height=1024, + width=1024, + steps=30, + CFGScale=7.5, + ) + + style_images: List[IImage] = await runware.imageInference( + requestImage=style_request + ) + styled_image_url = style_images[0].imageURL + print(f" Style applied: {styled_image_url}") + + # Step 2: Extend the styled image with outpainting + print("\n2. Extending image with outpainting...") + outpaint_config = IOutpaint(top=64, right=64, bottom=64, left=64, blur=8) + + extend_request = IImageInference( + positivePrompt="seamless extension, consistent artistic style, harmonious composition", + model="civitai:4384@128713", + seedImage=styled_image_url, + outpaint=outpaint_config, + width=1280, + height=640, + strength=0.5, + numberResults=1, + steps=35, + CFGScale=7.0, + ) + + extended_images: List[IImage] = await runware.imageInference( + requestImage=extend_request + ) + extended_image_url = extended_images[0].imageURL + print(f" Extended image: {extended_image_url}") + + # Step 3: Final enhancement with reference guidance + print("\n3. Final enhancement with reference guidance...") + reference_image = [ # right now it supports only 1 image in a list + "https://img.freepik.com/free-photo/macro-picture-red-leaf-lights-against-black-background_181624-32636.jpg" + ] + + enhance_request = IImageInference( + positivePrompt="masterpiece quality, enhanced details, perfect composition, museum-worthy art", + model="civitai:4384@128713", + seedImage=extended_image_url, + referenceImages=reference_image, + width=1280, + height=640, + strength=0.3, # Light enhancement + numberResults=1, + steps=40, + CFGScale=8.0, + ) + + final_images: List[IImage] = await runware.imageInference( + requestImage=enhance_request + ) + final_image_url = final_images[0].imageURL + + print("\n=== Workflow Complete ===") + print(f"Original: {original_image}") + print(f"Styled: {styled_image_url}") + print(f"Extended: {extended_image_url}") + print(f"Final: {final_image_url}") + print("\nThe image has been transformed through multiple editing stages!") + + except RunwareError as e: + print(f"Error in comprehensive workflow: {e}") + finally: + await runware.disconnect() + + +async def main(): + """Run all image-to-image and inpainting examples.""" + print("=== Basic Image-to-Image Transformation ===") + await basic_image_to_image() + + print("\n=== Inpainting with Masks ===") + await inpainting_with_masks() + + print("\n=== Outpainting Extension ===") + await outpainting_extension() + + print("\n=== IP Adapter Style Transfer ===") + await ip_adapter_style_transfer() + + print("\n=== Reference-Guided Generation ===") + await reference_guided_generation() + + print("\n=== Embedding-Enhanced Generation ===") + await embedding_enhanced_generation() + + print("\n=== Comprehensive Editing Workflow ===") + await comprehensive_editing_workflow() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/ace++/local_local_1.py b/examples/ace++/local_local_1.py deleted file mode 100644 index 5daabec..0000000 --- a/examples/ace++/local_local_1.py +++ /dev/null @@ -1,36 +0,0 @@ -import os - -from runware import Runware, IImageInference, IAcePlusPlus - - -async def main() -> None: - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - mask_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/local/local_1_m.webp" - init_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/local/local_1.webp" - request_image = IImageInference( - positivePrompt="By referencing the mask, restore a partial image from the doodle {image} that aligns with the textual explanation: \"1 white old owl\".", - model="runware:102@1", - height=1024, - width=1024, - numberResults=1, - steps=28, - CFGScale=50.0, - acePlusPlus=IAcePlusPlus( - inputImages=[init_image], - inputMasks=[mask_image], - repaintingScale=0.5, - taskType="local_editing" - ), - ) - images = await runware.imageInference(requestImage=request_image) - for image in images: - print(f"Image URL: {image.imageURL}") - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/examples/ace++/logo_paste.py b/examples/ace++/logo_paste.py deleted file mode 100644 index 7138e78..0000000 --- a/examples/ace++/logo_paste.py +++ /dev/null @@ -1,40 +0,0 @@ -import os - -from runware import Runware, IImageInference, IAcePlusPlus - - -async def main() -> None: - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - reference_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/logo_paste/1_ref.png" - mask_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/logo_paste/1_1_m.png" - init_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/logo_paste/1_1_edit.png" - request_image = IImageInference( - positivePrompt="The logo is printed on the headphones.", - model="runware:102@1", - taskUUID="68020b8f-bbcf-4779-ba51-4f3bb00aef6a", - height=1024, - width=1024, - numberResults=1, - steps=28, - CFGScale=50.0, - referenceImages=[reference_image], - acePlusPlus=IAcePlusPlus( - inputImages=[init_image], - inputMasks=[mask_image], - repaintingScale=1.0, - taskType="subject" - ), - ) - images = await runware.imageInference(requestImage=request_image) - for image in images: - print(f"Image URL: {image.imageURL}") - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/examples/ace++/movie_poster_1.py b/examples/ace++/movie_poster_1.py deleted file mode 100644 index 6515264..0000000 --- a/examples/ace++/movie_poster_1.py +++ /dev/null @@ -1,40 +0,0 @@ -import os - -from runware import Runware, IImageInference, IAcePlusPlus - - -async def main() -> None: - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - reference_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/movie_poster/1_ref.png" - mask_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/movie_poster/1_1_m.png" - init_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/movie_poster/1_1_edit.png" - request_image = IImageInference( - positivePrompt="The man is facing the camera and is smiling.", - model="runware:102@1", - taskUUID="68020b8f-bbcf-4779-ba51-4f3bb00aef6a", - height=768, - width=1024, - numberResults=1, - steps=28, - CFGScale=50.0, - referenceImages=[reference_image], - acePlusPlus=IAcePlusPlus( - inputImages=[init_image], - inputMasks=[mask_image], - repaintingScale=1.0, - taskType="portrait" - ), - ) - images = await runware.imageInference(requestImage=request_image) - for image in images: - print(f"Image URL: {image.imageURL}") - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/examples/ace++/photo_editing_1.py b/examples/ace++/photo_editing_1.py deleted file mode 100644 index 3f33943..0000000 --- a/examples/ace++/photo_editing_1.py +++ /dev/null @@ -1,39 +0,0 @@ -import os - -from runware import Runware, IImageInference, IAcePlusPlus - - -async def main() -> None: - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - init_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/photo_editing/1_1_edit.png" - mask_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/photo_editing/1_1_m.png" - reference_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/application/photo_editing/1_ref.png" - request_image = IImageInference( - positivePrompt="The item is put on the ground.", - model="runware:102@1", - height=1024, - width=1024, - numberResults=1, - steps=28, - CFGScale=50.0, - referenceImages=[reference_image], - acePlusPlus=IAcePlusPlus( - inputImages=[init_image], - inputMasks=[mask_image], - repaintingScale=1.0, - taskType="subject" - ), - ) - images = await runware.imageInference(requestImage=request_image) - for image in images: - print(f"Image URL: {image.imageURL}") - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/examples/ace++/portrait_human_1.py b/examples/ace++/portrait_human_1.py deleted file mode 100644 index b216438..0000000 --- a/examples/ace++/portrait_human_1.py +++ /dev/null @@ -1,36 +0,0 @@ -import os - -from runware import Runware, IImageInference, IAcePlusPlus - - -async def main() -> None: - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - reference_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/portrait/human_1.jpg" - request_image = IImageInference( - positivePrompt="Maintain the facial features, A girl is wearing a neat police uniform and sporting a badge. She is smiling with a friendly and confident demeanor. The background is blurred, featuring a cartoon logo.", - model="runware:102@1", - height=1024, - width=1024, - seed=4194866942, - numberResults=1, - steps=28, - CFGScale=50.0, - referenceImages=[reference_image], - acePlusPlus=IAcePlusPlus( - repaintingScale=0.5, - taskType="portrait" - ), - ) - images = await runware.imageInference(requestImage=request_image) - for image in images: - print(f"Image URL: {image.imageURL}") - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/examples/ace++/subject_subject_1.py b/examples/ace++/subject_subject_1.py deleted file mode 100644 index 930e706..0000000 --- a/examples/ace++/subject_subject_1.py +++ /dev/null @@ -1,35 +0,0 @@ -import os - -from runware import Runware, IImageInference, IAcePlusPlus - - -async def main() -> None: - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - reference_image = "https://raw.githubusercontent.com/ali-vilab/ACE_plus/refs/heads/main/assets/samples/subject/subject_1.jpg" - request_image = IImageInference( - positivePrompt="Display the logo in a minimalist style printed in white on a matte black ceramic coffee mug, alongside a steaming cup of coffee on a cozy cafe table.", - model="runware:102@1", - height=1024, - width=1024, - seed=2935362780, - numberResults=1, - steps=28, - CFGScale=50.0, - referenceImages=[reference_image], - acePlusPlus=IAcePlusPlus( - repaintingScale=1, - taskType="subject" - ), - ) - images = await runware.imageInference(requestImage=request_image) - for image in images: - print(f"Image URL: {image.imageURL}") - - -if __name__ == "__main__": - import asyncio - asyncio.run(main()) diff --git a/examples/dalmatian.jpg b/examples/dalmatian.jpg deleted file mode 100644 index d07bd21..0000000 Binary files a/examples/dalmatian.jpg and /dev/null differ diff --git a/examples/image_background_removal_withModel.py b/examples/image_background_removal_withModel.py deleted file mode 100644 index 4c0f4ac..0000000 --- a/examples/image_background_removal_withModel.py +++ /dev/null @@ -1,38 +0,0 @@ -from runware import Runware, RunwareAPIError,IImage, IImageBackgroundRemoval -import asyncio -import os -from dotenv import load_dotenv - -load_dotenv(override=True) - - -async def main() -> None: - runware = Runware( - api_key=os.environ.get("RUNWARE_API_KEY"), - ) - await runware.connect() - - request_image = IImageBackgroundRemoval( - taskUUID="abcdbb9c-3bd3-4d75-9234-bffeef994772", - model="runware:110@1", - inputImage="https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/headphones.jpeg" - ) - - print(f"Payload: {request_image}") - try: - processed_images: List[IImage] = await runware.imageBackgroundRemoval( - removeImageBackgroundPayload=request_image - ) - except RunwareAPIError as e: - print(f"API Error: {e}") - print(f"Error Code: {e.code}") - except Exception as e: - print(f"Unexpected Error: {e}") - else: - print("Processed Image with the background removed:") - print(processed_images) - for image in processed_images: - print(image.imageURL) - - -asyncio.run(main()) \ No newline at end of file diff --git a/examples/image_background_removal_with_settings.py b/examples/image_background_removal_with_settings.py deleted file mode 100644 index 2b0a003..0000000 --- a/examples/image_background_removal_with_settings.py +++ /dev/null @@ -1,51 +0,0 @@ -from runware import Runware, RunwareAPIError,IImage, IImageBackgroundRemoval, IBackgroundRemovalSettings -import asyncio -import os -from dotenv import load_dotenv - -load_dotenv(override=True) - - -async def main() -> None: - runware = Runware( - api_key=os.environ.get("RUNWARE_API_KEY") - ) - await runware.connect() - background_removal_settings = IBackgroundRemovalSettings( - rgba=[255, 255, 255, 0], - alphaMatting=True, - postProcessMask=True, - returnOnlyMask=False, - alphaMattingErodeSize=10, - alphaMattingForegroundThreshold=240, - alphaMattingBackgroundThreshold=10 - ) - - request_image = IImageBackgroundRemoval( - taskUUID="abcdbb9c-3bd3-4d75-9234-bffeef994772", - inputImage="https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/headphones.jpeg", - settings=background_removal_settings, - outputType="URL", - outputFormat="PNG", - includeCost=True, - ) - - print(f"Payload: {request_image}") - try: - processed_images: List[IImage] = await runware.imageBackgroundRemoval( - removeImageBackgroundPayload=request_image - ) - except RunwareAPIError as e: - print(f"API Error: {e}") - print(f"Error Code: {e.code}") - except Exception as e: - print(f"Unexpected Error: {e}") - else: - print("Processed Image with the background removed:") - print(processed_images) - for image in processed_images: - print(image.imageURL) - - -asyncio.run(main()) - diff --git a/examples/image_caption.py b/examples/image_caption.py deleted file mode 100644 index 240d074..0000000 --- a/examples/image_caption.py +++ /dev/null @@ -1,46 +0,0 @@ -import asyncio -import os -import logging -from dotenv import load_dotenv -from runware import Runware, IImageToText, IImageCaption, RunwareAPIError - -# Load environment variables from .env file -load_dotenv() - -RUNWARE_API_KEY = os.environ.get("RUNWARE_API_KEY") - - -async def main() -> None: - # Create an instance of RunwareServer - runware = Runware(api_key=RUNWARE_API_KEY) - - # Connect to the Runware service - await runware.connect() - - # The image requires for the seed image. It can be the UUID of previously generated image or an a file image. - image_path = "retriever.jpg" - - # With only mandatory parameters - request_image_to_text_payload = IImageCaption(inputImage=image_path) - # With all parameters - request_image_to_text_payload = IImageCaption( - inputImage=image_path, - includeCost=True, - ) - - try: - image_to_text: IImageToText = await runware.imageCaption( - requestImageToText=request_image_to_text_payload - ) - except RunwareAPIError as e: - print(f"API Error: {e}") - print(f"Error Code: {e.code}") - except Exception as e: - print(f"Unexpected Error: {e}") - else: - print("Description of the image:") - print(image_to_text.text) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/image_upscale.py b/examples/image_upscale.py deleted file mode 100644 index b8ddf5a..0000000 --- a/examples/image_upscale.py +++ /dev/null @@ -1,52 +0,0 @@ -import asyncio -import os -import logging -from typing import List, Optional -from dotenv import load_dotenv -from runware import Runware, IImage, IImageUpscale, RunwareAPIError - - -# Load environment variables from .env file -load_dotenv() - -RUNWARE_API_KEY = os.environ.get("RUNWARE_API_KEY") - - -async def main() -> None: - # Initialize the Runware client - runware = Runware(api_key=os.environ.get("RUNWARE_API_KEY")) - # Specifies the input image to be processed https://docs.runware.ai/en/image-editing/upscaling#inputimage - inputImage = "dalmatian.jpg" - inputImage = "https://img.freepik.com/free-photo/macro-picture-red-leaf-lights-against-black-background_181624-32636.jpg" - upscale_factor = 4 - - # With only mandatory parameters - upscale_gan_payload = IImageUpscale( - inputImage=inputImage, upscaleFactor=upscale_factor - ) - - # With all parameters - upscale_gan_payload = IImageUpscale( - inputImage=inputImage, - upscaleFactor=upscale_factor, - outputType="URL", - outputFormat="PNG", - includeCost=True, - ) - try: - upscaled_images: List[IImage] = await runware.imageUpscale( - upscaleGanPayload=upscale_gan_payload - ) - except RunwareAPIError as e: - print(f"API Error: {e}") - print(f"Error Code: {e.code}") - except Exception as e: - print(f"Unexpected Error: {e}") - else: - print(f"Upscaled Images ({upscale_factor}x):") - for inputImage in upscaled_images: - print(inputImage.imageURL) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/parallel_image_inference.py b/examples/parallel_image_inference.py deleted file mode 100644 index 93987c7..0000000 --- a/examples/parallel_image_inference.py +++ /dev/null @@ -1,124 +0,0 @@ -import asyncio -import os -import logging -from typing import List, Optional -from dotenv import load_dotenv -from runware import Runware, IImage, IError, IImageInference, RunwareAPIError -from runware.types import ILora - -# Load environment variables from .env file -load_dotenv() - -RUNWARE_API_KEY = os.environ.get("RUNWARE_API_KEY") - - -# By providing the `onPartialImages` callback function, you can receive and handle the generated images incrementally, -# allowing for more responsive and interactive processing of the images. -# Note that the `onPartialImages` callback function is optional. If you don't provide it, the `requestImages` -# method will still return the complete list of generated images once the entire async request is finished. -def on_partial_images(images: List[IImage], error: Optional[IError]) -> None: - if error: - print(f"API Error: {error}") - else: - print(f"Received {len(images)} partial images") - for image in images: - print(f"Partial Image URL: {image.imageURL}") - # Process or save the partial image as needed - - -# This function provides a safe way to make image requests, handling potential exceptions -# without disrupting other concurrent requests. It's particularly useful when making multiple -# image requests simultaneously. -# -# Usage: -# - Use this function with asyncio.gather for parallel processing of multiple requests. -# - Check the return value: None indicates an error occurred (details will be printed), -# while a successful result will contain the list of generated images. -async def safe_request_images(runware: Runware, request_image: IImageInference): - try: - return await runware.imageInference(requestImage=request_image) - except RunwareAPIError as e: - print(f"API Error: {e}") - print(f"Error Code: {e.code}") - return None - except Exception as e: - print(f"Unexpected Error: {e}") - return None - - -async def main() -> None: - # Create an instance of RunwareServer - runware = Runware(api_key=RUNWARE_API_KEY) - - # Connect to the Runware service - await runware.connect() - - lora_1 = [ - ILora(model="civitai:58390@62833", weight=0.4), - ILora(model="civitai:42903@232848", weight=0.3), - ILora(model="civitai:42903@222732", weight=0.3), - ] - - request_image1 = IImageInference( - positivePrompt="a beautiful sunset over the mountains", - model="civitai:36520@76907", - numberResults=2, - negativePrompt="cloudy, rainy", - onPartialImages=on_partial_images, - height=512, - width=512, - outputFormat="PNG", - ) - - request_image2 = IImageInference( - positivePrompt="a cozy hut in the woods", - model="civitai:30240@102996", - numberResults=1, - negativePrompt="modern, city", - lora=lora_1, - height=1024, - width=1024, - outputType="base64Data", - ) - - request_image3 = IImageInference( - positivePrompt="a wood workshop with tools and sawdust on the floor", - model="civitai:4384@128713", - numberResults=3, - height=1024, - width=1024, - includeCost=True, - ) - - first_images_request, second_images_request, third_image_request = ( - await asyncio.gather( - safe_request_images(runware, request_image1), - safe_request_images(runware, request_image2), - safe_request_images(runware, request_image3), - return_exceptions=True, - ) - ) - if first_images_request: - print("\nFirst Image Request Results:") - for image in first_images_request: - print(f"Image URL: {image.imageURL}") - else: - print("First Image Request Failed") - - if second_images_request: - print("\nSecond Image Request Results:") - for image in second_images_request: - print(f"imageBase64Data: {image.imageBase64Data[:100]}...") - else: - print("Second Image Request Failed") - - if third_image_request: - print("\nThird Image Request Results:") - for image in third_image_request: - print(f"Image URL: {image.imageURL}") - else: - print("Third Image Request Failed") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/prompt_enhance.py b/examples/prompt_enhance.py deleted file mode 100644 index c679252..0000000 --- a/examples/prompt_enhance.py +++ /dev/null @@ -1,53 +0,0 @@ -import asyncio -import os -import logging -from typing import List, Optional -from dotenv import load_dotenv -from runware import Runware, IPromptEnhance, IEnhancedPrompt, RunwareAPIError - -load_dotenv() - -RUNWARE_API_KEY = os.environ.get("RUNWARE_API_KEY") - - -async def main() -> None: - # Create an instance of RunwareServer - runware = Runware(api_key=RUNWARE_API_KEY) - - # Connect to the Runware service - await runware.connect() - - prompt = "A beautiful sunset over the mountains" - print(f"Original Prompt: {prompt}") - - # With only mandatory parameters - prompt_enhancer = IPromptEnhance( - prompt=prompt, - promptVersions=3, - promptMaxLength=64, - ) - # With all parameters - prompt_enhancer = IPromptEnhance( - prompt=prompt, - promptVersions=3, - promptMaxLength=300, - includeCost=True, - ) - try: - enhanced_prompts: List[IEnhancedPrompt] = await runware.promptEnhance( - promptEnhancer=prompt_enhancer - ) - except RunwareAPIError as e: - print(f"API Error: {e}") - print(f"Error Code: {e.code}") - except Exception as e: - print(f"Unexpected Error: {e}") - else: - print("Enhanced Prompts:\n") - for enhanced_prompt in enhanced_prompts: - print(enhanced_prompt.text, "\n") - # print(enhanced_prompt.cost, "\n") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/retriever.jpg b/examples/retriever.jpg deleted file mode 100644 index 235c033..0000000 Binary files a/examples/retriever.jpg and /dev/null differ diff --git a/examples/video/kling.py b/examples/video/kling.py deleted file mode 100644 index 12fb93a..0000000 --- a/examples/video/kling.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio -import os - -from runware import Runware, IVideoInference, IFrameImage, IKlingAIProviderSettings, IKlingCameraControl, IKlingCameraConfig - - -async def main(): - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - request = IVideoInference( - positivePrompt="A majestic eagle soaring through mountain peaks at golden hour, cinematic view", - model="klingai:1@1", - width=1280, - height=720, - duration=5, - numberResults=1, - includeCost=True, - CFGScale=1, - ) - - videos = await runware.videoInference(requestVideo=request) - for video in videos: - print(f"Video URL: {video.videoURL}") - print(f"Cost: {video.cost}") - print(f"Seed: {video.seed}") - print(f"Status: {video.status}") - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/examples/video/minimax.py b/examples/video/minimax.py deleted file mode 100644 index 51243d4..0000000 --- a/examples/video/minimax.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -import os - -from runware import Runware, IVideoInference, IFrameImage, IMinimaxProviderSettings - - -async def main(): - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - request = IVideoInference( - positivePrompt="[Push in, Follow] A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. [Pan left] The street opens into a small plaza where street vendors sell steaming food under colorful awnings.", - model="minimax:1@1", - width=1366, # Comment this to use i2v - height=768, # Comment this to use i2v - duration=6, - numberResults=1, - seed=10, - includeCost=True, - # frameImages=[ # Uncomment this to use t2v - # IFrameImage( - # inputImage= "https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/video_inference/woman_city.png", - # ), - # ] - providerSettings=IMinimaxProviderSettings( - promptOptimizer=True - ) - ) - - videos = await runware.videoInference(requestVideo=request) - for video in videos: - print(f"Video URL: {video.videoURL}") - print(f"Cost: {video.cost}") - print(f"Seed: {video.seed}") - print(f"Status: {video.status}") - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/examples/video/pixverse.py b/examples/video/pixverse.py deleted file mode 100644 index 7642ef0..0000000 --- a/examples/video/pixverse.py +++ /dev/null @@ -1,49 +0,0 @@ -import asyncio -import os - -from runware import Runware, IVideoInference, IFrameImage, IPixverseProviderSettings - - -async def main(): - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - request = IVideoInference( - positivePrompt="realistic video, slow motion, cinematic, high quality, 4k, 60fps", - negativePrompt="blurry", - model="pixverse:1@1", - width=1280, - height=720, - duration=5, - fps=24, - numberResults=1, - seed=10, - includeCost=True, - frameImages=[ - IFrameImage( - inputImage="https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/man_beard.jpg", - ), - IFrameImage( - inputImage="https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/abraham_lincon.png", - ), - ], - providerSettings=IPixverseProviderSettings( - effect="boom drop", - style="anime", - motionMode="normal", - watermark=False, - ) - ) - - videos = await runware.videoInference(requestVideo=request) - for video in videos: - print(f"Video URL: {video.videoURL}") - print(f"Cost: {video.cost}") - print(f"Seed: {video.seed}") - print(f"Status: {video.status}") - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/examples/video/seedance.py b/examples/video/seedance.py deleted file mode 100644 index 93636eb..0000000 --- a/examples/video/seedance.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -import os - -from runware import Runware, IVideoInference, IFrameImage, IBytedanceProviderSettings - - -async def main(): - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - request = IVideoInference( - positivePrompt=" couple in formal evening attire is caught in heavy rain on their way home, holding a black umbrella. In the flat shot, the man is wearing a black suit and the woman is wearing a white long dress. They walk slowly in the rain, and the rain drips down the umbrella. The camera moves smoothly with their steps, showing their elegant posture in the rain.", - model="bytedance:1@1", - height=1504, # Comment this to use i2v - width=640, # Comment this to use i2v - duration=5, - numberResults=1, - seed=10, - includeCost=True, - # frameImages=[ # Uncomment this to use i2v - # IFrameImage( - # inputImage="https://raw.githubusercontent.com/adilentiq/test-images/refs/heads/main/common/background.jpg", - # ), - # ], - providerSettings=IBytedanceProviderSettings( - cameraFixed=False - ) - ) - - videos = await runware.videoInference(requestVideo=request) - for video in videos: - print(f"Video URL: {video.videoURL}") - print(f"Cost: {video.cost}") - print(f"Seed: {video.seed}") - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/examples/video/veo.py b/examples/video/veo.py deleted file mode 100644 index 341c327..0000000 --- a/examples/video/veo.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -import os - -from runware import Runware, IVideoInference, IGoogleProviderSettings, IFrameImage - - -async def main(): - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - request = IVideoInference( - positivePrompt="spinning galaxy", - model="google:3@0", - width=1280, - height=720, - numberResults=1, - seed=10, - includeCost=True, - frameImages=[ # Comment this to use t2v - IFrameImage( - inputImage="https://github.com/adilentiq/test-images/blob/main/common/image_15_mb.jpg?raw=true", - ), - ], - providerSettings=IGoogleProviderSettings( # Needs only for veo3 - generateAudio=True, - enhancePrompt=True - ) - ) - videos = await runware.videoInference(requestVideo=request) - for video in videos: - print(f"Video URL: {video.videoURL}") - print(f"Cost: {video.cost}") - print(f"Seed: {video.seed}") - print(f"Status: {video.status}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/video/vidu.py b/examples/video/vidu.py deleted file mode 100644 index 5673fb5..0000000 --- a/examples/video/vidu.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio -import os - -from runware import Runware, IVideoInference, IFrameImage, IViduProviderSettings - - -async def main(): - runware = Runware( - api_key=os.getenv("RUNWARE_API_KEY"), - ) - await runware.connect() - - request = IVideoInference( - positivePrompt="A red fox moves stealthily through autumn woods, hunting for prey.", - model="vidu:1@1", - width=1920, - height=1080, - duration=5, - numberResults=1, - seed=10, - includeCost=True, - providerSettings=IViduProviderSettings( - style="anime", - movementAmplitude="auto", - ) - ) - - videos = await runware.videoInference(requestVideo=request) - for video in videos: - print(f"Video URL: {video.videoURL}") - print(f"Cost: {video.cost}") - print(f"Seed: {video.seed}") - print(f"Status: {video.status}") - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f24aef6..1877795 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ aiofiles==23.2.1 python-dotenv==1.0.1 -websockets +websockets>=12.0,<16.0 \ No newline at end of file diff --git a/runware/__init__.py b/runware/__init__.py index bb009bd..d0b2555 100644 --- a/runware/__init__.py +++ b/runware/__init__.py @@ -1,9 +1,175 @@ -from .server import RunwareServer as Runware -from .types import * -from .utils import * -from .base import * -from .logging_config import * -from .async_retry import * +from .client import Runware, RunwareClient, RunwareServer +from .core import ( + Message, + MessageType, + OperationContext, + OperationStatus, + ProgressUpdate, +) +from .exceptions import ( + RunwareAuthenticationError, + RunwareConnectionError, + RunwareError, + RunwareOperationError, + RunwareParseError, + RunwareResourceError, + RunwareServerError, + RunwareTimeoutError, + RunwareValidationError, + RunwareAPIError, +) +from .logging_config import ( + configure_component_logging, + get_logger, + get_logging_stats, + setup_logging, +) +from .messaging import MessageRouter +from .operations import ( + BaseOperation, + ImageBackgroundRemovalOperation, + ImageCaptionOperation, + ImageInferenceOperation, + ImageUpscaleOperation, + OperationManager, + PhotoMakerOperation, + PromptEnhanceOperation, + VideoInferenceOperation, +) +from .types import ( + EControlMode, + EModelArchitecture, + Environment, + EPreProcessorGroup, + File, + IAcceleratorOptions, + IAcePlusPlus, + IBackgroundRemovalSettings, + IBytedanceProviderSettings, + IControlNetGeneral, + IEmbedding, + IEnhancedPrompt, + IFrameImage, + IGoogleProviderSettings, + IImage, + IImageBackgroundRemoval, + IImageCaption, + IImageInference, + IImageToText, + IImageUpscale, + IInstantID, + IIpAdapter, + IKlingAIProviderSettings, + IKlingCameraConfig, + IKlingCameraControl, + ILora, + IMinimaxProviderSettings, + IModelSearch, + IModelSearchResponse, + IOutpaint, + IPhotoMaker, + IPixverseProviderSettings, + IPromptEnhance, + IRefiner, + IUploadModelBaseType, + IUploadModelCheckPoint, + IUploadModelControlNet, + IUploadModelLora, + IUploadModelResponse, + IVideo, + IVideoInference, + IViduProviderSettings, + UploadImageType, +) -__all__ = ["Runware"] -__version__ = "0.4.16" +__version__ = "0.5.0" + +__all__ = [ + # Core client classes + "Runware", + "RunwareClient", + "RunwareServer", + # Core types and enums + "OperationStatus", + "OperationContext", + "ProgressUpdate", + "MessageType", + "Message", + # Logging utilities + "setup_logging", + "get_logger", + "configure_component_logging", + "get_logging_stats", + # Exception classes + "RunwareError", + "RunwareTimeoutError", + "RunwareConnectionError", + "RunwareOperationError", + "RunwareValidationError", + "RunwareAuthenticationError", + "RunwareServerError", + "RunwareResourceError", + "RunwareParseError", + "RunwareAPIError", + # Operation system + "BaseOperation", + "OperationManager", + "MessageRouter", + # Concrete operation classes + "ImageInferenceOperation", + "VideoInferenceOperation", + "ImageCaptionOperation", + "ImageBackgroundRemovalOperation", + "ImageUpscaleOperation", + "PromptEnhanceOperation", + "PhotoMakerOperation", + # Request/Response types + "IImageInference", + "IImage", + "IVideoInference", + "IVideo", + "IImageCaption", + "IImageToText", + "IImageBackgroundRemoval", + "IPromptEnhance", + "IEnhancedPrompt", + "IImageUpscale", + "IPhotoMaker", + "UploadImageType", + "File", + # Environment and configuration + "Environment", + # Model and search types + "IUploadModelBaseType", + "IUploadModelResponse", + "IModelSearch", + "IModelSearchResponse", + # Preprocessing and control + "EPreProcessorGroup", + "IBackgroundRemovalSettings", + "IFrameImage", + # Provider settings + "IGoogleProviderSettings", + "IKlingAIProviderSettings", + "IKlingCameraControl", + "IKlingCameraConfig", + "IMinimaxProviderSettings", + "IPixverseProviderSettings", + "IViduProviderSettings", + "IBytedanceProviderSettings", + # Advanced features + "IAcePlusPlus", + "ILora", + "IControlNetGeneral", + "IRefiner", + "IAcceleratorOptions", + "EControlMode", + "IUploadModelCheckPoint", + "IUploadModelLora", + "IUploadModelControlNet", + "EModelArchitecture", + "IOutpaint", + "IInstantID", + "IIpAdapter", + "IEmbedding", +] diff --git a/runware/async_retry.py b/runware/async_retry.py deleted file mode 100644 index 6accc5c..0000000 --- a/runware/async_retry.py +++ /dev/null @@ -1,96 +0,0 @@ -import asyncio - - -async def asyncRetry(apiCall, options=None): - """ - Retry an asynchronous API call multiple times with configurable options. - - :param apiCall: The asynchronous function to be retried. - :param options: An optional dictionary that allows you to configure the retry behavior. - It has the following properties: - - maxRetries: The maximum number of retries before giving up (default is 1). - - delayInSeconds: The delay in seconds between each retry attempt (default is 1). - - callback: A function that will be called after each failed attempt. - :return: The result of the successful API call. - - This function retries an asynchronous API call multiple times with configurable options. - It attempts to execute the `apiCall` and returns the result if successful. If the `apiCall` - raises an exception, it calls the `callback` function (if provided), introduces a delay - before the next retry attempt, and continues retrying until the maximum number of retries - is reached. If all retry attempts are exhausted and the `apiCall` still fails, it raises - the last encountered exception. - - Example: - async def myApiCall(): - # API call logic here - ... - - result = await asyncRetry(myApiCall, options={ - 'maxRetries': 3, - 'delayInSeconds': 1, - 'callback': lambda: print('Retry attempt failed') - }) - print(result) - """ - if options is None: - options = {} - delayInSeconds = options.get("delayInSeconds", 1) - callback = options.get("callback") - maxRetries = options.get("maxRetries", 1) - - for attempt in range(maxRetries): - try: - return await apiCall() - except Exception as error: - if callback: - callback() - if attempt < maxRetries - 1: - await asyncio.sleep(delayInSeconds) - else: - raise error - - -async def asyncRetryGather(apiCalls, options=None): - """ - Retry multiple asynchronous API calls concurrently with configurable options. - - :param apiCalls: A list of asynchronous functions to be retried. - :param options: An optional dictionary that allows you to configure the retry behavior. - It has the following properties: - - maxRetries: The maximum number of retries before giving up (default is 1). - - delayInSeconds: The delay in seconds between each retry attempt (default is 1). - - callback: A function that will be called after each failed attempt. - :return: A list of results from the successful API calls. - - This function retries multiple asynchronous API calls concurrently with configurable options. - It creates tasks for each `apiCall` and executes them concurrently using `asyncio.gather()`. - Each task represents the execution of `asyncRetry` for a single `apiCall`. The results of - each successful API call are returned as a list in the same order as the input `apiCalls`. - - Example: - async def myApiCall1(): - # API call logic here - ... - - async def myApiCall2(): - # API call logic here - ... - - results = await asyncRetryGather([myApiCall1, myApiCall2], options={ - 'maxRetries': 3, - 'delayInSeconds': 1, - 'callback': lambda: print('Retry attempt failed') - }) - print(results) - """ - if options is None: - options = {} - delayInSeconds = options.get("delayInSeconds", 1) - callback = options.get("callback") - maxRetries = options.get("maxRetries", 1) - - tasks = [] - for apiCall in apiCalls: - task = asyncio.create_task(asyncRetry(apiCall, options)) - tasks.append(task) - return await asyncio.gather(*tasks) diff --git a/runware/base.py b/runware/base.py deleted file mode 100644 index e0e8fab..0000000 --- a/runware/base.py +++ /dev/null @@ -1,1564 +0,0 @@ -import inspect -import logging -import os -import re -import uuid -from asyncio import gather -from dataclasses import asdict -from typing import List, Optional, Union, Callable, Any, Dict - -from websockets.protocol import State - -from .async_retry import asyncRetry -from .types import ( - Environment, - IImageInference, - IPhotoMaker, - IImageCaption, - IImageToText, - IImageBackgroundRemoval, - IPromptEnhance, - IEnhancedPrompt, - IImageUpscale, - IUploadModelBaseType, - IUploadModelResponse, - ReconnectingWebsocketProps, - UploadImageType, - EPreProcessorGroup, - File, - ETaskType, - IModelSearch, - IModelSearchResponse, - IControlNet, - IVideo, - IVideoInference, - IGoogleProviderSettings, - IKlingAIProviderSettings, - IFrameImage, -) -from .types import IImage, IError, SdkType, ListenerType -from .utils import ( - BASE_RUNWARE_URLS, - getUUID, - fileToBase64, - createImageFromResponse, - createImageToTextFromResponse, - createEnhancedPromptsFromResponse, - instantiateDataclassList, - RunwareAPIError, - RunwareError, - instantiateDataclass, - TIMEOUT_DURATION, - accessDeepObject, - getIntervalWithPromise, - removeListener, - LISTEN_TO_IMAGES_KEY, - isLocalFile, - process_image, delay, -) - -# Configure logging -# configure_logging(log_level=logging.CRITICAL) - -logger = logging.getLogger(__name__) -MAX_POLLS_VIDEO_GENERATION = int(os.environ.get("RUNWARE_MAX_POLLS_VIDEO_GENERATION", 480)) - - -class RunwareBase: - def __init__( - self, - api_key: str, - url: str = BASE_RUNWARE_URLS[Environment.PRODUCTION], - timeout: int = TIMEOUT_DURATION, - ): - if timeout <= 0: - raise ValueError("Timeout must be greater than 0 milliseconds") - - self._ws: Optional[ReconnectingWebsocketProps] = None - self._listeners: List[ListenerType] = [] - self._apiKey: str = api_key - self._url: Optional[str] = url - self._timeout: int = timeout - self._globalMessages: Dict[str, Any] = {} - self._globalImages: List[IImage] = [] - self._globalError: Optional[IError] = None - self._connectionSessionUUID: Optional[str] = None - self._invalidAPIkey: Optional[str] = None - self._sdkType: SdkType = SdkType.SERVER - - def isWebsocketReadyState(self) -> bool: - return self._ws and self._ws.state is State.OPEN - - def isAuthenticated(self): - return self._connectionSessionUUID is not None - - def addListener( - self, - lis: Callable[[Any], Any], - check: Callable[[Any], Any], - groupKey: Optional[str] = None, - ) -> Dict[str, Callable[[], None]]: - # Get the current frame - current_frame = inspect.currentframe() - - # Get the caller's frame - caller_frame = current_frame.f_back - - # Get the caller's function name - caller_name = caller_frame.f_code.co_name - - # Get the caller's line number - caller_line_number = caller_frame.f_lineno - - debug_message = f"Listener {self.addListener.__name__} created by {caller_name} at line {caller_line_number} with listener: {lis} and check: {check}" - # logger.debug(debug_message) - - if not lis or not check: - raise ValueError("Listener and check functions are required") - - def listener(msg: Any) -> None: - if not lis or not check: - raise ValueError("Listener and check functions are required") - if msg.get("error"): - lis(msg) - elif check(msg): - lis(msg) - - groupListener: ListenerType = ListenerType( - key=getUUID(), - listener=listener, - group_key=groupKey, - debug_message=debug_message, - ) - self._listeners.append(groupListener) - - def destroy() -> None: - self._listeners = removeListener(self._listeners, groupListener) - - return {"destroy": destroy} - - def handle_connection_response(self, m): - if m.get("error"): - if m["errorId"] == 19: - self._invalidAPIkey = "Invalid API key" - else: - self._invalidAPIkey = "Error connection" - return - self._connectionSessionUUID = m.get("newConnectionSessionUUID", {}).get( - "connectionSessionUUID" - ) - self._invalidAPIkey = None - - async def photoMaker(self, requestPhotoMaker: IPhotoMaker): - retry_count = 0 - - try: - await self.ensureConnection() - - task_uuid = requestPhotoMaker.taskUUID or getUUID() - requestPhotoMaker.taskUUID = task_uuid - - for i, image in enumerate(requestPhotoMaker.inputImages): - if isLocalFile(image) and not str(image).startswith("http"): - requestPhotoMaker.inputImages[i] = await fileToBase64(image) - - prompt = f"{requestPhotoMaker.positivePrompt}".strip() - request_object = { - "taskUUID": requestPhotoMaker.taskUUID, - "model": requestPhotoMaker.model, - "positivePrompt": prompt, - "numberResults": requestPhotoMaker.numberResults, - "height": requestPhotoMaker.height, - "width": requestPhotoMaker.width, - "taskType": ETaskType.PHOTO_MAKER.value, - "style": requestPhotoMaker.style, - "strength": requestPhotoMaker.strength, - **( - {"inputImages": requestPhotoMaker.inputImages} - if requestPhotoMaker.inputImages - else {} - ), - **( - {"steps": requestPhotoMaker.steps} - if requestPhotoMaker.steps - else {} - ), - } - - if requestPhotoMaker.outputFormat is not None: - request_object["outputFormat"] = requestPhotoMaker.outputFormat - if requestPhotoMaker.includeCost: - request_object["includeCost"] = requestPhotoMaker.includeCost - if requestPhotoMaker.outputType: - request_object["outputType"] = requestPhotoMaker.outputType - - await self.send([request_object]) - - lis = self.globalListener( - taskUUID=task_uuid, - ) - - numberOfResults = requestPhotoMaker.numberResults - - def check(resolve: callable, reject: callable, *args: Any) -> bool: - photo_maker_list = self._globalMessages.get(task_uuid, []) - unique_results = {} - - for made_photo in photo_maker_list: - if made_photo.get("code"): - raise RunwareAPIError(made_photo) - - if made_photo.get("taskType") != "photoMaker": - continue - - image_uuid = made_photo.get("imageUUID") - if image_uuid not in unique_results: - unique_results[image_uuid] = made_photo - - if 0 < numberOfResults <= len(unique_results): - del self._globalMessages[task_uuid] - resolve(list(unique_results.values())) - return True - - return False - - response = await getIntervalWithPromise(check, debugKey="photo-maker") - - lis["destroy"]() - - if "code" in response: - # This indicates an error response - raise RunwareAPIError(response) - - if response: - if not isinstance(response, list): - response = [response] - - return instantiateDataclassList(IImage, response) - - except Exception as e: - if retry_count >= 2: - self.logger.error(f"Error in photoMaker request:", exc_info=e) - raise RunwareAPIError({"message": f"PhotoMaker failed after retries: {str(e)}"}) - else: - raise e - - async def imageInference( - self, requestImage: IImageInference - ) -> Union[List[IImage], None]: - let_lis: Optional[Any] = None - request_object: Optional[Dict[str, Any]] = None - task_uuids: List[str] = [] - retry_count = 0 - try: - await self.ensureConnection() - control_net_data: List[IControlNet] = [] - requestImage.maskImage = await process_image(requestImage.maskImage) - requestImage.seedImage = await process_image(requestImage.seedImage) - if requestImage.referenceImages: - requestImage.referenceImages = await process_image( - requestImage.referenceImages - ) - if requestImage.controlNet: - for control_data in requestImage.controlNet: - image_uploaded = await self.uploadImage(control_data.guideImage) - if not image_uploaded: - return [] - if hasattr(control_data, "preprocessor"): - control_data.preprocessor = control_data.preprocessor.value - control_data.guideImage = image_uploaded.imageUUID - control_net_data.append(control_data) - prompt = f"{requestImage.positivePrompt}".strip() - - control_net_data_dicts = [asdict(item) for item in control_net_data] - - instant_id_data = {} - if requestImage.instantID: - instant_id_data = { - k: v - for k, v in vars(requestImage.instantID).items() - if v is not None - } - - if "inputImage" in instant_id_data: - instant_id_data["inputImage"] = await process_image( - instant_id_data["inputImage"] - ) - - if "poseImage" in instant_id_data: - instant_id_data["poseImage"] = await process_image( - instant_id_data["poseImage"] - ) - - ip_adapters_data = [] - if requestImage.ipAdapters: - for ip_adapter in requestImage.ipAdapters: - ip_adapter_data = { - k: v for k, v in vars(ip_adapter).items() if v is not None - } - if "guideImage" in ip_adapter_data: - ip_adapter_data["guideImage"] = await process_image( - ip_adapter_data["guideImage"] - ) - - ip_adapters_data.append(ip_adapter_data) - - ace_plus_plus_data = {} - if requestImage.acePlusPlus: - ace_plus_plus_data = { - "inputImages": [], - "repaintingScale": requestImage.acePlusPlus.repaintingScale, - "type": requestImage.acePlusPlus.taskType, - } - if requestImage.acePlusPlus.inputImages: - ace_plus_plus_data["inputImages"] = await process_image( - requestImage.acePlusPlus.inputImages - ) - if requestImage.acePlusPlus.inputMasks: - ace_plus_plus_data["inputMasks"] = await process_image( - requestImage.acePlusPlus.inputMasks - ) - - request_object = { - "offset": 0, - "taskUUID": requestImage.taskUUID, - "modelId": requestImage.model, - "positivePrompt": prompt, - "numberResults": requestImage.numberResults, - "height": requestImage.height, - "width": requestImage.width, - "taskType": ETaskType.IMAGE_INFERENCE.value, - **({"steps": requestImage.steps} if requestImage.steps else {}), - **( - {"controlNet": control_net_data_dicts} - if control_net_data_dicts - else {} - ), - **( - { - "lora": [ - {"model": lora.model, "weight": lora.weight} - for lora in requestImage.lora - ] - } - if requestImage.lora - else {} - ), - **( - { - "lycoris": [ - {"model": lycoris.model, "weight": lycoris.weight} - for lycoris in requestImage.lycoris - ] - } - if requestImage.lycoris - else {} - ), - **( - { - "embeddings": [ - {"model": embedding.model} - for embedding in requestImage.embeddings - ] - } - if requestImage.embeddings - else {} - ), - **({"seed": requestImage.seed} if requestImage.seed else {}), - **( - { - "refiner": { - "model": requestImage.refiner.model, - **( - {"startStep": requestImage.refiner.startStep} - if requestImage.refiner.startStep is not None - else {} - ), - **( - { - "startStepPercentage": requestImage.refiner.startStepPercentage - } - if requestImage.refiner.startStepPercentage is not None - else {} - ), - } - } - if requestImage.refiner - else {} - ), - **({"instantID": instant_id_data} if instant_id_data else {}), - **( - { - "outpaint": { - k: v - for k, v in vars(requestImage.outpaint).items() - if v is not None - } - } - if requestImage.outpaint - else {} - ), - **({"ipAdapters": ip_adapters_data} if ip_adapters_data else {}), - **({"acePlusPlus": ace_plus_plus_data} if ace_plus_plus_data else {}), - } - - # Add optional parameters if they are provided - if requestImage.outputType is not None: - request_object["outputType"] = requestImage.outputType - if requestImage.outputFormat is not None: - request_object["outputFormat"] = requestImage.outputFormat - if requestImage.includeCost: - request_object["includeCost"] = requestImage.includeCost - if requestImage.checkNsfw: - request_object["checkNSFW"] = requestImage.checkNsfw - - if requestImage.negativePrompt: - request_object["negativePrompt"] = requestImage.negativePrompt - if requestImage.CFGScale: - request_object["CFGScale"] = requestImage.CFGScale - if requestImage.seedImage: - request_object["seedImage"] = requestImage.seedImage - if requestImage.acceleratorOptions: - pipeline_options_dict = { - k: v - for k, v in vars(requestImage.acceleratorOptions).items() - if v is not None - } - request_object.update({"acceleratorOptions": pipeline_options_dict}) - if requestImage.advancedFeatures: - pipeline_options_dict = { - k: v.__dict__ - for k, v in vars(requestImage.advancedFeatures).items() - if v is not None - } - request_object.update({"advancedFeatures": pipeline_options_dict}) - if requestImage.maskImage: - request_object["maskImage"] = requestImage.maskImage - if requestImage.referenceImages: - request_object["referenceImages"] = requestImage.referenceImages - if requestImage.strength: - request_object["strength"] = requestImage.strength - if requestImage.scheduler: - request_object["scheduler"] = requestImage.scheduler - if requestImage.vae: - request_object["vae"] = requestImage.vae - if requestImage.promptWeighting: - request_object["promptWeighting"] = requestImage.promptWeighting - if requestImage.maskMargin: - request_object["maskMargin"] = requestImage.maskMargin - if hasattr(requestImage, "extraArgs"): - # if extraArgs is present, and a dictionary, we will add its attributes to the request. - # these may contain options used for public beta testing. - if isinstance(requestImage.extraArgs, dict): - request_object.update(requestImage.extraArgs) - - if requestImage.outputQuality: - request_object["outputQuality"] = requestImage.outputQuality - return await asyncRetry( - lambda: self._requestImages( - request_object=request_object, - task_uuids=task_uuids, - let_lis=let_lis, - retry_count=retry_count, - number_of_images=requestImage.numberResults, - on_partial_images=requestImage.onPartialImages, - ) - ) - except Exception as e: - if retry_count >= 2: - self.logger.error(f"Error in requestImages:", exc_info=e) - raise RunwareAPIError({"message": f"Image inference failed after retries: {str(e)}"}) - else: - raise e - - async def _requestImages( - self, - request_object: Dict[str, Any], - task_uuids: List[str], - let_lis: Optional[Any], - retry_count: int, - number_of_images: int, - on_partial_images: Optional[Callable[[List[IImage], Optional[IError]], None]], - ) -> List[IImage]: - retry_count += 1 - if let_lis: - let_lis["destroy"]() - images_with_similar_task = [ - img for img in self._globalImages if img.get("taskUUID") in task_uuids - ] - - task_uuid = request_object.get("taskUUID") - if task_uuid is None: - task_uuid = getUUID() - - task_uuids.append(task_uuid) - - image_remaining = number_of_images - len(images_with_similar_task) - new_request_object = { - "newTask": { - **request_object, - "taskUUID": task_uuid, - "numberResults": image_remaining, - } - } - await self.send(new_request_object) - - let_lis = await self.listenToImages( - onPartialImages=on_partial_images, - taskUUID=task_uuid, - groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES, - ) - images = await self.getSimililarImage( - taskUUID=task_uuids, - numberOfImages=number_of_images, - shouldThrowError=True, - lis=let_lis, - ) - - let_lis["destroy"]() - # TODO: NameError("name 'image_path' is not defined"). I think I remove the images when I have onPartialImages - if images: - if "code" in images: - # This indicates an error response - raise RunwareAPIError(images) - - return instantiateDataclassList(IImage, images) - - # return images - - async def imageCaption(self, requestImageToText: IImageCaption) -> IImageToText: - try: - await self.ensureConnection() - return await asyncRetry( - lambda: self._requestImageToText(requestImageToText) - ) - except Exception as e: - raise e - - async def _requestImageToText( - self, requestImageToText: IImageCaption - ) -> IImageToText: - inputImage = requestImageToText.inputImage - - image_uploaded = await self.uploadImage(inputImage) - - if not image_uploaded or not image_uploaded.imageUUID: - return None - - taskUUID = getUUID() - - # Create a dictionary with mandatory parameters - task_params = { - "taskType": ETaskType.IMAGE_CAPTION.value, - "taskUUID": taskUUID, - "inputImage": image_uploaded.imageUUID, - } - - # Add optional parameters if they are provided - if requestImageToText.includeCost: - task_params["includeCost"] = requestImageToText.includeCost - - # Send the task with all applicable parameters - await self.send([task_params]) - - lis = self.globalListener( - taskUUID=taskUUID, - ) - - def check(resolve: callable, reject: callable, *args: Any) -> bool: - response = self._globalMessages.get(taskUUID) - # TODO: Check why I need a conversion here? - if response: - image_to_text = response[0] - else: - image_to_text = response - if image_to_text and image_to_text.get("error"): - reject(image_to_text) - return True - - if image_to_text: - del self._globalMessages[taskUUID] - resolve(image_to_text) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="image-to-text", timeOutDuration=self._timeout - ) - - lis["destroy"]() - - if "code" in response: - # This indicates an error response - raise RunwareAPIError(response) - - if response: - return createImageToTextFromResponse(response) - else: - return None - - async def imageBackgroundRemoval( - self, removeImageBackgroundPayload: IImageBackgroundRemoval - ) -> List[IImage]: - try: - await self.ensureConnection() - return await asyncRetry( - lambda: self._removeImageBackground(removeImageBackgroundPayload) - ) - except Exception as e: - raise e - - async def _removeImageBackground( - self, removeImageBackgroundPayload: IImageBackgroundRemoval - ) -> List[IImage]: - inputImage = removeImageBackgroundPayload.inputImage - - image_uploaded = await self.uploadImage(inputImage) - - if not image_uploaded or not image_uploaded.imageUUID: - return [] - if removeImageBackgroundPayload.taskUUID is not None: - taskUUID = removeImageBackgroundPayload.taskUUID - else: - taskUUID = getUUID() - - # Create a dictionary with mandatory parameters - task_params = { - "taskType": ETaskType.IMAGE_BACKGROUND_REMOVAL.value, - "taskUUID": taskUUID, - "inputImage": image_uploaded.imageUUID, - } - - # Add optional parameters if they are provided - if removeImageBackgroundPayload.outputType is not None: - task_params["outputType"] = removeImageBackgroundPayload.outputType - if removeImageBackgroundPayload.outputFormat is not None: - task_params["outputFormat"] = removeImageBackgroundPayload.outputFormat - if removeImageBackgroundPayload.includeCost: - task_params["includeCost"] = removeImageBackgroundPayload.includeCost - if removeImageBackgroundPayload.model: - task_params["model"] = removeImageBackgroundPayload.model - if removeImageBackgroundPayload.outputQuality: - task_params["outputQuality"] = removeImageBackgroundPayload.outputQuality - - # Handle settings if provided - convert dataclass to dictionary and add non-None values - if removeImageBackgroundPayload.settings: - settings_dict = { - k: v - for k, v in vars(removeImageBackgroundPayload.settings).items() - if v is not None - } - task_params.update(settings_dict) - - # Send the task with all applicable parameters - await self.send([task_params]) - - lis = self.globalListener( - taskUUID=taskUUID, - ) - - def check(resolve: callable, reject: callable, *args: Any) -> bool: - response = self._globalMessages.get(taskUUID) - # TODO: Check why I need a conversion here? - if response: - new_remove_background = response[0] - else: - new_remove_background = response - if new_remove_background and new_remove_background.get("error"): - reject(new_remove_background) - return True - - if new_remove_background: - del self._globalMessages[taskUUID] - resolve(new_remove_background) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="remove-image-background", timeOutDuration=self._timeout - ) - - lis["destroy"]() - - if "code" in response: - # This indicates an error response - raise RunwareAPIError(response) - - image = createImageFromResponse(response) - image_list: List[IImage] = [image] - - return image_list - - async def imageUpscale(self, upscaleGanPayload: IImageUpscale) -> List[IImage]: - try: - await self.ensureConnection() - return await asyncRetry(lambda: self._upscaleGan(upscaleGanPayload)) - except Exception as e: - raise e - - async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> List[IImage]: - inputImage = upscaleGanPayload.inputImage - upscaleFactor = upscaleGanPayload.upscaleFactor - - image_uploaded = await self.uploadImage(inputImage) - - if not image_uploaded or not image_uploaded.imageUUID: - return [] - - taskUUID = getUUID() - - # Create a dictionary with mandatory parameters - task_params = { - "taskType": ETaskType.IMAGE_UPSCALE.value, - "taskUUID": taskUUID, - "inputImage": image_uploaded.imageUUID, - "upscaleFactor": upscaleGanPayload.upscaleFactor, - } - - # Add optional parameters if they are provided - if upscaleGanPayload.outputType is not None: - task_params["outputType"] = upscaleGanPayload.outputType - if upscaleGanPayload.outputFormat is not None: - task_params["outputFormat"] = upscaleGanPayload.outputFormat - if upscaleGanPayload.includeCost: - task_params["includeCost"] = upscaleGanPayload.includeCost - - # Send the task with all applicable parameters - await self.send([task_params]) - - lis = self.globalListener( - taskUUID=taskUUID, - ) - - def check(resolve: callable, reject: callable, *args: Any) -> bool: - response = self._globalMessages.get(taskUUID) - # TODO: Check why I need a conversion here? - if response: - upscaled_image = response[0] - else: - upscaled_image = response - if upscaled_image and upscaled_image.get("error"): - reject(upscaled_image) - return True - - if upscaled_image: - del self._globalMessages[taskUUID] - resolve(upscaled_image) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="upscale-gan", timeOutDuration=self._timeout - ) - - lis["destroy"]() - - if "code" in response: - # This indicates an error response - raise RunwareAPIError(response) - - image = createImageFromResponse(response) - # TODO: The respones has an upscaleImageUUID field, should I return it as well? - image_list: List[IImage] = [image] - return image_list - - async def promptEnhance( - self, promptEnhancer: IPromptEnhance - ) -> List[IEnhancedPrompt]: - """ - Enhance the given prompt by generating multiple versions of it. - - :param promptEnhancer: An IPromptEnhancer object containing the prompt details. - :return: A list of IEnhancedPrompt objects representing the enhanced versions of the prompt. - :raises: Any error that occurs during the enhancement process. - """ - try: - await self.ensureConnection() - return await asyncRetry(lambda: self._enhancePrompt(promptEnhancer)) - except Exception as e: - raise e - - async def _enhancePrompt( - self, promptEnhancer: IPromptEnhance - ) -> List[IEnhancedPrompt]: - """ - Internal method to perform the actual prompt enhancement. - - :param promptEnhancer: An IPromptEnhancer object containing the prompt details. - :return: A list of IEnhancedPrompt objects representing the enhanced versions of the prompt. - """ - prompt = promptEnhancer.prompt - promptMaxLength = getattr(promptEnhancer, "promptMaxLength", 380) - - promptVersions = promptEnhancer.promptVersions or 1 - - taskUUID = getUUID() - - # Create a dictionary with mandatory parameters - task_params = { - "taskType": ETaskType.PROMPT_ENHANCE.value, - "taskUUID": taskUUID, - "prompt": prompt, - "promptMaxLength": promptMaxLength, - "promptVersions": promptVersions, - } - - # Add optional parameters if they are provided - if promptEnhancer.includeCost: - task_params["includeCost"] = promptEnhancer.includeCost - - # Send the task with all applicable parameters - await self.send([task_params]) - - lis = self.globalListener( - taskUUID=taskUUID, - ) - - def check(resolve: Any, reject: Any, *args: Any) -> bool: - response = self._globalMessages.get(taskUUID) - if isinstance(response, dict) and response.get("error"): - reject(response) - return True - # if response and len(response) >= promptVersions: - if response: - del self._globalMessages[taskUUID] - resolve(response) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="enhance-prompt", timeOutDuration=self._timeout - ) - - lis["destroy"]() - - if "code" in response[0]: - # This indicates an error response - raise RunwareAPIError(response[0]) - - # Transform the response to a list of IEnhancedPrompt objects - enhanced_prompts = createEnhancedPromptsFromResponse(response) - - return list(set(enhanced_prompts)) - - async def uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: - try: - await self.ensureConnection() - return await asyncRetry(lambda: self._uploadImage(file)) - except Exception as e: - raise e - - async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: - task_uuid = getUUID() - local_file = True - if isinstance(file, str): - if os.path.exists(file): - local_file = True - else: - local_file = isLocalFile(file) - - # Check if it's a base64 string (with or without data URI prefix) - if file.startswith("data:") or re.match( - r"^[A-Za-z0-9+/]+={0,2}$", file - ): - # Assume it's a base64 string (with or without data URI prefix) - local_file = False - if not local_file: - return UploadImageType( - imageUUID=file, - imageURL=file, - taskUUID=task_uuid, - ) - - file = await fileToBase64(file) - - await self.send( - [ - { - "taskType": ETaskType.IMAGE_UPLOAD.value, - "taskUUID": task_uuid, - "image": file, - } - ] - ) - - lis = self.globalListener(taskUUID=task_uuid) - - def check(resolve: callable, reject: callable, *args: Any) -> bool: - uploaded_image_list = self._globalMessages.get(task_uuid) - # TODO: Update to support multiple images - uploaded_image = uploaded_image_list[0] if uploaded_image_list else None - - if uploaded_image and uploaded_image.get("error"): - reject(uploaded_image) - return True - - if uploaded_image: - del self._globalMessages[task_uuid] - resolve(uploaded_image) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="upload-image", timeOutDuration=self._timeout - ) - - lis["destroy"]() - - if "code" in response: - # This indicates an error response - raise RunwareAPIError(response) - - if response: - image = UploadImageType( - imageUUID=response["imageUUID"], - imageURL=response["imageURL"], - taskUUID=response["taskUUID"], - ) - else: - image = None - return image - - async def uploadUnprocessedImage( - self, - file: Union[File, str], - preProcessorType: EPreProcessorGroup, - width: int = None, - height: int = None, - lowThresholdCanny: int = None, - highThresholdCanny: int = None, - includeHandsAndFaceOpenPose: bool = True, - ) -> Optional[UploadImageType]: - # Create a dummy UploadImageType object - uploaded_unprocessed_image = UploadImageType( - imageUUID=str(uuid.uuid4()), - imageURL="https://example.com/uploaded_unprocessed_image.jpg", - taskUUID=str(uuid.uuid4()), - ) - - return uploaded_unprocessed_image - - async def listenToImages( - self, - onPartialImages: Optional[Callable[[List[IImage], Optional[IError]], None]], - taskUUID: str, - groupKey: LISTEN_TO_IMAGES_KEY, - ) -> Dict[str, Callable[[], None]]: - """ - Set up a listener to receive partial image updates for a specific task. - - :param onPartialImages: A callback function to be invoked with the filtered images and any error. - :param taskUUID: The unique identifier of the task to filter images for. - :param groupKey: The group key to categorize the listener. - :return: A dictionary containing a 'destroy' function to remove the listener. - """ - logger.debug("Setting up images listener for taskUUID: %s", taskUUID) - - def listen_to_images_lis(m: Dict[str, Any]) -> None: - # Handle successful image generation - if isinstance(m.get("data"), list): - images = [ - img - for img in m["data"] - if img.get("taskType") == "imageInference" - and img.get("taskUUID") == taskUUID - ] - - if images: - self._globalImages.extend(images) - try: - partial_images = instantiateDataclassList(IImage, images) - if onPartialImages: - onPartialImages( - partial_images, None - ) # No error in this case - except Exception as e: - print( - f"Error occurred in user on_partial_images callback function: {e}" - ) - - # Handle error messages - elif isinstance(m.get("errors"), list): - errors = [ - error for error in m["errors"] if error.get("taskUUID") == taskUUID - ] - if errors: - error = IError( - error=True, # Since this is an error message, we set this to True - error_message=errors[0].get("message", "Unknown error"), - task_uuid=errors[0].get("taskUUID", ""), - error_code=errors[0].get("code"), - error_type=errors[0].get("type"), - parameter=errors[0].get("parameter"), - documentation=errors[0].get("documentation"), - ) - self._globalError = ( - error # Store the first error related to this task - ) - if onPartialImages: - onPartialImages( - [], self._globalError - ) # Empty list for images, pass the error - - def listen_to_images_check(m): - logger.debug("Images check message: %s", m) - # Check for successful image inference messages - image_inference_check = isinstance(m.get("data"), list) and any( - item.get("taskType") == "imageInference" for item in m["data"] - ) - # Check for error messages with matching taskUUID - error_check = isinstance(m.get("errors"), list) and any( - error.get("taskUUID") == taskUUID for error in m["errors"] - ) - error_code_check = ( - True - if any([error.get("code") for error in m.get("errors", [])]) - else False - ) - if error_code_check: - self._globalError = IError( - error=True, - error_message=f"Error in image inference: {m.get('errors')}", - task_uuid=taskUUID, - ) - - response = image_inference_check or error_check - return response - - temp_listener = self.addListener( - check=listen_to_images_check, lis=listen_to_images_lis, groupKey=groupKey - ) - - logger.debug("listenToImages :: Temp listener: %s", temp_listener) - - return temp_listener - - def globalListener(self, taskUUID: str) -> Dict[str, Callable[[], None]]: - """ - Set up a global listener to capture specific messages based on the provided taskUUID. - - :param taskUUID: The unique identifier of the task associated with the listener. - :return: A dictionary containing a 'destroy' function to remove the listener. - """ - logger.debug("Setting up global listener for taskUUID: %s", taskUUID) - - def global_lis(m: Dict[str, Any]) -> None: - logger.debug("Global listener message: %s", m) - logger.debug("Global listener taskUUID: %s", taskUUID) - # logger.debug("Global listener taskKey: %s", taskKey) - - if m.get("error"): - self._globalMessages[taskUUID] = m - return - - value = accessDeepObject( - taskUUID, m - ) # I think this is the taskType now, and it returns the content of 'data' - - if isinstance(value, list): - for v in value: - self._globalMessages[v["taskUUID"]] = self._globalMessages.get( - v["taskUUID"], [] - ) + [v] - logger.debug("Global messages v: %s", v) - logger.debug( - "self._globalMessages[v[taskUUID]]: %s", - self._globalMessages[v["taskUUID"]], - ) - else: - self._globalMessages[value["taskUUID"]] = value - - def global_check(m): - logger.debug("Global check message: %s", m) - return accessDeepObject(taskUUID, m) - - logger.debug("Global Listener taskUUID: %s", taskUUID) - - temp_listener = self.addListener(check=global_check, lis=global_lis) - logger.debug("globalListener :: Temp listener: %s", temp_listener) - - return temp_listener - - def handleIncompleteImages( - self, taskUUIDs: List[str], error: Any - ) -> Optional[List[IImage]]: - """ - Handle scenarios where the requested number of images is not fully received. - - :param taskUUIDs: A list of task UUIDs to filter the images. - :param error: The error object to raise if there are no or only one image. - :return: A list of available images if there are more than one, otherwise None. - :raises: The provided error if there are no or only one image. - """ - imagesWithSimilarTask = [ - img for img in self._globalImages if img["taskUUID"] in taskUUIDs - ] - if len(imagesWithSimilarTask) > 1: - self._globalImages = [ - img for img in self._globalImages if img["taskUUID"] not in taskUUIDs - ] - return imagesWithSimilarTask - else: - raise error - - async def ensureConnection(self) -> None: - """ - Ensure that a connection is established with the server. - - This method checks if the current connection is active and, if not, initiates a new connection. - It handles authentication and retries the connection if necessary. - - :raises: An error message if the connection cannot be established due to an invalid API key or other reasons. - """ - isConnected = self.connected() and self._ws.state is State.OPEN - # print(f"Is connected: {isConnected}") - - try: - if self._invalidAPIkey: - raise ConnectionError(self._invalidAPIkey) - - if not isConnected: - await self.connect() - # await asyncio.sleep(2) - - except Exception as e: - raise ConnectionError( - self._invalidAPIkey - or "Could not connect to server. Ensure your API key is correct" - ) - - async def getSimililarImage( - self, - taskUUID: Union[str, List[str]], - numberOfImages: int, - shouldThrowError: bool = False, - lis: Optional[ListenerType] = None, - timeout: Optional[int] = None, - ) -> Union[List[IImage], IError]: - """ - Retrieve similar images based on the provided task UUID(s) and desired number of images. - - :param taskUUID: A single task UUID or a list of task UUIDs to filter images. - :param numberOfImages: The desired number of images to retrieve. - :param shouldThrowError: A flag indicating whether to throw an error if the desired number of images is not reached. - :param lis: An optional listener to handle image updates. - :param timeout: The timeout duration for the operation. - :return: A list of retrieved images or an error object if the desired number of images is not reached. - """ - taskUUIDs = taskUUID if isinstance(taskUUID, list) else [taskUUID] - - if timeout is None: - timeout = self._timeout - - def check( - resolve: Callable[[List[IImage]], None], - reject: Callable[[IError], None], - intervalId: Any, - ) -> Optional[bool]: - # print(f"Check # Task UUIDs: {taskUUIDs}") - # print(f"Check # Global images: {self._globalImages}") - # print(f"Check # reject: {reject}") - # print(f"Check # resolve: {resolve}") - logger.debug(f"Check # Global images: {self._globalImages}") - imagesWithSimilarTask = [ - img - for img in self._globalImages - if img.get("taskType") == "imageInference" - and img.get("taskUUID") in taskUUIDs - ] - # logger.debug(f"Check # imagesWithSimilarTask: {imagesWithSimilarTask}") - - if self._globalError: - logger.debug(f"Check # _globalError: {self._globalError}") - - error = self._globalError - self._globalError = None - logger.debug(f"Rejecting with error: {error}") - logger.debug(f"Rejecting function: {reject}") - - reject(RunwareError(error)) - return True - elif len(imagesWithSimilarTask) >= numberOfImages: - resolve(imagesWithSimilarTask[:numberOfImages]) - self._globalImages = [ - img - for img in self._globalImages - if img.get("taskType") == "imageInference" - and img.get("taskUUID") not in taskUUIDs - ] - return True - # return False - - return await getIntervalWithPromise( - check, - debugKey="getting images", - shouldThrowError=shouldThrowError, - timeOutDuration=timeout, - ) - - async def _modelUpload( - self, requestModel: IUploadModelBaseType - ) -> Optional[IUploadModelResponse]: - task_uuid = getUUID() - base_fields = { - "taskType": ETaskType.MODEL_UPLOAD.value, - "taskUUID": task_uuid, - "air": requestModel.air, - "name": requestModel.name, - "downloadURL": requestModel.downloadURL, - "uniqueIdentifier": requestModel.uniqueIdentifier, - "version": requestModel.version, - "format": requestModel.format, - "private": requestModel.private, - "category": requestModel.category, - "architecture": requestModel.architecture, - } - - optional_fields = [ - "retry", - "heroImageURL", - "tags", - "shortDescription", - "comment", - "positiveTriggerWords", - "type", - "negativeTriggerWords", - "defaultWeight", - "defaultStrength", - "defaultGuidanceScale", - "defaultSteps", - "defaultScheduler", - "conditioning", - ] - - request_object = { - **base_fields, - **{ - field: getattr(requestModel, field) - for field in optional_fields - if getattr(requestModel, field, None) is not None - }, - } - - await self.send([request_object]) - - lis = self.globalListener( - taskUUID=task_uuid, - ) - - def check(resolve: callable, reject: callable, *args: Any) -> bool: - uploaded_model_list = self._globalMessages.get(task_uuid, []) - unique_statuses = set() - all_models = [] - - for uploaded_model in uploaded_model_list: - if uploaded_model.get("code"): - raise RunwareAPIError(uploaded_model) - - status = uploaded_model.get("status") - - if status not in unique_statuses: - all_models.append(uploaded_model) - unique_statuses.add(status) - - if status is not None and "error" in status: - raise RunwareAPIError(uploaded_model) - - if status == "ready": - uploaded_model_list.remove(uploaded_model) - if not uploaded_model_list: - del self._globalMessages[task_uuid] - else: - self._globalMessages[task_uuid] = uploaded_model_list - resolve(all_models) - return True - - return False - - response = await getIntervalWithPromise( - check, debugKey="upload-model", timeOutDuration=self._timeout - ) - - lis["destroy"]() - - if "code" in response: - # This indicates an error response - raise RunwareAPIError(response) - - if response: - if not isinstance(response, list): - response = [response] - - models = [] - for item in response: - models.append( - { - "taskType": item.get("taskType"), - "taskUUID": item.get("taskUUID"), - "status": item.get("status"), - "message": item.get("message"), - "air": item.get("air"), - } - ) - else: - models = None - return models - - async def modelUpload( - self, requestModel: IUploadModelBaseType - ) -> Optional[IUploadModelResponse]: - try: - await self.ensureConnection() - return await asyncRetry(lambda: self._modelUpload(requestModel)) - except Exception as e: - raise e - - async def modelSearch(self, payload: IModelSearch) -> IModelSearchResponse: - try: - await self.ensureConnection() - task_uuid = getUUID() - - request_object = { - "taskUUID": task_uuid, - "taskType": ETaskType.MODEL_SEARCH.value, - **({"tags": payload.tags} if payload.tags else {}), - } - - request_object.update( - { - key: value - for key, value in vars(payload).items() - if value is not None and key != "additional_params" - } - ) - - await self.send([request_object]) - - listener = self.globalListener(taskUUID=task_uuid) - - def check(resolve: Callable, reject: Callable, *args: Any) -> bool: - response = self._globalMessages.get(task_uuid) - if response: - if response[0].get("error"): - reject(response[0]) - return True - del self._globalMessages[task_uuid] - resolve(response[0]) - return True - return False - - response = await getIntervalWithPromise( - check, debugKey="model-search", timeOutDuration=self._timeout - ) - - listener["destroy"]() - - if "code" in response: - # This indicates an error response - raise RunwareAPIError(response) - - return instantiateDataclass(IModelSearchResponse, response) - - except Exception as e: - if isinstance(e, RunwareAPIError): - raise - - raise RunwareAPIError({"message": str(e)}) - - async def videoInference(self, requestVideo: IVideoInference) -> List[IVideo]: - await self.ensureConnection() - return await asyncRetry(lambda: self._requestVideo(requestVideo)) - - async def _requestVideo(self, requestVideo: IVideoInference) -> List[IVideo]: - await self._processVideoImages(requestVideo) - requestVideo.taskUUID = requestVideo.taskUUID or getUUID() - request_object = self._buildVideoRequest(requestVideo) - await self.send([request_object]) - return await self._handleInitialVideoResponse(requestVideo.taskUUID, requestVideo.numberResults) - - async def _processVideoImages(self, requestVideo: IVideoInference) -> None: - frame_tasks = [] - reference_tasks = [] - - if requestVideo.frameImages: - frame_tasks = [ - process_image(frame_item.inputImage) - for frame_item in requestVideo.frameImages - if isinstance(frame_item, IFrameImage) - ] - - if requestVideo.referenceImages: - reference_tasks = [ - process_image(reference_item) - for reference_item in requestVideo.referenceImages - ] - - frame_results = await gather(*frame_tasks) if frame_tasks else [] - reference_results = await gather(*reference_tasks) if reference_tasks else [] - - if requestVideo.frameImages and frame_results: - processed_frame_images = [] - result_index = 0 - for frame_item in requestVideo.frameImages: - if isinstance(frame_item, IFrameImage): - frame_item.inputImages = frame_results[result_index] - result_index += 1 - processed_frame_images.append(frame_item) - requestVideo.frameImages = processed_frame_images - - if requestVideo.referenceImages and reference_results: - requestVideo.referenceImages = reference_results - - def _buildVideoRequest(self, requestVideo: IVideoInference) -> Dict[str, Any]: - request_object = { - "deliveryMethod": requestVideo.deliveryMethod, - "taskType": ETaskType.VIDEO_INFERENCE.value, - "taskUUID": requestVideo.taskUUID, - "model": requestVideo.model, - "positivePrompt": requestVideo.positivePrompt.strip(), - "numberResults": requestVideo.numberResults, - } - - self._addOptionalVideoFields(request_object, requestVideo) - self._addVideoImages(request_object, requestVideo) - self._addProviderSettings(request_object, requestVideo) - return request_object - - def _addOptionalVideoFields(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None: - optional_fields = [ - "outputType", "outputFormat", "outputQuality", "uploadEndpoint", - "includeCost", "negativePrompt", "fps", "steps", "seed", - "CFGScale", "seedImage", "duration", "width", "height", - ] - - for field in optional_fields: - value = getattr(requestVideo, field, None) - if value is not None: - request_object[field] = value - - def _addVideoImages(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None: - if requestVideo.frameImages: - frame_images_data = [] - for frame_item in requestVideo.frameImages: - frame_images_data.append({k: v for k, v in asdict(frame_item).items() if v is not None}) - request_object["frameImages"] = frame_images_data - - if requestVideo.referenceImages: - request_object["referenceImages"] = requestVideo.referenceImages - - def _addProviderSettings(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None: - if not requestVideo.providerSettings: - return - provider_dict = requestVideo.providerSettings.to_request_dict() - if provider_dict: - request_object["providerSettings"] = provider_dict - - async def _handleInitialVideoResponse(self, task_uuid: str, number_results: int) -> List[IVideo]: - lis = self.globalListener(taskUUID=task_uuid) - - def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool: - response_list = self._globalMessages.get(task_uuid, []) - - if not response_list: - return False - - response = response_list[0] - - if response.get("code"): - raise RunwareAPIError(response) - - if response.get("status") == "success": - del self._globalMessages[task_uuid] - resolve([response]) - return True - - del self._globalMessages[task_uuid] - resolve("POLL_NEEDED") - return True - - try: - initial_response = await getIntervalWithPromise( - check_initial_response, - debugKey="video-inference-initial", - timeOutDuration=30000 - ) - finally: - lis["destroy"]() - - if initial_response == "POLL_NEEDED": - return await self._pollVideoResults(task_uuid, number_results) - else: - return instantiateDataclassList(IVideo, initial_response) - - async def _pollVideoResults(self, task_uuid: str, number_results: int) -> List[IVideo]: - for poll_count in range(MAX_POLLS_VIDEO_GENERATION): - try: - responses = await self._sendPollRequest(task_uuid, poll_count) - completed_results = self._processVideoPollingResponse(responses) - - if len(completed_results) >= number_results: - return instantiateDataclassList(IVideo, completed_results[:number_results]) - - if not self._hasPendingVideos(responses) and not completed_results: - raise RunwareAPIError({"message": f"Unexpected polling response at poll {poll_count}"}) - - except Exception as e: - if poll_count >= MAX_POLLS_VIDEO_GENERATION - 1: - raise e - - await delay(3) - - raise RunwareAPIError({"message": "Video generation timed out"}) - - async def _sendPollRequest(self, task_uuid: str, poll_count: int) -> List[Dict[str, Any]]: - await self.send([{ - "taskType": ETaskType.GET_RESPONSE.value, - "taskUUID": task_uuid - }]) - - lis = self.globalListener(taskUUID=task_uuid) - - def check_poll_response(resolve: callable, reject: callable, *args: Any) -> bool: - response_list = self._globalMessages.get(task_uuid, []) - if response_list: - del self._globalMessages[task_uuid] - resolve(response_list) - return True - return False - - try: - return await getIntervalWithPromise( - check_poll_response, - debugKey=f"video-poll-{poll_count}", - timeOutDuration=10000 - ) - finally: - lis["destroy"]() - - def _processVideoPollingResponse(self, responses: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - completed_results = [] - - for response in responses: - if response.get("code"): - raise RunwareAPIError(response) - status = response.get("status") - if status == "success": - completed_results.append(response) - - return completed_results - - def _hasPendingVideos(self, responses: List[Dict[str, Any]]) -> bool: - return any(response.get("status") == "pending" for response in responses) - - def connected(self) -> bool: - """ - Check if the current WebSocket connection is active and authenticated. - - :return: True if the connection is active and authenticated, False otherwise. - """ - return self.isWebsocketReadyState() and self._connectionSessionUUID is not None diff --git a/runware/client.py b/runware/client.py new file mode 100644 index 0000000..a15ef1a --- /dev/null +++ b/runware/client.py @@ -0,0 +1,700 @@ +import asyncio +import atexit +import weakref +from typing import Any, Callable, Dict, List, Optional, Union + +from .connection.manager import ConnectionManager, ConnectionState +from .core.types import ProgressUpdate +from .exceptions import RunwareAuthenticationError +from .logging_config import get_logger, setup_logging +from .messaging.router import MessageRouter +from .operations.image_background_removal import ImageBackgroundRemovalOperation +from .operations.image_caption import ImageCaptionOperation +from .operations.image_inference import ImageInferenceOperation +from .operations.image_upscale import ImageUpscaleOperation +from .operations.manager import OperationManager +from .operations.model_search import ModelSearchOperation +from .operations.model_upload import ModelUploadOperation +from .operations.photo_maker import PhotoMakerOperation +from .operations.prompt_enhance import PromptEnhanceOperation +from .operations.video_inference import VideoInferenceOperation +from .types import ( + Environment, + File, + IEnhancedPrompt, + IImage, + IImageBackgroundRemoval, + IImageCaption, + IImageInference, + IImageToText, + IImageUpscale, + IModelSearch, + IModelSearchResponse, + IPhotoMaker, + IPromptEnhance, + IUploadModelBaseType, + IUploadModelResponse, + IVideo, + IVideoInference, + UploadImageType, +) +from .utils import BASE_RUNWARE_URLS, fileToBase64, isLocalFile + +logger = get_logger(__name__) +_active_clients = weakref.WeakSet() + + +def _cleanup_all_clients(): + """Cleanup for all active clients""" + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + return + except RuntimeError: + return + for client in list(_active_clients): + try: + client._force_sync_cleanup() + except Exception: + pass + + +atexit.register(_cleanup_all_clients) + + +class RunwareClient: + """ + Runware API client for AI image and video generation. + + This client provides a high-level interface for interacting with the Runware API, + supporting various AI operations like image generation, video generation, and more. + """ + + def __init__( + self, + api_key: str, + url: str = BASE_RUNWARE_URLS[Environment.PRODUCTION], + max_concurrent_operations: int = 50, + default_timeout: float = 300.0, + log_level: str = "CRITICAL", + auto_disconnect: bool = True, + # New parameters for operation-specific timeouts + video_timeout: Optional[float] = None, + image_timeout: Optional[float] = None, + # Legacy parameters for backward compatibility + timeout: Optional[int] = None, + ): + """ + Initialize the Runware client. + + Args: + api_key: Your Runware API key + url: The Runware API endpoint URL + max_concurrent_operations: Maximum number of concurrent operations + default_timeout: Default timeout for operations in seconds + log_level: Logging level for the SDK + auto_disconnect: Whether to auto-disconnect after operations complete + video_timeout: Specific timeout for video operations + image_timeout: Specific timeout for image operations + timeout: Legacy timeout parameter (converted from ms to seconds) + """ + self.api_key = api_key + self.url = url + + # Handle legacy timeout parameter + if timeout is not None: + self.default_timeout = timeout / 1000.0 # Convert ms to seconds + else: + self.default_timeout = default_timeout + + # Set operation-specific timeouts + self.video_timeout = video_timeout or 1800.0 # 30 minutes for video operations + self.image_timeout = image_timeout or default_timeout # Use default for images + + self.auto_disconnect = auto_disconnect + + # Setup logging + setup_logging(log_level) + self.logger = get_logger(self.__class__.__name__) + + # Initialize core components + self.message_router = MessageRouter() + self.operation_manager = OperationManager( + max_concurrent_operations=max_concurrent_operations, + operation_timeout=self.default_timeout, + ) + self.connection_manager = ConnectionManager( + api_key=api_key, url=url, message_router=self.message_router + ) + + # Client state management + self._is_started = False + self._is_disconnecting = False + self._connection_callbacks: List[Callable[[ConnectionState], None]] = [] + + # Register for cleanup + _active_clients.add(self) + + self.logger.info(f"RunwareClient initialized with URL: {url}") + self.logger.info( + f"Timeouts - Default: {self.default_timeout}s, Video: {self.video_timeout}s, Image: {self.image_timeout}s" + ) + + def __del__(self): + """Cleanup method - will attempt to disconnect synchronously if possible""" + if self._is_started and not self._is_disconnecting: + self._force_sync_cleanup() + + def _force_sync_cleanup(self): + """Force synchronous cleanup of resources""" + if self._is_disconnecting: + return + + self._is_disconnecting = True + + try: + # Try to get current event loop + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # If loop is running, schedule cleanup + loop.create_task(self._emergency_disconnect()) + return + except RuntimeError: + pass + + self._is_started = False + + except Exception: + # Silently ignore cleanup errors in destructor + pass + + async def _emergency_disconnect(self): + """Emergency async disconnect""" + try: + await asyncio.wait_for(self.disconnect(), timeout=5.0) + except Exception: + self._force_sync_cleanup() + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.disconnect() + + # Legacy method for backward compatibility + async def ensureConnection(self): + """Legacy method - use connect() instead""" + if not self._is_started: + await self.connect() + + async def connect(self): + """Establish connection to the Runware API""" + if self._is_started: + self.logger.warning("Client already started") + return + + self.logger.info("Starting Runware client") + + try: + await self.message_router.start() + await self.operation_manager.start() + + self.connection_manager.add_connection_callback( + self._on_connection_state_change + ) + await self.connection_manager.start() + + authenticated = await self.connection_manager.wait_for_authentication( + timeout=30.0 + ) + if not authenticated: + raise RunwareAuthenticationError( + "Failed to authenticate within timeout" + ) + + self._is_started = True + self.logger.info("Runware client started successfully") + + except Exception as e: + self.logger.error("Failed to start Runware client", exc_info=e) + await self._cleanup() + raise + + async def disconnect(self): + """Disconnect from the Runware API""" + if not self._is_started or self._is_disconnecting: + return + + self._is_disconnecting = True + self.logger.info("Starting Runware client shutdown") + + try: + # Stop components in reverse order with timeouts + await asyncio.wait_for(self.operation_manager.stop(), timeout=10.0) + await asyncio.wait_for(self.message_router.stop(), timeout=10.0) + await asyncio.wait_for(self.connection_manager.stop(), timeout=10.0) + + self._is_started = False + self.logger.info("Runware client stopped successfully") + + except asyncio.TimeoutError as e: + self.logger.error("Timeout during client shutdown", exc_info=e) + except Exception as e: + self.logger.error("Error during client shutdown", exc_info=e) + finally: + self._is_disconnecting = False + + def is_connected(self) -> bool: + """Check if the client is connected and authenticated""" + return self._is_started and self.connection_manager.is_authenticated() + + # Legacy methods for backward compatibility + def connected(self) -> bool: + """Legacy method - use is_connected() instead""" + return self.is_connected() + + def isWebsocketReadyState(self) -> bool: + """Legacy method - use is_connected() instead""" + return self.is_connected() + + def isAuthenticated(self) -> bool: + """Legacy method - use is_connected() instead""" + return self.is_connected() + + def add_connection_callback(self, callback: Callable[[ConnectionState], None]): + """Add a callback for connection state changes""" + self._connection_callbacks.append(callback) + if self._is_started: + self.connection_manager.add_connection_callback(callback) + + def remove_connection_callback(self, callback: Callable[[ConnectionState], None]): + """Remove a connection state change callback""" + if callback in self._connection_callbacks: + self._connection_callbacks.remove(callback) + if self._is_started: + self.connection_manager.remove_connection_callback(callback) + + def _get_operation_timeout( + self, operation_type: str, user_timeout: Optional[float] = None + ) -> float: + """Get appropriate timeout for operation type""" + if user_timeout is not None: + return user_timeout + + if operation_type == "videoInference": + return self.video_timeout + elif operation_type in [ + "imageInference", + "photoMaker", + "imageCaption", + "imageBackgroundRemoval", + "imageUpscale", + ]: + return self.image_timeout + else: + return self.default_timeout + + async def _execute_operation(self, operation, operation_timeout: float): + """ + Common operation execution logic using Template Method pattern. + + This method encapsulates the common steps for executing any operation: + 1. Register the operation with the message router + 2. Execute the operation with the operation manager + 3. Unregister the operation + 4. Handle auto-disconnect if configured + """ + await self.message_router.register_operation(operation) + + try: + results = await self.operation_manager.execute_operation( + operation, operation_timeout + ) + return results + finally: + await self.message_router.unregister_operation(operation.operation_id) + await self._check_auto_disconnect() + + async def imageInference( + self, + requestImage: IImageInference, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> List[IImage]: + """ + Generate images using AI image inference. + + Args: + requestImage: Image inference request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + List of generated images + """ + await self.ensureConnection() + + operation_timeout = self._get_operation_timeout("imageInference", timeout) + self.logger.debug(f"Using timeout {operation_timeout}s for image inference") + + operation = ImageInferenceOperation(requestImage, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def videoInference( + self, + requestVideo: IVideoInference, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> List[IVideo]: + """ + Generate videos using AI video inference. + + Args: + requestVideo: Video inference request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + List of generated videos + """ + await self.ensureConnection() + + operation_timeout = self._get_operation_timeout("videoInference", timeout) + self.logger.info( + f"Starting video inference with timeout {operation_timeout}s ({operation_timeout / 60:.1f} minutes)" + ) + + operation = VideoInferenceOperation(requestVideo, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]: + """ + Upload an image file or return existing image reference. + + Args: + file: File object, file path, or existing image reference + + Returns: + UploadImageType with image UUID and URL + """ + await self.ensureConnection() + + if isinstance(file, str): + if not isLocalFile(file): + return UploadImageType(imageUUID=file, imageURL=file, taskUUID="direct") + + try: + file = await fileToBase64(file) + except Exception as e: + raise + + return UploadImageType( + imageUUID=str(hash(file))[:16], imageURL="uploaded", taskUUID="upload" + ) + + async def imageCaption( + self, + requestImageToText: IImageCaption, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> IImageToText: + """ + Generate captions for images. + + Args: + requestImageToText: Image caption request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + List of image captions + """ + await self.ensureConnection() + + operation_timeout = self._get_operation_timeout("imageCaption", timeout) + + operation = ImageCaptionOperation(requestImageToText, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def imageBackgroundRemoval( + self, + removeImageBackgroundPayload: IImageBackgroundRemoval, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> List[IImage]: + """ + Remove backgrounds from images. + + Args: + removeImageBackgroundPayload: Background removal request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + List of images with removed backgrounds + """ + await self.ensureConnection() + + operation_timeout = self._get_operation_timeout( + "imageBackgroundRemoval", timeout + ) + + operation = ImageBackgroundRemovalOperation(removeImageBackgroundPayload, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def imageUpscale( + self, + upscaleGanPayload: IImageUpscale, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> List[IImage]: + """ + Upscale images using AI. + + Args: + upscaleGanPayload: Image upscale request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + List of upscaled images + """ + await self.ensureConnection() + + operation_timeout = self._get_operation_timeout("imageUpscale", timeout) + + operation = ImageUpscaleOperation(upscaleGanPayload, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def promptEnhance( + self, + promptEnhancer: IPromptEnhance, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> List[IEnhancedPrompt]: + """ + Enhance prompts using AI. + + Args: + promptEnhancer: Prompt enhancement request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + List of enhanced prompts + """ + await self.ensureConnection() + + operation_timeout = timeout or self.default_timeout + + operation = PromptEnhanceOperation(promptEnhancer, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def photoMaker( + self, + requestPhotoMaker: IPhotoMaker, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> List[IImage]: + """ + Generate images with PhotoMaker. + + Args: + requestPhotoMaker: PhotoMaker request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + List of generated images + """ + await self.ensureConnection() + + operation_timeout = self._get_operation_timeout("photoMaker", timeout) + + operation = PhotoMakerOperation(requestPhotoMaker, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def modelUpload( + self, + requestModel: IUploadModelBaseType, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> Optional[List[IUploadModelResponse]]: + """ + Upload a model to Runware. + + Args: + requestModel: Model upload request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + List of upload model responses + """ + await self.ensureConnection() + + operation_timeout = timeout or self.default_timeout + + operation = ModelUploadOperation(requestModel, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def modelSearch( + self, + payload: IModelSearch, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[ProgressUpdate], None]] = None, + ) -> IModelSearchResponse: + """ + Search for models in the Runware model library. + + Args: + payload: Model search request parameters + timeout: Optional timeout in seconds + progress_callback: Optional progress callback function + + Returns: + Model search response with results + """ + await self.ensureConnection() + + operation_timeout = timeout or self.default_timeout + + operation = ModelSearchOperation(payload, self) + + if progress_callback: + operation.add_progress_callback(progress_callback) + + return await self._execute_operation(operation, operation_timeout) + + async def _check_auto_disconnect(self): + """Check if we should auto-disconnect after operation completion""" + if not self.auto_disconnect or not self._is_started: + return + + # Check if there are any active operations + active_operations = len(self.operation_manager.operations) + if active_operations == 0: + # Schedule disconnect to run after current operation completes + asyncio.create_task(self._delayed_disconnect()) + + async def _delayed_disconnect(self): + """Disconnect after a short delay to ensure operation cleanup""" + try: + await asyncio.sleep(0.1) # Small delay to ensure cleanup + if len(self.operation_manager.operations) == 0: + await self.disconnect() + except Exception as e: + self.logger.debug(f"Error in delayed disconnect: {e}") + + # Operation management methods + def get_operation_status(self, operation_id: str) -> Optional[Dict[str, Any]]: + """Get status of a specific operation""" + context = self.operation_manager.get_operation_context(operation_id) + if context: + return { + "operation_id": context.operation_id, + "operation_type": context.operation_type, + "status": context.status.value, + "progress": context.progress, + "created_at": context.created_at, + "completed_at": context.completed_at, + "results_count": len(context.results) if context.results else 0, + } + return None + + def list_operations(self, include_completed: bool = False) -> List[Dict[str, Any]]: + """List all current operations""" + contexts = self.operation_manager.list_operations() + + operations = [] + for context in contexts: + if not include_completed and context.status.value in [ + "completed", + "failed", + "cancelled", + ]: + continue + + operations.append( + { + "operation_id": context.operation_id, + "operation_type": context.operation_type, + "status": context.status.value, + "progress": context.progress, + "created_at": context.created_at, + "completed_at": context.completed_at, + } + ) + + return operations + + async def cancel_operation(self, operation_id: str) -> bool: + """Cancel a specific operation""" + return await self.operation_manager.cancel_operation(operation_id) + + async def cancel_all_operations(self) -> int: + """Cancel all active operations""" + return await self.operation_manager.cancel_all_operations() + + async def wait_for_connection(self, timeout: Optional[float] = 30.0) -> bool: + """Wait for connection to be established""" + if not self._is_started: + return False + return await self.connection_manager.wait_for_authentication(timeout) + + def _on_connection_state_change(self, new_state: ConnectionState): + """Handle connection state changes""" + for callback in self._connection_callbacks: + try: + callback(new_state) + except Exception as e: + self.logger.error("Error in connection callback", exc_info=e) + + async def _cleanup(self): + """Cleanup all resources""" + try: + if self._is_started: + await self.operation_manager.stop() + await self.message_router.stop() + await self.connection_manager.stop() + except Exception as e: + self.logger.error("Error during cleanup", exc_info=e) + + +# Backward compatibility alias +Runware = RunwareClient +RunwareServer = RunwareClient # lol diff --git a/runware/connection/__init__.py b/runware/connection/__init__.py new file mode 100644 index 0000000..4ac24b8 --- /dev/null +++ b/runware/connection/__init__.py @@ -0,0 +1,11 @@ +from .manager import ConnectionManager +from .types import ( + ConnectionState, + ConnectionStateCallback, +) + +__all__ = [ + "ConnectionManager", + "ConnectionState", + "ConnectionStateCallback", +] diff --git a/runware/connection/manager.py b/runware/connection/manager.py new file mode 100644 index 0000000..ff6f90c --- /dev/null +++ b/runware/connection/manager.py @@ -0,0 +1,435 @@ +import asyncio +import time +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +import websockets +from websockets.protocol import State + +from .types import ConnectionState +from ..core.cpu_bound import cpu_executor +from ..exceptions import RunwareAuthenticationError, RunwareConnectionError +from ..logging_config import get_logger + +if TYPE_CHECKING: + from ..messaging.router import MessageRouter + +logger = get_logger(__name__) + + +class ConnectionManager: + """Manages WebSocket connection with proper error handling.""" + + def __init__(self, api_key: str, url: str, message_router: "MessageRouter"): + self.api_key = api_key + self.url = url + self.message_router = message_router + + self.state = ConnectionState.DISCONNECTED + self.websocket: Optional[websockets.WebSocketServerProtocol] = None + self.connection_session_uuid: Optional[str] = None + + # Message sending + self.message_queue = asyncio.Queue(maxsize=1000) + + # Lifecycle management + self.is_running = False + self._should_reconnect = True + + # Background tasks + self._message_handler_task: Optional[asyncio.Task] = None + self._sender_task: Optional[asyncio.Task] = None + self._heartbeat_task: Optional[asyncio.Task] = None + + # Events for synchronization + self._connection_event = asyncio.Event() + self._authentication_event = asyncio.Event() + self._authentication_error_event = asyncio.Event() + self._auth_error: Optional[Exception] = None + + # Configuration + self._heartbeat_interval = 30.0 + self._heartbeat_timeout = 60.0 + self._last_pong_time = 0.0 + + # Reconnection settings + self._reconnect_attempts = 0 + self._max_reconnect_attempts = 3 + self._base_reconnect_delay = 2.0 + self._max_reconnect_delay = 30.0 + + # Callbacks + self._connection_callbacks: List[Callable[[ConnectionState], None]] = [] + + logger.info(f"ConnectionManager initialized for URL: {url}") + + async def start(self): + if self.is_running: + return + + self.is_running = True + self._should_reconnect = True + logger.info("Starting ConnectionManager") + + try: + await self._connect() + except Exception as e: + self.is_running = False + raise + + async def stop(self): + if not self.is_running: + return + + logger.info("Stopping ConnectionManager") + self.is_running = False + self._should_reconnect = False + + await self._disconnect() + await self._cleanup_tasks() + + async def send_message(self, content: List[Dict[str, Any]]) -> str: + if not self.is_connected(): + raise RunwareConnectionError("Cannot send message: not connected") + + message_id = f"msg_{int(time.time() * 1000000)}" + message_data = {"id": message_id, "content": content} + + try: + await asyncio.wait_for(self.message_queue.put(message_data), timeout=5.0) + return message_id + except asyncio.TimeoutError: + raise RunwareConnectionError("Message queue full") + + def is_connected(self) -> bool: + return ( + self.websocket is not None + and self.websocket.state == State.OPEN + and self.state in [ConnectionState.CONNECTED, ConnectionState.AUTHENTICATED] + ) + + def is_authenticated(self) -> bool: + return ( + self.state == ConnectionState.AUTHENTICATED + and self.connection_session_uuid is not None + ) + + async def wait_for_authentication(self, timeout: Optional[float] = 30.0) -> bool: + try: + await asyncio.wait_for(self._authentication_event.wait(), timeout) + return True + except asyncio.TimeoutError: + return False + + def add_connection_callback(self, callback: Callable[[ConnectionState], None]): + self._connection_callbacks.append(callback) + + def remove_connection_callback(self, callback: Callable[[ConnectionState], None]): + if callback in self._connection_callbacks: + self._connection_callbacks.remove(callback) + + async def _connect(self): + if self.state in [ConnectionState.CONNECTING, ConnectionState.RECONNECTING]: + return + + self._set_state(ConnectionState.CONNECTING) + + try: + logger.info(f"Connecting to WebSocket: {self.url}") + + self.websocket = await asyncio.wait_for( + websockets.connect( + self.url, + close_timeout=2, + max_size=None, + ping_interval=None, + ping_timeout=None, + ), + timeout=15.0, + ) + + self._set_state(ConnectionState.CONNECTED) + self._connection_event.set() + + # Start background tasks + await self._start_connection_tasks() + + # Authenticate - this will throw RunwareAuthenticationError if auth fails + await self._authenticate() + + logger.info("WebSocket connection authenticated successfully") + + except RunwareAuthenticationError: + # Authentication errors are not recoverable - don't retry + self._should_reconnect = False + self._set_state(ConnectionState.FAILED) + raise + except Exception as e: + self._set_state(ConnectionState.FAILED) + + if self.is_running and self._should_reconnect: + await self._schedule_reconnect() + else: + raise RunwareConnectionError(f"Connection failed: {e}") + + async def _start_connection_tasks(self): + self._message_handler_task = asyncio.create_task(self._handle_messages()) + self._sender_task = asyncio.create_task(self._send_messages()) + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + async def _authenticate(self): + if not self.is_connected(): + raise RunwareConnectionError("Cannot authenticate: not connected") + + self._authentication_event.clear() + self._authentication_error_event.clear() + self._auth_error = None + + auth_message = [{"taskType": "authentication", "apiKey": self.api_key}] + + if self.connection_session_uuid: + auth_message[0]["connectionSessionUUID"] = self.connection_session_uuid + + await self.send_message(auth_message) + + # Wait for either success or error + done, pending = await asyncio.wait( + [ + asyncio.create_task(self._authentication_event.wait()), + asyncio.create_task(self._authentication_error_event.wait()), + ], + timeout=15.0, + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel pending tasks + for task in pending: + task.cancel() + + if not done: + raise RunwareAuthenticationError("Authentication timeout") + + # Check if there was an error + if self._authentication_error_event.is_set(): + if self._auth_error: + raise self._auth_error + else: + raise RunwareAuthenticationError("Authentication failed") + + async def _send_messages(self): + try: + while self.is_running and self.is_connected(): + try: + message = await asyncio.wait_for( + self.message_queue.get(), timeout=1.0 + ) + await self._send_single_message(message) + except asyncio.TimeoutError: + continue + except Exception as e: + logger.error("Error in message sender", exc_info=e) + if not self.is_running: + break + except asyncio.CancelledError: + pass + + async def _send_single_message(self, message: Dict[str, Any]): + try: + serialized = await cpu_executor.serialize_json(message["content"]) + await self.websocket.send(serialized) + except Exception as e: + logger.error(f"Failed to send message {message['id']}", exc_info=e) + raise + + async def _handle_messages(self): + try: + async for raw_message in self.websocket: + if not self.is_running: + break + await self._process_incoming_message(raw_message) + except websockets.exceptions.ConnectionClosed: + if self.is_running and self._should_reconnect: + await self._handle_connection_loss() + except Exception as e: + logger.error("Error in message handler", exc_info=e) + if self.is_running and self._should_reconnect: + await self._handle_connection_loss() + + async def _process_incoming_message(self, raw_message: str): + try: + message = await cpu_executor.parse_json(raw_message) + + # Handle system messages first + if await self._handle_system_message(message): + return + + # Route to operations + await self.message_router.route_message(message) + + except Exception as e: + logger.error("Error processing incoming message", exc_info=e) + + async def _handle_system_message(self, message: Dict[str, Any]) -> bool: + if "data" in message: + for item in message["data"]: + if item.get("taskType") == "authentication": + return await self._handle_authentication_response(item) + elif item.get("pong"): + self._last_pong_time = time.time() + return True + + if "errors" in message: + for error in message["errors"]: + if error.get("taskType") == "authentication": + await self._handle_authentication_error(error) + return True + + return False + + async def _handle_authentication_response(self, auth_data: Dict[str, Any]) -> bool: + self.connection_session_uuid = auth_data.get("connectionSessionUUID") + + if self.connection_session_uuid: + self._set_state(ConnectionState.AUTHENTICATED) + self._authentication_event.set() + self._reconnect_attempts = 0 + return True + else: + self._auth_error = RunwareAuthenticationError( + "Authentication response missing session UUID" + ) + self._authentication_error_event.set() + return False + + async def _handle_authentication_error(self, error_data: Dict[str, Any]): + error_message = error_data.get("message", "Authentication failed") + + # Store the error and signal error event + self._auth_error = RunwareAuthenticationError(error_message) + self._authentication_error_event.set() + + # Don't allow reconnection for auth errors + self._should_reconnect = False + self._set_state(ConnectionState.FAILED) + + async def _heartbeat_loop(self): + try: + while self.is_running and self.is_connected(): + try: + ping_message = [{"taskType": "ping", "ping": True}] + await self.send_message(ping_message) + + await asyncio.sleep(self._heartbeat_interval) + + if (time.time() - self._last_pong_time) > self._heartbeat_timeout: + if self.is_running and self._should_reconnect: + await self._handle_connection_loss() + break + + except Exception as e: + logger.error("Heartbeat error", exc_info=e) + if not self.is_running: + break + except asyncio.CancelledError: + pass + + async def _handle_connection_loss(self): + if self.state == ConnectionState.RECONNECTING: + return + + await self._disconnect() + + if self.is_running and self._should_reconnect: + await self._schedule_reconnect() + + async def _schedule_reconnect(self): + self._set_state(ConnectionState.RECONNECTING) + + while ( + self.is_running + and self._should_reconnect + and self._reconnect_attempts < self._max_reconnect_attempts + and self.state != ConnectionState.AUTHENTICATED + ): + self._reconnect_attempts += 1 + + delay = min( + self._base_reconnect_delay * (2 ** (self._reconnect_attempts - 1)), + self._max_reconnect_delay, + ) + + logger.info( + f"Reconnect attempt {self._reconnect_attempts}/{self._max_reconnect_attempts} in {delay:.2f}s" + ) + + try: + await asyncio.sleep(delay) + except asyncio.CancelledError: + break + + try: + await self._connect() + if self.is_authenticated(): + self._reconnect_attempts = 0 + break + except RunwareAuthenticationError: + # Authentication errors should stop reconnection attempts + self._should_reconnect = False + break + except Exception as e: + logger.error( + f"Reconnect attempt {self._reconnect_attempts} failed", exc_info=e + ) + + if self._reconnect_attempts >= self._max_reconnect_attempts: + self._set_state(ConnectionState.FAILED) + + async def _disconnect(self): + if self.state == ConnectionState.DISCONNECTED: + return + + self._set_state(ConnectionState.DISCONNECTED) + + if self.websocket and self.websocket.state == State.OPEN: + try: + await asyncio.wait_for(self.websocket.close(), timeout=3.0) + except Exception: + pass + + self.websocket = None + self.connection_session_uuid = None + + # Clear events + self._connection_event.clear() + self._authentication_event.clear() + self._authentication_error_event.clear() + + async def _cleanup_tasks(self): + tasks = [self._message_handler_task, self._sender_task, self._heartbeat_task] + + for task in tasks: + if task and not task.done(): + task.cancel() + try: + await asyncio.wait_for(task, timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + + # Clear task references + self._message_handler_task = None + self._sender_task = None + self._heartbeat_task = None + + def _set_state(self, new_state: ConnectionState): + if self.state != new_state: + old_state = self.state + self.state = new_state + + logger.info( + f"Connection state changed: {old_state.value} -> {new_state.value}" + ) + + for callback in self._connection_callbacks: + try: + callback(new_state) + except Exception as e: + logger.error("Error in connection callback", exc_info=e) diff --git a/runware/connection/types.py b/runware/connection/types.py new file mode 100644 index 0000000..f417601 --- /dev/null +++ b/runware/connection/types.py @@ -0,0 +1,14 @@ +from enum import Enum +from typing import Callable + + +class ConnectionState(Enum): + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + AUTHENTICATED = "authenticated" + RECONNECTING = "reconnecting" + FAILED = "failed" + + +ConnectionStateCallback = Callable[[ConnectionState], None] diff --git a/runware/core/__init__.py b/runware/core/__init__.py new file mode 100644 index 0000000..4565b79 --- /dev/null +++ b/runware/core/__init__.py @@ -0,0 +1,15 @@ +from .types import ( + Message, + MessageType, + OperationContext, + OperationStatus, + ProgressUpdate, +) + +__all__ = [ + "OperationStatus", + "MessageType", + "OperationContext", + "ProgressUpdate", + "Message", +] diff --git a/runware/core/cpu_bound.py b/runware/core/cpu_bound.py new file mode 100644 index 0000000..be642e8 --- /dev/null +++ b/runware/core/cpu_bound.py @@ -0,0 +1,81 @@ +import asyncio +import functools +import json +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict, is_dataclass +from typing import Any, Callable, Dict, List, TypeVar + +T = TypeVar("T") + + +class CPUBoundExecutor: + """Manages CPU-bound operations in thread pool.""" + + def __init__(self, max_workers: int = None): + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._loop = None + + async def __aenter__(self): + self._loop = asyncio.get_event_loop() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._executor.shutdown(wait=True) + + async def parse_json(self, data: str) -> Dict[str, Any]: + """Parse JSON in thread pool.""" + return await self._run_in_executor(json.loads, data) + + async def serialize_json(self, obj: Any) -> str: + """Serialize to JSON in thread pool with enum support.""" + + def json_dumps_with_enums(obj): + def convert_enums(data): + if hasattr(data, "value") and hasattr(data, "name"): # It's an enum + return data.value + elif isinstance(data, dict): + return {key: convert_enums(value) for key, value in data.items()} + elif isinstance(data, list): + return [convert_enums(item) for item in data] + else: + return data + + converted_obj = convert_enums(obj) + return json.dumps(converted_obj) + + return await self._run_in_executor(json_dumps_with_enums, obj) + + async def serialize_dataclass(self, obj: Any) -> Dict[str, Any]: + """Serialize dataclass in thread pool with enum support.""" + if is_dataclass(obj): + + def serialize_with_enums(obj): + result = asdict(obj) + # Convert enums to their string values + for key, value in result.items(): + if hasattr(value, "value") and hasattr( + value, "name" + ): # It's an enum + result[key] = value.value + return result + + return await self._run_in_executor(serialize_with_enums, obj) + return obj + + async def batch_serialize_dataclasses( + self, objects: List[Any] + ) -> List[Dict[str, Any]]: + """Serialize multiple dataclasses concurrently.""" + tasks = [self.serialize_dataclass(obj) for obj in objects] + return await asyncio.gather(*tasks) + + async def _run_in_executor(self, func: Callable, *args, **kwargs): + """Run function in thread pool.""" + loop = self._loop or asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, functools.partial(func, *args, **kwargs) + ) + + +# Global instance +cpu_executor = CPUBoundExecutor(max_workers=4) diff --git a/runware/core/error_context.py b/runware/core/error_context.py new file mode 100644 index 0000000..dda9a4b --- /dev/null +++ b/runware/core/error_context.py @@ -0,0 +1,56 @@ +from contextlib import asynccontextmanager + +from ..exceptions import RunwareOperationError + + +class ErrorContext: + """Provides context for error handling.""" + + def __init__(self, operation_id: str, operation_type: str): + self.operation_id = operation_id + self.operation_type = operation_type + self.context_stack = [] + + @asynccontextmanager + async def phase(self, phase_name: str): + """Context manager for operation phases.""" + self.context_stack.append(phase_name) + try: + yield + except Exception as e: + if not hasattr(e, "_runware_context"): + e._runware_context = { + "operation_id": self.operation_id, + "operation_type": self.operation_type, + "phase": phase_name, + "stack": self.context_stack.copy(), + } + raise + finally: + if self.context_stack and self.context_stack[-1] == phase_name: + self.context_stack.pop() + + def wrap_error(self, error: Exception) -> RunwareOperationError: + """Wrap exception with full context.""" + context = { + "operation_id": self.operation_id, + "operation_type": self.operation_type, + "phases": self.context_stack, + "original_error": str(error), + "error_type": type(error).__name__, + } + + current_phase = ( + self.context_stack[-1] if self.context_stack else "unknown phase" + ) + + wrapped_error = RunwareOperationError( + message=f"{self.operation_type} failed in {current_phase}: {error}", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + + # Store context in the error for debugging + wrapped_error.details = context + + return wrapped_error diff --git a/runware/core/types.py b/runware/core/types.py new file mode 100644 index 0000000..ca3cbef --- /dev/null +++ b/runware/core/types.py @@ -0,0 +1,69 @@ +import time +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional + + +class OperationStatus(Enum): + PENDING = "pending" + EXECUTING = "executing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + TIMEOUT = "timeout" + + +class MessageType(Enum): + OPERATION_UPDATE = "operation_update" + OPERATION_COMPLETE = "operation_complete" + OPERATION_ERROR = "operation_error" + PROGRESS_UPDATE = "progress_update" + SERVER_MESSAGE = "server_message" + + +@dataclass +class OperationContext: + operation_id: str + operation_type: str + status: OperationStatus + progress: float = 0.0 + results: List[Any] = None + error: Optional[Exception] = None + start_time: Optional[float] = None + end_time: Optional[float] = None + metadata: Dict[str, Any] = None + completed_at: Optional[float] = None + created_at: Optional[float] = None + + def __post_init__(self): + if self.results is None: + self.results = [] + if self.metadata is None: + self.metadata = {} + + +@dataclass +class ProgressUpdate: + operation_id: str + progress: float + message: str = "" + partial_results: List[Any] = None + timestamp: float = None + + def __post_init__(self): + if self.partial_results is None: + self.partial_results = [] + if self.timestamp is None: + self.timestamp = time.time() + + +@dataclass +class Message: + type: MessageType + operation_id: Optional[str] + data: Dict[str, Any] + timestamp: float = None + + def __post_init__(self): + if self.timestamp is None: + self.timestamp = time.time() diff --git a/runware/exceptions.py b/runware/exceptions.py new file mode 100644 index 0000000..530032b --- /dev/null +++ b/runware/exceptions.py @@ -0,0 +1,127 @@ +from typing import Any, Dict, Optional + + +class RunwareError(Exception): + """Base exception for all Runware SDK errors.""" + + def __init__( + self, + message: str, + code: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, + parameter: Optional[str] = None, + error_type: Optional[str] = None, + documentation: Optional[str] = None, + task_uuid: Optional[str] = None, + ): + super().__init__(message) + self.message = message + self.code = code + self.details = details or {} + self.parameter = parameter + self.error_type = error_type + self.documentation = documentation + self.task_uuid = task_uuid + + def format_error(self): + """Format error for backward compatibility.""" + return { + "errors": [ + { + "code": self.code, + "message": self.message, + "parameter": self.parameter, + "type": self.error_type, + "documentation": self.documentation, + "taskUUID": self.task_uuid, + } + ] + } + + def __str__(self): + return str(self.format_error()) + + +class RunwareConnectionError(RunwareError): + """Raised when there are connection-related issues.""" + + def __init__(self, message: str, connection_state: Optional[str] = None, **kwargs): + super().__init__(message, **kwargs) + self.connection_state = connection_state + + +class RunwareAuthenticationError(RunwareError): + """Raised when authentication fails.""" + + def __init__(self, message: str = "Authentication failed", **kwargs): + super().__init__(message, **kwargs) + + +class RunwareOperationError(RunwareError): + """Raised when an operation fails.""" + + def __init__( + self, + message: str, + operation_id: str, + operation_type: Optional[str] = None, + **kwargs, + ): + super().__init__(message, **kwargs) + self.operation_id = operation_id + self.operation_type = operation_type + + +class RunwareTimeoutError(RunwareError): + """Raised when an operation times out.""" + + def __init__( + self, message: str, timeout_duration: Optional[float] = None, **kwargs + ): + super().__init__(message, **kwargs) + self.timeout_duration = timeout_duration + + +class RunwareParseError(RunwareOperationError): + """Raised when response parsing fails.""" + + def __init__( + self, + message: str, + operation_id: str = "unknown", + operation_type: str = "unknown", + raw_data: Optional[Dict[str, Any]] = None, + **kwargs, + ): + super().__init__(message, operation_id, operation_type, **kwargs) + self.raw_data = raw_data or {} + + +class RunwareAPIError(Exception): + """API error for backward compatibility with old SDK.""" + + def __init__(self, error_data: Dict[str, Any]): + self.error_data = error_data + self.code = error_data.get("code") + super().__init__(str(error_data)) + + def __str__(self): + return f"RunwareAPIError: {self.error_data}" + + +class RunwareValidationError(RunwareOperationError): + """Raised when validation fails.""" + + pass + + +class RunwareResourceError(RunwareOperationError): + """Raised when resource constraints are hit.""" + + pass + + +class RunwareServerError(RunwareOperationError): + """Raised when server returns an error.""" + + pass diff --git a/runware/logging_config.py b/runware/logging_config.py index 7d7dabd..674be25 100644 --- a/runware/logging_config.py +++ b/runware/logging_config.py @@ -1,30 +1,184 @@ import logging +import sys +from typing import Any, Dict -def add_console_handler(logger, formatter): - # does it already exist? if so, return None - for handler in logger.handlers: - if isinstance(handler, logging.StreamHandler): - return None - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) +def setup_logging( + log_level: str = "INFO", detailed_websockets: bool = False +) -> logging.Logger: + """ + Setup comprehensive logging for the entire Runware SDK with proper string level support. + Args: + log_level: String log level ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL') + detailed_websockets: Whether to enable detailed websocket message logging -def configure_logging(log_level=logging.DEBUG): - # Create a formatter - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + Returns: + The main SDK logger instance + """ + if isinstance(log_level, str): + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError( + f"Invalid log level: {log_level}. Valid levels: DEBUG, INFO, WARNING, ERROR, CRITICAL" + ) + else: + numeric_level = log_level + + # Create comprehensive formatter with more detailed format + detailed_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d:%(funcName)s] - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Create simple formatter for less verbose components + simple_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) - logger = logging.getLogger(__name__) - logger.setLevel(log_level) - add_console_handler(logger, formatter) + # Setup root logger for the SDK + sdk_logger = logging.getLogger("runware") + sdk_logger.setLevel(numeric_level) + + # Remove existing handlers to avoid duplicates + for handler in sdk_logger.handlers[:]: + sdk_logger.removeHandler(handler) + + # Add console handler for main SDK + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(detailed_formatter) + console_handler.setLevel(numeric_level) + sdk_logger.addHandler(console_handler) + + # Setup websockets logger with configurable detail level + websockets_logger = logging.getLogger("websockets") + + # Remove existing handlers + for handler in websockets_logger.handlers[:]: + websockets_logger.removeHandler(handler) + + if detailed_websockets: + # Detailed websocket logging when requested + websockets_logger.setLevel(max(numeric_level, logging.DEBUG)) + websockets_logger.propagate = False + + websocket_handler = logging.StreamHandler(sys.stdout) + websocket_handler.setFormatter(detailed_formatter) + websocket_handler.setLevel(max(numeric_level, logging.DEBUG)) + websockets_logger.addHandler(websocket_handler) + else: + # Standard websocket logging + websockets_logger.setLevel(max(numeric_level, logging.INFO)) + websockets_logger.propagate = False + + websocket_handler = logging.StreamHandler(sys.stdout) + websocket_handler.setFormatter(simple_formatter) + websocket_handler.setLevel(max(numeric_level, logging.INFO)) + websockets_logger.addHandler(websocket_handler) + + # Setup asyncio logger with appropriate level asyncio_logger = logging.getLogger("asyncio") - asyncio_logger.setLevel(log_level) - add_console_handler(asyncio_logger, formatter) + asyncio_logger.setLevel( + max(numeric_level, logging.WARNING) + ) # Only warnings and errors for asyncio + + # Configure specific component loggers with appropriate levels + component_configs = { + "runware.client": numeric_level, + "runware.operations": numeric_level, + "runware.connection": numeric_level, + "runware.messaging": numeric_level, + "runware.core": numeric_level, + } + + for component_name, level in component_configs.items(): + component_logger = logging.getLogger(component_name) + component_logger.setLevel(level) + + # Log the successful setup + sdk_logger.info(f"Runware SDK logging initialized with level: {log_level}") + if detailed_websockets: + sdk_logger.debug("Detailed websocket logging enabled") + + return sdk_logger + +def get_logger(name: str) -> logging.Logger: + """ + Get a logger for a specific module within the SDK. + + Args: + name: The name of the module/component requesting the logger + + Returns: + A configured logger instance for the specified component + """ + if not name.startswith("runware"): + if name == "__main__": + logger_name = "runware.main" + elif "." in name: + module_parts = name.split(".") + if len(module_parts) >= 2: + logger_name = f"runware.{module_parts[-2]}.{module_parts[-1]}" + else: + logger_name = f"runware.{module_parts[-1]}" + else: + logger_name = f"runware.{name}" + else: + logger_name = name + + return logging.getLogger(logger_name) + + +def configure_component_logging(component_name: str, level: str) -> logging.Logger: + """ + Configure logging for a specific component with a custom level. + + Args: + component_name: Name of the component (e.g., 'websockets', 'connection') + level: Log level for this component + + Returns: + The configured logger for the component + """ + if isinstance(level, str): + numeric_level = getattr(logging, level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid log level: {level}") + else: + numeric_level = level + + logger = get_logger(component_name) + logger.setLevel(numeric_level) + + return logger + + +def get_logging_stats() -> Dict[str, Any]: + """ + Get statistics about the current logging configuration. + + Returns: + Dictionary containing logging statistics and configuration info + """ + runware_logger = logging.getLogger("runware") websockets_logger = logging.getLogger("websockets") - websockets_logger.setLevel(log_level) - websockets_logger.propagate = False - add_console_handler(websockets_logger, formatter) + asyncio_logger = logging.getLogger("asyncio") + + return { + "main_level": logging.getLevelName(runware_logger.level), + "websockets_level": logging.getLevelName(websockets_logger.level), + "asyncio_level": logging.getLevelName(asyncio_logger.level), + "handlers_count": { + "runware": len(runware_logger.handlers), + "websockets": len(websockets_logger.handlers), + "asyncio": len(asyncio_logger.handlers), + }, + "effective_levels": { + "runware": logging.getLevelName(runware_logger.getEffectiveLevel()), + "websockets": logging.getLevelName(websockets_logger.getEffectiveLevel()), + "asyncio": logging.getLevelName(asyncio_logger.getEffectiveLevel()), + }, + } + diff --git a/runware/messaging/__init__.py b/runware/messaging/__init__.py new file mode 100644 index 0000000..57a85a5 --- /dev/null +++ b/runware/messaging/__init__.py @@ -0,0 +1,22 @@ +# Import router separately to avoid circular imports +from .router import MessageRouter +from .types import CompletionCallback, MessageHandler, ProgressCallback +from ..core.types import ( + Message, + MessageType, + OperationContext, + OperationStatus, + ProgressUpdate, +) + +__all__ = [ + "MessageType", + "OperationStatus", + "Message", + "OperationContext", + "ProgressUpdate", + "MessageHandler", + "ProgressCallback", + "CompletionCallback", + "MessageRouter", +] diff --git a/runware/messaging/router.py b/runware/messaging/router.py new file mode 100644 index 0000000..ea7166f --- /dev/null +++ b/runware/messaging/router.py @@ -0,0 +1,323 @@ +import asyncio +import time +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ..logging_config import get_logger + +if TYPE_CHECKING: + from .. import BaseOperation + +logger = get_logger(__name__) + + +class MessageRouter: + """Routes messages to appropriate operations.""" + + def __init__(self): + self.operations: Dict[str, "BaseOperation"] = {} + self.message_queue = asyncio.Queue(maxsize=1000) + self.router_task: Optional[asyncio.Task] = None + self.is_running = False + + # Store messages for operations that aren't registered yet + self._pending_messages: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + self._message_ttl = 300.0 # 5 minutes + logger.info("MessageRouter initialized") + + async def start(self): + if self.is_running: + return + + self.is_running = True + self.router_task = asyncio.create_task(self._route_messages()) + logger.info("Message router started") + + async def stop(self): + if not self.is_running: + return + + self.is_running = False + + if self.router_task and not self.router_task.done(): + self.router_task.cancel() + try: + await asyncio.wait_for(self.router_task, timeout=3.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + self.router_task = None + + # Clear queues + while not self.message_queue.empty(): + try: + self.message_queue.get_nowait() + except asyncio.QueueEmpty: + break + + async def register_operation(self, operation: "BaseOperation"): + if operation.operation_id in self.operations: + logger.warning(f"Operation {operation.operation_id} already registered") + return + + self.operations[operation.operation_id] = operation + + # Deliver any pending messages for this operation + pending_messages = self._pending_messages.pop(operation.operation_id, []) + if pending_messages: + logger.info( + f"Delivering {len(pending_messages)} pending messages to operation {operation.operation_id}" + ) + for message in pending_messages: + await self._deliver_message_to_operation(operation, message) + + logger.info( + f"Registered operation {operation.operation_id} ({operation.operation_type})" + ) + + async def unregister_operation(self, operation_id: str): + if operation_id in self.operations: + del self.operations[operation_id] + logger.debug(f"Unregistered operation {operation_id}") + + # Remove any pending messages for this operation + if operation_id in self._pending_messages: + del self._pending_messages[operation_id] + + async def route_message(self, message: Dict[str, Any]): + if not self.is_running: + logger.warning("Message router not running, dropping message") + return + + try: + await asyncio.wait_for(self.message_queue.put(message), timeout=5.0) + except asyncio.TimeoutError: + logger.error("Message queue full, dropping message") + + async def _route_messages(self): + logger.debug("Starting message routing loop") + + try: + while self.is_running: + try: + message = await asyncio.wait_for( + self.message_queue.get(), timeout=1.0 + ) + await self._process_message(message) + except asyncio.TimeoutError: + # Cleanup expired pending messages + await self._cleanup_expired_messages() + continue + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Error in message routing loop", exc_info=e) + except asyncio.CancelledError: + pass + finally: + logger.debug("Message routing loop stopped") + + async def _process_message(self, raw_message: Dict[str, Any]): + try: + logger.debug(f"Processing raw message: {raw_message}") + if self._is_error_message(raw_message): + logger.error(f"Processing ERROR message: {raw_message}") + await self._handle_error_message(raw_message) + elif self._is_data_message(raw_message): + await self._handle_data_message(raw_message) + else: + logger.warning(f"Unknown message format: {raw_message}") + + except Exception as e: + logger.error("Error processing message", exc_info=e) + + def _is_data_message(self, message: Dict[str, Any]) -> bool: + has_data = ( + "data" in message + and isinstance(message["data"], list) + and len(message["data"]) > 0 + ) + has_errors = ( + "errors" in message + and isinstance(message["errors"], list) + and len(message["errors"]) > 0 + ) + return has_data and not has_errors + + def _is_error_message(self, message: Dict[str, Any]) -> bool: + return ( + "errors" in message + and isinstance(message["errors"], list) + and len(message["errors"]) > 0 + ) + + async def _handle_data_message(self, raw_message: Dict[str, Any]): + data_items = raw_message.get("data", []) + logger.debug(f"Handling data message with {len(data_items)} items") + + for item in data_items: + operation_id = item.get("taskUUID") + task_type = item.get("taskType") + + logger.debug( + f"Processing data item - taskUUID: {operation_id}, taskType: {task_type}" + ) + + if not operation_id: + logger.warning(f"Data item missing taskUUID: {item}") + continue + + operation = self.operations.get(operation_id) + if operation: + logger.debug(f"Delivering message to operation {operation_id}") + await self._deliver_message_to_operation(operation, item) + else: + logger.debug( + f"Storing message for unregistered operation {operation_id}" + ) + item["_received_at"] = time.time() + self._pending_messages[operation_id].append(item) + + async def _handle_error_message(self, raw_message: Dict[str, Any]): + error_items = raw_message.get("errors", []) + logger.error(f"Handling error message with {len(error_items)} error items") + + for item in error_items: + operation_id = item.get("taskUUID") + error_message = item.get("message", "Unknown error") + error_code = item.get("code") + + logger.error( + f"Processing server error - taskUUID: {operation_id}, message: {error_message}, code: {error_code}" + ) + logger.error(f"Full error item: {item}") + + error_message_data = {**item, "taskType": "error"} + + if operation_id and operation_id != "N/A": + operation = self.operations.get(operation_id) + if operation: + logger.error(f"Delivering error to operation {operation_id}") + await self._deliver_message_to_operation( + operation, error_message_data + ) + else: + # Store error for later delivery + logger.error( + f"Storing error for unregistered operation {operation_id}" + ) + error_message_data["_received_at"] = time.time() + self._pending_messages[operation_id].append(error_message_data) + else: + # Unmatched error - deliver to all active operations + logger.error( + f"Unmatched error, delivering to all active operations: {error_message}" + ) + await self._handle_unmatched_error(error_message_data) + + async def _handle_unmatched_error(self, error_message: Dict[str, Any]): + """Handle errors that don't have a specific taskUUID.""" + if not self.operations: + logger.warning("No active operations to deliver unmatched error to") + return + + error_text = error_message.get("message", "") + + target_operations = [] + + if "image" in error_text.lower(): + target_operations = [ + op + for op in self.operations.values() + if op.operation_type + in ["imageInference", "imageGeneration", "photoMaker"] + ] + elif "video" in error_text.lower(): + target_operations = [ + op + for op in self.operations.values() + if op.operation_type == "videoInference" + ] + + if not target_operations: + target_operations = [ + max(self.operations.values(), key=lambda op: op.created_at) + ] + logger.warning( + f"Delivering unmatched error to most recent operation: {target_operations[0].operation_id}" + ) + + for operation in target_operations: + enhanced_error = { + **error_message, + "taskUUID": operation.operation_id, + "message": f"Server error: {error_text}", + } + logger.debug( + f"Delivering unmatched error to operation {operation.operation_id}" + ) + await self._deliver_message_to_operation(operation, enhanced_error) + + async def _deliver_message_to_operation(self, operation, message: Dict[str, Any]): + try: + task_type = message.get("taskType") + operation_id = operation.operation_id + + logger.debug( + f"Delivering message to operation {operation_id}: taskType={task_type}" + ) + + # Special logging for different message types + match task_type: + case "getResponse": + status = message.get("status") + error_info = message.get("error") or message.get("message", "") + logger.info( + f"Delivering getResponse to operation {operation_id}: status={status}, error={error_info}, full_message={message}" + ) + case "videoInference": + status = message.get("status") + error_info = message.get("error", "") + logger.info( + f"Delivering videoInference to operation {operation_id}: status={status}, error={error_info}" + ) + case "error": + error_msg = message.get("message", "Unknown error") + error_code = message.get("code", "") + logger.error( + f"Delivering ERROR to operation {operation_id}: {error_msg} (code: {error_code})" + ) + + await operation.handle_message(message) + logger.debug(f"Successfully delivered message to operation {operation_id}") + + except Exception as e: + logger.error( + f"Error delivering message to operation {operation.operation_id}", + exc_info=e, + ) + + async def _cleanup_expired_messages(self): + """Remove messages that are too old.""" + current_time = time.time() + expired_operations = [] + total_expired = 0 + + for operation_id, messages in self._pending_messages.items(): + valid_messages = [] + for message in messages: + received_at = message.get("_received_at", current_time) + if current_time - received_at < self._message_ttl: + valid_messages.append(message) + else: + total_expired += 1 + + if valid_messages: + self._pending_messages[operation_id] = valid_messages + else: + expired_operations.append(operation_id) + + for operation_id in expired_operations: + del self._pending_messages[operation_id] + + if total_expired > 0: + logger.debug(f"Cleaned up {total_expired} expired messages") diff --git a/runware/messaging/types.py b/runware/messaging/types.py new file mode 100644 index 0000000..0f67e04 --- /dev/null +++ b/runware/messaging/types.py @@ -0,0 +1,38 @@ +import time +from dataclasses import dataclass +from enum import Enum +from typing import Any, Awaitable, Callable, Dict + + +class MessageType(Enum): + IMAGE_INFERENCE = "imageInference" + VIDEO_INFERENCE = "videoInference" + PHOTO_MAKER = "photoMaker" + IMAGE_UPLOAD = "imageUpload" + IMAGE_UPSCALE = "imageUpscale" + IMAGE_BACKGROUND_REMOVAL = "imageBackgroundRemoval" + IMAGE_CAPTION = "imageCaption" + PROMPT_ENHANCE = "promptEnhance" + AUTHENTICATION = "authentication" + MODEL_UPLOAD = "modelUpload" + MODEL_SEARCH = "modelSearch" + GET_RESPONSE = "getResponse" + PING = "ping" + ERROR = "error" + + +@dataclass +class Message: + message_type: MessageType + operation_id: str + data: Dict[str, Any] + timestamp: float = None + + def __post_init__(self): + if self.timestamp is None: + self.timestamp = time.time() + + +MessageHandler = Callable[[Message], Awaitable[None]] +ProgressCallback = Callable[[Any], None] +CompletionCallback = Callable[[Any], None] diff --git a/runware/operations/__init__.py b/runware/operations/__init__.py new file mode 100644 index 0000000..30baba3 --- /dev/null +++ b/runware/operations/__init__.py @@ -0,0 +1,25 @@ +from .base import BaseOperation +from .image_background_removal import ImageBackgroundRemovalOperation +from .image_caption import ImageCaptionOperation +from .image_inference import ImageInferenceOperation +from .image_upscale import ImageUpscaleOperation +from .manager import OperationManager +from .model_search import ModelSearchOperation +from .model_upload import ModelUploadOperation +from .photo_maker import PhotoMakerOperation +from .prompt_enhance import PromptEnhanceOperation +from .video_inference import VideoInferenceOperation + +__all__ = [ + "BaseOperation", + "OperationManager", + "ImageInferenceOperation", + "VideoInferenceOperation", + "ImageCaptionOperation", + "ImageBackgroundRemovalOperation", + "ImageUpscaleOperation", + "PromptEnhanceOperation", + "PhotoMakerOperation", + "ModelUploadOperation", + "ModelSearchOperation", +] diff --git a/runware/operations/base.py b/runware/operations/base.py new file mode 100644 index 0000000..d24620a --- /dev/null +++ b/runware/operations/base.py @@ -0,0 +1,593 @@ +import asyncio +import time +import uuid +import weakref +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional + +from ..core.error_context import ErrorContext +from ..core.types import OperationContext, OperationStatus, ProgressUpdate +from ..exceptions import RunwareOperationError, RunwareTimeoutError, RunwareParseError +from ..logging_config import get_logger +from ..types import ETaskType +from ..utils import instantiateDataclass + +if TYPE_CHECKING: + from ..client import RunwareClient + +logger = get_logger(__name__) + + +class BaseOperation(ABC): + """ + Abstract base class for all Runware operations. + + The execution flow follows these steps: + 1. Initialization and validation + 2. Request payload building + 3. Request sending + 4. Response waiting and processing + 5. Cleanup + """ + + # Class-level configuration that subclasses should override + field_mappings: Dict[str, str] = {} + response_class: Any = None + status: OperationStatus + + def __init__( + self, operation_id: Optional[str] = None, client: "RunwareClient" = None + ): + """ + Initialize the base operation. + + Args: + operation_id: Unique identifier for this operation + client: Reference to the Runware client + """ + self.operation_id = operation_id or self._generate_operation_id() + + # Use weak reference to client to prevent circular references + self._client_ref = weakref.ref(client) if client else None + + # Operation state + self._initialize_state() + + # Logging + self.logger = get_logger(f"{self.__class__.__name__}.{self.operation_id}") + + # Event handling + self._initialize_events() + + # Message handling + self._message_handlers: Dict[ + str, Callable[[Dict[str, Any]], Awaitable[None]] + ] = {} + self._setup_message_handlers() + + # Timeout management + self._timeout_task: Optional[asyncio.Task] = None + self._is_cancelled = False + + self.logger.debug( + f"Operation {self.operation_id} created ({self.operation_type})" + ) + + def _generate_operation_id(self) -> str: + """Generate a unique operation ID.""" + return str(uuid.uuid4()) + + def _initialize_state(self): + """Initialize operation state variables.""" + self.status = OperationStatus.PENDING + self.created_at = time.time() + self.completed_at: Optional[float] = None + self.results: List[Any] = [] + self.error: Optional[Exception] = None + self.progress: float = 0.0 + self.metadata: Dict[str, Any] = {} + + def _initialize_events(self): + """Initialize event handling components.""" + self.completion_event = asyncio.Event() + self.progress_callbacks: List[Callable[[ProgressUpdate], None]] = [] + self.completion_callbacks: List[Callable[["BaseOperation"], None]] = [] + + @property + def client(self) -> Optional["RunwareClient"]: + """Get client from weak reference.""" + return self._client_ref() if self._client_ref else None + + @property + @abstractmethod + def operation_type(self) -> str: + """Return the type of this operation. Must be implemented by subclasses.""" + pass + + async def execute(self) -> Any: + """ + Main execution method implementing the Template Method pattern. + + This method defines the skeleton of the operation execution algorithm. + Subclasses can override hook methods to customize specific steps. + """ + error_ctx = ErrorContext(self.operation_id, self.operation_type) + + try: + # Step 1: Pre-execution initialization + async with error_ctx.phase("initialization"): + await self._pre_execution_hook() + await self._validate_operation() + + # Step 2: Build request payload + async with error_ctx.phase("build_payload"): + request_payload = await self._build_request_payload() + await self._post_payload_build_hook(request_payload) + + # Step 3: Send request + async with error_ctx.phase("send_request"): + await self._send_request(request_payload) + await self._post_request_send_hook() + + # Step 4: Wait for completion and process results + async with error_ctx.phase("wait_completion"): + results = await self._wait_for_results() + processed_results = await self._process_results(results) + + # Step 5: Post-execution cleanup and finalization + async with error_ctx.phase("finalization"): + await self._post_execution_hook(processed_results) + + self.logger.info(f"Operation {self.operation_id} completed successfully") + return processed_results + + except RunwareOperationError: + raise + except Exception as e: + wrapped = error_ctx.wrap_error(e) + await self._handle_error(wrapped) + raise wrapped + + # Abstract methods that subclasses must implement + @abstractmethod + async def _build_request_payload(self) -> List[Dict[str, Any]]: + """Build the request payload for this operation. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _setup_message_handlers(self): + """Setup message handlers for this operation. Must be implemented by subclasses.""" + pass + + async def _pre_execution_hook(self): + """Hook called before operation execution begins.""" + self.logger.info(f"Operation {self.operation_id} starting execution") + + async def _post_payload_build_hook(self, request_payload: List[Dict[str, Any]]): + """Hook called after request payload is built.""" + self.logger.debug(f"Operation {self.operation_id} built request payload") + + async def _post_request_send_hook(self): + """Hook called after request is sent.""" + self.logger.debug(f"Operation {self.operation_id} request sent") + + async def _post_execution_hook(self, results: Any): + """Hook called after successful execution.""" + self.logger.debug(f"Operation {self.operation_id} execution completed") + + # Core operation methods + async def _validate_operation(self): + """Validate operation preconditions.""" + if not self.client: + raise RunwareOperationError( + "No client available for operation", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + + if not self.client.connection_manager: + raise RunwareOperationError( + "No connection manager available", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + + async def _send_request(self, request_payload: List[Dict[str, Any]]): + """Send the request payload to the server.""" + if self.client and self.client.connection_manager: + self.logger.debug(f"Operation {self.operation_id} sending request payload") + await self.client.connection_manager.send_message(request_payload) + else: + raise RunwareOperationError( + "No connection manager available", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + + async def _wait_for_results(self) -> Any: + """Wait for operation completion and return results.""" + return await self.wait_for_completion() + + async def _process_results(self, results: Any) -> Any: + """Process and validate results before returning.""" + return results + + # Operation lifecycle methods + async def start(self, timeout: Optional[float] = None) -> Any: + """ + Start the operation execution. + + Args: + timeout: Optional timeout in seconds + + Returns: + Operation results + """ + if self.status != OperationStatus.PENDING: + raise RunwareOperationError( + f"Operation {self.operation_id} cannot be started in status {self.status.value}", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + + self.status = OperationStatus.EXECUTING + self.logger.info( + f"Starting operation {self.operation_id} ({self.operation_type})" + ) + + if timeout: + self._timeout_task = asyncio.create_task(self._handle_timeout(timeout)) + + try: + return await self.execute() + except Exception as e: + self.logger.error( + f"Operation {self.operation_id} execution failed", exc_info=e + ) + await self._handle_error(e) + raise + finally: + await self._cleanup() + + async def wait_for_completion(self, timeout: Optional[float] = None) -> Any: + """ + Wait for the operation to complete. + + Args: + timeout: Optional timeout in seconds + + Returns: + Operation results + """ + if self.status == OperationStatus.COMPLETED: + return self.get_results() + + if self.status in [OperationStatus.FAILED, OperationStatus.TIMEOUT]: + if self.error: + raise self.error + raise RunwareOperationError( + f"Operation {self.operation_id} failed with status {self.status.value}", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + + try: + await asyncio.wait_for(self.completion_event.wait(), timeout) + return self.get_results() + except asyncio.TimeoutError: + await self.cancel() + raise RunwareTimeoutError( + f"Operation {self.operation_id} timed out after {timeout}s", + timeout_duration=timeout, + ) + + async def cancel(self): + """Cancel the operation.""" + if self.status in [ + OperationStatus.COMPLETED, + OperationStatus.FAILED, + OperationStatus.CANCELLED, + OperationStatus.TIMEOUT, + ]: + return + + if self._is_cancelled: + return + + self._is_cancelled = True + self.status = OperationStatus.CANCELLED + self.logger.info(f"Cancelling operation {self.operation_id}") + + self.completion_event.set() + await self._cleanup() + + def get_results(self) -> Any: + """Get operation results.""" + if ( + self.status in [OperationStatus.FAILED, OperationStatus.TIMEOUT] + and self.error + ): + raise self.error + return self.results + + def get_context(self) -> OperationContext: + """Get operation context for monitoring.""" + return OperationContext( + operation_id=self.operation_id, + operation_type=self.operation_type, + status=self.status, + created_at=self.created_at, + completed_at=self.completed_at, + results=self.results, + error=self.error, + progress=self.progress, + metadata=self.metadata, + ) + + # Event handling methods + def add_progress_callback(self, callback: Callable[[ProgressUpdate], None]): + """Add a progress callback.""" + self.progress_callbacks.append(callback) + + def add_completion_callback(self, callback: Callable[["BaseOperation"], None]): + """Add a completion callback.""" + self.completion_callbacks.append(callback) + + # Message handling methods + async def handle_message(self, message: Dict[str, Any]): + """ + Handle incoming messages from the server. + + Args: + message: Message from the server + """ + # Check if operation is already completed or cancelled + if self.status in [ + OperationStatus.COMPLETED, + OperationStatus.CANCELLED, + OperationStatus.FAILED, + ]: + return + + try: + message_type = message.get("taskType") + handler = self._message_handlers.get(message_type) + + if handler: + await handler(message) + else: + await self._handle_unknown_message(message) + + except Exception as e: + self.logger.error( + f"Error handling message for operation {self.operation_id}", exc_info=e + ) + await self._handle_error(e) + + async def _handle_unknown_message(self, message: Dict[str, Any]): + """Handle unknown message types.""" + message_type = message.get("taskType", "unknown") + self.logger.warning( + f"Unknown message type '{message_type}' for operation {self.operation_id}" + ) + + # Error handling and completion methods + async def _handle_error(self, error: Exception): + """Handle operation errors.""" + if self.status in [ + OperationStatus.COMPLETED, + OperationStatus.CANCELLED, + OperationStatus.FAILED, + OperationStatus.TIMEOUT, + ]: + return + + self.logger.error( + f"Operation {self.operation_id} encountered error: {str(error)}", + exc_info=error, + ) + self.error = error + + # Set appropriate status based on error type + if isinstance(error, RunwareTimeoutError): + self.status = OperationStatus.TIMEOUT + else: + self.status = OperationStatus.FAILED + + self.completed_at = time.time() + self.completion_event.set() + await self._notify_completion() + + async def _handle_timeout(self, timeout: float): + """Handle operation timeout.""" + try: + await asyncio.sleep(timeout) + if self.status == OperationStatus.EXECUTING: + self.logger.error( + f"Operation {self.operation_id} timed out after {timeout}s" + ) + await self._handle_error( + RunwareTimeoutError( + f"Operation {self.operation_id} timed out after {timeout}s", + timeout_duration=timeout, + ) + ) + except asyncio.CancelledError: + pass + + async def _complete_operation(self, results: Any = None): + """Complete the operation successfully.""" + if self.status in [ + OperationStatus.COMPLETED, + OperationStatus.CANCELLED, + OperationStatus.FAILED, + OperationStatus.TIMEOUT, + ]: + return + + if results is not None: + if isinstance(results, list): + self.results.extend(results) + else: + self.results.append(results) + + self.logger.info( + f"Operation {self.operation_id} completed with {len(self.results)} results" + ) + self.status = OperationStatus.COMPLETED + self.completed_at = time.time() + self.progress = 1.0 + + self.completion_event.set() + await self._notify_progress() + await self._notify_completion() + + # Progress and notification methods + async def _update_progress( + self, + progress: float, + message: Optional[str] = None, + partial_results: Optional[List[Any]] = None, + ): + """Update operation progress.""" + if self.status in [ + OperationStatus.COMPLETED, + OperationStatus.CANCELLED, + OperationStatus.FAILED, + ]: + return + + self.progress = max(0.0, min(1.0, progress)) + + if partial_results: + self.results.extend(partial_results) + + await self._notify_progress(message, partial_results) + + async def _notify_progress( + self, message: Optional[str] = None, partial_results: Optional[List[Any]] = None + ): + """Notify progress callbacks.""" + if not self.progress_callbacks: + return + + progress_update = ProgressUpdate( + operation_id=self.operation_id, + progress=self.progress, + message=message, + partial_results=partial_results, + ) + + for callback in self.progress_callbacks: + try: + callback(progress_update) + except Exception as e: + self.logger.error( + f"Error in progress callback for operation {self.operation_id}", + exc_info=e, + ) + + async def _notify_completion(self): + """Notify completion callbacks.""" + for callback in self.completion_callbacks: + try: + callback(self) + except Exception as e: + self.logger.error( + f"Error in completion callback for operation {self.operation_id}", + exc_info=e, + ) + + # Response parsing methods + def _parse_response(self, message: Dict[str, Any]) -> Any: + """ + Parse server response into expected data structure. + + This method uses the field_mappings and response_class defined by subclasses + to convert server messages into properly typed response objects. + """ + if not self.field_mappings or not self.response_class: + raise RunwareParseError( + f"Operation missing field_mappings or response_class", + operation_id=self.operation_id, + operation_type=self.operation_type, + raw_data=message, + ) + + try: + processed_fields = {} + + for field_name, message_key in self.field_mappings.items(): + if message_key in message: + value = message[message_key] + + if field_name == "taskType" and hasattr(ETaskType, "value"): + # Only convert to enum if the response class expects an enum + if hasattr(self.response_class, "__annotations__"): + field_type = self.response_class.__annotations__.get( + field_name + ) + if field_type == ETaskType: + processed_fields[field_name] = ETaskType(value) + else: + processed_fields[field_name] = value + else: + processed_fields[field_name] = value + elif field_name == "cost" and value is not None: + processed_fields[field_name] = float(value) + else: + processed_fields[field_name] = value + + return instantiateDataclass(self.response_class, processed_fields) + + except Exception as e: + self.logger.error( + f"Operation {self.operation_id} failed to parse response", + exc_info=e, + ) + raise RunwareParseError( + f"Failed to parse response: {e}", + operation_id=self.operation_id, + operation_type=self.operation_type, + raw_data=message, + ) + + async def _handle_error_message(self, message: Dict[str, Any]): + error_message = message.get("message", "Unknown error") + + logger.error(f"Operation {self.operation_id} received error: {error_message}") + + error = RunwareOperationError( + f"{self.__class__.__name__} error: {error_message}", + operation_id=self.operation_id, + operation_type=self.operation_type, + code=message.get("code"), + parameter=message.get("parameter"), + error_type=message.get("type"), + documentation=message.get("documentation"), + task_uuid=message.get("taskUUID", self.operation_id), + ) + + await self._handle_error(error) + + # Cleanup methods + async def _cleanup(self): + """Cleanup operation resources.""" + if self._timeout_task and not self._timeout_task.done(): + self._timeout_task.cancel() + try: + await asyncio.wait_for(self._timeout_task, timeout=1.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + self._timeout_task = None + + # Clear callbacks to prevent memory leaks + self.progress_callbacks.clear() + self.completion_callbacks.clear() + + # String representation methods + def __str__(self): + return f"{self.__class__.__name__}({self.operation_id}, {self.status.value})" + + def __repr__(self): + return self.__str__() diff --git a/runware/operations/image_background_removal.py b/runware/operations/image_background_removal.py new file mode 100644 index 0000000..14cf15f --- /dev/null +++ b/runware/operations/image_background_removal.py @@ -0,0 +1,90 @@ +from typing import Any, Dict, List + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..logging_config import get_logger +from ..types import ETaskType, IImage, IImageBackgroundRemoval + +logger = get_logger(__name__) + + +class ImageBackgroundRemovalOperation(BaseOperation): + field_mappings = { + "taskType": "taskType", + "imageUUID": "imageUUID", + "taskUUID": "taskUUID", + "inputImageUUID": "inputImageUUID", + "imageURL": "imageURL", + "imageBase64Data": "imageBase64Data", + "imageDataURI": "imageDataURI", + "cost": "cost", + } + response_class = IImage + + def __init__(self, request: IImageBackgroundRemoval, client=None): + super().__init__(request.taskUUID, client) + self.request = request + self._image_uploaded = None + + logger.info( + f"Image background removal operation {self.operation_id} initialized" + ) + + @property + def operation_type(self) -> str: + return "imageBackgroundRemoval" + + def _setup_message_handlers(self): + self._message_handlers = { + "imageBackgroundRemoval": self._handle_background_removal, + "error": self._handle_error_message, + } + + async def _build_request_payload(self) -> List[Dict[str, Any]]: + self._image_uploaded = await self.client.uploadImage(self.request.inputImage) + task_params = await cpu_executor.serialize_dataclass( + { + "taskType": ETaskType.IMAGE_BACKGROUND_REMOVAL.value, + "taskUUID": self.operation_id, + "inputImage": ( + self._image_uploaded.imageUUID + if self._image_uploaded + else self.request.inputImage + ), + } + ) + + optional_fields = { + "outputType": self.request.outputType, + "outputFormat": self.request.outputFormat, + "includeCost": self.request.includeCost, + "model": self.request.model, + "outputQuality": self.request.outputQuality, + } + + for key, value in optional_fields.items(): + if value is not None: + task_params[key] = value + + if self.request.settings: + settings_dict = { + k: v for k, v in vars(self.request.settings).items() if v is not None + } + task_params.update(settings_dict) + + return [task_params] + + async def _handle_background_removal(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling background removal message: {message}" + ) + image_data = self._parse_response(message) + await self._complete_operation([image_data]) + + except Exception as e: + logger.error( + f"Error handling background removal message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) diff --git a/runware/operations/image_caption.py b/runware/operations/image_caption.py new file mode 100644 index 0000000..cda3374 --- /dev/null +++ b/runware/operations/image_caption.py @@ -0,0 +1,75 @@ +from typing import Any, Dict + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..logging_config import get_logger +from ..types import ETaskType, IImageCaption, IImageToText + +logger = get_logger(__name__) + + +class ImageCaptionOperation(BaseOperation): + field_mappings = { + "taskType": "taskType", + "taskUUID": "taskUUID", + "text": "text", + "cost": "cost", + } + response_class = IImageToText + + def __init__(self, request: IImageCaption, client=None): + super().__init__(operation_id=None, client=client) + self.request = request + self._image_uploaded = None + + logger.info(f"Image caption operation {self.operation_id} initialized") + + @property + def operation_type(self) -> str: + return "imageCaption" + + async def execute(self) -> IImageToText | None: + results = await super().execute() + if results is not None: + return results[0] + return None + + def _setup_message_handlers(self): + self._message_handlers = { + "imageCaption": self._handle_image_caption, + "error": self._handle_error_message, + } + + async def _build_request_payload(self) -> list[Dict[str, Any]]: + self._image_uploaded = await self.client.uploadImage(self.request.inputImage) + task_params = await cpu_executor.serialize_dataclass( + { + "taskType": ETaskType.IMAGE_CAPTION.value, + "taskUUID": self.operation_id, + "inputImage": ( + self._image_uploaded.imageUUID + if self._image_uploaded + else self.request.inputImage + ), + } + ) + + if self.request.includeCost: + task_params["includeCost"] = self.request.includeCost + + return [task_params] + + async def _handle_image_caption(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling image caption message: {message}" + ) + image_to_text = self._parse_response(message) + await self._complete_operation([image_to_text]) + + except Exception as e: + logger.error( + f"Error handling image caption message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) diff --git a/runware/operations/image_inference.py b/runware/operations/image_inference.py new file mode 100644 index 0000000..b1a4c42 --- /dev/null +++ b/runware/operations/image_inference.py @@ -0,0 +1,400 @@ +from typing import Any, Dict, List, Optional + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..logging_config import get_logger +from ..types import ETaskType, IImage, IImageInference +from ..utils import process_image + +logger = get_logger(__name__) + + +class ImageInferenceOperation(BaseOperation): + field_mappings = { + "taskType": "taskType", + "imageUUID": "imageUUID", + "taskUUID": "taskUUID", + "seed": "seed", + "inputImageUUID": "inputImageUUID", + "imageURL": "imageURL", + "imageBase64Data": "imageBase64Data", + "imageDataURI": "imageDataURI", + "NSFWContent": "NSFWContent", + "cost": "cost", + } + response_class = IImage + + def __init__(self, request: IImageInference, client=None): + super().__init__(request.taskUUID, client) + self.request = request + self.expected_results = request.numberResults or 1 + self.received_results = 0 + self._processed_images: Dict[str, Any] = {} + + logger.info(f"Image inference operation {self.operation_id} initialized") + logger.debug( + f"Operation {self.operation_id} expects {self.expected_results} results" + ) + + @property + def operation_type(self) -> str: + return "imageInference" + + def _setup_message_handlers(self): + """Setup message handlers for different message types""" + self._message_handlers = { + # Primary message types for image inference + "imageInference": self._handle_image_inference, + "imageGeneration": self._handle_image_inference, + # Alternative message types the server might send + "image_inference": self._handle_image_inference, + "image_generation": self._handle_image_inference, + "inference": self._handle_image_inference, + "generation": self._handle_image_inference, + # Error handlers + "error": self._handle_error_message, + "Error": self._handle_error_message, + } + + logger.info( + f"Operation {self.operation_id} setup message handlers: {list(self._message_handlers.keys())}" + ) + + async def handle_message(self, message: Dict[str, Any]): + """Enhanced message handling with detailed logging""" + logger.debug(f"Operation {self.operation_id} received message: {message}") + + try: + message_type = message.get("taskType") + logger.debug(f"Operation {self.operation_id} message type: {message_type}") + + handler = self._message_handlers.get(message_type) + + if handler: + logger.debug( + f"Operation {self.operation_id} found handler for message type: {message_type}" + ) + await handler(message) + else: + logger.warning( + f"Operation {self.operation_id} no handler for message type: {message_type}" + ) + await self._handle_unknown_message(message) + + except Exception as e: + logger.error( + f"Error handling message for operation {self.operation_id}", exc_info=e + ) + await self._handle_error(e) + + async def _handle_unknown_message(self, message: Dict[str, Any]): + """Enhanced fallback handler for unknown message types""" + logger.warning(f"Operation {self.operation_id} handling unknown message type") + logger.debug(f"Operation {self.operation_id} unknown message: {message}") + + # Try to detect if this looks like an image inference result + has_image_fields = any( + field in message for field in ["imageUUID", "imageURL", "imageBase64Data"] + ) + has_task_uuid = message.get("taskUUID") == self.operation_id + + if has_image_fields and has_task_uuid: + logger.info( + f"Operation {self.operation_id} unknown message looks like image result, treating as image inference" + ) + await self._handle_image_inference(message) + else: + logger.warning( + f"Operation {self.operation_id} unknown message doesn't look like image result, ignoring" + ) + + async def _build_request_payload(self) -> List[Dict[str, Any]]: + logger.debug(f"Operation {self.operation_id} building request payload") + + control_net_data: List[Dict[str, Any]] = [] + + # Process images + try: + if self.request.maskImage: + logger.debug(f"Operation {self.operation_id} processing mask image") + self.request.maskImage = await process_image(self.request.maskImage) + if self.request.seedImage: + logger.debug(f"Operation {self.operation_id} processing seed image") + self.request.seedImage = await process_image(self.request.seedImage) + if self.request.referenceImages: + logger.debug( + f"Operation {self.operation_id} processing reference images" + ) + self.request.referenceImages = await process_image( + self.request.referenceImages + ) + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to process images", exc_info=e + ) + raise + + # Process ControlNet with ThreadPoolExecutor + if self.request.controlNet: + logger.debug( + f"Operation {self.operation_id} processing {len(self.request.controlNet)} ControlNet items" + ) + try: + # Prepare control data + control_items = [] + for i, control_data in enumerate(self.request.controlNet): + if self.client: + image_uploaded = await self.client.uploadImage( + control_data.guideImage + ) + if image_uploaded: + control_data.guideImage = image_uploaded.imageUUID + control_items.append(control_data) + + control_net_data = await cpu_executor.batch_serialize_dataclasses( + control_items + ) + + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to process ControlNet", + exc_info=e, + ) + raise + + # Process InstantID + instant_id_data = {} + if self.request.instantID: + logger.debug(f"Operation {self.operation_id} processing InstantID") + try: + instant_id_data = { + k: v + for k, v in vars(self.request.instantID).items() + if v is not None + } + if "inputImage" in instant_id_data: + instant_id_data["inputImage"] = await process_image( + instant_id_data["inputImage"] + ) + if "poseImage" in instant_id_data: + instant_id_data["poseImage"] = await process_image( + instant_id_data["poseImage"] + ) + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to process InstantID", + exc_info=e, + ) + raise + + # Process IP Adapters + ip_adapters_data = [] + if self.request.ipAdapters: + logger.debug( + f"Operation {self.operation_id} processing {len(self.request.ipAdapters)} IP Adapters" + ) + try: + for ip_adapter in self.request.ipAdapters: + ip_adapter_data = { + k: v for k, v in vars(ip_adapter).items() if v is not None + } + if "guideImage" in ip_adapter_data: + ip_adapter_data["guideImage"] = await process_image( + ip_adapter_data["guideImage"] + ) + ip_adapters_data.append(ip_adapter_data) + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to process IP Adapters", + exc_info=e, + ) + raise + + # Process ACE++ + ace_plus_plus_data = {} + if self.request.acePlusPlus: + logger.debug(f"Operation {self.operation_id} processing ACE++") + try: + ace_plus_plus_data = { + "inputImages": [], + "repaintingScale": self.request.acePlusPlus.repaintingScale, + "type": self.request.acePlusPlus.taskType, + } + if self.request.acePlusPlus.inputImages: + ace_plus_plus_data["inputImages"] = await process_image( + self.request.acePlusPlus.inputImages + ) + if self.request.acePlusPlus.inputMasks: + ace_plus_plus_data["inputMasks"] = await process_image( + self.request.acePlusPlus.inputMasks + ) + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to process ACE++", exc_info=e + ) + raise + + # Build main request object + request_object = { + "taskType": ETaskType.IMAGE_INFERENCE.value, + "taskUUID": self.operation_id, + "modelId": self.request.model, + "positivePrompt": self.request.positivePrompt.strip(), + "numberResults": self.expected_results, + } + + # Add optional fields + optional_fields = { + "steps": self.request.steps, + "height": self.request.height, + "width": self.request.width, + "controlNet": control_net_data if control_net_data else None, + "lora": ( + [ + {"model": lora.model, "weight": lora.weight} + for lora in self.request.lora + ] + if self.request.lora + else None + ), + "lycoris": ( + [ + {"model": lycoris.model, "weight": lycoris.weight} + for lycoris in self.request.lycoris + ] + if self.request.lycoris + else None + ), + "embeddings": ( + [{"model": embedding.model} for embedding in self.request.embeddings] + if self.request.embeddings + else None + ), + "seed": self.request.seed, + "refiner": self._build_refiner_data() if self.request.refiner else None, + "instantID": instant_id_data if instant_id_data else None, + "outpaint": ( + {k: v for k, v in vars(self.request.outpaint).items() if v is not None} + if self.request.outpaint + else None + ), + "ipAdapters": ip_adapters_data if ip_adapters_data else None, + "acePlusPlus": ace_plus_plus_data if ace_plus_plus_data else None, + "outputType": self.request.outputType, + "outputFormat": self.request.outputFormat, + "includeCost": self.request.includeCost, + "checkNSFW": self.request.checkNsfw, + "negativePrompt": self.request.negativePrompt, + "CFGScale": self.request.CFGScale, + "seedImage": self.request.seedImage, + "maskImage": self.request.maskImage, + "referenceImages": self.request.referenceImages, + "strength": self.request.strength, + "scheduler": self.request.scheduler, + "vae": self.request.vae, + "promptWeighting": self.request.promptWeighting, + "maskMargin": self.request.maskMargin, + "outputQuality": self.request.outputQuality, + } + + # Add accelerator options + if self.request.acceleratorOptions: + pipeline_options_dict = { + k: v + for k, v in vars(self.request.acceleratorOptions).items() + if v is not None + } + optional_fields["acceleratorOptions"] = pipeline_options_dict + + # Add advanced features + if self.request.advancedFeatures: + pipeline_options_dict = { + k: v.__dict__ + for k, v in vars(self.request.advancedFeatures).items() + if v is not None + } + optional_fields["advancedFeatures"] = pipeline_options_dict + + # Add non-null optional fields + for key, value in optional_fields.items(): + if value is not None: + request_object[key] = value + + # Add extra args + if hasattr(self.request, "extraArgs") and isinstance( + self.request.extraArgs, dict + ): + request_object.update(self.request.extraArgs) + + logger.debug( + f"Operation {self.operation_id} built request payload: {request_object}" + ) + return [request_object] + + def _build_refiner_data(self) -> Optional[Dict[str, Any]]: + if not self.request.refiner: + return None + + refiner_data = {"model": self.request.refiner.model} + + if self.request.refiner.startStep is not None: + refiner_data["startStep"] = self.request.refiner.startStep + + if self.request.refiner.startStepPercentage is not None: + refiner_data["startStepPercentage"] = ( + self.request.refiner.startStepPercentage + ) + + return refiner_data + + async def _handle_image_inference(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling image inference message: {message}" + ) + + image_uuid = message.get("imageUUID") + if not image_uuid: + logger.warning( + f"Operation {self.operation_id} image inference message missing imageUUID: {message}" + ) + return + + if image_uuid in self._processed_images: + logger.debug( + f"Operation {self.operation_id} already processed image {image_uuid}" + ) + return + + logger.info( + f"Operation {self.operation_id} processing new image {image_uuid}" + ) + + image_data = self._parse_response(message) + self._processed_images[image_uuid] = image_data + self.received_results += 1 + + progress = min(self.received_results / self.expected_results, 1.0) + + await self._update_progress(progress, partial_results=[image_data]) + + # Call user callback if provided + if self.request.onPartialImages: + try: + self.request.onPartialImages([image_data], None) + except Exception as e: + logger.error(f"Error in onPartialImages callback", exc_info=e) + + # Check if we have all expected results + if self.received_results >= self.expected_results: + logger.info( + f"Operation {self.operation_id} received all expected results, completing" + ) + await self._complete_operation() + + except Exception as e: + logger.error( + f"Error handling image inference message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) diff --git a/runware/operations/image_upscale.py b/runware/operations/image_upscale.py new file mode 100644 index 0000000..c8b6f69 --- /dev/null +++ b/runware/operations/image_upscale.py @@ -0,0 +1,89 @@ +from typing import Any, Dict, List + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..exceptions import RunwareOperationError +from ..logging_config import get_logger +from ..types import ETaskType, IImage, IImageUpscale + +logger = get_logger(__name__) + + +class ImageUpscaleOperation(BaseOperation): + field_mappings = { + "taskType": "taskType", + "imageUUID": "imageUUID", + "taskUUID": "taskUUID", + "inputImageUUID": "inputImageUUID", + "imageURL": "imageURL", + "imageBase64Data": "imageBase64Data", + "imageDataURI": "imageDataURI", + "cost": "cost", + } + response_class = IImage + + def __init__(self, request: IImageUpscale, client=None): + super().__init__(operation_id=None, client=client) + self.request = request + self._image_uploaded = None + + logger.info(f"Image upscale operation {self.operation_id} initialized") + + @property + def operation_type(self) -> str: + return "imageUpscale" + + def _setup_message_handlers(self): + self._message_handlers = { + "imageUpscale": self._handle_image_upscale, + "error": self._handle_error_message, + } + + async def _build_request_payload(self) -> List[Dict[str, Any]]: + logger.debug(f"Operation {self.operation_id} uploading image") + self._image_uploaded = await self.client.uploadImage(self.request.inputImage) + if not self._image_uploaded or not self._image_uploaded.imageUUID: + raise RunwareOperationError( + "Failed to upload image", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + task_params = await cpu_executor.serialize_dataclass( + { + "taskType": ETaskType.IMAGE_UPSCALE.value, + "taskUUID": self.operation_id, + "inputImage": ( + self._image_uploaded.imageUUID + if self._image_uploaded + else self.request.inputImage + ), + "upscaleFactor": self.request.upscaleFactor, + } + ) + + optional_fields = { + "outputType": self.request.outputType, + "outputFormat": self.request.outputFormat, + "includeCost": self.request.includeCost, + } + + for key, value in optional_fields.items(): + if value is not None: + task_params[key] = value + + return [task_params] + + async def _handle_image_upscale(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling image upscale message: {message}" + ) + image_data = self._parse_response(message) + await self._complete_operation([image_data]) + + except Exception as e: + logger.error( + f"Error handling image upscale message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) diff --git a/runware/operations/manager.py b/runware/operations/manager.py new file mode 100644 index 0000000..39210ba --- /dev/null +++ b/runware/operations/manager.py @@ -0,0 +1,286 @@ +import asyncio +import time +from typing import Any, Callable, Dict, List, Optional + +from .base import BaseOperation +from ..core.types import OperationContext, OperationStatus +from ..exceptions import RunwareOperationError, RunwareResourceError +from ..logging_config import get_logger + +logger = get_logger(__name__) + + +class OperationManager: + """Manages operation execution and lifecycle.""" + + def __init__( + self, max_concurrent_operations: int = 100, operation_timeout: float = 300.0 + ): + self.operations: Dict[str, BaseOperation] = {} + self.max_concurrent_operations = max_concurrent_operations + self.default_operation_timeout = operation_timeout + + self._completion_callbacks: List[Callable[[BaseOperation], None]] = [] + self._cleanup_task: Optional[asyncio.Task] = None + self._is_running = False + self._operation_semaphore = asyncio.Semaphore(max_concurrent_operations) + + logger.debug( + f"OperationManager initialized with max_concurrent_operations={max_concurrent_operations}" + ) + + async def start(self): + if self._is_running: + return + + self._is_running = True + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + logger.info( + f"Operation manager started (max concurrent: {self.max_concurrent_operations})" + ) + + async def stop(self): + if not self._is_running: + return + + self._is_running = False + + if self._cleanup_task and not self._cleanup_task.done(): + self._cleanup_task.cancel() + try: + await asyncio.wait_for(self._cleanup_task, timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + self._cleanup_task = None + + cancelled_count = await self._cancel_all_operations() + logger.info( + f"Operation manager stopped, cancelled {cancelled_count} operations" + ) + + async def register_operation(self, operation: BaseOperation) -> BaseOperation: + if len(self.operations) >= self.max_concurrent_operations: + raise RunwareResourceError( + f"Maximum concurrent operations limit reached ({self.max_concurrent_operations})", + resource_type="operation_slots", + ) + + if operation.operation_id in self.operations: + raise RunwareOperationError( + f"Operation {operation.operation_id} already registered", + operation_id=operation.operation_id, + operation_type=operation.operation_type, + ) + + self.operations[operation.operation_id] = operation + operation.add_completion_callback(self._on_operation_completed) + + logger.debug( + f"Registered operation {operation.operation_id} ({operation.operation_type})" + ) + return operation + + async def unregister_operation( + self, operation_id: str + ) -> Optional[OperationContext]: + operation = self.operations.pop(operation_id, None) + if not operation: + return None + + context = operation.get_context() + logger.debug( + f"Unregistered operation {operation_id}, status: {context.status.value}" + ) + return context + + async def execute_operation( + self, operation: BaseOperation, timeout: Optional[float] = None + ) -> Any: + operation_timeout = timeout or self.default_operation_timeout + logger.info( + f"Executing operation {operation.operation_id} ({operation.operation_type}) with timeout {operation_timeout}s" + ) + + async with self._operation_semaphore: + await self.register_operation(operation) + + try: + start_time = time.time() + result = await operation.start(operation_timeout) + execution_time = time.time() - start_time + logger.info( + f"Operation {operation.operation_id} completed successfully in {execution_time:.2f}s" + ) + return result + except Exception as e: + logger.error(f"Operation {operation.operation_id} failed", exc_info=e) + raise + finally: + await self.unregister_operation(operation.operation_id) + + def get_operation(self, operation_id: str) -> Optional[BaseOperation]: + return self.operations.get(operation_id) + + def get_operation_context(self, operation_id: str) -> Optional[OperationContext]: + operation = self.operations.get(operation_id) + if operation: + return operation.get_context() + return None + + def list_operations( + self, + status_filter: Optional[OperationStatus] = None, + operation_type_filter: Optional[str] = None, + ) -> List[OperationContext]: + contexts = [] + + for operation in self.operations.values(): + context = operation.get_context() + + if status_filter and context.status != status_filter: + continue + if ( + operation_type_filter + and context.operation_type != operation_type_filter + ): + continue + + contexts.append(context) + + return contexts + + async def cancel_operation(self, operation_id: str) -> bool: + operation = self.operations.get(operation_id) + if not operation: + return False + + logger.info(f"Cancelling operation {operation_id}") + await operation.cancel() + return True + + async def cancel_operations_by_type(self, operation_type: str) -> int: + operations_to_cancel = [ + op for op in self.operations.values() if op.operation_type == operation_type + ] + + logger.info( + f"Cancelling {len(operations_to_cancel)} operations of type {operation_type}" + ) + + for operation in operations_to_cancel: + try: + await operation.cancel() + except Exception as e: + logger.error( + f"Error cancelling operation {operation.operation_id}", exc_info=e + ) + + return len(operations_to_cancel) + + async def cancel_all_operations(self) -> int: + return await self._cancel_all_operations() + + def add_completion_callback(self, callback: Callable[[BaseOperation], None]): + self._completion_callbacks.append(callback) + + def remove_completion_callback(self, callback: Callable[[BaseOperation], None]): + if callback in self._completion_callbacks: + self._completion_callbacks.remove(callback) + + async def wait_for_all_operations(self, timeout: Optional[float] = None) -> bool: + if not self.operations: + return True + + operations = list(self.operations.values()) + logger.info(f"Waiting for {len(operations)} operations to complete") + tasks = [op.wait_for_completion() for op in operations] + + try: + if timeout: + await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions=True), timeout + ) + else: + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("All operations completed") + return True + except asyncio.TimeoutError: + logger.warning( + f"Timeout waiting for operations to complete after {timeout}s" + ) + return False + + def _on_operation_completed(self, operation: BaseOperation): + logger.debug( + f"Operation {operation.operation_id} completed with status: {operation.status.value}" + ) + + for callback in self._completion_callbacks: + try: + callback(operation) + except Exception as e: + logger.error("Error in completion callback", exc_info=e) + + async def _cancel_all_operations(self) -> int: + operations_to_cancel = list(self.operations.values()) + + if not operations_to_cancel: + return 0 + + logger.info(f"Cancelling {len(operations_to_cancel)} operations") + cancel_tasks = [] + + for op in operations_to_cancel: + try: + cancel_tasks.append(op.cancel()) + except Exception as e: + logger.error( + f"Error creating cancel task for operation {op.operation_id}", + exc_info=e, + ) + + if cancel_tasks: + try: + await asyncio.wait_for( + asyncio.gather(*cancel_tasks, return_exceptions=True), timeout=5.0 + ) + except asyncio.TimeoutError: + logger.warning("Some operations did not cancel within timeout") + + return len(operations_to_cancel) + + async def _periodic_cleanup(self): + logger.debug("Starting periodic cleanup loop") + try: + while self._is_running: + try: + await asyncio.sleep(60.0) # Check every minute + await self._cleanup_completed_operations() + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Error in periodic cleanup", exc_info=e) + if not self._is_running: + break + except asyncio.CancelledError: + pass + finally: + logger.debug("Periodic cleanup stopped") + + async def _cleanup_completed_operations(self): + operations_to_remove = [] + + for operation_id, operation in self.operations.items(): + if operation.status in [ + OperationStatus.COMPLETED, + OperationStatus.FAILED, + OperationStatus.CANCELLED, + OperationStatus.TIMEOUT, + ]: + operations_to_remove.append(operation_id) + + for operation_id in operations_to_remove: + await self.unregister_operation(operation_id) + + if operations_to_remove: + logger.debug(f"Cleaned up {len(operations_to_remove)} completed operations") diff --git a/runware/operations/model_search.py b/runware/operations/model_search.py new file mode 100644 index 0000000..810cf22 --- /dev/null +++ b/runware/operations/model_search.py @@ -0,0 +1,143 @@ +from dataclasses import fields +from typing import Any, Dict + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..exceptions import RunwareOperationError +from ..logging_config import get_logger +from ..types import ETaskType, IModel, IModelSearch, IModelSearchResponse + +logger = get_logger(__name__) + + +class ModelSearchOperation(BaseOperation): + field_mappings = { + "taskUUID": "taskUUID", + "taskType": "taskType", + "totalResults": "totalResults", + } + response_class = IModelSearchResponse + + def __init__(self, request: IModelSearch, client=None): + super().__init__(operation_id=None, client=client) + self.request = request + logger.info(f"Model search operation {self.operation_id} initialized") + + @property + def operation_type(self) -> str: + return "modelSearch" + + async def execute(self) -> IModelSearchResponse: + results = await super().execute() + if results is not None: + return results[0] + return None + + def _setup_message_handlers(self): + self._message_handlers = { + "modelSearch": self._handle_model_search, + "error": self._handle_error_message, + } + + async def _build_request_payload(self) -> list[Dict[str, Any]]: + request_object = await cpu_executor.serialize_dataclass( + { + "taskUUID": self.operation_id, + "taskType": ETaskType.MODEL_SEARCH.value, + } + ) + + # Add tags if present + if self.request.tags: + request_object["tags"] = self.request.tags + + # Add all other fields from payload, excluding additional_params + for key, value in vars(self.request).items(): + if value is not None and key != "additional_params": + request_object[key] = value + + return [request_object] + + def _parse_response(self, message: Dict[str, Any]) -> IModelSearchResponse: + """ + Override base _parse_response to handle complex model search response. + """ + try: + # Parse models with additional fields support + models = [] + for model_data in message.get("results", []): + model = self._create_model_from_data(model_data) + models.append(model) + + # Create response with parsed models + response_data = { + "results": models, + "taskUUID": message.get("taskUUID"), + "taskType": message.get("taskType"), + "totalResults": message.get("totalResults", 0), + } + + return IModelSearchResponse(**response_data) + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to parse model search response", + exc_info=e, + ) + raise RunwareOperationError( + f"Failed to parse model search response: {e}", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + + def _create_model_from_data(self, model_data: Dict[str, Any]) -> IModel: + """Create IModel instance from API data with additional fields support.""" + # Get valid fields for IModel + valid_fields = {f.name for f in fields(IModel)} + + # Separate known and unknown fields + known_fields = {} + additional_fields = {} + + for key, value in model_data.items(): + if key in valid_fields: + known_fields[key] = value + else: + additional_fields[key] = value + + # Add additional_fields if there are any unknown fields + if additional_fields: + known_fields["additional_fields"] = additional_fields + + return IModel(**known_fields) + + async def _handle_model_search(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling model search message: {message}" + ) + + if message.get("error") or message.get("code"): + error_message = message.get( + "message", message.get("error", "Unknown error") + ) + error_code = message.get("code") + + error = RunwareOperationError( + f"Model search error: {error_message}", + operation_id=self.operation_id, + operation_type=self.operation_type, + code=error_code, + ) + await self._handle_error(error) + return + + # Parse and complete the operation using base _parse_response + model_search_response = self._parse_response(message) + await self._complete_operation([model_search_response]) + + except Exception as e: + logger.error( + f"Error handling model search message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) diff --git a/runware/operations/model_upload.py b/runware/operations/model_upload.py new file mode 100644 index 0000000..957b345 --- /dev/null +++ b/runware/operations/model_upload.py @@ -0,0 +1,135 @@ +from typing import Any, Dict + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..exceptions import RunwareOperationError +from ..logging_config import get_logger +from ..types import ETaskType, IUploadModelBaseType, IUploadModelResponse + +logger = get_logger(__name__) + + +class ModelUploadOperation(BaseOperation): + field_mappings = { + "air": "air", + "taskUUID": "taskUUID", + "taskType": "taskType", + } + response_class = IUploadModelResponse + + def __init__(self, request: IUploadModelBaseType, client=None): + super().__init__(operation_id=None, client=client) + self.request = request + self._processed_statuses = set() + logger.info(f"Model upload operation {self.operation_id} initialized") + + @property + def operation_type(self) -> str: + return "modelUpload" + + async def execute(self) -> IUploadModelResponse: + results = await super().execute() + if results is not None: + return results[0] + return None + + def _setup_message_handlers(self): + self._message_handlers = { + "modelUpload": self._handle_model_upload, + "error": self._handle_error_message, + } + + async def _build_request_payload(self) -> list[Dict[str, Any]]: + base_fields = await cpu_executor.serialize_dataclass( + { + "taskType": ETaskType.MODEL_UPLOAD.value, + "taskUUID": self.operation_id, + "air": self.request.air, + "name": self.request.name, + "downloadURL": self.request.downloadURL, + "uniqueIdentifier": self.request.uniqueIdentifier, + "version": self.request.version, + "format": self.request.format, + "private": self.request.private, + "category": self.request.category, + "architecture": self.request.architecture, + } + ) + + optional_fields = [ + "retry", + "heroImageURL", + "tags", + "shortDescription", + "comment", + "positiveTriggerWords", + "type", + "negativeTriggerWords", + "defaultWeight", + "defaultStrength", + "defaultGuidanceScale", + "defaultSteps", + "defaultScheduler", + "conditioning", + ] + + request_object = { + **base_fields, + **{ + field: getattr(self.request, field) + for field in optional_fields + if getattr(self.request, field, None) is not None + }, + } + + return [request_object] + + async def _handle_model_upload(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling model upload message: {message}" + ) + + if message.get("code"): + error = RunwareOperationError( + f"Model upload error: {message.get('message', 'Unknown error')}", + operation_id=self.operation_id, + operation_type=self.operation_type, + code=message.get("code"), + ) + await self._handle_error(error) + return + + status = message.get("status") + if not status: + logger.warning( + f"Operation {self.operation_id} received message without status" + ) + return + + # Track unique statuses + if status not in self._processed_statuses: + self._processed_statuses.add(status) + + # Check for error in status + if status and "error" in str(status).lower(): + error = RunwareOperationError( + f"Model upload failed with status: {status}", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + await self._handle_error(error) + return + + # Complete when status is "ready" + if status == "ready": + logger.info(f"Operation {self.operation_id} model upload ready") + upload_response = self._parse_response(message) + await self._complete_operation([upload_response]) + + except Exception as e: + logger.error( + f"Error handling model upload message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) diff --git a/runware/operations/photo_maker.py b/runware/operations/photo_maker.py new file mode 100644 index 0000000..743f145 --- /dev/null +++ b/runware/operations/photo_maker.py @@ -0,0 +1,137 @@ +from typing import Any, Dict, List + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..logging_config import get_logger +from ..types import ETaskType, IImage, IPhotoMaker +from ..utils import process_image + +logger = get_logger(__name__) + + +class PhotoMakerOperation(BaseOperation): + field_mappings = { + "taskType": "taskType", + "imageUUID": "imageUUID", + "taskUUID": "taskUUID", + "seed": "seed", + "inputImageUUID": "inputImageUUID", + "imageURL": "imageURL", + "imageBase64Data": "imageBase64Data", + "imageDataURI": "imageDataURI", + "NSFWContent": "NSFWContent", + "cost": "cost", + } + response_class = IImage + + def __init__(self, request: IPhotoMaker, client=None): + super().__init__(request.taskUUID, client) + self.request = request + self.expected_results = request.numberResults + self.received_results = 0 + self._processed_images: Dict[str, Any] = {} + + logger.info(f"Photo maker operation {self.operation_id} initialized") + logger.debug( + f"Operation {self.operation_id} expects {self.expected_results} results" + ) + + @property + def operation_type(self) -> str: + return "photoMaker" + + def _setup_message_handlers(self): + self._message_handlers = { + "photoMaker": self._handle_photo_maker, + "error": self._handle_error_message, + } + + async def _process_input_images(self): + if self.request.inputImages: + try: + processed_images = [] + for image in self.request.inputImages: + processed_image = await process_image(image) + processed_images.append(processed_image) + self.request.inputImages = processed_images + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to process input images", + exc_info=e, + ) + raise + + async def _build_request_payload(self) -> List[Dict[str, Any]]: + await self._process_input_images() + request_object = await cpu_executor.serialize_dataclass( + { + "taskUUID": self.operation_id, + "model": self.request.model, + "positivePrompt": self.request.positivePrompt.strip(), + "numberResults": self.expected_results, + "height": self.request.height, + "width": self.request.width, + "taskType": ETaskType.PHOTO_MAKER.value, + "style": self.request.style, + "strength": self.request.strength, + } + ) + + if self.request.inputImages: + request_object["inputImages"] = self.request.inputImages + + optional_fields = { + "steps": self.request.steps, + "outputFormat": self.request.outputFormat, + "includeCost": self.request.includeCost, + "outputType": self.request.outputType, + } + + for key, value in optional_fields.items(): + if value is not None: + request_object[key] = value + + return [request_object] + + async def _handle_photo_maker(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling photo maker message: {message}" + ) + + image_uuid = message.get("imageUUID") + if not image_uuid: + logger.warning( + f"Operation {self.operation_id} photo maker message missing imageUUID: {message}" + ) + return + + if image_uuid in self._processed_images: + logger.debug( + f"Operation {self.operation_id} already processed image {image_uuid}" + ) + return + + logger.info( + f"Operation {self.operation_id} processing new image {image_uuid}" + ) + + image_data = self._parse_response(message) + self._processed_images[image_uuid] = image_data + self.received_results += 1 + + progress = min(self.received_results / self.expected_results, 1.0) + await self._update_progress(progress, partial_results=[image_data]) + + if self.received_results >= self.expected_results: + logger.info( + f"Operation {self.operation_id} received all expected results" + ) + await self._complete_operation() + + except Exception as e: + logger.error( + f"Error handling photo maker message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) diff --git a/runware/operations/prompt_enhance.py b/runware/operations/prompt_enhance.py new file mode 100644 index 0000000..af28a52 --- /dev/null +++ b/runware/operations/prompt_enhance.py @@ -0,0 +1,81 @@ +from typing import Any, Dict, List + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..logging_config import get_logger +from ..types import ETaskType, IEnhancedPrompt, IPromptEnhance + +logger = get_logger(__name__) + + +class PromptEnhanceOperation(BaseOperation): + field_mappings = { + "taskType": "taskType", + "taskUUID": "taskUUID", + "text": "text", + "cost": "cost", + } + response_class = IEnhancedPrompt + + def __init__(self, request: IPromptEnhance, client=None): + super().__init__(operation_id=None, client=client) + self.request = request + self.expected_results = request.promptVersions + self.received_results = 0 + self._processed_prompts = [] + + logger.info(f"Prompt enhance operation {self.operation_id} initialized") + logger.debug( + f"Operation {self.operation_id} expects {self.expected_results} prompt versions" + ) + + @property + def operation_type(self) -> str: + return "promptEnhance" + + def _setup_message_handlers(self): + self._message_handlers = { + "promptEnhance": self._handle_prompt_enhance, + "error": self._handle_error_message, + } + + async def _build_request_payload(self) -> List[Dict[str, Any]]: + task_params = await cpu_executor.serialize_dataclass( + { + "taskType": ETaskType.PROMPT_ENHANCE.value, + "taskUUID": self.operation_id, + "prompt": self.request.prompt, + "promptMaxLength": getattr(self.request, "promptMaxLength", 380), + "promptVersions": self.request.promptVersions, + } + ) + + if self.request.includeCost: + task_params["includeCost"] = self.request.includeCost + + return [task_params] + + async def _handle_prompt_enhance(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling prompt enhance message: {message}" + ) + enhanced_prompt = self._parse_response(message) + self._processed_prompts.append(enhanced_prompt) + self.received_results += 1 + + progress = min(self.received_results / self.expected_results, 1.0) + await self._update_progress(progress, partial_results=[enhanced_prompt]) + + if self.received_results >= self.expected_results: + logger.info( + f"Operation {self.operation_id} received all expected prompt versions" + ) + await self._complete_operation() + + except Exception as e: + logger.error( + f"Error handling prompt enhance message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) diff --git a/runware/operations/video_inference.py b/runware/operations/video_inference.py new file mode 100644 index 0000000..2b3ad47 --- /dev/null +++ b/runware/operations/video_inference.py @@ -0,0 +1,466 @@ +import asyncio +import time +from typing import Any, Dict, List, Optional + +from .base import BaseOperation +from ..core.cpu_bound import cpu_executor +from ..core.types import OperationStatus +from ..exceptions import RunwareOperationError, RunwareTimeoutError +from ..logging_config import get_logger +from ..types import ETaskType, IFrameImage, IVideo, IVideoInference +from ..utils import process_image + +logger = get_logger(__name__) + + +class VideoInferenceOperation(BaseOperation): + field_mappings = { + "taskType": "taskType", + "taskUUID": "taskUUID", + "status": "status", + "videoUUID": "videoUUID", + "videoURL": "videoURL", + "cost": "cost", + "seed": "seed", + } + response_class = IVideo + + def __init__(self, request: IVideoInference, client=None): + super().__init__(request.taskUUID, client) + self.request = request + self.expected_results = request.numberResults or 1 + self.received_results = 0 + self._processed_videos: Dict[str, Any] = {} + self._status_monitoring_task: Optional[asyncio.Task] = None + self._initial_response_received = False + self._initial_response_event = asyncio.Event() + self._processed_final_status = False + self._final_status_lock = asyncio.Lock() + self._last_server_update = None + self._consecutive_failed_requests = 0 + self._max_failed_requests = 10 + + logger.info(f"Video inference operation {self.operation_id} initialized") + logger.debug( + f"Operation {self.operation_id} expects {self.expected_results} results" + ) + + @property + def operation_type(self) -> str: + return "videoInference" + + def _setup_message_handlers(self): + self._message_handlers = { + "videoInference": self._handle_video_inference, + "getResponse": self._handle_status_response, + "error": self._handle_error_message, + } + + async def _process_video_images(self): + logger.debug(f"Operation {self.operation_id} processing video images") + + frame_tasks = [] + reference_tasks = [] + + try: + if self.request.frameImages: + logger.debug( + f"Operation {self.operation_id} processing {len(self.request.frameImages)} frame images" + ) + frame_tasks = [ + process_image(frame_item.inputImage) + for frame_item in self.request.frameImages + if isinstance(frame_item, IFrameImage) + ] + + if self.request.referenceImages: + logger.debug( + f"Operation {self.operation_id} processing {len(self.request.referenceImages)} reference images" + ) + reference_tasks = [ + process_image(reference_item) + for reference_item in self.request.referenceImages + ] + + if frame_tasks: + frame_results = await asyncio.gather(*frame_tasks) + if frame_results: + processed_frame_images = [] + result_index = 0 + for frame_item in self.request.frameImages: + if isinstance(frame_item, IFrameImage): + frame_item.inputImages = frame_results[result_index] + result_index += 1 + processed_frame_images.append(frame_item) + self.request.frameImages = processed_frame_images + logger.debug( + f"Operation {self.operation_id} processed frame images successfully" + ) + + if reference_tasks: + reference_results = await asyncio.gather(*reference_tasks) + if reference_results: + self.request.referenceImages = reference_results + logger.debug( + f"Operation {self.operation_id} processed reference images successfully" + ) + + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to process video images", + exc_info=e, + ) + raise + + async def _build_request_payload(self) -> List[Dict[str, Any]]: + logger.debug( + f"Operation {self.operation_id} building video inference request payload" + ) + await self._process_video_images() + + request_object = { + "deliveryMethod": self.request.deliveryMethod or "async", + "taskType": ETaskType.VIDEO_INFERENCE.value, + "taskUUID": self.operation_id, + "model": self.request.model, + "positivePrompt": self.request.positivePrompt.strip(), + "numberResults": self.expected_results, + } + + optional_fields = { + "outputType": self.request.outputType, + "outputFormat": self.request.outputFormat, + "outputQuality": self.request.outputQuality, + "uploadEndpoint": self.request.uploadEndpoint, + "includeCost": self.request.includeCost, + "negativePrompt": self.request.negativePrompt, + "fps": self.request.fps, + "steps": self.request.steps, + "seed": self.request.seed, + "CFGScale": self.request.CFGScale, + "duration": self.request.duration, + "width": self.request.width, + "height": self.request.height, + } + + for key, value in optional_fields.items(): + if value is not None: + request_object[key] = value + + # Use ThreadPoolExecutor for serializing frame images + if self.request.frameImages: + frame_images_data = [] + for frame_item in self.request.frameImages: + # Serialize dataclass using ThreadPoolExecutor + serialized_frame = await cpu_executor.serialize_dataclass(frame_item) + frame_images_data.append( + {k: v for k, v in serialized_frame.items() if v is not None} + ) + request_object["frameImages"] = frame_images_data + + if self.request.referenceImages: + request_object["referenceImages"] = self.request.referenceImages + + if self.request.providerSettings: + provider_dict = self.request.providerSettings.to_request_dict() + if provider_dict: + request_object["providerSettings"] = provider_dict + + logger.debug( + f"Operation {self.operation_id} built video inference request payload" + ) + return [request_object] + + async def _handle_video_inference(self, message: Dict[str, Any]): + try: + logger.debug( + f"Operation {self.operation_id} handling video inference message: {message}" + ) + + self._initial_response_received = True + self._initial_response_event.set() + + # Update last server update timestamp + self._last_server_update = time.time() + + status = message.get("status") + logger.info(f"Operation {self.operation_id} received status: {status}") + + # Use lock to prevent duplicate processing of final status + async with self._final_status_lock: + if self._processed_final_status: + logger.debug( + f"Operation {self.operation_id} already processed final status, ignoring" + ) + return + + if status == "success": + logger.info( + f"Operation {self.operation_id} processing success status" + ) + self._processed_final_status = True + await self._stop_status_monitoring() + video_data = self._parse_response(message) + await self._complete_operation([video_data]) + elif status == "failed": + logger.error( + f"Operation {self.operation_id} processing failed status" + ) + self._processed_final_status = True + await self._stop_status_monitoring() + error_message = message.get("error", "Video generation failed") + error = RunwareOperationError( + f"Video generation failed: {error_message}", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + await self._handle_error(error) + else: + # For non-final statuses, start monitoring if not already started + await self._start_status_monitoring() + + except Exception as e: + logger.error( + f"Error handling video inference message for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) + + async def _start_status_monitoring(self): + if self._status_monitoring_task and not self._status_monitoring_task.done(): + logger.debug( + f"Operation {self.operation_id} status monitoring already running" + ) + return + + logger.info(f"Operation {self.operation_id} starting status monitoring") + self._status_monitoring_task = asyncio.create_task(self._monitor_status()) + + async def _stop_status_monitoring(self): + """Explicitly stop status monitoring task""" + if self._status_monitoring_task and not self._status_monitoring_task.done(): + logger.info(f"Operation {self.operation_id} stopping status monitoring") + self._status_monitoring_task.cancel() + try: + await asyncio.wait_for(self._status_monitoring_task, timeout=2.0) + logger.debug( + f"Operation {self.operation_id} status monitoring stopped successfully" + ) + except (asyncio.CancelledError, asyncio.TimeoutError): + logger.debug( + f"Operation {self.operation_id} status monitoring stopped (cancelled/timeout)" + ) + except Exception as e: + logger.warning( + f"Operation {self.operation_id} error stopping status monitoring: {e}" + ) + finally: + self._status_monitoring_task = None + + async def _monitor_status(self): + check_interval = 3.0 # Increased initial interval + max_interval = 15.0 # Increased max interval + max_checks = 600 # Increased max checks for longer operations + check_count = 0 + + # Add timeout tracking for server responsiveness + import time + + no_response_threshold = 120.0 # 2 minutes without server response + + logger.debug(f"Operation {self.operation_id} status monitoring loop started") + + try: + while check_count < max_checks and self.status == OperationStatus.EXECUTING: + check_count += 1 + + logger.debug( + f"Operation {self.operation_id} status check {check_count}/{max_checks}" + ) + + # Check if we haven't received server updates for too long + current_time = time.time() + if ( + self._last_server_update + and current_time - self._last_server_update > no_response_threshold + ): + logger.warning( + f"Operation {self.operation_id} no server updates for {current_time - self._last_server_update:.1f}s" + ) + + status_request = { + "taskType": ETaskType.GET_RESPONSE.value, + "taskUUID": self.operation_id, + } + + if self.client and self.client.connection_manager: + logger.debug( + f"Operation {self.operation_id} sending status request" + ) + try: + await self.client.connection_manager.send_message( + [status_request] + ) + logger.debug( + f"Operation {self.operation_id} status request sent successfully" + ) + except Exception as e: + logger.error( + f"Operation {self.operation_id} failed to send status request: {e}" + ) + self._consecutive_failed_requests += 1 + + if ( + self._consecutive_failed_requests + >= self._max_failed_requests + ): + logger.error( + f"Operation {self.operation_id} too many failed requests, stopping monitoring" + ) + await self._handle_error( + RunwareOperationError( + f"Failed to send status requests {self._consecutive_failed_requests} times", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + ) + break + else: + logger.error( + f"Operation {self.operation_id} no client or connection manager available" + ) + break + + progress = min(check_count / max_checks, 0.9) + await self._update_progress(progress, "Generating video") + + # Check status again before sleeping to catch quick status changes + if self.status != OperationStatus.EXECUTING: + logger.debug( + f"Operation {self.operation_id} status changed to {self.status.value}, stopping monitoring" + ) + break + + try: + logger.debug( + f"Operation {self.operation_id} sleeping for {check_interval}s" + ) + await asyncio.sleep(check_interval) + # Reset failed request counter on successful sleep + self._consecutive_failed_requests = 0 + except asyncio.CancelledError: + logger.debug( + f"Operation {self.operation_id} status monitoring cancelled during sleep" + ) + break + + check_interval = min( + check_interval * 1.05, max_interval + ) # Slower growth + + if self.status == OperationStatus.EXECUTING and check_count >= max_checks: + logger.error( + f"Operation {self.operation_id} timed out after {max_checks} status checks" + ) + await self._handle_error( + RunwareTimeoutError( + f"Video generation timed out after {max_checks} status checks", + timeout_duration=max_checks * max_interval, + ) + ) + + except asyncio.CancelledError: + logger.debug(f"Operation {self.operation_id} status monitoring cancelled") + except Exception as e: + logger.error( + f"Error in status monitoring for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) + finally: + logger.debug( + f"Operation {self.operation_id} status monitoring loop finished" + ) + + async def _handle_status_response(self, message: Dict[str, Any]): + try: + status = message.get("status") + logger.info( + f"Operation {self.operation_id} received status response: {status} (raw message: {message})" + ) + + # Update last server update timestamp + self._last_server_update = time.time() + + # Reset failed request counter on successful response + self._consecutive_failed_requests = 0 + + # Use same lock and logic as video inference handler + async with self._final_status_lock: + if self._processed_final_status: + logger.debug( + f"Operation {self.operation_id} already processed final status, ignoring status response" + ) + return + + if status == "success": + logger.info( + f"Operation {self.operation_id} processing success status from getResponse" + ) + self._processed_final_status = True + await self._stop_status_monitoring() + video_data = self._parse_response(message) + await self._complete_operation([video_data]) + elif status == "failed": + logger.error( + f"Operation {self.operation_id} processing failed status from getResponse" + ) + self._processed_final_status = True + await self._stop_status_monitoring() + error_message = message.get("error", "Video generation failed") + error = RunwareOperationError( + f"Video generation failed: {error_message}", + operation_id=self.operation_id, + operation_type=self.operation_type, + ) + await self._handle_error(error) + elif status == "pending" or status == "processing": + progress_message = message.get("message", f"Status: {status}") + current_progress = min(self.progress + 0.1, 0.9) + await self._update_progress(current_progress, progress_message) + logger.debug( + f"Operation {self.operation_id} continuing with status: {status}" + ) + else: + logger.warning( + f"Operation {self.operation_id} unknown status in response: {status}" + ) + + except Exception as e: + logger.error( + f"Error handling status response for operation {self.operation_id}", + exc_info=e, + ) + await self._handle_error(e) + + async def _handle_error_message(self, message: Dict[str, Any]): + # Set initial response event to unblock waiting + if not self._initial_response_received: + self._initial_response_received = True + self._initial_response_event.set() + + # Stop monitoring on error + await self._stop_status_monitoring() + return super()._handle_error_message(message) + + async def _cleanup(self): + """Clean up video inference specific resources""" + logger.debug(f"Operation {self.operation_id} starting video inference cleanup") + + # Stop status monitoring first + await self._stop_status_monitoring() + + # Then call parent cleanup + await super()._cleanup() + + logger.debug(f"Operation {self.operation_id} video inference cleanup completed") diff --git a/runware/server.py b/runware/server.py deleted file mode 100644 index 79fa065..0000000 --- a/runware/server.py +++ /dev/null @@ -1,298 +0,0 @@ -import asyncio -import json -import logging -import websockets -from websockets.protocol import State -from typing import Any, Dict, Optional - - -from .types import SdkType -from .utils import ( - BASE_RUNWARE_URLS, - PING_INTERVAL, - PING_TIMEOUT_DURATION, - TIMEOUT_DURATION, -) -from .base import RunwareBase -from .types import ( - Environment, - ListenerType, -) - -from .logging_config import configure_logging - - -class RunwareServer(RunwareBase): - def __init__( - self, - api_key: str, - url: str = BASE_RUNWARE_URLS[Environment.PRODUCTION], - log_level=logging.CRITICAL, - timeout: int = TIMEOUT_DURATION, - ): - super().__init__(api_key=api_key, url=url, timeout=timeout) - self._instantiated: bool = False - self._reconnecting_task: Optional[asyncio.Task] = None - self._pingTimeout: Optional[asyncio.Task] = None - self._pongListener: Optional[ListenerType] = None - self._loginListener: Optional[ListenerType] = None - self._sdkType: SdkType = SdkType.SERVER - self._apiKey: str = api_key - self._message_handler_task: Optional[asyncio.Task] = None - self._last_pong_time: float = 0.0 - self._is_shutting_down: bool = False - - # Configure logging - configure_logging(log_level) - self.logger = logging.getLogger(__name__) - self.logger.setLevel(log_level) - - async def connect(self): - self.logger.info("Connecting to Runware server from server") - - try: - self._ws = await websockets.connect(self._url) - # update close_timeout so that we end the script sooner for inference examples - self._ws.close_timeout = 1 - self._ws.max_size = None - self.logger.info(f"Connected to WebSocket URL: {self._url}") - - async def on_open(ws): - def login_check(m): - if ( - m.get("data") - and len(m["data"]) > 0 - and m["data"][0].get("connectionSessionUUID") - ): - return True - if m.get("errors"): - for error in m["errors"]: - if error.get("taskType") == "authentication": - return True - return False - - def login_lis(m): - if m.get("errors"): - for error in m["errors"]: - if error.get("taskType") == "authentication": - err_msg = "Authentication error" - self._invalidAPIkey = error.get("message") or err_msg - self._connection_session_uuid_event.set() - return - if m.get("data") and len(m["data"]) > 0: - self._connectionSessionUUID = m["data"][0].get( - "connectionSessionUUID" - ) - self._invalidAPIkey = None - self._connection_session_uuid_event.set() - - if not self._loginListener: - self._loginListener = self.addListener( - check=login_check, lis=login_lis - ) - - def pong_check(m): - return m.get("data", [])[0].get("pong") if m.get("data") else None - - def pong_lis(m): - if m.get("data", [])[0].get("pong"): - self._last_pong_time = asyncio.get_event_loop().time() - - self._connection_session_uuid_event = asyncio.Event() - - if not self._pongListener: - self._pongListener = self.addListener( - check=pong_check, lis=pong_lis - ) - - if self._reconnecting_task: - self._reconnecting_task.cancel() - - if self._connectionSessionUUID and self.isWebsocketReadyState(): - self.logger.info( - f"Starting new connection with connectionSessionUUID {self._connectionSessionUUID}" - ) - await self.send( - [ - { - "taskType": "authentication", - "apiKey": self._apiKey, - "connectionSessionUUID": self._connectionSessionUUID, - } - ] - ) - elif self.isWebsocketReadyState(): - self.logger.info("Starting new connection with apiKey only") - await self.send( - [ - { - "taskType": "authentication", - "apiKey": self._apiKey, - } - ] - ) - - if self.isWebsocketReadyState(): - self.logger.info("Starting heartbeat task") - self._heartbeat_task = asyncio.create_task( - self.heartBeat(), name="Task_Heartbeat" - ) - - self._message_handler_task = asyncio.create_task( - self._handle_messages(), name="Task_Message_Handler" - ) - await on_open(self._ws) - # Wait for the _connectionSessionUUID to be set - await self._connection_session_uuid_event.wait() - - except websockets.exceptions.ConnectionClosedError: - await self.handleClose() - - async def on_message(self, ws, message): - if not message: - return - - try: - m = json.loads(message) - except json.JSONDecodeError as e: - self.logger.error(f"Failed to parse JSON message:", exc_info=e) - return - - for lis in self._listeners: - try: - result = lis.listener(m) - except Exception as e: - self.logger.error(f"Error in listener {lis.key}:", exc_info=e) - continue - if result: - return - - async def _handle_messages(self): - try: - self.logger.debug( - f"Starting message handler task {self._message_handler_task}" - ) - async for message in self._ws: - if self._is_shutting_down: - break - try: - await self.on_message(self._ws, message) - except Exception as e: - self.logger.error(f"Error in on_message:", exc_info=e) - continue - except websockets.exceptions.ConnectionClosedError as e: - if not self._is_shutting_down: - self.logger.error(f"Connection Closed Error:", exc_info=e) - await self.handleClose() - except Exception as e: - self.logger.error(f"Critical error in _handle_messages:", exc_info=e) - if not self._is_shutting_down: - await self.handleClose() - - async def send(self, msg: Dict[str, Any]): - self.logger.debug(f"Sending message: {msg}") - if self._ws and self._ws.state is State.OPEN and not self._is_shutting_down: - await self._ws.send(json.dumps(msg)) - - def _get_task_by_name(self, name): - tasks = asyncio.all_tasks() - for task in tasks: - if task.get_name() == name: - return task - return None - - async def handleClose(self): - self.logger.debug("Handling close") - - if self._invalidAPIkey: - self.logger.error(f"Error: {self._invalidAPIkey}") - return - - reconnecting_task = self._get_task_by_name("Task_Reconnecting") - if reconnecting_task is not None: - if not reconnecting_task.done() and not reconnecting_task.cancelled(): - self.logger.debug(f"Cancelling Task_Reconnecting {reconnecting_task}") - try: - reconnecting_task.cancel() - except Exception as e: - self.logger.error(f"Error while cancelling Task_Reconnecting:", exc_info=e) - - message_handler_task = self._get_task_by_name("Task_Message_Handler") - if message_handler_task is not None: - if not message_handler_task.done() and not message_handler_task.cancelled(): - self.logger.debug( - f"Cancelling Task_Message_Handler {message_handler_task}" - ) - try: - message_handler_task.cancel() - except Exception as e: - self.logger.error( - f"Error while cancelling Task_Message_Handler:", exc_info=e - ) - - heartbeat_task = self._get_task_by_name("Task_Heartbeat") - if heartbeat_task is not None: - if not heartbeat_task.done() and not heartbeat_task.cancelled(): - self.logger.debug(f"Cancelling Task_Heartbeat {heartbeat_task}") - try: - heartbeat_task.cancel() - except Exception as e: - self.logger.error(f"Error while cancelling Task_Heartbeat:", exc_info=e) - - async def reconnect(): - reconnect_attempts = 0 - max_reconnect_attempts = 5 - - while reconnect_attempts < max_reconnect_attempts and not self._is_shutting_down: - self.logger.info(f"Reconnecting... (attempt {reconnect_attempts + 1})") - await asyncio.sleep(min(reconnect_attempts * 2 + 1, 10)) - try: - await self.connect() - if self.isWebsocketReadyState(): - self.logger.info("Reconnected successfully") - break # Break out of the loop if the connection is successful and in a ready state - else: - self.logger.warning( - "WebSocket connection is not in a ready state after reconnecting" - ) - except Exception as e: - self.logger.error(f"Error while reconnecting:", exc_info=e) - - reconnect_attempts += 1 - - if reconnect_attempts >= max_reconnect_attempts: - self.logger.error("Max reconnection attempts reached. Giving up.") - self._is_shutting_down = True - - # Attempting to reconnect... - if not self._is_shutting_down: - self._reconnecting_task = asyncio.create_task( - reconnect(), name="Task_Reconnecting" - ) - - async def heartBeat(self): - while not self._is_shutting_down: - if self.isWebsocketReadyState(): - self.logger.debug("Sending ping") - try: - await self.send([{"taskType": "ping", "ping": True}]) - except websockets.exceptions.ConnectionClosedError as e: - self.logger.error( - f"Error sending ping. Connection likely closed.", exc_info=e - ) - break - except Exception as e: - self.logger.error(f"Unexpected error sending ping", exc_info=e) - break - - await asyncio.sleep(PING_INTERVAL / 1000) - - if ( - asyncio.get_event_loop().time() - self._last_pong_time - > PING_TIMEOUT_DURATION / 1000 - ): - self.logger.warning("No pong received. Connection may be lost.") - await self.handleClose() - break - else: - break diff --git a/runware/types.py b/runware/types.py index 51dcc21..764a2fa 100644 --- a/runware/types.py +++ b/runware/types.py @@ -1,7 +1,7 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field from enum import Enum -from dataclasses import dataclass, field, asdict -from typing import List, Union, Optional, Callable, Any, Dict, TypeVar, Literal +from typing import Any, Callable, Dict, List, Literal, Optional, Union class Environment(Enum): @@ -408,7 +408,7 @@ class IImageInference: checkNsfw: Optional[bool] = None negativePrompt: Optional[str] = None seedImage: Optional[Union[File, str]] = None - referenceImages: Optional[Union[File, str]] = None + referenceImages: Optional[Union[File, str, list]] = None maskImage: Optional[Union[File, str]] = None strength: Optional[float] = None height: Optional[int] = None @@ -435,7 +435,6 @@ class IImageInference: outpaint: Optional[IOutpaint] = None instantID: Optional[IInstantID] = None ipAdapters: Optional[List[IIpAdapter]] = field(default_factory=list) - referenceImages: Optional[List[Union[str, File]]] = field(default_factory=list) acePlusPlus: Optional[IAcePlusPlus] = None extraArgs: Optional[Dict[str, Any]] = field(default_factory=dict) @@ -591,8 +590,11 @@ class IFrameImage: class SerializableMixin: def serialize(self) -> Dict[str, Any]: - return {k: v for k, v in asdict(self).items() - if v is not None and not k.startswith('_')} + return { + k: v + for k, v in asdict(self).items() + if v is not None and not k.startswith("_") + } @dataclass @@ -686,7 +688,6 @@ class IPixverseProviderSettings(BaseProviderSettings): cameraMovement: Optional[str] = None style: Optional[str] = None motionMode: Optional[str] = None - watermark: Optional[bool] = None soundEffectSwitch: Optional[bool] = None soundEffectContent: Optional[str] = None @@ -706,7 +707,15 @@ def provider_key(self) -> str: return "vidu" -VideoProviderSettings = IKlingAIProviderSettings | IGoogleProviderSettings | IMinimaxProviderSettings | IBytedanceProviderSettings | IPixverseProviderSettings | IViduProviderSettings +VideoProviderSettings = ( + IKlingAIProviderSettings + | IGoogleProviderSettings + | IMinimaxProviderSettings + | IBytedanceProviderSettings + | IPixverseProviderSettings + | IViduProviderSettings +) + @dataclass class IVideoInference: @@ -732,6 +741,7 @@ class IVideoInference: numberResults: Optional[int] = 1 providerSettings: Optional[VideoProviderSettings] = None + @dataclass class IVideo: taskType: str @@ -741,136 +751,3 @@ class IVideo: videoURL: Optional[str] = None cost: Optional[float] = None seed: Optional[int] = None - - -# The GetWithPromiseCallBackType is defined using the Callable type from the typing module. It represents a function that takes a dictionary -# with specific keys and returns either a boolean or None. -# The dictionary should have the following keys: -# resolve: A function that takes a value of any type and returns None. -# reject: A function that takes a value of any type and returns None. -# intervalId: A value of any type representing the interval ID. -# You can use these types in your Python code to define variables, parameters, or return types that match the corresponding TypeScript types. -# -# def on_message(event: Any): -# # Handle WebSocket message event -# pass -# -# websocket = ReconnectingWebsocketProps(websocket_object) -# websocket.add_event_listener("message", on_message, {}) -# -# uploaded_image = UploadImageType("abc123", "image.png", "task123") -# -# def get_with_promise(callback_data: Dict[str, Union[Callable[[Any], None], Any]]) -> Union[bool, None]: -# # Implement the callback function logic here -# pass - - -GetWithPromiseCallBackType = Callable[ - [Dict[str, Union[Callable[[Any], None], Any]]], Union[bool, None] -] - - -# The ListenerType class is defined to represent the structure of a listener. -# The key parameter is a string that represents a unique identifier for the listener. -# The listener parameter is a callable function that takes a single argument msg of type Any and returns None. -# It represents the function to be called when the corresponding event occurs. -# The group_key parameter is an optional string that represents a group identifier for the listener. It allows grouping listeners together based on a common key. -# You can create instances of ListenerType by providing the required parameters: -# -# def on_message(msg: Any): -# # Handle the message -# print(msg) -# -# listener = ListenerType("message_listener", on_message, group_key="message_group") - -# In this example, we define a function on_message that takes a single argument msg and handles the received message. -# We then create an instance of ListenerType called listener by providing the key "message_listener", -# the on_message function as the listener, and an optional group key "message_group". -# You can store instances of ListenerType in a list or dictionary to manage multiple listeners in your application. - -# listeners = [ -# ListenerType("listener1", on_message1), -# ListenerType("listener2", on_message2, group_key="group1"), -# ListenerType("listener3", on_message3, group_key="group1"), -# ] - - -class ListenerType: - def __init__( - self, - key: str, - listener: Callable[[Any], None], - group_key: Optional[str] = None, - debug_message: Optional[str] = None, - ): - """ - Initialize a new ListenerType instance. - - :param key: str, a unique identifier for the listener. - :param listener: Callable[[Any], None], the function to be called when the listener is triggered. - :param group_key: Optional[str], an optional grouping key that can be used to categorize listeners. - """ - self.key = key - self.listener = listener - self.group_key = group_key - self.debug_message = debug_message - - def __str__(self): - return f"ListenerType(key={self.key}, listener={self.listener}, group_key={self.group_key}, debug_message={self.debug_message})" - - def __repr__(self): - return self.__str__() - - -T = TypeVar("T") -Keys = TypeVar("Keys") - - -class RequireAtLeastOne: - def __init__(self, data: Dict[str, Any], required_keys: Union[str, Keys]): - if not isinstance(data, dict): - raise TypeError("data must be a dictionary") - - self.data = data - self.required_keys = required_keys - - if not isinstance(required_keys, (list, tuple)): - required_keys = [required_keys] - - missing_keys = [key for key in required_keys if key not in data] - if len(missing_keys) == len(required_keys): - raise ValueError( - f"At least one of the required keys must be present: {', '.join(required_keys)}" - ) - - def __getitem__(self, key: str): - return self.data[key] - - def __setitem__(self, key: str, value: Any): - self.data[key] = value - - def __delitem__(self, key: str): - del self.data[key] - - def __contains__(self, key: str): - return key in self.data - - def __len__(self): - return len(self.data) - - def __iter__(self): - return iter(self.data) - - -class RequireOnlyOne(RequireAtLeastOne): - def __init__(self, data: Dict[str, Any], required_keys: Union[str, Keys]): - super().__init__(data, required_keys) - - if not isinstance(required_keys, (list, tuple)): - required_keys = [required_keys] - - provided_keys = [key for key in required_keys if key in data] - if len(provided_keys) > 1: - raise ValueError( - f"Only one key can be provided: {', '.join(provided_keys)}" - ) diff --git a/runware/utils.py b/runware/utils.py index 65796fb..fc37c03 100644 --- a/runware/utils.py +++ b/runware/utils.py @@ -1,38 +1,18 @@ -import asyncio import base64 +import mimetypes import os import re +import uuid +from dataclasses import fields +from typing import Any, Optional, Type, Union from urllib.parse import urlparse import aiofiles -import datetime -import uuid -import json -import mimetypes -import inspect -from typing import Any, Dict, List, Union, Optional, TypeVar, Type, Coroutine -from dataclasses import fields -from .types import ( - Environment, - EPreProcessor, - EPreProcessorGroup, - ListenerType, - IControlNet, - File, - GetWithPromiseCallBackType, - IImage, - ETaskType, - IImageToText, - IEnhancedPrompt, - IError, - UploadImageType, -) -import logging -logger = logging.getLogger(__name__) +from .logging_config import get_logger +from .types import Environment, File, UploadImageType -if not mimetypes.guess_type("test.webp")[0]: - mimetypes.add_type("image/webp", ".webp") +logger = get_logger(__name__) BASE_RUNWARE_URLS = { Environment.PRODUCTION: "wss://ws-api.runware.ai/v1", @@ -44,7 +24,6 @@ "REQUEST_IMAGES": 2, } - PING_TIMEOUT_DURATION = 10000 # 10 seconds PING_INTERVAL = 5000 # 5 seconds @@ -56,147 +35,6 @@ class LISTEN_TO_IMAGES_KEY: REQUEST_IMAGES = "REQUEST_IMAGES" -class RunwareAPIError(Exception): - def __init__(self, error_data: dict): - self.code = error_data.get("code") - self.error_data = error_data - super().__init__(str(error_data)) - - def __str__(self): - return f"RunwareAPIError: {self.error_data}" - - -class RunwareError(Exception): - def __init__(self, ierror: IError): - self.ierror = ierror - super().__init__(f"Runware Error: {ierror.error_message}") - - def format_error(self): - return { - "errors": [ - { - "code": self.ierror.error_code, - "message": self.ierror.error_message, - "parameter": self.ierror.parameter, - "type": self.ierror.error_type, - "documentation": self.ierror.documentation, - "taskUUID": self.ierror.task_uuid, - } - ] - } - - def __str__(self): - return f"Runware Error: {self.format_error()}" - - -class Blob: - """ - A Python equivalent of the Blob class to simulate file-like behavior in an immutable manner. - - This class is designed to closely align with the TypeScript implementation of the Blob class. - It provides a way to represent and manipulate immutable binary data, similar to how files are handled. - - :param blob_parts: List[bytes], content parts of the blob, typically bytes. - :param options: Dict[str, Any], containing options such as type (MIME type). - - Example: - content = b"Hello, world!" - blob = Blob([content], {"type": "text/plain"}) - print(len(blob)) # Output: 13 - print(str(blob)) # Output: "Hello, world!" - print(blob.size()) # Output: 13 - """ - - def __init__( - self, - blob_parts: Optional[List[bytes]] = None, - options: Optional[Dict[str, Any]] = None, - ): - """ - Initialize the Blob object. - :param blob_parts: List[bytes], content parts of the blob, typically bytes. - :param options: Dict[str, Any], containing options such as type (MIME type). - """ - self._content = b"".join(blob_parts) if blob_parts else b"" - self.type = options.get("type", "") if options else "" - - def __len__(self) -> int: - return len(self._content) - - def __str__(self) -> str: - return self._content.decode("utf-8") - - def size(self) -> int: - return len(self) - - -class MockFile: - """ - A class that provides a method to create mock file objects for testing purposes. - - The `create` method generates a Blob object that simulates a file with specified attributes - such as name, size, and MIME type. This is useful when you need to work with file-like objects - in tests or when actual files are not available. - - Example: - mock_file = MockFile() - file_obj = mock_file.create(name="example.txt", size=2048, mime_type="text/plain") - print(file_obj.name) # Output: "example.txt" - print(file_obj.size()) # Output: 2048 - print(file_obj.type) # Output: "text/plain" - print(file_obj.lastModifiedDate) # Output: current datetime - """ - - def create( - self, name: str = "mock.txt", size: int = 1024, mime_type: str = "plain/txt" - ) -> Blob: - """ - Create a mock file object with specified attributes. - - This method generates a Blob object that simulates a file. The content of the file is - created as a sequence of 'a' characters repeated 'size' times. The Blob object is then - enhanced with additional attributes to mimic a real file, such as name and lastModifiedDate. - - :param name: str, the name of the file (default: "mock.txt") - :param size: int, the size of the file in bytes (default: 1024) - :param mime_type: str, the MIME type of the file (default: "plain/txt") - :return: Blob, a Blob object simulating a file with the specified attributes - """ - content = ("a" * size).encode() # Create content as bytes - blob = Blob(blob_parts=[content], options={"type": mime_type}) - - # Simulate File attributes by adding properties to the Blob object - setattr(blob, "name", name) - setattr(blob, "lastModifiedDate", datetime.datetime.now()) - return blob - - -T = TypeVar("T") - - -def removeFromAray(col: Optional[List[T]], targetElem: T) -> None: - """ - Remove the first occurrence of an element from an array. - - :param col: Optional[List[T]], the collection from which to remove the element. None is safely handled. - :param target_elem: T, the element to remove from the collection. - """ - if col is None: - return - try: - col.remove(targetElem) - except ValueError: - # If targetElem is not in col, do nothing - return - - -def getUUID() -> str: - """ - Generate and return a new UUID as a string. - """ - return str(uuid.uuid4()) - - def isValidUUID(uuid_str: str) -> bool: """ Check if a given string is a valid UUID. @@ -211,57 +49,6 @@ def isValidUUID(uuid_str: str) -> bool: return False -def evaluateToBoolean(*args: Any) -> bool: - """ - Evaluate to boolean by checking if all arguments are truthy. - - :param args: Variable length argument list of any type. - :return: Returns True if all arguments are truthy, otherwise False. - """ - return all(args) - - -def compact(key: Any, data: Dict[str, Any]) -> Dict[str, Any]: - """ - Returns a dictionary containing the data if the key is truthy, otherwise returns an empty dictionary. - - :param key: Any, the key to check for truthiness. - :param data: Dict[str, Any], the dictionary to return if the key is truthy. - :return: A dictionary containing the data if key is truthy; otherwise, an empty dictionary. - - Example: - lowThresholdCanny = 5 # or None if it should be omitted - highThresholdCanny = 10 # or None if it should be omitted - send_data = { - "newPreProcessControlNet": { - "taskUUID": "some-uuid", - "preProcessorType": "some-type", - "guideImageUUID": "another-uuid", - "includeHandsAndFaceOpenPose": True, - **compact(lowThresholdCanny, {"lowThresholdCanny": lowThresholdCanny}), - **compact(highThresholdCanny, {"highThresholdCanny": highThresholdCanny}), - }, - } - """ - return data if key else {} - - -# originally range() in Typescipt library, renamed nu avoid conflict with Python's built-in range function -# TODO: only used in tests/test_utils.py, consider removing -def generateString(count: int) -> str: - return "a" * count - - -# TODO: function it's not used in the code anywhere else, consider removing -def remove1Mutate(col: List[Any], targetElem: Any) -> None: - if col is None: - return - try: - col.remove(targetElem) - except ValueError: - return - - async def fileToBase64(file_path: str) -> str: """ Asynchronously convert a file at a given path to a Base64-encoded string. @@ -270,480 +57,32 @@ async def fileToBase64(file_path: str) -> str: :return: str, Base64-encoded content of the file. :raises FileNotFoundError: if the file does not exist. :raises IOError: if the file cannot be read. - - Example: - try: - if isinstance(file, str) and file.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')): - image_base64 = await fileToBase64(file) - else: - # Otherwise, use the string directly as it might be a Base64 string - image_base64 = file - - await send({ - "newImageUpload": { - "imageBase64": image_base64, - "taskUUID": task_uuid, - "taskType": task_type, - } - }) - except FileNotFoundError as e: - print(f"Error: {e}") - except IOError as e: - print(f"Error: {e}") - """ try: + logger.debug(f"Converting file to base64: {file_path}") async with aiofiles.open(file_path, "rb") as file: file_contents = await file.read() mime_type, _ = mimetypes.guess_type(file_path) if mime_type is None: + logger.warning(f"Unable to determine MIME type for file: {file_path}") raise ValueError( f"Unable to determine the MIME type for file: {file_path}" ) base64_content = base64.b64encode(file_contents).decode("utf-8") + logger.debug( + f"Successfully converted file to base64, size: {len(base64_content)} chars" + ) return f"data:{mime_type};base64,{base64_content}" except FileNotFoundError: + logger.error(f"File not found: {file_path}") raise FileNotFoundError(f"The file at {file_path} does not exist.") - except IOError: + except IOError as e: + logger.error(f"Failed to read file: {file_path}", exc_info=e) raise IOError(f"The file at {file_path} could not be read.") -def removeListener( - listeners: List[ListenerType], listener: ListenerType -) -> List[ListenerType]: - """ - Remove a specified listener from a list of listeners. - - This function filters out the listener whose `key` attribute matches that of the given `listener` object. - It returns a new list without altering the original list of listeners. - - :param listeners: List[ListenerType], the list from which to remove the listener. - :param listener: ListenerType, the listener to be removed based on its 'key'. - :return: List[ListenerType], a new list with the specified listener removed. - """ - - return [lis for lis in listeners if lis.key != listener.key] - - -def removeAllKeyListener(listeners: List[ListenerType], key: str) -> List[ListenerType]: - """ - Remove all listeners from the list that have a specific key. - - This function filters out all listeners whose `key` attribute matches the specified `key`. - It creates a new list of listeners without those that have the matching key, without altering the original list. - - :param listeners: List[ListenerType], the list from which to remove listeners. - :param key: str, the key associated with listeners to be removed. - :return: List[ListenerType], a new list with all matching key listeners removed. - """ - return [lis for lis in listeners if lis.key != key] - - -async def delay(time: float, milliseconds: int = 1000) -> None: - """ - Asynchronously wait for a specified amount of time. - - :param time: float, the number of time units to wait. - :param milliseconds: int, number of milliseconds each time unit represents. - """ - await asyncio.sleep(time * milliseconds / 1000) - - -def getTaskType( - prompt: str, - controlNet: Optional[List[IControlNet]], - imageMaskInitiator: Optional[Union[File, str]], - imageInitiator: Optional[Union[File, str]], -) -> int: - """ - Determine the task type based on the presence or absence of various parameters. - - :param prompt: str, the prompt text. - :param control_net: Optional[List[IControlNet]], a list of settings for controlling the network, which can be None. - :param image_initiator: Optional[Union[File, str]], a File object or a string path indicating the image initiator. - :param image_mask_initiator: Optional[Union[File, str]], a File object or a string path indicating the image mask initiator. - :return: int, the task type determined by the conditions. - """ - if evaluateToBoolean( - prompt, not controlNet, not imageMaskInitiator, not imageInitiator - ): - return 1 - if evaluateToBoolean( - prompt, not controlNet, not imageMaskInitiator, imageInitiator - ): - return 2 - if evaluateToBoolean(prompt, not controlNet, imageMaskInitiator, imageInitiator): - return 3 - if evaluateToBoolean( - prompt, controlNet, not imageMaskInitiator, not imageInitiator - ): - return 9 - if evaluateToBoolean(prompt, controlNet, not imageMaskInitiator, imageInitiator): - return 10 - if evaluateToBoolean(prompt, controlNet, imageMaskInitiator, imageInitiator): - return 10 - # TODO: Better handling of invalid task types, e.g. raise an exception - return -1 - - -def getPreprocessorType(processor: EPreProcessor) -> EPreProcessorGroup: - """ - Determine the preprocessor group based on the given preprocessor. - - :param processor: EPreProcessor, the preprocessor for which to determine the group. - :return: EPreProcessorGroup, the corresponding preprocessor group for the given preprocessor. - - This function maps an EPreProcessor enum value to its corresponding EPreProcessorGroup enum value. - It helps in identifying the group or category to which a specific preprocessor belongs. - - Example: - preprocessor = EPreProcessor.canny - group = getPreprocessorType(preprocessor) - print(group) # Output: EPreProcessorGroup.canny - """ - if processor == EPreProcessor.canny: - return EPreProcessorGroup.canny - elif processor in [ - EPreProcessor.depth_leres, - EPreProcessor.depth_midas, - EPreProcessor.depth_zoe, - ]: - return EPreProcessorGroup.depth - elif processor == EPreProcessor.inpaint_global_harmonious: - return EPreProcessorGroup.depth - elif processor == EPreProcessor.lineart_anime: - return EPreProcessorGroup.lineart_anime - elif processor in [ - EPreProcessor.lineart_coarse, - EPreProcessor.lineart_realistic, - EPreProcessor.lineart_standard, - ]: - return EPreProcessorGroup.lineart - elif processor == EPreProcessor.mlsd: - return EPreProcessorGroup.mlsd - elif processor == EPreProcessor.normal_bae: - return EPreProcessorGroup.normalbae - elif processor in [ - EPreProcessor.openpose_face, - EPreProcessor.openpose_faceonly, - EPreProcessor.openpose_full, - EPreProcessor.openpose_hand, - EPreProcessor.openpose, - ]: - return EPreProcessorGroup.openpose - elif processor in [EPreProcessor.scribble_hed, EPreProcessor.scribble_pidinet]: - return EPreProcessorGroup.scribble - elif processor in [ - EPreProcessor.seg_ofade20k, - EPreProcessor.seg_ofcoco, - EPreProcessor.seg_ufade20k, - ]: - return EPreProcessorGroup.seg - elif processor == EPreProcessor.shuffle: - return EPreProcessorGroup.shuffle - elif processor in [ - EPreProcessor.softedge_hed, - EPreProcessor.softedge_hedsafe, - EPreProcessor.softedge_pidinet, - EPreProcessor.softedge_pidisafe, - ]: - return EPreProcessorGroup.softedge - elif processor == EPreProcessor.tile_gaussian: - return EPreProcessorGroup.tile - else: - return EPreProcessorGroup.canny - - -def accessDeepObject( - key: str, - data: Dict[str, Any], - useZero: bool = True, - shouldReturnString: bool = False, -) -> Any: - """ - Navigate deeply nested data structures based on a dot/bracket-notated key. - - :param key: str, the key path (e.g., "person.address[0].street"). - :param data: dict, the data to navigate. - :param useZero: bool, return 0 instead of None for non-existent values. - :param shouldReturnString: bool, return a JSON string of the object if True. - :return: The value found at the key path or a default value. - """ - - # Get the current frame - current_frame = inspect.currentframe() - - # Get the caller's frame - caller_frame = current_frame.f_back - - # Get the caller's function name - caller_name = caller_frame.f_code.co_name - - # Get the caller's line number - caller_line_number = caller_frame.f_lineno - - logger.debug( - f"Function {accessDeepObject.__name__} called by {caller_name} at line {caller_line_number}" - ) - logger.debug(f"accessDeepObject key: {key}") - logger.debug(f"accessDeepObject data: {data}") - - default_value = 0 if useZero else None - - # if "data" in data and isinstance(data["data"], list): - # # Iterate through each item in the data list - # for item in data["data"]: - # # Check if 'taskType' is in the item and matches the target_task_type - # if "taskType" in item and item["taskType"] == key: - # # Return the entire item if there's a match - # matching_tasks.append(item) - matching_tasks = [] - - for field in ["data", "errors"]: - if field in data and isinstance(data[field], list): - for item in data[field]: - if "taskUUID" in item and item["taskUUID"] == key: - matching_tasks.append(item) - - # Check for successful messages - if "data" in data and isinstance(data["data"], list): - for item in data["data"]: - if "taskUUID" in item and item["taskUUID"] == key: - matching_tasks.append(item) - - # Check for error messages - if "errors" in data and isinstance(data["errors"], list): - for error in data["errors"]: - if "taskUUID" in error and error["taskUUID"] == key: - matching_tasks.append(error) - - if len(matching_tasks) == 0: - return default_value - - logger.debug(f"accessDeepObject matching_tasks: {matching_tasks}") - - if shouldReturnString and isinstance(matching_tasks, (dict, list)): - return json.dumps(matching_tasks) - - return matching_tasks - - # keys = re.split(r"\.|\[", key) - # keys = [k.replace("]", "") for k in keys] - - # logger.debug(f"accessDeepObject keys: {keys}") - - # current_value = data - # for k in keys: - # logger.debug(f"accessDeepObject key: {k}") - # # logger.debug( - # # "isinstance(current_value, (dict, list))", - # # str(isinstance(current_value, (dict, list))), - # # ) - # if isinstance(current_value, (dict, list)): - # logger.debug(f"accessDeepObject current_value: {current_value}") - # logger.debug(f"k in current_value: {k in current_value}") - # if k.isdigit() and isinstance(current_value, list): - # index = int(k) - # if 0 <= index < len(current_value): - # current_value = current_value[index] - # else: - # return default_value - # elif k in current_value: - # current_value = current_value[k] - # else: - # return default_value - # else: - # return default_value - - # logger.debug(f"accessDeepObject current_value: {current_value}") - - # if shouldReturnString and isinstance(current_value, (dict, list)): - # return json.dumps(current_value) - - # return current_value - - -def createEnhancedPromptsFromResponse(response: List[dict]) -> List[IEnhancedPrompt]: - def process_single_prompt(prompt_data: dict) -> IEnhancedPrompt: - processed_fields = {} - - for field in fields(IEnhancedPrompt): - if field.name in prompt_data: - if field.name == "taskType": - processed_fields[field.name] = ETaskType(prompt_data[field.name]) - elif field.type == float or field.type == Optional[float]: - processed_fields[field.name] = float(prompt_data[field.name]) - else: - processed_fields[field.name] = prompt_data[field.name] - - return instantiateDataclass(IEnhancedPrompt, processed_fields) - - return [process_single_prompt(prompt) for prompt in response] - - -def createImageFromResponse(response: dict) -> IImage: - processed_fields = {} - - for field in fields(IImage): - if field.name in response: - if field.type == bool or field.type == Optional[bool]: - processed_fields[field.name] = bool(response[field.name]) - elif field.type == float or field.type == Optional[float]: - processed_fields[field.name] = float(response[field.name]) - else: - processed_fields[field.name] = response[field.name] - - return instantiateDataclass(IImage, processed_fields) - - -def createImageToTextFromResponse(response: dict) -> IImageToText: - processed_fields = {} - - for field in fields(IImageToText): - if field.name in response: - if field.name == "taskType": - # Convert string to ETaskType enum - processed_fields[field.name] = ETaskType(response[field.name]) - elif field.type == float or field.type == Optional[float]: - processed_fields[field.name] = float(response[field.name]) - else: - processed_fields[field.name] = response[field.name] - - return instantiateDataclass(IImageToText, processed_fields) - - -async def getIntervalWithPromise( - callback: GetWithPromiseCallBackType, - debugKey: str = "debugKey", - timeOutDuration: int = TIMEOUT_DURATION, # in milliseconds - shouldThrowError: bool = True, - pollingInterval: int = 100, # in milliseconds -) -> Any: - """ - Set up an interval to repeatedly call a callback function until a condition is met or a timeout occurs. - - :param callback: A function that is called repeatedly within the interval. It receives an object with - `resolve`, `reject`, and `intervalId` properties, allowing the callback to control - the promise's resolution or rejection and clear the interval if needed. - The callback should return a boolean value indicating whether to clear the interval. - :param debugKey: A string used for debugging purposes. Default is "debugKey". - :param timeOutDuration: The duration in milliseconds after which the promise will be rejected if the - callback hasn't resolved or rejected it. Default is TIMEOUT_DURATION. - :param shouldThrowError: A boolean indicating whether to reject the promise with an error message if - the timeout is reached. Default is True. - :param pollingInterval: The interval in milliseconds at which the callback is invoked. Default is 100. - :return: The result of the callback function or the rejection reason if the timeout is reached. - - Example: - async def upload_image(task_uuid): - image = await getIntervalWithPromise( - lambda params: params["resolve"]("uploadedImage") if "uploadedImage" in globals() else None, - debugKey="upload-image", - pollingInterval=200, - ) - return image - - uploaded_image = await upload_image("task123") - print(uploaded_image) # Output: "uploadedImage" - """ - logger = logging.getLogger(__name__) - - loop = asyncio.get_running_loop() - future = loop.create_future() - intervalHandle = None - - async def check_callback(): - nonlocal intervalHandle, future - try: - if not future.done(): - # logger.debug(f"Checking callback for {debugKey}") - # logger.debug(f"Future done: {future.done()}") - # logger.debug(f"Future callback: {callback}") - # logger.debug(f"Future result: {future.result}") - - # logger.debug(f"Future exception: {future.exception}") - # logger.debug(f"callback: {callback}") - - result = callback( - future.set_result, future.set_exception, intervalHandle - ) - - if result: - if intervalHandle: - intervalHandle.cancel() - logger.debug(f"Interval cleared for {debugKey}") - else: - # TODO: Find a better way than polling, as it's not very efficient. - # Consider using asyncio.Event or asyncio.Condition triggered by an incoming message - # as the state won't change unless I have a new message from the service - intervalHandle = loop.call_later( - pollingInterval / 1000, - lambda: ( - # logger.debug("Creating task for check_callback"), - asyncio.create_task(check_callback()), - )[-1], - ) - else: - logger.debug( - f"Future already done for {debugKey}, interval not rescheduled" - ) - except Exception as e: - future.set_exception(e) - logger.exception(f"Error in check_callback 2 for {debugKey}: {str(e)}") - - await check_callback() - # intervalHandle = loop.call_later( - # pollingInterval / 1000, - # lambda: ( - # logger.debug("Creating task for check_callback"), - # asyncio.create_task(check_callback()), - # )[ - # -1 - # ], # Return the task itself) - # ) - - async def timeout_handler(): - nonlocal future, intervalHandle - try: - if not future.done(): - if shouldThrowError: - future.set_exception( - Exception(f"Message could not be received for {debugKey}") - ) - logger.error(f"Error: Message could not be received for {debugKey}") - else: - future.set_result(None) - if intervalHandle: - intervalHandle.cancel() - logger.debug(f"Interval cleared due to timeout for {debugKey}") - except Exception as e: - future.set_exception(e) - logger.exception(f"Error in timeout_handler for {debugKey}: {str(e)}") - - # Schedule the timeout handler - timeoutHandle = loop.call_later( - timeOutDuration / 1000, - lambda: ( - logger.debug("Creating task for timeout_handler"), - asyncio.create_task(timeout_handler()), - )[-1], - ) - - try: - await future - finally: - if intervalHandle and not intervalHandle.cancelled(): - intervalHandle.cancel() - logger.debug(f"Interval canceled for {debugKey}") - if timeoutHandle and not timeoutHandle.cancelled(): - timeoutHandle.cancel() - logger.debug(f"Timeout canceled for {debugKey}") - - return await future - - def instantiateDataclass(dataclass_type: Type[Any], data: dict) -> Any: """ Instantiates a dataclass object from a dictionary, filtering out any unknown attributes. @@ -759,35 +98,22 @@ def instantiateDataclass(dataclass_type: Type[Any], data: dict) -> Any: return dataclass_type(**filtered_data) -def instantiateDataclassList( - dataclass_type: Type[Any], data_list: List[dict] -) -> List[Any]: - """ - Instantiates a list of dataclass objects from a list of dictionaries, - filtering out any unknown attributes. - - :param dataclass_type: The dataclass type to instantiate. - :param data_list: A list of dictionaries with data. - :return: A list of instantiated dataclass objects. - """ - # Get the set of valid field names for the dataclass - instances = [] - for data in data_list: - instances.append(instantiateDataclass(dataclass_type, data)) - return instances - - def isLocalFile(file): + logger.debug(f"Checking if file is local: {file}") + if os.path.isfile(file): + logger.debug(f"File exists locally: {file}") return True # Check if the string is a valid UUID if isValidUUID(file): + logger.debug(f"File is a valid UUID: {file}") return False # Check if the string is a valid URL parsed_url = urlparse(file) if parsed_url.scheme and parsed_url.netloc: + logger.debug(f"File is a valid URL: {file}") return False # Use the URL as is else: # Handle case with no scheme and no netloc @@ -799,6 +125,7 @@ def isLocalFile(file): # Check if it's a base64 string (with or without data URI prefix) if file.startswith("data:") or re.match(r"^[A-Za-z0-9+/]+={0,2}$", file): # Assume it's a base64 string (with or without data URI prefix) + logger.debug(f"File is a base64 string: {file}") return False # Assume it's a URL without scheme (e.g., 'example.com/some/path') @@ -806,10 +133,13 @@ def isLocalFile(file): file = f"https://{file}" parsed_url = urlparse(file) if parsed_url.netloc: # Now it should have a valid netloc + logger.debug(f"File is a URL without scheme: {file}") return False else: + logger.error(f"File or URL '{file}' not found or invalid") raise FileNotFoundError(f"File or URL '{file}' not found.") + logger.error(f"File or URL '{file}' not valid or not found") raise FileNotFoundError(f"File or URL '{file}' not valid or not found.") @@ -819,12 +149,18 @@ async def process_image( if image is None: return None elif isinstance(image, list): + logger.debug(f"Processing {len(image)} images in list") images = [] - for image in image: - images.append(await process_image(image)) + for img in image: + images.append(await process_image(img)) return images elif isinstance(image, UploadImageType): + logger.debug(f"Using uploaded image UUID: {image.imageUUID}") return image.imageUUID + if isLocalFile(image) and not image.startswith("http"): + logger.debug(f"Converting local file to base64: {image}") return await fileToBase64(image) + + logger.debug(f"Using image as-is: {image}") return image diff --git a/setup.py b/setup.py index 3149088..39193b0 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -6,7 +6,7 @@ setup( name="runware", license="MIT", - version="0.4.16", + version="0.5.0", author="Runware Inc.", author_email="python.sdk@runware.ai", description="The Python Runware SDK is used to run image inference with the Runware API, powered by the Runware inference platform. It can be used to generate images with text-to-image and image-to-image. It also allows the use of an existing gallery of models or selecting any model or LoRA from the CivitAI gallery. The API also supports upscaling, background removal, inpainting and outpainting, and a series of other ControlNet models.", @@ -36,6 +36,6 @@ install_requires=[ "aiofiles==23.2.1", "python-dotenv==1.0.1", - "websockets==12.0", + "websockets>=12.0,<16.0", ], ) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_base.py b/tests/test_base.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_server.py b/tests/test_server.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_types.py b/tests/test_types.py deleted file mode 100644 index a532276..0000000 --- a/tests/test_types.py +++ /dev/null @@ -1,221 +0,0 @@ -import pytest - -from typing import Any -from runware.types import ( - IControlNet, - IControlNetCanny, - IControlNetOpenPose, - IImageInference, - File, - RequireAtLeastOne, - RequireOnlyOne, - ListenerType, - EPreProcessor, - EOpenPosePreProcessor, - EControlMode, -) - - -def test_icontrol_net_union(): - canny_control_net = IControlNetCanny( - model='qatests:68487@08629', - weight=0.5, - startStep=0, - endStep=10, - guideImage="canny_image.png", - controlMode=EControlMode.BALANCED, - lowThresholdCanny=100, - highThresholdCanny=200, - - ) - assert isinstance(canny_control_net, IControlNet) - - hands_face_control_net = IControlNetOpenPose( - preprocessor=EOpenPosePreProcessor.openpose_face, - weight=0.7, - startStep=5, - endStep=15, - guideImage="hands_face_image.png", - controlMode=EControlMode.PROMPT, - includeHandsAndFaceOpenPose=True, - ) - assert isinstance(hands_face_control_net, IControlNet) - - -def test_irequest_image(): - request_image = IImageInference( - positivePrompt="A beautiful landscape", - width=512, - height=512, - model='qatests:68487@08629', - seedImage=File(b"image_data"), - maskImage="mask.png", - ) - assert request_image.positivePrompt == "A beautiful landscape" - assert request_image.width == 512 - assert request_image.height == 512 - assert request_image.model == 'qatests:68487@08629' - assert isinstance(request_image.seedImage, File) - assert request_image.maskImage == "mask.png" - - -def test_require_at_least_one(): - data = {"key1": "value1", "key2": "value2", "key3": "value3"} - obj = RequireAtLeastOne(data, ["key1", "key3"]) - assert obj["key1"] == "value1" - assert "key2" in obj - assert len(obj) == 3 - - -def test_require_only_one(): - data = {"key1": "value1"} - obj = RequireOnlyOne(data, ["key1"]) - assert obj["key1"] == "value1" - data = {"key1": "value1", "key2": "value2"} - with pytest.raises(ValueError): - RequireOnlyOne(data, ["key1", "key2"]) - - -def test_require_at_least_one_missing_keys(): - data = {"key1": "value1"} - obj = RequireAtLeastOne(data, ["key1", "key2"]) - assert obj["key1"] == "value1" - assert "key2" not in obj - assert len(obj) == 1 - - -def test_require_at_least_one_non_string_keys(): - data = {1: "value1", "key2": "value2"} # One key is an integer - obj = RequireAtLeastOne(data, [1, "key2"]) - assert obj[1] == "value1" - - -def test_require_at_least_one_extra_keys(): - data = {"key1": "value1", "key2": "value2", "extra_key": "extra"} - obj = RequireAtLeastOne(data, ["key1"]) - assert obj["extra_key"] == "extra" - assert obj["key1"] == "value1" - - -def test_require_at_least_one_non_dict_data(): - with pytest.raises(TypeError, match="data must be a dictionary"): - RequireAtLeastOne("some_string", ["key1"]) - - -def test_require_at_least_one_invalid_required_keys(): - data = {"key1": "value1"} - with pytest.raises(TypeError): - RequireAtLeastOne(data, 123) # 123 as an example of a non-iterable, non-string - - -def test_require_at_least_one_removing_key(): - data = {"key1": "value1", "key2": "value2"} - obj = RequireAtLeastOne(data, ["key1", "key2"]) - del obj["key1"] - assert "key2" in obj - assert len(obj) == 1 - - del obj["key2"] - assert len(obj) == 0 - - with pytest.raises( - ValueError, - match="At least one of the required keys must be present: key1, key2", - ): - RequireAtLeastOne({"extra_key": "extra"}, ["key1", "key2"]) - - -def test_require_only_one_adding_keys(): - data = {"key1": "value1"} - obj = RequireOnlyOne(data, ["key1"]) # Specify only one required key - - obj["key2"] = "value2" # Adding a non-required key should be allowed - assert "key2" in obj - assert len(obj) == 2 - - with pytest.raises(ValueError): - RequireOnlyOne({"key1": "value1", "key2": "value2"}, ["key1", "key2"]) - - -def test_listener_type(): - def on_message(msg: Any): - print(msg) - - listener = ListenerType("message_listener", on_message, group_key="group1") - assert listener.key == "message_listener" - assert listener.group_key == "group1" - listener.listener("Hello") # Prints "Hello" - - -def test_icontrol_net_canny_creation(): - control_net_canny = IControlNetCanny( - model='civitai:38784@44716', - weight=0.8, - startStep=2, - endStep=8, - guideImage="canny_guide_image.png", - controlMode=EControlMode.PROMPT, - lowThresholdCanny=100, - highThresholdCanny=200, - ) - assert isinstance(control_net_canny, IControlNetCanny) - assert isinstance(control_net_canny, IControlNet) - assert control_net_canny.preprocessor == EPreProcessor.canny - assert control_net_canny.weight == 0.8 - assert control_net_canny.startStep == 2 - assert control_net_canny.endStep == 8 - assert control_net_canny.guideImage == "canny_guide_image.png" - assert control_net_canny.controlMode == EControlMode.PROMPT - assert control_net_canny.lowThresholdCanny == 100 - assert control_net_canny.highThresholdCanny == 200 - - -def test_icontrol_net_hands_and_face_creation(): - control_net_hands_and_face = IControlNetOpenPose( - preprocessor=EOpenPosePreProcessor.openpose_face, - weight=0.6, - startStep=1, - endStep=9, - guideImage="hands_face_guide_image_unprocessed.png", - controlMode=EControlMode.CONTROL_NET, - includeHandsAndFaceOpenPose=True, - ) - assert isinstance(control_net_hands_and_face, IControlNetOpenPose) - assert isinstance(control_net_hands_and_face, IControlNet) - assert ( - control_net_hands_and_face.preprocessor == EOpenPosePreProcessor.openpose_face - ) - assert control_net_hands_and_face.weight == 0.6 - assert control_net_hands_and_face.startStep == 1 - assert control_net_hands_and_face.endStep == 9 - assert ( - control_net_hands_and_face.guideImage - == "hands_face_guide_image_unprocessed.png" - ) - assert control_net_hands_and_face.controlMode == EControlMode.CONTROL_NET - assert control_net_hands_and_face.includeHandsAndFaceOpenPose == True - - -def test_icontrol_net_union(): - control_net_canny = IControlNetCanny( - model='qatests:68487@08629', - weight=0.7, - startStep=3, - endStep=7, - guideImage="canny_guide_image.png", - controlMode=EControlMode.BALANCED, - lowThresholdCanny=150, - highThresholdCanny=250, - ) - control_net_hands_and_face = IControlNetOpenPose( - preprocessor=EOpenPosePreProcessor.openpose_full, - weight=0.9, - startStep=4, - endStep=6, - guideImage="hands_face_guide_image_unprocessed.png", - controlMode=EControlMode.PROMPT, - includeHandsAndFaceOpenPose=False, - ) - control_nets = [control_net_canny, control_net_hands_and_face] - for control_net in control_nets: - assert isinstance(control_net, IControlNet) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index dd8e9f6..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,297 +0,0 @@ -import asyncio -import json -from unittest.mock import Mock, patch, MagicMock, AsyncMock -import datetime -import base64 -from unittest import TestCase, mock -import pytest -from runware.utils import ( - removeFromAray, - getIntervalWithPromise, - fileToBase64, - getUUID, - isValidUUID, - getTaskType, - evaluateToBoolean, - compact, - getPreprocessorType, - accessDeepObject, - delay, - MockFile, - generateString, - remove1Mutate, - removeListener, - removeAllKeyListener, -) -from runware.types import Environment, EPreProcessor, EPreProcessorGroup - - -# @pytest.fixture -# def mock_file(): -# # Setup code before each test function -# with patch( -# "runware.utils.datetime", -# MagicMock(now=MagicMock(return_value=datetime.datetime(2020, 1, 1))), -# ): -# yield MockFile() # This replaces setUp and yields the MockFile object for each test - -# def test_create_defaults(mock_file): - - -def test_create_defaults(): - with patch( - "runware.utils.datetime", - MagicMock(now=MagicMock(return_value=datetime.datetime(2020, 1, 1))), - ): - # Test default parameters - mock_file = MockFile() - blob = mock_file.create() - assert blob.name == "mock.txt" - assert len(blob) == 1024 # Testing __len__ - assert str(blob) == "a" * 1024 # Testing __str__ - assert blob.size() == 1024 # Testing size() - assert blob.type == "plain/txt" - # assert blob.lastModifiedDate == datetime.datetime( - # 2020, 1, 1 - # ) # Check for mocked date - - -def test_create_custom_params(): - # Test custom parameters - with patch( - "runware.utils.datetime", - MagicMock(now=MagicMock(return_value=datetime.datetime(2020, 1, 1))), - ): - mock_file = MockFile() - blob = mock_file.create("example.txt", 2048, "text/plain") - print(f"Blob:{blob.lastModifiedDate}:") - assert blob.name == "example.txt" - assert len(blob) == 2048 - assert str(blob) == "a" * 2048 - assert blob.size() == 2048 - assert blob.type == "text/plain" - # assert blob.lastModifiedDate == datetime.datetime( - # 2020, 1, 1 - # ) # Check for mocked date - - -def test_create_empty_file(): - # Test creating an empty file - mock_file = MockFile() - blob = mock_file.create("empty.txt", 0, "text/plain") - assert blob.name == "empty.txt" - assert len(blob) == 0 - assert str(blob) == "" - assert blob.size() == 0 - assert blob.type == "text/plain" - # assert blob.lastModifiedDate == datetime.datetime( - # 2020, 1, 1 - # ) # Check for mocked date - - -def test_removeFromAray(): - arr = [1, 2, 3, 4] - removeFromAray(arr, 2) - assert arr == [1, 3, 4] - - -def test_getUUID(): - uuid = getUUID() - assert isinstance(uuid, str) - assert isValidUUID(uuid) - - -def test_isValidUUID(): - valid_uuid = "123e4567-e89b-12d3-a456-426655440000" - invalid_uuid = "invalid-uuid" - assert isValidUUID(valid_uuid) - assert not isValidUUID(invalid_uuid) - - -def test_getTaskType(): - assert getTaskType("prompt", None, None, None) == 1 - assert getTaskType("prompt", None, None, "image") == 2 - assert getTaskType("prompt", None, "mask", "image") == 3 - assert getTaskType("prompt", "controlnet", None, None) == 9 - assert getTaskType("prompt", "controlnet", None, "image") == 10 - assert getTaskType("prompt", "controlnet", "mask", "image") == 10 - - -def test_evaluateToBoolean(): - assert evaluateToBoolean(True, True, True) - assert not evaluateToBoolean(True, False, True) - - -def test_compact(): - assert compact(True, {"a": 1}) == {"a": 1} - assert compact(False, {"a": 1}) == {} - - -def test_getPreprocessorType(): - assert getPreprocessorType(EPreProcessor.canny) == EPreProcessorGroup.canny - assert getPreprocessorType(EPreProcessor.depth_leres) == EPreProcessorGroup.depth - assert ( - getPreprocessorType(EPreProcessor.lineart_anime) - == EPreProcessorGroup.lineart_anime - ) - assert ( - getPreprocessorType(EPreProcessor.openpose_face) == EPreProcessorGroup.openpose - ) - assert getPreprocessorType(EPreProcessor.seg_ofade20k) == EPreProcessorGroup.seg - - -def test_accessDeepObject(): - data = { - "a": {"b": [1, 2, 3], "c": {"d": "text"}}, - "e": [{"f": "value1"}, {"f": "value2"}], - } - - # Test existing keys and array index - assert accessDeepObject("a.b[1]", data) == 2 - assert accessDeepObject("a.c.d", data) == "text" - - # Test non-existent keys - assert accessDeepObject("a.x", data, useZero=False) is None - assert accessDeepObject("a.b[3]", data, useZero=True) == 0 - - # Test boundary conditions for arrays - assert accessDeepObject("e[0].f", data) == "value1" - assert accessDeepObject("e[1].f", data) == "value2" - assert accessDeepObject("e[2].f", data, useZero=False) is None - - # Test the shouldReturnString flag - assert accessDeepObject("a", data, shouldReturnString=True) == json.dumps(data["a"]) - - # Test invalid key format - assert accessDeepObject("a..b", data, useZero=True) == 0 - assert accessDeepObject("a.[b]", data, useZero=False) is None - - -def test_generateString(): - assert generateString(3) == "aaa" - - -def test_remove1Mutate(): - arr = [1, 2, 3, 4] - remove1Mutate(arr, 2) - assert arr == [1, 3, 4] - - -def test_removeListener(): - listeners = [Mock(key="a"), Mock(key="b"), Mock(key="c")] - listener = Mock(key="b") - updated_listeners = removeListener(listeners, listener) - assert len(updated_listeners) == 2 - assert listener not in updated_listeners - - -def test_removeAllKeyListener(): - listeners = [Mock(key="a"), Mock(key="b"), Mock(key="b"), Mock(key="c")] - updated_listeners = removeAllKeyListener(listeners, "b") - assert len(updated_listeners) == 2 - assert Mock(key="b") not in updated_listeners - - -@pytest.mark.asyncio -async def test_file_to_base64_success(tmpdir): - # Create a temporary file - file_path = tmpdir / "test_image.jpg" - file_contents = b"test_file_contents" # Sample content - # Use the built-in 'open' to write - with open(file_path, "wb") as f: - f.write(file_contents) - - # Patch aiofiles.open to return an AsyncMock - with patch("aiofiles.open") as mock_open: - mock_file = AsyncMock() - mock_file.read = AsyncMock(return_value=file_contents) - mock_open.return_value.__aenter__ = AsyncMock(return_value=mock_file) - result = await fileToBase64(str(file_path)) - - # Assert the expected Base64 representation of file_contents - expected_base64 = base64.b64encode(file_contents).decode("utf-8") - assert result == expected_base64 - - -@pytest.mark.asyncio -async def test_file_to_base64_file_not_found(tmpdir): - file_path = tmpdir / "nonexistent_file.txt" - # Patch aiofiles.open to raise FileNotFoundError - with patch("aiofiles.open", side_effect=FileNotFoundError): - with pytest.raises(FileNotFoundError) as error_info: - await fileToBase64(str(file_path)) - - assert str(error_info.value) == f"The file at {file_path} does not exist." - - -@pytest.mark.asyncio -async def test_delay(): - with patch("asyncio.sleep") as mock_sleep: - await delay(1.5) - mock_sleep.assert_called_once_with(1.5) - - -@pytest.mark.asyncio -async def test_immediate_resolution(): - async def callback(params): - params["resolve"]("resolved immediately") - return True # Stop the interval immediately - - result = await getIntervalWithPromise( - callback, timeOutDuration=5000, pollingInterval=100 - ) - assert ( - result == "resolved immediately" - ), "The future should have been resolved immediately." - - -@pytest.mark.asyncio -async def test_timeout(): - async def callback(params): - await asyncio.sleep( - 1 - ) # Simulate delay longer than polling but not long enough to resolve - return False # Continue the interval - - with pytest.raises(asyncio.TimeoutError): - await getIntervalWithPromise( - callback, - debugKey="timeout_test", - timeOutDuration=500, - shouldThrowError=True, - ) - - -@pytest.mark.asyncio -async def test_callback_error_handling(): - async def callback(params): - raise Exception("Deliberate exception") - - with pytest.raises(Exception) as exc_info: - await getIntervalWithPromise( - callback, debugKey="error_test", timeOutDuration=2000 - ) - assert "Deliberate exception" in str( - exc_info.value - ), "The specific error should be caught and raised." - - -@pytest.mark.asyncio -async def test_proper_interval_handling(): - call_count = 0 - - async def callback(params): - nonlocal call_count - call_count += 1 - if call_count >= 3: # Resolve after 3 calls - params["resolve"]("resolved after several intervals") - return True - return False - - result = await getIntervalWithPromise( - callback, debugKey="interval_test", timeOutDuration=2000, pollingInterval=500 - ) - assert ( - result == "resolved after several intervals" - ), "Should resolve after exactly 3 intervals." - assert call_count == 3, "Callback should be called exactly 3 times."