-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
29 lines (23 loc) · 897 Bytes
/
test.py
File metadata and controls
29 lines (23 loc) · 897 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import matplotlib.pyplot as plt
import numpy as np
from model import model, denoise_process
def visualize_image(tensor_image):
image = tensor_image.cpu().detach().numpy()
image = np.transpose(image, (1, 2, 0))
image = np.clip(image, 0, 1)
plt.imshow(image)
plt.axis('off')
plt.show()
def generate_image(model, input_image, num_steps):
model.eval()
with torch.no_grad():
denoised_image = denoise_process(model, input_image, num_steps)
return denoised_image
model.load_state_dict(torch.load('Stable_Diffusion_2_epochs.pth'))
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
input_image = torch.randn((1, 3, 32, 32))
input_image = input_image.to('cuda' if torch.cuda.is_available() else 'cpu')
num_steps = 20
generated_image = generate_image(model, input_image, num_steps)
visualize_image(generated_image.squeeze(0))