-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathinference_mae.py
More file actions
52 lines (40 loc) · 1.88 KB
/
inference_mae.py
File metadata and controls
52 lines (40 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# gen2seg official inference pipeline code for Stable Diffusion model
#
# Please see our project website at https://reachomk.github.io/gen2seg
#
# Additionally, if you use our code please cite our paper, along with the two works above.
import os
import time
import torch
from gen2seg_mae_pipeline import gen2segMAEInstancePipeline # Custom pipeline for MAE
from transformers import AutoImageProcessor
from PIL import Image
import numpy as np
# Example usage: Update these paths as needed.
image_path = "/nfs_share3/om/diffusion-e2e-ft/plane1.jpg" # Path to the input image, change as needed.
output_path = "seg_mae.png" # Path to save the output image.
device = "cuda:0" # Change to "cpu" if no GPU is available.
print(f"Loading MAE pipeline on {device} for single image inference...")
# Load the image processor (using a pretrained processor from facebook/vit-mae-huge).
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge")
# Instantiate the pipeline and move it to the desired device.
pipe = gen2segMAEInstancePipeline(model="reachomk/gen2seg-mae-h", image_processor=image_processor).to(device)
# Load the image, storing the original size, then resize for inference.
orig_image = Image.open(image_path).convert("RGB")
orig_size = orig_image.size # (width, height)
image = orig_image.resize((224, 224))
# Run inference.
start_time = time.time()
with torch.no_grad():
pipe_output = pipe([image])
end_time = time.time()
print(f"Inference completed in {end_time - start_time:.2f} seconds.")
prediction = pipe_output.prediction[0]
# Convert the prediction to an image.
seg = np.array(prediction.squeeze()).astype(np.uint8)
seg_img = Image.fromarray(seg)
# Resize the segmentation output back to the original image size.
seg_img = seg_img.resize(orig_size, Image.LANCZOS)
# Save the output image.
seg_img.save(output_path)
print(f"Saved output image to {output_path}")