Skip to content

add triton_group_norm#4

Open
nishirong wants to merge 19 commits intozhoutianzi666:jieru_tritonfrom
nishirong:group_norm_aot
Open

add triton_group_norm#4
nishirong wants to merge 19 commits intozhoutianzi666:jieru_tritonfrom
nishirong:group_norm_aot

Conversation

@nishirong
Copy link
Copy Markdown

PR types

PR changes

Describe

实现 triton group norm


tune_and_invoke_part_with_two_kernels = tune_and_invoke_part.replace("${op_name}", "${first_kernel_name}").replace("run_triton_kernel", "run_triton_first_kernel") + tune_and_invoke_part.replace("${op_name}", "${second_kernel_name}").replace("run_triton_kernel", "run_triton_second_kernel")

# tune_and_invoke_part_with_two_kernels = """
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果下面的东西的确可以由上面的语句产生,就把注释的语句都删了吧

Comment thread python/paddle/incubate/tt/triton_ops.py Outdated
offset_block = tl.arange(0, BLOCK_SIZE_M)
data_start = batch_id * batch_stride + group_id * group_stride
sample_ptrs = sample_ptr + data_start + offset_channel[:, None] * channel_stride + offset_block[None, :] * hw_stride + block_id * BLOCK_SIZE_M
# 计算均值
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成英文

Comment thread python/paddle/incubate/tt/triton_ops.py Outdated
_sum = tl.sum(sample)

_sum_squares = tl.sum(sample * sample)
output_start = batch_id * group_num + group_id + tl.arange(0,1)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确定需要加上tl.arange(0,1)吗?

zhoutianzi666 pushed a commit that referenced this pull request Sep 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants