Skip to content

Commit 6f57196

Browse files
CopilotBordaCopilot
authored
Add runtime type annotations to rfdetr util helpers (#551)
* Add runtime type annotations (re-authored from Copilot) * Standardize docstring argument and return type formatting across utility functions * Apply suggestions from code review --------- Co-authored-by: jirka <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 950de8c commit 6f57196

File tree

10 files changed

+187
-125
lines changed

10 files changed

+187
-125
lines changed

rfdetr/util/benchmark.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,19 @@
3232
import time
3333

3434

35-
from typing import Any, Callable, List, Optional, Union
35+
from typing import Any, Callable, Dict, List, Sequence, Union
3636
from numbers import Number
3737

3838
Handle = Callable[[List[Any], List[Any]], Union[typing.Counter[str], Number]]
3939

4040

41-
def get_shape(val: object) -> typing.List[int]:
41+
def get_shape(val: Any) -> typing.List[int]:
4242
"""
4343
Get the shapes from a jit value object.
4444
Args:
45-
val (torch._C.Value): jit value object.
45+
val: jit value object.
4646
Returns:
47-
list(int): return a list of ints.
47+
return a list of ints.
4848
"""
4949
if val.isCompleteTensor(): # pyre-ignore
5050
r = val.type().sizes() # pyre-ignore
@@ -64,17 +64,17 @@ def get_shape(val: object) -> typing.List[int]:
6464

6565

6666
def addmm_flop_jit(
67-
inputs: typing.List[object], outputs: typing.List[object]
67+
inputs: typing.List[Any], outputs: typing.List[Any]
6868
) -> typing.Counter[str]:
6969
"""
7070
This method counts the flops for fully connected layers with torch script.
7171
Args:
72-
inputs (list(torch._C.Value)): The input shape in the form of a list of
72+
inputs: The input shape in the form of a list of
7373
jit object.
74-
outputs (list(torch._C.Value)): The output shape in the form of a list
74+
outputs: The output shape in the form of a list
7575
of jit object.
7676
Returns:
77-
Counter: A Counter dictionary that records the number of flops for each
77+
A Counter dictionary that records the number of flops for each
7878
operation.
7979
"""
8080
# Count flop for nn.Linear
@@ -91,7 +91,7 @@ def addmm_flop_jit(
9191
return flop_counter
9292

9393

94-
def bmm_flop_jit(inputs, outputs):
94+
def bmm_flop_jit(inputs: typing.List[Any], outputs: typing.List[Any]) -> Counter[str]:
9595
# Count flop for nn.Linear
9696
# inputs is a list of length 3.
9797
input_shapes = [get_shape(v) for v in inputs]
@@ -106,7 +106,7 @@ def bmm_flop_jit(inputs, outputs):
106106
return flop_counter
107107

108108

109-
def basic_binary_op_flop_jit(inputs, outputs, name):
109+
def basic_binary_op_flop_jit(inputs: typing.List[Any], outputs: typing.List[Any], name: str) -> Counter[str]:
110110
input_shapes = [get_shape(v) for v in inputs]
111111
# for broadcasting
112112
input_shapes = [s[::-1] for s in input_shapes]
@@ -116,29 +116,34 @@ def basic_binary_op_flop_jit(inputs, outputs, name):
116116
return flop_counter
117117

118118

119-
def rsqrt_flop_jit(inputs, outputs):
119+
def rsqrt_flop_jit(inputs: typing.List[Any], outputs: typing.List[Any]) -> Counter[str]:
120120
input_shapes = [get_shape(v) for v in inputs]
121121
flop = prod(input_shapes[0]) * 2
122122
flop_counter = Counter({"rsqrt": flop})
123123
return flop_counter
124124

125125

126-
def dropout_flop_jit(inputs, outputs):
126+
def dropout_flop_jit(inputs: typing.List[Any], outputs: typing.List[Any]) -> Counter[str]:
127127
input_shapes = [get_shape(v) for v in inputs[:1]]
128128
flop = prod(input_shapes[0])
129129
flop_counter = Counter({"dropout": flop})
130130
return flop_counter
131131

132132

133-
def softmax_flop_jit(inputs, outputs):
133+
def softmax_flop_jit(inputs: typing.List[Any], outputs: typing.List[Any]) -> Counter[str]:
134134
# from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/profiler/internal/flops_registry.py
135135
input_shapes = [get_shape(v) for v in inputs[:1]]
136136
flop = prod(input_shapes[0]) * 5
137137
flop_counter = Counter({"softmax": flop})
138138
return flop_counter
139139

140140

141-
def _reduction_op_flop_jit(inputs, outputs, reduce_flops=1, finalize_flops=0):
141+
def _reduction_op_flop_jit(
142+
inputs: typing.List[Any],
143+
outputs: typing.List[Any],
144+
reduce_flops: int = 1,
145+
finalize_flops: int = 0,
146+
) -> int:
142147
input_shapes = [get_shape(v) for v in inputs]
143148
output_shapes = [get_shape(v) for v in outputs]
144149

@@ -161,11 +166,11 @@ def conv_flop_count(
161166
This method counts the flops for convolution. Note only multiplication is
162167
counted. Computation for addition and bias is ignored.
163168
Args:
164-
x_shape (list(int)): The input shape before convolution.
165-
w_shape (list(int)): The filter shape.
166-
out_shape (list(int)): The output shape after convolution.
169+
x_shape: The input shape before convolution.
170+
w_shape: The filter shape.
171+
out_shape: The output shape after convolution.
167172
Returns:
168-
Counter: A Counter dictionary that records the number of flops for each
173+
A Counter dictionary that records the number of flops for each
169174
operation.
170175
"""
171176
batch_size, Cin_dim, Cout_dim = x_shape[0], w_shape[1], out_shape[1]
@@ -177,17 +182,17 @@ def conv_flop_count(
177182

178183

179184
def conv_flop_jit(
180-
inputs: typing.List[object], outputs: typing.List[object]
185+
inputs: typing.List[Any], outputs: typing.List[Any]
181186
) -> typing.Counter[str]:
182187
"""
183188
This method counts the flops for convolution using torch script.
184189
Args:
185-
inputs (list(torch._C.Value)): The input shape in the form of a list of
190+
inputs: The input shape in the form of a list of
186191
jit object before convolution.
187-
outputs (list(torch._C.Value)): The output shape in the form of a list
192+
outputs: The output shape in the form of a list
188193
of jit object after convolution.
189194
Returns:
190-
Counter: A Counter dictionary that records the number of flops for each
195+
A Counter dictionary that records the number of flops for each
191196
operation.
192197
"""
193198
# Inputs of Convolution should be a list of length 12. They represent:
@@ -206,18 +211,18 @@ def conv_flop_jit(
206211

207212

208213
def einsum_flop_jit(
209-
inputs: typing.List[object], outputs: typing.List[object]
214+
inputs: typing.List[Any], outputs: typing.List[Any]
210215
) -> typing.Counter[str]:
211216
"""
212217
This method counts the flops for the einsum operation. We currently support
213218
two einsum operations: "nct,ncp->ntp" and "ntg,ncg->nct".
214219
Args:
215-
inputs (list(torch._C.Value)): The input shape in the form of a list of
220+
inputs: The input shape in the form of a list of
216221
jit object before einsum.
217-
outputs (list(torch._C.Value)): The output shape in the form of a list
222+
outputs: The output shape in the form of a list
218223
of jit object after einsum.
219224
Returns:
220-
Counter: A Counter dictionary that records the number of flops for each
225+
A Counter dictionary that records the number of flops for each
221226
operation.
222227
"""
223228
# Inputs of einsum should be a list of length 2.
@@ -254,17 +259,17 @@ def einsum_flop_jit(
254259

255260

256261
def matmul_flop_jit(
257-
inputs: typing.List[object], outputs: typing.List[object]
262+
inputs: typing.List[Any], outputs: typing.List[Any]
258263
) -> typing.Counter[str]:
259264
"""
260265
This method counts the flops for matmul.
261266
Args:
262-
inputs (list(torch._C.Value)): The input shape in the form of a list of
267+
inputs: The input shape in the form of a list of
263268
jit object before matmul.
264-
outputs (list(torch._C.Value)): The output shape in the form of a list
269+
outputs: The output shape in the form of a list
265270
of jit object after matmul.
266271
Returns:
267-
Counter: A Counter dictionary that records the number of flops for each
272+
A Counter dictionary that records the number of flops for each
268273
operation.
269274
"""
270275

@@ -287,17 +292,17 @@ def matmul_flop_jit(
287292

288293

289294
def batchnorm_flop_jit(
290-
inputs: typing.List[object], outputs: typing.List[object]
295+
inputs: typing.List[Any], outputs: typing.List[Any]
291296
) -> typing.Counter[str]:
292297
"""
293298
This method counts the flops for batch norm.
294299
Args:
295-
inputs (list(torch._C.Value)): The input shape in the form of a list of
300+
inputs: The input shape in the form of a list of
296301
jit object before batch norm.
297-
outputs (list(torch._C.Value)): The output shape in the form of a list
302+
outputs: The output shape in the form of a list
298303
of jit object after batch norm.
299304
Returns:
300-
Counter: A Counter dictionary that records the number of flops for each
305+
A Counter dictionary that records the number of flops for each
301306
operation.
302307
"""
303308
# Inputs[0] contains the shape of the input.
@@ -470,26 +475,26 @@ def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
470475

471476
def flop_count(
472477
model: nn.Module,
473-
inputs: typing.Tuple[object, ...],
474-
whitelist: typing.Union[typing.List[str], None] = None,
475-
customized_ops: typing.Union[typing.Dict[str, typing.Callable], None] = None,
478+
inputs: typing.Tuple[Any, ...],
479+
whitelist: typing.Optional[typing.List[str]] = None,
480+
customized_ops: typing.Optional[typing.Dict[str, typing.Callable]] = None,
476481
) -> typing.DefaultDict[str, float]:
477482
"""
478483
Given a model and an input to the model, compute the Gflops of the given
479484
model. Note the input should have a batch size of 1.
480485
Args:
481-
model (nn.Module): The model to compute flop counts.
482-
inputs (tuple): Inputs that are passed to `model` to count flops.
486+
model: The model to compute flop counts.
487+
inputs: Inputs that are passed to `model` to count flops.
483488
Inputs need to be in a tuple.
484-
whitelist (list(str)): Whitelist of operations that will be counted. It
489+
whitelist: Whitelist of operations that will be counted. It
485490
needs to be a subset of _SUPPORTED_OPS. By default, the function
486491
computes flops for all supported operations.
487-
customized_ops (dict(str,Callable)) : A dictionary contains customized
492+
customized_ops: A dictionary contains customized
488493
operations and their flop handles. If customized_ops contains an
489494
operation in _SUPPORTED_OPS, then the default handle in
490495
_SUPPORTED_OPS will be overwritten.
491496
Returns:
492-
defaultdict: A dictionary that records the number of gflops for each
497+
A dictionary that records the number of gflops for each
493498
operation.
494499
"""
495500
# Copy _SUPPORTED_OPS to flop_count_ops.
@@ -557,13 +562,13 @@ def flop_count(
557562
return final_count
558563

559564

560-
def warmup(model, inputs, N=10):
565+
def warmup(model: torch.nn.Module, inputs: Any, N: int = 10) -> None:
561566
for i in range(N):
562567
out = model(inputs)
563568
torch.cuda.synchronize()
564569

565570

566-
def measure_time(model, inputs, N=10):
571+
def measure_time(model: torch.nn.Module, inputs: Any, N: int = 10) -> float:
567572
warmup(model, inputs)
568573
s = time.time()
569574
for i in range(N):
@@ -573,7 +578,7 @@ def measure_time(model, inputs, N=10):
573578
return t
574579

575580

576-
def fmt_res(data):
581+
def fmt_res(data: np.ndarray) -> Dict[str, float]:
577582
# return data.mean(), data.std(), data.min(), data.max()
578583
return {
579584
"mean": data.mean(),
@@ -583,7 +588,7 @@ def fmt_res(data):
583588
}
584589

585590

586-
def benchmark(model, dataset, output_dir):
591+
def benchmark(model: torch.nn.Module, dataset: Sequence[Any], output_dir: Any) -> Dict[str, Any]:
587592
print("Get model size, FLOPs, and FPS")
588593
# import pdb; pdb.set_trace()
589594
_outputs = {}

rfdetr/util/box_ops.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,36 @@
1414
"""
1515
Utilities for bounding box manipulation and GIoU.
1616
"""
17+
from typing import Tuple
18+
1719
import torch
1820
import torch.nn.functional as F
1921
from torchvision.ops.boxes import box_area
2022

2123

22-
def box_cxcywh_to_xyxy(x):
24+
def box_cxcywh_to_xyxy(x: torch.Tensor) -> torch.Tensor:
2325
x_c, y_c, w, h = x.unbind(-1)
2426
b = [(x_c - 0.5 * w.clamp(min=0.0)), (y_c - 0.5 * h.clamp(min=0.0)),
2527
(x_c + 0.5 * w.clamp(min=0.0)), (y_c + 0.5 * h.clamp(min=0.0))]
2628
return torch.stack(b, dim=-1)
2729

2830

29-
def box_xyxy_to_cxcywh(x):
31+
def box_xyxy_to_cxcywh(x: torch.Tensor) -> torch.Tensor:
3032
x0, y0, x1, y1 = x.unbind(-1)
3133
b = [(x0 + x1) / 2, (y0 + y1) / 2,
3234
(x1 - x0), (y1 - y0)]
3335
return torch.stack(b, dim=-1)
3436

3537

3638
# modified from torchvision to also return the union
37-
def box_iou(boxes1, boxes2):
39+
def box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
40+
"""
41+
Returns:
42+
iou: the NxM matrix containing the pairwise
43+
IoU values for every element in boxes1 and boxes2
44+
union: the NxM matrix containing the pairwise
45+
union values for every element in boxes1 and boxes2
46+
"""
3847
area1 = box_area(boxes1)
3948
area2 = box_area(boxes2)
4049

@@ -50,7 +59,7 @@ def box_iou(boxes1, boxes2):
5059
return iou, union
5160

5261

53-
def generalized_box_iou(boxes1, boxes2):
62+
def generalized_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
5463
"""
5564
Generalized IoU from https://giou.stanford.edu/
5665
@@ -72,7 +81,7 @@ def generalized_box_iou(boxes1, boxes2):
7281
return iou - (area - union) / area
7382

7483

75-
def masks_to_boxes(masks):
84+
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
7685
"""Compute the bounding boxes around the provided masks
7786
7887
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
@@ -99,7 +108,7 @@ def masks_to_boxes(masks):
99108
return torch.stack([x_min, y_min, x_max, y_max], 1)
100109

101110

102-
def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
111+
def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
103112
"""
104113
Compute the DICE loss, similar to generalized IOU for masks
105114
Args:
@@ -122,7 +131,7 @@ def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
122131
) # type: torch.jit.ScriptModule
123132

124133

125-
def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
134+
def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
126135
"""
127136
Args:
128137
inputs: A float tensor of arbitrary shape.

rfdetr/util/drop_scheduler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,17 @@
55
# ------------------------------------------------------------------------
66
"""util for drop scheduler."""
77
import numpy as np
8+
from typing import Literal
89

910

10-
def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode='standard', schedule='constant'):
11+
def drop_scheduler(
12+
drop_rate: float,
13+
epochs: int,
14+
niter_per_ep: int,
15+
cutoff_epoch: int = 0,
16+
mode: Literal['standard', 'early', 'late'] = 'standard',
17+
schedule: Literal['constant', 'linear'] = 'constant',
18+
) -> np.ndarray:
1119
"""drop scheduler"""
1220
assert mode in ['standard', 'early', 'late']
1321
if mode == 'standard':

0 commit comments

Comments
 (0)