Skip to content

Commit

Permalink
bugfix: fix geneate_dispatch_inc args from parser (#870)
Browse files Browse the repository at this point in the history
head_dim args from setup.py and parser in
aot_build_utils/generate_dispatch_inc.py missmatch

Co-authored-by: baowending.bwd <[email protected]>
  • Loading branch information
baowendin and baowending.bwd authored Feb 18, 2025
1 parent 7e06dc0 commit 78dde79
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions aot_build_utils/generate_dispatch_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
parser.add_argument(
"--path", type=str, required=True, help="Path to the dispatch inc file"
)
parser.add_argument(
"--head_dims_sm90", type=str, required=True, nargs="+", help="Head dimensions in format of 'head_dim_qk,head_dim_vo'",
)
parser.add_argument(
"--head_dims", type=int, required=True, nargs="+", help="Head dimensions"
)
Expand All @@ -124,6 +127,7 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
help="Mask modes",
)
args = parser.parse_args()
args.head_dims_sm90 = [tuple(map(int, x.split(","))) for x in args.head_dims_sm90]
print(args)
with open(Path(args.path), "w") as f:
f.write(get_dispatch_inc_str(args))

0 comments on commit 78dde79

Please sign in to comment.