@@ -234,18 +234,19 @@ cu12 = [
234234 " torch==2.12.0+cu126 ; sys_platform == 'linux'" ,
235235 " numba-cuda[cu12] ; platform_system != 'Darwin'" ,
236236 " cuda-python>=12,<13 ; platform_system != 'Darwin'" ,
237+ " transformer-engine[pytorch,core_cu12]==2.15"
237238]
238239
239240cu13 = [
240241 " torch==2.12.0+cu132 ; sys_platform == 'linux'" ,
241242 " numba-cuda[cu13] ; platform_system != 'Darwin'" ,
242243 " cuda-python>=13,<14 ; platform_system != 'Darwin'" ,
244+ " transformer-engine[pytorch,core_cu13]==2.15"
243245]
244246
245247compiled = [
246248 " onnx-ir==0.2.1" ,
247249 " onnxscript==0.7.0" ,
248- " transformer-engine==2.15" ,
249250 " deep_ep==1.2.1" ,
250251 " nv-grouped-gemm==1.1.4.post8" ,
251252 " causal-conv1d==1.6.2.post1" ,
@@ -258,7 +259,6 @@ compiled = [
258259compiled-a100 = [
259260 " onnx-ir==0.2.1" ,
260261 " onnxscript==0.7.0" ,
261- " transformer-engine==2.15" ,
262262 " nv-grouped-gemm==1.1.4.post8" ,
263263 " causal-conv1d==1.6.2.post1" ,
264264 " mamba-ssm==2.3.2.post1" ,
@@ -458,6 +458,7 @@ no-build-isolation-package = [
458458 " mamba-ssm" ,
459459 " nv-grouped-gemm" ,
460460 " transformer-engine" ,
461+ " transformer-engine-torch"
461462]
462463
463464# --- uv configuration ---
0 commit comments