RoboSpatialBrain / scripts /object_direction_remap.py
lbx511's picture
Upload folder using huggingface_hub
3bbe20b verified
import os
import re
import sys
_SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__))
_SUBMIT_ROOT = os.path.dirname(os.path.dirname(_SCRIPTS_DIR))
if _SCRIPTS_DIR not in sys.path:
sys.path.insert(0, _SCRIPTS_DIR)
if _SUBMIT_ROOT not in sys.path:
sys.path.insert(0, _SUBMIT_ROOT)
_DIRECTION_PATTERN = re.compile(
r"situated to the (left of|right of|in front of|behind|below|above) the",
re.IGNORECASE,
)
_SKIP_DIRECTIONS = {"below", "above"}
_DIRECTION_CANONICAL = {
"left of": "left",
"right of": "right",
"in front of": "front",
"behind": "behind",
}
_DIRECTION_TO_PHRASE = {
"left": "left of",
"right": "right of",
"in_front_of": "in front of",
"behind": "behind",
}
_FACING_PROMPT_TEMPLATE = (
"In the image, there is a {object_name}. "
"Which direction is the {object_name} facing from the camera's perspective? "
"Choose exactly one and output only that choice: "
"left, right, toward the camera, away from the camera."
)
_OBJECT_EXTRACTION_SYSTEM = (
"From the following sentence, identify the target object that is being acted upon, "
"placed, or referred to as the main subject of interest. "
"Output only the object name, nothing else."
)
def should_remap(question):
match = _DIRECTION_PATTERN.search(question)
if match is None:
return False, None, None
phrase = match.group(1).lower()
if phrase in _SKIP_DIRECTIONS:
return False, None, None
canonical = _DIRECTION_CANONICAL.get(phrase)
return True, canonical, phrase
def _parse_facing_direction(raw_output):
text = raw_output.lower()
if "away from" in text or "backward" in text:
return "facing_away_from_camera"
if "toward the camera" in text or "towards the camera" in text:
return "facing_toward_camera"
if "forward" in text:
return "facing_toward_camera"
has_left = "left" in text
has_right = "right" in text
if has_left and not has_right:
return "facing_left"
if has_right and not has_left:
return "facing_right"
if "toward" in text or "towards" in text:
return "facing_toward_camera"
return None
def _map_direction(orig_canonical, facing_key, direction_map):
if orig_canonical is None or facing_key is None:
return None
facing_entry = direction_map.get(facing_key, {})
return facing_entry.get(orig_canonical)
def _replace_direction_in_question(question, orig_phrase, new_canonical):
new_phrase = _DIRECTION_TO_PHRASE.get(new_canonical)
if new_phrase is None:
return question
old_pattern = f"to the {orig_phrase} the"
new_pattern = f"to the {new_phrase} the"
return question.replace(old_pattern, new_pattern)
def _extract_object_from_question(question, clf_kwargs):
m = re.search(r"there is a (.+?)\.", question, re.IGNORECASE)
if m:
return m.group(1).strip()
import torch
from lm_classifier import _apply_chat_template
first_sentence = (question.split(".")[0] + ".") if "." in question else question
clf_model = clf_kwargs["model"]
clf_tokenizer = clf_kwargs["tokenizer"]
first_device = next(clf_model.parameters()).device
messages = [
{"role": "system", "content": _OBJECT_EXTRACTION_SYSTEM},
{"role": "user", "content": f"Sentences: {first_sentence}"},
]
text = _apply_chat_template(clf_tokenizer, messages, enable_thinking=False)
inputs = clf_tokenizer([text], return_tensors="pt").to(first_device)
with torch.no_grad():
generated_ids = clf_model.generate(**inputs, max_new_tokens=16, do_sample=False)
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)]
raw = clf_tokenizer.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0].strip()
if "</think>" in raw:
raw = raw.split("</think>", 1)[1].strip()
if not raw or raw == "[]":
raw = "object"
return raw
def run_context_with_remap(question, image_path, depth_path, model_kwargs, clf_kwargs, direction_map):
import torch
from robobrain_runner import run_robobrain
from evaluation import _extract_first_point
do_remap, orig_dir_canonical, orig_dir_phrase = should_remap(question)
if not do_remap:
return run_robobrain(question, image_path, depth_path, model_kwargs, LM_classify="context")
answer1 = run_robobrain(question, image_path, depth_path, model_kwargs, LM_classify="context")
coord1_tuple, _ = _extract_first_point(answer1)
coord1_str = f"({coord1_tuple[0]}, {coord1_tuple[1]})" if coord1_tuple else ""
object_name = _extract_object_from_question(question, clf_kwargs)
torch.cuda.empty_cache()
obj_label = object_name.strip() or "object"
dir_prompt = _FACING_PROMPT_TEMPLATE.format(object_name=obj_label)
dir_answer = run_robobrain(dir_prompt, image_path, depth_path, model_kwargs, add_think_override=False)
facing_key = _parse_facing_direction(dir_answer)
new_dir = _map_direction(orig_dir_canonical, facing_key, direction_map)
if new_dir is None:
return coord1_str or answer1
new_question = _replace_direction_in_question(question, orig_dir_phrase, new_dir)
answer2 = run_robobrain(new_question, image_path, depth_path, model_kwargs, LM_classify="context")
coord2_tuple, _ = _extract_first_point(answer2)
coord2_str = f"({coord2_tuple[0]}, {coord2_tuple[1]})" if coord2_tuple else ""
parts = [p for p in [coord1_str, coord2_str] if p]
return " ".join(parts) if parts else (answer1 or answer2)