MissingBreath's picture
Update api.py
65d6b85 verified
Raw
History Blame Contribute Delete
2.57 kB
from fastapi import FastAPI, File, UploadFile
import numpy as np
from PIL import Image
import io
import tensorflow as tf
import os
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# # os.environ['HF_TOKEN']=''
# from huggingface_hub import login
# hf_token = os.getenv("HF_TOKEN")
# login(token=hf_token)
# Read token from environment
# hf_token = os.getenv("HF_TOKEN")
# print("HF_TOKEN:", hf_token)
# Load tokenizer directly with the token (no login)
# tokenizer = AutoTokenizer.from_pretrained(
# "chillies/distilbert-course-review-classification",
# token=hf_token # Pass it directly
# )
# tokenizer = AutoTokenizer.from_pretrained("chillies/distilbert-course-review-classification")
# from transformers import DistilBertTokenizer
# tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# model = AutoModelForSequenceClassification.from_pretrained("chillies/distilbert-course-review-classification")
# from transformers import DistilBertTokenizerFast
# tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
# from transformers import pipeline
# model = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_DIR = "./my_model"
TOKENIZER_DIR = "./my_tokenizer"
# Load the model and tokenizer
try:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
def inference(review):
inputs = tokenizer(review, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
# Assuming the model outputs logits
predicted_class = outputs.logits.argmax(dim=-1).item()
class_labels = [
'Improvement Suggestions', 'Questions', 'Confusion', 'Support Request',
'Discussion', 'Course Comparison', 'Related Course Suggestions',
'Negative', 'Positive'
]
return class_labels[predicted_class]
from pydantic import BaseModel
from typing import List
class ReviewRequest(BaseModel):
reviews: List[str]
app = FastAPI()
@app.post("/classify")
async def classify(request: ReviewRequest):
print("HERE", request)
reviews = request.reviews
predictions = []
for review in reviews:
predicted_class = inference(review)
predictions.append(predicted_class)
return {"predictions": predictions}