Skip to content

Commit 53e667f

Browse files
committed
fix, minor bugs
1 parent 535ff3a commit 53e667f

File tree

5 files changed

+20
-6
lines changed

5 files changed

+20
-6
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ deeplens/do_code/
1010
!lenses/*.png
1111
!lenses/**/*.png
1212
!images/*.png
13+
!assets/*.png
1314
debug.py
1415
*.log
1516
temp/*
@@ -22,3 +23,5 @@ datasets/BSDS300/
2223
docs/_build/
2324

2425
# add my own files
26+
test/test_outputs
27+
visualization/

deeplens/optics/monte_carlo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def assign_points_to_pixels(
192192
if phase is None:
193193
raise ValueError("Phase must be provided for coherent mode")
194194

195-
grid = torch.zeros(ks, ks, dtype=torch.complex128).to(device)
195+
c_dtype = torch.complex64 if points.dtype == torch.float32 else torch.complex128
196+
grid = torch.zeros(ks, ks, dtype=c_dtype).to(device)
196197
grid.index_put_(
197198
tuple(pixel_indices_tl.t()),
198199
(1 - w_b) * (1 - w_r) * mask * amp * torch.exp(1j * phase),
@@ -246,7 +247,8 @@ def assign_points_to_pixels(
246247
if phase is None:
247248
raise ValueError("Phase must be provided for coherent mode")
248249

249-
grid = torch.zeros(ks, ks, dtype=torch.complex128).to(device)
250+
c_dtype = torch.complex64 if points.dtype == torch.float32 else torch.complex128
251+
grid = torch.zeros(ks, ks, dtype=c_dtype).to(device)
250252
grid.index_put_(
251253
tuple(pixel_indices_tl.t()),
252254
mask * amp * torch.exp(1j * phase),

deeplens/optics/wave.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def pad(self, Hpad, Wpad):
414414
self.phy_size[1] * self.res[1] / Worg,
415415
]
416416
self.x, self.y = self.gen_xy_grid()
417-
self.z = F.pad(self.z, (Hpad, Hpad, Wpad, Wpad), mode="replicate")
417+
self.z = torch.full_like(self.x, self.z[0, 0].item())
418418

419419
def flip(self):
420420
"""Flip the field horizontally and vertically."""

deeplens/sensor/sensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def simu_noise(self, img_raw, iso):
143143

144144
# Calculate noise standard deviation
145145
shotnoise_std = torch.clamp(
146-
self.shotnoise_std_alpha * torch.sqrt(img_raw - black_level)
146+
self.shotnoise_std_alpha * torch.sqrt(torch.clamp(img_raw - black_level, min=0.0))
147147
+ self.shotnoise_std_beta,
148148
0.0,
149149
)

deeplens/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,17 @@ def img2batch(img):
136136
img = (
137137
torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
138138
) # (H, W, C) -> (1, C, H, W)
139+
elif torch.is_tensor(img):
140+
if img.shape[0] in [1, 3]:
141+
# Assume (C, H, W) -> (1, C, H, W)
142+
img = img.unsqueeze(0)
143+
elif img.shape[-1] in [1, 3]:
144+
# Assume (H, W, C) -> (1, C, H, W)
145+
img = img.permute(2, 0, 1).unsqueeze(0)
146+
else:
147+
raise ValueError("Image channel should be 1 or 3.")
139148
else:
140-
raise ValueError("Image should be numpy array.")
149+
raise ValueError("Image should be numpy array or torch tensor.")
141150

142151
# Tensor dtype
143152
if img.dtype == torch.uint8:
@@ -187,7 +196,7 @@ def batch_psnr(pred, target, max_val=1.0, eps=1e-8):
187196
# Calculate PSNR
188197
psnr = 20 * torch.log10(max_val / torch.sqrt(mse + eps))
189198

190-
return psnr.item()
199+
return psnr
191200

192201

193202
def batch_SSIM(img, img_clean):

0 commit comments

Comments
 (0)