|
23 | 23 | "tensorflow-inference-neuronx", |
24 | 24 | ] |
25 | 25 | DISPLAY_NAMES = { |
26 | | - "pytorch-inference-neuronx": "PyTorch Inference (Neuronx)", |
27 | | - "pytorch-training-neuronx": "PyTorch Training (Neuronx)", |
28 | | - "tensorflow-inference-neuronx": "TensorFlow Inference (Neuronx)", |
| 26 | + "pytorch-inference-neuronx": "PyTorch Inference (NeuronX)", |
| 27 | + "pytorch-training-neuronx": "PyTorch Training (NeuronX)", |
| 28 | + "tensorflow-inference-neuronx": "TensorFlow Inference (NeuronX)", |
29 | 29 | } |
30 | 30 | COLUMNS = ["Framework", "Python", "SDK", "Accelerator", "Platform", "Example URL"] |
31 | 31 |
|
32 | 32 |
|
33 | 33 | def parse_tag(tag: str) -> dict: |
34 | 34 | """Parse Neuron tag format: 2.8.0-neuronx-py311-sdk2.26.1-ubuntu22.04""" |
35 | | - result = {"version": "", "sdk": "", "python": ""} |
| 35 | + result = {"version": "", "sdk": "", "python": "", "accelerator": ""} |
36 | 36 |
|
37 | 37 | match = re.match( |
38 | 38 | r"^(\d+\.\d+\.\d+)-" # framework version |
39 | | - r"neuronx-" # neuronx marker |
| 39 | + r"(neuronx)-" # neuronx marker |
40 | 40 | r"(py\d+)-" # python |
41 | 41 | r"sdk([\d.]+)-", # sdk version |
42 | 42 | tag, |
43 | 43 | ) |
44 | 44 | if match: |
45 | 45 | result["version"] = match.group(1) |
46 | | - result["python"] = match.group(2) |
47 | | - result["sdk"] = match.group(3) |
| 46 | + result["accelerator"] = "NeuronX" if match.group(2) == "neuronx" else "-" |
| 47 | + result["python"] = match.group(3) |
| 48 | + result["sdk"] = match.group(4) |
48 | 49 |
|
49 | 50 | return result |
50 | 51 |
|
@@ -75,7 +76,7 @@ def generate(yaml_data: dict) -> str: |
75 | 76 | f"{framework_name} {parsed['version']}", |
76 | 77 | parsed["python"], |
77 | 78 | parsed["sdk"], |
78 | | - "Neuron", |
| 79 | + parsed["accelerator"], |
79 | 80 | "SageMaker", |
80 | 81 | build_ecr_url(repo_key, tag), |
81 | 82 | ] |
|
0 commit comments