upsample / handler.py
jayyap's picture
Update handler.py
0edcf54
Raw
History Blame Contribute Delete
1.4 kB
from typing import Dict, List, Any
from diffusers import DiffusionPipeline
import torch
from io import BytesIO
import requests
from PIL import Image
import base64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages", torch_dtype=dtype).to(device)
# # this command loads the individual model components on GPU on-demand.
# self.pipeline.enable_model_cpu_offload()
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
image = data.pop("image", None)
# process image
image = self.decode_base64_image(image)
low_res_img = image#.resize((128, 128))
with torch.no_grad():
upscaled_image = self.pipeline(low_res_img, num_inference_steps=100, eta=1).images[0]
return upscaled_image
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image