2626# with definitions given here.
2727
2828from collections .abc import Callable , Sequence
29- from typing import Optional , Self , Union
29+ from typing import Optional , Self , SupportsFloat , SupportsInt , Union
3030
3131class CudaStream :
3232 """Represents CUDA stream
@@ -71,13 +71,15 @@ class Benchmark:
7171 def get_name (self ) -> str :
7272 "Get benchmark name"
7373 ...
74- def add_int64_axis (self , name : str , values : Sequence [int ]) -> Self :
74+ def add_int64_axis (self , name : str , values : Sequence [SupportsInt ]) -> Self :
7575 "Add integral type parameter axis with given name and values to sweep over"
7676 ...
77- def add_int64_power_of_two_axis (self , name : str , values : Sequence [int ]) -> Self :
77+ def add_int64_power_of_two_axis (
78+ self , name : str , values : Sequence [SupportsInt ]
79+ ) -> Self :
7880 "Add integral type parameter axis with given name and values to sweep over"
7981 ...
80- def add_float64_axis (self , name : str , values : Sequence [float ]) -> Self :
82+ def add_float64_axis (self , name : str , values : Sequence [SupportsFloat ]) -> Self :
8183 "Add floating-point type parameter axis with given name and values to sweep over"
8284 ...
8385 def add_string_axis (sef , name : str , values : Sequence [str ]) -> Self :
@@ -92,31 +94,31 @@ class Benchmark:
9294 def set_run_once (self , v : bool ) -> Self :
9395 "Set whether all benchmark configurations are executed only once"
9496 ...
95- def set_skip_time (self , duration_seconds : float ) -> Self :
97+ def set_skip_time (self , duration_seconds : SupportsFloat ) -> Self :
9698 "Set run durations, in seconds, that should be skipped"
9799 ...
98- def set_throttle_recovery_delay (self , delay_seconds : float ) -> Self :
100+ def set_throttle_recovery_delay (self , delay_seconds : SupportsFloat ) -> Self :
99101 "Set throttle recovery delay, in seconds"
100102 ...
101- def set_throttle_threshold (self , threshold : float ) -> Self :
103+ def set_throttle_threshold (self , threshold : SupportsFloat ) -> Self :
102104 "Set throttle threshold, as a fraction of maximal GPU frequency"
103105 ...
104- def set_timeout (self , duration_seconds : float ) -> Self :
106+ def set_timeout (self , duration_seconds : SupportsFloat ) -> Self :
105107 "Set benchmark run duration timeout value, in seconds"
106108 ...
107109 def set_stopping_criterion (self , criterion : str ) -> Self :
108110 "Set stopping criterion to be used"
109111 ...
110- def set_criterion_param_float64 (self , name : str , value : float ) -> Self :
112+ def set_criterion_param_float64 (self , name : str , value : SupportsFloat ) -> Self :
111113 "Set stopping criterion floating point parameter value"
112114 ...
113- def set_criterion_param_int64 (self , name : str , value : int ) -> Self :
115+ def set_criterion_param_int64 (self , name : str , value : SupportsInt ) -> Self :
114116 "Set stopping criterion integer parameter value"
115117 ...
116118 def set_criterion_param_string (self , name : str , value : str ) -> Self :
117119 "Set stopping criterion string parameter value"
118120 ...
119- def set_min_samples (self , count : int ) -> Self :
121+ def set_min_samples (self , count : SupportsInt ) -> Self :
120122 "Set minimal samples count before stopping criterion applies"
121123 ...
122124
@@ -153,13 +155,13 @@ class State:
153155 def get_int64 (self , name : str ) -> int :
154156 "Get value for given Int64 axis from this configuration"
155157 ...
156- def get_int64_or_default (self , name : str , default_value : int ) -> int :
158+ def get_int64_or_default (self , name : str , default_value : SupportsInt ) -> int :
157159 "Get value for given Int64 axis from this configuration"
158160 ...
159161 def get_float64 (self , name : str ) -> float :
160162 "Get value for given Float64 axis from this configuration"
161163 ...
162- def get_float64_or_default (self , name : str , default_value : float ) -> float :
164+ def get_float64_or_default (self , name : str , default_value : SupportsFloat ) -> float :
163165 "Get value for given Float64 axis from this configuration"
164166 ...
165167 def get_string (self , name : str ) -> str :
@@ -168,10 +170,12 @@ class State:
168170 def get_string_or_default (self , name : str , default_value : str ) -> str :
169171 "Get value for given String axis from this configuration"
170172 ...
171- def add_element_count (self , count : int , column_name : Optional [str ] = None ) -> None :
173+ def add_element_count (
174+ self , count : SupportsInt , column_name : Optional [str ] = None
175+ ) -> None :
172176 "Add element count"
173177 ...
174- def set_element_count (self , count : int ) -> None :
178+ def set_element_count (self , count : SupportsInt ) -> None :
175179 "Set element count"
176180 ...
177181 def get_element_count (self ) -> int :
@@ -186,10 +190,14 @@ class State:
186190 def get_skip_reason (self ) -> str :
187191 "Get reason provided for skipping this configuration"
188192 ...
189- def add_global_memory_reads (self , nbytes : int , / , column_name : str = "" ) -> None :
193+ def add_global_memory_reads (
194+ self , nbytes : SupportsInt , / , column_name : str = ""
195+ ) -> None :
190196 "Inform NVBench that given amount of bytes is being read by the benchmark from global memory"
191197 ...
192- def add_global_memory_writes (self , nbytes : int , / , column_name : str = "" ) -> None :
198+ def add_global_memory_writes (
199+ self , nbytes : SupportsInt , / , column_name : str = ""
200+ ) -> None :
193201 "Inform NVBench that given amount of bytes is being written by the benchmark into global memory"
194202 ...
195203 def get_benchmark (self ) -> Benchmark :
@@ -198,13 +206,13 @@ class State:
198206 def get_throttle_threshold (self ) -> float :
199207 "Get throttle threshold value, as fraction of maximal frequency"
200208 ...
201- def set_throttle_threshold (self , threshold_fraction : float ) -> None :
209+ def set_throttle_threshold (self , threshold_fraction : SupportsFloat ) -> None :
202210 "Set throttle threshold fraction to specified value, expected to be between 0 and 1"
203211 ...
204212 def get_min_samples (self ) -> int :
205213 "Get the number of benchmark timings NVBench performs before stopping criterion begins being used"
206214 ...
207- def set_min_samples (self , min_samples_count : int ) -> None :
215+ def set_min_samples (self , min_samples_count : SupportsInt ) -> None :
208216 "Set the number of benchmark timings for NVBench to perform before stopping criterion begins being used"
209217 ...
210218 def get_disable_blocking_kernel (self ) -> bool :
@@ -222,13 +230,13 @@ class State:
222230 def get_timeout (self ) -> float :
223231 "Get time-out value for benchmark execution of this configuration"
224232 ...
225- def set_timeout (self , duration : float ) -> None :
233+ def set_timeout (self , duration : SupportsFloat ) -> None :
226234 "Set time-out value for benchmark execution of this configuration, in seconds"
227235 ...
228236 def get_blocking_kernel_timeout (self ) -> float :
229237 "Get time-out value for execution of blocking kernel"
230238 ...
231- def set_blocking_kernel_timeout (self , duration : float ) -> None :
239+ def set_blocking_kernel_timeout (self , duration : SupportsFloat ) -> None :
232240 "Set time-out value for execution of blocking kernel, in seconds"
233241 ...
234242 def collect_cupti_metrics (self ) -> None :
@@ -265,7 +273,9 @@ class State:
265273 def get_short_description (self ) -> str :
266274 "Get short description for this configuration"
267275 ...
268- def add_summary (self , column_name : str , value : Union [int , float , str ]) -> None :
276+ def add_summary (
277+ self , column_name : str , value : Union [SupportsInt , SupportsFloat , str ]
278+ ) -> None :
269279 "Add summary column with a value"
270280 ...
271281 def get_axis_values (self ) -> dict [str , int | float | str ]:
0 commit comments