Skip to content

Commit 177f398

Browse files
committed
add: torch bias field
1 parent ee1e156 commit 177f398

2 files changed

Lines changed: 41 additions & 0 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .croppad import torch_croppad
2+
from .bias_field import torch_bias_field
3+
from .blur import torch_blur
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
3+
def torch_bias_field(image: torch.Tensor, clip_to_input_range: bool = False) -> torch.Tensor:
4+
img_min = image.min()
5+
img_max = image.max()
6+
7+
if len(image.shape) == 3:
8+
assert image.ndim == 3, "Expected [H, W, D] tensor"
9+
10+
x, y, z = image.shape
11+
X, Y, Z = torch.meshgrid(
12+
torch.linspace(0, x-1, x),
13+
torch.linspace(0, y-1, y),
14+
torch.linspace(0, z-1, z),
15+
indexing='ij'
16+
)
17+
x0 = torch.randint(0, x, (1,))
18+
y0 = torch.randint(0, y, (1,))
19+
z0 = torch.randint(0, z, (1,))
20+
G = 1 - ((X - x0)**2 / (x**2) + (Y - y0)**2 / (y**2) + (Z - z0)**2 / (z**2))
21+
else:
22+
assert image.ndim == 2, "Expected [H, W] tensor"
23+
24+
x, y = image.shape
25+
X, Y = torch.meshgrid(
26+
torch.linspace(0, x-1, x),
27+
torch.linspace(0, y-1, y),
28+
indexing='ij'
29+
)
30+
x0 = torch.randint(0, x, (1,))
31+
y0 = torch.randint(0, y, (1,))
32+
G = 1 - ((X - x0)**2 / (x**2) + (Y - y0)**2 / (y**2))
33+
34+
image = G * image
35+
36+
if clip_to_input_range:
37+
image = torch.clamp(image, min=img_min, max=img_max)
38+
39+
return image

0 commit comments

Comments
 (0)