import os import re import types def _normalize_1k_coords(text): def _replace_tuple(match): inner = match.group(1) nums = [s.strip() for s in inner.split(",")] if all(re.fullmatch(r"\d+", n) for n in nums): floats = [f"{int(n) / 1000:.3f}" for n in nums] return "(" + ", ".join(floats) + ")" return match.group(0) return re.sub(r"\(([^()]+)\)", _replace_tuple, text) def _parse_think_and_answer(text): thinking = "" if "" in text: parts = text.split("", 1) thinking = parts[0].strip() after = parts[1] answer_match = re.search(r"(.*?)", after, re.DOTALL) if answer_match: return thinking, answer_match.group(1).strip() answer = after.replace("", "").replace("", "").strip() return thinking, answer return thinking, text def _normalize_binary_answer(text): content = text.strip() if not content: return content line = content.splitlines()[-1].strip() line = re.sub(r"(?i)^answer\s*[::]\s*", "", line).strip() if line.lower().startswith("yes"): return "yes" if line.lower().startswith("no"): return "no" m = re.search(r"(?i)\b(yes|no)\b", line) if m: return m.group(1).lower() return content def load_robobrain(model_path): import torch as _torch from transformers import AutoModelForImageTextToText, AutoProcessor print(f"Loading RoboBrain checkpoint from {model_path} ...") try: model = AutoModelForImageTextToText.from_pretrained( model_path, dtype="auto", device_map="auto" ) except ValueError as exc: if "requires `accelerate`" not in str(exc): raise model = AutoModelForImageTextToText.from_pretrained( model_path, dtype="auto" ).to("cuda") _orig_get_placeholder_mask = model.model.get_placeholder_mask.__func__ def _patched_get_placeholder_mask(self, input_ids, inputs_embeds, image_features=None, video_features=None): if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( _torch.tensor(self.config.image_token_id, dtype=_torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) special_video_mask = inputs_embeds == self.get_input_embeddings()( _torch.tensor(self.config.video_token_id, dtype=_torch.long, device=inputs_embeds.device) ) special_video_mask = special_video_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id special_video_mask = input_ids == self.config.video_token_id special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) return special_image_mask, special_video_mask model.model.get_placeholder_mask = types.MethodType(_patched_get_placeholder_mask, model.model) _text_model = model.model.language_model _orig_deepstack = _text_model._deepstack_process def _safe_deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds): if not visual_pos_masks.any(): return hidden_states return _orig_deepstack(hidden_states, visual_pos_masks, visual_embeds) _text_model._deepstack_process = types.MethodType(_safe_deepstack_process, _text_model) min_pixels = 256 * 28 * 28 max_pixels = 1280 * 28 * 28 processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels) return {"model": model, "processor": processor} def run_robobrain(question, image_path, depth_path, kwargs, return_thinking=False, LM_classify=None, add_think_override=None): from qwen_vl_utils import process_vision_info model = kwargs["model"] processor = kwargs["processor"] question_lower = question.lower().strip() _POINTING_KEYWORDS = ["pinpoint", "coordinates", "point to", "locate the position"] _BINARY_PHRASES = ["yes or no", "determine whether", "answer with only", "answer yes or no"] _BINARY_STARTS = ("is ", "are ", "does ", "do ", "can ", "has ", "was ", "were ", "will ") _POINTING_POST_PROMPT = ( " You MUST provide at least 5 distinct 2D points that satisfy the conditions above." " Do NOT output only one point. Format your final answer strictly as a list of tuples:" " [(x1, y1), (x2, y2), (x3, y3), ...]." ) _BINARY_POST_PROMPT = ( "Your task is to answer the question above. Respond with a brief explanation" " if needed, followed by a yes or no answer in the last line of your response.\n\n" "Format your final answer strictly as follows on the last line of your response:\n" "Answer: yes or no\n\n" "Do not include additional text after this line.\n" ) if LM_classify is not None: if LM_classify == "context": post_prompt = _POINTING_POST_PROMPT add_think = True q_type = "pointing" elif LM_classify == "compatibility": post_prompt = _BINARY_POST_PROMPT add_think = True q_type = "binary" else: post_prompt = "" add_think = False q_type = "binary" else: if any(kw in question_lower for kw in _POINTING_KEYWORDS): q_type = "pointing" elif ( any(phrase in question_lower for phrase in _BINARY_PHRASES) or question_lower.startswith(_BINARY_STARTS) ): q_type = "binary" else: q_type = "open" if q_type == "pointing": post_prompt = _POINTING_POST_PROMPT elif q_type == "binary": post_prompt = _BINARY_POST_PROMPT else: post_prompt = "" add_think = True if add_think_override is not None: add_think = add_think_override full_question = question + post_prompt image_uri = image_path if image_path.startswith("http") else f"file://{image_path}" messages = [ { "role": "user", "content": [ {"type": "image", "image": image_uri}, {"type": "text", "text": full_question}, ], }, ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if add_think: text = text + "\n" image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) _first_device = next(model.parameters()).device inputs = inputs.to(_first_device) max_new_tokens = int(os.environ.get("ROBOBRAIN_MAX_NEW_TOKENS", "2048")) generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) response = output_text[0].strip() if output_text else "" thinking, answer_text = _parse_think_and_answer(response) if q_type == "pointing": answer_text = _normalize_1k_coords(answer_text) elif q_type == "binary": answer_text = _normalize_binary_answer(answer_text) if return_thinking: return thinking, answer_text return answer_text