1+ import torch
2+
3+ def torch_bias_field (image : torch .Tensor , clip_to_input_range : bool = False ) -> torch .Tensor :
4+ img_min = image .min ()
5+ img_max = image .max ()
6+
7+ if len (image .shape ) == 3 :
8+ assert image .ndim == 3 , "Expected [H, W, D] tensor"
9+
10+ x , y , z = image .shape
11+ X , Y , Z = torch .meshgrid (
12+ torch .linspace (0 , x - 1 , x ),
13+ torch .linspace (0 , y - 1 , y ),
14+ torch .linspace (0 , z - 1 , z ),
15+ indexing = 'ij'
16+ )
17+ x0 = torch .randint (0 , x , (1 ,))
18+ y0 = torch .randint (0 , y , (1 ,))
19+ z0 = torch .randint (0 , z , (1 ,))
20+ G = 1 - ((X - x0 )** 2 / (x ** 2 ) + (Y - y0 )** 2 / (y ** 2 ) + (Z - z0 )** 2 / (z ** 2 ))
21+ else :
22+ assert image .ndim == 2 , "Expected [H, W] tensor"
23+
24+ x , y = image .shape
25+ X , Y = torch .meshgrid (
26+ torch .linspace (0 , x - 1 , x ),
27+ torch .linspace (0 , y - 1 , y ),
28+ indexing = 'ij'
29+ )
30+ x0 = torch .randint (0 , x , (1 ,))
31+ y0 = torch .randint (0 , y , (1 ,))
32+ G = 1 - ((X - x0 )** 2 / (x ** 2 ) + (Y - y0 )** 2 / (y ** 2 ))
33+
34+ image = G * image
35+
36+ if clip_to_input_range :
37+ image = torch .clamp (image , min = img_min , max = img_max )
38+
39+ return image
0 commit comments