| import os |
| import re |
| import sys |
|
|
| _SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__)) |
| _SUBMIT_ROOT = os.path.dirname(os.path.dirname(_SCRIPTS_DIR)) |
| _DEFAULT_LM_PATH = os.path.join(_SUBMIT_ROOT, "model", "LM") |
|
|
| _VALID_CATEGORIES = ["context", "configuration", "compatibility"] |
|
|
| _SYSTEM_PROMPT = """\ |
| You are a question classifier for a spatial reasoning benchmark. |
| Classify the given question into exactly one of three categories: |
| |
| - context: The question asks to locate or pinpoint spatial coordinates/positions. |
| These questions contain words like "pinpoint", "coordinates", "point to", or |
| ask to mark positions as coordinate tuples. They describe a scene object and |
| ask where certain spatial areas are. |
| |
| - configuration: The question asks whether an object is currently arranged in a |
| certain spatial relationship with another object (its existing configuration). |
| These questions typically start with "Is...", "Are...", "Does...", "Do...", |
| "Was...", "Were..." and ask about real spatial state, answered with yes or no. |
| |
| - compatibility: The question asks whether an object *can* physically fit in or |
| be placed at a certain position relative to another object. |
| These questions typically start with "Can..." and ask about spatial feasibility, |
| answered with yes or no. |
| |
| Reply with ONLY one word: context, configuration, or compatibility. |
| Do not explain your answer.\ |
| """ |
|
|
|
|
| def _parse_category(raw_output): |
| text = raw_output.strip().lower() |
| if "</think>" in text: |
| text = text.split("</think>", 1)[1].strip() |
| text = re.sub(r"[^\w\s]", " ", text) |
| stripped = text.strip() |
|
|
| for label in _VALID_CATEGORIES: |
| if stripped == label: |
| return label |
|
|
| words = stripped.split() |
| if words: |
| for label in _VALID_CATEGORIES: |
| if words[0] == label: |
| return label |
|
|
| for label in sorted(_VALID_CATEGORIES, key=len, reverse=True): |
| if label in stripped: |
| return label |
|
|
| return None |
|
|
|
|
| def _apply_chat_template(tokenizer, messages, enable_thinking): |
| try: |
| return tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=enable_thinking, |
| ) |
| except TypeError: |
| return tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
|
|
| def load_lm_classifier(model_path=None): |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| resolved_path = model_path or _DEFAULT_LM_PATH |
| print(f"[LM Classifier] Loading from {resolved_path} ...") |
| clf_model = AutoModelForCausalLM.from_pretrained( |
| resolved_path, torch_dtype="auto", device_map="auto" |
| ) |
| clf_tokenizer = AutoTokenizer.from_pretrained(resolved_path) |
| first_device = next(clf_model.parameters()).device |
| print(f"[LM Classifier] Loaded on {first_device}.") |
|
|
| return { |
| "model": clf_model, |
| "tokenizer": clf_tokenizer, |
| } |
|
|
|
|
| def classify_single(question, clf_kwargs, enable_thinking=False, max_new_tokens=32): |
| import torch |
|
|
| clf_model = clf_kwargs["model"] |
| clf_tokenizer = clf_kwargs["tokenizer"] |
| first_device = next(clf_model.parameters()).device |
|
|
| messages = [ |
| {"role": "system", "content": _SYSTEM_PROMPT}, |
| {"role": "user", "content": f'Classify this question:\n"{question}"'}, |
| ] |
| text = _apply_chat_template(clf_tokenizer, messages, enable_thinking) |
| inputs = clf_tokenizer([text], return_tensors="pt").to(first_device) |
|
|
| with torch.no_grad(): |
| generated_ids = clf_model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| ) |
| trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)] |
| raw = clf_tokenizer.batch_decode( |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| )[0].strip() |
|
|
| predicted = _parse_category(raw) |
| return predicted if predicted is not None else "context" |
|
|