RoboSpatialBrain / scripts /robobrain_runner.py
lbx511's picture
Upload folder using huggingface_hub
3bbe20b verified
import os
import re
import types
def _normalize_1k_coords(text):
def _replace_tuple(match):
inner = match.group(1)
nums = [s.strip() for s in inner.split(",")]
if all(re.fullmatch(r"\d+", n) for n in nums):
floats = [f"{int(n) / 1000:.3f}" for n in nums]
return "(" + ", ".join(floats) + ")"
return match.group(0)
return re.sub(r"\(([^()]+)\)", _replace_tuple, text)
def _parse_think_and_answer(text):
thinking = ""
if "</think>" in text:
parts = text.split("</think>", 1)
thinking = parts[0].strip()
after = parts[1]
answer_match = re.search(r"<answer>(.*?)</answer>", after, re.DOTALL)
if answer_match:
return thinking, answer_match.group(1).strip()
answer = after.replace("<answer>", "").replace("</answer>", "").strip()
return thinking, answer
return thinking, text
def _normalize_binary_answer(text):
content = text.strip()
if not content:
return content
line = content.splitlines()[-1].strip()
line = re.sub(r"(?i)^answer\s*[::]\s*", "", line).strip()
if line.lower().startswith("yes"):
return "yes"
if line.lower().startswith("no"):
return "no"
m = re.search(r"(?i)\b(yes|no)\b", line)
if m:
return m.group(1).lower()
return content
def load_robobrain(model_path):
import torch as _torch
from transformers import AutoModelForImageTextToText, AutoProcessor
print(f"Loading RoboBrain checkpoint from {model_path} ...")
try:
model = AutoModelForImageTextToText.from_pretrained(
model_path, dtype="auto", device_map="auto"
)
except ValueError as exc:
if "requires `accelerate`" not in str(exc):
raise
model = AutoModelForImageTextToText.from_pretrained(
model_path, dtype="auto"
).to("cuda")
_orig_get_placeholder_mask = model.model.get_placeholder_mask.__func__
def _patched_get_placeholder_mask(self, input_ids, inputs_embeds, image_features=None, video_features=None):
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
_torch.tensor(self.config.image_token_id, dtype=_torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_video_mask = inputs_embeds == self.get_input_embeddings()(
_torch.tensor(self.config.video_token_id, dtype=_torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
return special_image_mask, special_video_mask
model.model.get_placeholder_mask = types.MethodType(_patched_get_placeholder_mask, model.model)
_text_model = model.model.language_model
_orig_deepstack = _text_model._deepstack_process
def _safe_deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds):
if not visual_pos_masks.any():
return hidden_states
return _orig_deepstack(hidden_states, visual_pos_masks, visual_embeds)
_text_model._deepstack_process = types.MethodType(_safe_deepstack_process, _text_model)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
return {"model": model, "processor": processor}
def run_robobrain(question, image_path, depth_path, kwargs, return_thinking=False, LM_classify=None, add_think_override=None):
from qwen_vl_utils import process_vision_info
model = kwargs["model"]
processor = kwargs["processor"]
question_lower = question.lower().strip()
_POINTING_KEYWORDS = ["pinpoint", "coordinates", "point to", "locate the position"]
_BINARY_PHRASES = ["yes or no", "determine whether", "answer with only", "answer yes or no"]
_BINARY_STARTS = ("is ", "are ", "does ", "do ", "can ", "has ", "was ", "were ", "will ")
_POINTING_POST_PROMPT = (
" You MUST provide at least 5 distinct 2D points that satisfy the conditions above."
" Do NOT output only one point. Format your final answer strictly as a list of tuples:"
" [(x1, y1), (x2, y2), (x3, y3), ...]."
)
_BINARY_POST_PROMPT = (
"Your task is to answer the question above. Respond with a brief explanation"
" if needed, followed by a yes or no answer in the last line of your response.\n\n"
"Format your final answer strictly as follows on the last line of your response:\n"
"Answer: yes or no\n\n"
"Do not include additional text after this line.\n"
)
if LM_classify is not None:
if LM_classify == "context":
post_prompt = _POINTING_POST_PROMPT
add_think = True
q_type = "pointing"
elif LM_classify == "compatibility":
post_prompt = _BINARY_POST_PROMPT
add_think = True
q_type = "binary"
else:
post_prompt = ""
add_think = False
q_type = "binary"
else:
if any(kw in question_lower for kw in _POINTING_KEYWORDS):
q_type = "pointing"
elif (
any(phrase in question_lower for phrase in _BINARY_PHRASES)
or question_lower.startswith(_BINARY_STARTS)
):
q_type = "binary"
else:
q_type = "open"
if q_type == "pointing":
post_prompt = _POINTING_POST_PROMPT
elif q_type == "binary":
post_prompt = _BINARY_POST_PROMPT
else:
post_prompt = ""
add_think = True
if add_think_override is not None:
add_think = add_think_override
full_question = question + post_prompt
image_uri = image_path if image_path.startswith("http") else f"file://{image_path}"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_uri},
{"type": "text", "text": full_question},
],
},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if add_think:
text = text + "<think>\n"
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
_first_device = next(model.parameters()).device
inputs = inputs.to(_first_device)
max_new_tokens = int(os.environ.get("ROBOBRAIN_MAX_NEW_TOKENS", "2048"))
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
response = output_text[0].strip() if output_text else ""
thinking, answer_text = _parse_think_and_answer(response)
if q_type == "pointing":
answer_text = _normalize_1k_coords(answer_text)
elif q_type == "binary":
answer_text = _normalize_binary_answer(answer_text)
if return_thinking:
return thinking, answer_text
return answer_text