Skip to content

Commit 7c34043

Browse files
committed
add packaging to aoti
1 parent 479b24b commit 7c34043

5 files changed

+241
-9
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

_package_aoti.py

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import glob
2+
import os
3+
import pathlib
4+
from typing import Any, Callable, Dict, List, Optional, Tuple
5+
6+
import torch
7+
import torch._inductor
8+
import torch.utils._pytree as pytree
9+
from torch.export._tree_utils import reorder_kwargs
10+
from torch.export import ExportedProgram
11+
from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact
12+
13+
14+
from _pt2_archive_constants import (
15+
AOTINDUCTOR_DIR,
16+
ARCHIVE_ROOT_NAME,
17+
CONSTANTS_DIR,
18+
MODELS_FILENAME_FORMAT,
19+
SAMPLE_INPUTS_DIR,
20+
WEIGHTS_DIR,
21+
)
22+
23+
24+
ARCHIVE_VERSION = 0
25+
26+
class PT2ArchiveWriter:
27+
def __init__(self, archive_path: str):
28+
self.archive_file = torch._C.PyTorchFileWriter(archive_path)
29+
self.archive_file.set_min_version(ARCHIVE_VERSION)
30+
self.write_string("archive_format", "pt2")
31+
32+
def __enter__(self):
33+
return self
34+
35+
def __exit__(self, *args):
36+
self.close()
37+
38+
def write_bytes(self, name: str, data: bytes) -> None:
39+
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
40+
self.archive_file.write_record(name, data, len(data))
41+
42+
def write_string(self, name: str, data: str) -> None:
43+
assert isinstance(data, str), f"Expected string but got {type(data)}"
44+
data_bytes = data.encode()
45+
self.write_bytes(name, data_bytes)
46+
47+
def write_file(self, name: str, file_path: str) -> None:
48+
"""
49+
Copy a file into the archive.
50+
name: The destination file inside the archive.
51+
file_path: The source file on disk.
52+
"""
53+
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
54+
55+
with open(file_path, "rb") as f:
56+
file_bytes = f.read()
57+
self.write_bytes(name, file_bytes)
58+
59+
def close(self) -> None:
60+
self.archive_file.write_end_of_file()
61+
62+
63+
class PT2ArchiveReader:
64+
def __init__(self, archive_path: str):
65+
self.archive_file = torch._C.PyTorchFileReader(archive_path)
66+
assert self.read_string("archive_format") == "pt2", "Invalid archive format"
67+
68+
def __enter__(self):
69+
return self
70+
71+
def __exit__(self, *args):
72+
# torch._C.PyTorchFileReader doesn't have a close method
73+
pass
74+
75+
def read_bytes(self, name: str) -> bytes:
76+
return self.archive_file.get_record(name)
77+
78+
def read_string(self, name: str) -> str:
79+
data = self.read_bytes(name)
80+
return data.decode()
81+
82+
def get_file_names(self) -> List[str]:
83+
return self.archive_file.get_all_records()
84+
85+
86+
def _package_exported_program(
87+
archive_writer: PT2ArchiveWriter, exported_program: ExportedProgram
88+
) -> None:
89+
exported_artifact: SerializedArtifact = serialize(exported_program)
90+
archive_writer.write_bytes(MODELS_FILENAME_FORMAT.format("model"), exported_artifact.exported_program)
91+
archive_writer.write_bytes(os.path.join(WEIGHTS_DIR, "weights.pt"), exported_artifact.state_dict)
92+
archive_writer.write_bytes(os.path.join(CONSTANTS_DIR, "constants.pt"), exported_artifact.constants)
93+
archive_writer.write_bytes(os.path.join(SAMPLE_INPUTS_DIR, "example_inputs.pt"), exported_artifact.example_inputs)
94+
95+
96+
def _package_aoti_files(archive_writer: PT2ArchiveWriter, so_path: str):
97+
cpp_file_path = so_path[:-3] + ".cpp"
98+
extern_nodes_file_path = so_path[:-3] + ".json"
99+
work_dir = pathlib.Path(so_path).parent
100+
cubin_file_paths = glob.glob(f"{work_dir}/*.cubin")
101+
102+
package_files = [so_path, cpp_file_path]
103+
package_files.extend(cubin_file_paths)
104+
105+
if os.path.isfile(extern_nodes_file_path):
106+
package_files.append(extern_nodes_file_path)
107+
108+
for path in package_files:
109+
filename = os.path.basename(path)
110+
archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path)
111+
112+
113+
def _extract_exported_program(archive_reader: PT2ArchiveReader) -> ExportedProgram:
114+
exported_program_bytes = archive_reader.read_bytes(MODELS_FILENAME_FORMAT.format("model"))
115+
state_dict_bytes = archive_reader.read_bytes(os.path.join(WEIGHTS_DIR, "weights.pt"))
116+
constants_bytes = archive_reader.read_bytes(os.path.join(CONSTANTS_DIR, "constants.pt"))
117+
example_inputs_bytes = archive_reader.read_bytes(os.path.join(SAMPLE_INPUTS_DIR, "example_inputs.pt"))
118+
119+
artifact: SerializedArtifact = SerializedArtifact(
120+
exported_program_bytes,
121+
state_dict_bytes,
122+
constants_bytes,
123+
example_inputs_bytes,
124+
)
125+
126+
deserialized_exported_program = deserialize(artifact)
127+
return deserialized_exported_program
128+
129+
130+
def _extract_so(archive_reader: PT2ArchiveReader, device: str) -> Callable:
131+
tmp_output_dir = pathlib.Path("/tmp/aotinductor_loaded_model")
132+
tmp_output_dir.mkdir(exist_ok=True)
133+
134+
file_names = archive_reader.get_file_names()
135+
aoti_files = [file for file in file_names if file.startswith(AOTINDUCTOR_DIR)]
136+
137+
so_path = None
138+
for file in aoti_files:
139+
filename = os.path.basename(file)
140+
with open(tmp_output_dir / filename, 'wb') as f:
141+
f.write(archive_reader.read_bytes(file))
142+
if file.endswith('.so'):
143+
assert so_path is None
144+
so_path = tmp_output_dir / filename
145+
assert so_path is not None
146+
so_path = str(so_path)
147+
148+
if device == "cpu":
149+
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
150+
elif device == "cuda" or device.startswith("cuda:"):
151+
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
152+
else:
153+
raise RuntimeError("Unsupported device " + device)
154+
155+
def optimized(*args, **kwargs):
156+
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
157+
in_spec = pytree.treespec_loads(call_spec[0])
158+
out_spec = pytree.treespec_loads(call_spec[1])
159+
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
160+
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
161+
return pytree.tree_unflatten(flat_outputs, out_spec)
162+
163+
return optimized
164+
165+
166+
def aoti_compile(
167+
exported_program: ExportedProgram,
168+
args: Tuple[Any],
169+
kwargs: Optional[Dict[str, Any]] = None,
170+
*,
171+
options: Optional[Dict[str, Any]] = None,
172+
):
173+
archive_path = options["aot_inductor.output_path"]
174+
options["aot_inductor.output_path"] = ""
175+
176+
so_path = torch._inductor.aot_compile(
177+
exported_program.module(), args, kwargs, options=options
178+
)
179+
180+
with PT2ArchiveWriter(archive_path) as archive_writer:
181+
# _package_exported_program(archive_writer, exported_program)
182+
_package_aoti_files(archive_writer, so_path)
183+
184+
return archive_path
185+
186+
187+
def aoti_load(path: str, device: str):
188+
with PT2ArchiveReader(path) as archive_reader:
189+
# exported_program = _extract_exported_program(archive_reader)
190+
optimized = _extract_so(archive_reader, device)
191+
192+
return optimized

