|
22 | 22 |
|
23 | 23 | import os |
24 | 24 | import random |
| 25 | +import textwrap |
| 26 | + |
| 27 | +from tabulate import tabulate |
25 | 28 |
|
26 | 29 | from .rgg import get_randomizer_config_default |
27 | 30 |
|
@@ -298,3 +301,92 @@ def test_random_graph_algorithm( |
298 | 301 | graph_builder_type=RandomGraphAlgorithm, |
299 | 302 | framework=framework, |
300 | 303 | ) |
| 304 | + |
| 305 | + |
| 306 | +@dataclass |
| 307 | +class InfoColumn: |
| 308 | + name: str |
| 309 | + header: str |
| 310 | + width: Union[int, float] |
| 311 | + |
| 312 | + |
| 313 | +class InfoUtils: |
| 314 | + |
| 315 | + ALL_TESTS = RGGConfiguraionProvider.ALL_TESTS |
| 316 | + |
| 317 | + @classmethod |
| 318 | + def print_query_params(cls, max_width=80): |
| 319 | + print("Query parameters:") |
| 320 | + cls.print_query_values(max_width) |
| 321 | + print("Query examples:") |
| 322 | + cls.print_query_examples(max_width) |
| 323 | + |
| 324 | + @classmethod |
| 325 | + def print_query_values(cls, max_width=80): |
| 326 | + |
| 327 | + frameworks = [test.framework.template_name.upper() for test in cls.ALL_TESTS] |
| 328 | + frameworks = set(frameworks) |
| 329 | + frameworks = ", ".join(frameworks) |
| 330 | + |
| 331 | + test_names = [test.test_name.upper() for test in cls.ALL_TESTS] |
| 332 | + test_names = set(test_names) |
| 333 | + test_names = ", ".join(test_names) |
| 334 | + |
| 335 | + parameters = [ |
| 336 | + {"name": "FRAMEWORKS", "description": "List of frameworks.", "supported_values": f"{frameworks}", "default": ""}, |
| 337 | + {"name": "TEST_NAMES", "description": "List of test names.", "supported_values": f"{test_names}", "default": "DEFAULT"}, |
| 338 | + {"name": "RANDOM_TEST_SEED", "description": "Initial seed for RGG.", "supported_values": "", "default": "0"}, |
| 339 | + {"name": "RANDOM_TEST_COUNT", "description": "Number of random tests to be generated and executed.", "supported_values": "", "default": "5"}, |
| 340 | + {"name": "RANDOM_TESTS_SELECTED", "description": "Limiting random tests to only selected subset defined as comma separated list of test indexes.", "supported_values": "", "default": "no limitation if not specified or empty"}, |
| 341 | + {"name": "VERIFICATION_TIMEOUT", "description": "Limit time for inference verification in seconds.", "supported_values": "", "default": "60"}, |
| 342 | + {"name": "MIN_DIM", "description": "Minimal number of dimensions of input tensors.", "supported_values": "", "default": "3"}, |
| 343 | + {"name": "MAX_DIM", "description": "Maximum number of dimensions of input tensors.", "supported_values": "", "default": "4"}, |
| 344 | + {"name": "MIN_OP_SIZE_PER_DIM", "description": "Minimal size of an operand dimension.", "supported_values": "", "default": "16"}, |
| 345 | + {"name": "MAX_OP_SIZE_PER_DIM", "description": "Maximum size of an operand dimension. Smaller operand size results in fewer failed tests.", "supported_values": "", "default": "512"}, |
| 346 | + {"name": "OP_SIZE_QUANTIZATION", "description": "Quantization factor for operand size.", "supported_values": "", "default": "1"}, |
| 347 | + {"name": "MIN_MICROBATCH_SIZE", "description": "Minimal size of microbatch of an input tensor.", "supported_values": "", "default": "1"}, |
| 348 | + {"name": "MAX_MICROBATCH_SIZE", "description": "Maximum size of microbatch of an input tensor.", "supported_values": "", "default": "8"}, |
| 349 | + {"name": "NUM_OF_NODES_MIN", "description": "Minimal number of nodes to be generated by RGG.", "supported_values": "", "default": "5"}, |
| 350 | + {"name": "NUM_OF_NODES_MAX", "description": "Maximum number of nodes to be generated by RGG.", "supported_values": "", "default": "10"}, |
| 351 | + {"name": "NUM_OF_FORK_JOINS_MAX", "description": "Maximum number of fork joins to be generated by random graph algorithm in RGG.", "supported_values": "", "default": "50"}, |
| 352 | + {"name": "CONSTANT_INPUT_RATE", "description": "Rate of constant inputs in RGG in percents.", "supported_values": "", "default": "50"}, |
| 353 | + {"name": "SAME_INPUTS_PERCENT_LIMIT", "description": "Percent limit of nodes which have same value on multiple inputes.", "supported_values": "", "default": "10"}, |
| 354 | + ] |
| 355 | + |
| 356 | + cls.print_formatted_parameters(parameters, max_width, columns=[ |
| 357 | + InfoColumn("name", "Parameter", 25), |
| 358 | + InfoColumn("description", "Description", 0.6), |
| 359 | + InfoColumn("supported_values", "Supported values", 0.4), |
| 360 | + InfoColumn("default", "Default", 20), |
| 361 | + ]) |
| 362 | + |
| 363 | + @classmethod |
| 364 | + def print_query_examples(cls, max_width=80): |
| 365 | + |
| 366 | + parameters = [ |
| 367 | + {"name": "FRAMEWORKS", "example": "export FRAMEWORKS=FORGE"}, |
| 368 | + {"name": "TEST_NAMES", "example": "export TEST_NAMES=DEFAULT"}, |
| 369 | + {"name": "TEST_NAMES", "example": "export RANDOM_TEST_COUNT='3,4,6'"}, |
| 370 | + ] |
| 371 | + |
| 372 | + cls.print_formatted_parameters(parameters, max_width, columns=[ |
| 373 | + InfoColumn("name", "Parameter", 25), |
| 374 | + InfoColumn("example", "Examples", 0.8), |
| 375 | + ]) |
| 376 | + |
| 377 | + @classmethod |
| 378 | + def print_formatted_parameters(cls, parameters, max_width: int, columns: List[InfoColumn]): |
| 379 | + |
| 380 | + fixed_width = sum([col.width for col in columns if isinstance(col.width, int)]) |
| 381 | + for col in columns: |
| 382 | + if isinstance(col.width, float): |
| 383 | + col.width = int((max_width - fixed_width) * col.width) |
| 384 | + |
| 385 | + for param in parameters: |
| 386 | + for col in columns: |
| 387 | + param[col.name] = "\n".join(textwrap.wrap(param[col.name], width=col.width)) |
| 388 | + |
| 389 | + table_data = [[param[column.name] for column in columns] for param in parameters] |
| 390 | + |
| 391 | + headers = [column.header for column in columns] |
| 392 | + print(tabulate(table_data, headers, tablefmt="grid")) |
0 commit comments