# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved # pyre-unsafe """ Transformer decoder. Inspired from Pytorch's version, adds the pre-norm variant """ import math from functools import partial from typing import Any, Dict, List, Optional, Union import numpy as np import torch import torch.nn.functional as torchF from ..sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis, compute_axial_cis_real from ..sam.transformer import RoPEAttention from torch import nn, Tensor from torch.nn.attention import sdpa_kernel, SDPBackend from torchvision.ops.roi_align import RoIAlign from .act_ckpt_utils import activation_ckpt_wrapper from .box_ops import box_cxcywh_to_xyxy from .model_misc import ( chunked_ffn_forward, gen_sineembed_for_position, get_activation_fn, get_clones, inverse_sigmoid, MLP, ) class TransformerDecoderLayer(nn.Module): def __init__( self, activation: str, d_model: int, dim_feedforward: int, dropout: float, cross_attention: nn.Module, n_heads: int, use_text_cross_attention: bool = False, ): super().__init__() # cross attention self.cross_attn = cross_attention self.norm1 = nn.LayerNorm(d_model) # cross attention text self.use_text_cross_attention = use_text_cross_attention if use_text_cross_attention: self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=0.0) self.catext_norm = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=0.0) self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, dim_feedforward) self.activation = get_activation_fn(activation) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm3 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): def _forward(x): return self.linear2(self.activation(self.linear1(x))) tgt2 = chunked_ffn_forward([tgt.clone()], self.linear1.out_features, self.linear1.in_features, _forward) tgt.add_(tgt2) del tgt2 tgt = self.norm3(tgt) return tgt def forward( self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 memory_text: Optional[Tensor] = None, # num_token, bs, d_model text_attention_mask: Optional[Tensor] = None, # bs, num_token # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention # dac dac=False, dac_use_selfatt_ln=True, presence_token=None, # skip inside deformable attn identity=0.0, **kwargs, # additional kwargs for compatibility ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ # self attention if self.self_attn is not None: if dac: # we only apply self attention to the first half of the queries assert tgt.shape[0] % 2 == 0 num_o2o_queries = tgt.shape[0] // 2 tgt_o2o = tgt[:num_o2o_queries] tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries] tgt_o2m = tgt[num_o2o_queries:] else: tgt_o2o = tgt tgt_query_pos_o2o = tgt_query_pos if presence_token is not None: tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0) tgt_query_pos_o2o = torch.cat( [torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0 ) tgt_query_pos = torch.cat( [torch.zeros_like(presence_token), tgt_query_pos], dim=0 ) q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o) tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask, need_weights=False)[0] del q, k tgt_o2o.add_(tgt2) del tgt2 if dac: if not dac_use_selfatt_ln: tgt_o2o = self.norm2(tgt_o2o) tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine if dac_use_selfatt_ln: tgt = self.norm2(tgt) else: tgt = tgt_o2o tgt = self.norm2(tgt) if self.use_text_cross_attention: tgt2 = self.ca_text( self.with_pos_embed(tgt, tgt_query_pos), memory_text, memory_text, key_padding_mask=text_attention_mask, need_weights=False, )[0] tgt.add_(tgt2) del tgt2 tgt = self.catext_norm(tgt) if presence_token is not None: presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) cross_attn_mask = torch.cat( [presence_token_mask, cross_attn_mask], dim=1 ) # (bs*nheads, 1+nq, hw) # Cross attention to image tgt2 = self.cross_attn( query=self.with_pos_embed(tgt, tgt_query_pos), key=self.with_pos_embed(memory, memory_pos), value=memory, attn_mask=cross_attn_mask, key_padding_mask=( memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None ), )[0] tgt.add_(tgt2) del tgt2 tgt = self.norm1(tgt) # ffn tgt = self.forward_ffn(tgt) presence_token_out = None if presence_token is not None: presence_token_out = tgt[:1] tgt = tgt[1:] return tgt, presence_token_out class TransformerDecoder(nn.Module): def __init__( self, d_model: int, frozen: bool, interaction_layer, layer, num_layers: int, num_queries: int, return_intermediate: bool, box_refine: bool = False, num_o2m_queries: int = 0, dac: bool = False, boxRPB: str = "none", # Experimental: An object query for SAM 2 tasks instance_query: bool = False, # Defines the number of additional instance queries, # 1 or 4 are the most likely for single vs multi mask support num_instances: int = 1, # Irrelevant if instance_query is False dac_use_selfatt_ln: bool = True, use_act_checkpoint: bool = False, compile_mode=None, presence_token: bool = False, clamp_presence_logits: bool = True, clamp_presence_logit_max_val: float = 10.0, use_normed_output_consistently: bool = True, separate_box_head_instance: bool = False, separate_norm_instance: bool = False, resolution: Optional[int] = None, stride: Optional[int] = None, ): super().__init__() self.d_model = d_model self.layers = get_clones(layer, num_layers) self.fine_layers = ( get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers ) self.num_layers = num_layers self.num_queries = num_queries self.dac = dac if dac: self.num_o2m_queries = num_queries tot_num_queries = num_queries else: self.num_o2m_queries = num_o2m_queries tot_num_queries = num_queries + num_o2m_queries self.norm = nn.LayerNorm(d_model) self.return_intermediate = return_intermediate self.bbox_embed = MLP(d_model, d_model, 4, 3) self.query_embed = nn.Embedding(tot_num_queries, d_model) self.instance_query_embed = None self.instance_query_reference_points = None self.use_instance_query = instance_query self.num_instances = num_instances self.use_normed_output_consistently = use_normed_output_consistently self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None self.instance_bbox_embed = None if separate_box_head_instance: self.instance_bbox_embed = MLP(d_model, d_model, 4, 3) if instance_query: self.instance_query_embed = nn.Embedding(num_instances, d_model) self.box_refine = box_refine if box_refine: nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) self.reference_points = nn.Embedding(num_queries, 4) if instance_query: self.instance_reference_points = nn.Embedding(num_instances, 4) assert boxRPB in ["none", "log", "linear", "both"] self.boxRPB = boxRPB if boxRPB != "none": try: nheads = self.layers[0].cross_attn_image.num_heads except AttributeError: nheads = self.layers[0].cross_attn.num_heads n_input = 4 if boxRPB == "both" else 2 self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2) self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2) self.compilable_cord_cache = None self.compilable_stored_size = None self.coord_cache = {} if resolution is not None and stride is not None: feat_size = resolution // stride coords_h, coords_w = self._get_coords( feat_size, feat_size, device="cuda" ) self.compilable_cord_cache = (coords_h, coords_w) self.compilable_stored_size = (feat_size, feat_size) self.roi_pooler = ( RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True) if interaction_layer is not None else None ) if frozen: for p in self.parameters(): p.requires_grad_(False) self.presence_token = None self.clamp_presence_logits = clamp_presence_logits self.clamp_presence_logit_max_val = clamp_presence_logit_max_val if presence_token: self.presence_token = nn.Embedding(1, d_model) self.presence_token_head = MLP(d_model, d_model, 1, 3) self.presence_token_out_norm = nn.LayerNorm(d_model) self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2) self.dac_use_selfatt_ln = dac_use_selfatt_ln self.use_act_checkpoint = use_act_checkpoint nn.init.normal_(self.query_embed.weight.data) if self.instance_query_embed is not None: nn.init.normal_(self.instance_query_embed.weight.data) assert self.roi_pooler is None assert self.return_intermediate, "support return_intermediate only" assert self.box_refine, "support box refine only" self.compile_mode = compile_mode self.compiled = False # We defer compilation till after the first forward, to first warm-up the boxRPB cache # assign layer index to each layer so that some layers can decide what to do # based on which layer index they are (e.g. cross attention to memory bank only # in selected layers) for layer_idx, layer in enumerate(self.layers): layer.layer_idx = layer_idx @staticmethod def _get_coords(H, W, device): coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W return coords_h, coords_w def _get_rpb_matrix(self, reference_boxes, feat_size): H, W = feat_size boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1) bs, num_queries, _ = boxes_xyxy.shape if self.compilable_cord_cache is None: self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device) self.compilable_stored_size = (H, W) if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == ( H, W, ): # good, hitting the cache, will be compilable coords_h, coords_w = self.compilable_cord_cache else: # cache miss, will create compilation issue # In case we're not compiling, we'll still rely on the dict-based cache if feat_size not in self.coord_cache: self.coord_cache[feat_size] = self._get_coords( H, W, reference_boxes.device ) coords_h, coords_w = self.coord_cache[feat_size] assert coords_h.shape == (H,) assert coords_w.shape == (W,) deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] deltas_y = deltas_y.view(bs, num_queries, -1, 2) deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] deltas_x = deltas_x.view(bs, num_queries, -1, 2) if self.boxRPB in ["log", "both"]: deltas_x_log = deltas_x * 8 # normalize to -8, 8 deltas_x_log = ( torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8) ) deltas_y_log = deltas_y * 8 # normalize to -8, 8 deltas_y_log = ( torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8) ) if self.boxRPB == "log": deltas_x = deltas_x_log deltas_y = deltas_y_log else: deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1) deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1) if self.training: assert self.use_act_checkpoint, "activation ckpt not enabled in decoder" deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)( x=deltas_x, act_ckpt_enable=self.training and self.use_act_checkpoint, ) # bs, num_queries, W, n_heads deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)( x=deltas_y, act_ckpt_enable=self.training and self.use_act_checkpoint, ) # bs, num_queries, H, n_heads if not torch.compiler.is_dynamo_compiling(): assert deltas_x.shape[:3] == (bs, num_queries, W) assert deltas_y.shape[:3] == (bs, num_queries, H) B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( 2 ) # bs, num_queries, H, W, n_heads if not torch.compiler.is_dynamo_compiling(): assert B.shape[:4] == (bs, num_queries, H, W) B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W B = B.contiguous() # memeff attn likes ordered strides if not torch.compiler.is_dynamo_compiling(): assert B.shape[2:] == (num_queries, H * W) return B def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, # for text memory_text: Optional[Tensor] = None, text_attention_mask: Optional[Tensor] = None, # if `apply_dac` is None, it will default to `self.dac` apply_dac: Optional[bool] = None, is_instance_prompt=False, decoder_extra_kwargs: Optional[Dict] = None, # ROI memory bank obj_roi_memory_feat=None, obj_roi_memory_mask=None, box_head_trk=None, ): """ Input: - tgt: nq, bs, d_model - memory: \\sum{hw}, bs, d_model - pos: \\sum{hw}, bs, d_model - reference_boxes: nq, bs, 4 (after sigmoid) - valid_ratios/spatial_shapes: bs, nlevel, 2 """ if memory_mask is not None: assert self.boxRPB == "none", ( "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented" ) apply_dac = apply_dac if apply_dac is not None else self.dac if apply_dac: assert (tgt.shape[0] == self.num_queries) or ( self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings) ) tgt = tgt.repeat(2, 1, 1) # note that we don't tile tgt_mask, since DAC doesn't # use self-attention in o2m queries if reference_boxes is not None: assert (reference_boxes.shape[0] == self.num_queries) or ( self.use_instance_query and ( reference_boxes.shape[0] == self.instance_query_embed.num_embeddings ) ) reference_boxes = reference_boxes.repeat(2, 1, 1) bs = tgt.shape[1] intermediate = [] intermediate_presence_logits = [] presence_feats = None if self.box_refine: if reference_boxes is None: # In this case, we're in a one-stage model, so we generate the reference boxes reference_boxes = self.reference_points.weight.unsqueeze(1) reference_boxes = ( reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1) ) reference_boxes = reference_boxes.sigmoid() intermediate_ref_boxes = [reference_boxes] else: reference_boxes = None intermediate_ref_boxes = None output = tgt presence_out = None if self.presence_token is not None and is_instance_prompt is False: # expand to batch dim presence_out = self.presence_token.weight[None].expand(1, bs, -1) box_head = self.bbox_embed if is_instance_prompt and self.instance_bbox_embed is not None: box_head = self.instance_bbox_embed out_norm = self.norm if is_instance_prompt and self.instance_norm is not None: out_norm = self.instance_norm for layer_idx, layer in enumerate(self.layers): reference_points_input = ( reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :] ) # nq, bs, nlevel, 4 query_sine_embed = gen_sineembed_for_position( reference_points_input[:, :, 0, :], self.d_model ) # nq, bs, d_model*2 # conditional query query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model if self.boxRPB != "none" and reference_boxes is not None: assert spatial_shapes.shape[0] == 1, ( "only single scale support implemented" ) memory_mask = self._get_rpb_matrix( reference_boxes, (spatial_shapes[0, 0], spatial_shapes[0, 1]), ) memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W) if self.training: assert self.use_act_checkpoint, ( "Activation checkpointing not enabled in the decoder" ) output, presence_out = activation_ckpt_wrapper(layer)( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory_text=memory_text, text_attention_mask=text_attention_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, self_attn_mask=tgt_mask, cross_attn_mask=memory_mask, dac=apply_dac, dac_use_selfatt_ln=self.dac_use_selfatt_ln, presence_token=presence_out, **(decoder_extra_kwargs or {}), act_ckpt_enable=self.training and self.use_act_checkpoint, # ROI memory bank obj_roi_memory_feat=obj_roi_memory_feat, obj_roi_memory_mask=obj_roi_memory_mask, ) # iter update if self.box_refine: reference_before_sigmoid = inverse_sigmoid(reference_boxes) if box_head_trk is None: # delta_unsig = self.bbox_embed(output) if not self.use_normed_output_consistently: delta_unsig = box_head(output) else: delta_unsig = box_head(out_norm(output)) else: # box_head_trk use a separate box head for tracking queries Q_det = decoder_extra_kwargs["Q_det"] assert output.size(0) >= Q_det delta_unsig_det = self.bbox_embed(output[:Q_det]) delta_unsig_trk = box_head_trk(output[Q_det:]) delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() reference_boxes = new_reference_points.detach() if layer_idx != self.num_layers - 1: intermediate_ref_boxes.append(new_reference_points) else: raise NotImplementedError("not implemented yet") intermediate.append(out_norm(output)) if self.presence_token is not None and is_instance_prompt is False: # norm, mlp head intermediate_layer_presence_logits = self.presence_token_head( self.presence_token_out_norm(presence_out) ).squeeze(-1) # clamp to mitigate numerical issues if self.clamp_presence_logits: intermediate_layer_presence_logits.clamp( min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val, ) intermediate_presence_logits.append(intermediate_layer_presence_logits) presence_feats = presence_out.clone() if not self.compiled and self.compile_mode is not None: self.forward = torch.compile( self.forward, mode=self.compile_mode, fullgraph=True ) self.compiled = True return ( torch.stack(intermediate), torch.stack(intermediate_ref_boxes), ( torch.stack(intermediate_presence_logits) if self.presence_token is not None and is_instance_prompt is False else None ), presence_feats, ) class TransformerEncoderCrossAttention(nn.Module): def __init__( self, d_model: int, frozen: bool, pos_enc_at_input: bool, layer, num_layers: int, use_act_checkpoint: bool = False, batch_first: bool = False, # Do layers expect batch first input? # which layers to exclude cross attention? default: None, means all # layers use cross attention remove_cross_attention_layers: Optional[list] = None, ): super().__init__() self.d_model = d_model self.layers = get_clones(layer, num_layers) self.num_layers = num_layers self.norm = nn.LayerNorm(d_model) self.pos_enc_at_input = pos_enc_at_input self.use_act_checkpoint = use_act_checkpoint if frozen: for p in self.parameters(): p.requires_grad_(False) self.batch_first = batch_first # remove cross attention layers if specified self.remove_cross_attention_layers = [False] * self.num_layers if remove_cross_attention_layers is not None: for i in remove_cross_attention_layers: self.remove_cross_attention_layers[i] = True assert len(self.remove_cross_attention_layers) == len(self.layers) for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers): if remove_cross_attention: self.layers[i].cross_attn_image = None self.layers[i].norm2 = None def forward( self, src, # self-attention inputs prompt, # cross-attention inputs src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs src_key_padding_mask: Optional[Tensor] = None, prompt_key_padding_mask: Optional[Tensor] = None, src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs feat_sizes: Optional[list] = None, num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* ): if isinstance(src, list): assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list) assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1 src, src_key_padding_mask, src_pos = ( src[0], src_key_padding_mask[0], src_pos[0], ) assert src.shape[1] == prompt.shape[1], ( "Batch size must be the same for src and prompt" ) output = src if self.pos_enc_at_input and src_pos is not None: output = output.clone() output.add_(src_pos, alpha=0.1) if self.batch_first: # Convert to batch first output = output.transpose(0, 1) src_pos = src_pos.transpose(0, 1) prompt = prompt.transpose(0, 1) prompt_pos = prompt_pos.transpose(0, 1) for layer in self.layers: kwds = {} if isinstance(layer.cross_attn_image, RoPEAttention): kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} output = activation_ckpt_wrapper(layer)( tgt=output, memory=prompt, tgt_mask=src_mask, memory_mask=prompt_mask, tgt_key_padding_mask=src_key_padding_mask, memory_key_padding_mask=prompt_key_padding_mask, pos=prompt_pos, query_pos=src_pos, dac=False, attn_bias=None, act_ckpt_enable=self.training and self.use_act_checkpoint, **kwds, ) normed_output = self.norm(output) if self.batch_first: # Convert back to seq first normed_output = normed_output.transpose(0, 1) src_pos = src_pos.transpose(0, 1) return { "memory": normed_output, "pos_embed": src_pos, "padding_mask": src_key_padding_mask, } class TransformerDecoderLayerv1(nn.Module): def __init__( self, activation: str, cross_attention: nn.Module, d_model: int, dim_feedforward: int, dropout: float, pos_enc_at_attn: bool, pos_enc_at_cross_attn_keys: bool, pos_enc_at_cross_attn_queries: bool, pre_norm: bool, self_attention: nn.Module, ): super().__init__() self.d_model = d_model self.dim_feedforward = dim_feedforward self.self_attn = self_attention self.cross_attn_image = cross_attention # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.activation_str = activation self.activation = get_activation_fn(activation) self.pre_norm = pre_norm self.pos_enc_at_attn = pos_enc_at_attn self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys def forward_ffn(self, x): def _forward(x): return self.linear2(self.activation(self.linear1(x))) return chunked_ffn_forward(x, self.linear1.out_features, self.linear1.in_features, _forward) def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, **kwargs, ): q = k = tgt + query_pos if self.pos_enc_at_attn else tgt # Self attention tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, )[0] del q, k tgt.add_(tgt2) del tgt2 tgt = self.norm1(tgt) # Cross attention to image tgt2 = self.cross_attn_image( query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt, key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt.add_(tgt2) del tgt2 tgt = self.norm2(tgt) # FFN tgt2 = self.forward_ffn([tgt.clone()]) tgt.add_(tgt2) del tgt2 tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, dac: bool = False, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, attn_bias: Optional[Tensor] = None, **kwargs, ): if dac: # we only apply self attention to the first half of the queries assert tgt.shape[0] % 2 == 0 other_tgt = tgt[tgt.shape[0] // 2 :] tgt = tgt[: tgt.shape[0] // 2] tgt2 = self.norm1(tgt) q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, )[0] del q, k tgt.add_(tgt2) del tgt2 if dac: # Recombine tgt = torch.cat((tgt, other_tgt), dim=0) tgt2 = self.norm2(tgt) if self.pos_enc_at_cross_attn_queries: tgt2.add_(query_pos) tgt2 = self.cross_attn_image( query=tgt2, key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, attn_bias=attn_bias, )[0] tgt.add_(tgt2) del tgt2 tgt2 = self.norm3(tgt) tgt2_list = [tgt2] del tgt2 tgt2 = self.forward_ffn(tgt2_list) tgt.add_(tgt2) del tgt2 return tgt def forward( self, tgt, memory, dac: bool = False, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, attn_bias: Optional[Tensor] = None, **kwds: Any, ) -> torch.Tensor: fwd_fn = self.forward_pre if self.pre_norm else self.forward_post return fwd_fn( tgt, memory, dac=dac, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, attn_bias=attn_bias, **kwds, ) class TransformerDecoderLayerv2(TransformerDecoderLayerv1): def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any): super().__init__(*args, **kwds) self.cross_attention_first = cross_attention_first def forward_ffn(self, x): def _forward(x): return self.linear2(self.activation(self.linear1(x))) return chunked_ffn_forward(x, self.linear1.out_features, self.linear1.in_features, _forward) def _forward_sa(self, tgt, query_pos): # Self-Attention tgt2 = self.norm1(tgt) q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 tgt2 = self.self_attn(q, k, v=tgt2) del q, k tgt.add_(tgt2) del tgt2 return tgt def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): if self.cross_attn_image is None: return tgt kwds = {} if num_k_exclude_rope > 0: assert isinstance(self.cross_attn_image, RoPEAttention) kwds = {"num_k_exclude_rope": num_k_exclude_rope} # Cross-Attention tgt2 = self.norm2(tgt) if self.pos_enc_at_cross_attn_queries: tgt2.add_(query_pos) tgt2 = self.cross_attn_image( q=tgt2, k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, v=memory, **kwds, ) tgt.add_(tgt2) del tgt2 return tgt def forward_pre( self, tgt, memory, dac: bool, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, attn_bias: Optional[Tensor] = None, num_k_exclude_rope: int = 0, ): assert dac is False assert tgt_mask is None assert memory_mask is None assert tgt_key_padding_mask is None assert memory_key_padding_mask is None assert attn_bias is None if self.cross_attention_first: tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) tgt = self._forward_sa(tgt, query_pos) else: tgt = self._forward_sa(tgt, query_pos) tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) # MLP tgt2 = self.norm3(tgt) tgt2_list = [tgt2] del tgt2 tgt2 = self.forward_ffn(tgt2_list) tgt.add_(tgt2) del tgt2 return tgt def forward(self, *args: Any, **kwds: Any) -> torch.Tensor: if self.pre_norm: return self.forward_pre(*args, **kwds) raise NotImplementedError def functional_attention( qkv_list: list, *, dropout: float, num_heads: int, num_k_exclude_rope: int = 0, freqs_cis: Optional[Tensor] = None, freqs_cis_real: Optional[Tensor] = None, freqs_cis_imag: Optional[Tensor] = None, use_fa3: bool = False, use_rope_real: bool = False, rope_k_repeat: bool, ) -> Union[Tensor, tuple[Tensor, Tensor]]: q, k, v = qkv_list qkv_list.clear() b, n, cq = q.shape _, m, ck = k.shape _, _, cv = v.shape if b > 1: assert k.shape[0] == v.shape[0] == b else: # broadcast-able assert k.shape[0] == b == 1, f"{q.shape=} {k.shape=} {v.shape=}" assert v.shape[1] == m q = q.reshape(b, n, num_heads, cq // num_heads).transpose(1, 2) k = k.reshape(b, m, num_heads, ck // num_heads).transpose(1, 2) v = v.reshape(v.shape[0], m, num_heads, cv // num_heads).transpose(1, 2) if freqs_cis is not None or freqs_cis_real is not None: num_k_rope = k.size(-2) - num_k_exclude_rope if use_rope_real: qk_list = [q, k[:, :, :num_k_rope]] del q q, k_rope = apply_rotary_enc_real( qk_list, freqs_cis_real=freqs_cis_real, freqs_cis_imag=freqs_cis_imag, repeat_freqs_k=rope_k_repeat, ) k[:, :, :num_k_rope] = k_rope del k_rope else: qk_list = [q, k[:, :, :num_k_rope]] del q q, k_rope = apply_rotary_enc( qk_list, freqs_cis=freqs_cis, repeat_freqs_k=rope_k_repeat, ) k[:, :, :num_k_rope] = k_rope del k_rope if use_fa3: from ..perflib.fa3 import flash_attn_func assert dropout == 0.0 out = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)) del q, k, v else: with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]): out = torchF.scaled_dot_product_attention(q, k, v, dropout_p=dropout) del q, k, v out = out.transpose(1, 2) # B * n * n_heads * (cv // num_heads) out = out.reshape(b, n, cv) return out class SimpleRoPEAttention(nn.Module): """ Attention with rotary position encoding. This class is "simple" because it does not perform q/k/v/out projections. """ def __init__( self, d_model: int, num_heads: int, dropout_p: float, rope_theta=10000.0, # whether to repeat q rope to match k length # this is needed for cross-attention to memories rope_k_repeat=False, feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution use_fa3: bool = False, use_rope_real: bool = False, ): super().__init__() self.num_heads = num_heads compute_fn = compute_axial_cis_real if use_rope_real else compute_axial_cis self.compute_cis = partial(compute_fn, dim=d_model // num_heads, theta=rope_theta) device = None self.freqs_cis = None self.freqs_cis_real = None self.freqs_cis_imag = None self.use_fa3 = use_fa3 self.use_rope_real = use_rope_real if self.use_rope_real: self.freqs_cis_real, self.freqs_cis_imag = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1], device=device) else: self.freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1], device=device) self.rope_k_repeat = rope_k_repeat def forward( self, qkv_list: list, num_k_exclude_rope: int = 0, ) -> Union[Tensor, tuple[Tensor, Tensor]]: q = qkv_list[0] # Apply rotary position encoding w = h = math.sqrt(q.shape[-2]) freqs_len = self.freqs_cis_real.shape[0] if self.use_rope_real else self.freqs_cis.shape[0] if freqs_len != q.shape[-2]: if self.use_rope_real: self.freqs_cis_real, self.freqs_cis_imag = self.compute_cis(end_x=w, end_y=h, device=q.device) else: self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device) elif self.use_rope_real: self.freqs_cis_real = self.freqs_cis_real.to(q.device) self.freqs_cis_imag = self.freqs_cis_imag.to(q.device) else: self.freqs_cis = self.freqs_cis.to(q.device) if q.shape[-2] != qkv_list[1].shape[-2]: assert self.rope_k_repeat del q out = functional_attention( qkv_list, dropout=0.0, num_heads=self.num_heads, num_k_exclude_rope=num_k_exclude_rope, freqs_cis=self.freqs_cis, freqs_cis_real=self.freqs_cis_real if self.use_rope_real else None, freqs_cis_imag=self.freqs_cis_imag if self.use_rope_real else None, use_fa3=self.use_fa3, use_rope_real=self.use_rope_real, rope_k_repeat=self.rope_k_repeat, ) return out class DecoupledTransformerDecoderLayerv2(nn.Module): def __init__( self, *, activation: str, d_model: int, num_heads: int, dim_feedforward: int, dropout: float, pos_enc_at_attn: bool, pos_enc_at_cross_attn_keys: bool, pos_enc_at_cross_attn_queries: bool, pre_norm: bool, cross_attention_first: bool = False, self_attention_rope: SimpleRoPEAttention, cross_attention_rope: SimpleRoPEAttention, ): super().__init__() self.d_model = d_model self.num_heads = num_heads self.dim_feedforward = dim_feedforward self.self_attn_q_proj = nn.Linear(d_model, d_model) self.self_attn_k_proj = nn.Linear(d_model, d_model) self.self_attn_v_proj = nn.Linear(d_model, d_model) self.self_attn_out_proj = nn.Linear(d_model, d_model) self.cross_attn_q_proj = nn.Linear(d_model, d_model) self.cross_attn_k_proj = nn.Linear(d_model, d_model) self.cross_attn_v_proj = nn.Linear(d_model, d_model) self.cross_attn_out_proj = nn.Linear(d_model, d_model) self.image_cross_attn_q_proj = nn.Linear(d_model, d_model) self.image_cross_attn_k_proj = nn.Linear(d_model, d_model) self.self_attention_rope = self_attention_rope self.cross_attention_rope = cross_attention_rope # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.activation_str = activation self.activation = get_activation_fn(activation) self.pre_norm = pre_norm self.pos_enc_at_attn = pos_enc_at_attn self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys self.cross_attention_first = cross_attention_first def forward_ffn(self, x): def _forward(x): return self.linear2(self.activation(self.linear1(x))) return chunked_ffn_forward(x, self.linear1.out_features, self.linear1.in_features, _forward) def _forward_sa(self, tgt, query_pos): # Self-Attention tgt2 = self.norm1(tgt) v = self.self_attn_v_proj(tgt2) if self.pos_enc_at_attn: tgt2.add_(query_pos) q = self.self_attn_q_proj(tgt2) k = self.self_attn_k_proj(tgt2) del tgt2 qkv_list = [q, k, v] del q, k, v out = self.self_attention_rope(qkv_list) tgt2 = self.self_attn_out_proj(out) del out tgt.add_(tgt2) del tgt2 return tgt def _forward_ca( self, *, image, tgt, memory_image, memory, query_pos, memory_image_pos, num_k_exclude_rope=0, ): kwds = {} if num_k_exclude_rope > 0: assert isinstance(self.cross_attention_rope, SimpleRoPEAttention) kwds = {"num_k_exclude_rope": num_k_exclude_rope} # Cross-Attention tgt2 = self.norm2(tgt) q = self.image_cross_attn_q_proj(image) q.add_(self.cross_attn_q_proj(tgt2)) if self.pos_enc_at_cross_attn_queries: q.add_(query_pos) k = self.image_cross_attn_k_proj(memory_image) k.add_(self.cross_attn_k_proj(memory)) if self.pos_enc_at_cross_attn_keys: k.add_(memory_image_pos) v = self.cross_attn_v_proj(memory) del tgt2 qkv_list = [q, k, v] del q, k, v out = self.cross_attention_rope(qkv_list, **kwds) tgt2 = self.cross_attn_out_proj(out) del out tgt.add_(tgt2) del tgt2 return tgt def forward_pre( self, *, image, tgt, memory_image, memory, image_pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, memory_image_pos: Optional[Tensor] = None, memory_pos: Optional[Tensor] = None, num_k_exclude_rope: int = 0, ): if self.cross_attention_first: tgt = self._forward_ca( image=image, tgt=tgt, memory_image=memory_image, memory=memory, query_pos=query_pos, memory_image_pos=memory_image_pos, num_k_exclude_rope=num_k_exclude_rope, ) tgt = self._forward_sa(tgt, query_pos) else: tgt = self._forward_sa(tgt, query_pos) tgt = self._forward_ca( image=image, tgt=tgt, memory_image=memory_image, memory=memory, query_pos=query_pos, memory_image_pos=memory_image_pos, num_k_exclude_rope=num_k_exclude_rope, ) # MLP tgt2 = self.norm3(tgt) tgt2_list = [tgt2] del tgt2 tgt2 = self.forward_ffn(tgt2_list) tgt.add_(tgt2) del tgt2 return image, tgt def forward(self, *args: Any, **kwds: Any) -> torch.Tensor: if self.pre_norm: return self.forward_pre(*args, **kwds) raise NotImplementedError class TransformerEncoderDecoupledCrossAttention(nn.Module): def __init__( self, d_model: int, frozen: bool, pos_enc_at_input: bool, layer, num_layers: int, use_act_checkpoint: bool = False, batch_first: bool = False, # Do layers expect batch first input? use_image_in_output: bool = True, ): super().__init__() self.d_model = d_model self.layers = get_clones(layer, num_layers) self.num_layers = num_layers self.norm = nn.LayerNorm(d_model) self.pos_enc_at_input = pos_enc_at_input self.use_act_checkpoint = use_act_checkpoint self.use_image_in_output = use_image_in_output if frozen: for p in self.parameters(): p.requires_grad_(False) self.batch_first = batch_first def forward( self, image: Tensor, # image features src: Tensor, # self-attention inputs; object features memory_image: Tensor, # cross-attention inputs; image features memory: Tensor, # cross-attention inputs; object features image_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs memory_image_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* ): assert src.shape[1] == memory.shape[1], ( "Batch size must be the same for src and memory" ) assert image.shape[1] == memory_image.shape[1], ( "Batch size must be the same for image and memory_image" ) output = src if self.pos_enc_at_input and src_pos is not None: output = output.clone() output.add_(src_pos, alpha=0.1) if self.batch_first: # Convert to batch first output = output.transpose(0, 1) src_pos = src_pos.transpose(0, 1) image = image.transpose(0, 1) memory = memory.transpose(0, 1) memory_pos = memory_pos.transpose(0, 1) memory_image = memory_image.transpose(0, 1) memory_image_pos = memory_image_pos.transpose(0, 1) if memory_image.shape[1] != memory.shape[1]: # Pad memory_image with zeros, to accodmate object pointers assert (memory.shape[1] - memory_image.shape[1]) == num_obj_ptr_tokens, ( f"{memory.shape[1]} - {memory_image.shape[1]} != {num_obj_ptr_tokens}" ) memory_image = torch.cat( [ memory_image, torch.zeros( (memory_image.shape[0], num_obj_ptr_tokens) + memory_image.shape[2:], dtype=memory_image.dtype, device=memory_image.device, ), ], dim=1, ) if memory_image_pos is not None: assert ( memory_pos.shape[1] - memory_image_pos.shape[1] ) == num_obj_ptr_tokens, ( f"{memory_pos.shape[1]} - {memory_image_pos.shape[1]} != {num_obj_ptr_tokens}" ) # tpos is the same in the batch anyway; note that memory_image always has a batch size of 1 memory_image_pos = torch.cat( [ memory_image_pos, memory_pos[0:1, -num_obj_ptr_tokens:], ], dim=1, ) for layer in self.layers: image, output = activation_ckpt_wrapper(layer)( image=image, tgt=output, memory_image=memory_image, memory=memory, image_pos=image_pos, query_pos=src_pos, memory_image_pos=memory_image_pos, memory_pos=memory_pos, num_k_exclude_rope=num_obj_ptr_tokens, act_ckpt_enable=self.training and self.use_act_checkpoint, ) if self.use_image_in_output: output = output.clone() output.add_(image) normed_output = self.norm(output) else: normed_output = self.norm(output) if self.batch_first: # Convert back to seq first normed_output = normed_output.transpose(0, 1) src_pos = src_pos.transpose(0, 1) return { "memory": normed_output, "pos_embed": src_pos, }