Skip to content

Output Unknown error handling prediction. #1269

Open
@swapnil-lader

Description

@swapnil-lader

Hi, I am trying to return an image using model prediction and it's throwing me an unknown error, In local I am able to get the output out of the model but in replicate it throws the above error.

##Base Code

from cog import BasePredictor, Input, Path,File
import os
import cv2
from PIL import Image
import pandas as pd
import os
import torch
from src.dataset import DatasetInference
from src.model import ModelAlpha, Model
from models.alpha_model_config import cfg as alpha_model_cfg
from models.trimap_model_config import cfg as trimap_model_cfg
from tqdm import tqdm
import numpy as np
import pandas as pd
import cv2
import glob
import torch
from torch.utils.data import DataLoader
import albumentations as alb
from os import path
import tempfile
os.environ['TRANSFORMERS_CACHE'] = './cache/'
os.environ['TORCH_HOME'] = './cache/'

class Predictor(BasePredictor):

def setup(self):
    self.device = torch.device("cpu")
    self.alpha_model_general , self.trimap_model_general = self.get_models("models/alpha_model_general.ckpt" , "models/trimap_model_general.ckpt" , self.device)

def remove_data(self):
    normal_images = os.listdir("data/output/val/images/")
    mask_path = "output_mask/"
    mask_images = os.listdir(mask_path)

    for mask , normal in zip(mask_images , normal_images):
        os.remove(mask_path + mask)
        os.remove("data/output/val/images/" + normal)

    print("data cleaning complete!")
    
def create_csv(self,image_name , tracking_id):    
    test_dataset = {
        "image_path" : [image_name],
        "id" : [tracking_id]
    }
    dataframe = pd.DataFrame(test_dataset)
    csv_file = "data/test_dataset_" + tracking_id + ".csv"
    dataframe.to_csv(csv_file)
    # print("CSV file created!")
    return csv_file

def get_models(self,alpha_path , trimap_path , device = torch.device("cpu")):
    # print('Load TriMap Model...')
    trimap_model = Model(cfg=trimap_model_cfg)
    trimap_model = trimap_model.load_from_checkpoint(
        trimap_path, cfg=trimap_model_cfg)
    trimap_model.freeze()
    trimap_model = trimap_model.to(device)
    # print('OK!\n')

    # print('Load Alpha Model...')
    alpha_model = ModelAlpha(cfg=alpha_model_cfg)
    alpha_model = alpha_model.load_from_checkpoint(
        alpha_path, cfg=alpha_model_cfg)
    alpha_model.freeze()
    alpha_model = alpha_model.to(device)
    # print('OK!\n')

    return alpha_model , trimap_model

def blend(self, image_path, alpha_path):   
    foreground = cv2.imread(image_path)
    alpha = cv2.imread(alpha_path)
    background = np.ones(foreground.shape)

    foreground = foreground.astype(float)
    background = background.astype(float)

    alpha = alpha.astype(float)/255
    
    # print(foreground.shape , alpha.shape)

    foreground = cv2.multiply(alpha, foreground)
    background = cv2.multiply(1.0 - alpha, background)

    blend = cv2.add(foreground, background)

    return blend

def prepare_dataset(self,**kwargs):
    src_df = pd.read_csv(kwargs['image_csv_path'])
    # print(f'Alpha masks will be built for {len(src_df)} image(s).')

    #  get path to save results
    dst_dir = kwargs['mask_path']
    if not os.path.exists(dst_dir):
        os.makedirs(dst_dir, exist_ok=True)

    # print(f'Alpha masks will be saved in {dst_dir}\n')

    output_dir = kwargs["output_path"]
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # print('Load the models...')

    # print('Create dataset...')
    t = alb.Compose([
        alb.Resize(512, 512),
        alb.Normalize(mean=trimap_model_cfg.MEAN, std=trimap_model_cfg.STD)])

    test_dataset = DatasetInference(df=src_df, transform=t)

    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=kwargs.get('batch_size', 1),
        shuffle=False,
        num_workers=kwargs.get('workers', 2))
    # print('OK!\n')

    #  iterate over data and build alphas
    # print('Get alpha masks...')

    return test_dataloader , dst_dir , output_dir

