Skip to content

Commit a0d42d5

Browse files
committed
Live mosaic style transfer working in sobel.py
1 parent 5efe5d5 commit a0d42d5

1 file changed

Lines changed: 23 additions & 17 deletions

File tree

demos/video/style-transfer/sobel.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,11 @@ def tensor_to_bgr(frame_tensor, *, undo_normalise=False, mean=None, std=None):
110110

111111
# 4) scale back to 0‑255, clamp, uint8
112112
img = (img * 255.0)
113-
img = img.cpu().to(torch.float32)
113+
# img = img # .to(torch.float16)
114114
img = img.clamp(0,255).byte()
115115

116116
# 5) channel‑last & numpy
117-
img = img.permute(1,2,0).numpy() # H,W,C RGB
117+
img = img.permute(1,2,0).cpu().numpy() # H,W,C RGB
118118
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # → BGR
119119
img = np.ascontiguousarray(img) # ensure OpenCV‑happy
120120
return img
@@ -169,6 +169,10 @@ def forward(self, rgb):
169169
sm = torch.jit.load("models/sobel_edge_float32.pt")
170170
# sm = torch.jit.load("models/mosaic_float32.pt")
171171
# sm.to('mps')
172+
173+
mosaic = torch.jit.load("models/mosaic_float16.pt")
174+
mosaic.to('mps')
175+
172176
# print(sm)
173177

174178
import sys
@@ -205,32 +209,34 @@ def forward(self, rgb):
205209
# 5) (Optional) add a batch dim and push to GPU ------------------------------
206210
tensor = tensor.unsqueeze(0) # 1 x C x H x W
207211

208-
if ticks == 3:
209-
mosaic = torch.jit.load("models/mosaic_float32.pt")
210-
mosaic.to('mps')
211-
mosaic_output = mosaic(tensor) / 255.0
212-
# mosaic_output = undo_normalize(mosaic_output)
213-
print('input:',tensor.shape,tensor.dtype)
214-
print('mosaic output:',mosaic_output.shape)
215-
torchvision.utils.save_image(tensor[0], 'input_tensor.png')
216-
torchvision.utils.save_image(mosaic_output[0], 'mosaic_output.png')
212+
# if ticks == 3:
213+
# tensor = tensor.to(torch.float16)
214+
# mosaic = torch.jit.load("models/mosaic_float16.pt")
215+
# mosaic.to('mps')
216+
# mosaic_output = mosaic(tensor) / 255.0
217+
# # mosaic_output = undo_normalize(mosaic_output)
218+
# print('input:',tensor.shape,tensor.dtype)
219+
# print('mosaic output:',mosaic_output.shape,mosaic_output.dtype)
220+
# torchvision.utils.save_image(tensor[0], 'input_tensor.png')
221+
# torchvision.utils.save_image(mosaic_output[0], 'mosaic_output.png')
217222

218-
sys.exit(0)
223+
# sys.exit(0)
219224

220-
output_tensor = sm(tensor)
221-
print('input:',tensor.shape,tensor.dtype)
222-
print('output:',output_tensor.shape)
225+
output_tensor = mosaic(tensor.to(torch.float16)) / 255.0
226+
# print('input:',tensor.shape,tensor.dtype)
227+
# print('output:',output_tensor.shape)
223228

224229

225-
frame_bgr_out = tensor_to_bgr(output_tensor, undo_normalise=True,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
230+
# frame_bgr_out = tensor_to_bgr(output_tensor, undo_normalise=True,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
231+
frame_bgr_out = tensor_to_bgr(output_tensor)
226232

227233
# Display the captured frame
228234
cv2.imshow('Camera', frame_bgr_out)
229235

230236
# time.sleep(1.0)
231237

232238
# Press 'q' to exit the loop
233-
# if ticks > 3:
239+
# if ticks > 10:
234240
# break
235241

236242
if cv2.waitKey(1) == ord('q'):

0 commit comments

Comments
 (0)