File size: 11,110 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import gradio as gr
import torch
from safetensors.torch import save_file
from postprocessing.flashvsr.wgp_bridge import FlashVSRBridge
from shared.utils.virtual_media import build_virtual_media_path
class FlashVSRProcessHandler:
system_handler = "flashvsr"
model_type = "__system_flashvsr"
model_label = "WanGP System Postprocessing"
target_control_label = "Upsampling"
target_control_choices = [(f"x{FlashVSRBridge.format_ratio_label(scale)}", FlashVSRBridge.upsampling_value(scale)) for scale in FlashVSRBridge.UPSAMPLING_RATIOS]
default_target_control = FlashVSRBridge.upsampling_value(2.0)
default_chunk_size_seconds = 3.0
frame_step = 1
minimum_requested_frames = 1
# FlashVSR's streaming output has an 11-frame tail that must be regenerated with the next source chunk before writing.
overlap_frames = 11
hide_sliding_window_overlap = True
hide_output_resolution = True
hide_prompt = True
def get_overlap_frames(self, chunk_frames: int) -> int:
return max(0, min(int(self.overlap_frames), int(chunk_frames) - 1))
def normalize_target_control(self, value: str | None) -> str:
value = str(value or "").strip()
if scale_for_lanczos(value) is not None or FlashVSRBridge.scale_for_upsampling(value) is not None:
return value
scale = scale_for_any(value)
return FlashVSRBridge.upsampling_value(scale) if scale is not None else self.default_target_control
def target_control_choices_for_process(self, process_settings: dict) -> list[tuple[str, str]]:
prefix = upsampling_prefix_for_process(process_settings)
return [(f"x{FlashVSRBridge.format_ratio_label(scale)}", upsampling_value(prefix, scale)) for scale in FlashVSRBridge.UPSAMPLING_RATIOS]
def target_control_default_for_process(self, process_settings: dict) -> str:
return self.normalize_target_control_for_process(process_settings.get("target_ratio"), process_settings)
def normalize_target_control_for_process(self, value: str | None, process_settings: dict) -> str:
scale = scale_for_any(value) or scale_for_any(process_settings.get("target_ratio")) or 2.0
return upsampling_value(upsampling_prefix_for_process(process_settings), scale if scale in FlashVSRBridge.UPSAMPLING_RATIOS else 2.0)
def output_resolution_token(self, value: str | None) -> str:
value = self.normalize_target_control(value)
scale = scale_for_lanczos(value) or FlashVSRBridge.scale_for_upsampling(value) or 2.0
prefix = "lanczos-" if value.startswith("lanczos") else ("flashvsr2pass-" if FlashVSRBridge.is_two_pass_upsampling(value) else "")
return f"{prefix}x{FlashVSRBridge.format_ratio(scale)}"
def build_queue_settings(self, process_settings: dict, *, source_path: str, start_frame: int, frame_count: int, target_control: str, seed: int, continue_cache: Any, audio_track_no: int | None = None) -> dict:
target_control = self.normalize_target_control_for_process(target_control, process_settings)
video_path = build_virtual_media_path(source_path, start_frame=start_frame, end_frame=start_frame + frame_count - 1, audio_track_no=audio_track_no)
api_options = dict(process_settings.get("_api", {})) if isinstance(process_settings.get("_api"), dict) else {}
api_options.update({"return_media": True, "suppress_source_audio": False, "suppress_metadata_images": True})
if self.supports_continue_cache_for_target(target_control):
api_options.update({"return_flashvsr_continue_cache": True, "flashvsr_continue_cache": continue_cache})
else:
api_options.pop("return_flashvsr_continue_cache", None)
api_options.pop("flashvsr_continue_cache", None)
settings = dict(process_settings)
settings.update({
"mode": "edit_postprocessing",
"model_type": self.model_type,
"prompt": str(settings.get("prompt") or "FlashVSR upsampling"),
"image_mode": 0,
"video_source": video_path,
"video_length": int(frame_count),
"keep_frames_video_source": str(int(frame_count)),
"temporal_upsampling": "",
"spatial_upsampling": target_control,
"film_grain_intensity": 0,
"film_grain_saturation": 0.5,
"postprocess_audio": "",
"repeat_generation": 1,
"batch_size": 1,
"seed": int(seed),
"_api": api_options,
})
return settings
def supports_continue_cache(self) -> bool:
return True
def supports_continue_cache_for_target(self, value: str | None) -> bool:
value = self.normalize_target_control(value)
return FlashVSRBridge.scale_for_upsampling(value) is not None
def cache_sidecar_path(self, output_filename: str) -> str:
output_path = Path(output_filename).resolve()
return str(output_path.with_suffix(output_path.suffix + ".flashvsr_cache.safetensors"))
def can_resume_without_output_metadata(self, output_filename: str) -> bool:
return Path(self.cache_sidecar_path(output_filename)).is_file() or Path(output_filename).is_file()
def move_continue_cache(self, source_output_filename: str, target_output_filename: str) -> bool:
source_path = Path(self.cache_sidecar_path(source_output_filename))
if not source_path.is_file():
return False
target_path = Path(self.cache_sidecar_path(target_output_filename))
target_path.parent.mkdir(parents=True, exist_ok=True)
source_path.replace(target_path)
return True
def delete_continue_cache(self, output_filename: str) -> None:
cache_path = Path(self.cache_sidecar_path(output_filename))
if cache_path.is_file():
cache_path.unlink()
def save_continue_cache(self, cache: Any, output_filename: str, metadata: dict | None = None) -> str:
if not isinstance(cache, dict):
return ""
tail = _cache_tail_to_uint8(cache.get("tail_frames"))
if tail is None:
return ""
tensors = {"tail_frames": tail}
shifted_tail = _cache_tail_to_uint8(cache.get("tail_frames_shifted"))
if shifted_tail is not None:
tensors["tail_frames_shifted"] = shifted_tail
cache_metadata = {
"version": "2" if shifted_tail is not None else "1",
"handler": self.system_handler,
"scale": str(cache.get("scale", "")),
"variant": str(cache.get("variant", "")),
"metadata": json.dumps(metadata or {}, ensure_ascii=True, sort_keys=True),
}
cache_metadata.update({key: str(cache[key]) for key in ("two_pass", "shift_y", "shift_x", "out_shift_y", "out_shift_x") if key in cache})
sidecar_path = self.cache_sidecar_path(output_filename)
Path(sidecar_path).parent.mkdir(parents=True, exist_ok=True)
save_file(tensors, sidecar_path, metadata=cache_metadata)
return sidecar_path
def load_continue_cache(self, output_filename: str) -> Any:
sidecar_path = self.cache_sidecar_path(output_filename)
if not Path(sidecar_path).is_file():
raise gr.Error(f"FlashVSR continuation cache is missing: {sidecar_path}")
from safetensors import safe_open
with safe_open(sidecar_path, framework="pt", device="cpu") as handle:
metadata = dict(handle.metadata() or {})
cache = {"tail_frames": _load_tail_tensor(handle, "tail_frames", sidecar_path), "scale": _coerce_float(metadata.get("scale"), 0.0), "variant": str(metadata.get("variant") or "")}
if "tail_frames_shifted" in set(handle.keys()):
cache["tail_frames_shifted"] = _load_tail_tensor(handle, "tail_frames_shifted", sidecar_path)
cache.update({key: _coerce_float(metadata.get(key), 0.0) for key in ("shift_y", "shift_x", "out_shift_y", "out_shift_x") if key in metadata})
if "two_pass" in metadata:
cache["two_pass"] = str(metadata.get("two_pass")).lower() == "true"
return cache
def continue_cache_from_tail_frames(self, tail_frames: Any, target_control: str | None = None) -> Any:
tail = _cache_tail_to_uint8(tail_frames)
if tail is None:
return None
return {"tail_frames": tail, "scale": FlashVSRBridge.scale_for_upsampling(self.normalize_target_control(target_control)) or 0.0, "variant": "", "fallback": True}
def _coerce_float(value: Any, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return float(default)
def _cache_tail_to_uint8(tail: Any) -> torch.Tensor | None:
if not torch.is_tensor(tail) or tail.ndim != 4 or int(tail.shape[1]) <= 0:
return None
if tail.dtype == torch.uint8:
return tail.detach().cpu().contiguous()
return tail.detach().cpu().float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8).contiguous()
def _load_tail_tensor(handle, key: str, sidecar_path: str) -> torch.Tensor:
tail = handle.get_tensor(key)
if not torch.is_tensor(tail) or tail.ndim != 4:
raise gr.Error(f"FlashVSR continuation cache is invalid: {sidecar_path}")
return tail.clone().contiguous() if tail.dtype == torch.uint8 else tail.float().clamp_(-1.0, 1.0).contiguous()
def scale_for_lanczos(value: str | None) -> float | None:
text = str(value or "").strip().lower()
if not text.startswith("lanczos"):
return None
try:
scale = float(text[len("lanczos"):])
except ValueError:
return None
return scale if scale in FlashVSRBridge.UPSAMPLING_RATIOS else None
def scale_for_any(value: str | None) -> float | None:
text = str(value or "").strip()
if len(text) == 0:
return None
scale = scale_for_lanczos(text) or FlashVSRBridge.scale_for_upsampling(text)
if scale is not None:
return scale
try:
scale = float(text)
except ValueError:
return None
return scale if scale in FlashVSRBridge.UPSAMPLING_RATIOS else None
def upsampling_prefix_for_process(process_settings: dict | None) -> str:
settings = process_settings if isinstance(process_settings, dict) else {}
method = str(settings.get("spatial_upsampling_method") or "").strip().lower()
if method in ("lanczos", FlashVSRBridge.UPSAMPLING_VALUE_PREFIX, FlashVSRBridge.UPSAMPLING_TWO_PASS_VALUE_PREFIX):
return method
target = str(settings.get("target_ratio") or "").strip().lower()
if target.startswith("lanczos"):
return "lanczos"
if target.startswith(FlashVSRBridge.UPSAMPLING_TWO_PASS_VALUE_PREFIX):
return FlashVSRBridge.UPSAMPLING_TWO_PASS_VALUE_PREFIX
return FlashVSRBridge.UPSAMPLING_VALUE_PREFIX
def upsampling_value(prefix: str, scale: float) -> str:
return f"{prefix}{FlashVSRBridge.format_ratio(scale)}"
HANDLER = FlashVSRProcessHandler()
|