| import os |
| import re |
| import types |
|
|
|
|
| def _normalize_1k_coords(text): |
| def _replace_tuple(match): |
| inner = match.group(1) |
| nums = [s.strip() for s in inner.split(",")] |
| if all(re.fullmatch(r"\d+", n) for n in nums): |
| floats = [f"{int(n) / 1000:.3f}" for n in nums] |
| return "(" + ", ".join(floats) + ")" |
| return match.group(0) |
| return re.sub(r"\(([^()]+)\)", _replace_tuple, text) |
|
|
|
|
| def _parse_think_and_answer(text): |
| thinking = "" |
| if "</think>" in text: |
| parts = text.split("</think>", 1) |
| thinking = parts[0].strip() |
| after = parts[1] |
| answer_match = re.search(r"<answer>(.*?)</answer>", after, re.DOTALL) |
| if answer_match: |
| return thinking, answer_match.group(1).strip() |
| answer = after.replace("<answer>", "").replace("</answer>", "").strip() |
| return thinking, answer |
| return thinking, text |
|
|
|
|
| def _normalize_binary_answer(text): |
| content = text.strip() |
| if not content: |
| return content |
| line = content.splitlines()[-1].strip() |
| line = re.sub(r"(?i)^answer\s*[::]\s*", "", line).strip() |
| if line.lower().startswith("yes"): |
| return "yes" |
| if line.lower().startswith("no"): |
| return "no" |
| m = re.search(r"(?i)\b(yes|no)\b", line) |
| if m: |
| return m.group(1).lower() |
| return content |
|
|
|
|
| def load_robobrain(model_path): |
| import torch as _torch |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
| print(f"Loading RoboBrain checkpoint from {model_path} ...") |
| try: |
| model = AutoModelForImageTextToText.from_pretrained( |
| model_path, dtype="auto", device_map="auto" |
| ) |
| except ValueError as exc: |
| if "requires `accelerate`" not in str(exc): |
| raise |
| model = AutoModelForImageTextToText.from_pretrained( |
| model_path, dtype="auto" |
| ).to("cuda") |
|
|
| _orig_get_placeholder_mask = model.model.get_placeholder_mask.__func__ |
|
|
| def _patched_get_placeholder_mask(self, input_ids, inputs_embeds, image_features=None, video_features=None): |
| if input_ids is None: |
| special_image_mask = inputs_embeds == self.get_input_embeddings()( |
| _torch.tensor(self.config.image_token_id, dtype=_torch.long, device=inputs_embeds.device) |
| ) |
| special_image_mask = special_image_mask.all(-1) |
| special_video_mask = inputs_embeds == self.get_input_embeddings()( |
| _torch.tensor(self.config.video_token_id, dtype=_torch.long, device=inputs_embeds.device) |
| ) |
| special_video_mask = special_video_mask.all(-1) |
| else: |
| special_image_mask = input_ids == self.config.image_token_id |
| special_video_mask = input_ids == self.config.video_token_id |
| special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
| special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
| return special_image_mask, special_video_mask |
|
|
| model.model.get_placeholder_mask = types.MethodType(_patched_get_placeholder_mask, model.model) |
|
|
| _text_model = model.model.language_model |
| _orig_deepstack = _text_model._deepstack_process |
|
|
| def _safe_deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds): |
| if not visual_pos_masks.any(): |
| return hidden_states |
| return _orig_deepstack(hidden_states, visual_pos_masks, visual_embeds) |
|
|
| _text_model._deepstack_process = types.MethodType(_safe_deepstack_process, _text_model) |
|
|
| min_pixels = 256 * 28 * 28 |
| max_pixels = 1280 * 28 * 28 |
| processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels) |
|
|
| return {"model": model, "processor": processor} |
|
|
|
|
| def run_robobrain(question, image_path, depth_path, kwargs, return_thinking=False, LM_classify=None, add_think_override=None): |
| from qwen_vl_utils import process_vision_info |
|
|
| model = kwargs["model"] |
| processor = kwargs["processor"] |
|
|
| question_lower = question.lower().strip() |
|
|
| _POINTING_KEYWORDS = ["pinpoint", "coordinates", "point to", "locate the position"] |
| _BINARY_PHRASES = ["yes or no", "determine whether", "answer with only", "answer yes or no"] |
| _BINARY_STARTS = ("is ", "are ", "does ", "do ", "can ", "has ", "was ", "were ", "will ") |
|
|
| _POINTING_POST_PROMPT = ( |
| " You MUST provide at least 5 distinct 2D points that satisfy the conditions above." |
| " Do NOT output only one point. Format your final answer strictly as a list of tuples:" |
| " [(x1, y1), (x2, y2), (x3, y3), ...]." |
| ) |
| _BINARY_POST_PROMPT = ( |
| "Your task is to answer the question above. Respond with a brief explanation" |
| " if needed, followed by a yes or no answer in the last line of your response.\n\n" |
| "Format your final answer strictly as follows on the last line of your response:\n" |
| "Answer: yes or no\n\n" |
| "Do not include additional text after this line.\n" |
| ) |
|
|
| if LM_classify is not None: |
| if LM_classify == "context": |
| post_prompt = _POINTING_POST_PROMPT |
| add_think = True |
| q_type = "pointing" |
| elif LM_classify == "compatibility": |
| post_prompt = _BINARY_POST_PROMPT |
| add_think = True |
| q_type = "binary" |
| else: |
| post_prompt = "" |
| add_think = False |
| q_type = "binary" |
| else: |
| if any(kw in question_lower for kw in _POINTING_KEYWORDS): |
| q_type = "pointing" |
| elif ( |
| any(phrase in question_lower for phrase in _BINARY_PHRASES) |
| or question_lower.startswith(_BINARY_STARTS) |
| ): |
| q_type = "binary" |
| else: |
| q_type = "open" |
|
|
| if q_type == "pointing": |
| post_prompt = _POINTING_POST_PROMPT |
| elif q_type == "binary": |
| post_prompt = _BINARY_POST_PROMPT |
| else: |
| post_prompt = "" |
|
|
| add_think = True |
|
|
| if add_think_override is not None: |
| add_think = add_think_override |
|
|
| full_question = question + post_prompt |
|
|
| image_uri = image_path if image_path.startswith("http") else f"file://{image_path}" |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image_uri}, |
| {"type": "text", "text": full_question}, |
| ], |
| }, |
| ] |
|
|
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
| if add_think: |
| text = text + "<think>\n" |
|
|
| image_inputs, video_inputs = process_vision_info(messages) |
| inputs = processor( |
| text=[text], |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt", |
| ) |
| _first_device = next(model.parameters()).device |
| inputs = inputs.to(_first_device) |
|
|
| max_new_tokens = int(os.environ.get("ROBOBRAIN_MAX_NEW_TOKENS", "2048")) |
| generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) |
|
|
| generated_ids_trimmed = [ |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| ] |
| output_text = processor.batch_decode( |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| ) |
| response = output_text[0].strip() if output_text else "" |
|
|
| thinking, answer_text = _parse_think_and_answer(response) |
|
|
| if q_type == "pointing": |
| answer_text = _normalize_1k_coords(answer_text) |
| elif q_type == "binary": |
| answer_text = _normalize_binary_answer(answer_text) |
|
|
| if return_thinking: |
| return thinking, answer_text |
| return answer_text |
|
|