def predict_output(self,dataset , test_dataloader  , dst_dir, output_dir , device = torch.device("cpu")):
    public_urls = []
    mask_paths = []
    k = 0

    for batch in tqdm(test_dataloader, total=len(test_dataloader)):
        # print(batch)
        img, data = batch
        img = img.to(device)
        # print(img)
        # print(img.shape)
        trimap = self.trimap_model_general(img)
        trimap = self.trimap_model_general.activation(trimap)
        trimap = trimap.argmax(1, keepdim=True)
        trimap_processed = trimap.clone().detach()
        trimap_processed[trimap_processed == 1] = 255
        trimap_processed[trimap_processed == 2] = 128
        trimap_processed = trimap_processed.to(torch.float32)
        trimap_processed = trimap_processed / 255

        img_trimap = torch.cat([img, trimap_processed], dim=1)
        trimap_output = trimap_processed.detach().cpu().numpy()*255
        trimap_output = np.reshape(trimap_output , (trimap_output.shape[0] , trimap_output.shape[2] , trimap_output.shape[3] , trimap_output.shape[1]))

        alpha = self.alpha_model_general(img_trimap)
        alpha = self.alpha_model_general.activation(alpha['refine_output'])

        trimap_final = alpha
        k+=1
        # print("reached for loop")
        for i in range(len(img)):
            # print(dataset['id'][i] , f'_{k}_{i}.png')
            mask_path = os.path.join(
                dst_dir,
                dataset['id'][i] + f'_{k}_{i}.png')

            m = trimap_final[i].cpu().detach().numpy()[0, ...]
            m = np.uint8(m * 255)

            m = cv2.resize(m, (int(data['size'][0][i]), int(data['size'][1][i] )))
            # m = np.expand_dims(m, axis=2)

            numpy_image = img[i].cpu().detach().numpy()
            
            numpy_image = np.reshape(numpy_image , (numpy_image.shape[1] ,numpy_image.shape[2] ,numpy_image.shape[0]))

            resized_image = cv2.resize(numpy_image , (int(data['size'][0][i]), int(data['size'][1][i])))
            # print(resized_image.shape , m.shape)

            # blended_image = blend(resized_image , m)

            cv2.imwrite(f"{mask_path}" , m)
            # cv2.imwrite(mask_path, m)

            mask_paths.append(mask_path)

        image_path = sorted(glob.glob('data/output/val/images/*'))
        alpha_path = sorted(glob.glob(dst_dir + '/*'))


        for i in tqdm(image_path):
            # try:
            ID = dataset[dataset["image_path"]==i]["id"].to_string().split(" ")[-1]
            # print(dataset[dataset["image_path"]==i]["id"].to_string().split(" ")[-1])

            for j in tqdm(alpha_path):
                if ID in j: 
                    blended_image =  self.blend(i, j)
                    output_path = Path(tempfile.mkdtemp()) / "output.png"
                    cv2.imwrite(str(output_path) , blended_image)

                    # os.remove(output_location)

            # except Exception as e:
            #     print(e)

        self.remove_data()

        return output_path

# The arguments and types the model takes as input
def predict(self,
        image: Path = Input(description="Image to run inference on"),
        tracking_id: str = Input(description="Insert Your tracking id",default="test")
) -> Path:
    """Run a single prediction on the model"""
    filepath = str(image)
    tracking_id = tracking_id
    device = "cpu"
    raw_image = Image.open(filepath).convert("RGB")
    input_image_dir = "data/output/val/images"
    if path.exists(input_image_dir):
        pass
    else:
        os.makedirs(input_image_dir)
    image_download_path = input_image_dir+"/"+"test.png"
    # print(image_download_path)
    raw_image.save(image_download_path)
    
    csv_path = self.create_csv(image_download_path , tracking_id)
    # print(csv_path)
    if path.exists("blended_images/") and path.exists("output_mask/"):
        pass
    else:
        os.makedirs("blended_images/")
        os.makedirs("output_mask/")
    req = {}    
    req["image_csv_path"] = csv_path
    req["mask_path"] = "output_mask/"
    req["output_path"] = "output/"
    test_dataloader , dst_dir , output_dir = self.prepare_dataset(**req)   
    dataset = pd.read_csv(req["image_csv_path"])
    
    # print(test_dataloader , dst_dir , output_dir , dataset.head())
    
    output_location = self.predict_output(
        dataset,
        test_dataloader,
        dst_dir,
        output_dir,
        device)
    # print(output_location,type(output_location))
    return output_location

Metadata

Metadata

Assignees

No one assigned

    Labels

    BackendIssues with the replicate backend

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions