RoboSpatialBrain / scripts /lm_classifier.py
lbx511's picture
Upload folder using huggingface_hub
3bbe20b verified
import os
import re
import sys
_SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__))
_SUBMIT_ROOT = os.path.dirname(os.path.dirname(_SCRIPTS_DIR))
_DEFAULT_LM_PATH = os.path.join(_SUBMIT_ROOT, "model", "LM")
_VALID_CATEGORIES = ["context", "configuration", "compatibility"]
_SYSTEM_PROMPT = """\
You are a question classifier for a spatial reasoning benchmark.
Classify the given question into exactly one of three categories:
- context: The question asks to locate or pinpoint spatial coordinates/positions.
These questions contain words like "pinpoint", "coordinates", "point to", or
ask to mark positions as coordinate tuples. They describe a scene object and
ask where certain spatial areas are.
- configuration: The question asks whether an object is currently arranged in a
certain spatial relationship with another object (its existing configuration).
These questions typically start with "Is...", "Are...", "Does...", "Do...",
"Was...", "Were..." and ask about real spatial state, answered with yes or no.
- compatibility: The question asks whether an object *can* physically fit in or
be placed at a certain position relative to another object.
These questions typically start with "Can..." and ask about spatial feasibility,
answered with yes or no.
Reply with ONLY one word: context, configuration, or compatibility.
Do not explain your answer.\
"""
def _parse_category(raw_output):
text = raw_output.strip().lower()
if "</think>" in text:
text = text.split("</think>", 1)[1].strip()
text = re.sub(r"[^\w\s]", " ", text)
stripped = text.strip()
for label in _VALID_CATEGORIES:
if stripped == label:
return label
words = stripped.split()
if words:
for label in _VALID_CATEGORIES:
if words[0] == label:
return label
for label in sorted(_VALID_CATEGORIES, key=len, reverse=True):
if label in stripped:
return label
return None
def _apply_chat_template(tokenizer, messages, enable_thinking):
try:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
def load_lm_classifier(model_path=None):
from transformers import AutoModelForCausalLM, AutoTokenizer
resolved_path = model_path or _DEFAULT_LM_PATH
print(f"[LM Classifier] Loading from {resolved_path} ...")
clf_model = AutoModelForCausalLM.from_pretrained(
resolved_path, torch_dtype="auto", device_map="auto"
)
clf_tokenizer = AutoTokenizer.from_pretrained(resolved_path)
first_device = next(clf_model.parameters()).device
print(f"[LM Classifier] Loaded on {first_device}.")
return {
"model": clf_model,
"tokenizer": clf_tokenizer,
}
def classify_single(question, clf_kwargs, enable_thinking=False, max_new_tokens=32):
import torch
clf_model = clf_kwargs["model"]
clf_tokenizer = clf_kwargs["tokenizer"]
first_device = next(clf_model.parameters()).device
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": f'Classify this question:\n"{question}"'},
]
text = _apply_chat_template(clf_tokenizer, messages, enable_thinking)
inputs = clf_tokenizer([text], return_tensors="pt").to(first_device)
with torch.no_grad():
generated_ids = clf_model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)]
raw = clf_tokenizer.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0].strip()
predicted = _parse_category(raw)
return predicted if predicted is not None else "context"