Skip to content

Commit add539a

Browse files
Replaced argument type annotation: int -> typing.SupportsInt
Same for float->typing.SupportsFloat. Result types remain int/float
1 parent 88a3ad0 commit add539a

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

python/cuda/nvbench/__init__.pyi

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# with definitions given here.
2727

2828
from collections.abc import Callable, Sequence
29-
from typing import Optional, Self, Union
29+
from typing import Optional, Self, SupportsFloat, SupportsInt, Union
3030

3131
class CudaStream:
3232
"""Represents CUDA stream
@@ -71,13 +71,15 @@ class Benchmark:
7171
def get_name(self) -> str:
7272
"Get benchmark name"
7373
...
74-
def add_int64_axis(self, name: str, values: Sequence[int]) -> Self:
74+
def add_int64_axis(self, name: str, values: Sequence[SupportsInt]) -> Self:
7575
"Add integral type parameter axis with given name and values to sweep over"
7676
...
77-
def add_int64_power_of_two_axis(self, name: str, values: Sequence[int]) -> Self:
77+
def add_int64_power_of_two_axis(
78+
self, name: str, values: Sequence[SupportsInt]
79+
) -> Self:
7880
"Add integral type parameter axis with given name and values to sweep over"
7981
...
80-
def add_float64_axis(self, name: str, values: Sequence[float]) -> Self:
82+
def add_float64_axis(self, name: str, values: Sequence[SupportsFloat]) -> Self:
8183
"Add floating-point type parameter axis with given name and values to sweep over"
8284
...
8385
def add_string_axis(sef, name: str, values: Sequence[str]) -> Self:
@@ -92,31 +94,31 @@ class Benchmark:
9294
def set_run_once(self, v: bool) -> Self:
9395
"Set whether all benchmark configurations are executed only once"
9496
...
95-
def set_skip_time(self, duration_seconds: float) -> Self:
97+
def set_skip_time(self, duration_seconds: SupportsFloat) -> Self:
9698
"Set run durations, in seconds, that should be skipped"
9799
...
98-
def set_throttle_recovery_delay(self, delay_seconds: float) -> Self:
100+
def set_throttle_recovery_delay(self, delay_seconds: SupportsFloat) -> Self:
99101
"Set throttle recovery delay, in seconds"
100102
...
101-
def set_throttle_threshold(self, threshold: float) -> Self:
103+
def set_throttle_threshold(self, threshold: SupportsFloat) -> Self:
102104
"Set throttle threshold, as a fraction of maximal GPU frequency"
103105
...
104-
def set_timeout(self, duration_seconds: float) -> Self:
106+
def set_timeout(self, duration_seconds: SupportsFloat) -> Self:
105107
"Set benchmark run duration timeout value, in seconds"
106108
...
107109
def set_stopping_criterion(self, criterion: str) -> Self:
108110
"Set stopping criterion to be used"
109111
...
110-
def set_criterion_param_float64(self, name: str, value: float) -> Self:
112+
def set_criterion_param_float64(self, name: str, value: SupportsFloat) -> Self:
111113
"Set stopping criterion floating point parameter value"
112114
...
113-
def set_criterion_param_int64(self, name: str, value: int) -> Self:
115+
def set_criterion_param_int64(self, name: str, value: SupportsInt) -> Self:
114116
"Set stopping criterion integer parameter value"
115117
...
116118
def set_criterion_param_string(self, name: str, value: str) -> Self:
117119
"Set stopping criterion string parameter value"
118120
...
119-
def set_min_samples(self, count: int) -> Self:
121+
def set_min_samples(self, count: SupportsInt) -> Self:
120122
"Set minimal samples count before stopping criterion applies"
121123
...
122124

@@ -153,13 +155,13 @@ class State:
153155
def get_int64(self, name: str) -> int:
154156
"Get value for given Int64 axis from this configuration"
155157
...
156-
def get_int64_or_default(self, name: str, default_value: int) -> int:
158+
def get_int64_or_default(self, name: str, default_value: SupportsInt) -> int:
157159
"Get value for given Int64 axis from this configuration"
158160
...
159161
def get_float64(self, name: str) -> float:
160162
"Get value for given Float64 axis from this configuration"
161163
...
162-
def get_float64_or_default(self, name: str, default_value: float) -> float:
164+
def get_float64_or_default(self, name: str, default_value: SupportsFloat) -> float:
163165
"Get value for given Float64 axis from this configuration"
164166
...
165167
def get_string(self, name: str) -> str:
@@ -168,10 +170,12 @@ class State:
168170
def get_string_or_default(self, name: str, default_value: str) -> str:
169171
"Get value for given String axis from this configuration"
170172
...
171-
def add_element_count(self, count: int, column_name: Optional[str] = None) -> None:
173+
def add_element_count(
174+
self, count: SupportsInt, column_name: Optional[str] = None
175+
) -> None:
172176
"Add element count"
173177
...
174-
def set_element_count(self, count: int) -> None:
178+
def set_element_count(self, count: SupportsInt) -> None:
175179
"Set element count"
176180
...
177181
def get_element_count(self) -> int:
@@ -186,10 +190,14 @@ class State:
186190
def get_skip_reason(self) -> str:
187191
"Get reason provided for skipping this configuration"
188192
...
189-
def add_global_memory_reads(self, nbytes: int, /, column_name: str = "") -> None:
193+
def add_global_memory_reads(
194+
self, nbytes: SupportsInt, /, column_name: str = ""
195+
) -> None:
190196
"Inform NVBench that given amount of bytes is being read by the benchmark from global memory"
191197
...
192-
def add_global_memory_writes(self, nbytes: int, /, column_name: str = "") -> None:
198+
def add_global_memory_writes(
199+
self, nbytes: SupportsInt, /, column_name: str = ""
200+
) -> None:
193201
"Inform NVBench that given amount of bytes is being written by the benchmark into global memory"
194202
...
195203
def get_benchmark(self) -> Benchmark:
@@ -198,13 +206,13 @@ class State:
198206
def get_throttle_threshold(self) -> float:
199207
"Get throttle threshold value, as fraction of maximal frequency"
200208
...
201-
def set_throttle_threshold(self, threshold_fraction: float) -> None:
209+
def set_throttle_threshold(self, threshold_fraction: SupportsFloat) -> None:
202210
"Set throttle threshold fraction to specified value, expected to be between 0 and 1"
203211
...
204212
def get_min_samples(self) -> int:
205213
"Get the number of benchmark timings NVBench performs before stopping criterion begins being used"
206214
...
207-
def set_min_samples(self, min_samples_count: int) -> None:
215+
def set_min_samples(self, min_samples_count: SupportsInt) -> None:
208216
"Set the number of benchmark timings for NVBench to perform before stopping criterion begins being used"
209217
...
210218
def get_disable_blocking_kernel(self) -> bool:
@@ -222,13 +230,13 @@ class State:
222230
def get_timeout(self) -> float:
223231
"Get time-out value for benchmark execution of this configuration"
224232
...
225-
def set_timeout(self, duration: float) -> None:
233+
def set_timeout(self, duration: SupportsFloat) -> None:
226234
"Set time-out value for benchmark execution of this configuration, in seconds"
227235
...
228236
def get_blocking_kernel_timeout(self) -> float:
229237
"Get time-out value for execution of blocking kernel"
230238
...
231-
def set_blocking_kernel_timeout(self, duration: float) -> None:
239+
def set_blocking_kernel_timeout(self, duration: SupportsFloat) -> None:
232240
"Set time-out value for execution of blocking kernel, in seconds"
233241
...
234242
def collect_cupti_metrics(self) -> None:
@@ -265,7 +273,9 @@ class State:
265273
def get_short_description(self) -> str:
266274
"Get short description for this configuration"
267275
...
268-
def add_summary(self, column_name: str, value: Union[int, float, str]) -> None:
276+
def add_summary(
277+
self, column_name: str, value: Union[SupportsInt, SupportsFloat, str]
278+
) -> None:
269279
"Add summary column with a value"
270280
...
271281
def get_axis_values(self) -> dict[str, int | float | str]:

0 commit comments

Comments
 (0)