diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 38d0e33b2..7de72576f 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -571,9 +571,8 @@ def do_nothing(max_batch_size, max_seq_length): # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. - from torch._inductor.package import load_package - aoti_compiled_model = load_package( + aoti_compiled_model = torch._inductor.aoti_load_package( str(builder_args.aoti_package_path.absolute()) ) diff --git a/torchchat/export.py b/torchchat/export.py index 979778b7c..cf817e476 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -8,10 +8,10 @@ from typing import Optional import torch +import torch._inductor import torch.nn as nn from torch.export import Dim -import torch._inductor from torchchat.cli.builder import ( _initialize_model, @@ -68,20 +68,24 @@ def export_for_server( with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): metadata = {} # TODO: put more metadata here - options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata} + options = {"aot_inductor.metadata": metadata} if not package: options = {"aot_inductor.output_path": output_path} - path = torch._export.aot_compile( + ep = torch.export.export( model, example_inputs, dynamic_shapes=dynamic_shapes, - options=options, ) if package: - from torch._inductor.package import package_aoti - path = package_aoti(output_path, path) + path = torch._inductor.aoti_compile_and_package( + ep, package_path=output_path, inductor_configs=options + ) + else: + path = torch._inductor.aot_compile( + ep.module(), example_inputs, options=options + ) print(f"The generated packaged model can be found at: {path}") return path