Description
Hi, I am trying to return an image using model prediction and it's throwing me an unknown error, In local I am able to get the output out of the model but in replicate it throws the above error.
##Base Code
from cog import BasePredictor, Input, Path,File
import os
import cv2
from PIL import Image
import pandas as pd
import os
import torch
from src.dataset import DatasetInference
from src.model import ModelAlpha, Model
from models.alpha_model_config import cfg as alpha_model_cfg
from models.trimap_model_config import cfg as trimap_model_cfg
from tqdm import tqdm
import numpy as np
import pandas as pd
import cv2
import glob
import torch
from torch.utils.data import DataLoader
import albumentations as alb
from os import path
import tempfile
os.environ['TRANSFORMERS_CACHE'] = './cache/'
os.environ['TORCH_HOME'] = './cache/'
class Predictor(BasePredictor):
def setup(self):
self.device = torch.device("cpu")
self.alpha_model_general , self.trimap_model_general = self.get_models("models/alpha_model_general.ckpt" , "models/trimap_model_general.ckpt" , self.device)
def remove_data(self):
normal_images = os.listdir("data/output/val/images/")
mask_path = "output_mask/"
mask_images = os.listdir(mask_path)
for mask , normal in zip(mask_images , normal_images):
os.remove(mask_path + mask)
os.remove("data/output/val/images/" + normal)
print("data cleaning complete!")
def create_csv(self,image_name , tracking_id):
test_dataset = {
"image_path" : [image_name],
"id" : [tracking_id]
}
dataframe = pd.DataFrame(test_dataset)
csv_file = "data/test_dataset_" + tracking_id + ".csv"
dataframe.to_csv(csv_file)
# print("CSV file created!")
return csv_file
def get_models(self,alpha_path , trimap_path , device = torch.device("cpu")):
# print('Load TriMap Model...')
trimap_model = Model(cfg=trimap_model_cfg)
trimap_model = trimap_model.load_from_checkpoint(
trimap_path, cfg=trimap_model_cfg)
trimap_model.freeze()
trimap_model = trimap_model.to(device)
# print('OK!\n')
# print('Load Alpha Model...')
alpha_model = ModelAlpha(cfg=alpha_model_cfg)
alpha_model = alpha_model.load_from_checkpoint(
alpha_path, cfg=alpha_model_cfg)
alpha_model.freeze()
alpha_model = alpha_model.to(device)
# print('OK!\n')
return alpha_model , trimap_model
def blend(self, image_path, alpha_path):
foreground = cv2.imread(image_path)
alpha = cv2.imread(alpha_path)
background = np.ones(foreground.shape)
foreground = foreground.astype(float)
background = background.astype(float)
alpha = alpha.astype(float)/255
# print(foreground.shape , alpha.shape)
foreground = cv2.multiply(alpha, foreground)
background = cv2.multiply(1.0 - alpha, background)
blend = cv2.add(foreground, background)
return blend
def prepare_dataset(self,**kwargs):
src_df = pd.read_csv(kwargs['image_csv_path'])
# print(f'Alpha masks will be built for {len(src_df)} image(s).')
# get path to save results
dst_dir = kwargs['mask_path']
if not os.path.exists(dst_dir):
os.makedirs(dst_dir, exist_ok=True)
# print(f'Alpha masks will be saved in {dst_dir}\n')
output_dir = kwargs["output_path"]
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# print('Load the models...')
# print('Create dataset...')
t = alb.Compose([
alb.Resize(512, 512),
alb.Normalize(mean=trimap_model_cfg.MEAN, std=trimap_model_cfg.STD)])
test_dataset = DatasetInference(df=src_df, transform=t)
test_dataloader = DataLoader(
dataset=test_dataset,
batch_size=kwargs.get('batch_size', 1),
shuffle=False,
num_workers=kwargs.get('workers', 2))
# print('OK!\n')
# iterate over data and build alphas
# print('Get alpha masks...')
return test_dataloader , dst_dir , output_dir
def predict_output(self,dataset , test_dataloader , dst_dir, output_dir , device = torch.device("cpu")):
public_urls = []
mask_paths = []
k = 0
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
# print(batch)
img, data = batch
img = img.to(device)
# print(img)
# print(img.shape)
trimap = self.trimap_model_general(img)
trimap = self.trimap_model_general.activation(trimap)
trimap = trimap.argmax(1, keepdim=True)
trimap_processed = trimap.clone().detach()
trimap_processed[trimap_processed == 1] = 255
trimap_processed[trimap_processed == 2] = 128
trimap_processed = trimap_processed.to(torch.float32)
trimap_processed = trimap_processed / 255
img_trimap = torch.cat([img, trimap_processed], dim=1)
trimap_output = trimap_processed.detach().cpu().numpy()*255
trimap_output = np.reshape(trimap_output , (trimap_output.shape[0] , trimap_output.shape[2] , trimap_output.shape[3] , trimap_output.shape[1]))
alpha = self.alpha_model_general(img_trimap)
alpha = self.alpha_model_general.activation(alpha['refine_output'])
trimap_final = alpha
k+=1
# print("reached for loop")
for i in range(len(img)):
# print(dataset['id'][i] , f'_{k}_{i}.png')
mask_path = os.path.join(
dst_dir,
dataset['id'][i] + f'_{k}_{i}.png')
m = trimap_final[i].cpu().detach().numpy()[0, ...]
m = np.uint8(m * 255)
m = cv2.resize(m, (int(data['size'][0][i]), int(data['size'][1][i] )))
# m = np.expand_dims(m, axis=2)
numpy_image = img[i].cpu().detach().numpy()
numpy_image = np.reshape(numpy_image , (numpy_image.shape[1] ,numpy_image.shape[2] ,numpy_image.shape[0]))
resized_image = cv2.resize(numpy_image , (int(data['size'][0][i]), int(data['size'][1][i])))
# print(resized_image.shape , m.shape)
# blended_image = blend(resized_image , m)
cv2.imwrite(f"{mask_path}" , m)
# cv2.imwrite(mask_path, m)
mask_paths.append(mask_path)
image_path = sorted(glob.glob('data/output/val/images/*'))
alpha_path = sorted(glob.glob(dst_dir + '/*'))
for i in tqdm(image_path):
# try:
ID = dataset[dataset["image_path"]==i]["id"].to_string().split(" ")[-1]
# print(dataset[dataset["image_path"]==i]["id"].to_string().split(" ")[-1])
for j in tqdm(alpha_path):
if ID in j:
blended_image = self.blend(i, j)
output_path = Path(tempfile.mkdtemp()) / "output.png"
cv2.imwrite(str(output_path) , blended_image)
# os.remove(output_location)
# except Exception as e:
# print(e)
self.remove_data()
return output_path
# The arguments and types the model takes as input
def predict(self,
image: Path = Input(description="Image to run inference on"),
tracking_id: str = Input(description="Insert Your tracking id",default="test")
) -> Path:
"""Run a single prediction on the model"""
filepath = str(image)
tracking_id = tracking_id
device = "cpu"
raw_image = Image.open(filepath).convert("RGB")
input_image_dir = "data/output/val/images"
if path.exists(input_image_dir):
pass
else:
os.makedirs(input_image_dir)
image_download_path = input_image_dir+"/"+"test.png"
# print(image_download_path)
raw_image.save(image_download_path)
csv_path = self.create_csv(image_download_path , tracking_id)
# print(csv_path)
if path.exists("blended_images/") and path.exists("output_mask/"):
pass
else:
os.makedirs("blended_images/")
os.makedirs("output_mask/")
req = {}
req["image_csv_path"] = csv_path
req["mask_path"] = "output_mask/"
req["output_path"] = "output/"
test_dataloader , dst_dir , output_dir = self.prepare_dataset(**req)
dataset = pd.read_csv(req["image_csv_path"])
# print(test_dataloader , dst_dir , output_dir , dataset.head())
output_location = self.predict_output(
dataset,
test_dataloader,
dst_dir,
output_dir,
device)
# print(output_location,type(output_location))
return output_location