File size: 4,064 Bytes
3bbe20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import re
import sys

_SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__))
_SUBMIT_ROOT = os.path.dirname(os.path.dirname(_SCRIPTS_DIR))
_DEFAULT_LM_PATH = os.path.join(_SUBMIT_ROOT, "model", "LM")

_VALID_CATEGORIES = ["context", "configuration", "compatibility"]

_SYSTEM_PROMPT = """\
You are a question classifier for a spatial reasoning benchmark.
Classify the given question into exactly one of three categories:

- context: The question asks to locate or pinpoint spatial coordinates/positions.
  These questions contain words like "pinpoint", "coordinates", "point to", or
  ask to mark positions as coordinate tuples. They describe a scene object and
  ask where certain spatial areas are.

- configuration: The question asks whether an object is currently arranged in a
  certain spatial relationship with another object (its existing configuration).
  These questions typically start with "Is...", "Are...", "Does...", "Do...",
  "Was...", "Were..." and ask about real spatial state, answered with yes or no.

- compatibility: The question asks whether an object *can* physically fit in or
  be placed at a certain position relative to another object.
  These questions typically start with "Can..." and ask about spatial feasibility,
  answered with yes or no.

Reply with ONLY one word: context, configuration, or compatibility.
Do not explain your answer.\
"""


def _parse_category(raw_output):
    text = raw_output.strip().lower()
    if "</think>" in text:
        text = text.split("</think>", 1)[1].strip()
    text = re.sub(r"[^\w\s]", " ", text)
    stripped = text.strip()

    for label in _VALID_CATEGORIES:
        if stripped == label:
            return label

    words = stripped.split()
    if words:
        for label in _VALID_CATEGORIES:
            if words[0] == label:
                return label

    for label in sorted(_VALID_CATEGORIES, key=len, reverse=True):
        if label in stripped:
            return label

    return None


def _apply_chat_template(tokenizer, messages, enable_thinking):
    try:
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking,
        )
    except TypeError:
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )


def load_lm_classifier(model_path=None):
    from transformers import AutoModelForCausalLM, AutoTokenizer

    resolved_path = model_path or _DEFAULT_LM_PATH
    print(f"[LM Classifier] Loading from {resolved_path} ...")
    clf_model = AutoModelForCausalLM.from_pretrained(
        resolved_path, torch_dtype="auto", device_map="auto"
    )
    clf_tokenizer = AutoTokenizer.from_pretrained(resolved_path)
    first_device = next(clf_model.parameters()).device
    print(f"[LM Classifier] Loaded on {first_device}.")

    return {
        "model": clf_model,
        "tokenizer": clf_tokenizer,
    }


def classify_single(question, clf_kwargs, enable_thinking=False, max_new_tokens=32):
    import torch

    clf_model = clf_kwargs["model"]
    clf_tokenizer = clf_kwargs["tokenizer"]
    first_device = next(clf_model.parameters()).device

    messages = [
        {"role": "system", "content": _SYSTEM_PROMPT},
        {"role": "user", "content": f'Classify this question:\n"{question}"'},
    ]
    text = _apply_chat_template(clf_tokenizer, messages, enable_thinking)
    inputs = clf_tokenizer([text], return_tensors="pt").to(first_device)

    with torch.no_grad():
        generated_ids = clf_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )
    trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)]
    raw = clf_tokenizer.batch_decode(
        trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0].strip()

    predicted = _parse_category(raw)
    return predicted if predicted is not None else "context"