diff --git a/invokeai/app/api/routers/utilities.py b/invokeai/app/api/routers/utilities.py index 568546603ab..06f3e4a44c2 100644 --- a/invokeai/app/api/routers/utilities.py +++ b/invokeai/app/api/routers/utilities.py @@ -74,6 +74,9 @@ class ExpandPromptRequest(BaseModel): model_key: str max_tokens: int = Field(default=300, ge=1, le=2048) system_prompt: str | None = None + task_id: str | None = Field( + default=None, description="Client-supplied task ID used to correlate socket progress events to this request" + ) class ExpandPromptResponse(BaseModel): @@ -90,14 +93,25 @@ def _resolve_model_path(model_config_path: str) -> Path: return (base_models_path / model_path).resolve() -def _run_expand_prompt(prompt: str, model_key: str, max_tokens: int, system_prompt: str | None) -> str: +def _run_expand_prompt( + prompt: str, + model_key: str, + max_tokens: int, + system_prompt: str | None, + task_id: str | None, + user_id: str, +) -> str: """Run text LLM inference synchronously (called from thread).""" model_manager = ApiDependencies.invoker.services.model_manager + events = ApiDependencies.invoker.services.events model_config = model_manager.store.get_model(model_key) if model_config.type != ModelType.TextLLM: raise ValueError(f"Model '{model_key}' is not a TextLLM model (got {model_config.type})") + if task_id is not None: + events.emit_llm_task_progress(task_id=task_id, user_id=user_id, phase="loading_model", message="Loading model") + with _model_load_lock: loaded_model = model_manager.load.load_model(model_config) @@ -107,12 +121,28 @@ def _run_expand_prompt(prompt: str, model_key: str, max_tokens: int, system_prom pipeline = TextLLMPipeline(model, tokenizer) model_device = next(model.parameters()).device + + progress_callback = None + if task_id is not None: + + def progress_callback(current: int, total: int) -> None: + events.emit_llm_task_progress( + task_id=task_id, + user_id=user_id, + phase="generating", + message="Generating", + percentage=(current / total) if total > 0 else None, + current_tokens=current, + total_tokens=total, + ) + output = pipeline.run( prompt=prompt, system_prompt=system_prompt or DEFAULT_SYSTEM_PROMPT, max_new_tokens=max_tokens, device=model_device, dtype=TorchDevice.choose_torch_dtype(), + progress_callback=progress_callback, ) return output @@ -127,6 +157,7 @@ def _run_expand_prompt(prompt: str, model_key: str, max_tokens: int, system_prom ) async def expand_prompt(current_user: CurrentUserOrDefault, body: ExpandPromptRequest) -> ExpandPromptResponse: """Expand a brief prompt into a detailed image generation prompt using a text LLM.""" + events = ApiDependencies.invoker.services.events try: expanded = await asyncio.to_thread( _run_expand_prompt, @@ -134,13 +165,23 @@ async def expand_prompt(current_user: CurrentUserOrDefault, body: ExpandPromptRe body.model_key, body.max_tokens, body.system_prompt, + body.task_id, + current_user.user_id, ) + if body.task_id is not None: + events.emit_llm_task_complete(task_id=body.task_id, user_id=current_user.user_id) return ExpandPromptResponse(expanded_prompt=expanded) except UnknownModelException: + if body.task_id is not None: + events.emit_llm_task_error(task_id=body.task_id, user_id=current_user.user_id, error="Model not found") raise HTTPException(status_code=404, detail=f"Model '{body.model_key}' not found") except ValueError as e: + if body.task_id is not None: + events.emit_llm_task_error(task_id=body.task_id, user_id=current_user.user_id, error=str(e)) raise HTTPException(status_code=422, detail=str(e)) except Exception as e: + if body.task_id is not None: + events.emit_llm_task_error(task_id=body.task_id, user_id=current_user.user_id, error=str(e)) logger.error(f"Error expanding prompt: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -152,6 +193,9 @@ class ImageToPromptRequest(BaseModel): image_name: str model_key: str instruction: str = "Describe this image in detail for use as an AI image generation prompt." + task_id: str | None = Field( + default=None, description="Client-supplied task ID used to correlate socket progress events to this request" + ) class ImageToPromptResponse(BaseModel): @@ -159,14 +203,24 @@ class ImageToPromptResponse(BaseModel): error: str | None = None -def _run_image_to_prompt(image_name: str, model_key: str, instruction: str) -> str: +def _run_image_to_prompt( + image_name: str, + model_key: str, + instruction: str, + task_id: str | None, + user_id: str, +) -> str: """Run LLaVA OneVision inference synchronously (called from thread).""" model_manager = ApiDependencies.invoker.services.model_manager + events = ApiDependencies.invoker.services.events model_config = model_manager.store.get_model(model_key) if model_config.type != ModelType.LlavaOnevision: raise ValueError(f"Model '{model_key}' is not a LLaVA OneVision model (got {model_config.type})") + if task_id is not None: + events.emit_llm_task_progress(task_id=task_id, user_id=user_id, phase="loading_model", message="Loading model") + with _model_load_lock: loaded_model = model_manager.load.load_model(model_config) @@ -185,11 +239,27 @@ def _run_image_to_prompt(image_name: str, model_key: str, instruction: str) -> s pipeline = LlavaOnevisionPipeline(model, processor) model_device = next(model.parameters()).device + + progress_callback = None + if task_id is not None: + + def progress_callback(current: int, total: int) -> None: + events.emit_llm_task_progress( + task_id=task_id, + user_id=user_id, + phase="generating", + message="Generating", + percentage=(current / total) if total > 0 else None, + current_tokens=current, + total_tokens=total, + ) + output = pipeline.run( prompt=instruction, images=[image], device=model_device, dtype=TorchDevice.choose_torch_dtype(), + progress_callback=progress_callback, ) return output @@ -208,20 +278,33 @@ async def image_to_prompt(current_user: CurrentUserOrDefault, body: ImageToPromp # via this endpoint (mirrors the policy in routers/images.py). assert_image_read_access(body.image_name, current_user) + events = ApiDependencies.invoker.services.events try: prompt = await asyncio.to_thread( _run_image_to_prompt, body.image_name, body.model_key, body.instruction, + body.task_id, + current_user.user_id, ) + if body.task_id is not None: + events.emit_llm_task_complete(task_id=body.task_id, user_id=current_user.user_id) return ImageToPromptResponse(prompt=prompt) except UnknownModelException: + if body.task_id is not None: + events.emit_llm_task_error(task_id=body.task_id, user_id=current_user.user_id, error="Model not found") raise HTTPException(status_code=404, detail=f"Model '{body.model_key}' not found") except ImageFileNotFoundException: + if body.task_id is not None: + events.emit_llm_task_error(task_id=body.task_id, user_id=current_user.user_id, error="Image not found") raise HTTPException(status_code=404, detail=f"Image '{body.image_name}' not found") except (ValueError, TypeError) as e: + if body.task_id is not None: + events.emit_llm_task_error(task_id=body.task_id, user_id=current_user.user_id, error=str(e)) raise HTTPException(status_code=422, detail=str(e)) except Exception as e: + if body.task_id is not None: + events.emit_llm_task_error(task_id=body.task_id, user_id=current_user.user_id, error=str(e)) logger.error(f"Error generating prompt from image: {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 5783b804c0b..2c6e428d8f0 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -24,6 +24,10 @@ InvocationErrorEvent, InvocationProgressEvent, InvocationStartedEvent, + LLMTaskCompleteEvent, + LLMTaskErrorEvent, + LLMTaskEventBase, + LLMTaskProgressEvent, ModelEventBase, ModelInstallCancelledEvent, ModelInstallCompleteEvent, @@ -87,6 +91,8 @@ class BulkDownloadSubscriptionEvent(BaseModel): BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent} +LLM_TASK_EVENTS = {LLMTaskProgressEvent, LLMTaskCompleteEvent, LLMTaskErrorEvent} + class SocketIO: _sub_queue = "subscribe_queue" @@ -115,6 +121,7 @@ def __init__(self, app: FastAPI): register_events(QUEUE_EVENTS, self._handle_queue_event) register_events(MODEL_EVENTS, self._handle_model_event) register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event) + register_events(LLM_TASK_EVENTS, self._handle_llm_task_event) async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> bool: """Handle socket connection and authenticate the user. @@ -345,6 +352,18 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None: await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json")) + async def _handle_llm_task_event(self, event: FastAPIEvent[LLMTaskEventBase]) -> None: + """Route LLM utility task events privately to the originating user + admins. + + These events carry partial prompt content (via the task_id correlation) and + must not be broadcast to other users. + """ + event_name, event_data = event + user_room = f"user:{event_data.user_id}" + payload = event_data.model_dump(mode="json") + await self._sio.emit(event=event_name, data=payload, room=user_room) + await self._sio.emit(event=event_name, data=payload, room="admin") + async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None: event_name, event_data = event # Route to user-specific + admin rooms so that other authenticated diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 935b422a732..1dfd02728da 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -1,7 +1,7 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, @@ -19,6 +19,9 @@ InvocationErrorEvent, InvocationProgressEvent, InvocationStartedEvent, + LLMTaskCompleteEvent, + LLMTaskErrorEvent, + LLMTaskProgressEvent, ModelInstallCancelledEvent, ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, @@ -191,6 +194,41 @@ def emit_model_install_error(self, job: "ModelInstallJob") -> None: # endregion + # region LLM utility tasks + + def emit_llm_task_progress( + self, + task_id: str, + user_id: str, + phase: Literal["loading_model", "generating"], + message: str, + percentage: float | None = None, + current_tokens: int | None = None, + total_tokens: int | None = None, + ) -> None: + """Emit a progress event for an LLM utility task (expand-prompt, image-to-prompt).""" + self.dispatch( + LLMTaskProgressEvent( + task_id=task_id, + user_id=user_id, + phase=phase, + message=message, + percentage=percentage, + current_tokens=current_tokens, + total_tokens=total_tokens, + ) + ) + + def emit_llm_task_complete(self, task_id: str, user_id: str) -> None: + """Emit a completion event for an LLM utility task.""" + self.dispatch(LLMTaskCompleteEvent(task_id=task_id, user_id=user_id)) + + def emit_llm_task_error(self, task_id: str, user_id: str, error: str) -> None: + """Emit an error event for an LLM utility task.""" + self.dispatch(LLMTaskErrorEvent(task_id=task_id, user_id=user_id, error=error)) + + # endregion + # region Bulk image download def emit_bulk_download_started( diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 0c530f9a2f7..6b06c7be060 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Literal, Optional, Protocol, TypeAlias, TypeVar from fastapi_events.handlers.local import local_handler from fastapi_events.registry.payload_schema import registry as payload_schema @@ -689,6 +689,51 @@ def build( ) +class LLMTaskEventBase(EventBase): + """Base class for LLM utility task events (expand-prompt, image-to-prompt). + + These events are correlated to a specific HTTP request via a client-supplied + task_id and routed privately to the originating user so partial prompt content + is not broadcast. + """ + + task_id: str = Field(description="Client-supplied task ID correlating events to a single request") + user_id: str = Field(default="system", description="ID of the user who initiated the task") + + +@payload_schema.register +class LLMTaskProgressEvent(LLMTaskEventBase): + """Event model for llm_task_progress""" + + __event_name__ = "llm_task_progress" + + phase: Literal["loading_model", "generating"] = Field(description="Which phase of the task is in progress") + message: str = Field(description="A short message describing the current phase") + percentage: float | None = Field( + default=None, ge=0, le=1, description="Progress fraction in [0, 1]; omit for indeterminate progress" + ) + current_tokens: int | None = Field(default=None, description="Number of tokens generated so far (generating phase)") + total_tokens: int | None = Field( + default=None, description="Max tokens the request will generate (generating phase)" + ) + + +@payload_schema.register +class LLMTaskCompleteEvent(LLMTaskEventBase): + """Event model for llm_task_complete""" + + __event_name__ = "llm_task_complete" + + +@payload_schema.register +class LLMTaskErrorEvent(LLMTaskEventBase): + """Event model for llm_task_error""" + + __event_name__ = "llm_task_error" + + error: str = Field(description="The error message") + + @payload_schema.register class RecallParametersUpdatedEvent(QueueEventBase): """Event model for recall_parameters_updated""" diff --git a/invokeai/backend/llava_onevision_pipeline.py b/invokeai/backend/llava_onevision_pipeline.py index 93614f40654..abb136ba9fe 100644 --- a/invokeai/backend/llava_onevision_pipeline.py +++ b/invokeai/backend/llava_onevision_pipeline.py @@ -1,6 +1,11 @@ +import threading +from typing import Callable + import torch from PIL.Image import Image -from transformers import LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor +from transformers import LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor, TextIteratorStreamer + +ProgressCallback = Callable[[int, int], None] class LlavaOnevisionPipeline: @@ -10,7 +15,15 @@ def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor self._vllm_model = vllm_model self._processor = processor - def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str: + def run( + self, + prompt: str, + images: list[Image], + device: torch.device, + dtype: torch.dtype, + max_new_tokens: int = 400, + progress_callback: ProgressCallback | None = None, + ) -> str: # TODO(ryand): Tune the max number of images that are useful for the model. if len(images) > 3: raise ValueError( @@ -21,15 +34,46 @@ def run(self, prompt: str, images: list[Image], device: torch.device, dtype: tor # Define a chat history and use `apply_chat_template` to get correctly formatted prompt. # "content" is a list of dicts with types "text" or "image". content = [{"type": "text", "text": prompt}] - # Add the correct number of images. for _ in images: content.append({"type": "image"}) conversation = [{"role": "user", "content": content}] - prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) - inputs = self._processor(images=images or None, text=prompt, return_tensors="pt").to(device=device, dtype=dtype) - output = self._vllm_model.generate(**inputs, max_new_tokens=400, do_sample=False) - output_str: str = self._processor.decode(output[0][2:], skip_special_tokens=True) - # The output_str will include the prompt, so we extract the response. - response = output_str.split("assistant\n", 1)[1].strip() - return response + formatted_prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) + inputs = self._processor(images=images or None, text=formatted_prompt, return_tensors="pt").to( + device=device, dtype=dtype + ) + + tokenizer = self._processor.tokenizer + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + generation_kwargs = dict( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + streamer=streamer, + ) + + generation_error: list[BaseException] = [] + + def _generate() -> None: + try: + self._vllm_model.generate(**generation_kwargs) + except BaseException as e: + generation_error.append(e) + + thread = threading.Thread(target=_generate, daemon=True) + thread.start() + + chunks: list[str] = [] + for chunk in streamer: + if not chunk: + continue + chunks.append(chunk) + if progress_callback is not None: + token_count = len(tokenizer.encode("".join(chunks), add_special_tokens=False)) + progress_callback(min(token_count, max_new_tokens), max_new_tokens) + + thread.join() + if generation_error: + raise generation_error[0] + + return "".join(chunks).strip() diff --git a/invokeai/backend/text_llm_pipeline.py b/invokeai/backend/text_llm_pipeline.py index 69815c1a7f7..d0eb534adb4 100644 --- a/invokeai/backend/text_llm_pipeline.py +++ b/invokeai/backend/text_llm_pipeline.py @@ -1,5 +1,8 @@ +import threading +from typing import Callable + import torch -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase, TextIteratorStreamer DEFAULT_SYSTEM_PROMPT = ( "You are an expert prompt writer for AI image generation. " @@ -8,6 +11,9 @@ ) +ProgressCallback = Callable[[int, int], None] + + class TextLLMPipeline: """A wrapper for a causal language model + tokenizer for text generation.""" @@ -22,6 +28,7 @@ def run( max_new_tokens: int = 300, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float16, + progress_callback: ProgressCallback | None = None, ) -> str: # Build messages for chat template if supported, otherwise use raw prompt. if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template is not None: @@ -33,24 +40,51 @@ def run( messages, tokenize=False, add_generation_prompt=True ) else: - # Fallback for models without chat template if system_prompt: formatted_prompt = f"{system_prompt}\n\nUser: {prompt}\nAssistant:" else: formatted_prompt = prompt inputs = self._tokenizer(formatted_prompt, return_tensors="pt").to(device=device) - output = self._model.generate( + + streamer = TextIteratorStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True) + generation_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, + streamer=streamer, ) - # Decode only the newly generated tokens (exclude the input prompt tokens). - input_length = inputs["input_ids"].shape[1] - generated_tokens = output[0][input_length:] - response = self._tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() + # model.generate blocks until done; run it in a thread so we can consume the + # streamer iteratively and emit progress. + generation_error: list[BaseException] = [] + + def _generate() -> None: + try: + self._model.generate(**generation_kwargs) + except BaseException as e: + generation_error.append(e) + + thread = threading.Thread(target=_generate, daemon=True) + thread.start() + + chunks: list[str] = [] + token_count = 0 + for chunk in streamer: + if not chunk: + continue + chunks.append(chunk) + # The streamer yields decoded text chunks rather than individual tokens. + # Re-tokenizing each chunk to count tokens is expensive; instead approximate + # by re-tokenizing the accumulated text. This is exact enough for a progress bar. + token_count = len(self._tokenizer.encode("".join(chunks), add_special_tokens=False)) + if progress_callback is not None: + progress_callback(min(token_count, max_new_tokens), max_new_tokens) + + thread.join() + if generation_error: + raise generation_error[0] - return response + return "".join(chunks).strip() diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index e13946511e2..e44571afc1d 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -22178,6 +22178,18 @@ } ], "title": "System Prompt" + }, + "task_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Task Id", + "description": "Client-supplied task ID used to correlate socket progress events to this request" } }, "type": "object", @@ -34684,6 +34696,18 @@ "type": "string", "title": "Instruction", "default": "Describe this image in detail for use as an AI image generation prompt." + }, + "task_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Task Id", + "description": "Client-supplied task ID used to correlate socket progress events to this request" } }, "type": "object", @@ -42623,6 +42647,144 @@ "type": "object" }, "JsonValue": {}, + "LLMTaskCompleteEvent": { + "description": "Event model for llm_task_complete", + "properties": { + "timestamp": { + "description": "The timestamp of the event", + "title": "Timestamp", + "type": "integer" + }, + "task_id": { + "description": "Client-supplied task ID correlating events to a single request", + "title": "Task Id", + "type": "string" + }, + "user_id": { + "default": "system", + "description": "ID of the user who initiated the task", + "title": "User Id", + "type": "string" + } + }, + "required": ["timestamp", "task_id", "user_id"], + "title": "LLMTaskCompleteEvent", + "type": "object" + }, + "LLMTaskErrorEvent": { + "description": "Event model for llm_task_error", + "properties": { + "timestamp": { + "description": "The timestamp of the event", + "title": "Timestamp", + "type": "integer" + }, + "task_id": { + "description": "Client-supplied task ID correlating events to a single request", + "title": "Task Id", + "type": "string" + }, + "user_id": { + "default": "system", + "description": "ID of the user who initiated the task", + "title": "User Id", + "type": "string" + }, + "error": { + "description": "The error message", + "title": "Error", + "type": "string" + } + }, + "required": ["timestamp", "task_id", "user_id", "error"], + "title": "LLMTaskErrorEvent", + "type": "object" + }, + "LLMTaskProgressEvent": { + "description": "Event model for llm_task_progress", + "properties": { + "timestamp": { + "description": "The timestamp of the event", + "title": "Timestamp", + "type": "integer" + }, + "task_id": { + "description": "Client-supplied task ID correlating events to a single request", + "title": "Task Id", + "type": "string" + }, + "user_id": { + "default": "system", + "description": "ID of the user who initiated the task", + "title": "User Id", + "type": "string" + }, + "phase": { + "description": "Which phase of the task is in progress", + "enum": ["loading_model", "generating"], + "title": "Phase", + "type": "string" + }, + "message": { + "description": "A short message describing the current phase", + "title": "Message", + "type": "string" + }, + "percentage": { + "anyOf": [ + { + "maximum": 1, + "minimum": 0, + "type": "number" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Progress fraction in [0, 1]; omit for indeterminate progress", + "title": "Percentage" + }, + "current_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of tokens generated so far (generating phase)", + "title": "Current Tokens" + }, + "total_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Max tokens the request will generate (generating phase)", + "title": "Total Tokens" + } + }, + "required": [ + "timestamp", + "task_id", + "user_id", + "phase", + "message", + "percentage", + "current_tokens", + "total_tokens" + ], + "title": "LLMTaskProgressEvent", + "type": "object" + }, "LaMaInfillInvocation": { "category": "inpaint", "class": "invocation", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 3e88d460e55..da256e17fb1 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -392,7 +392,9 @@ "noTextLLMInstalledDescription": "Prompt expansion needs a Text LLM (causal language model). We recommend Qwen2.5-1.5B-Instruct (~3 GB) — small, fast, and available as a starter model.", "noVisionModelInstalledTitle": "No vision model installed", "noVisionModelInstalledDescription": "Image-to-prompt needs a vision-language model (e.g. LLaVA Onevision). The 0.5B starter (~1 GB) is the lightweight default.", - "openModelManager": "Open Model Manager" + "openModelManager": "Open Model Manager", + "llmTaskLoadingModel": "Loading model…", + "llmTaskGenerating": "Generating…" }, "queue": { "queue": "Queue", diff --git a/invokeai/frontend/web/src/features/prompt/ExpandPromptButton.tsx b/invokeai/frontend/web/src/features/prompt/ExpandPromptButton.tsx index e0f035963a2..272b7f6810b 100644 --- a/invokeai/frontend/web/src/features/prompt/ExpandPromptButton.tsx +++ b/invokeai/frontend/web/src/features/prompt/ExpandPromptButton.tsx @@ -18,6 +18,7 @@ import { useDisclosure } from 'common/hooks/useBoolean'; import { positivePromptChanged, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice'; import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore'; import { ModelPicker } from 'features/parameters/components/ModelPicker'; +import { LLMTaskProgressDisplay } from 'features/prompt/LLMTaskProgressDisplay'; import { setPromptUndo } from 'features/prompt/promptUndo'; import { navigationApi } from 'features/ui/layouts/navigation-api'; import { memo, useCallback, useState } from 'react'; @@ -26,6 +27,8 @@ import { PiSparkleBold } from 'react-icons/pi'; import { useExpandPromptMutation } from 'services/api/endpoints/utilities'; import { useTextLLMModels } from 'services/api/hooks/modelsByType'; import type { AnyModelConfig } from 'services/api/types'; +import { clearLLMTaskState } from 'services/events/stores'; +import { v4 as uuidv4 } from 'uuid'; const loadingStyles: SystemStyleObject = { svg: { animation: spinAnimation }, @@ -38,6 +41,7 @@ export const ExpandPromptButton = memo(() => { const [modelConfigs] = useTextLLMModels(); const popover = useDisclosure(false); const [selectedModel, setSelectedModel] = useState(undefined); + const [taskId, setTaskId] = useState(null); const [expandPrompt, { isLoading }] = useExpandPromptMutation(); const hasModels = modelConfigs.length > 0; @@ -50,10 +54,13 @@ export const ExpandPromptButton = memo(() => { if (!selectedModel || !prompt.trim()) { return; } + const newTaskId = uuidv4(); + setTaskId(newTaskId); try { const result = await expandPrompt({ prompt, model_key: selectedModel.key, + task_id: newTaskId, }).unwrap(); if (result.expanded_prompt) { setPromptUndo(prompt); @@ -62,6 +69,9 @@ export const ExpandPromptButton = memo(() => { popover.close(); } catch { // Error is handled by RTK Query + } finally { + clearLLMTaskState(newTaskId); + setTaskId(null); } }, [selectedModel, prompt, expandPrompt, dispatch, popover]); @@ -110,6 +120,7 @@ export const ExpandPromptButton = memo(() => { onChange={handleModelChange} placeholder={t('prompt.selectTextLLM')} /> + {isLoading ? : null}