File size: 14,357 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import importlib
import importlib.util
import os
from pathlib import Path
from typing import Iterable

import numpy as np
import torch
from PIL import Image

from shared.utils import files_locator as fl
from .logger import get_logger
from .model.device_utils import accelerator_autocast, empty_accelerator_cache, get_accelerator_device, is_accelerator_device


_PACKAGE_ROOT = Path(__file__).resolve().parent
_SAM3_FOLDER = "sam3"
_SAM3_CHECKPOINT_NAME = "sam3.1_multiplex_bf16.safetensors"
_SAM3_BPE_NAME = "bpe_simple_vocab_16e6.txt.gz"
KEEP_VIDEO_FRAMES_ON_CUDA = True
_TEXT_ENCODER_CACHE = None
_TEXT_ENCODER_CACHE_KEY = None
logger = get_logger(__name__)


def _cleanup():
    import gc

    gc.collect()
    empty_accelerator_cache()


def _load_model_builder():
    try:
        return importlib.import_module(".model_builder", package=__package__)
    except ModuleNotFoundError as exc:
        if exc.name != importlib.util.resolve_name(".model_builder", __package__):
            raise
    raise FileNotFoundError("SAM3.1 code was not found under preprocessing/sam3.")


def _checkpoint_path():
    for candidate in [
        os.path.join(_SAM3_FOLDER, _SAM3_CHECKPOINT_NAME),
        os.path.join("sam3.1", _SAM3_CHECKPOINT_NAME),
        _SAM3_CHECKPOINT_NAME,
    ]:
        checkpoint = fl.locate_file(candidate, error_if_none=False)
        if checkpoint is not None:
            return checkpoint, "sam3.1"
    checkpoint = _PACKAGE_ROOT / _SAM3_CHECKPOINT_NAME
    if checkpoint.is_file():
        return os.fspath(checkpoint), "sam3.1"
    raise FileNotFoundError("SAM3.1 bf16 safetensors checkpoint was not found by files_locator as sam3/sam3.1_multiplex_bf16.safetensors, sam3.1/sam3.1_multiplex_bf16.safetensors, or sam3.1_multiplex_bf16.safetensors, nor under preprocessing/sam3.")


def _bpe_path():
    for candidate in [
        os.path.join(_SAM3_FOLDER, _SAM3_BPE_NAME),
        os.path.join("sam3.1", _SAM3_BPE_NAME),
        _SAM3_BPE_NAME,
    ]:
        bpe_path = fl.locate_file(candidate, error_if_none=False)
        if bpe_path is not None:
            return bpe_path
    bpe_path = _PACKAGE_ROOT / "assets" / _SAM3_BPE_NAME
    if bpe_path.is_file():
        return os.fspath(bpe_path)
    raise FileNotFoundError("SAM3 BPE vocabulary was not found by files_locator as sam3/bpe_simple_vocab_16e6.txt.gz, sam3.1/bpe_simple_vocab_16e6.txt.gz, or bpe_simple_vocab_16e6.txt.gz, nor under preprocessing/sam3/assets.")


def _autocast_context():
    return accelerator_autocast()


def _bf16_prompt_payload(value):
    if torch.is_tensor(value):
        return value.to(dtype=torch.bfloat16) if value.is_floating_point() else value
    if isinstance(value, dict):
        return {key: _bf16_prompt_payload(item) for key, item in value.items()}
    if isinstance(value, list):
        return [_bf16_prompt_payload(item) for item in value]
    if isinstance(value, tuple):
        return tuple(_bf16_prompt_payload(item) for item in value)
    return value


def _format_keywords_for_log(keywords: list[str]):
    return ", ".join(f"'{keyword}'" for keyword in keywords)


def _to_numpy(value):
    if torch.is_tensor(value):
        return value.detach().cpu().numpy()
    return np.asarray(value)


def _sam3_outputs_to_binary_mask(outputs, height: int, width: int):
    if outputs is None or "out_binary_masks" not in outputs:
        return np.zeros((height, width), dtype=np.bool_)
    masks = _to_numpy(outputs["out_binary_masks"])
    if masks.size == 0:
        return np.zeros((height, width), dtype=np.bool_)
    if masks.ndim == 2:
        masks = masks[None, :, :]
    elif masks.ndim == 4 and masks.shape[1] == 1:
        masks = masks[:, 0]
    elif masks.ndim > 3:
        masks = masks.reshape((-1, *masks.shape[-2:]))
    if masks.shape[-2:] != (height, width):
        masks = np.stack([np.asarray(Image.fromarray(mask.astype(np.uint8)).resize((width, height), resample=Image.Resampling.NEAREST)) for mask in masks], axis=0)
    return masks.astype(bool).any(axis=0)


def resolve_sam3_grounding_batch_size(batch_size=None) -> int:
    if batch_size is not None:
        batch_size = int(batch_size)
        if batch_size > 0:
            return batch_size
    if not torch.cuda.is_available():
        return 2
    total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
    return 4 if total_vram_gb >= 8 else 2


