@@ -117,12 +117,20 @@ def run(self):
117117 return
118118 precmd = " "
119119 paralleJobs = MachineCPUCount
120+ user_tests = os.environ.get(" IB_TEST_TYPE" , " " )
121+ if self.xType == " GPU" :
122+ user_tests = " cuda"
123+ elif self.xType == " ROCM" :
124+ user_tests = " rocm"
120125 if (" ASAN" in os.environ[" CMSSW_VERSION" ]) or (" UBSAN" in os.environ[" CMSSW_VERSION" ]):
121126 paralleJobs = int(MachineCPUCount / 2)
122- if (self.xType == " GPU" ) or (" _GPU_X" in os.environ[" CMSSW_VERSION" ]):
123- precmd = " export USER_UNIT_TESTS=cuda ; "
124- if (self.xType == " ROCM" ) or (" _ROCM_X" in os.environ[" CMSSW_VERSION" ]):
125- precmd = " export USER_UNIT_TESTS=rocm ; "
127+ if user_tests == " " :
128+ if " _GPU_X" in os.environ[" CMSSW_VERSION" ]:
129+ user_tests = " cuda"
130+ elif " _ROCM_X" in os.environ[" CMSSW_VERSION" ]:
131+ user_tests = " rocm"
132+ if user_tests != " " :
133+ precmd = " export USER_UNIT_TESTS=%s ; " % user_tests
126134 skiptests = " "
127135 if " lxplus" in getHostName():
128136 skiptests = " SKIP_UNITTESTS=ExpressionEvaluatorUnitTest"
@@ -432,6 +440,10 @@ def doTest(self, only=None):
432440 print(" \n " + 80 * " -" + " gpu_unit \n " )
433441 self.threadList[" gpu_unit" ] = self.runUnitTests([], " GPU" )
434442
443+ if only and " rocm_unit" in only:
444+ print(" \n " + 80 * " -" + " rocm_unit \n " )
445+ self.threadList[" rocm_unit" ] = self.runUnitTests([], " ROCM" )
446+
435447 if not only or " codeRules" in only:
436448 print(" \n " + 80 * " -" + " codeRules \n " )
437449 self.threadList[" codeRules" ] = self.runCodeRulesChecker()
0 commit comments