@@ -155,7 +155,7 @@ def random_rotate(img: torch.Tensor) -> torch.Tensor:
155155 random_rotate ,
156156 ]
157157 choice = torch .randint (len (transforms ), (1 ,)).item ()
158- return transforms [choice ](img )
158+ return transforms [int ( choice ) ](img )
159159
160160 def random_resize_and_pad (img : torch .Tensor , dim : int = 245 ) -> torch .Tensor :
161161 """
@@ -168,21 +168,24 @@ def random_resize_and_pad(img: torch.Tensor, dim: int = 245) -> torch.Tensor:
168168 img , size = (target , target ), mode = 'bilinear' , align_corners = False
169169 )
170170
171- pad_total = dim - target
172- pad_top = torch .randint (0 , pad_total , (1 ,)).item () # type: ignore[arg-type]
173- pad_bottom = pad_total - pad_top
174- pad_left = torch .randint (0 , pad_total , (1 ,)).item () # type: ignore[arg-type]
175- pad_right = pad_total - pad_left
171+ pad_total = int ( dim - target )
172+ pad_top = int ( torch .randint (0 , pad_total , (1 ,)).item ())
173+ pad_bottom = int ( pad_total - pad_top )
174+ pad_left = int ( torch .randint (0 , pad_total , (1 ,)).item ())
175+ pad_right = int ( pad_total - pad_left )
176176
177- padded = f .pad (resized , [pad_left , pad_right , pad_top , pad_bottom ], value = 0 ) # type: ignore[list-item]
178- return f .interpolate (
177+ padded : torch .Tensor = f .pad (
178+ resized , [pad_left , pad_right , pad_top , pad_bottom ], value = 0
179+ )
180+ padded = f .interpolate (
179181 padded , size = (orig , orig ), mode = 'bilinear' , align_corners = False
180182 )
183+ return padded
181184
182185 # Choose one augmentation at random
183186 transforms = [random_affine , random_resize_and_pad ]
184187 idx = torch .randint (len (transforms ), (1 ,)).item ()
185- aug_x : torch .Tensor = transforms [idx ](x )
188+ aug_x : torch .Tensor = transforms [int ( idx ) ](x ) # type: ignore[operator]
186189 return aug_x
187190
188191 def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
@@ -235,15 +238,7 @@ def __init__(self, region_num: int, is_channels_first: bool = False) -> None:
235238 def get_params (
236239 self , x : torch .Tensor
237240 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
238- """
239- Compute per-channel min, max and number of splits.
240-
241- Args:
242- x: Tensor of shape (C,H,W)
243-
244- Returns:
245- min_val: (C,), max_val: (C,), counts: (C,) = region_num - 1 splits
246- """
241+ """Compute per-channel min, max and number of splits."""
247242
248243 c , _ , _ = x .size ()
249244 flat = x .view (c , - 1 )
@@ -265,7 +260,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
265260
266261 # sample random percentiles for splits
267262 total_splits = counts .sum ().item ()
268- rand_perc = torch .rand (total_splits , device = x .device )
263+ rand_perc = torch .rand (int ( total_splits ) , device = x .device )
269264 splits = rand_perc .view (- 1 , self .region_num - 1 )
270265
271266 # compute split positions: in [min_val, max_val) per channel
0 commit comments