ColabWan / postprocessing /flashvsr /process_handler.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import gradio as gr
import torch
from safetensors.torch import save_file
from postprocessing.flashvsr.wgp_bridge import FlashVSRBridge
from shared.utils.virtual_media import build_virtual_media_path
class FlashVSRProcessHandler:
system_handler = "flashvsr"
model_type = "__system_flashvsr"
model_label = "WanGP System Postprocessing"
target_control_label = "Upsampling"
target_control_choices = [(f"x{FlashVSRBridge.format_ratio_label(scale)}", FlashVSRBridge.upsampling_value(scale)) for scale in FlashVSRBridge.UPSAMPLING_RATIOS]
default_target_control = FlashVSRBridge.upsampling_value(2.0)
default_chunk_size_seconds = 3.0
frame_step = 1
minimum_requested_frames = 1
# FlashVSR's streaming output has an 11-frame tail that must be regenerated with the next source chunk before writing.
overlap_frames = 11
hide_sliding_window_overlap = True
hide_output_resolution = True
hide_prompt = True
def get_overlap_frames(self, chunk_frames: int) -> int:
return max(0, min(int(self.overlap_frames), int(chunk_frames) - 1))
def normalize_target_control(self, value: str | None) -> str:
value = str(value or "").strip()
if scale_for_lanczos(value) is not None or FlashVSRBridge.scale_for_upsampling(value) is not None:
return value
scale = scale_for_any(value)
return FlashVSRBridge.upsampling_value(scale) if scale is not None else self.default_target_control
def target_control_choices_for_process(self, process_settings: dict) -> list[tuple[str, str]]:
prefix = upsampling_prefix_for_process(process_settings)
return [(f"x{FlashVSRBridge.format_ratio_label(scale)}", upsampling_value(prefix, scale)) for scale in FlashVSRBridge.UPSAMPLING_RATIOS]
def target_control_default_for_process(self, process_settings: dict) -> str:
return self.normalize_target_control_for_process(process_settings.get("target_ratio"), process_settings)
def normalize_target_control_for_process(self, value: str | None, process_settings: dict) -> str:
scale = scale_for_any(value) or scale_for_any(process_settings.get("target_ratio")) or 2.0
return upsampling_value(upsampling_prefix_for_process(process_settings), scale if scale in FlashVSRBridge.UPSAMPLING_RATIOS else 2.0)
def output_resolution_token(self, value: str | None) -> str:
value = self.normalize_target_control(value)
scale = scale_for_lanczos(value) or FlashVSRBridge.scale_for_upsampling(value) or 2.0
prefix = "lanczos-" if value.startswith("lanczos") else ("flashvsr2pass-" if FlashVSRBridge.is_two_pass_upsampling(value) else "")
return f"{prefix}x{FlashVSRBridge.format_ratio(scale)}"
def build_queue_settings(self, process_settings: dict, *, source_path: str, start_frame: int, frame_count: int, target_control: str, seed: int, continue_cache: Any, audio_track_no: int | None = None) -> dict:
target_control = self.normalize_target_control_for_process(target_control, process_settings)
video_path = build_virtual_media_path(source_path, start_frame=start_frame, end_frame=start_frame + frame_count - 1, audio_track_no=audio_track_no)
api_options = dict(process_settings.get("_api", {})) if isinstance(process_settings.get("_api"), dict) else {}
api_options.update({"return_media": True, "suppress_source_audio": False, "suppress_metadata_images": True})
if self.supports_continue_cache_for_target(target_control):
api_options.update({"return_flashvsr_continue_cache": True, "flashvsr_continue_cache": continue_cache})
else:
api_options.pop("return_flashvsr_continue_cache", None)
api_options.pop("flashvsr_continue_cache", None)
settings = dict(process_settings)
settings.update({
"mode": "edit_postprocessing",
"model_type": self.model_type,
"prompt": str(settings.get("prompt") or "FlashVSR upsampling"),
"image_mode": 0,
"video_source": video_path,
"video_length": int(frame_count),
"keep_frames_video_source": str(int(frame_count)),
"temporal_upsampling": "",
"spatial_upsampling": target_control,
"film_grain_intensity": 0,
"film_grain_saturation": 0.5,
"postprocess_audio": "",
"repeat_generation": 1,
"batch_size": 1,
"seed": int(seed),
"_api": api_options,
})
return settings
def supports_continue_cache(self) -> bool:
return True
def supports_continue_cache_for_target(self, value: str | None) -> bool:
value = self.normalize_target_control(value)
return FlashVSRBridge.scale_for_upsampling(value) is not None
def cache_sidecar_path(self, output_filename: str) -> str:
output_path = Path(output_filename).resolve()
return str(output_path.with_suffix(output_path.suffix + ".flashvsr_cache.safetensors"))
def can_resume_without_output_metadata(self, output_filename: str) -> bool:
return Path(self.cache_sidecar_path(output_filename)).is_file() or Path(output_filename).is_file()
def move_continue_cache(self, source_output_filename: str, target_output_filename: str) -> bool:
source_path = Path(self.cache_sidecar_path(source_output_filename))
if not source_path.is_file():
return False
target_path = Path(self.cache_sidecar_path(target_output_filename))
target_path.parent.mkdir(parents=True, exist_ok=True)
source_path.replace(target_path)
return True
def delete_continue_cache(self, output_filename: str) -> None:
cache_path = Path(self.cache_sidecar_path(output_filename))
if cache_path.is_file():
cache_path.unlink()
def save_continue_cache(self, cache: Any, output_filename: str, metadata: dict | None = None) -> str:
if not isinstance(cache, dict):
return ""
tail = _cache_tail_to_uint8(cache.get("tail_frames"))
if tail is None:
return ""
tensors = {"tail_frames": tail}
shifted_tail = _cache_tail_to_uint8(cache.get("tail_frames_shifted"))
if shifted_tail is not None:
tensors["tail_frames_shifted"] = shifted_tail
cache_metadata = {
"version": "2" if shifted_tail is not None else "1",
"handler": self.system_handler,
"scale": str(cache.get("scale", "")),
"variant": str(cache.get("variant", "")),
"metadata": json.dumps(metadata or {}, ensure_ascii=True, sort_keys=True),
}
cache_metadata.update({key: str(cache[key]) for key in ("two_pass", "shift_y", "shift_x", "out_shift_y", "out_shift_x") if key in cache})
sidecar_path = self.cache_sidecar_path(output_filename)
Path(sidecar_path).parent.mkdir(parents=True, exist_ok=True)
save_file(tensors, sidecar_path, metadata=cache_metadata)
return sidecar_path
def load_continue_cache(self, output_filename: str) -> Any:
sidecar_path = self.cache_sidecar_path(output_filename)
if not Path(sidecar_path).is_file():
raise gr.Error(f"FlashVSR continuation cache is missing: {sidecar_path}")
from safetensors import safe_open
with safe_open(sidecar_path, framework="pt", device="cpu") as handle:
metadata = dict(handle.metadata() or {})
cache = {"tail_frames": _load_tail_tensor(handle, "tail_frames", sidecar_path), "scale": _coerce_float(metadata.get("scale"), 0.0), "variant": str(metadata.get("variant") or "")}
if "tail_frames_shifted" in set(handle.keys()):
cache["tail_frames_shifted"] = _load_tail_tensor(handle, "tail_frames_shifted", sidecar_path)
cache.update({key: _coerce_float(metadata.get(key), 0.0) for key in ("shift_y", "shift_x", "out_shift_y", "out_shift_x") if key in metadata})
if "two_pass" in metadata:
cache["two_pass"] = str(metadata.get("two_pass")).lower() == "true"
return cache
def continue_cache_from_tail_frames(self, tail_frames: Any, target_control: str | None = None) -> Any:
tail = _cache_tail_to_uint8(tail_frames)
if tail is None:
return None
return {"tail_frames": tail, "scale": FlashVSRBridge.scale_for_upsampling(self.normalize_target_control(target_control)) or 0.0, "variant": "", "fallback": True}
def _coerce_float(value: Any, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return float(default)
def _cache_tail_to_uint8(tail: Any) -> torch.Tensor | None:
if not torch.is_tensor(tail) or tail.ndim != 4 or int(tail.shape[1]) <= 0:
return None
if tail.dtype == torch.uint8:
return tail.detach().cpu().contiguous()
return tail.detach().cpu().float().clamp(-1.0, 1.0).add(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8).contiguous()
def _load_tail_tensor(handle, key: str, sidecar_path: str) -> torch.Tensor:
tail = handle.get_tensor(key)
if not torch.is_tensor(tail) or tail.ndim != 4:
raise gr.Error(f"FlashVSR continuation cache is invalid: {sidecar_path}")
return tail.clone().contiguous() if tail.dtype == torch.uint8 else tail.float().clamp_(-1.0, 1.0).contiguous()
def scale_for_lanczos(value: str | None) -> float | None:
text = str(value or "").strip().lower()
if not text.startswith("lanczos"):
return None
try:
scale = float(text[len("lanczos"):])
except ValueError:
return None
return scale if scale in FlashVSRBridge.UPSAMPLING_RATIOS else None
def scale_for_any(value: str | None) -> float | None:
text = str(value or "").strip()
if len(text) == 0:
return None
scale = scale_for_lanczos(text) or FlashVSRBridge.scale_for_upsampling(text)
if scale is not None:
return scale
try:
scale = float(text)
except ValueError:
return None
return scale if scale in FlashVSRBridge.UPSAMPLING_RATIOS else None
def upsampling_prefix_for_process(process_settings: dict | None) -> str:
settings = process_settings if isinstance(process_settings, dict) else {}
method = str(settings.get("spatial_upsampling_method") or "").strip().lower()
if method in ("lanczos", FlashVSRBridge.UPSAMPLING_VALUE_PREFIX, FlashVSRBridge.UPSAMPLING_TWO_PASS_VALUE_PREFIX):
return method
target = str(settings.get("target_ratio") or "").strip().lower()
if target.startswith("lanczos"):
return "lanczos"
if target.startswith(FlashVSRBridge.UPSAMPLING_TWO_PASS_VALUE_PREFIX):
return FlashVSRBridge.UPSAMPLING_TWO_PASS_VALUE_PREFIX
return FlashVSRBridge.UPSAMPLING_VALUE_PREFIX
def upsampling_value(prefix: str, scale: float) -> str:
return f"{prefix}{FlashVSRBridge.format_ratio(scale)}"
HANDLER = FlashVSRProcessHandler()