def _encode_text_outputs(text_encoder, captions: list[str], device: torch.device):
    masks, memories, embeds = [], [], []
    if is_accelerator_device(device):
        text_encoder.to(device=device, dtype=torch.bfloat16)
    for caption in captions:
        with torch.inference_mode(), _autocast_context():
            text_attention_mask, text_memory, text_embeds = text_encoder([caption], device=device)
        masks.append(text_attention_mask.detach().cpu())
        memories.append(text_memory.detach().cpu())
        embeds.append(text_embeds.detach().cpu())
        del text_attention_mask, text_memory, text_embeds
        _cleanup()
    return {
        "language_features": torch.cat(memories, dim=1),
        "language_mask": torch.cat(masks, dim=0),
        "language_embeds": torch.cat(embeds, dim=1),
    }


def _encode_keyword_prompts(model_builder, checkpoint_path: str, bpe_path: str, keywords: list[str], keep_text_encoder_loaded: bool = False):
    global _TEXT_ENCODER_CACHE, _TEXT_ENCODER_CACHE_KEY
    text_encoder = None
    device = get_accelerator_device()
    cache_key = (checkpoint_path, bpe_path)
    preencoded = {}
    try:
        if keep_text_encoder_loaded and _TEXT_ENCODER_CACHE is not None and _TEXT_ENCODER_CACHE_KEY == cache_key:
            text_encoder = _TEXT_ENCODER_CACHE
        else:
            text_encoder = model_builder.build_sam3_text_encoder(checkpoint_path=checkpoint_path, bpe_path=bpe_path)
            if keep_text_encoder_loaded:
                _TEXT_ENCODER_CACHE = text_encoder
                _TEXT_ENCODER_CACHE_KEY = cache_key
        for keyword in keywords:
            preencoded[keyword] = _encode_text_outputs(text_encoder, [keyword, "visual", "geometric"], device)
    finally:
        if keep_text_encoder_loaded and text_encoder is not None:
            text_encoder.to("cpu")
        elif text_encoder is not None:
            del text_encoder
        _cleanup()
    return preencoded


def encode_sam3_keyword_prompts(keywords: Iterable[str], keep_text_encoder_loaded: bool = False):
    keywords = [str(keyword).strip() for keyword in keywords if str(keyword).strip()]
    if len(keywords) == 0:
        return {}
    model_builder = _load_model_builder()
    checkpoint_path, _ = _checkpoint_path()
    bpe_path = _bpe_path()
    return _encode_keyword_prompts(model_builder, checkpoint_path, bpe_path, keywords, keep_text_encoder_loaded=keep_text_encoder_loaded)


def clear_sam3_text_encoder_cache():
    global _TEXT_ENCODER_CACHE, _TEXT_ENCODER_CACHE_KEY
    if _TEXT_ENCODER_CACHE is not None:
        del _TEXT_ENCODER_CACHE
    _TEXT_ENCODER_CACHE = None
    _TEXT_ENCODER_CACHE_KEY = None
    _cleanup()


def fill_sam3_binary_mask_holes(mask: np.ndarray, fill_hole_area: int):
    fill_hole_area = max(0, int(fill_hole_area))
    if fill_hole_area == 0 or not np.any(mask):
        return mask.astype(np.bool_, copy=False)
    from .model.sam3_tracker_utils import fill_holes_in_mask_scores

    scores = torch.from_numpy(mask.astype(np.float32, copy=False))[None, None]
    scores = scores * 2 - 1
    filled = fill_holes_in_mask_scores(scores, max_area=fill_hole_area, fill_holes=True, remove_sprinkles=False)
    return filled[0, 0].numpy() > 0


def _load_predictor(
    model_builder=None,
    checkpoint_path=None,
    bpe_path=None,
    version=None,
    include_text_encoder=True,
    batched_grounding_batch_size=None,
    postprocess_batch_size=1,
    use_batched_grounding=True,
    trim_past_non_cond_mem_for_eval=True,
    fill_hole_area: int = 0,
    manual_model_loading: bool = False,
):
    model_builder = model_builder or _load_model_builder()
    checkpoint_path, version = (checkpoint_path, version) if checkpoint_path is not None and version is not None else _checkpoint_path()
    bpe_path = bpe_path or _bpe_path()
    grounding_batch_size = resolve_sam3_grounding_batch_size(batched_grounding_batch_size)
    return model_builder.build_sam3_predictor(checkpoint_path=checkpoint_path, bpe_path=bpe_path, version=version, use_fa3=False, use_rope_real=True, compile=False, warm_up=False, include_text_encoder=include_text_encoder, postprocess_batch_size=postprocess_batch_size, use_batched_grounding=use_batched_grounding, batched_grounding_batch_size=grounding_batch_size, trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, fill_hole_area=fill_hole_area, manual_model_loading=manual_model_loading)


