##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep ! # I am sure you are a nice person and as you copy this code, you will give me officially proper credits: # Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter from __future__ import annotations import gc import math import os from contextlib import contextmanager from dataclasses import dataclass from typing import Any import torch import torch.nn.functional as F from accelerate import init_empty_weights from einops import rearrange from safetensors.torch import load_file from tqdm import tqdm from mmgp import offload from models.wan.modules.vae import WanVAE from .attention_backend import log_sparse_backend, require_sparge_attention from .tcdecoder import build_tcdecoder from .utils import Causal_LQ4x_Proj from .wan_video_dit import WanModel, precompute_freqs_cis_3d FLASHVSR_VARIANT_TINY_LONG = "tiny-long" FLASHVSR_VARIANT_TINY = "tiny" FLASHVSR_VARIANT_FULL = "full" FLASHVSR_TOPK_RATIO = 0.0 # 0 = auto area-scaled ratio; >0 = fixed sparse attention ratio. FLASHVSR_FULL_MIN_AUTO_TOPK_RATIO = 1.5 FLASHVSR_KV_CACHE_WINDOWS = 1 # Stream cache windows kept between denoise chunks; each window is two latent frames. FLASHVSR_CONTINUE_CACHE_FRAMES = 11 FLASHVSR_COTENANTS_MAP = {"lq_proj": ["transformer"]} FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO = False FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS = False FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION = True FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_INPUT_SHIFT = None # None = half-period output phase shift. FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_PERIOD = 16 FLASHVSR_STILL_IMAGE_RETURN_WARMED_FRAME = True FLASHVSR_STILL_IMAGE_SHIFT_BLEND = 0.5 FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_PATH = "flashvsr_still_image_debug.mp4" FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_FPS = 4 WAN_1_3B_CONFIG = { "has_image_input": False, "patch_size": (1, 2, 2), "in_dim": 16, "dim": 1536, "ffn_dim": 8960, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 12, "num_layers": 30, "eps": 1e-6, } @contextmanager def _default_dtype(dtype: torch.dtype): previous_dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) try: yield finally: torch.set_default_dtype(previous_dtype) @dataclass class FlashVSRPaths: transformer: str lq_proj: str posi_prompt: str tcdecoder: str | None = None vae: str | None = None def _preprocess_transformer_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: converter = WanModel.state_dict_converter() state_dict, _ = converter.from_civitai(state_dict) return state_dict def _sinusoidal_embedding_1d(dim: int, position: torch.Tensor) -> torch.Tensor: sinusoid = torch.outer(position.type(torch.float64), torch.pow(10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2))) return torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(position.dtype) def _next_conditioning_frame_count(frame_count: int) -> int: padded = max(25, frame_count + 4) remainder = padded % 8 if remainder != 1: padded += (1 - remainder) % 8 return padded def _aligned_output_size(height: int, width: int, scale: float) -> tuple[int, int]: target_h = max(1, int(height * scale)) target_w = max(1, int(width * scale)) return max(128, math.ceil(target_h / 128) * 128), max(128, math.ceil(target_w / 128) * 128) def _conditioning_sizes(sample: torch.Tensor, scale: float) -> tuple[int, int, int, int]: _, frames, height, width = sample.shape output_height = max(1, int(height * scale)) output_width = max(1, int(width * scale)) padded_output_height, padded_output_width = _aligned_output_size(height, width, scale) pad_h = padded_output_height - output_height pad_w = padded_output_width - output_width if pad_h or pad_w: print(f"[FlashVSR] Edge padding output canvas {output_width}x{output_height} -> {padded_output_width}x{padded_output_height}; final crop restores {output_width}x{output_height}") return output_height, output_width, padded_output_height, padded_output_width def _prepare_conditioning_range(sample: torch.Tensor, start: int, end: int, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int, dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: frames = int(sample.shape[1]) pad_h = padded_output_height - output_height pad_w = padded_output_width - output_width frame_indices = [min(max(frame_idx, 0), frames - 1) for frame_idx in range(start, end)] lq = sample[:, frame_indices] if lq.dtype == torch.uint8: lq = lq.float().div_(127.5).sub_(1.0) else: lq = lq.detach().float().clamp_(-1.0, 1.0) lq = F.interpolate(lq.permute(1, 0, 2, 3).contiguous(), size=(output_height, output_width), mode="bicubic", align_corners=False) if pad_h or pad_w: lq = F.pad(lq, (0, pad_w, 0, pad_h), mode="replicate") return lq.clamp_(-1.0, 1.0).to(dtype=dtype).permute(1, 0, 2, 3).contiguous() def _pad_conditioning_frames(lq_video: torch.Tensor, target_frames: int) -> torch.Tensor: missing = target_frames - lq_video.shape[2] if missing <= 0: return lq_video[:, :, :target_frames] tail = lq_video[:, :, -1:].repeat(1, 1, missing, 1, 1) return torch.cat([lq_video, tail], dim=2) def _crop_output_frames(frames: torch.Tensor, height: int, width: int) -> torch.Tensor: if frames.shape[-2:] == (height, width): return frames return frames[..., :height, :width].contiguous() def _shift_spatial_replicate(tensor: torch.Tensor, shift_y: int, shift_x: int) -> torch.Tensor: if shift_y == 0 and shift_x == 0: return tensor.clone() height, width = tensor.shape[-2:] shift_y = max(1 - height, min(height - 1, int(shift_y))) shift_x = max(1 - width, min(width - 1, int(shift_x))) crop = tensor[..., max(0, -shift_y):height - max(0, shift_y), max(0, -shift_x):width - max(0, shift_x)] return F.pad(crop, (max(0, shift_x), max(0, -shift_x), max(0, shift_y), max(0, -shift_y)), mode="replicate") def _apply_still_image_shift_correction(base: torch.Tensor, shifted: torch.Tensor, scale: float) -> torch.Tensor: base_float = base.to(dtype=torch.float32, copy=True) corrected = base_float.lerp_(shifted.to(dtype=torch.float32), float(FLASHVSR_STILL_IMAGE_SHIFT_BLEND)) if base.dtype == torch.uint8: return corrected.round_().clamp_(0, 255).to(torch.uint8) return corrected.clamp_(-1.0, 1.0).to(dtype=base.dtype) def _shift_continue_cache(continue_cache: Any, shift_y: int, shift_x: int) -> Any: if not isinstance(continue_cache, dict): return continue_cache tail = continue_cache.get("tail_frames") if not torch.is_tensor(tail) or tail.ndim != 4: return continue_cache shifted_cache = dict(continue_cache) shifted_cache["tail_frames"] = _shift_spatial_replicate(tail, shift_y, shift_x) return shifted_cache def _two_pass_shifted_continue_cache(continue_cache: Any, shift_y: int, shift_x: int) -> Any: if not isinstance(continue_cache, dict): return continue_cache tail = continue_cache.get("tail_frames_shifted") if not torch.is_tensor(tail) or tail.ndim != 4: return _shift_continue_cache(continue_cache, shift_y, shift_x) shifted_cache = dict(continue_cache) shifted_cache["tail_frames"] = tail return shifted_cache def _make_two_pass_continue_cache(base_cache: Any, shifted_cache: Any, shift_y: int, shift_x: int, out_shift_y: int, out_shift_x: int) -> Any: if not isinstance(base_cache, dict): return base_cache cache = dict(base_cache) shifted_tail = shifted_cache.get("tail_frames") if isinstance(shifted_cache, dict) else None if torch.is_tensor(shifted_tail) and shifted_tail.ndim == 4: cache["tail_frames_shifted"] = shifted_tail.contiguous() cache.update({"two_pass": True, "shift_y": shift_y, "shift_x": shift_x, "out_shift_y": out_shift_y, "out_shift_x": out_shift_x}) return cache def _select_still_image_frame(frames: torch.Tensor, frame_index: int) -> torch.Tensor: return frames[:, :, frame_index:frame_index + 1].contiguous() if frames.ndim == 5 else frames[:, frame_index:frame_index + 1].contiguous() def _decoded_frames_to_cpu(frames: torch.Tensor, frame_count: int, height: int, width: int) -> torch.Tensor: frames = frames.detach()[0, :, :frame_count, :height, :width] if frames.device.type == "cpu" and frames.dtype == torch.float32 and frames.is_contiguous(): return frames frames_cpu = torch.empty(tuple(frames.shape), dtype=torch.float32, device="cpu") frames_cpu.copy_(frames) return frames_cpu def _save_still_image_debug_video(frames: torch.Tensor) -> None: if not FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO: return path = os.path.abspath(FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_PATH) try: from shared.utils.audio_video import save_video debug_frames = frames.detach().cpu() save_video(tensor=debug_frames, save_file=path, fps=FLASHVSR_STILL_IMAGE_DEBUG_VIDEO_FPS, nrow=1, normalize=True, value_range=(-1, 1), codec_type="libx264_8", container="mp4") print(f"[FlashVSR] Still image debug video saved to {path} ({int(debug_frames.shape[2])} frames)") del debug_frames except Exception as exc: print(f"[FlashVSR] Failed to save still image debug video: {exc}") def _nested_tensors_to(value: Any, device: torch.device | str, dtype: torch.dtype | None = None) -> Any: if torch.is_tensor(value): return value.detach().to(device=device, dtype=dtype or value.dtype) if isinstance(value, list): return [_nested_tensors_to(item, device, dtype) for item in value] return value def _tcdecoder_mem_halo_latents(tcdecoder: torch.nn.Module) -> int: radius = 0.0 jump = 1.0 decoder = tcdecoder.taehv.decoder if hasattr(tcdecoder, "taehv") else tcdecoder.decoder for module in decoder: if isinstance(module, torch.nn.Conv2d): kernel = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else int(module.kernel_size) radius += ((kernel - 1) / 2) * jump elif module.__class__.__name__ == "MemBlock": for submodule in module.conv: if isinstance(submodule, torch.nn.Conv2d): kernel = submodule.kernel_size[0] if isinstance(submodule.kernel_size, tuple) else int(submodule.kernel_size) radius += ((kernel - 1) / 2) * jump elif isinstance(module, torch.nn.Upsample): scale = module.scale_factor[0] if isinstance(module.scale_factor, tuple) else module.scale_factor jump /= float(scale or 1) return max(1, int(math.ceil(radius))) def _report_progress(progress_callback, phase: str, current_step: int | None = None, total_steps: int | None = None) -> None: if callable(progress_callback): progress_callback(phase, current_step, total_steps) def _abort_requested(abort_callback) -> bool: return callable(abort_callback) and abort_callback() def _apply_continue_cache(frames: torch.Tensor, continue_cache: Any) -> torch.Tensor: if not isinstance(continue_cache, dict): return frames tail = continue_cache.get("tail_frames") if not torch.is_tensor(tail) or tail.ndim != 4: return frames if tail.shape[0] != frames.shape[0] or tail.shape[-2:] != frames.shape[-2:]: return frames overlap = min(int(tail.shape[1]), int(frames.shape[1])) if overlap <= 0: return frames if frames.dtype == torch.uint8: if tail.dtype != torch.uint8: tail = tail.float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8) frames[:, :overlap].copy_(tail[:, -overlap:].to(device=frames.device)) return frames if tail.dtype == torch.uint8: tail = tail.to(device=frames.device, dtype=frames.dtype).div(127.5).sub(1.0) else: tail = tail.to(device=frames.device, dtype=frames.dtype) frames[:, :overlap].copy_(tail[:, -overlap:]) return frames def _make_continue_cache(frames: torch.Tensor, scale: float, variant: str, overlap_frames: int = FLASHVSR_CONTINUE_CACHE_FRAMES) -> dict[str, Any]: tail_len = min(overlap_frames, frames.shape[1]) tail = frames[:, -tail_len:].detach().cpu() if tail.dtype != torch.uint8: tail = tail.float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8) return {"tail_frames": tail.contiguous(), "scale": scale, "variant": variant} def _wavelet_color_fix(frames: torch.Tensor, lq_video: torch.Tensor) -> torch.Tensor: if frames.shape != lq_video[:, :, :frames.shape[2]].shape: return frames for start in range(0, frames.shape[2], 4): end = min(start + 4, frames.shape[2]) frame_chunk = frames[:, :, start:end] lq_chunk = lq_video[:, :, start:end].to(device=frames.device, dtype=frames.dtype) mean_frames = frame_chunk.mean(dim=(3, 4), keepdim=True) std_frames = frame_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True) std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) frame_chunk.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).clamp_(-1.0, 1.0) return frames def _wavelet_color_fix_from_sample(frames: torch.Tensor, sample: torch.Tensor, scale: float, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int) -> torch.Tensor: step = 1 if frames.dtype == torch.uint8 else 4 for start in range(0, min(int(frames.shape[2]), int(sample.shape[1])), step): end = min(start + step, int(frames.shape[2]), int(sample.shape[1])) frame_chunk = frames[:, :, start:end] if frames.dtype == torch.uint8: frame_float = frame_chunk.float() lq_chunk = sample[:, start:end].unsqueeze(0).to(device=frames.device, dtype=torch.float32) if sample.dtype != torch.uint8: lq_chunk.clamp_(-1.0, 1.0).add_(1.0).mul_(127.5) mean_frames = frame_float.mean(dim=(3, 4), keepdim=True) std_frames = frame_float.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True) std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) frame_float.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).round_().clamp_(0, 255) frame_chunk.copy_(frame_float.to(torch.uint8)) del frame_float, lq_chunk, mean_frames, std_frames, mean_lq, std_lq continue lq_chunk = sample[:, start:end].unsqueeze(0).to(device=frames.device, dtype=frames.dtype) if sample.dtype == torch.uint8: lq_chunk.div_(127.5).sub_(1.0) else: lq_chunk.clamp_(-1.0, 1.0) mean_frames = frame_chunk.mean(dim=(3, 4), keepdim=True) std_frames = frame_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) mean_lq = lq_chunk.mean(dim=(3, 4), keepdim=True) std_lq = lq_chunk.std(dim=(3, 4), keepdim=True).clamp_min_(1e-5) frame_chunk.sub_(mean_frames).div_(std_frames).mul_(std_lq).add_(mean_lq).clamp_(-1.0, 1.0) del lq_chunk, mean_frames, std_frames, mean_lq, std_lq return frames def _denoise_stream_chunk( dit: WanModel, x: torch.Tensor, context: torch.Tensor | None, lq_layer_chunks: list[list[torch.Tensor | None]], block_cache_k: list[torch.Tensor | None], block_cache_v: list[torch.Tensor | None], chunk_index: int, timestep_embed: torch.Tensor, timestep_mod: torch.Tensor, *, topk_ratio: float = 2.0, kv_ratio: float = FLASHVSR_KV_CACHE_WINDOWS, local_range: int = 9, cache_next: bool = True, allow_short_start: bool = False, abort_callback=None, ) -> tuple[torch.Tensor | None, list[torch.Tensor | None], list[torch.Tensor | None]]: x, (frames, height, width) = dit.patchify(x) win = (2, 8, 8) seqlen = frames // win[0] window_size = win[0] * height * width // 128 topk = int(window_size * window_size * topk_ratio) - 1 kv_len = max(1, int(kv_ratio)) if chunk_index == 0: freqs_t = dit.freqs[0][:frames] else: start = 4 + chunk_index * 2 freqs_t = dit.freqs[0][start:start + frames] freqs = tuple((freq.real.to(device=x.device, dtype=x.dtype), freq.imag.to(device=x.device, dtype=x.dtype)) for freq in (freqs_t, dit.freqs[1][:height], dit.freqs[2][:width])) for block_id, block in enumerate(dit.blocks): if _abort_requested(abort_callback): return None, block_cache_k, block_cache_v if block_id < len(lq_layer_chunks[0]): offset = 0 for chunk in lq_layer_chunks: lq = chunk[block_id].to(x.device, dtype=x.dtype) next_offset = offset + lq.shape[1] x[:, offset:next_offset].add_(lq) offset = next_offset chunk[block_id] = None del lq cache_refs = None if block_cache_k[block_id] is not None: cache_refs = [block_cache_k[block_id].to(x.device, dtype=x.dtype), block_cache_v[block_id].to(x.device, dtype=x.dtype)] block_cache_k[block_id] = None block_cache_v[block_id] = None x_ref = [x] x = None x, next_cache_k, next_cache_v = block( x_ref, context, timestep_mod, freqs, frames, height, width, seqlen, topk, block_id=block_id, kv_len=kv_len, is_stream=True, pre_cache_refs=cache_refs, local_range=local_range, cache_next=cache_next, allow_short_start=allow_short_start, ) x_ref.clear() block_cache_k[block_id] = next_cache_k del next_cache_k block_cache_v[block_id] = next_cache_v del next_cache_v, cache_refs if _abort_requested(abort_callback): return None, block_cache_k, block_cache_v x = dit.head([x], timestep_embed) return dit.unpatchify([x], (frames, height, width)), block_cache_k, block_cache_v class FlashVSRRuntime: def __init__(self) -> None: self.variant: str | None = None self.dtype = torch.bfloat16 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dit: WanModel | None = None self.lq_proj: Causal_LQ4x_Proj | None = None self.tcdecoder: torch.nn.Module | None = None self.vae: WanVAE | None = None self.offloadobj = None self.prompt_context: torch.Tensor | None = None self.timestep: torch.Tensor | None = None self.timestep_embed: torch.Tensor | None = None self.timestep_mod: torch.Tensor | None = None self.profile = None def load(self, paths: FlashVSRPaths, variant: str, profile, init_pipe) -> None: require_sparge_attention() variant = variant or FLASHVSR_VARIANT_TINY_LONG if self.dit is not None and self.variant == variant and self.profile == profile: return self.release() self.variant = variant self.profile = profile with init_empty_weights(include_buffers=True), _default_dtype(self.dtype): self.dit = WanModel(**WAN_1_3B_CONFIG).eval() self.lq_proj = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).eval() self.dit._offload_hooks = ["reinit_cross_kv"] self.lq_proj._offload_hooks = ["stream_forward"] offload.load_model_data(self.dit, paths.transformer, writable_tensors=False, preprocess_sd=_preprocess_transformer_state_dict, default_dtype=self.dtype, ignore_unused_weights=True, verboseLevel=-1) self.dit.freqs = precompute_freqs_cis_3d(WAN_1_3B_CONFIG["dim"] // WAN_1_3B_CONFIG["num_heads"]) offload.load_model_data(self.lq_proj, paths.lq_proj, writable_tensors=False, default_dtype=self.dtype, verboseLevel=-1) self.dit.requires_grad_(False) self.lq_proj.requires_grad_(False) self.prompt_context = load_file(paths.posi_prompt, device="cpu")["context"].to(self.dtype) pipe = {"transformer": self.dit, "lq_proj": self.lq_proj} if variant in (FLASHVSR_VARIANT_TINY, FLASHVSR_VARIANT_TINY_LONG): self.tcdecoder = build_tcdecoder(new_channels=[512, 256, 128, 128], device="cpu", dtype=self.dtype, new_latent_channels=16 + 768).eval() self.tcdecoder._offload_hooks = ["decode_video"] offload.load_model_data(self.tcdecoder, paths.tcdecoder, writable_tensors=False, default_dtype=self.dtype, ignore_unused_weights=True, verboseLevel=-1) self.tcdecoder.requires_grad_(False) pipe["tcdecoder"] = self.tcdecoder else: self.vae = WanVAE(vae_pth=paths.vae, dtype=self.dtype, upsampler_factor=1, device="cpu") self.vae.device = self.device self.vae.model.requires_grad_(False) pipe["vae"] = self.vae.model kwargs = {"coTenantsMap": FLASHVSR_COTENANTS_MAP} profile_no = init_pipe(pipe, kwargs, profile) self.offloadobj = offload.profile(pipe, profile_no=profile_no, quantizeTransformer=False, convertWeightsFloatTo=self.dtype, verboseLevel=-1, **kwargs) log_sparse_backend() def _prepare_run_state(self) -> None: if self.device.type != "cuda": raise RuntimeError("FlashVSR requires CUDA.") context = self.prompt_context.to(self.device, dtype=self.dtype) self.dit.reinit_cross_kv(context) self.timestep = torch.tensor([1000.0], device=self.device, dtype=self.dtype) self.timestep_embed = self.dit.time_embedding(_sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) self.timestep_mod = self.dit.time_projection(self.timestep_embed).unflatten(1, (6, self.dit.dim)) def _clear_runtime_caches(self) -> None: if self.dit is not None: self.dit.clear_cross_kv() if self.lq_proj is not None: self.lq_proj.clear_cache() if self.tcdecoder is not None: self.tcdecoder.clean_mem() if self.vae is not None: self.vae.model.clear_cache() self.timestep = None self.timestep_embed = None self.timestep_mod = None def _unload_mmgp(self) -> None: self._clear_runtime_caches() if self.offloadobj is not None: self.offloadobj.unload_all() if torch.cuda.is_available(): torch.cuda.empty_cache() def _decode_tcdecoder(self, latents: torch.Tensor, sample: torch.Tensor, lq_start: int, lq_end: int, output_height: int, output_width: int, padded_output_height: int, padded_output_width: int, tile_size: int, tile_mems: dict[tuple[int, int], Any] | None, abort_callback=None, progress_callback=None, progress_step: int | None = None, progress_total: int | None = None) -> tuple[torch.Tensor | None, dict[tuple[int, int], Any] | None]: if self.tcdecoder is None: raise RuntimeError("FlashVSR tiny variants require TCDecoder.") _report_progress(progress_callback, "TCDecoder Decoding", progress_step, progress_total) tile_size = int(tile_size or 0) cur_lq = _prepare_conditioning_range(sample, lq_start, lq_end, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0) if tile_size <= 0 or (padded_output_height <= tile_size and padded_output_width <= tile_size): cur_lq = cur_lq.to(self.device, dtype=self.dtype) frames = self.tcdecoder.decode_video(latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=cur_lq).transpose(1, 2).mul_(2).sub_(1) del cur_lq _report_progress(progress_callback, "TCDecoder Decoding", progress_step + 1 if progress_step is not None else None, progress_total) return frames, tile_mems halo = _tcdecoder_mem_halo_latents(self.tcdecoder) latent_tile = max(1, tile_size // 8) latent_height = padded_output_height // 8 latent_width = padded_output_width // 8 tile_mems = {} if tile_mems is None else tile_mems frames_out = None for latent_y0 in range(0, latent_height, latent_tile): latent_y1 = min(latent_y0 + latent_tile, latent_height) write_y0, write_y1 = latent_y0 * 8, min(latent_y1 * 8, output_height) if write_y1 <= write_y0: continue expanded_y0, expanded_y1 = max(0, latent_y0 - halo), min(latent_height, latent_y1 + halo) crop_y0 = (latent_y0 - expanded_y0) * 8 for latent_x0 in range(0, latent_width, latent_tile): if _abort_requested(abort_callback): del cur_lq return None, tile_mems latent_x1 = min(latent_x0 + latent_tile, latent_width) write_x0, write_x1 = latent_x0 * 8, min(latent_x1 * 8, output_width) if write_x1 <= write_x0: continue expanded_x0, expanded_x1 = max(0, latent_x0 - halo), min(latent_width, latent_x1 + halo) crop_x0 = (latent_x0 - expanded_x0) * 8 tile_key = (latent_y0, latent_x0) saved_mem = tile_mems.get(tile_key) if saved_mem is None: self.tcdecoder.clean_mem() else: self.tcdecoder.mem = _nested_tensors_to(saved_mem, self.device, self.dtype) cur_lq_tile = cur_lq[:, :, :, expanded_y0 * 8:expanded_y1 * 8, expanded_x0 * 8:expanded_x1 * 8].contiguous().to(self.device, dtype=self.dtype) cur_latents = latents[:, :, :, expanded_y0:expanded_y1, expanded_x0:expanded_x1].to(self.device, dtype=self.dtype) tile_frames = self.tcdecoder.decode_video(cur_latents.transpose(1, 2), parallel=False, show_progress_bar=False, cond=cur_lq_tile).transpose(1, 2).mul_(2).sub_(1) tile_mems[tile_key] = _nested_tensors_to(self.tcdecoder.mem, "cpu") self.tcdecoder.clean_mem() tile_frames = tile_frames[:, :, :, crop_y0:crop_y0 + latent_y1 * 8 - latent_y0 * 8, crop_x0:crop_x0 + latent_x1 * 8 - latent_x0 * 8] if frames_out is None: frames_out = torch.empty((tile_frames.shape[0], tile_frames.shape[1], tile_frames.shape[2], output_height, output_width), dtype=torch.float32, device="cpu") tile_cpu = tile_frames[:, :, :, :write_y1 - write_y0, :write_x1 - write_x0].detach().cpu().float() frames_out[:, :, :, write_y0:write_y1, write_x0:write_x1].copy_(tile_cpu) del cur_lq_tile, cur_latents, tile_frames, tile_cpu del cur_lq _report_progress(progress_callback, "TCDecoder Decoding", progress_step + 1 if progress_step is not None else None, progress_total) return frames_out, tile_mems def release(self) -> None: self._clear_runtime_caches() if self.offloadobj is not None: self.offloadobj.release() self.offloadobj = None self.dit = None self.lq_proj = None self.tcdecoder = None self.vae = None self.prompt_context = None self.timestep = None self.timestep_embed = None self.timestep_mod = None self.variant = None self.profile = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.inference_mode() def upscale( self, sample: torch.Tensor, scale: float, *, seed: int = 0, continue_cache: Any = None, return_continue_cache: bool = False, persistent_models: bool = False, vae_tile_size: int | None = None, topk_ratio: float = FLASHVSR_TOPK_RATIO, still_image: bool = False, abort_callback=None, progress_callback=None, ) -> tuple[torch.Tensor | None, dict[str, Any] | None]: if self.dit is None or self.lq_proj is None: raise RuntimeError("FlashVSR models are not loaded.") def abort_result(): self._unload_mmgp() if not persistent_models: self.release() return None, None input_frames = sample.shape[1] num_frames = _next_conditioning_frame_count(input_frames) output_height, output_width, padded_output_height, padded_output_width = _conditioning_sizes(sample, scale) configured_topk_ratio = max(0.0, min(4.0, float(topk_ratio or 0.0))) if configured_topk_ratio > 0: topk_ratio = configured_topk_ratio print(f"[FlashVSR] Sparse top-k ratio fixed to {topk_ratio:.3f}") else: raw_topk_ratio = min(2.0, 2.0 * 768 * 1280 / max(int(padded_output_height) * int(padded_output_width), 1)) topk_ratio = max(raw_topk_ratio, FLASHVSR_FULL_MIN_AUTO_TOPK_RATIO) if topk_ratio != raw_topk_ratio: print(f"[FlashVSR] Sparse top-k ratio adjusted to {topk_ratio:.3f} for {padded_output_width}x{padded_output_height} (minimum; raw auto {raw_topk_ratio:.3f})") elif topk_ratio < 2.0: print(f"[FlashVSR] Sparse top-k ratio adjusted to {topk_ratio:.3f} for {padded_output_width}x{padded_output_height}") self._prepare_run_state() self.lq_proj.clear_cache() if self.tcdecoder is not None: self.tcdecoder.clean_mem() if self.vae is not None: self.vae.model.clear_cache() print(f"[FlashVSR] Stream KV cache windows: {max(1, int(FLASHVSR_KV_CACHE_WINDOWS))}") tcdecoder_tile_size = int(vae_tile_size or 0) if self.tcdecoder is not None else 0 tcdecoder_tile_mems = None if self.tcdecoder is not None: if tcdecoder_tile_size > 0 and (padded_output_height > tcdecoder_tile_size or padded_output_width > tcdecoder_tile_size): print(f"[FlashVSR] TCDecoder spatial tiling policy: tile_size={tcdecoder_tile_size}px, halo={_tcdecoder_mem_halo_latents(self.tcdecoder) * 8}px") tcdecoder_tile_mems = {} else: print("[FlashVSR] TCDecoder spatial tiling policy: tile_size=0px") generator = torch.Generator(device="cpu").manual_seed(0 if seed is None or seed < 0 else int(seed)) still_image = bool(still_image and input_frames == 1) self.lq_proj.shift_start_prefix = still_image optimize_still_image = still_image and not FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS first_chunk_latent_frames = 2 if optimize_still_image else 6 first_chunk_lq_steps = first_chunk_latent_frames + 1 if optimize_still_image else 7 still_debug_frame_count = (first_chunk_latent_frames - 1) * 4 + 1 still_output_frame = still_debug_frame_count - 1 if still_image and FLASHVSR_STILL_IMAGE_RETURN_WARMED_FRAME else 0 if optimize_still_image: print(f"[FlashVSR] Still image mode: denoising {first_chunk_latent_frames} startup latent frames instead of 6; returning decoded frame {still_output_frame}") elif still_image and FLASHVSR_DISABLE_STILL_IMAGE_OPTIMIZATIONS: print(f"[FlashVSR] Still image debug mode: image optimizations disabled; denoising original 6 startup latent frames; returning decoded frame {still_output_frame}") latent_frame_count = first_chunk_latent_frames if still_image else (num_frames - 1) // 4 latents = torch.empty((1, 16, latent_frame_count, padded_output_height // 8, padded_output_width // 8), device="cpu", dtype=self.dtype) latents.normal_(generator=generator) process_total = (num_frames - 1) // 8 - 2 pre_cache_k = [None] * len(self.dit.blocks) pre_cache_v = [None] * len(self.dit.blocks) frames_out = None frames_cursor = 0 lq_pre_idx = 0 lq_cur_idx = 0 _report_progress(progress_callback, "Denoising", 0, process_total) for process_idx in tqdm(range(process_total), desc="FlashVSR"): if _abort_requested(abort_callback): return abort_result() lq_layer_chunks = [] torch.cuda.empty_cache() if process_idx == 0: for inner_idx in range(first_chunk_lq_steps): if _abort_requested(abort_callback): return abort_result() lq_chunk = _prepare_conditioning_range(sample, max(0, inner_idx * 4 - 3), (inner_idx + 1) * 4 - 3, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0).to(self.device, dtype=self.dtype) lq_list = [lq_chunk] del lq_chunk cur = self.lq_proj.stream_forward(lq_list) if cur is not None: lq_layer_chunks.append(cur) del cur lq_cur_idx = 1 if optimize_still_image else 21 latent_start, latent_end = 0, first_chunk_latent_frames cur_latents = latents[:, :, :first_chunk_latent_frames].to(self.device, dtype=self.dtype) else: for inner_idx in range(2): if _abort_requested(abort_callback): return abort_result() lq_start = process_idx * 8 + 17 + inner_idx * 4 lq_chunk = _prepare_conditioning_range(sample, lq_start, lq_start + 4, output_height, output_width, padded_output_height, padded_output_width, dtype=self.dtype).unsqueeze(0).to(self.device, dtype=self.dtype) lq_list = [lq_chunk] del lq_chunk cur = self.lq_proj.stream_forward(lq_list) if cur is not None: lq_layer_chunks.append(cur) del cur lq_cur_idx = process_idx * 8 + 21 latent_start, latent_end = 4 + process_idx * 2, 6 + process_idx * 2 cur_latents = latents[:, :, latent_start:latent_end].to(self.device, dtype=self.dtype) torch.cuda.empty_cache() noise_pred, pre_cache_k, pre_cache_v = _denoise_stream_chunk( self.dit, cur_latents, None, lq_layer_chunks, pre_cache_k, pre_cache_v, process_idx, self.timestep_embed, self.timestep_mod, topk_ratio=topk_ratio, cache_next=process_idx + 1 < process_total, allow_short_start=optimize_still_image and process_idx == 0, abort_callback=abort_callback, ) if noise_pred is None: return abort_result() cur_latents = cur_latents - noise_pred _report_progress(progress_callback, "Denoising", process_idx + 1, process_total) if self.variant == FLASHVSR_VARIANT_TINY_LONG: save_still_debug_video = still_image and frames_cursor == 0 and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO decode_latents = cur_latents if still_image and frames_cursor == 0 else cur_latents decode_lq_cur_idx = still_debug_frame_count if still_image and frames_cursor == 0 else lq_cur_idx cur_frames, tcdecoder_tile_mems = self._decode_tcdecoder(decode_latents, sample, lq_pre_idx, decode_lq_cur_idx, output_height, output_width, padded_output_height, padded_output_width, tcdecoder_tile_size, tcdecoder_tile_mems, abort_callback=abort_callback, progress_callback=progress_callback, progress_step=process_idx, progress_total=process_total) if cur_frames is None: return abort_result() cur_frames = _crop_output_frames(cur_frames.detach().cpu(), output_height, output_width) if save_still_debug_video: _save_still_image_debug_video(cur_frames) if still_image and frames_cursor == 0: cur_frames = _select_still_image_frame(cur_frames, still_output_frame) copy_frames = min(int(cur_frames.shape[2]), input_frames - frames_cursor) if copy_frames > 0: if frames_out is None: frames_out = torch.empty((cur_frames.shape[0], cur_frames.shape[1], input_frames, output_height, output_width), dtype=torch.float32, device="cpu") frames_out[:, :, frames_cursor:frames_cursor + copy_frames].copy_(cur_frames[:, :, :copy_frames].float()) frames_cursor += copy_frames lq_pre_idx = lq_cur_idx del cur_frames else: latents[:, :, latent_start:latent_end].copy_(cur_latents.detach().cpu()) lq_layer_chunks = None self.lq_proj.clear_cache() pre_cache_k = pre_cache_v = None self.dit.clear_cross_kv() gc.collect() if self.variant == FLASHVSR_VARIANT_TINY_LONG: frames = frames_out else: if self.variant == FLASHVSR_VARIANT_TINY: if _abort_requested(abort_callback): return abort_result() self.tcdecoder.clean_mem() frames_out = None frames_cursor = 0 lq_pre_idx = 0 for decode_idx in range(process_total): if _abort_requested(abort_callback): return abort_result() if decode_idx == 0: lq_cur_idx = 1 if optimize_still_image else 21 latent_start, latent_end = 0, first_chunk_latent_frames else: lq_cur_idx = decode_idx * 8 + 21 latent_start, latent_end = 4 + decode_idx * 2, 6 + decode_idx * 2 cur_latents = latents[:, :, latent_start:latent_end].to(self.device, dtype=self.dtype) save_still_debug_video = still_image and frames_cursor == 0 and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO decode_latents = cur_latents if still_image and frames_cursor == 0 else cur_latents decode_lq_cur_idx = still_debug_frame_count if still_image and frames_cursor == 0 else lq_cur_idx cur_frames, tcdecoder_tile_mems = self._decode_tcdecoder(decode_latents, sample, lq_pre_idx, decode_lq_cur_idx, output_height, output_width, padded_output_height, padded_output_width, tcdecoder_tile_size, tcdecoder_tile_mems, abort_callback=abort_callback, progress_callback=progress_callback, progress_step=decode_idx, progress_total=process_total) if cur_frames is None: return abort_result() cur_frames = _crop_output_frames(cur_frames.detach().cpu(), output_height, output_width) if save_still_debug_video: _save_still_image_debug_video(cur_frames) if still_image and frames_cursor == 0: cur_frames = _select_still_image_frame(cur_frames, still_output_frame) copy_frames = min(int(cur_frames.shape[2]), input_frames - frames_cursor) if copy_frames > 0: if frames_out is None: frames_out = torch.empty((cur_frames.shape[0], cur_frames.shape[1], input_frames, output_height, output_width), dtype=torch.float32, device="cpu") frames_out[:, :, frames_cursor:frames_cursor + copy_frames].copy_(cur_frames[:, :, :copy_frames].float()) frames_cursor += copy_frames lq_pre_idx = lq_cur_idx del cur_latents, cur_frames frames = frames_out else: if _abort_requested(abort_callback): return abort_result() _report_progress(progress_callback, "VAE Decoding") if self.vae is None: raise RuntimeError("FlashVSR full variant requires the Wan VAE.") vae_tile_size = int(vae_tile_size or 0) print(f"[FlashVSR] Wan VAE tiling policy: tile_size={vae_tile_size}px") save_still_debug_video = still_image and FLASHVSR_SAVE_STILL_IMAGE_DEBUG_VIDEO decode_latents = latents[0, :, :first_chunk_latent_frames].contiguous() if still_image else latents[0] frames = self.vae.decode_to_cpu_uint8([decode_latents], vae_tile_size, target_frames=None if save_still_debug_video else 1 if still_image else input_frames, target_height=output_height, target_width=output_width, frame_start=0 if save_still_debug_video or not still_image else still_output_frame)[0] if save_still_debug_video: _save_still_image_debug_video(frames) if still_image: frames = _select_still_image_frame(frames, still_output_frame if save_still_debug_video else 0) if self.tcdecoder is not None: self.tcdecoder.clean_mem() if self.vae is not None: self.vae.model.clear_cache() latents = frames_out = pre_cache_k = pre_cache_v = tcdecoder_tile_mems = None noise_pred = cur_latents = lq_layer_chunks = None lq_chunk = cur = cur_lq = cur_frames = None if torch.is_tensor(frames) and frames.dtype == torch.uint8 and frames.ndim == 4: if frames.shape[1:] != (input_frames, output_height, output_width): frames = frames[:, :input_frames, :output_height, :output_width].contiguous() else: decoded_frames = frames frames = _decoded_frames_to_cpu(decoded_frames, input_frames, output_height, output_width) del decoded_frames gc.collect() _report_progress(progress_callback, "Color Correction") _wavelet_color_fix_from_sample(frames.unsqueeze(0), sample, scale, output_height, output_width, output_height, output_width) if frames.dtype != torch.uint8: frames.clamp_(-1.0, 1.0) frames = _apply_continue_cache(frames, continue_cache) cache = _make_continue_cache(frames, scale, self.variant) if return_continue_cache else None sample = None self._unload_mmgp() if not persistent_models: self.release() return frames, cache _RUNTIME = FlashVSRRuntime() def upscale_video( sample: torch.Tensor, scale: float, paths: FlashVSRPaths, *, variant: str = FLASHVSR_VARIANT_TINY_LONG, seed: int = 0, continue_cache: Any = None, return_continue_cache: bool = False, persistent_models: bool = False, vae_tile_size: int | None = None, topk_ratio: float = FLASHVSR_TOPK_RATIO, init_pipe, profile, still_image: bool = False, two_pass: bool = False, abort_callback=None, progress_callback=None, ) -> tuple[torch.Tensor | None, dict[str, Any] | None]: _report_progress(progress_callback, "Caching") _RUNTIME.load(paths, variant, profile=profile, init_pipe=init_pipe) try: shift_correction = bool( FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION and two_pass ) if shift_correction: shift_y, shift_x = FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_INPUT_SHIFT or (max(1, int(round(FLASHVSR_STILL_IMAGE_SHIFT_CORRECTION_PERIOD * 0.5 / scale))), 0) out_shift_y, out_shift_x = int(round(shift_y * scale)), int(round(shift_x * scale)) print(f"[FlashVSR] x{scale:g} shifted two-pass blend: extra shifted pass ({shift_y}px input / {out_shift_y}px output), blend={FLASHVSR_STILL_IMAGE_SHIFT_BLEND:g}") base, base_cache = _RUNTIME.upscale(sample, scale, seed=seed, continue_cache=continue_cache, return_continue_cache=return_continue_cache, persistent_models=True, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback) if base is None: result = (None, None) else: shifted_sample = _shift_spatial_replicate(sample, shift_y, shift_x) shifted_continue_cache = _two_pass_shifted_continue_cache(continue_cache, out_shift_y, out_shift_x) shifted, shifted_cache = _RUNTIME.upscale(shifted_sample, scale, seed=seed, continue_cache=shifted_continue_cache, return_continue_cache=return_continue_cache, persistent_models=True, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback) result = (None, None) if shifted is None else (_apply_still_image_shift_correction(base, _shift_spatial_replicate(shifted, -out_shift_y, -out_shift_x), scale), _make_two_pass_continue_cache(base_cache, shifted_cache, shift_y, shift_x, out_shift_y, out_shift_x)) del shifted_sample, shifted del base if not persistent_models: _RUNTIME.release() else: result = _RUNTIME.upscale(sample, scale, seed=seed, continue_cache=continue_cache, return_continue_cache=return_continue_cache, persistent_models=persistent_models, vae_tile_size=vae_tile_size, topk_ratio=topk_ratio, still_image=still_image, abort_callback=abort_callback, progress_callback=progress_callback) if result[0] is None: if persistent_models: _RUNTIME._unload_mmgp() else: _RUNTIME.release() return result except Exception: if persistent_models: _RUNTIME._unload_mmgp() else: _RUNTIME.release() raise def release_models() -> None: _RUNTIME.release()