File size: 7,811 Bytes
3bbe20b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | import os
import re
import types
def _normalize_1k_coords(text):
def _replace_tuple(match):
inner = match.group(1)
nums = [s.strip() for s in inner.split(",")]
if all(re.fullmatch(r"\d+", n) for n in nums):
floats = [f"{int(n) / 1000:.3f}" for n in nums]
return "(" + ", ".join(floats) + ")"
return match.group(0)
return re.sub(r"\(([^()]+)\)", _replace_tuple, text)
def _parse_think_and_answer(text):
thinking = ""
if "</think>" in text:
parts = text.split("</think>", 1)
thinking = parts[0].strip()
after = parts[1]
answer_match = re.search(r"<answer>(.*?)</answer>", after, re.DOTALL)
if answer_match:
return thinking, answer_match.group(1).strip()
answer = after.replace("<answer>", "").replace("</answer>", "").strip()
return thinking, answer
return thinking, text
def _normalize_binary_answer(text):
content = text.strip()
if not content:
return content
line = content.splitlines()[-1].strip()
line = re.sub(r"(?i)^answer\s*[::]\s*", "", line).strip()
if line.lower().startswith("yes"):
return "yes"
if line.lower().startswith("no"):
return "no"
m = re.search(r"(?i)\b(yes|no)\b", line)
if m:
return m.group(1).lower()
return content
def load_robobrain(model_path):
import torch as _torch
from transformers import AutoModelForImageTextToText, AutoProcessor
print(f"Loading RoboBrain checkpoint from {model_path} ...")
try:
model = AutoModelForImageTextToText.from_pretrained(
model_path, dtype="auto", device_map="auto"
)
except ValueError as exc:
if "requires `accelerate`" not in str(exc):
raise
model = AutoModelForImageTextToText.from_pretrained(
model_path, dtype="auto"
).to("cuda")
_orig_get_placeholder_mask = model.model.get_placeholder_mask.__func__
def _patched_get_placeholder_mask(self, input_ids, inputs_embeds, image_features=None, video_features=None):
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
_torch.tensor(self.config.image_token_id, dtype=_torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_video_mask = inputs_embeds == self.get_input_embeddings()(
_torch.tensor(self.config.video_token_id, dtype=_torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
return special_image_mask, special_video_mask
model.model.get_placeholder_mask = types.MethodType(_patched_get_placeholder_mask, model.model)
_text_model = model.model.language_model
_orig_deepstack = _text_model._deepstack_process
def _safe_deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds):
if not visual_pos_masks.any():
return hidden_states
return _orig_deepstack(hidden_states, visual_pos_masks, visual_embeds)
_text_model._deepstack_process = types.MethodType(_safe_deepstack_process, _text_model)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
return {"model": model, "processor": processor}
def run_robobrain(question, image_path, depth_path, kwargs, return_thinking=False, LM_classify=None, add_think_override=None):
from qwen_vl_utils import process_vision_info
model = kwargs["model"]
processor = kwargs["processor"]
question_lower = question.lower().strip()
_POINTING_KEYWORDS = ["pinpoint", "coordinates", "point to", "locate the position"]
_BINARY_PHRASES = ["yes or no", "determine whether", "answer with only", "answer yes or no"]
_BINARY_STARTS = ("is ", "are ", "does ", "do ", "can ", "has ", "was ", "were ", "will ")
_POINTING_POST_PROMPT = (
" You MUST provide at least 5 distinct 2D points that satisfy the conditions above."
" Do NOT output only one point. Format your final answer strictly as a list of tuples:"
" [(x1, y1), (x2, y2), (x3, y3), ...]."
)
_BINARY_POST_PROMPT = (
"Your task is to answer the question above. Respond with a brief explanation"
" if needed, followed by a yes or no answer in the last line of your response.\n\n"
"Format your final answer strictly as follows on the last line of your response:\n"
"Answer: yes or no\n\n"
"Do not include additional text after this line.\n"
)
if LM_classify is not None:
if LM_classify == "context":
post_prompt = _POINTING_POST_PROMPT
add_think = True
q_type = "pointing"
elif LM_classify == "compatibility":
post_prompt = _BINARY_POST_PROMPT
add_think = True
q_type = "binary"
else:
post_prompt = ""
add_think = False
q_type = "binary"
else:
if any(kw in question_lower for kw in _POINTING_KEYWORDS):
q_type = "pointing"
elif (
any(phrase in question_lower for phrase in _BINARY_PHRASES)
or question_lower.startswith(_BINARY_STARTS)
):
q_type = "binary"
else:
q_type = "open"
if q_type == "pointing":
post_prompt = _POINTING_POST_PROMPT
elif q_type == "binary":
post_prompt = _BINARY_POST_PROMPT
else:
post_prompt = ""
add_think = True
if add_think_override is not None:
add_think = add_think_override
full_question = question + post_prompt
image_uri = image_path if image_path.startswith("http") else f"file://{image_path}"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_uri},
{"type": "text", "text": full_question},
],
},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if add_think:
text = text + "<think>\n"
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
_first_device = next(model.parameters()).device
inputs = inputs.to(_first_device)
max_new_tokens = int(os.environ.get("ROBOBRAIN_MAX_NEW_TOKENS", "2048"))
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
response = output_text[0].strip() if output_text else ""
thinking, answer_text = _parse_think_and_answer(response)
if q_type == "pointing":
answer_text = _normalize_1k_coords(answer_text)
elif q_type == "binary":
answer_text = _normalize_binary_answer(answer_text)
if return_thinking:
return thinking, answer_text
return answer_text
|