def load_sam3_mask_predictor(
    *,
    include_text_encoder: bool = True,
    postprocess_batch_size: int = 1,
    use_batched_grounding: bool = True,
    batched_grounding_batch_size=None,
    trim_past_non_cond_mem_for_eval: bool = True,
    fill_hole_area: int = 0,
    manual_model_loading: bool = False,
):
    model_builder = _load_model_builder()
    checkpoint_path, version = _checkpoint_path()
    bpe_path = _bpe_path()
    return _load_predictor(
        model_builder,
        checkpoint_path,
        bpe_path,
        version,
        include_text_encoder=include_text_encoder,
        batched_grounding_batch_size=batched_grounding_batch_size,
        postprocess_batch_size=postprocess_batch_size,
        use_batched_grounding=use_batched_grounding,
        trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval,
        fill_hole_area=fill_hole_area,
        manual_model_loading=manual_model_loading,
    )


def run_sam3_video(
    video: np.ndarray,
    keywords: Iterable[str],
    *,
    include_text_encoder: bool = False,
    preencode_text: bool = True,
    batched_grounding_batch_size=None,
    postprocess_batch_size: int = 1,
    use_batched_grounding: bool = True,
    trim_past_non_cond_mem_for_eval: bool = True,
    keep_video_frames_on_cuda: bool = KEEP_VIDEO_FRAMES_ON_CUDA,
    cache_frame_outputs: bool = False,
    fill_hole_area: int = 0,
    progress_callback=None,
):
    keywords = [str(keyword).strip() for keyword in keywords if str(keyword).strip()]
    if len(keywords) == 0:
        return np.zeros(video.shape[:3], dtype=np.bool_)

    model_builder = _load_model_builder()
    checkpoint_path, version = _checkpoint_path()
    bpe_path = _bpe_path()
    _cleanup()
    if version == "sam3.1" and preencode_text:
        logger.info("SAM3 encoding keywords before propagation: %s", _format_keywords_for_log(keywords))
        preencoded_prompts = _encode_keyword_prompts(model_builder, checkpoint_path, bpe_path, keywords)
    else:
        preencoded_prompts = None
    video_predictor = _load_predictor(
        model_builder,
        checkpoint_path,
        bpe_path,
        version,
        include_text_encoder=include_text_encoder or preencoded_prompts is None,
        batched_grounding_batch_size=batched_grounding_batch_size,
        postprocess_batch_size=postprocess_batch_size,
        use_batched_grounding=use_batched_grounding,
        trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval,
        fill_hole_area=0,
    )
    num_frames, height, width, _ = video.shape
    video_pil = [Image.fromarray(video[i]) for i in range(num_frames)]
    session_id = None
    response = video_predictor.handle_request({"type": "start_session", "resource_path": video_pil, "offload_video_to_cpu": not keep_video_frames_on_cuda, "cache_frame_outputs": cache_frame_outputs})
    session_id = response["session_id"]
    dynamic_mask = np.zeros((num_frames, height, width), dtype=np.bool_)
    try:
        total_progress_steps = len(keywords) * num_frames
        for keyword_index, keyword in enumerate(keywords):
            progress_base = keyword_index * num_frames
            logger.info("SAM3 keyword currently being processed: '%s'", keyword)
            request = {"type": "add_prompt", "session_id": session_id, "frame_index": 0, "text": keyword}
            if preencoded_prompts is not None:
                request["preencoded_text_outputs"] = _bf16_prompt_payload(preencoded_prompts[keyword])
            with _autocast_context():
                result = video_predictor.handle_request(request)
                dynamic_mask[0] |= _sam3_outputs_to_binary_mask(result.get("outputs") if isinstance(result, dict) else None, height, width)
                if progress_callback is not None:
                    progress_callback(progress_base, total_progress_steps)
                internal_progress_seen = False

                def model_progress_callback(done, total):
                    nonlocal internal_progress_seen
                    internal_progress_seen = True
                    progress_callback(min(progress_base + int(done), total_progress_steps), total_progress_steps)

                stream_request = {
                    "type": "propagate_in_video",
                    "session_id": session_id,
                    "propagation_direction": "forward",
                    "start_frame_index": 0,
                    "max_frame_num_to_track": num_frames,
                }
                if progress_callback is not None:
                    stream_request["progress_callback"] = model_progress_callback
                propagated_frames = 0
                for result in video_predictor.handle_stream_request(stream_request):
                    propagated_frames += 1
                    if progress_callback is not None and not internal_progress_seen:
                        progress_callback(min(progress_base + propagated_frames, total_progress_steps), total_progress_steps)
                    outputs = result["outputs"]
                    dynamic_mask[result["frame_index"]] |= _sam3_outputs_to_binary_mask(outputs, height, width)
    finally:
        if session_id is not None:
            video_predictor.handle_request({"type": "close_session", "session_id": session_id})
        video_predictor.shutdown()
        del video_predictor
        _cleanup()
    if fill_hole_area > 0:
        dynamic_mask = np.stack([fill_sam3_binary_mask_holes(mask, fill_hole_area) for mask in dynamic_mask], axis=0)
    return dynamic_mask