import contextlib import copy import asyncio import functools import inspect import threading import time from pathlib import Path from typing import Any, Sequence from shared.api import GeneratedArtifact, GenerationError, GenerationResult, PreviewUpdate, SessionJob, WanGPSession, _pushd _NO_YIELDED_RESULT = object() _GRADIO_LOG_PATCH_LOCK = threading.Lock() _ORIGINAL_GRADIO_LOG_MESSAGE = None _WRAPPED_LOG_LOCAL = threading.local() def _buffered_gradio_log_message(message: str, title: str, level: str = "info", duration: float | None = 10, visible: bool = True): from gradio.context import LocalContext blocks = LocalContext.blocks.get() event_id = LocalContext.event_id.get() if blocks is not None and event_id is not None: return _ORIGINAL_GRADIO_LOG_MESSAGE(message, title=title, level=level, duration=duration, visible=visible) owner = getattr(_WRAPPED_LOG_LOCAL, "owner", None) call_id = str(getattr(_WRAPPED_LOG_LOCAL, "call_id", "") or "").strip() if owner is not None and len(call_id) > 0: state = owner._get_wrapped_call(call_id) if state is not None: state.queue_log_message(message=message, title=title, level=level, duration=duration, visible=visible) return return _ORIGINAL_GRADIO_LOG_MESSAGE(message, title=title, level=level, duration=duration, visible=visible) def _ensure_gradio_log_message_patch() -> None: global _ORIGINAL_GRADIO_LOG_MESSAGE with _GRADIO_LOG_PATCH_LOCK: if _ORIGINAL_GRADIO_LOG_MESSAGE is not None: return import gradio.helpers as gr_helpers _ORIGINAL_GRADIO_LOG_MESSAGE = gr_helpers.log_message gr_helpers.log_message = _buffered_gradio_log_message def _normalize_queue_request(request): if request is None: return None try: from gradio import route_utils route_utils.get_api_call_path(request) return request except Exception: pass try: from fastapi import Request as FastAPIRequest except Exception: return request scope = dict(getattr(request, "scope", {}) or {}) if scope.get("type") != "http": return request queue_path = f"{route_utils.API_PREFIX}/queue/join" scope["path"] = queue_path scope["raw_path"] = queue_path.encode("utf-8") scope["query_string"] = b"" try: return FastAPIRequest(scope, request.receive) except Exception: return request class GradioProgressCallbacks: def __init__(self, progress) -> None: self._progress = progress self._ratio = 0.0 def on_status(self, status) -> None: status = str(status or "").strip() if status: self._progress(self._ratio, desc=status) def on_progress(self, update) -> None: self._ratio = max(0.0, min(1.0, float(getattr(update, "progress", 0)) / 100.0)) self._progress(self._ratio, desc=str(getattr(update, "status", "") or "Generating...")) class _WrappedCallState: def __init__(self, output_count: int) -> None: self.output_count = output_count self.done = threading.Event() self.result: Any = None self.has_result = False self.error: BaseException | None = None self.job: SessionJob | None = None self.abort_client_id = "" self._abort_client_ids: list[str] = [] self._abort_client_ids_lock = threading.Lock() self.callback_context_ready = threading.Event() self.callback_context: dict[str, Any] | None = None self._followup_jobs: list[SessionJob] = [] self._followup_lock = threading.Lock() self._followup_enabled = False self._primary_job_forwarded = False self._yielded_results: list[Any] = [] self._yielded_results_lock = threading.Lock() self._log_messages: list[dict[str, Any]] = [] self._log_messages_lock = threading.Lock() def set_result(self, result: Any) -> None: self.result = result self.has_result = True self.done.set() def set_completed(self) -> None: self.done.set() def set_error(self, error: BaseException) -> None: self.error = error self.done.set() def set_callback_context(self, context: dict[str, Any]) -> None: self.callback_context = dict(context) self.callback_context_ready.set() def enable_followup_queue_triggers(self) -> None: self._followup_enabled = True def add_followup_job(self, job: SessionJob) -> None: if not self._followup_enabled: return with self._followup_lock: self._followup_jobs.append(job) def pop_ready_followup_load_queue_token(self) -> str: with self._followup_lock: for index, job in enumerate(self._followup_jobs): if job.webui_submission_ready: self._followup_jobs.pop(index) return job.webui_load_queue_token return "" def pop_primary_load_queue_token(self) -> str: if self._primary_job_forwarded or self.job is None or not self.job.webui_submission_ready: return "" self._primary_job_forwarded = True self.enable_followup_queue_triggers() return self.job.webui_load_queue_token def queue_abort_client_id(self, client_id: str) -> None: client_id = str(client_id or "").strip() if len(client_id) == 0: return with self._abort_client_ids_lock: self._abort_client_ids.append(client_id) def pop_abort_client_id(self) -> str: with self._abort_client_ids_lock: if not self._abort_client_ids: return "" client_id = self._abort_client_ids.pop(0) self.abort_client_id = client_id return client_id def push_yielded_result(self, result: Any) -> None: with self._yielded_results_lock: self._yielded_results.append(result) def pop_yielded_result(self) -> Any: with self._yielded_results_lock: if not self._yielded_results: return _NO_YIELDED_RESULT return self._yielded_results.pop(0) def queue_log_message(self, *, message: str, title: str, level: str, duration: float | None, visible: bool) -> None: with self._log_messages_lock: self._log_messages.append({"log": str(message or ""), "title": str(title or ""), "level": str(level or "info"), "duration": duration, "visible": bool(visible)}) def pop_log_messages(self) -> list[dict[str, Any]]: with self._log_messages_lock: if not self._log_messages: return [] messages = list(self._log_messages) self._log_messages.clear() return messages class _BoundGradioCallbacks: def __init__(self, callbacks: object, state: _WrappedCallState, owner: "GradioWanGPSession") -> None: self._callbacks = callbacks self._state = state self._owner = owner def __getattr__(self, name: str) -> Any: target = getattr(self._callbacks, name) if not callable(target): return target @functools.wraps(target) def wrapped(*args, **kwargs): self._state.callback_context_ready.wait(timeout=30.0) context = self._state.callback_context if not isinstance(context, dict): return target(*args, **kwargs) with self._owner._push_callback_context(context): return target(*args, **kwargs) return wrapped class WebUIQueueProbe: _POLL_INTERVAL_SECONDS = 0.2 _MISSING_OUTPUT_TIMEOUT_SECONDS = 5.0 _QUEUE_ADMISSION_SUSPEND_NOTICE_SECONDS = 10.0 _INLINE_QUEUE_SLOT_TIMEOUT_SECONDS = 10.0 _CANCEL_GRACE_SECONDS = 1.0 def __init__(self, session: WanGPSession, runtime, tasks: list[dict[str, Any]], job: SessionJob) -> None: self._session = session self._runtime = runtime self._tasks = tasks self._job = job self._wgp = runtime.module self._gen = session._state["gen"] self._manifest = job.webui_manifest or self._build_manifest(tasks) self._client_ids: list[str] = [] self._task_index_by_client_id: dict[str, int] = {} self._task_id_by_client_id: dict[str, Any] = {} self._outputs_by_client_id: dict[str, str] = {} self._artifacts_by_client_id: dict[str, GeneratedArtifact] = {} self._errors_by_client_id: dict[str, GenerationError] = {} self._admitted_client_ids: set[str] = set() self._missing_output_since: dict[str, float] = {} self._last_status_text = "" self._last_active_client_id = "" self._last_progress_key: tuple[Any, ...] | None = None self._last_preview_key: tuple[Any, ...] | None = None self._active_progress_seed: Any = None self._cancel_issued = False self._cancel_requested_at: float | None = None self._cancel_dispatched_client_ids: set[str] = set() self._submitted_at = 0.0 self._queue_wait_suspended = False self._logged_admitted_client_ids: set[str] = set() self._logged_missing_output_client_ids: set[str] = set() self._live_started_client_ids: set[str] = set() for index, task in enumerate(self._tasks, start=1): params = self._session._get_task_settings(task) client_id = str(params.get("client_id", "") or "").strip() if len(client_id) == 0: continue self._client_ids.append(client_id) self._task_index_by_client_id[client_id] = index self._task_id_by_client_id[client_id] = task.get("id") def run(self) -> GenerationResult: self._submit_inline_manifest() while not self._all_clients_finished(): self._poll_once() if self._all_clients_finished(): break time.sleep(self._POLL_INTERVAL_SECONDS) generated_files = [self._outputs_by_client_id[client_id] for client_id in self._client_ids if client_id in self._outputs_by_client_id] errors = [self._errors_by_client_id[client_id] for client_id in self._client_ids if client_id in self._errors_by_client_id] successful_tasks = len(generated_files) failed_tasks = len(self._client_ids) - successful_tasks return GenerationResult( success=len(errors) == 0 and failed_tasks == 0, generated_files=generated_files, errors=errors, total_tasks=len(self._client_ids), successful_tasks=successful_tasks, failed_tasks=failed_tasks, artifacts=tuple(self._artifacts_by_client_id.get(client_id) for client_id in self._client_ids if client_id in self._artifacts_by_client_id), ) def _submit_inline_manifest(self) -> None: self._reset_idle_state() self._wait_for_inline_queue_slot() if self._job.cancel_requested: for client_id in self._client_ids: self._register_error(client_id, "Generation was cancelled", stage="cancelled") return self._gen.setdefault("queue_errors", {}) self._gen["inline_queue"] = copy.deepcopy(self._manifest) self._job._mark_webui_submission_ready() print(f"WanGP API queued client_ids={self._client_ids}") gradio_context = getattr(self._session, "_gradio_webui_context", None) if not isinstance(gradio_context, dict) or not gradio_context.get("defer_load_queue_trigger", False): self._trigger_load_queue_event() self._submitted_at = time.time() self._publish("status", "Queued in WanGP...", "on_status") def _trigger_load_queue_event(self) -> None: gradio_context = getattr(self._session, "_gradio_webui_context", None) if not isinstance(gradio_context, dict): raise RuntimeError("WanGP WebUI queue submission requires an active Gradio session context.") fn_index = gradio_context.get("load_queue_fn_index") blocks = gradio_context.get("blocks") request = gradio_context.get("request") session_hash = gradio_context.get("session_hash") if not isinstance(fn_index, int) or blocks is None or request is None or not session_hash: raise RuntimeError("WanGP WebUI queue trigger is unavailable for the current Gradio session.") from gradio.data_classes import PredictBodyInternal request = _normalize_queue_request(request) if getattr(blocks._queue, "server_app", None) is None and getattr(blocks, "app", None) is not None: blocks._queue.set_server_app(blocks.app) body = PredictBodyInternal(session_hash=session_hash, fn_index=fn_index, data=[None, None], request=request) success, error_or_event_id = asyncio.run(blocks._queue.push(body=body, request=request, username=getattr(request, "username", None))) if not success: raise RuntimeError(str(error_or_event_id)) def _wait_for_inline_queue_slot(self) -> None: deadline = time.time() + self._INLINE_QUEUE_SLOT_TIMEOUT_SECONDS while self._gen.get("inline_queue") is not None: if self._job.cancel_requested: return if time.time() >= deadline: raise RuntimeError("WanGP inline queue bridge is busy") time.sleep(0.05) def _reset_idle_state(self) -> None: if self._gen.get("in_progress", False) or list(self._gen.get("queue", []) or []): return self._gen["abort"] = False self._gen["resume"] = False self._gen["early_stop"] = False self._gen["early_stop_forwarded"] = False self._gen["status"] = "" self._gen["status_display"] = False self._gen["last_progress_args"] = None self._gen["progress_args"] = None self._gen["preview"] = None def _poll_once(self) -> None: queue_client_ids, active_client_id = self._get_queue_snapshot() for client_id in queue_client_ids: if client_id in self._client_ids: self._admitted_client_ids.add(client_id) if client_id not in self._logged_admitted_client_ids: print(f"WanGP API admitted client_id={client_id}") self._logged_admitted_client_ids.add(client_id) if self._job.cancel_requested: self._request_cancel() if self._queue_wait_suspended and any(client_id in self._admitted_client_ids for client_id in self._client_ids): print("WanGP back in focus API queue resumed") self._queue_wait_suspended = False self._check_queue_errors() self._check_outputs(queue_client_ids) self._emit_live_updates(queue_client_ids, active_client_id) self._check_queue_admission_timeout() self._finalize_cancelled_clients(queue_client_ids) def _get_queue_snapshot(self) -> tuple[list[str], str]: queue_client_ids: list[str] = [] active_client_id = "" first_queue_task = True for task in list(self._gen.get("queue", []) or []): if not isinstance(task, dict): continue params = self._session._get_task_settings(task) client_id = str(params.get("client_id", "") or "").strip() if first_queue_task: active_client_id = client_id first_queue_task = False if len(client_id) == 0: continue queue_client_ids.append(client_id) return queue_client_ids, active_client_id def _check_queue_errors(self) -> None: queue_errors = self._gen.get("queue_errors", {}) or {} for client_id in self._client_ids: if client_id in self._outputs_by_client_id or client_id in self._errors_by_client_id: continue error_tuple = queue_errors.get(client_id) if error_tuple is None: continue error_text = str(error_tuple[0] if len(error_tuple) > 0 else "WanGP queue error") aborted = bool(error_tuple[1]) if len(error_tuple) > 1 else False print(f"WanGP API queue error client_id={client_id} aborted={aborted} error={error_text}") if aborted: self._remove_queue_client_id(client_id) self._register_error(client_id, "Generation was cancelled", stage="cancelled") else: self._register_error(client_id, error_text or "WanGP queue error", stage="generation") def _check_outputs(self, queue_client_ids: list[str]) -> None: processed = self._wgp.get_processed_queue(self._gen) if not isinstance(processed, tuple) or len(processed) != 4: return file_list, file_settings_list, audio_file_list, audio_file_settings_list = processed for client_id in self._client_ids: if client_id in self._outputs_by_client_id or client_id in self._errors_by_client_id: continue output_path = self._find_output_for_client(client_id, file_list, file_settings_list, audio_file_list, audio_file_settings_list) pending_artifact = self._session._peek_output_artifact(client_id) if pending_artifact is not None and client_id not in queue_client_ids: if not queue_client_ids and self._gen.get("in_progress", False): if client_id not in self._logged_missing_output_client_ids: print(f"WanGP API delaying completion for client_id={client_id} until main queue settles") self._logged_missing_output_client_ids.add(client_id) continue artifact = self._session._consume_output_artifact(client_id) resolved_output_path = str(output_path or (artifact.path if artifact is not None else "") or "").strip() if len(resolved_output_path) == 0: self._register_error(client_id, f"Generation produced an API artifact for client_id '{client_id}' without an output path.", stage="generation") continue self._outputs_by_client_id[client_id] = resolved_output_path if artifact is not None: self._artifacts_by_client_id[client_id] = GeneratedArtifact( path=resolved_output_path, media_type=artifact.media_type, client_id=artifact.client_id, video_tensor_uint8=artifact.video_tensor_uint8, video_tensor_hdr=artifact.video_tensor_hdr, hdr=artifact.hdr, audio_tensor=artifact.audio_tensor, audio_sampling_rate=artifact.audio_sampling_rate, fps=artifact.fps, flashvsr_continue_cache=artifact.flashvsr_continue_cache, ) self._missing_output_since.pop(client_id, None) self._logged_missing_output_client_ids.discard(client_id) print(f"WanGP API completed client_id={client_id} via artifact path={resolved_output_path}") payload = {"client_id": client_id, "path": resolved_output_path} self._publish("output", payload, "on_output") continue if output_path is not None and client_id not in queue_client_ids: if not queue_client_ids and self._gen.get("in_progress", False): if client_id not in self._logged_missing_output_client_ids: print(f"WanGP API delaying gallery completion for client_id={client_id} until main queue settles") self._logged_missing_output_client_ids.add(client_id) continue self._outputs_by_client_id[client_id] = output_path artifact = self._session._consume_output_artifact(client_id) if artifact is not None: self._artifacts_by_client_id[client_id] = GeneratedArtifact( path=output_path, media_type=artifact.media_type, client_id=artifact.client_id, video_tensor_uint8=artifact.video_tensor_uint8, video_tensor_hdr=artifact.video_tensor_hdr, hdr=artifact.hdr, audio_tensor=artifact.audio_tensor, audio_sampling_rate=artifact.audio_sampling_rate, fps=artifact.fps, flashvsr_continue_cache=artifact.flashvsr_continue_cache, ) self._missing_output_since.pop(client_id, None) self._logged_missing_output_client_ids.discard(client_id) print(f"WanGP API completed client_id={client_id} via gallery path={output_path}") payload = {"client_id": client_id, "path": output_path} self._publish("output", payload, "on_output") continue if client_id in queue_client_ids: self._missing_output_since.pop(client_id, None) self._logged_missing_output_client_ids.discard(client_id) continue if client_id not in self._admitted_client_ids: continue started_missing_at = self._missing_output_since.setdefault(client_id, time.time()) if client_id not in self._logged_missing_output_client_ids: print(f"WanGP API waiting for output client_id={client_id} queue_empty={not queue_client_ids} artifact_ready={pending_artifact is not None}") self._logged_missing_output_client_ids.add(client_id) if time.time() - started_missing_at >= self._MISSING_OUTPUT_TIMEOUT_SECONDS: self._register_error( client_id, f"Generation finished queue processing but no output with client_id '{client_id}' was found in the gallery.", stage="generation", ) def _emit_live_updates(self, queue_client_ids: list[str], active_client_id: str) -> None: if active_client_id != self._last_active_client_id: self._last_active_client_id = active_client_id self._last_progress_key = None self._last_preview_key = None self._last_status_text = "" self._active_progress_seed = copy.deepcopy(self._gen.get("last_progress_args")) live_generation_running = bool(self._gen.get("in_progress", False)) active_client_is_live = live_generation_running and active_client_id in self._client_ids and active_client_id not in self._outputs_by_client_id and active_client_id not in self._errors_by_client_id if active_client_is_live: self._live_started_client_ids.add(active_client_id) progress_args = self._gen.get("last_progress_args") progress_ready = progress_args != self._active_progress_seed progress_update = self._session._build_progress_update(progress_args, include_state_fallback=False) if progress_ready else None if progress_update is not None: if len(progress_update.status) > 0 and progress_update.status != self._last_status_text: self._last_status_text = progress_update.status self._publish("status", progress_update.status, "on_status") progress_key = ( active_client_id, progress_update.phase, progress_update.progress, progress_update.current_step, progress_update.total_steps, progress_update.status, progress_update.unit, ) if progress_key != self._last_progress_key: self._last_progress_key = progress_key self._publish("progress", progress_update, "on_progress") preview_image = self._gen.get("preview") if preview_image is not None and progress_update is not None: preview_key = (active_client_id, id(preview_image), getattr(preview_image, "size", None), progress_update.progress) if preview_key != self._last_preview_key: self._last_preview_key = preview_key self._publish( "preview", PreviewUpdate( image=preview_image, phase=progress_update.phase, status=progress_update.status, progress=progress_update.progress, current_step=progress_update.current_step, total_steps=progress_update.total_steps, ), "on_preview", ) return queued_client_ids = [ client_id for client_id in queue_client_ids if client_id in self._client_ids and client_id not in self._outputs_by_client_id and client_id not in self._errors_by_client_id ] if queued_client_ids and any(client_id not in self._live_started_client_ids for client_id in queued_client_ids): status_text = "Waiting in WanGP queue..." if status_text != self._last_status_text: self._last_status_text = status_text self._publish("status", status_text, "on_status") def _check_queue_admission_timeout(self) -> None: pending_client_ids = [ client_id for client_id in self._client_ids if client_id not in self._outputs_by_client_id and client_id not in self._errors_by_client_id and client_id not in self._admitted_client_ids ] if not pending_client_ids: self._queue_wait_suspended = False return if self._gen.get("in_progress", False) or list(self._gen.get("queue", []) or []): self._submitted_at = time.time() return if self._submitted_at <= 0 or time.time() - self._submitted_at < self._QUEUE_ADMISSION_SUSPEND_NOTICE_SECONDS or self._queue_wait_suspended: return print("WanGP API queue suspended while waiting for Video Generator to get browser focus") self._publish("status", "Waiting for WanGP Video Generator to get browser focus...", "on_status") self._queue_wait_suspended = True def _finalize_cancelled_clients(self, queue_client_ids: list[str]) -> None: if not self._cancel_issued or self._cancel_requested_at is None: return if time.time() - self._cancel_requested_at < self._CANCEL_GRACE_SECONDS: return for client_id in self._client_ids: if client_id in self._outputs_by_client_id or client_id in self._errors_by_client_id: continue if client_id in queue_client_ids or self._inline_queue_contains_client_id(client_id): continue self._register_error(client_id, "Generation was cancelled", stage="cancelled") def _request_cancel(self) -> None: dispatched_any = False for client_id in self._client_ids: if client_id in self._outputs_by_client_id or client_id in self._errors_by_client_id: continue if client_id in self._cancel_dispatched_client_ids: continue if self._remove_inline_queue_client_id(client_id): self._cancel_dispatched_client_ids.add(client_id) dispatched_any = True continue if client_id in self._admitted_client_ids: self._queue_abort_client_id(client_id) self._cancel_dispatched_client_ids.add(client_id) dispatched_any = True if dispatched_any: self._cancel_issued = True self._cancel_requested_at = time.time() def _queue_abort_client_id(self, client_id: str) -> None: owner = getattr(self._session, "_gradio_session_proxy", None) enqueue = getattr(owner, "_enqueue_abort_client_id", None) if not callable(enqueue) or not enqueue(self._job, client_id): self._gen["abort"] = True print("WanGP API set direct abort flag because the WebUI abort trigger was unavailable.") def _remove_inline_queue_client_id(self, client_id: str) -> bool: inline_queue = self._gen.get("inline_queue") if inline_queue is None: return False def _matches(item: Any) -> bool: if not isinstance(item, dict): return False params = item.get("params") if isinstance(params, dict) and str(params.get("client_id", "") or "").strip() == client_id: return True return str(item.get("client_id", "") or "").strip() == client_id if _matches(inline_queue): self._gen.pop("inline_queue", None) return True if isinstance(inline_queue, list): remaining = [item for item in inline_queue if not _matches(item)] if len(remaining) != len(inline_queue): if remaining: self._gen["inline_queue"] = remaining else: self._gen.pop("inline_queue", None) return True return False def _remove_queue_client_id(self, client_id: str) -> bool: queue = self._gen.get("queue") if not isinstance(queue, list): return False remaining = [] removed = False for item in list(queue): if self._inline_item_matches_client_id(item, client_id): removed = True continue remaining.append(item) if removed: queue[:] = remaining self._gen["queue"] = queue return removed def _inline_queue_contains_client_id(self, client_id: str) -> bool: inline_queue = self._gen.get("inline_queue") if inline_queue is None: return False if isinstance(inline_queue, list): return any(self._inline_item_matches_client_id(item, client_id) for item in inline_queue) return self._inline_item_matches_client_id(inline_queue, client_id) @staticmethod def _inline_item_matches_client_id(item: Any, client_id: str) -> bool: if not isinstance(item, dict): return False params = item.get("params") if isinstance(params, dict) and str(params.get("client_id", "") or "").strip() == client_id: return True return str(item.get("client_id", "") or "").strip() == client_id @staticmethod def _find_output_for_client( client_id: str, file_list: Sequence[Any], file_settings_list: Sequence[Any], audio_file_list: Sequence[Any], audio_file_settings_list: Sequence[Any], ) -> str | None: for paths, settings_list in ((file_list, file_settings_list), (audio_file_list, audio_file_settings_list)): for path, settings in zip(reversed(list(paths or [])), reversed(list(settings_list or []))): if not isinstance(settings, dict): continue if str(settings.get("client_id", "") or "").strip() == client_id: return str(Path(path).resolve()) return None def _register_error(self, client_id: str, message: str, *, stage: str) -> None: if client_id in self._errors_by_client_id or client_id in self._outputs_by_client_id: return failure = GenerationError( message=message, task_index=self._task_index_by_client_id.get(client_id), task_id=self._task_id_by_client_id.get(client_id), stage=stage, ) self._errors_by_client_id[client_id] = failure self._publish("error", failure, "on_error") def _publish(self, kind: str, payload: Any, callback_name: str | None = None) -> None: self._job.events.put(kind, payload) if callback_name is not None: self._session._emit_callback(callback_name, payload, job=self._job) def _all_clients_finished(self) -> bool: completed_count = len(self._outputs_by_client_id) + len(self._errors_by_client_id) return completed_count >= len(self._client_ids) @staticmethod def _build_manifest(tasks: Sequence[dict[str, Any]]) -> list[dict[str, Any]]: manifest = [] for index, task in enumerate(tasks, start=1): params = copy.deepcopy(WanGPSession._get_task_settings(task)) manifest.append({"id": task.get("id", index), "params": params, "plugin_data": copy.deepcopy(task.get("plugin_data", {}))}) return manifest def run_webui_job(session, job: SessionJob, tasks: list[dict[str, Any]]) -> None: try: runtime = session._ensure_runtime() job.events.put("started", {"tasks": len(tasks), "backend": "webui_queue"}) result = WebUIQueueProbe(session, runtime, tasks, job).run() job.events.put("completed", result) session._emit_callback("on_complete", result, job=job) job._set_result(result) except BaseException as exc: failure = session._make_generation_error(exc, task_index=None, task_id=None, stage="runtime") result = GenerationResult( success=False, generated_files=[], errors=[failure], total_tasks=len(tasks), successful_tasks=0, failed_tasks=max(1, len(tasks)), artifacts=(), ) job.events.put("error", failure) session._emit_callback("on_error", failure, job=job) job.events.put("completed", result) session._emit_callback("on_complete", result, job=job) job._set_result(result) finally: job.events.close() with session._job_lock: if session._active_job is job: session._active_job = None class GradioWanGPSession: def __init__(self, *, init_fn, plugin=None, state_component: Any = None, session_kwargs: dict[str, Any] | None = None) -> None: self._init_fn = init_fn self._plugin = plugin self._state_component = state_component self._session_kwargs = dict(session_kwargs or {}) self._session_kwargs.setdefault("console_output", False) self._session: WanGPSession | None = None self._defer_load_queue_trigger = False self._ui_local = threading.local() self._wrapped_calls: dict[str, _WrappedCallState] = {} self._wrapped_calls_lock = threading.Lock() self._ui_call_component = None @classmethod def for_plugin(cls, plugin, *, init_fn, session_kwargs: dict[str, Any] | None = None): plugin.request_component("state") return cls(init_fn=init_fn, plugin=plugin, session_kwargs=session_kwargs) def submit(self, source, callbacks: object | None = None) -> SessionJob: session = self._ensure_session() self._bind_gradio_context(session) job = session.submit(source, callbacks=self._wrap_callbacks_for_current_call(callbacks)) self._capture_job_for_current_call(job) return job def submit_task(self, settings: dict[str, Any], callbacks: object | None = None) -> SessionJob: session = self._ensure_session() self._bind_gradio_context(session) job = session.submit_task(settings, callbacks=self._wrap_callbacks_for_current_call(callbacks)) self._capture_job_for_current_call(job) return job def submit_manifest(self, settings_list: list[dict[str, Any]], callbacks: object | None = None) -> SessionJob: session = self._ensure_session() self._bind_gradio_context(session) job = session.submit_manifest(settings_list, callbacks=self._wrap_callbacks_for_current_call(callbacks)) self._capture_job_for_current_call(job) return job def run(self, source, callbacks: object | None = None) -> GenerationResult: session = self._ensure_session() self._bind_gradio_context(session) return session.run(source, callbacks=self._wrap_callbacks_for_current_call(callbacks)) def run_task(self, settings: dict[str, Any], callbacks: object | None = None) -> GenerationResult: session = self._ensure_session() self._bind_gradio_context(session) return session.run_task(settings, callbacks=self._wrap_callbacks_for_current_call(callbacks)) def run_manifest(self, settings_list: list[dict[str, Any]], callbacks: object | None = None) -> GenerationResult: session = self._ensure_session() self._bind_gradio_context(session) return session.run_manifest(settings_list, callbacks=self._wrap_callbacks_for_current_call(callbacks)) def ensure_ready(self): self._ensure_session().ensure_ready() return self def close(self) -> None: if self._session is None: return self._session.close() self._session = None def cancel(self) -> None: if self._session is not None: self._session.cancel() @contextlib.contextmanager def plugin_ui_context(self): import gradio as gr original_click = gr.Button.click if self._ui_call_component is None: self._ui_call_component = gr.State("") @functools.wraps(original_click) def patched_click(button, *args, **kwargs): fn = kwargs.get("fn") if fn is None and args: fn = args[0] if not callable(fn): return original_click(button, *args, **kwargs) if not self._callback_uses_api_session(fn): return original_click(button, *args, **kwargs) return self._wrap_button_click(original_click, button, *args, **kwargs) gr.Button.click = patched_click try: yield finally: gr.Button.click = original_click def __getattr__(self, name: str) -> Any: if name.startswith("_"): raise AttributeError(name) return getattr(self._ensure_session(), name) def _wrap_button_click(self, original_click, button, *args, **kwargs): import gradio as gr fn = kwargs.get("fn") if fn is None and args: fn = args[0] inputs = kwargs.get("inputs") if inputs is None and len(args) > 1: inputs = args[1] outputs = kwargs.get("outputs") if outputs is None and len(args) > 2: outputs = args[2] original_outputs = self._normalize_outputs(outputs) explicit_show_progress = kwargs.get("show_progress") if "show_progress" in kwargs else None explicit_progress_targets = self._normalize_outputs(kwargs.get("show_progress_on")) if "show_progress_on" in kwargs else None load_queue_trigger = self._resolve_main_bridge_component("wangp_main_load_queue_trigger") abort_client_id = self._resolve_main_bridge_component("wangp_main_abort_client_id") call_state = self._ui_call_component wrapped_start = self._make_wrapped_click_start(fn, len(original_outputs)) kwargs["fn"] = wrapped_start kwargs["outputs"] = [*original_outputs, load_queue_trigger, abort_client_id, call_state] kwargs["show_progress"] = "hidden" args = () dependency = original_click(button, *args, **kwargs) wait_outputs = [*original_outputs, load_queue_trigger, abort_client_id, call_state] progress_targets = explicit_progress_targets if explicit_progress_targets is not None else [component for component in original_outputs if hasattr(component, "_id")] def wait_wrapped_call(call_id): yield from self._wait_wrapped_call(call_id, len(original_outputs)) then_kwargs = {"fn": wait_wrapped_call, "inputs": [call_state], "outputs": wait_outputs, "show_progress": explicit_show_progress or "full"} if progress_targets is not None and len(progress_targets) > 0: then_kwargs["show_progress_on"] = progress_targets dependency.then( **then_kwargs, ) return dependency def _make_wrapped_click_start(self, fn, output_count: int): import gradio as gr @functools.wraps(fn) def wrapped(*args, **kwargs): call_id = str(time.time_ns()) state = _WrappedCallState(output_count) self._remember_wrapped_call(call_id, state) bound_state = self._resolve_state() bound_context = self._capture_current_gradio_context() state.set_callback_context(bound_context) worker = threading.Thread(target=self._run_wrapped_click_worker, args=(call_id, fn, args, kwargs, bound_state, bound_context), daemon=True, name="wangp-plugin-click") worker.start() deadline = time.time() + 0.5 while time.time() < deadline: yielded_result = state.pop_yielded_result() if yielded_result is not _NO_YIELDED_RESULT: return [*self._normalize_callback_result(yielded_result, output_count), gr.skip(), gr.skip(), call_id] load_queue_token = state.pop_primary_load_queue_token() if load_queue_token: return [*self._blank_outputs(output_count), load_queue_token, gr.skip(), call_id] if state.job is not None: state.job._webui_submission_ready.wait(timeout=0.05) if state.done.wait(timeout=0.05): if state.error is not None: self._forget_wrapped_call(call_id) raise self._as_gradio_error(state.error) yielded_result = state.pop_yielded_result() if yielded_result is not _NO_YIELDED_RESULT: return [*self._normalize_callback_result(yielded_result, output_count), gr.skip(), gr.skip(), call_id] if not state.has_result: self._forget_wrapped_call(call_id) return [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] self._forget_wrapped_call(call_id) return [*self._normalize_callback_result(state.result, output_count), gr.skip(), gr.skip(), ""] if state.job is not None: if not state.job._webui_submission_ready.wait(timeout=5): self._forget_wrapped_call(call_id) raise gr.Error("WanGP WebUI submission did not become ready in time.") load_queue_token = state.pop_primary_load_queue_token() if load_queue_token: return [*self._blank_outputs(output_count), load_queue_token, gr.skip(), call_id] state.done.wait() if state.error is not None: self._forget_wrapped_call(call_id) raise self._as_gradio_error(state.error) yielded_result = state.pop_yielded_result() if yielded_result is not _NO_YIELDED_RESULT: return [*self._normalize_callback_result(yielded_result, output_count), gr.skip(), gr.skip(), call_id] if not state.has_result: self._forget_wrapped_call(call_id) return [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] self._forget_wrapped_call(call_id) return [*self._normalize_callback_result(state.result, output_count), gr.skip(), gr.skip(), ""] wrapped.__signature__ = inspect.signature(fn) return wrapped def _wait_wrapped_call(self, call_id: str, output_count: int): import gradio as gr call_id = str(call_id or "").strip() if len(call_id) == 0: yield [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] return state = self._get_wrapped_call(call_id) if state is None: yield [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] return try: state.set_callback_context(self._capture_progress_callback_context()) self._flush_buffered_log_messages(state) while True: self._flush_buffered_log_messages(state) load_queue_token = state.pop_primary_load_queue_token() if load_queue_token: yield [*self._blank_outputs(output_count), load_queue_token, gr.skip(), call_id] continue load_queue_token = state.pop_ready_followup_load_queue_token() if load_queue_token: print(f"WanGP API forwarding follow-up load_queue_trigger token={load_queue_token}") yield [*self._blank_outputs(output_count), load_queue_token, gr.skip(), call_id] continue abort_client_id = state.pop_abort_client_id() if abort_client_id: print(f"WanGP API forwarding abort_client_id={abort_client_id}") yield [*self._blank_outputs(output_count), gr.skip(), abort_client_id, call_id] continue yielded_result = state.pop_yielded_result() if yielded_result is not _NO_YIELDED_RESULT: yield [*self._normalize_callback_result(yielded_result, output_count), gr.skip(), gr.skip(), call_id] continue if state.done.is_set(): if state.error is not None: raise self._as_gradio_error(state.error) if state.has_result: yield [*self._normalize_callback_result(state.result, output_count), gr.skip(), gr.skip(), ""] else: yield [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] break time.sleep(0.05) finally: self._forget_wrapped_call(call_id) @staticmethod def _flush_buffered_log_messages(state: _WrappedCallState) -> None: context = state.callback_context if not isinstance(context, dict): return blocks = context.get("blocks") event_id = context.get("event_id") if blocks is None or event_id is None: return for message in state.pop_log_messages(): blocks._queue.log_message(event_id=event_id, **message) def _run_wrapped_click_worker(self, call_id: str, fn, args, kwargs, bound_state: dict[str, Any], bound_context: dict[str, Any]) -> None: state = self._get_wrapped_call(call_id) if state is None: return _ensure_gradio_log_message_patch() self._ui_local.call_id = call_id self._ui_local.defer_load_queue_trigger = True self._ui_local.bound_state = bound_state _WRAPPED_LOG_LOCAL.owner = self _WRAPPED_LOG_LOCAL.call_id = call_id try: exec_context = dict(bound_context) if isinstance(state.callback_context, dict): exec_context.update(state.callback_context) self._ui_local.bound_gradio_context = exec_context with self._push_callback_context(exec_context): result = fn(*args, **kwargs) if inspect.isgenerator(result): iterator = iter(result) while True: try: state.push_yielded_result(next(iterator)) except StopIteration as stop: if stop.value is not None: state.set_result(stop.value) else: state.set_completed() break else: state.set_result(result) except BaseException as exc: state.set_error(exc) finally: _WRAPPED_LOG_LOCAL.owner = None _WRAPPED_LOG_LOCAL.call_id = "" self._ui_local.call_id = "" self._ui_local.defer_load_queue_trigger = False self._ui_local.bound_state = None self._ui_local.bound_gradio_context = None def _capture_job_for_current_call(self, job: SessionJob) -> None: call_id = str(getattr(self._ui_local, "call_id", "") or "").strip() if len(call_id) == 0: return job._bind_webui_owner_call(call_id) state = self._get_wrapped_call(call_id) if state is not None: if state.job is None: state.job = job else: state.add_followup_job(job) def _capture_cancelled_job(self, job: SessionJob) -> None: return def _enqueue_abort_client_id(self, job: SessionJob, client_id: str) -> bool: call_id = str(getattr(job, "webui_owner_call_id", "") or getattr(self._ui_local, "call_id", "") or "").strip() if len(call_id) == 0: return False state = self._get_wrapped_call(call_id) if state is None: return False state.queue_abort_client_id(client_id) return True def _remember_wrapped_call(self, call_id: str, state: _WrappedCallState) -> None: with self._wrapped_calls_lock: self._wrapped_calls[call_id] = state def _get_wrapped_call(self, call_id: str) -> _WrappedCallState | None: with self._wrapped_calls_lock: return self._wrapped_calls.get(call_id) def _forget_wrapped_call(self, call_id: str) -> None: with self._wrapped_calls_lock: self._wrapped_calls.pop(call_id, None) def _callback_uses_api_session(self, fn) -> bool: candidates: list[Any] = [] try: closure_vars = inspect.getclosurevars(fn) except Exception: closure_vars = None if closure_vars is not None: candidates.extend(closure_vars.nonlocals.values()) candidates.extend(closure_vars.globals.values()) for values in (getattr(fn, "__defaults__", None) or (), (getattr(fn, "__kwdefaults__", None) or {}).values()): candidates.extend(values) for cell in getattr(fn, "__closure__", ()) or (): try: candidates.append(cell.cell_contents) except ValueError: continue for candidate in candidates: if candidate is self or candidate is self._session: return True if isinstance(candidate, (GradioWanGPSession, WanGPSession)): return True if inspect.ismethod(candidate) and candidate.__self__ in (self, self._session): return True return False def _wrap_callbacks_for_current_call(self, callbacks: object | None) -> object | None: if callbacks is None: return None call_id = str(getattr(self._ui_local, "call_id", "") or "").strip() if len(call_id) == 0: return callbacks state = self._get_wrapped_call(call_id) if state is None: return callbacks if isinstance(callbacks, _BoundGradioCallbacks): return callbacks return _BoundGradioCallbacks(callbacks, state, self) @staticmethod def _normalize_outputs(outputs: Any) -> list[Any]: if outputs is None: return [] if isinstance(outputs, (list, tuple)): return list(outputs) return [outputs] @staticmethod def _blank_outputs(output_count: int) -> list[Any]: import gradio as gr return [gr.skip()] * output_count @staticmethod def _normalize_callback_result(result: Any, output_count: int) -> list[Any]: import gradio as gr if output_count <= 0: return [] if output_count == 1: return [result] if isinstance(result, tuple): normalized = list(result) elif isinstance(result, list): normalized = list(result) else: normalized = [result] if len(normalized) < output_count: normalized.extend([gr.skip()] * (output_count - len(normalized))) return normalized[:output_count] @staticmethod def _as_gradio_error(error: BaseException): import gradio as gr return error if isinstance(error, gr.Error) else gr.Error(str(error)) def _ensure_session(self) -> WanGPSession: state = self._resolve_state() if self._session is None or self._session._state is not state: session_kwargs = copy.deepcopy(self._session_kwargs) session_kwargs["webui_state"] = state self._session = WanGPSession(**session_kwargs) self._session._gradio_session_proxy = self return self._session def _resolve_state(self) -> dict[str, Any]: bound_state = getattr(self._ui_local, "bound_state", None) if isinstance(bound_state, dict): return bound_state component = self._state_component if component is None and self._plugin is not None: component = getattr(self._plugin, "state", None) state = self._resolve_live_session_state(component) if not isinstance(state, dict): state = getattr(component, "value", None) if component is not None else None if not isinstance(state, dict): raise RuntimeError("WanGP WebUI session requires access to the live Gradio state component.") return state @staticmethod def _resolve_live_session_state(component: Any) -> dict[str, Any] | None: component_id = getattr(component, "_id", None) if component_id is None: return None try: from gradio.context import LocalContext except Exception: return None try: blocks = LocalContext.blocks.get(None) request = LocalContext.request.get(None) except LookupError: return None session_hash = getattr(request, "session_hash", None) if request is not None else None state_holder = getattr(blocks, "state_holder", None) if blocks is not None else None if not session_hash or state_holder is None: return None try: session_state = state_holder[session_hash] state = session_state[component_id] except Exception: return None return state if isinstance(state, dict) else None def _bind_gradio_context(self, session: WanGPSession) -> None: bound_context = getattr(self._ui_local, "bound_gradio_context", None) if isinstance(bound_context, dict): session._gradio_webui_context = dict(bound_context) session._gradio_webui_context["defer_load_queue_trigger"] = self._defer_load_queue_trigger or bool(getattr(self._ui_local, "defer_load_queue_trigger", False)) return try: from gradio.context import LocalContext except Exception: raise RuntimeError("WanGP WebUI session requires an active Gradio callback context.") try: blocks = LocalContext.blocks.get(None) request_wrapper = LocalContext.request.get(None) except LookupError as exc: raise RuntimeError("WanGP WebUI session requires an active Gradio callback context.") from exc session_hash = getattr(request_wrapper, "session_hash", None) if request_wrapper is not None else None request = getattr(request_wrapper, "request", request_wrapper) if blocks is None or request is None or not session_hash: raise RuntimeError("WanGP WebUI session requires a live Gradio request with a session hash.") session._gradio_webui_context = { "blocks": blocks, "request": request, "session_hash": session_hash, "load_queue_fn_index": self._resolve_trigger_fn_index(blocks, session_hash, "load_queue_action", "change"), "abort_fn_index": self._resolve_abort_fn_index(blocks, session_hash), "defer_load_queue_trigger": self._defer_load_queue_trigger or bool(getattr(self._ui_local, "defer_load_queue_trigger", False)), } def _set_defer_load_queue_trigger(self, value: bool) -> None: self._defer_load_queue_trigger = bool(value) def _capture_current_gradio_context(self) -> dict[str, Any]: try: from gradio.context import LocalContext except Exception: raise RuntimeError("WanGP WebUI session requires an active Gradio callback context.") try: blocks = LocalContext.blocks.get(None) blocks_config = LocalContext.blocks_config.get(None) renderable = LocalContext.renderable.get(None) render_block = LocalContext.render_block.get(None) in_event_listener = LocalContext.in_event_listener.get(False) event_id = LocalContext.event_id.get(None) request_wrapper = LocalContext.request.get(None) progress = LocalContext.progress.get(None) except LookupError as exc: raise RuntimeError("WanGP WebUI session requires an active Gradio callback context.") from exc session_hash = getattr(request_wrapper, "session_hash", None) if request_wrapper is not None else None request = getattr(request_wrapper, "request", request_wrapper) if blocks is None or request is None or not session_hash: raise RuntimeError("WanGP WebUI session requires a live Gradio request with a session hash.") return { "blocks": blocks, "blocks_config": blocks_config, "renderable": renderable, "render_block": render_block, "in_event_listener": in_event_listener, "event_id": event_id, "request": request, "progress": progress, "session_hash": session_hash, "load_queue_fn_index": self._resolve_trigger_fn_index(blocks, session_hash, "load_queue_action", "change"), "abort_fn_index": self._resolve_abort_fn_index(blocks, session_hash), "defer_load_queue_trigger": self._defer_load_queue_trigger or bool(getattr(self._ui_local, "defer_load_queue_trigger", False)), } @staticmethod def _capture_progress_callback_context() -> dict[str, Any]: try: from gradio.context import LocalContext except Exception as exc: raise RuntimeError("WanGP progress callbacks require an active Gradio callback context.") from exc return { "blocks": LocalContext.blocks.get(None), "blocks_config": LocalContext.blocks_config.get(None), "renderable": LocalContext.renderable.get(None), "render_block": LocalContext.render_block.get(None), "in_event_listener": LocalContext.in_event_listener.get(False), "event_id": LocalContext.event_id.get(None), "request": LocalContext.request.get(None), "progress": LocalContext.progress.get(None), } @staticmethod @contextlib.contextmanager def _push_callback_context(context: dict[str, Any]): try: from gradio.context import LocalContext except Exception: yield return tokens = [] mapping = { LocalContext.blocks: context.get("blocks"), LocalContext.blocks_config: context.get("blocks_config"), LocalContext.renderable: context.get("renderable"), LocalContext.render_block: context.get("render_block"), LocalContext.in_event_listener: context.get("in_event_listener", False), LocalContext.event_id: context.get("event_id"), LocalContext.request: context.get("request"), LocalContext.progress: context.get("progress"), } try: for var, value in mapping.items(): tokens.append((var, var.set(value))) yield finally: for var, token in reversed(tokens): var.reset(token) @staticmethod def _resolve_main_bridge_component(elem_id: str): try: from gradio.context import Context, get_blocks_context except Exception as exc: raise RuntimeError(f"WanGP WebUI bridge component '{elem_id}' is unavailable outside the Gradio build context.") from exc blocks_context = get_blocks_context() if blocks_context is None and getattr(Context, "root_block", None) is not None: blocks_context = Context.root_block.default_config blocks = getattr(blocks_context, "blocks", None) if not isinstance(blocks, dict): raise RuntimeError(f"WanGP WebUI bridge component '{elem_id}' was not found in the current Blocks tree.") for component in blocks.values(): if getattr(component, "elem_id", None) == elem_id: return component raise RuntimeError(f"WanGP WebUI bridge component '{elem_id}' could not be resolved.") @staticmethod def _resolve_trigger_fn_index(blocks, session_hash: str, api_name: str, event_name: str) -> int: session_state = blocks.state_holder[session_hash] for block_fn in session_state.blocks_config.fns.values(): targets = getattr(block_fn, "targets", ()) or () if getattr(block_fn, "api_name", None) == api_name and any(target[1] == event_name for target in targets): return int(getattr(block_fn, "_id")) raise RuntimeError(f"WanGP WebUI trigger '{api_name}' was not found.") @staticmethod def _resolve_abort_fn_index(blocks, session_hash: str) -> int: session_state = blocks.state_holder[session_hash] for block_fn in session_state.blocks_config.fns.values(): targets = getattr(block_fn, "targets", ()) or () api_name = str(getattr(block_fn, "api_name", "") or "") if api_name.startswith("abort_generation") and any(target[1] == "change" for target in targets): return int(getattr(block_fn, "_id")) raise RuntimeError("WanGP WebUI abort trigger was not found.") def create_gradio_webui_session(plugin, *, init_fn, session_kwargs: dict[str, Any] | None = None) -> GradioWanGPSession: return GradioWanGPSession.for_plugin(plugin, init_fn=init_fn, session_kwargs=session_kwargs) def create_gradio_progress_callbacks(progress) -> GradioProgressCallbacks: return GradioProgressCallbacks(progress)