We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 535ff3a commit 53e667fCopy full SHA for 53e667f
.gitignore
@@ -10,6 +10,7 @@ deeplens/do_code/
10
!lenses/*.png
11
!lenses/**/*.png
12
!images/*.png
13
+!assets/*.png
14
debug.py
15
*.log
16
temp/*
@@ -22,3 +23,5 @@ datasets/BSDS300/
22
23
docs/_build/
24
25
# add my own files
26
+test/test_outputs
27
+visualization/
deeplens/optics/monte_carlo.py
@@ -192,7 +192,8 @@ def assign_points_to_pixels(
192
if phase is None:
193
raise ValueError("Phase must be provided for coherent mode")
194
195
- grid = torch.zeros(ks, ks, dtype=torch.complex128).to(device)
+ c_dtype = torch.complex64 if points.dtype == torch.float32 else torch.complex128
196
+ grid = torch.zeros(ks, ks, dtype=c_dtype).to(device)
197
grid.index_put_(
198
tuple(pixel_indices_tl.t()),
199
(1 - w_b) * (1 - w_r) * mask * amp * torch.exp(1j * phase),
@@ -246,7 +247,8 @@ def assign_points_to_pixels(
246
247
248
249
250
251
252
253
254
mask * amp * torch.exp(1j * phase),
deeplens/optics/wave.py
@@ -414,7 +414,7 @@ def pad(self, Hpad, Wpad):
414
self.phy_size[1] * self.res[1] / Worg,
415
]
416
self.x, self.y = self.gen_xy_grid()
417
- self.z = F.pad(self.z, (Hpad, Hpad, Wpad, Wpad), mode="replicate")
+ self.z = torch.full_like(self.x, self.z[0, 0].item())
418
419
def flip(self):
420
"""Flip the field horizontally and vertically."""
deeplens/sensor/sensor.py
@@ -143,7 +143,7 @@ def simu_noise(self, img_raw, iso):
143
144
# Calculate noise standard deviation
145
shotnoise_std = torch.clamp(
146
- self.shotnoise_std_alpha * torch.sqrt(img_raw - black_level)
+ self.shotnoise_std_alpha * torch.sqrt(torch.clamp(img_raw - black_level, min=0.0))
147
+ self.shotnoise_std_beta,
148
0.0,
149
)
deeplens/utils.py
@@ -136,8 +136,17 @@ def img2batch(img):
136
img = (
137
torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
138
) # (H, W, C) -> (1, C, H, W)
139
+ elif torch.is_tensor(img):
140
+ if img.shape[0] in [1, 3]:
141
+ # Assume (C, H, W) -> (1, C, H, W)
142
+ img = img.unsqueeze(0)
+ elif img.shape[-1] in [1, 3]:
+ # Assume (H, W, C) -> (1, C, H, W)
+ img = img.permute(2, 0, 1).unsqueeze(0)
+ else:
+ raise ValueError("Image channel should be 1 or 3.")
else:
- raise ValueError("Image should be numpy array.")
+ raise ValueError("Image should be numpy array or torch tensor.")
150
151
# Tensor dtype
152
if img.dtype == torch.uint8:
@@ -187,7 +196,7 @@ def batch_psnr(pred, target, max_val=1.0, eps=1e-8):
187
# Calculate PSNR
188
psnr = 20 * torch.log10(max_val / torch.sqrt(mse + eps))
189
190
- return psnr.item()
+ return psnr
191
200
201
202
def batch_SSIM(img, img_clean):
0 commit comments