File size: 52,287 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
"""
Transformer decoder.
Inspired from Pytorch's version, adds the pre-norm variant
"""
import math
from functools import partial
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as torchF
from ..sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis, compute_axial_cis_real
from ..sam.transformer import RoPEAttention
from torch import nn, Tensor
from torch.nn.attention import sdpa_kernel, SDPBackend
from torchvision.ops.roi_align import RoIAlign
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import (
chunked_ffn_forward,
gen_sineembed_for_position,
get_activation_fn,
get_clones,
inverse_sigmoid,
MLP,
)
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
activation: str,
d_model: int,
dim_feedforward: int,
dropout: float,
cross_attention: nn.Module,
n_heads: int,
use_text_cross_attention: bool = False,
):
super().__init__()
# cross attention
self.cross_attn = cross_attention
self.norm1 = nn.LayerNorm(d_model)
# cross attention text
self.use_text_cross_attention = use_text_cross_attention
if use_text_cross_attention:
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=0.0)
self.catext_norm = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=0.0)
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = get_activation_fn(activation)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm3 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
def _forward(x):
return self.linear2(self.activation(self.linear1(x)))
tgt2 = chunked_ffn_forward([tgt.clone()], self.linear1.out_features, self.linear1.in_features, _forward)
tgt.add_(tgt2)
del tgt2
tgt = self.norm3(tgt)
return tgt
def forward(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
memory_text: Optional[Tensor] = None, # num_token, bs, d_model
text_attention_mask: Optional[Tensor] = None, # bs, num_token
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
# dac
dac=False,
dac_use_selfatt_ln=True,
presence_token=None,
# skip inside deformable attn
identity=0.0,
**kwargs, # additional kwargs for compatibility
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
# self attention
if self.self_attn is not None:
if dac:
# we only apply self attention to the first half of the queries
assert tgt.shape[0] % 2 == 0
num_o2o_queries = tgt.shape[0] // 2
tgt_o2o = tgt[:num_o2o_queries]
tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
tgt_o2m = tgt[num_o2o_queries:]
else:
tgt_o2o = tgt
tgt_query_pos_o2o = tgt_query_pos
if presence_token is not None:
tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
tgt_query_pos_o2o = torch.cat(
[torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0
)
tgt_query_pos = torch.cat(
[torch.zeros_like(presence_token), tgt_query_pos], dim=0
)
q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask, need_weights=False)[0]
del q, k
tgt_o2o.add_(tgt2)
del tgt2
if dac:
if not dac_use_selfatt_ln:
tgt_o2o = self.norm2(tgt_o2o)
tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine
if dac_use_selfatt_ln:
tgt = self.norm2(tgt)
else:
tgt = tgt_o2o
tgt = self.norm2(tgt)
if self.use_text_cross_attention:
tgt2 = self.ca_text(
self.with_pos_embed(tgt, tgt_query_pos),
memory_text,
memory_text,
key_padding_mask=text_attention_mask,
need_weights=False,
)[0]
tgt.add_(tgt2)
del tgt2
tgt = self.catext_norm(tgt)
if presence_token is not None:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
cross_attn_mask = torch.cat(
[presence_token_mask, cross_attn_mask], dim=1
) # (bs*nheads, 1+nq, hw)
# Cross attention to image
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=cross_attn_mask,
key_padding_mask=(
memory_key_padding_mask.transpose(0, 1)
if memory_key_padding_mask is not None
else None
),
)[0]
tgt.add_(tgt2)
del tgt2
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
presence_token_out = None
if presence_token is not None:
presence_token_out = tgt[:1]
tgt = tgt[1:]
return tgt, presence_token_out
class TransformerDecoder(nn.Module):
def __init__(
self,
d_model: int,
frozen: bool,
interaction_layer,
layer,
num_layers: int,
num_queries: int,
return_intermediate: bool,
box_refine: bool = False,
num_o2m_queries: int = 0,
dac: bool = False,
boxRPB: str = "none",
# Experimental: An object query for SAM 2 tasks
instance_query: bool = False,
# Defines the number of additional instance queries,
# 1 or 4 are the most likely for single vs multi mask support
num_instances: int = 1, # Irrelevant if instance_query is False
dac_use_selfatt_ln: bool = True,
use_act_checkpoint: bool = False,
compile_mode=None,
presence_token: bool = False,
clamp_presence_logits: bool = True,
clamp_presence_logit_max_val: float = 10.0,
use_normed_output_consistently: bool = True,
separate_box_head_instance: bool = False,
separate_norm_instance: bool = False,
resolution: Optional[int] = None,
stride: Optional[int] = None,
):
super().__init__()
self.d_model = d_model
self.layers = get_clones(layer, num_layers)
self.fine_layers = (
get_clones(interaction_layer, num_layers)
if interaction_layer is not None
else [None] * num_layers
)
self.num_layers = num_layers
self.num_queries = num_queries
self.dac = dac
if dac:
self.num_o2m_queries = num_queries
tot_num_queries = num_queries
else:
self.num_o2m_queries = num_o2m_queries
tot_num_queries = num_queries + num_o2m_queries
self.norm = nn.LayerNorm(d_model)
self.return_intermediate = return_intermediate
self.bbox_embed = MLP(d_model, d_model, 4, 3)
self.query_embed = nn.Embedding(tot_num_queries, d_model)
self.instance_query_embed = None
self.instance_query_reference_points = None
self.use_instance_query = instance_query
self.num_instances = num_instances
self.use_normed_output_consistently = use_normed_output_consistently
self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
self.instance_bbox_embed = None
if separate_box_head_instance:
self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
if instance_query:
self.instance_query_embed = nn.Embedding(num_instances, d_model)
self.box_refine = box_refine
if box_refine:
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
self.reference_points = nn.Embedding(num_queries, 4)
if instance_query:
self.instance_reference_points = nn.Embedding(num_instances, 4)
assert boxRPB in ["none", "log", "linear", "both"]
self.boxRPB = boxRPB
if boxRPB != "none":
try:
nheads = self.layers[0].cross_attn_image.num_heads
except AttributeError:
nheads = self.layers[0].cross_attn.num_heads
n_input = 4 if boxRPB == "both" else 2
self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
self.compilable_cord_cache = None
self.compilable_stored_size = None
self.coord_cache = {}
if resolution is not None and stride is not None:
feat_size = resolution // stride
coords_h, coords_w = self._get_coords(
feat_size, feat_size, device="cuda"
)
self.compilable_cord_cache = (coords_h, coords_w)
self.compilable_stored_size = (feat_size, feat_size)
self.roi_pooler = (
RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
if interaction_layer is not None
else None
)
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.presence_token = None
self.clamp_presence_logits = clamp_presence_logits
self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
if presence_token:
self.presence_token = nn.Embedding(1, d_model)
self.presence_token_head = MLP(d_model, d_model, 1, 3)
self.presence_token_out_norm = nn.LayerNorm(d_model)
self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
self.dac_use_selfatt_ln = dac_use_selfatt_ln
self.use_act_checkpoint = use_act_checkpoint
nn.init.normal_(self.query_embed.weight.data)
if self.instance_query_embed is not None:
nn.init.normal_(self.instance_query_embed.weight.data)
assert self.roi_pooler is None
assert self.return_intermediate, "support return_intermediate only"
assert self.box_refine, "support box refine only"
self.compile_mode = compile_mode
self.compiled = False
# We defer compilation till after the first forward, to first warm-up the boxRPB cache
# assign layer index to each layer so that some layers can decide what to do
# based on which layer index they are (e.g. cross attention to memory bank only
# in selected layers)
for layer_idx, layer in enumerate(self.layers):
layer.layer_idx = layer_idx
@staticmethod
def _get_coords(H, W, device):
coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H
coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W
return coords_h, coords_w
def _get_rpb_matrix(self, reference_boxes, feat_size):
H, W = feat_size
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1)
bs, num_queries, _ = boxes_xyxy.shape
if self.compilable_cord_cache is None:
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
self.compilable_stored_size = (H, W)
if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
H,
W,
):
# good, hitting the cache, will be compilable
coords_h, coords_w = self.compilable_cord_cache
else:
# cache miss, will create compilation issue
# In case we're not compiling, we'll still rely on the dict-based cache
if feat_size not in self.coord_cache:
self.coord_cache[feat_size] = self._get_coords(
H, W, reference_boxes.device
)
coords_h, coords_w = self.coord_cache[feat_size]
assert coords_h.shape == (H,)
assert coords_w.shape == (W,)
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
deltas_x = deltas_x.view(bs, num_queries, -1, 2)
if self.boxRPB in ["log", "both"]:
deltas_x_log = deltas_x * 8 # normalize to -8, 8
deltas_x_log = (
torch.sign(deltas_x_log)
* torch.log2(torch.abs(deltas_x_log) + 1.0)
/ np.log2(8)
)
deltas_y_log = deltas_y * 8 # normalize to -8, 8
deltas_y_log = (
torch.sign(deltas_y_log)
* torch.log2(torch.abs(deltas_y_log) + 1.0)
/ np.log2(8)
)
if self.boxRPB == "log":
deltas_x = deltas_x_log
deltas_y = deltas_y_log
else:
deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
if self.training:
assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)(
x=deltas_x,
act_ckpt_enable=self.training and self.use_act_checkpoint,
) # bs, num_queries, W, n_heads
deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)(
x=deltas_y,
act_ckpt_enable=self.training and self.use_act_checkpoint,
) # bs, num_queries, H, n_heads
if not torch.compiler.is_dynamo_compiling():
assert deltas_x.shape[:3] == (bs, num_queries, W)
assert deltas_y.shape[:3] == (bs, num_queries, H)
B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
2
) # bs, num_queries, H, W, n_heads
if not torch.compiler.is_dynamo_compiling():
assert B.shape[:4] == (bs, num_queries, H, W)
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
B = B.contiguous() # memeff attn likes ordered strides
if not torch.compiler.is_dynamo_compiling():
assert B.shape[2:] == (num_queries, H * W)
return B
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
# for text
memory_text: Optional[Tensor] = None,
text_attention_mask: Optional[Tensor] = None,
# if `apply_dac` is None, it will default to `self.dac`
apply_dac: Optional[bool] = None,
is_instance_prompt=False,
decoder_extra_kwargs: Optional[Dict] = None,
# ROI memory bank
obj_roi_memory_feat=None,
obj_roi_memory_mask=None,
box_head_trk=None,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: \\sum{hw}, bs, d_model
- pos: \\sum{hw}, bs, d_model
- reference_boxes: nq, bs, 4 (after sigmoid)
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
if memory_mask is not None:
assert self.boxRPB == "none", (
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
)
apply_dac = apply_dac if apply_dac is not None else self.dac
if apply_dac:
assert (tgt.shape[0] == self.num_queries) or (
self.use_instance_query
and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
)
tgt = tgt.repeat(2, 1, 1)
# note that we don't tile tgt_mask, since DAC doesn't
# use self-attention in o2m queries
if reference_boxes is not None:
assert (reference_boxes.shape[0] == self.num_queries) or (
self.use_instance_query
and (
reference_boxes.shape[0]
== self.instance_query_embed.num_embeddings
)
)
reference_boxes = reference_boxes.repeat(2, 1, 1)
bs = tgt.shape[1]
intermediate = []
intermediate_presence_logits = []
presence_feats = None
if self.box_refine:
if reference_boxes is None:
# In this case, we're in a one-stage model, so we generate the reference boxes
reference_boxes = self.reference_points.weight.unsqueeze(1)
reference_boxes = (
reference_boxes.repeat(2, bs, 1)
if apply_dac
else reference_boxes.repeat(1, bs, 1)
)
reference_boxes = reference_boxes.sigmoid()
intermediate_ref_boxes = [reference_boxes]
else:
reference_boxes = None
intermediate_ref_boxes = None
output = tgt
presence_out = None
if self.presence_token is not None and is_instance_prompt is False:
# expand to batch dim
presence_out = self.presence_token.weight[None].expand(1, bs, -1)
box_head = self.bbox_embed
if is_instance_prompt and self.instance_bbox_embed is not None:
box_head = self.instance_bbox_embed
out_norm = self.norm
if is_instance_prompt and self.instance_norm is not None:
out_norm = self.instance_norm
for layer_idx, layer in enumerate(self.layers):
reference_points_input = (
reference_boxes[:, :, None]
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :], self.d_model
) # nq, bs, d_model*2
# conditional query
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
if self.boxRPB != "none" and reference_boxes is not None:
assert spatial_shapes.shape[0] == 1, (
"only single scale support implemented"
)
memory_mask = self._get_rpb_matrix(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
)
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training:
assert self.use_act_checkpoint, (
"Activation checkpointing not enabled in the decoder"
)
output, presence_out = activation_ckpt_wrapper(layer)(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory_text=memory_text,
text_attention_mask=text_attention_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask,
dac=apply_dac,
dac_use_selfatt_ln=self.dac_use_selfatt_ln,
presence_token=presence_out,
**(decoder_extra_kwargs or {}),
act_ckpt_enable=self.training and self.use_act_checkpoint,
# ROI memory bank
obj_roi_memory_feat=obj_roi_memory_feat,
obj_roi_memory_mask=obj_roi_memory_mask,
)
# iter update
if self.box_refine:
reference_before_sigmoid = inverse_sigmoid(reference_boxes)
if box_head_trk is None:
# delta_unsig = self.bbox_embed(output)
if not self.use_normed_output_consistently:
delta_unsig = box_head(output)
else:
delta_unsig = box_head(out_norm(output))
else:
# box_head_trk use a separate box head for tracking queries
Q_det = decoder_extra_kwargs["Q_det"]
assert output.size(0) >= Q_det
delta_unsig_det = self.bbox_embed(output[:Q_det])
delta_unsig_trk = box_head_trk(output[Q_det:])
delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_boxes = new_reference_points.detach()
if layer_idx != self.num_layers - 1:
intermediate_ref_boxes.append(new_reference_points)
else:
raise NotImplementedError("not implemented yet")
intermediate.append(out_norm(output))
if self.presence_token is not None and is_instance_prompt is False:
# norm, mlp head
intermediate_layer_presence_logits = self.presence_token_head(
self.presence_token_out_norm(presence_out)
).squeeze(-1)
# clamp to mitigate numerical issues
if self.clamp_presence_logits:
intermediate_layer_presence_logits.clamp(
min=-self.clamp_presence_logit_max_val,
max=self.clamp_presence_logit_max_val,
)
intermediate_presence_logits.append(intermediate_layer_presence_logits)
presence_feats = presence_out.clone()
if not self.compiled and self.compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=self.compile_mode, fullgraph=True
)
self.compiled = True
return (
torch.stack(intermediate),
torch.stack(intermediate_ref_boxes),
(
torch.stack(intermediate_presence_logits)
if self.presence_token is not None and is_instance_prompt is False
else None
),
presence_feats,
)
class TransformerEncoderCrossAttention(nn.Module):
def __init__(
self,
d_model: int,
frozen: bool,
pos_enc_at_input: bool,
layer,
num_layers: int,
use_act_checkpoint: bool = False,
batch_first: bool = False, # Do layers expect batch first input?
# which layers to exclude cross attention? default: None, means all
# layers use cross attention
remove_cross_attention_layers: Optional[list] = None,
):
super().__init__()
self.d_model = d_model
self.layers = get_clones(layer, num_layers)
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.pos_enc_at_input = pos_enc_at_input
self.use_act_checkpoint = use_act_checkpoint
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.batch_first = batch_first
# remove cross attention layers if specified
self.remove_cross_attention_layers = [False] * self.num_layers
if remove_cross_attention_layers is not None:
for i in remove_cross_attention_layers:
self.remove_cross_attention_layers[i] = True
assert len(self.remove_cross_attention_layers) == len(self.layers)
for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers):
if remove_cross_attention:
self.layers[i].cross_attn_image = None
self.layers[i].norm2 = None
def forward(
self,
src, # self-attention inputs
prompt, # cross-attention inputs
src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs
prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs
src_key_padding_mask: Optional[Tensor] = None,
prompt_key_padding_mask: Optional[Tensor] = None,
src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
feat_sizes: Optional[list] = None,
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
):
if isinstance(src, list):
assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list)
assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1
src, src_key_padding_mask, src_pos = (
src[0],
src_key_padding_mask[0],
src_pos[0],
)
assert src.shape[1] == prompt.shape[1], (
"Batch size must be the same for src and prompt"
)
output = src
if self.pos_enc_at_input and src_pos is not None:
output = output.clone()
output.add_(src_pos, alpha=0.1)
if self.batch_first:
# Convert to batch first
output = output.transpose(0, 1)
src_pos = src_pos.transpose(0, 1)
prompt = prompt.transpose(0, 1)
prompt_pos = prompt_pos.transpose(0, 1)
for layer in self.layers:
kwds = {}
if isinstance(layer.cross_attn_image, RoPEAttention):
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
output = activation_ckpt_wrapper(layer)(
tgt=output,
memory=prompt,
tgt_mask=src_mask,
memory_mask=prompt_mask,
tgt_key_padding_mask=src_key_padding_mask,
memory_key_padding_mask=prompt_key_padding_mask,
pos=prompt_pos,
query_pos=src_pos,
dac=False,
attn_bias=None,
act_ckpt_enable=self.training and self.use_act_checkpoint,
**kwds,
)
normed_output = self.norm(output)
if self.batch_first:
# Convert back to seq first
normed_output = normed_output.transpose(0, 1)
src_pos = src_pos.transpose(0, 1)
return {
"memory": normed_output,
"pos_embed": src_pos,
"padding_mask": src_key_padding_mask,
}
class TransformerDecoderLayerv1(nn.Module):
def __init__(
self,
activation: str,
cross_attention: nn.Module,
d_model: int,
dim_feedforward: int,
dropout: float,
pos_enc_at_attn: bool,
pos_enc_at_cross_attn_keys: bool,
pos_enc_at_cross_attn_queries: bool,
pre_norm: bool,
self_attention: nn.Module,
):
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.self_attn = self_attention
self.cross_attn_image = cross_attention
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.activation_str = activation
self.activation = get_activation_fn(activation)
self.pre_norm = pre_norm
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
def forward_ffn(self, x):
def _forward(x):
return self.linear2(self.activation(self.linear1(x)))
return chunked_ffn_forward(x, self.linear1.out_features, self.linear1.in_features, _forward)
def forward_post(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
**kwargs,
):
q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
# Self attention
tgt2 = self.self_attn(
q,
k,
value=tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)[0]
del q, k
tgt.add_(tgt2)
del tgt2
tgt = self.norm1(tgt)
# Cross attention to image
tgt2 = self.cross_attn_image(
query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt.add_(tgt2)
del tgt2
tgt = self.norm2(tgt)
# FFN
tgt2 = self.forward_ffn([tgt.clone()])
tgt.add_(tgt2)
del tgt2
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
dac: bool = False,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
**kwargs,
):
if dac:
# we only apply self attention to the first half of the queries
assert tgt.shape[0] % 2 == 0
other_tgt = tgt[tgt.shape[0] // 2 :]
tgt = tgt[: tgt.shape[0] // 2]
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(
q,
k,
value=tgt2,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)[0]
del q, k
tgt.add_(tgt2)
del tgt2
if dac:
# Recombine
tgt = torch.cat((tgt, other_tgt), dim=0)
tgt2 = self.norm2(tgt)
if self.pos_enc_at_cross_attn_queries:
tgt2.add_(query_pos)
tgt2 = self.cross_attn_image(
query=tgt2,
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
attn_bias=attn_bias,
)[0]
tgt.add_(tgt2)
del tgt2
tgt2 = self.norm3(tgt)
tgt2_list = [tgt2]
del tgt2
tgt2 = self.forward_ffn(tgt2_list)
tgt.add_(tgt2)
del tgt2
return tgt
def forward(
self,
tgt,
memory,
dac: bool = False,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
**kwds: Any,
) -> torch.Tensor:
fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
return fwd_fn(
tgt,
memory,
dac=dac,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
attn_bias=attn_bias,
**kwds,
)
class TransformerDecoderLayerv2(TransformerDecoderLayerv1):
def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any):
super().__init__(*args, **kwds)
self.cross_attention_first = cross_attention_first
def forward_ffn(self, x):
def _forward(x):
return self.linear2(self.activation(self.linear1(x)))
return chunked_ffn_forward(x, self.linear1.out_features, self.linear1.in_features, _forward)
def _forward_sa(self, tgt, query_pos):
# Self-Attention
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(q, k, v=tgt2)
del q, k
tgt.add_(tgt2)
del tgt2
return tgt
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
if self.cross_attn_image is None:
return tgt
kwds = {}
if num_k_exclude_rope > 0:
assert isinstance(self.cross_attn_image, RoPEAttention)
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
# Cross-Attention
tgt2 = self.norm2(tgt)
if self.pos_enc_at_cross_attn_queries:
tgt2.add_(query_pos)
tgt2 = self.cross_attn_image(
q=tgt2,
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
v=memory,
**kwds,
)
tgt.add_(tgt2)
del tgt2
return tgt
def forward_pre(
self,
tgt,
memory,
dac: bool,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
num_k_exclude_rope: int = 0,
):
assert dac is False
assert tgt_mask is None
assert memory_mask is None
assert tgt_key_padding_mask is None
assert memory_key_padding_mask is None
assert attn_bias is None
if self.cross_attention_first:
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
tgt = self._forward_sa(tgt, query_pos)
else:
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
# MLP
tgt2 = self.norm3(tgt)
tgt2_list = [tgt2]
del tgt2
tgt2 = self.forward_ffn(tgt2_list)
tgt.add_(tgt2)
del tgt2
return tgt
def forward(self, *args: Any, **kwds: Any) -> torch.Tensor:
if self.pre_norm:
return self.forward_pre(*args, **kwds)
raise NotImplementedError
def functional_attention(
qkv_list: list,
*,
dropout: float,
num_heads: int,
num_k_exclude_rope: int = 0,
freqs_cis: Optional[Tensor] = None,
freqs_cis_real: Optional[Tensor] = None,
freqs_cis_imag: Optional[Tensor] = None,
use_fa3: bool = False,
use_rope_real: bool = False,
rope_k_repeat: bool,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
q, k, v = qkv_list
qkv_list.clear()
b, n, cq = q.shape
_, m, ck = k.shape
_, _, cv = v.shape
if b > 1:
assert k.shape[0] == v.shape[0] == b
else:
# broadcast-able
assert k.shape[0] == b == 1, f"{q.shape=} {k.shape=} {v.shape=}"
assert v.shape[1] == m
q = q.reshape(b, n, num_heads, cq // num_heads).transpose(1, 2)
k = k.reshape(b, m, num_heads, ck // num_heads).transpose(1, 2)
v = v.reshape(v.shape[0], m, num_heads, cv // num_heads).transpose(1, 2)
if freqs_cis is not None or freqs_cis_real is not None:
num_k_rope = k.size(-2) - num_k_exclude_rope
if use_rope_real:
qk_list = [q, k[:, :, :num_k_rope]]
del q
q, k_rope = apply_rotary_enc_real(
qk_list,
freqs_cis_real=freqs_cis_real,
freqs_cis_imag=freqs_cis_imag,
repeat_freqs_k=rope_k_repeat,
)
k[:, :, :num_k_rope] = k_rope
del k_rope
else:
qk_list = [q, k[:, :, :num_k_rope]]
del q
q, k_rope = apply_rotary_enc(
qk_list,
freqs_cis=freqs_cis,
repeat_freqs_k=rope_k_repeat,
)
k[:, :, :num_k_rope] = k_rope
del k_rope
if use_fa3:
from ..perflib.fa3 import flash_attn_func
assert dropout == 0.0
out = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2))
del q, k, v
else:
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]):
out = torchF.scaled_dot_product_attention(q, k, v, dropout_p=dropout)
del q, k, v
out = out.transpose(1, 2) # B * n * n_heads * (cv // num_heads)
out = out.reshape(b, n, cv)
return out
class SimpleRoPEAttention(nn.Module):
"""
Attention with rotary position encoding.
This class is "simple" because it does not perform q/k/v/out projections.
"""
def __init__(
self,
d_model: int,
num_heads: int,
dropout_p: float,
rope_theta=10000.0,
# whether to repeat q rope to match k length
# this is needed for cross-attention to memories
rope_k_repeat=False,
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
use_fa3: bool = False,
use_rope_real: bool = False,
):
super().__init__()
self.num_heads = num_heads
compute_fn = compute_axial_cis_real if use_rope_real else compute_axial_cis
self.compute_cis = partial(compute_fn, dim=d_model // num_heads, theta=rope_theta)
device = None
self.freqs_cis = None
self.freqs_cis_real = None
self.freqs_cis_imag = None
self.use_fa3 = use_fa3
self.use_rope_real = use_rope_real
if self.use_rope_real:
self.freqs_cis_real, self.freqs_cis_imag = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1], device=device)
else:
self.freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1], device=device)
self.rope_k_repeat = rope_k_repeat
def forward(
self,
qkv_list: list,
num_k_exclude_rope: int = 0,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
q = qkv_list[0]
# Apply rotary position encoding
w = h = math.sqrt(q.shape[-2])
freqs_len = self.freqs_cis_real.shape[0] if self.use_rope_real else self.freqs_cis.shape[0]
if freqs_len != q.shape[-2]:
if self.use_rope_real:
self.freqs_cis_real, self.freqs_cis_imag = self.compute_cis(end_x=w, end_y=h, device=q.device)
else:
self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device)
elif self.use_rope_real:
self.freqs_cis_real = self.freqs_cis_real.to(q.device)
self.freqs_cis_imag = self.freqs_cis_imag.to(q.device)
else:
self.freqs_cis = self.freqs_cis.to(q.device)
if q.shape[-2] != qkv_list[1].shape[-2]:
assert self.rope_k_repeat
del q
out = functional_attention(
qkv_list,
dropout=0.0,
num_heads=self.num_heads,
num_k_exclude_rope=num_k_exclude_rope,
freqs_cis=self.freqs_cis,
freqs_cis_real=self.freqs_cis_real if self.use_rope_real else None,
freqs_cis_imag=self.freqs_cis_imag if self.use_rope_real else None,
use_fa3=self.use_fa3,
use_rope_real=self.use_rope_real,
rope_k_repeat=self.rope_k_repeat,
)
return out
class DecoupledTransformerDecoderLayerv2(nn.Module):
def __init__(
self,
*,
activation: str,
d_model: int,
num_heads: int,
dim_feedforward: int,
dropout: float,
pos_enc_at_attn: bool,
pos_enc_at_cross_attn_keys: bool,
pos_enc_at_cross_attn_queries: bool,
pre_norm: bool,
cross_attention_first: bool = False,
self_attention_rope: SimpleRoPEAttention,
cross_attention_rope: SimpleRoPEAttention,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.dim_feedforward = dim_feedforward
self.self_attn_q_proj = nn.Linear(d_model, d_model)
self.self_attn_k_proj = nn.Linear(d_model, d_model)
self.self_attn_v_proj = nn.Linear(d_model, d_model)
self.self_attn_out_proj = nn.Linear(d_model, d_model)
self.cross_attn_q_proj = nn.Linear(d_model, d_model)
self.cross_attn_k_proj = nn.Linear(d_model, d_model)
self.cross_attn_v_proj = nn.Linear(d_model, d_model)
self.cross_attn_out_proj = nn.Linear(d_model, d_model)
self.image_cross_attn_q_proj = nn.Linear(d_model, d_model)
self.image_cross_attn_k_proj = nn.Linear(d_model, d_model)
self.self_attention_rope = self_attention_rope
self.cross_attention_rope = cross_attention_rope
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.activation_str = activation
self.activation = get_activation_fn(activation)
self.pre_norm = pre_norm
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
self.cross_attention_first = cross_attention_first
def forward_ffn(self, x):
def _forward(x):
return self.linear2(self.activation(self.linear1(x)))
return chunked_ffn_forward(x, self.linear1.out_features, self.linear1.in_features, _forward)
def _forward_sa(self, tgt, query_pos):
# Self-Attention
tgt2 = self.norm1(tgt)
v = self.self_attn_v_proj(tgt2)
if self.pos_enc_at_attn:
tgt2.add_(query_pos)
q = self.self_attn_q_proj(tgt2)
k = self.self_attn_k_proj(tgt2)
del tgt2
qkv_list = [q, k, v]
del q, k, v
out = self.self_attention_rope(qkv_list)
tgt2 = self.self_attn_out_proj(out)
del out
tgt.add_(tgt2)
del tgt2
return tgt
def _forward_ca(
self,
*,
image,
tgt,
memory_image,
memory,
query_pos,
memory_image_pos,
num_k_exclude_rope=0,
):
kwds = {}
if num_k_exclude_rope > 0:
assert isinstance(self.cross_attention_rope, SimpleRoPEAttention)
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
# Cross-Attention
tgt2 = self.norm2(tgt)
q = self.image_cross_attn_q_proj(image)
q.add_(self.cross_attn_q_proj(tgt2))
if self.pos_enc_at_cross_attn_queries:
q.add_(query_pos)
k = self.image_cross_attn_k_proj(memory_image)
k.add_(self.cross_attn_k_proj(memory))
if self.pos_enc_at_cross_attn_keys:
k.add_(memory_image_pos)
v = self.cross_attn_v_proj(memory)
del tgt2
qkv_list = [q, k, v]
del q, k, v
out = self.cross_attention_rope(qkv_list, **kwds)
tgt2 = self.cross_attn_out_proj(out)
del out
tgt.add_(tgt2)
del tgt2
return tgt
def forward_pre(
self,
*,
image,
tgt,
memory_image,
memory,
image_pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
memory_image_pos: Optional[Tensor] = None,
memory_pos: Optional[Tensor] = None,
num_k_exclude_rope: int = 0,
):
if self.cross_attention_first:
tgt = self._forward_ca(
image=image,
tgt=tgt,
memory_image=memory_image,
memory=memory,
query_pos=query_pos,
memory_image_pos=memory_image_pos,
num_k_exclude_rope=num_k_exclude_rope,
)
tgt = self._forward_sa(tgt, query_pos)
else:
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(
image=image,
tgt=tgt,
memory_image=memory_image,
memory=memory,
query_pos=query_pos,
memory_image_pos=memory_image_pos,
num_k_exclude_rope=num_k_exclude_rope,
)
# MLP
tgt2 = self.norm3(tgt)
tgt2_list = [tgt2]
del tgt2
tgt2 = self.forward_ffn(tgt2_list)
tgt.add_(tgt2)
del tgt2
return image, tgt
def forward(self, *args: Any, **kwds: Any) -> torch.Tensor:
if self.pre_norm:
return self.forward_pre(*args, **kwds)
raise NotImplementedError
class TransformerEncoderDecoupledCrossAttention(nn.Module):
def __init__(
self,
d_model: int,
frozen: bool,
pos_enc_at_input: bool,
layer,
num_layers: int,
use_act_checkpoint: bool = False,
batch_first: bool = False, # Do layers expect batch first input?
use_image_in_output: bool = True,
):
super().__init__()
self.d_model = d_model
self.layers = get_clones(layer, num_layers)
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.pos_enc_at_input = pos_enc_at_input
self.use_act_checkpoint = use_act_checkpoint
self.use_image_in_output = use_image_in_output
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.batch_first = batch_first
def forward(
self,
image: Tensor, # image features
src: Tensor, # self-attention inputs; object features
memory_image: Tensor, # cross-attention inputs; image features
memory: Tensor, # cross-attention inputs; object features
image_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
memory_image_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
):
assert src.shape[1] == memory.shape[1], (
"Batch size must be the same for src and memory"
)
assert image.shape[1] == memory_image.shape[1], (
"Batch size must be the same for image and memory_image"
)
output = src
if self.pos_enc_at_input and src_pos is not None:
output = output.clone()
output.add_(src_pos, alpha=0.1)
if self.batch_first:
# Convert to batch first
output = output.transpose(0, 1)
src_pos = src_pos.transpose(0, 1)
image = image.transpose(0, 1)
memory = memory.transpose(0, 1)
memory_pos = memory_pos.transpose(0, 1)
memory_image = memory_image.transpose(0, 1)
memory_image_pos = memory_image_pos.transpose(0, 1)
if memory_image.shape[1] != memory.shape[1]:
# Pad memory_image with zeros, to accodmate object pointers
assert (memory.shape[1] - memory_image.shape[1]) == num_obj_ptr_tokens, (
f"{memory.shape[1]} - {memory_image.shape[1]} != {num_obj_ptr_tokens}"
)
memory_image = torch.cat(
[
memory_image,
torch.zeros(
(memory_image.shape[0], num_obj_ptr_tokens)
+ memory_image.shape[2:],
dtype=memory_image.dtype,
device=memory_image.device,
),
],
dim=1,
)
if memory_image_pos is not None:
assert (
memory_pos.shape[1] - memory_image_pos.shape[1]
) == num_obj_ptr_tokens, (
f"{memory_pos.shape[1]} - {memory_image_pos.shape[1]} != {num_obj_ptr_tokens}"
)
# tpos is the same in the batch anyway; note that memory_image always has a batch size of 1
memory_image_pos = torch.cat(
[
memory_image_pos,
memory_pos[0:1, -num_obj_ptr_tokens:],
],
dim=1,
)
for layer in self.layers:
image, output = activation_ckpt_wrapper(layer)(
image=image,
tgt=output,
memory_image=memory_image,
memory=memory,
image_pos=image_pos,
query_pos=src_pos,
memory_image_pos=memory_image_pos,
memory_pos=memory_pos,
num_k_exclude_rope=num_obj_ptr_tokens,
act_ckpt_enable=self.training and self.use_act_checkpoint,
)
if self.use_image_in_output:
output = output.clone()
output.add_(image)
normed_output = self.norm(output)
else:
normed_output = self.norm(output)
if self.batch_first:
# Convert back to seq first
normed_output = normed_output.transpose(0, 1)
src_pos = src_pos.transpose(0, 1)
return {
"memory": normed_output,
"pos_embed": src_pos,
}
|