Kronos-small / handler.py
Imp3rtinence
fix: remove .dt accessor from DatetimeIndex
63d24c3
Raw
History Blame Contribute Delete
1.61 kB
import torch
import numpy as np
import pandas as pd
import os
from model.kronos import Kronos, KronosTokenizer, KronosPredictor
class EndpointHandler:
def __init__(self, path=""):
tokenizer_path = os.path.join(path, "tokenizer")
tokenizer = KronosTokenizer.from_pretrained(tokenizer_path)
model = Kronos.from_pretrained(path)
self.predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
def __call__(self, data):
inputs = data.get("inputs", [])
parameters = data.get("parameters", {})
prediction_length = parameters.get("prediction_length", 8)
if len(inputs) == 0:
return {"error": "No input data"}
cols = ["open", "high", "low", "close", "volume"]
if isinstance(inputs[0], list):
df = pd.DataFrame(inputs, columns=cols[:len(inputs[0])])
else:
df = pd.DataFrame({"open": inputs, "high": inputs, "low": inputs, "close": inputs})
if "volume" not in df.columns:
df["volume"] = 0.0
now = pd.Timestamp.now()
x_timestamps = pd.date_range(end=now, periods=len(df), freq="15min")
y_timestamps = pd.date_range(start=now + pd.Timedelta("15min"), periods=prediction_length, freq="15min")
pred_df = self.predictor.predict(
df, x_timestamps, y_timestamps,
pred_len=prediction_length,
T=1.0, top_k=0, top_p=0.9,
sample_count=5, verbose=False
)
result = pred_df[["open", "high", "low", "close"]].values.tolist()
return {"predictions": result}