_pt2_archive_constants.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# This file codify PT2 Inference Archive Spec
2+
# https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit?usp=sharing
3+
4+
# Naming convention
5+
# *_DIR: path to a folder, e.g. "data/aotinductor/"
6+
# *_PATH: absolute path to a file, e.g. "models/merge.json"
7+
# *_FORMAT: naming format of a file, e.g. "models/{}.json"
8+
9+
ARCHIVE_ROOT_NAME: str = "package"
10+
11+
# Archive format
12+
ARCHIVE_FORMAT_PATH: str = "archive_format"
13+
14+
# Model definitions
15+
MODELS_DIR: str = "models/"
16+
MODELS_FILENAME_FORMAT: str = "models/{}.json"; # {model_name}
17+
18+
# AOTInductor artifacts
19+
AOTINDUCTOR_DIR: str = "data/aotinductor/"
20+
21+
# weights, including parameters and buffers
22+
WEIGHTS_DIR: str = "data/weights/"
23+
WEIGHT_FILENAME_PREFIX: str = "weight_"
24+
25+
# constants, including tensor_constants, non-persistent buffers and script objects
26+
CONSTANTS_DIR: str = "data/constants/"
27+
TENSOR_CONSTANT_FILENAME_PREFIX: str = "tensor_"
28+
CUSTOM_OBJ_FILENAME_PREFIX: str = "custom_obj_"
29+
30+
# sample inputs
31+
SAMPLE_INPUTS_DIR: str = "data/sample_inputs/"
32+
SAMPLE_INPUTS_FILENAME_FORMAT: str = "data/sample_inputs/{}.pt"; # {model_name}
33+
34+
# extra folder
35+
EXTRA_DIR: str = "extra/"
36+
MODULE_INFO_PATH: str = "extra/module_info.json"

export_aoti.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from model import Transformer
2121

22+
from _package_aoti import aoti_compile
23+
2224
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
2325

2426

@@ -47,11 +49,11 @@ def export_model(model: nn.Module, device, output_path, args=None):
4749
# Specify that the first dimension of each input is that batch size
4850
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
4951

50-
so = torch._export.aot_compile(
51-
model,
52-
args=input,
53-
options={"aot_inductor.output_path": output_path},
54-
dynamic_shapes=dynamic_shapes,
52+
ep = torch.export.export(
53+
model, args=input, dynamic_shapes=dynamic_shapes,
54+
)
55+
package_path = aoti_compile(
56+
ep, input, options={"aot_inductor.output_path": output_path}
5557
)
56-
print(f"The generated DSO model can be found at: {so}")
57-
return so
58+
print(f"The generated PT2 model can be found at: {package_path}")
59+
return package_path

generate.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,8 @@ def main(
362362
# attributes will NOT be seen on by AOTI-compiled forward
363363
# function, e.g. calling model.setup_cache will NOT touch
364364
# AOTI compiled and maintained model buffers such as kv_cache.
365-
model.forward = torch._export.aot_load(str(dso_path.absolute()), device)
365+
from _package_aoti import aoti_load
366+
model.forward = aoti_load(str(dso_path.absolute()), device)
366367
except:
367368
raise RuntimeError(f"Failed to load AOTI compiled {dso_path}")
368369
elif pte_path:
@@ -387,7 +388,7 @@ def main(
387388
# dtype:
388389
if model_dtype:
389390
model.to(dtype=model_dtype)
390-
391+
391392
if is_speculative:
392393
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
393394
else:

0 commit comments

Comments
 (0)