Skip to content

Commit 6382fe5

Browse files
committed
fix neuronx key
Signed-off-by: sirutBuasai <sirutbuasai27@outlook.com>
1 parent 21bb087 commit 6382fe5

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

docs/src/tables/huggingface_pytorch_table.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
DISPLAY_NAMES = {
2828
"huggingface-pytorch-training": "HuggingFace PyTorch Training",
2929
"huggingface-pytorch-inference": "HuggingFace PyTorch Inference",
30-
"huggingface-pytorch-inference-neuronx": "HuggingFace PyTorch Inference (Neuronx)",
31-
"huggingface-pytorch-training-neuronx": "HuggingFace PyTorch Training (Neuronx)",
30+
"huggingface-pytorch-inference-neuronx": "HuggingFace PyTorch Inference (NeuronX)",
31+
"huggingface-pytorch-training-neuronx": "HuggingFace PyTorch Training (NeuronX)",
3232
"huggingface-pytorch-trcomp-training": "HuggingFace PyTorch Training Compiler",
3333
}
3434
COLUMNS_STANDARD = [
@@ -74,7 +74,7 @@ def parse_tag(tag: str) -> dict:
7474
result["version"] = match.group(1)
7575
result["transformers"] = match.group(2)
7676
accel = match.group(3)
77-
result["accelerator"] = "Neuron" if accel == "neuronx" else accel.upper()
77+
result["accelerator"] = "NeuronX" if accel == "neuronx" else accel.upper()
7878
result["python"] = match.group(4)
7979
extra = match.group(5) or ""
8080
if extra.startswith("cu"):
@@ -96,7 +96,7 @@ def generate(yaml_data: dict) -> str:
9696
if not tags:
9797
continue
9898

99-
is_neuron = "neuronx" in repo_key
99+
is_neuron = "neuron" in repo_key
100100
columns = COLUMNS_NEURON if is_neuron else COLUMNS_STANDARD
101101

102102
rows = []

docs/src/tables/neuron_table.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,29 @@
2323
"tensorflow-inference-neuronx",
2424
]
2525
DISPLAY_NAMES = {
26-
"pytorch-inference-neuronx": "PyTorch Inference (Neuronx)",
27-
"pytorch-training-neuronx": "PyTorch Training (Neuronx)",
28-
"tensorflow-inference-neuronx": "TensorFlow Inference (Neuronx)",
26+
"pytorch-inference-neuronx": "PyTorch Inference (NeuronX)",
27+
"pytorch-training-neuronx": "PyTorch Training (NeuronX)",
28+
"tensorflow-inference-neuronx": "TensorFlow Inference (NeuronX)",
2929
}
3030
COLUMNS = ["Framework", "Python", "SDK", "Accelerator", "Platform", "Example URL"]
3131

3232

3333
def parse_tag(tag: str) -> dict:
3434
"""Parse Neuron tag format: 2.8.0-neuronx-py311-sdk2.26.1-ubuntu22.04"""
35-
result = {"version": "", "sdk": "", "python": ""}
35+
result = {"version": "", "sdk": "", "python": "", "accelerator": ""}
3636

3737
match = re.match(
3838
r"^(\d+\.\d+\.\d+)-" # framework version
39-
r"neuronx-" # neuronx marker
39+
r"(neuronx)-" # neuronx marker
4040
r"(py\d+)-" # python
4141
r"sdk([\d.]+)-", # sdk version
4242
tag,
4343
)
4444
if match:
4545
result["version"] = match.group(1)
46-
result["python"] = match.group(2)
47-
result["sdk"] = match.group(3)
46+
result["accelerator"] = "NeuronX" if match.group(2) == "neuronx" else "-"
47+
result["python"] = match.group(3)
48+
result["sdk"] = match.group(4)
4849

4950
return result
5051

@@ -75,7 +76,7 @@ def generate(yaml_data: dict) -> str:
7576
f"{framework_name} {parsed['version']}",
7677
parsed["python"],
7778
parsed["sdk"],
78-
"Neuron",
79+
parsed["accelerator"],
7980
"SageMaker",
8081
build_ecr_url(repo_key, tag),
8182
]

0 commit comments

Comments
 (0)