import importlib import importlib.util import os from pathlib import Path from typing import Iterable import numpy as np import torch from PIL import Image from shared.utils import files_locator as fl from .logger import get_logger from .model.device_utils import accelerator_autocast, empty_accelerator_cache, get_accelerator_device, is_accelerator_device _PACKAGE_ROOT = Path(__file__).resolve().parent _SAM3_FOLDER = "sam3" _SAM3_CHECKPOINT_NAME = "sam3.1_multiplex_bf16.safetensors" _SAM3_BPE_NAME = "bpe_simple_vocab_16e6.txt.gz" KEEP_VIDEO_FRAMES_ON_CUDA = True _TEXT_ENCODER_CACHE = None _TEXT_ENCODER_CACHE_KEY = None logger = get_logger(__name__) def _cleanup(): import gc gc.collect() empty_accelerator_cache() def _load_model_builder(): try: return importlib.import_module(".model_builder", package=__package__) except ModuleNotFoundError as exc: if exc.name != importlib.util.resolve_name(".model_builder", __package__): raise raise FileNotFoundError("SAM3.1 code was not found under preprocessing/sam3.") def _checkpoint_path(): for candidate in [ os.path.join(_SAM3_FOLDER, _SAM3_CHECKPOINT_NAME), os.path.join("sam3.1", _SAM3_CHECKPOINT_NAME), _SAM3_CHECKPOINT_NAME, ]: checkpoint = fl.locate_file(candidate, error_if_none=False) if checkpoint is not None: return checkpoint, "sam3.1" checkpoint = _PACKAGE_ROOT / _SAM3_CHECKPOINT_NAME if checkpoint.is_file(): return os.fspath(checkpoint), "sam3.1" raise FileNotFoundError("SAM3.1 bf16 safetensors checkpoint was not found by files_locator as sam3/sam3.1_multiplex_bf16.safetensors, sam3.1/sam3.1_multiplex_bf16.safetensors, or sam3.1_multiplex_bf16.safetensors, nor under preprocessing/sam3.") def _bpe_path(): for candidate in [ os.path.join(_SAM3_FOLDER, _SAM3_BPE_NAME), os.path.join("sam3.1", _SAM3_BPE_NAME), _SAM3_BPE_NAME, ]: bpe_path = fl.locate_file(candidate, error_if_none=False) if bpe_path is not None: return bpe_path bpe_path = _PACKAGE_ROOT / "assets" / _SAM3_BPE_NAME if bpe_path.is_file(): return os.fspath(bpe_path) raise FileNotFoundError("SAM3 BPE vocabulary was not found by files_locator as sam3/bpe_simple_vocab_16e6.txt.gz, sam3.1/bpe_simple_vocab_16e6.txt.gz, or bpe_simple_vocab_16e6.txt.gz, nor under preprocessing/sam3/assets.") def _autocast_context(): return accelerator_autocast() def _bf16_prompt_payload(value): if torch.is_tensor(value): return value.to(dtype=torch.bfloat16) if value.is_floating_point() else value if isinstance(value, dict): return {key: _bf16_prompt_payload(item) for key, item in value.items()} if isinstance(value, list): return [_bf16_prompt_payload(item) for item in value] if isinstance(value, tuple): return tuple(_bf16_prompt_payload(item) for item in value) return value def _format_keywords_for_log(keywords: list[str]): return ", ".join(f"'{keyword}'" for keyword in keywords) def _to_numpy(value): if torch.is_tensor(value): return value.detach().cpu().numpy() return np.asarray(value) def _sam3_outputs_to_binary_mask(outputs, height: int, width: int): if outputs is None or "out_binary_masks" not in outputs: return np.zeros((height, width), dtype=np.bool_) masks = _to_numpy(outputs["out_binary_masks"]) if masks.size == 0: return np.zeros((height, width), dtype=np.bool_) if masks.ndim == 2: masks = masks[None, :, :] elif masks.ndim == 4 and masks.shape[1] == 1: masks = masks[:, 0] elif masks.ndim > 3: masks = masks.reshape((-1, *masks.shape[-2:])) if masks.shape[-2:] != (height, width): masks = np.stack([np.asarray(Image.fromarray(mask.astype(np.uint8)).resize((width, height), resample=Image.Resampling.NEAREST)) for mask in masks], axis=0) return masks.astype(bool).any(axis=0) def resolve_sam3_grounding_batch_size(batch_size=None) -> int: if batch_size is not None: batch_size = int(batch_size) if batch_size > 0: return batch_size if not torch.cuda.is_available(): return 2 total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) return 4 if total_vram_gb >= 8 else 2 def _encode_text_outputs(text_encoder, captions: list[str], device: torch.device): masks, memories, embeds = [], [], [] if is_accelerator_device(device): text_encoder.to(device=device, dtype=torch.bfloat16) for caption in captions: with torch.inference_mode(), _autocast_context(): text_attention_mask, text_memory, text_embeds = text_encoder([caption], device=device) masks.append(text_attention_mask.detach().cpu()) memories.append(text_memory.detach().cpu()) embeds.append(text_embeds.detach().cpu()) del text_attention_mask, text_memory, text_embeds _cleanup() return { "language_features": torch.cat(memories, dim=1), "language_mask": torch.cat(masks, dim=0), "language_embeds": torch.cat(embeds, dim=1), } def _encode_keyword_prompts(model_builder, checkpoint_path: str, bpe_path: str, keywords: list[str], keep_text_encoder_loaded: bool = False): global _TEXT_ENCODER_CACHE, _TEXT_ENCODER_CACHE_KEY text_encoder = None device = get_accelerator_device() cache_key = (checkpoint_path, bpe_path) preencoded = {} try: if keep_text_encoder_loaded and _TEXT_ENCODER_CACHE is not None and _TEXT_ENCODER_CACHE_KEY == cache_key: text_encoder = _TEXT_ENCODER_CACHE else: text_encoder = model_builder.build_sam3_text_encoder(checkpoint_path=checkpoint_path, bpe_path=bpe_path) if keep_text_encoder_loaded: _TEXT_ENCODER_CACHE = text_encoder _TEXT_ENCODER_CACHE_KEY = cache_key for keyword in keywords: preencoded[keyword] = _encode_text_outputs(text_encoder, [keyword, "visual", "geometric"], device) finally: if keep_text_encoder_loaded and text_encoder is not None: text_encoder.to("cpu") elif text_encoder is not None: del text_encoder _cleanup() return preencoded def encode_sam3_keyword_prompts(keywords: Iterable[str], keep_text_encoder_loaded: bool = False): keywords = [str(keyword).strip() for keyword in keywords if str(keyword).strip()] if len(keywords) == 0: return {} model_builder = _load_model_builder() checkpoint_path, _ = _checkpoint_path() bpe_path = _bpe_path() return _encode_keyword_prompts(model_builder, checkpoint_path, bpe_path, keywords, keep_text_encoder_loaded=keep_text_encoder_loaded) def clear_sam3_text_encoder_cache(): global _TEXT_ENCODER_CACHE, _TEXT_ENCODER_CACHE_KEY if _TEXT_ENCODER_CACHE is not None: del _TEXT_ENCODER_CACHE _TEXT_ENCODER_CACHE = None _TEXT_ENCODER_CACHE_KEY = None _cleanup() def fill_sam3_binary_mask_holes(mask: np.ndarray, fill_hole_area: int): fill_hole_area = max(0, int(fill_hole_area)) if fill_hole_area == 0 or not np.any(mask): return mask.astype(np.bool_, copy=False) from .model.sam3_tracker_utils import fill_holes_in_mask_scores scores = torch.from_numpy(mask.astype(np.float32, copy=False))[None, None] scores = scores * 2 - 1 filled = fill_holes_in_mask_scores(scores, max_area=fill_hole_area, fill_holes=True, remove_sprinkles=False) return filled[0, 0].numpy() > 0 def _load_predictor( model_builder=None, checkpoint_path=None, bpe_path=None, version=None, include_text_encoder=True, batched_grounding_batch_size=None, postprocess_batch_size=1, use_batched_grounding=True, trim_past_non_cond_mem_for_eval=True, fill_hole_area: int = 0, manual_model_loading: bool = False, ): model_builder = model_builder or _load_model_builder() checkpoint_path, version = (checkpoint_path, version) if checkpoint_path is not None and version is not None else _checkpoint_path() bpe_path = bpe_path or _bpe_path() grounding_batch_size = resolve_sam3_grounding_batch_size(batched_grounding_batch_size) return model_builder.build_sam3_predictor(checkpoint_path=checkpoint_path, bpe_path=bpe_path, version=version, use_fa3=False, use_rope_real=True, compile=False, warm_up=False, include_text_encoder=include_text_encoder, postprocess_batch_size=postprocess_batch_size, use_batched_grounding=use_batched_grounding, batched_grounding_batch_size=grounding_batch_size, trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, fill_hole_area=fill_hole_area, manual_model_loading=manual_model_loading) def load_sam3_mask_predictor( *, include_text_encoder: bool = True, postprocess_batch_size: int = 1, use_batched_grounding: bool = True, batched_grounding_batch_size=None, trim_past_non_cond_mem_for_eval: bool = True, fill_hole_area: int = 0, manual_model_loading: bool = False, ): model_builder = _load_model_builder() checkpoint_path, version = _checkpoint_path() bpe_path = _bpe_path() return _load_predictor( model_builder, checkpoint_path, bpe_path, version, include_text_encoder=include_text_encoder, batched_grounding_batch_size=batched_grounding_batch_size, postprocess_batch_size=postprocess_batch_size, use_batched_grounding=use_batched_grounding, trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, fill_hole_area=fill_hole_area, manual_model_loading=manual_model_loading, ) def run_sam3_video( video: np.ndarray, keywords: Iterable[str], *, include_text_encoder: bool = False, preencode_text: bool = True, batched_grounding_batch_size=None, postprocess_batch_size: int = 1, use_batched_grounding: bool = True, trim_past_non_cond_mem_for_eval: bool = True, keep_video_frames_on_cuda: bool = KEEP_VIDEO_FRAMES_ON_CUDA, cache_frame_outputs: bool = False, fill_hole_area: int = 0, progress_callback=None, ): keywords = [str(keyword).strip() for keyword in keywords if str(keyword).strip()] if len(keywords) == 0: return np.zeros(video.shape[:3], dtype=np.bool_) model_builder = _load_model_builder() checkpoint_path, version = _checkpoint_path() bpe_path = _bpe_path() _cleanup() if version == "sam3.1" and preencode_text: logger.info("SAM3 encoding keywords before propagation: %s", _format_keywords_for_log(keywords)) preencoded_prompts = _encode_keyword_prompts(model_builder, checkpoint_path, bpe_path, keywords) else: preencoded_prompts = None video_predictor = _load_predictor( model_builder, checkpoint_path, bpe_path, version, include_text_encoder=include_text_encoder or preencoded_prompts is None, batched_grounding_batch_size=batched_grounding_batch_size, postprocess_batch_size=postprocess_batch_size, use_batched_grounding=use_batched_grounding, trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, fill_hole_area=0, ) num_frames, height, width, _ = video.shape video_pil = [Image.fromarray(video[i]) for i in range(num_frames)] session_id = None response = video_predictor.handle_request({"type": "start_session", "resource_path": video_pil, "offload_video_to_cpu": not keep_video_frames_on_cuda, "cache_frame_outputs": cache_frame_outputs}) session_id = response["session_id"] dynamic_mask = np.zeros((num_frames, height, width), dtype=np.bool_) try: total_progress_steps = len(keywords) * num_frames for keyword_index, keyword in enumerate(keywords): progress_base = keyword_index * num_frames logger.info("SAM3 keyword currently being processed: '%s'", keyword) request = {"type": "add_prompt", "session_id": session_id, "frame_index": 0, "text": keyword} if preencoded_prompts is not None: request["preencoded_text_outputs"] = _bf16_prompt_payload(preencoded_prompts[keyword]) with _autocast_context(): result = video_predictor.handle_request(request) dynamic_mask[0] |= _sam3_outputs_to_binary_mask(result.get("outputs") if isinstance(result, dict) else None, height, width) if progress_callback is not None: progress_callback(progress_base, total_progress_steps) internal_progress_seen = False def model_progress_callback(done, total): nonlocal internal_progress_seen internal_progress_seen = True progress_callback(min(progress_base + int(done), total_progress_steps), total_progress_steps) stream_request = { "type": "propagate_in_video", "session_id": session_id, "propagation_direction": "forward", "start_frame_index": 0, "max_frame_num_to_track": num_frames, } if progress_callback is not None: stream_request["progress_callback"] = model_progress_callback propagated_frames = 0 for result in video_predictor.handle_stream_request(stream_request): propagated_frames += 1 if progress_callback is not None and not internal_progress_seen: progress_callback(min(progress_base + propagated_frames, total_progress_steps), total_progress_steps) outputs = result["outputs"] dynamic_mask[result["frame_index"]] |= _sam3_outputs_to_binary_mask(outputs, height, width) finally: if session_id is not None: video_predictor.handle_request({"type": "close_session", "session_id": session_id}) video_predictor.shutdown() del video_predictor _cleanup() if fill_hole_area > 0: dynamic_mask = np.stack([fill_sam3_binary_mask_holes(mask, fill_hole_area) for mask in dynamic_mask], axis=0) return dynamic_mask