RoboSpatialBrain / scripts /inference.py
lbx511's picture
Add files using upload-large-folder tool
065cb81 verified
import os
import sys
_SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__))
_SUBMIT_ROOT = os.path.dirname(os.path.dirname(_SCRIPTS_DIR))
if _SCRIPTS_DIR not in sys.path:
sys.path.insert(0, _SCRIPTS_DIR)
if _SUBMIT_ROOT not in sys.path:
sys.path.insert(0, _SUBMIT_ROOT)
_LM_PATH = os.path.join(_SUBMIT_ROOT, "model", "LM")
_VL_B_PATH = os.path.join(_SUBMIT_ROOT, "model", "VL-B")
_VL_F_PATH = os.path.join(_SUBMIT_ROOT, "model", "VL-F")
_DIRECTION_MAP_PATH = os.path.join(_SCRIPTS_DIR, "object_direction_map.yaml")
def load_robospatialBrain_models(low_memory=False):
import yaml
import torch
from lm_classifier import load_lm_classifier
from robobrain_runner import load_robobrain
lm_clf = load_lm_classifier(_LM_PATH)
torch.cuda.empty_cache()
with open(_DIRECTION_MAP_PATH) as f:
direction_map = yaml.safe_load(f)
if low_memory:
return {
"lm_clf": lm_clf,
"direction_map": direction_map,
"low_memory": True,
"_vl_b_path": _VL_B_PATH,
"_vl_f_path": _VL_F_PATH,
"_current_vl": None,
"_current_vl_type": None,
}
vl_b = load_robobrain(_VL_B_PATH)
torch.cuda.empty_cache()
vl_f = load_robobrain(_VL_F_PATH)
torch.cuda.empty_cache()
return {
"lm_clf": lm_clf,
"vl_b": vl_b,
"vl_f": vl_f,
"direction_map": direction_map,
"low_memory": False,
}
def _swap_vl(kwargs, needed_type, load_robobrain):
import torch
if kwargs["_current_vl_type"] != needed_type:
if kwargs["_current_vl"] is not None:
kwargs["_current_vl"].clear()
kwargs["_current_vl"] = None
torch.cuda.empty_cache()
path = kwargs["_vl_f_path"] if needed_type == "F" else kwargs["_vl_b_path"]
kwargs["_current_vl"] = load_robobrain(path)
kwargs["_current_vl_type"] = needed_type
def _run_compatibility(question, image_path, kwargs, low_memory, run_robobrain, load_robobrain):
if low_memory:
_swap_vl(kwargs, "F", load_robobrain)
result_f = run_robobrain(question, image_path, None, kwargs["_current_vl"], LM_classify="compatibility")
if result_f != "yes":
return result_f
_swap_vl(kwargs, "B", load_robobrain)
return run_robobrain(question, image_path, None, kwargs["_current_vl"], LM_classify="compatibility")
else:
result_f = run_robobrain(question, image_path, None, kwargs["vl_f"], LM_classify="compatibility")
if result_f != "yes":
return result_f
return run_robobrain(question, image_path, None, kwargs["vl_b"], LM_classify="compatibility")
def inference_single(question, image_path, kwargs):
import torch
from lm_classifier import classify_single
from object_direction_remap import should_remap, run_context_with_remap
from robobrain_runner import run_robobrain, load_robobrain
lm_clf = kwargs["lm_clf"]
direction_map = kwargs["direction_map"]
low_memory = kwargs.get("low_memory", False)
category = classify_single(question, lm_clf)
torch.cuda.empty_cache()
if category == "compatibility":
return _run_compatibility(question, image_path, kwargs, low_memory, run_robobrain, load_robobrain)
if low_memory:
_swap_vl(kwargs, "B", load_robobrain)
model_kw = kwargs["_current_vl"]
else:
model_kw = kwargs["vl_b"]
if category == "context":
do_remap, _, _ = should_remap(question)
if do_remap:
return run_context_with_remap(
question, image_path, None, model_kw, lm_clf, direction_map
)
return run_robobrain(question, image_path, None, model_kw, LM_classify=category)