File size: 3,775 Bytes
3bbe20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
065cb81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bbe20b
 
 
 
 
 
 
 
 
 
 
 
 
065cb81
 
 
3bbe20b
065cb81
3bbe20b
 
065cb81
3bbe20b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import sys

_SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__))
_SUBMIT_ROOT = os.path.dirname(os.path.dirname(_SCRIPTS_DIR))

if _SCRIPTS_DIR not in sys.path:
    sys.path.insert(0, _SCRIPTS_DIR)
if _SUBMIT_ROOT not in sys.path:
    sys.path.insert(0, _SUBMIT_ROOT)

_LM_PATH = os.path.join(_SUBMIT_ROOT, "model", "LM")
_VL_B_PATH = os.path.join(_SUBMIT_ROOT, "model", "VL-B")
_VL_F_PATH = os.path.join(_SUBMIT_ROOT, "model", "VL-F")
_DIRECTION_MAP_PATH = os.path.join(_SCRIPTS_DIR, "object_direction_map.yaml")


def load_robospatialBrain_models(low_memory=False):
    import yaml
    import torch
    from lm_classifier import load_lm_classifier
    from robobrain_runner import load_robobrain

    lm_clf = load_lm_classifier(_LM_PATH)
    torch.cuda.empty_cache()

    with open(_DIRECTION_MAP_PATH) as f:
        direction_map = yaml.safe_load(f)

    if low_memory:
        return {
            "lm_clf": lm_clf,
            "direction_map": direction_map,
            "low_memory": True,
            "_vl_b_path": _VL_B_PATH,
            "_vl_f_path": _VL_F_PATH,
            "_current_vl": None,
            "_current_vl_type": None,
        }

    vl_b = load_robobrain(_VL_B_PATH)
    torch.cuda.empty_cache()

    vl_f = load_robobrain(_VL_F_PATH)
    torch.cuda.empty_cache()

    return {
        "lm_clf": lm_clf,
        "vl_b": vl_b,
        "vl_f": vl_f,
        "direction_map": direction_map,
        "low_memory": False,
    }


def _swap_vl(kwargs, needed_type, load_robobrain):
    import torch
    if kwargs["_current_vl_type"] != needed_type:
        if kwargs["_current_vl"] is not None:
            kwargs["_current_vl"].clear()
            kwargs["_current_vl"] = None
            torch.cuda.empty_cache()
        path = kwargs["_vl_f_path"] if needed_type == "F" else kwargs["_vl_b_path"]
        kwargs["_current_vl"] = load_robobrain(path)
        kwargs["_current_vl_type"] = needed_type


def _run_compatibility(question, image_path, kwargs, low_memory, run_robobrain, load_robobrain):
    if low_memory:
        _swap_vl(kwargs, "F", load_robobrain)
        result_f = run_robobrain(question, image_path, None, kwargs["_current_vl"], LM_classify="compatibility")
        if result_f != "yes":
            return result_f
        _swap_vl(kwargs, "B", load_robobrain)
        return run_robobrain(question, image_path, None, kwargs["_current_vl"], LM_classify="compatibility")
    else:
        result_f = run_robobrain(question, image_path, None, kwargs["vl_f"], LM_classify="compatibility")
        if result_f != "yes":
            return result_f
        return run_robobrain(question, image_path, None, kwargs["vl_b"], LM_classify="compatibility")


def inference_single(question, image_path, kwargs):
    import torch
    from lm_classifier import classify_single
    from object_direction_remap import should_remap, run_context_with_remap
    from robobrain_runner import run_robobrain, load_robobrain

    lm_clf = kwargs["lm_clf"]
    direction_map = kwargs["direction_map"]
    low_memory = kwargs.get("low_memory", False)

    category = classify_single(question, lm_clf)
    torch.cuda.empty_cache()

    if category == "compatibility":
        return _run_compatibility(question, image_path, kwargs, low_memory, run_robobrain, load_robobrain)

    if low_memory:
        _swap_vl(kwargs, "B", load_robobrain)
        model_kw = kwargs["_current_vl"]
    else:
        model_kw = kwargs["vl_b"]

    if category == "context":
        do_remap, _, _ = should_remap(question)
        if do_remap:
            return run_context_with_remap(
                question, image_path, None, model_kw, lm_clf, direction_map
            )

    return run_robobrain(question, image_path, None, model_kw, LM_classify=category)