import os import re import sys _SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__)) _SUBMIT_ROOT = os.path.dirname(os.path.dirname(_SCRIPTS_DIR)) _DEFAULT_LM_PATH = os.path.join(_SUBMIT_ROOT, "model", "LM") _VALID_CATEGORIES = ["context", "configuration", "compatibility"] _SYSTEM_PROMPT = """\ You are a question classifier for a spatial reasoning benchmark. Classify the given question into exactly one of three categories: - context: The question asks to locate or pinpoint spatial coordinates/positions. These questions contain words like "pinpoint", "coordinates", "point to", or ask to mark positions as coordinate tuples. They describe a scene object and ask where certain spatial areas are. - configuration: The question asks whether an object is currently arranged in a certain spatial relationship with another object (its existing configuration). These questions typically start with "Is...", "Are...", "Does...", "Do...", "Was...", "Were..." and ask about real spatial state, answered with yes or no. - compatibility: The question asks whether an object *can* physically fit in or be placed at a certain position relative to another object. These questions typically start with "Can..." and ask about spatial feasibility, answered with yes or no. Reply with ONLY one word: context, configuration, or compatibility. Do not explain your answer.\ """ def _parse_category(raw_output): text = raw_output.strip().lower() if "" in text: text = text.split("", 1)[1].strip() text = re.sub(r"[^\w\s]", " ", text) stripped = text.strip() for label in _VALID_CATEGORIES: if stripped == label: return label words = stripped.split() if words: for label in _VALID_CATEGORIES: if words[0] == label: return label for label in sorted(_VALID_CATEGORIES, key=len, reverse=True): if label in stripped: return label return None def _apply_chat_template(tokenizer, messages, enable_thinking): try: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking, ) except TypeError: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) def load_lm_classifier(model_path=None): from transformers import AutoModelForCausalLM, AutoTokenizer resolved_path = model_path or _DEFAULT_LM_PATH print(f"[LM Classifier] Loading from {resolved_path} ...") clf_model = AutoModelForCausalLM.from_pretrained( resolved_path, torch_dtype="auto", device_map="auto" ) clf_tokenizer = AutoTokenizer.from_pretrained(resolved_path) first_device = next(clf_model.parameters()).device print(f"[LM Classifier] Loaded on {first_device}.") return { "model": clf_model, "tokenizer": clf_tokenizer, } def classify_single(question, clf_kwargs, enable_thinking=False, max_new_tokens=32): import torch clf_model = clf_kwargs["model"] clf_tokenizer = clf_kwargs["tokenizer"] first_device = next(clf_model.parameters()).device messages = [ {"role": "system", "content": _SYSTEM_PROMPT}, {"role": "user", "content": f'Classify this question:\n"{question}"'}, ] text = _apply_chat_template(clf_tokenizer, messages, enable_thinking) inputs = clf_tokenizer([text], return_tensors="pt").to(first_device) with torch.no_grad(): generated_ids = clf_model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, ) trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)] raw = clf_tokenizer.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0].strip() predicted = _parse_category(raw) return predicted if predicted is not None else "context"