| |
| |
| |
| from __future__ import annotations |
|
|
| import gc |
| import math |
| import os |
| from contextlib import contextmanager |
| from dataclasses import dataclass |
| from typing import Any |
|
|
| import torch |
| import torch.nn.functional as F |
| from accelerate import init_empty_weights |
| from einops import rearrange |
| from safetensors.torch import load_file |
| from tqdm import tqdm |
|
|
| from mmgp import offload |
| from models.wan.modules.vae import WanVAE |
| from .attention_backend import log_sparse_backend, require_sparge_attention |
| from .tcdecoder import build_tcdecoder |
| from .utils import Causal_LQ4x_Proj |
| from .wan_video_dit import WanModel, precompute_freqs_cis_3d |
|
|
|
|
| FLASHVSR_VARIANT_TINY_LONG = "tiny-long" |
| FLASHVSR_VARIANT_TINY = "tiny" |
| FLASHVSR_VARIANT_FULL = "full" |
|
|
| FLASHVSR_TOPK_RATIO = 0.0 |
| FLASHVSR_FULL_MIN_AUTO_TOPK_RATIO = 1.5 |
| FLASHVSR_KV_CACHE_WINDOWS = 1 |
| FLASHVSR_CONTINUE_CACHE_FRAMES = 11 |
| FLASHVSR_COTENANTS_MAP = {"lq_proj": ["transformer"]} |
| FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO = False |
| FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS = False |
| FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION = True |
| FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_INPUT_SHIFT = None |
| FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_PERIOD = 16 |
| FLASHVSR_STILL_IMAGE_RETURN_WARMED_FRAME = True |
| FLASHVSR_STILL_IMAGE_SHIFT_BLEND = 0.5 |
| FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_PATH = "flashvsr_still_image_debug.mp4" |
| FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_FPS = 4 |
|
|
| WAN_1_3B_CONFIG = { |
| "has_image_input": False, |
| "patch_size": (1, 2, 2), |
| "in_dim": 16, |
| "dim": 1536, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "text_dim": 4096, |
| "out_dim": 16, |
| "num_heads": 12, |
| "num_layers": 30, |
| "eps": 1e-6, |
| } |
|
|
|
|
| @contextmanager |
| def _default_dtype(dtype: torch.dtype): |
| previous_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(dtype) |
| try: |
| yield |
| finally: |
| torch.set_default_dtype(previous_dtype) |
|
|
|
|
| @dataclass |
| class FlashVSRPaths: |
| transformer: str |
| lq_proj: str |
| posi_prompt: str |
| tcdecoder: str | None = None |
| vae: str | None = None |
|
|
|
|
| def _preprocess_transformer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| converter = WanModel.state_dict_converter() |
| state_dict, _ = converter.from_civitai(state_dict) |
| return state_dict |
|
|
|
|
| def _sinusoidal_embedding_1d(dim: int, position: torch.Tensor) -> torch.Tensor: |
| sinusoid = torch.outer(position.type(torch.float64), torch.pow(10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2))) |
| return torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(position.dtype) |
|
|
|
|
| def _next_conditioning_frame_count(frame_count: int) -> int: |
| padded = max(25, frame_count + 4) |
| remainder = padded % 8 |
| if remainder != 1: |
| padded += (1 - remainder) % 8 |
| return padded |
|
|
|
|
| def _aligned_output_size(height: int, width: int, scale: float) -> tuple[int, int]: |
| target_h = max(1, int(height * scale)) |
| target_w = max(1, int(width * scale)) |
| return max(128, math.ceil(target_h / 128) * 128), max(128, math.ceil(target_w / 128) * 128) |
|
|
|
|
| def _conditioning_sizes(sample: torch.Tensor, scale: float) -> tuple[int, int, int, int]: |
| _, frames, height, width = sample.shape |
| output_height = max(1, int(height * scale)) |
| output_width = max(1, int(width * scale)) |
| padded_output_height, padded_output_width = _aligned_output_size(height, width, scale) |
| pad_h = padded_output_height - output_height |
| pad_w = padded_output_width - output_width |
| if pad_h or pad_w: |
| print(f"[FlashVSR] Edge padding output canvas {output_width}x{output_height} -> {padded_output_width}x{padded_output_height}; final crop restores {output_width}x{output_height}") |
| return output_height, output_width, padded_output_height, padded_output_width |
|
|
|
|
| def _prepare_conditioning_range(sample: torch.Tensor, start: int, end: int, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int, dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: |
| frames = int(sample.shape[1]) |
| pad_h = padded_output_height - output_height |
| pad_w = padded_output_width - output_width |
| frame_indices = [min(max(frame_idx, 0), frames - 1) for frame_idx in range(start, end)] |
| lq = sample[:, frame_indices] |
| if lq.dtype == torch.uint8: |
| lq = lq.float().div_(127.5).sub_(1.0) |
| else: |
| lq = lq.detach().float().clamp_(-1.0, 1.0) |
| lq = F.interpolate(lq.permute(1, 0, 2, 3).contiguous(), size=(output_height, output_width), mode="bicubic", align_corners=False) |
| if pad_h or pad_w: |
| lq = F.pad(lq, (0, pad_w, 0, pad_h), mode="replicate") |
| return lq.clamp_(-1.0, 1.0).to(dtype=dtype).permute(1, 0, 2, 3).contiguous() |
|
|
|
|
| def _pad_conditioning_frames(lq_video: torch.Tensor, target_frames: int) -> torch.Tensor: |
| missing = target_frames - lq_video.shape[2] |
| if missing <= 0: |
| return lq_video[:, :, :target_frames] |
| tail = lq_video[:, :, -1:].repeat(1, 1, missing, 1, 1) |
| return torch.cat([lq_video, tail], dim=2) |
|
|
|
|
| def _crop_output_frames(frames: torch.Tensor, height: int, width: int) -> torch.Tensor: |
| if frames.shape[-2:] == (height, width): |
| return frames |
| return frames[..., :height, :width].contiguous() |
|
|
|
|
| def _shift_spatial_replicate(tensor: torch.Tensor, shift_y: int, shift_x: int) -> torch.Tensor: |
| if shift_y == 0 and shift_x == 0: |
| return tensor.clone() |
| height, width = tensor.shape[-2:] |
| shift_y = max(1 - height, min(height - 1, int(shift_y))) |
| shift_x = max(1 - width, min(width - 1, int(shift_x))) |
| crop = tensor[..., max(0, -shift_y):height - max(0, shift_y), max(0, -shift_x):width - max(0, shift_x)] |
| return F.pad(crop, (max(0, shift_x), max(0, -shift_x), max(0, shift_y), max(0, -shift_y)), mode="replicate") |
|
|
|
|
| def _apply_still_image_shift_correction(base: torch.Tensor, shifted: torch.Tensor, scale: float) -> torch.Tensor: |
| base_float = base.to(dtype=torch.float32, copy=True) |
| corrected = base_float.lerp_(shifted.to(dtype=torch.float32), float(FLASHVSR_STILL_IMAGE_SHIFT_BLEND)) |
| if base.dtype == torch.uint8: |
| return corrected.round_().clamp_(0, 255).to(torch.uint8) |
| return corrected.clamp_(-1.0, 1.0).to(dtype=base.dtype) |
|
|
|
|
| def _shift_continue_cache(continue_cache: Any, shift_y: int, shift_x: int) -> Any: |
| if not isinstance(continue_cache, dict): |
| return continue_cache |
| tail = continue_cache.get("tail_frames") |
| if not torch.is_tensor(tail) or tail.ndim != 4: |
| return continue_cache |
| shifted_cache = dict(continue_cache) |
| shifted_cache["tail_frames"] = _shift_spatial_replicate(tail, shift_y, shift_x) |
| return shifted_cache |
|
|
|
|
| def _two_pass_shifted_continue_cache(continue_cache: Any, shift_y: int, shift_x: int) -> Any: |
| if not isinstance(continue_cache, dict): |
| return continue_cache |
| tail = continue_cache.get("tail_frames_shifted") |
| if not torch.is_tensor(tail) or tail.ndim != 4: |
| return _shift_continue_cache(continue_cache, shift_y, shift_x) |
| shifted_cache = dict(continue_cache) |
| shifted_cache["tail_frames"] = tail |
| return shifted_cache |
|
|
|
|
| def _make_two_pass_continue_cache(base_cache: Any, shifted_cache: Any, shift_y: int, shift_x: int, out_shift_y: int, out_shift_x: int) -> Any: |
| if not isinstance(base_cache, dict): |
| return base_cache |
| cache = dict(base_cache) |
| shifted_tail = shifted_cache.get("tail_frames") if isinstance(shifted_cache, dict) else None |
| if torch.is_tensor(shifted_tail) and shifted_tail.ndim == 4: |
| cache["tail_frames_shifted"] = shifted_tail.contiguous() |
| cache.update({"two_pass": True, "shift_y": shift_y, "shift_x": shift_x, "out_shift_y": out_shift_y, "out_shift_x": out_shift_x}) |
| return cache |
|
|
|
|
| def _select_still_image_frame(frames: torch.Tensor, frame_index: int) -> torch.Tensor: |
| return frames[:, :, frame_index:frame_index + 1].contiguous() if frames.ndim == 5 else frames[:, frame_index:frame_index + 1].contiguous() |
|
|
|
|
| def _decoded_frames_to_cpu(frames: torch.Tensor, frame_count: int, height: int, width: int) -> torch.Tensor: |
| frames = frames.detach()[0, :, :frame_count, :height, :width] |
| if frames.device.type == "cpu" and frames.dtype == torch.float32 and frames.is_contiguous(): |
| return frames |
| frames_cpu = torch.empty(tuple(frames.shape), dtype=torch.float32, device="cpu") |
| frames_cpu.copy_(frames) |
| return frames_cpu |
|
|
|
|
| def _save_still_image_debug_video(frames: torch.Tensor) -> None: |
| if not FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO: |
| return |
| path = os.path.abspath(FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_PATH) |
| try: |
| from shared.utils.audio_video import save_video |
| debug_frames = frames.detach().cpu() |
| save_video(tensor=debug_frames, save_file=path, fps=FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_FPS, nrow=1, normalize=True, value_range=(-1, 1), codec_type="libx264_8", container="mp4") |
| print(f"[FlashVSR] Still image debug video saved to {path} ({int(debug_frames.shape[2])} frames)") |
| del debug_frames |
| except Exception as exc: |
| print(f"[FlashVSR] Failed to save still image debug video: {exc}") |
|
|
|
|
| def _nested_tensors_to(value: Any, device: torch.device | str, dtype: torch.dtype | None = None) -> Any: |
| if torch.is_tensor(value): |
| return value.detach().to(device=device, dtype=dtype or value.dtype) |
| if isinstance(value, list): |
| return [_nested_tensors_to(item, device, dtype) for item in value] |
| return value |
|
|
|
|
| def _tcdecoder_mem_halo_latents(tcdecoder: torch.nn.Module) -> int: |
| radius = 0.0 |
| jump = 1.0 |
| decoder = tcdecoder.taehv.decoder if hasattr(tcdecoder, "taehv") else tcdecoder.decoder |
| for module in decoder: |
| if isinstance(module, torch.nn.Conv2d): |
| kernel = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else int(module.kernel_size) |
| radius += ((kernel - 1) / 2) * jump |
| elif module.__class__.__name__ == "MemBlock": |
| for submodule in module.conv: |
| if isinstance(submodule, torch.nn.Conv2d): |
| kernel = submodule.kernel_size[0] if isinstance(submodule.kernel_size, tuple) else int(submodule.kernel_size) |
| radius += ((kernel - 1) / 2) * jump |
| elif isinstance(module, torch.nn.Upsample): |
| scale = module.scale_factor[0] if isinstance(module.scale_factor, tuple) else module.scale_factor |
| jump /= float(scale or 1) |
| return max(1, int(math.ceil(radius))) |
|
|
|
|
| def _report_progress(progress_callback, phase: str, current_step: int | None = None, total_steps: int | None = None) -> None: |
| if callable(progress_callback): |
| progress_callback(phase, current_step, total_steps) |
|
|
|
|
| def _abort_requested(abort_callback) -> bool: |
| return callable(abort_callback) and abort_callback() |
|
|
|
|
| def _apply_continue_cache(frames: torch.Tensor, continue_cache: Any) -> torch.Tensor: |
| if not isinstance(continue_cache, dict): |
| return frames |
| tail = continue_cache.get("tail_frames") |
| if not torch.is_tensor(tail) or tail.ndim != 4: |
| return frames |
| if tail.shape[0] != frames.shape[0] or tail.shape[-2:] != frames.shape[-2:]: |
| return frames |
| overlap = min(int(tail.shape[1]), int(frames.shape[1])) |
| if overlap <= 0: |
| return frames |
| if frames.dtype == torch.uint8: |
| if tail.dtype != torch.uint8: |
| tail = tail.float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8) |
| frames[:, :overlap].copy_(tail[:, -overlap:].to(device=frames.device)) |
| return frames |
| if tail.dtype == torch.uint8: |
| tail = tail.to(device=frames.device, dtype=frames.dtype).div(127.5).sub(1.0) |
| else: |
| tail = tail.to(device=frames.device, dtype=frames.dtype) |
| frames[:, :overlap].copy_(tail[:, -overlap:]) |
| return frames |
|
|
|
|
| def _make_continue_cache(frames: torch.Tensor, scale: float, variant: str, overlap_frames: int = FLASHVSR_CONTINUE_CACHE_FRAMES) -> dict[str, Any]: |
| tail_len = min(overlap_frames, frames.shape[1]) |
| tail = frames[:, -tail_len:].detach().cpu() |
| if tail.dtype != torch.uint8: |
| tail = tail.float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8) |
| return {"tail_frames": tail.contiguous(), "scale": scale, "variant": variant} |
|
|
|
|
| def _wavelet_color_fix(frames: torch.Tensor, lq_video: torch.Tensor) -> torch.Tensor: |
| if frames.shape != lq_video[:, :, :frames.shape[2]].shape: |
| return frames |
| for start in range(0, frames.shape[2], 4): |
| end = min(start + 4, frames.shape[2]) |
| frame_chunk = frames[:, :, start:end] |
| lq_chunk = lq_video[:, :, start:end].to(device=frames.device, dtype=frames.dtype) |
| mean_frames = frame_chunk.mean(dim=(3, 4), keepdim=True) |
| std_frames = frame_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) |
| mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True) |
| std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) |
| frame_chunk.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).clamp_(-1.0, 1.0) |
| return frames |
|
|
|
|
| def _wavelet_color_fix_from_sample(frames: torch.Tensor, sample: torch.Tensor, scale: float, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int) -> torch.Tensor: |
| step = 1 if frames.dtype == torch.uint8 else 4 |
| for start in range(0, min(int(frames.shape[2]), int(sample.shape[1])), step): |
| end = min(start + step, int(frames.shape[2]), int(sample.shape[1])) |
| frame_chunk = frames[:, :, start:end] |
| if frames.dtype == torch.uint8: |
| frame_float = frame_chunk.float() |
| lq_chunk = sample[:, start:end].unsqueeze(0).to(device=frames.device, dtype=torch.float32) |
| if sample.dtype != torch.uint8: |
| lq_chunk.clamp_(-1.0, 1.0).add_(1.0).mul_(127.5) |
| mean_frames = frame_float.mean(dim=(3, 4), keepdim=True) |
| std_frames = frame_float.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) |
| mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True) |
| std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) |
| frame_float.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).round_().clamp_(0, 255) |
| frame_chunk.copy_(frame_float.to(torch.uint8)) |
| del frame_float, lq_chunk, mean_frames, std_frames, mean_lq, std_lq |
| continue |
| lq_chunk = sample[:, start:end].unsqueeze(0).to(device=frames.device, dtype=frames.dtype) |
| if sample.dtype == torch.uint8: |
| lq_chunk.div_(127.5).sub_(1.0) |
| else: |
| lq_chunk.clamp_(-1.0, 1.0) |
| mean_frames = frame_chunk.mean(dim=(3, 4), keepdim=True) |
| std_frames = frame_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) |
| mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True) |
| std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) |
| frame_chunk.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).clamp_(-1.0, 1.0) |
| del lq_chunk, mean_frames, std_frames, mean_lq, std_lq |
| return frames |
|
|
|
|
| def _denoise_stream_chunk( |
| dit: WanModel, |
| x: torch.Tensor, |
| context: torch.Tensor | None, |
| lq_layer_chunks: list[list[torch.Tensor | None]], |
| block_cache_k: list[torch.Tensor | None], |
| block_cache_v: list[torch.Tensor | None], |
| chunk_index: int, |
| timestep_embed: torch.Tensor, |
| timestep_mod: torch.Tensor, |
| *, |
| topk_ratio: float = 2.0, |
| kv_ratio: float = FLASHVSR_KV_CACHE_WINDOWS, |
| local_range: int = 9, |
| cache_next: bool = True, |
| allow_short_start: bool = False, |
| abort_callback=None, |
| ) -> tuple[torch.Tensor | None, list[torch.Tensor | None], list[torch.Tensor | None]]: |
| x, (frames, height, width) = dit.patchify(x) |
| win = (2, 8, 8) |
| seqlen = frames // win[0] |
| window_size = win[0] * height * width // 128 |
| topk = int(window_size * window_size * topk_ratio) - 1 |
| kv_len = max(1, int(kv_ratio)) |
| if chunk_index == 0: |
| freqs_t = dit.freqs[0][:frames] |
| else: |
| start = 4 + chunk_index * 2 |
| freqs_t = dit.freqs[0][start:start + frames] |
| freqs = tuple((freq.real.to(device=x.device, dtype=x.dtype), freq.imag.to(device=x.device, dtype=x.dtype)) for freq in (freqs_t, dit.freqs[1][:height], dit.freqs[2][:width])) |
| for block_id, block in enumerate(dit.blocks): |
| if _abort_requested(abort_callback): |
| return None, block_cache_k, block_cache_v |
| if block_id < len(lq_layer_chunks[0]): |
| offset = 0 |
| for chunk in lq_layer_chunks: |
| lq = chunk[block_id].to(x.device, dtype=x.dtype) |
| next_offset = offset + lq.shape[1] |
| x[:, offset:next_offset].add_(lq) |
| offset = next_offset |
| chunk[block_id] = None |
| del lq |
| cache_refs = None |
| if block_cache_k[block_id] is not None: |
| cache_refs = [block_cache_k[block_id].to(x.device, dtype=x.dtype), block_cache_v[block_id].to(x.device, dtype=x.dtype)] |
| block_cache_k[block_id] = None |
| block_cache_v[block_id] = None |
| x_ref = [x] |
| x = None |
| x, next_cache_k, next_cache_v = block( |
| x_ref, context, timestep_mod, freqs, frames, height, width, seqlen, topk, |
| block_id=block_id, kv_len=kv_len, is_stream=True, |
| pre_cache_refs=cache_refs, local_range=local_range, cache_next=cache_next, allow_short_start=allow_short_start, |
| ) |
| x_ref.clear() |
| block_cache_k[block_id] = next_cache_k |
| del next_cache_k |
| block_cache_v[block_id] = next_cache_v |
| del next_cache_v, cache_refs |
| if _abort_requested(abort_callback): |
| return None, block_cache_k, block_cache_v |
| x = dit.head([x], timestep_embed) |
| return dit.unpatchify([x], (frames, height, width)), block_cache_k, block_cache_v |
|
|
|
|
| class FlashVSRRuntime: |
| def __init__(self) -> None: |
| self.variant: str | None = None |
| self.dtype = torch.bfloat16 |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.dit: WanModel | None = None |
| self.lq_proj: Causal_LQ4x_Proj | None = None |
| self.tcdecoder: torch.nn.Module | None = None |
| self.vae: WanVAE | None = None |
| self.offloadobj = None |
| self.prompt_context: torch.Tensor | None = None |
| self.timestep: torch.Tensor | None = None |
| self.timestep_embed: torch.Tensor | None = None |
| self.timestep_mod: torch.Tensor | None = None |
| self.profile = None |
|
|
| def load(self, paths: FlashVSRPaths, variant: str, profile, init_pipe) -> None: |
| require_sparge_attention() |
| variant = variant or FLASHVSR_VARIANT_TINY_LONG |
| if self.dit is not None and self.variant == variant and self.profile == profile: |
| return |
| self.release() |
| self.variant = variant |
| self.profile = profile |
| with init_empty_weights(include_buffers=True), _default_dtype(self.dtype): |
| self.dit = WanModel(**WAN_1_3B_CONFIG).eval() |
| self.lq_proj = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).eval() |
| self.dit._offload_hooks = ["reinit_cross_kv"] |
| self.lq_proj._offload_hooks = ["stream_forward"] |
| offload.load_model_data(self.dit, paths.transformer, writable_tensors=False, preprocess_sd=_preprocess_transformer_state_dict, default_dtype=self.dtype, ignore_unused_weights=True, verboseLevel=-1) |
| self.dit.freqs = precompute_freqs_cis_3d(WAN_1_3B_CONFIG["dim"] // WAN_1_3B_CONFIG["num_heads"]) |
| offload.load_model_data(self.lq_proj, paths.lq_proj, writable_tensors=False, default_dtype=self.dtype, verboseLevel=-1) |
| self.dit.requires_grad_(False) |
| self.lq_proj.requires_grad_(False) |
| self.prompt_context = load_file(paths.posi_prompt, device="cpu")["context"].to(self.dtype) |
| pipe = {"transformer": self.dit, "lq_proj": self.lq_proj} |
| if variant in (FLASHVSR_VARIANT_TINY, FLASHVSR_VARIANT_TINY_LONG): |
| self.tcdecoder = build_tcdecoder(new_channels=[512, 256, 128, 128], device="cpu", dtype=self.dtype, new_latent_channels=16 + 768).eval() |
| self.tcdecoder._offload_hooks = ["decode_video"] |
| offload.load_model_data(self.tcdecoder, paths.tcdecoder, writable_tensors=False, default_dtype=self.dtype, ignore_unused_weights=True, verboseLevel=-1) |
| self.tcdecoder.requires_grad_(False) |
| pipe["tcdecoder"] = self.tcdecoder |
| else: |
| self.vae = WanVAE(vae_pth=paths.vae, dtype=self.dtype, upsampler_factor=1, device="cpu") |
| self.vae.device = self.device |
| self.vae.model.requires_grad_(False) |
| pipe["vae"] = self.vae.model |
| kwargs = {"coTenantsMap": FLASHVSR_COTENANTS_MAP} |
| profile_no = init_pipe(pipe, kwargs, profile) |
| self.offloadobj = offload.profile(pipe, profile_no=profile_no, quantizeTransformer=False, convertWeightsFloatTo=self.dtype, verboseLevel=-1, **kwargs) |
| log_sparse_backend() |
|
|
| def _prepare_run_state(self) -> None: |
| if self.device.type != "cuda": |
| raise RuntimeError("FlashVSR requires CUDA.") |
| context = self.prompt_context.to(self.device, dtype=self.dtype) |
| self.dit.reinit_cross_kv(context) |
| self.timestep = torch.tensor([1000.0], device=self.device, dtype=self.dtype) |
| self.timestep_embed = self.dit.time_embedding(_sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) |
| self.timestep_mod = self.dit.time_projection(self.timestep_embed).unflatten(1, (6, self.dit.dim)) |
|
|
| def _clear_runtime_caches(self) -> None: |
| if self.dit is not None: |
| self.dit.clear_cross_kv() |
| if self.lq_proj is not None: |
| self.lq_proj.clear_cache() |
| if self.tcdecoder is not None: |
| self.tcdecoder.clean_mem() |
| if self.vae is not None: |
| self.vae.model.clear_cache() |
| self.timestep = None |
| self.timestep_embed = None |
| self.timestep_mod = None |
|
|
| def _unload_mmgp(self) -> None: |
| self._clear_runtime_caches() |
| if self.offloadobj is not None: |
| self.offloadobj.unload_all() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| def _decode_tcdecoder(self, latents: torch.Tensor, sample: torch.Tensor, lq_start: int, lq_end: int, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int, tile_size: int, tile_mems: dict[tuple[int, int], Any] | None, abort_callback=None, progress_callback=None, progress_step: int | None = None, progress_total: int | None = None) -> tuple[torch.Tensor | None, dict[tuple[int, int], Any] | None]: |
| if self.tcdecoder is None: |
| raise RuntimeError("FlashVSR tiny variants require TCDecoder.") |
| _report_progress(progress_callback, "TCDecoder Decoding", progress_step, progress_total) |
| tile_size = int(tile_size or 0) |
| cur_lq = _prepare_conditioning_range(sample, lq_start, lq_end, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0) |
| if tile_size <= 0 or (padded_output_height <= tile_size and padded_output_width <= tile_size): |
| cur_lq = cur_lq.to(self.device, dtype=self.dtype) |
| frames = self.tcdecoder.decode_video(latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=cur_lq).transpose(1, 2).mul_(2).sub_(1) |
| del cur_lq |
| _report_progress(progress_callback, "TCDecoder Decoding", progress_step + 1 if progress_step is not None else None, progress_total) |
| return frames, tile_mems |
|
|
| halo = _tcdecoder_mem_halo_latents(self.tcdecoder) |
| latent_tile = max(1, tile_size // 8) |
| latent_height = padded_output_height // 8 |
| latent_width = padded_output_width // 8 |
| tile_mems = {} if tile_mems is None else tile_mems |
| frames_out = None |
| for latent_y0 in range(0, latent_height, latent_tile): |
| latent_y1 = min(latent_y0 + latent_tile, latent_height) |
| write_y0, write_y1 = latent_y0 * 8, min(latent_y1 * 8, output_height) |
| if write_y1 <= write_y0: |
| continue |
| expanded_y0, expanded_y1 = max(0, latent_y0 - halo), min(latent_height, latent_y1 + halo) |
| crop_y0 = (latent_y0 - expanded_y0) * 8 |
| for latent_x0 in range(0, latent_width, latent_tile): |
| if _abort_requested(abort_callback): |
| del cur_lq |
| return None, tile_mems |
| latent_x1 = min(latent_x0 + latent_tile, latent_width) |
| write_x0, write_x1 = latent_x0 * 8, min(latent_x1 * 8, output_width) |
| if write_x1 <= write_x0: |
| continue |
| expanded_x0, expanded_x1 = max(0, latent_x0 - halo), min(latent_width, latent_x1 + halo) |
| crop_x0 = (latent_x0 - expanded_x0) * 8 |
| tile_key = (latent_y0, latent_x0) |
| saved_mem = tile_mems.get(tile_key) |
| if saved_mem is None: |
| self.tcdecoder.clean_mem() |
| else: |
| self.tcdecoder.mem = _nested_tensors_to(saved_mem, self.device, self.dtype) |
| cur_lq_tile = cur_lq[:, :, :, expanded_y0 * 8:expanded_y1 * 8, expanded_x0 * 8:expanded_x1 * 8].contiguous().to(self.device, dtype=self.dtype) |
| cur_latents = latents[:, :, :, expanded_y0:expanded_y1, expanded_x0:expanded_x1].to(self.device, dtype=self.dtype) |
| tile_frames = self.tcdecoder.decode_video(cur_latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=cur_lq_tile).transpose(1, 2).mul_(2).sub_(1) |
| tile_mems[tile_key] = _nested_tensors_to(self.tcdecoder.mem, "cpu") |
| self.tcdecoder.clean_mem() |
| tile_frames = tile_frames[:, :, :, crop_y0:crop_y0 + latent_y1 * 8 - latent_y0 * 8, crop_x0:crop_x0 + latent_x1 * 8 - latent_x0 * 8] |
| if frames_out is None: |
| frames_out = torch.empty((tile_frames.shape[0], tile_frames.shape[1], tile_frames.shape[2], output_height, output_width), dtype=torch.float32, device="cpu") |
| tile_cpu = tile_frames[:, :, :, :write_y1 - write_y0, :write_x1 - write_x0].detach().cpu().float() |
| frames_out[:, :, :, write_y0:write_y1, write_x0:write_x1].copy_(tile_cpu) |
| del cur_lq_tile, cur_latents, tile_frames, tile_cpu |
| del cur_lq |
| _report_progress(progress_callback, "TCDecoder Decoding", progress_step + 1 if progress_step is not None else None, progress_total) |
| return frames_out, tile_mems |
|
|
| def release(self) -> None: |
| self._clear_runtime_caches() |
| if self.offloadobj is not None: |
| self.offloadobj.release() |
| self.offloadobj = None |
| self.dit = None |
| self.lq_proj = None |
| self.tcdecoder = None |
| self.vae = None |
| self.prompt_context = None |
| self.timestep = None |
| self.timestep_embed = None |
| self.timestep_mod = None |
| self.variant = None |
| self.profile = None |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| @torch.inference_mode() |
| def upscale( |
| self, |
| sample: torch.Tensor, |
| scale: float, |
| *, |
| seed: int = 0, |
| continue_cache: Any = None, |
| return_continue_cache: bool = False, |
| persistent_models: bool = False, |
| vae_tile_size: int | None = None, |
| topk_ratio: float = FLASHVSR_TOPK_RATIO, |
| still_image: bool = False, |
| abort_callback=None, |
| progress_callback=None, |
| ) -> tuple[torch.Tensor | None, dict[str, Any] | None]: |
| if self.dit is None or self.lq_proj is None: |
| raise RuntimeError("FlashVSR models are not loaded.") |
| def abort_result(): |
| self._unload_mmgp() |
| if not persistent_models: |
| self.release() |
| return None, None |
|
|
| input_frames = sample.shape[1] |
| num_frames = _next_conditioning_frame_count(input_frames) |
| output_height, output_width, padded_output_height, padded_output_width = _conditioning_sizes(sample, scale) |
| configured_topk_ratio = max(0.0, min(4.0, float(topk_ratio or 0.0))) |
| if configured_topk_ratio > 0: |
| topk_ratio = configured_topk_ratio |
| print(f"[FlashVSR] Sparse top-k ratio fixed to {topk_ratio:.3f}") |
| else: |
| raw_topk_ratio = min(2.0, 2.0 * 768 * 1280 / max(int(padded_output_height) * int(padded_output_width), 1)) |
| topk_ratio = max(raw_topk_ratio, FLASHVSR_FULL_MIN_AUTO_TOPK_RATIO) |
| if topk_ratio != raw_topk_ratio: |
| print(f"[FlashVSR] Sparse top-k ratio adjusted to {topk_ratio:.3f} for {padded_output_width}x{padded_output_height} (minimum; raw auto {raw_topk_ratio:.3f})") |
| elif topk_ratio < 2.0: |
| print(f"[FlashVSR] Sparse top-k ratio adjusted to {topk_ratio:.3f} for {padded_output_width}x{padded_output_height}") |
| self._prepare_run_state() |
| self.lq_proj.clear_cache() |
| if self.tcdecoder is not None: |
| self.tcdecoder.clean_mem() |
| if self.vae is not None: |
| self.vae.model.clear_cache() |
| print(f"[FlashVSR] Stream KV cache windows: {max(1, int(FLASHVSR_KV_CACHE_WINDOWS))}") |
| tcdecoder_tile_size = int(vae_tile_size or 0) if self.tcdecoder is not None else 0 |
| tcdecoder_tile_mems = None |
| if self.tcdecoder is not None: |
| if tcdecoder_tile_size > 0 and (padded_output_height > tcdecoder_tile_size or padded_output_width > tcdecoder_tile_size): |
| print(f"[FlashVSR] TCDecoder spatial tiling policy: tile_size={tcdecoder_tile_size}px, halo={_tcdecoder_mem_halo_latents(self.tcdecoder) * 8}px") |
| tcdecoder_tile_mems = {} |
| else: |
| print("[FlashVSR] TCDecoder spatial tiling policy: tile_size=0px") |
| generator = torch.Generator(device="cpu").manual_seed(0 if seed is None or seed < 0 else int(seed)) |
| still_image = bool(still_image and input_frames == 1) |
| self.lq_proj.shift_start_prefix = still_image |
| optimize_still_image = still_image and not FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS |
| first_chunk_latent_frames = 2 if optimize_still_image else 6 |
| first_chunk_lq_steps = first_chunk_latent_frames + 1 if optimize_still_image else 7 |
| still_debug_frame_count = (first_chunk_latent_frames - 1) * 4 + 1 |
| still_output_frame = still_debug_frame_count - 1 if still_image and FLASHVSR_STILL_IMAGE_RETURN_WARMED_FRAME else 0 |
| if optimize_still_image: |
| print(f"[FlashVSR] Still image mode: denoising {first_chunk_latent_frames} startup latent frames instead of 6; returning decoded frame {still_output_frame}") |
| elif still_image and FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS: |
| print(f"[FlashVSR] Still image debug mode: image optimizations disabled; denoising original 6 startup latent frames; returning decoded frame {still_output_frame}") |
| latent_frame_count = first_chunk_latent_frames if still_image else (num_frames - 1) // 4 |
| latents = torch.empty((1, 16, latent_frame_count, padded_output_height // 8, padded_output_width // 8), device="cpu", dtype=self.dtype) |
| latents.normal_(generator=generator) |
| process_total = (num_frames - 1) // 8 - 2 |
| pre_cache_k = [None] * len(self.dit.blocks) |
| pre_cache_v = [None] * len(self.dit.blocks) |
| frames_out = None |
| frames_cursor = 0 |
| lq_pre_idx = 0 |
| lq_cur_idx = 0 |
| _report_progress(progress_callback, "Denoising", 0, process_total) |
| for process_idx in tqdm(range(process_total), desc="FlashVSR"): |
| if _abort_requested(abort_callback): |
| return abort_result() |
| lq_layer_chunks = [] |
| torch.cuda.empty_cache() |
| if process_idx == 0: |
| for inner_idx in range(first_chunk_lq_steps): |
| if _abort_requested(abort_callback): |
| return abort_result() |
| lq_chunk = _prepare_conditioning_range(sample, max(0, inner_idx * 4 - 3), (inner_idx + 1) * 4 - 3, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0).to(self.device, dtype=self.dtype) |
| lq_list = [lq_chunk] |
| del lq_chunk |
| cur = self.lq_proj.stream_forward(lq_list) |
| if cur is not None: |
| lq_layer_chunks.append(cur) |
| del cur |
| lq_cur_idx = 1 if optimize_still_image else 21 |
| latent_start, latent_end = 0, first_chunk_latent_frames |
| cur_latents = latents[:, :, :first_chunk_latent_frames].to(self.device, dtype=self.dtype) |
| else: |
| for inner_idx in range(2): |
| if _abort_requested(abort_callback): |
| return abort_result() |
| lq_start = process_idx * 8 + 17 + inner_idx * 4 |
| lq_chunk = _prepare_conditioning_range(sample, lq_start, lq_start + 4, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0).to(self.device, dtype=self.dtype) |
| lq_list = [lq_chunk] |
| del lq_chunk |
| cur = self.lq_proj.stream_forward(lq_list) |
| if cur is not None: |
| lq_layer_chunks.append(cur) |
| del cur |
| lq_cur_idx = process_idx * 8 + 21 |
| latent_start, latent_end = 4 + process_idx * 2, 6 + process_idx * 2 |
| cur_latents = latents[:, :, latent_start:latent_end].to(self.device, dtype=self.dtype) |
| torch.cuda.empty_cache() |
|
|
| noise_pred, pre_cache_k, pre_cache_v = _denoise_stream_chunk( |
| self.dit, cur_latents, None, lq_layer_chunks, pre_cache_k, pre_cache_v, process_idx, |
| self.timestep_embed, self.timestep_mod, topk_ratio=topk_ratio, cache_next=process_idx + 1 < process_total, allow_short_start=optimize_still_image and process_idx == 0, abort_callback=abort_callback, |
| ) |
| if noise_pred is None: |
| return abort_result() |
| cur_latents = cur_latents - noise_pred |
| _report_progress(progress_callback, "Denoising", process_idx + 1, process_total) |
| if self.variant == FLASHVSR_VARIANT_TINY_LONG: |
| save_still_debug_video = still_image and frames_cursor == 0 and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO |
| decode_latents = cur_latents if still_image and frames_cursor == 0 else cur_latents |
| decode_lq_cur_idx = still_debug_frame_count if still_image and frames_cursor == 0 else lq_cur_idx |
| cur_frames, tcdecoder_tile_mems = self._decode_tcdecoder(decode_latents, sample, lq_pre_idx, decode_lq_cur_idx, output_height, output_width, padded_output_height, padded_output_width, tcdecoder_tile_size, tcdecoder_tile_mems, abort_callback=abort_callback, progress_callback=progress_callback, progress_step=process_idx, progress_total=process_total) |
| if cur_frames is None: |
| return abort_result() |
| cur_frames = _crop_output_frames(cur_frames.detach().cpu(), output_height, output_width) |
| if save_still_debug_video: |
| _save_still_image_debug_video(cur_frames) |
| if still_image and frames_cursor == 0: |
| cur_frames = _select_still_image_frame(cur_frames, still_output_frame) |
| copy_frames = min(int(cur_frames.shape[2]), input_frames - frames_cursor) |
| if copy_frames > 0: |
| if frames_out is None: |
| frames_out = torch.empty((cur_frames.shape[0], cur_frames.shape[1], input_frames, output_height, output_width), dtype=torch.float32, device="cpu") |
| frames_out[:, :, frames_cursor:frames_cursor + copy_frames].copy_(cur_frames[:, :, :copy_frames].float()) |
| frames_cursor += copy_frames |
| lq_pre_idx = lq_cur_idx |
| del cur_frames |
| else: |
| latents[:, :, latent_start:latent_end].copy_(cur_latents.detach().cpu()) |
| lq_layer_chunks = None |
| self.lq_proj.clear_cache() |
| pre_cache_k = pre_cache_v = None |
| self.dit.clear_cross_kv() |
| gc.collect() |
| if self.variant == FLASHVSR_VARIANT_TINY_LONG: |
| frames = frames_out |
| else: |
| if self.variant == FLASHVSR_VARIANT_TINY: |
| if _abort_requested(abort_callback): |
| return abort_result() |
| self.tcdecoder.clean_mem() |
| frames_out = None |
| frames_cursor = 0 |
| lq_pre_idx = 0 |
| for decode_idx in range(process_total): |
| if _abort_requested(abort_callback): |
| return abort_result() |
| if decode_idx == 0: |
| lq_cur_idx = 1 if optimize_still_image else 21 |
| latent_start, latent_end = 0, first_chunk_latent_frames |
| else: |
| lq_cur_idx = decode_idx * 8 + 21 |
| latent_start, latent_end = 4 + decode_idx * 2, 6 + decode_idx * 2 |
| cur_latents = latents[:, :, latent_start:latent_end].to(self.device, dtype=self.dtype) |
| save_still_debug_video = still_image and frames_cursor == 0 and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO |
| decode_latents = cur_latents if still_image and frames_cursor == 0 else cur_latents |
| decode_lq_cur_idx = still_debug_frame_count if still_image and frames_cursor == 0 else lq_cur_idx |
| cur_frames, tcdecoder_tile_mems = self._decode_tcdecoder(decode_latents, sample, lq_pre_idx, decode_lq_cur_idx, output_height, output_width, padded_output_height, padded_output_width, tcdecoder_tile_size, tcdecoder_tile_mems, abort_callback=abort_callback, progress_callback=progress_callback, progress_step=decode_idx, progress_total=process_total) |
| if cur_frames is None: |
| return abort_result() |
| cur_frames = _crop_output_frames(cur_frames.detach().cpu(), output_height, output_width) |
| if save_still_debug_video: |
| _save_still_image_debug_video(cur_frames) |
| if still_image and frames_cursor == 0: |
| cur_frames = _select_still_image_frame(cur_frames, still_output_frame) |
| copy_frames = min(int(cur_frames.shape[2]), input_frames - frames_cursor) |
| if copy_frames > 0: |
| if frames_out is None: |
| frames_out = torch.empty((cur_frames.shape[0], cur_frames.shape[1], input_frames, output_height, output_width), dtype=torch.float32, device="cpu") |
| frames_out[:, :, frames_cursor:frames_cursor + copy_frames].copy_(cur_frames[:, :, :copy_frames].float()) |
| frames_cursor += copy_frames |
| lq_pre_idx = lq_cur_idx |
| del cur_latents, cur_frames |
| frames = frames_out |
| else: |
| if _abort_requested(abort_callback): |
| return abort_result() |
| _report_progress(progress_callback, "VAE Decoding") |
| if self.vae is None: |
| raise RuntimeError("FlashVSR full variant requires the Wan VAE.") |
| vae_tile_size = int(vae_tile_size or 0) |
| print(f"[FlashVSR] Wan VAE tiling policy: tile_size={vae_tile_size}px") |
| save_still_debug_video = still_image and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO |
| decode_latents = latents[0, :, :first_chunk_latent_frames].contiguous() if still_image else latents[0] |
| frames = self.vae.decode_to_cpu_uint8([decode_latents], vae_tile_size, target_frames=None if save_still_debug_video else 1 if still_image else input_frames, target_height=output_height, target_width=output_width, frame_start=0 if save_still_debug_video or not still_image else still_output_frame)[0] |
| if save_still_debug_video: |
| _save_still_image_debug_video(frames) |
| if still_image: |
| frames = _select_still_image_frame(frames, still_output_frame if save_still_debug_video else 0) |
| if self.tcdecoder is not None: |
| self.tcdecoder.clean_mem() |
| if self.vae is not None: |
| self.vae.model.clear_cache() |
| latents = frames_out = pre_cache_k = pre_cache_v = tcdecoder_tile_mems = None |
| noise_pred = cur_latents = lq_layer_chunks = None |
| lq_chunk = cur = cur_lq = cur_frames = None |
| if torch.is_tensor(frames) and frames.dtype == torch.uint8 and frames.ndim == 4: |
| if frames.shape[1:] != (input_frames, output_height, output_width): |
| frames = frames[:, :input_frames, :output_height, :output_width].contiguous() |
| else: |
| decoded_frames = frames |
| frames = _decoded_frames_to_cpu(decoded_frames, input_frames, output_height, output_width) |
| del decoded_frames |
| gc.collect() |
| _report_progress(progress_callback, "Color Correction") |
| _wavelet_color_fix_from_sample(frames.unsqueeze(0), sample, scale, output_height, output_width, output_height, output_width) |
| if frames.dtype != torch.uint8: |
| frames.clamp_(-1.0, 1.0) |
| frames = _apply_continue_cache(frames, continue_cache) |
| cache = _make_continue_cache(frames, scale, self.variant) if return_continue_cache else None |
| sample = None |
| self._unload_mmgp() |
| if not persistent_models: |
| self.release() |
| return frames, cache |
|
|
|
|
| _RUNTIME = FlashVSRRuntime() |
|
|
|
|
| def upscale_video( |
| sample: torch.Tensor, |
| scale: float, |
| paths: FlashVSRPaths, |
| *, |
| variant: str = FLASHVSR_VARIANT_TINY_LONG, |
| seed: int = 0, |
| continue_cache: Any = None, |
| return_continue_cache: bool = False, |
| persistent_models: bool = False, |
| vae_tile_size: int | None = None, |
| topk_ratio: float = FLASHVSR_TOPK_RATIO, |
| init_pipe, |
| profile, |
| still_image: bool = False, |
| two_pass: bool = False, |
| abort_callback=None, |
| progress_callback=None, |
| ) -> tuple[torch.Tensor | None, dict[str, Any] | None]: |
| _report_progress(progress_callback, "Caching") |
| _RUNTIME.load(paths, variant, profile=profile, init_pipe=init_pipe) |
| try: |
| shift_correction = bool( |
| FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION |
| and two_pass |
| ) |
| if shift_correction: |
| shift_y, shift_x = FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_INPUT_SHIFT or (max(1, int(round(FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_PERIOD * 0.5 / scale))), 0) |
| out_shift_y, out_shift_x = int(round(shift_y * scale)), int(round(shift_x * scale)) |
| print(f"[FlashVSR] x{scale:g} shifted two-pass blend: extra shifted pass ({shift_y}px input / {out_shift_y}px output), blend={FLASHVSR_STILL_IMAGE_SHIFT_BLEND:g}") |
| base, base_cache = _RUNTIME.upscale(sample, scale, seed=seed, continue_cache=continue_cache, return_continue_cache=return_continue_cache, persistent_models=True, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback) |
| if base is None: |
| result = (None, None) |
| else: |
| shifted_sample = _shift_spatial_replicate(sample, shift_y, shift_x) |
| shifted_continue_cache = _two_pass_shifted_continue_cache(continue_cache, out_shift_y, out_shift_x) |
| shifted, shifted_cache = _RUNTIME.upscale(shifted_sample, scale, seed=seed, continue_cache=shifted_continue_cache, return_continue_cache=return_continue_cache, persistent_models=True, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback) |
| result = (None, None) if shifted is None else (_apply_still_image_shift_correction(base, _shift_spatial_replicate(shifted, -out_shift_y, -out_shift_x), scale), _make_two_pass_continue_cache(base_cache, shifted_cache, shift_y, shift_x, out_shift_y, out_shift_x)) |
| del shifted_sample, shifted |
| del base |
| if not persistent_models: |
| _RUNTIME.release() |
| else: |
| result = _RUNTIME.upscale(sample, scale, seed=seed, continue_cache=continue_cache, return_continue_cache=return_continue_cache, persistent_models=persistent_models, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback) |
| if result[0] is None: |
| if persistent_models: |
| _RUNTIME._unload_mmgp() |
| else: |
| _RUNTIME.release() |
| return result |
| except Exception: |
| if persistent_models: |
| _RUNTIME._unload_mmgp() |
| else: |
| _RUNTIME.release() |
| raise |
|
|
|
|
| def release_models() -> None: |
| _RUNTIME.release() |
|
|