Skip to content

Commit c6bc478

Browse files
authored
Add ExtendedTestBase shared base class for benchmark and functional tests (#3689)
### Summary Extracts shared test infrastructure from `BenchmarkBase` into a new `ExtendedTestBase` base class, reducing code duplication between benchmark and functional tests. This addresses the reviewer feedback on #3648 about sharing common logic between `benchmark_base.py` and `functional_base.py`. ### Changes #### New: `utils/extended_test_base.py` - Created ExtendedTestBase class with following shared infrastructure: execute_command(), get_gpu_architecture(), detect_gpu_count(), create_test_result(), calculate_statistics(), upload_results() #### Updated: `benchmark/scripts/benchmark_base.py` - BenchmarkBase now inherits from ExtendedTestBase - Removed duplicated methods: execute_command, _detect_gpu_count, calculate_statistics, upload_results #### Updated: `benchmark/scripts/test_rccl_benchmark.py` - self._detect_gpu_count() → self.detect_gpu_count() (now inherited from base) #### Updated: `utils/__init__.py`, `README.md`, `utils/README.md` - Added ExtendedTestBase to exports and documentation #### Inheritance Hierarchy ``` ExtendedTestBase (utils/extended_test_base.py) ├── BenchmarkBase (benchmark/scripts/benchmark_base.py) │ ├── ROCfftBenchmark, RCCLBenchmark, ROCblasBenchmark, ... └── FunctionalBase (functional/scripts/functional_base.py) ← will inherit in follow-up PR ├── MIOpenDriverConv, RcclTestInfra, ... ``` #### Follow-up - Update FunctionalBase to inherit from ExtendedTestBase (after this PR merges) --------- Signed-off-by: Lenine Ajagappane <Lenine.Ajagappane@amd.com>
1 parent 864f3d3 commit c6bc478

6 files changed

Lines changed: 271 additions & 187 deletions

File tree

tests/extended_tests/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,14 @@ extended_tests/
8282
│ └── README.md # Functional-specific docs (placeholder - tests to be added in follow-up PRs)
8383
8484
└── utils/ # SHARED utilities for all test types
85+
├── extended_test_base.py # ExtendedTestBase - shared base class for all tests
86+
├── extended_test_client.py # ExtendedTestClient - system detection & result reporting
8587
├── exceptions.py # Custom exception classes
8688
│ ├── BenchmarkExecutionError # Execution/parsing failures
8789
│ ├── BenchmarkResultError # Result validation failures
8890
│ └── FrameworkException # Base exception
8991
9092
├── logger.py # Logging utilities
91-
├── extended_test_client.py # ExtendedTestClient API
9293
├── constants.py # Global constants
9394
9495
├── config/ # Configuration parsers

tests/extended_tests/benchmark/scripts/benchmark_base.py

Lines changed: 39 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,26 @@
33

44
"""Base class for benchmark tests with common functionality."""
55

6-
import os
7-
import shlex
86
import shutil
9-
import subprocess
107
import sys
118
from pathlib import Path
12-
from typing import Dict, List, Tuple, Any, IO
9+
from typing import Dict, List, Any
1310
from prettytable import PrettyTable
1411

1512
# Add parent directory to path for utils import
1613
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
17-
# Add build_tools/github_actions to path for github_actions_utils
18-
sys.path.insert(
19-
0, str(Path(__file__).resolve().parents[4] / "build_tools" / "github_actions")
20-
)
21-
from utils import ExtendedTestClient, HardwareDetector
14+
2215
from utils.logger import log
2316
from utils.exceptions import TestExecutionError
24-
from github_actions_utils import gha_append_step_summary
17+
from utils.extended_test_base import ExtendedTestBase, gha_append_step_summary
2518

2619

27-
class BenchmarkBase:
20+
class BenchmarkBase(ExtendedTestBase):
2821
"""Base class providing common benchmark logic.
2922
23+
Inherits shared infrastructure from ExtendedTestBase (execute_command,
24+
create_test_result, calculate_statistics, upload_results, etc.).
25+
3026
Child classes must implement run_benchmarks() and parse_results().
3127
"""
3228

@@ -37,88 +33,9 @@ def __init__(self, benchmark_name: str, display_name: str = None):
3733
benchmark_name: Internal benchmark name (e.g., 'rocfft')
3834
display_name: Display name for reports (e.g., 'ROCfft'), defaults to benchmark_name
3935
"""
36+
super().__init__(benchmark_name, display_name or benchmark_name.upper())
4037
self.benchmark_name = benchmark_name
41-
self.display_name = display_name or benchmark_name.upper()
42-
43-
# Environment variables
44-
self.therock_bin_dir = os.getenv("THEROCK_BIN_DIR")
45-
self.artifact_run_id = os.getenv("ARTIFACT_RUN_ID")
46-
self.amdgpu_families = os.getenv("AMDGPU_FAMILIES")
4738
self.script_dir = Path(__file__).resolve().parent
48-
self.therock_dir = Path(__file__).resolve().parents[4]
49-
50-
# Initialize test client (will be set in run())
51-
self.client = None
52-
53-
def execute_command(
54-
self, cmd: List[str], log_file_handle: IO, env: Dict[str, str] = None
55-
) -> int:
56-
"""Execute a command and stream output to log file.
57-
58-
Args:
59-
cmd: Command list to execute
60-
log_file_handle: File handle to write output
61-
env: Optional environment variables to set
62-
63-
Returns:
64-
Exit code from the command
65-
"""
66-
log.info(f"++ Exec [{self.therock_dir}]$ {shlex.join(cmd)}")
67-
log_file_handle.write(f"{shlex.join(cmd)}\n")
68-
69-
# Merge custom env with current environment
70-
process_env = os.environ.copy()
71-
if env:
72-
process_env.update(env)
73-
74-
process = subprocess.Popen(
75-
cmd,
76-
cwd=self.therock_dir,
77-
stdout=subprocess.PIPE,
78-
stderr=subprocess.STDOUT,
79-
text=True,
80-
bufsize=1,
81-
env=process_env,
82-
)
83-
84-
for line in process.stdout:
85-
log.info(line.strip())
86-
log_file_handle.write(f"{line}")
87-
88-
process.wait()
89-
return process.returncode
90-
91-
def _detect_gpu_count(self) -> int:
92-
"""Detect the number of available GPUs using HardwareDetector.
93-
94-
Returns:
95-
Number of GPUs detected
96-
97-
Raises:
98-
RuntimeError: If no GPUs detected or detection fails
99-
"""
100-
try:
101-
detector = HardwareDetector()
102-
gpu_list = detector.detect_gpu()
103-
gpu_count = len(gpu_list)
104-
105-
if gpu_count == 0:
106-
raise RuntimeError(
107-
"No GPUs detected. Benchmarks require at least one GPU. "
108-
"Ensure ROCm drivers are installed and GPU devices are accessible."
109-
)
110-
111-
log.info(f"Detected {gpu_count} GPU(s)")
112-
return gpu_count
113-
114-
except RuntimeError:
115-
# Re-raise RuntimeError as-is
116-
raise
117-
except Exception as e:
118-
raise RuntimeError(
119-
f"Failed to detect GPUs: {e}. "
120-
"Ensure ROCm drivers are installed and GPU devices are accessible."
121-
) from e
12239

12340
def _validate_openmpi(self) -> None:
12441
"""Check if OpenMPI is installed and available in the system.
@@ -143,7 +60,11 @@ def create_test_result(
14360
flag: str,
14461
**kwargs,
14562
) -> Dict[str, Any]:
146-
"""Create a standardized test result dictionary.
63+
"""Create a standardized benchmark test result dictionary.
64+
65+
Overrides ExtendedTestBase.create_test_result to enforce benchmark-specific
66+
required fields (score, unit, flag) and provide defaults for
67+
batch_size and ngpu.
14768
14869
Args:
14970
test_name: Benchmark name
@@ -157,91 +78,22 @@ def create_test_result(
15778
Returns:
15879
Dict[str, Any]: Test result dictionary with test data and configuration
15980
"""
160-
# Extract common parameters with defaults
161-
batch_size = kwargs.get("batch_size", 0)
162-
ngpu = kwargs.get("ngpu", 1)
163-
164-
# Build test config with all parameters
165-
test_config = {
166-
"test_name": test_name,
167-
"sub_test_name": subtest_name,
168-
"python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
169-
"environment_dependencies": [],
170-
"batch_size": batch_size,
171-
"ngpu": ngpu,
172-
}
173-
174-
# Add any additional kwargs to test_config
175-
for key, value in kwargs.items():
176-
if key not in ["batch_size", "ngpu"]:
177-
test_config[key] = value
178-
179-
return {
180-
"test_name": test_name,
181-
"subtest": subtest_name,
182-
"batch_size": batch_size,
183-
"ngpu": ngpu,
184-
"status": status,
185-
"score": float(score),
186-
"unit": unit,
187-
"flag": flag,
188-
"test_config": test_config,
189-
}
190-
191-
def calculate_statistics(
192-
self, test_results: List[Dict[str, Any]]
193-
) -> Dict[str, Any]:
194-
"""Calculate test statistics from results.
195-
196-
Args:
197-
test_results: List of test result dictionaries with 'status' key
198-
199-
Returns:
200-
Dictionary with:
201-
- passed: Number of passed tests
202-
- failed: Number of failed tests
203-
- total: Total number of tests
204-
- overall_status: 'PASS' if no failures, else 'FAIL'
205-
"""
206-
passed = sum(1 for r in test_results if r.get("status") == "PASS")
207-
failed = sum(1 for r in test_results if r.get("status") == "FAIL")
208-
overall_status = "PASS" if failed == 0 else "FAIL"
209-
210-
return {
211-
"passed": passed,
212-
"failed": failed,
213-
"total": len(test_results),
214-
"overall_status": overall_status,
215-
}
216-
217-
def upload_results(
218-
self, test_results: List[Dict[str, Any]], stats: Dict[str, Any]
219-
) -> bool:
220-
"""Upload results to API and save locally."""
221-
log.info("Uploading Results to API")
222-
success = self.client.upload_results(
223-
test_name=f"{self.benchmark_name}_benchmark",
224-
test_results=test_results,
225-
test_status=stats["overall_status"],
226-
test_metadata={
227-
"artifact_run_id": self.artifact_run_id,
228-
"amdgpu_families": self.amdgpu_families,
229-
"benchmark_name": self.benchmark_name,
230-
"total_subtests": stats["total"],
231-
"passed_subtests": stats["passed"],
232-
"failed_subtests": stats["failed"],
233-
},
234-
save_local=True,
235-
output_dir=str(self.script_dir / "results"),
81+
# Extract benchmark-specific parameters with defaults
82+
batch_size = kwargs.pop("batch_size", 0)
83+
ngpu = kwargs.pop("ngpu", 1)
84+
85+
return super().create_test_result(
86+
test_name=test_name,
87+
subtest_name=subtest_name,
88+
status=status,
89+
score=float(score),
90+
unit=unit,
91+
flag=flag,
92+
batch_size=batch_size,
93+
ngpu=ngpu,
94+
**kwargs,
23695
)
23796

238-
if success:
239-
log.info("Results uploaded successfully")
240-
else:
241-
log.info("Results saved locally only (API upload disabled or failed)")
242-
243-
return success
244-
24597
def compare_with_lkg(self, tables: Any) -> Any:
24698
"""Compare results with Last Known Good baseline."""
24799
log.info("Comparing results with LKG")
@@ -319,10 +171,6 @@ def run(self) -> int:
319171
"""Execute benchmark workflow and return exit code (0=PASS, 1=FAIL)."""
320172
log.info(f"Initializing {self.display_name} Benchmark Test")
321173

322-
# Initialize extended test client and print system info
323-
self.client = ExtendedTestClient(auto_detect=True)
324-
self.client.print_system_summary()
325-
326174
# Run benchmarks (implemented by child class)
327175
self.run_benchmarks()
328176

@@ -338,7 +186,18 @@ def run(self) -> int:
338186
log.info(f"Test Summary: {stats['passed']} passed, {stats['failed']} failed")
339187

340188
# Upload results
341-
self.upload_results(test_results, stats)
189+
self.upload_results(
190+
test_results=test_results,
191+
stats=stats,
192+
test_type="benchmark",
193+
output_dir=str(self.script_dir / "results"),
194+
extra_metadata={
195+
"benchmark_name": self.benchmark_name,
196+
"total_subtests": stats["total"],
197+
"passed_subtests": stats["passed"],
198+
"failed_subtests": stats["failed"],
199+
},
200+
)
342201

343202
# Compare with LKG (compares each table individually and prints results)
344203
final_tables = self.compare_with_lkg(tables)

tests/extended_tests/benchmark/scripts/test_rccl_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
sys.path.insert(0, str(Path(__file__).resolve().parents[2])) # For extended_tests/utils
1818
sys.path.insert(0, str(Path(__file__).parent)) # For benchmark_base
1919
from benchmark_base import BenchmarkBase, run_benchmark_main
20+
from github_actions_utils import get_visible_gpu_count
2021
from utils.logger import log
2122

2223

@@ -26,7 +27,7 @@ class RCCLBenchmark(BenchmarkBase):
2627
def __init__(self):
2728
super().__init__(benchmark_name="rccl", display_name="RCCL")
2829
self.log_file = self.script_dir / "rccl_bench.log"
29-
self.ngpu = self._detect_gpu_count()
30+
self.ngpu = get_visible_gpu_count(therock_bin_dir=self.therock_bin_dir)
3031

3132
# Validate OpenMPI is available (from base class)
3233
self._validate_openmpi()

tests/extended_tests/utils/README.md

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ Utility modules organized into logical subdirectories for maintainability and sc
77
```
88
extended_tests/utils/
99
├── __init__.py # Public exports
10-
├── extended_test_client.py # Main ExtendedTestClient API
10+
├── extended_test_base.py # ExtendedTestBase - shared base class for all tests
11+
├── extended_test_client.py # ExtendedTestClient - system detection & result reporting
1112
├── constants.py # Framework constants
1213
├── exceptions.py # Custom exceptions
1314
├── logger.py # Logging configuration
@@ -36,12 +37,28 @@ extended_tests/utils/
3637

3738
## Usage
3839

39-
### From Benchmark Scripts
40+
### From Extended Test Base Classes
4041

41-
Benchmark scripts add `extended_tests/` to `sys.path`, then import:
42+
Both `BenchmarkBase` and `FunctionalBase` inherit from `ExtendedTestBase`, which provides
43+
shared infrastructure (command execution, GPU detection, result creation, statistics, uploads):
4244

4345
```python
44-
# Import path setup (already done in benchmark_base.py)
46+
# In benchmark_base.py / functional_base.py
47+
from utils.extended_test_base import ExtendedTestBase
48+
49+
50+
class BenchmarkBase(ExtendedTestBase): ...
51+
52+
53+
class FunctionalBase(ExtendedTestBase): ...
54+
```
55+
56+
### From Test Scripts
57+
58+
Test scripts add `extended_tests/` to `sys.path`, then import:
59+
60+
```python
61+
# Import path setup (already done in base classes)
4562
sys.path.insert(
4663
0, str(Path(__file__).resolve().parents[2])
4764
) # Adds extended_tests/ to path
@@ -52,6 +69,7 @@ from utils.constants import Constants
5269
from utils.exceptions import ConfigurationError
5370

5471
# Main API classes
72+
from utils.extended_test_base import ExtendedTestBase
5573
from utils.extended_test_client import ExtendedTestClient
5674
from utils.system.system_detector import SystemDetector
5775
from utils.config.config_helper import ConfigHelper
@@ -80,10 +98,11 @@ from utils.results import ResultsHandler, ResultsAPI
8098

8199
### Root Level
82100

101+
- **extended_test_base.py** - `ExtendedTestBase` shared base class for benchmark and functional tests (command execution, GPU detection, test result creation, statistics, result uploads)
102+
- **extended_test_client.py** - `ExtendedTestClient` API for system detection and result reporting
83103
- **constants.py** - Framework constants and defaults
84104
- **exceptions.py** - Custom exception classes
85105
- **logger.py** - Logging configuration
86-
- **extended_test_client.py** - Main ExtendedTestClient API
87106

88107
### Config
89108

tests/extended_tests/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
"exceptions",
2222
# Main API
2323
"ExtendedTestClient",
24+
# Shared test base class
25+
"ExtendedTestBase",
2426
# Commonly used exports
2527
"SystemContext",
2628
"SystemDetector",

0 commit comments

Comments
 (0)