-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[WIP][GPU] Add dropout to jit path for matmul #4514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
src/gpu/intel/gemm/with_post_ops.cpp
Outdated
| const bool with_wei_scales | ||
| = !attr_scales.has_default_values(DNNL_ARG_WEIGHTS); | ||
| const bool with_dst_scales = !attr_scales.has_default_values(DNNL_ARG_DST); | ||
| const bool with_dropout = !attr()->dropout_.has_defaul_values(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| const bool with_dropout = !attr()->dropout_.has_defaul_values(); | |
| const bool with_dropout = !attr()->dropout_.has_default_values(); |
src/gpu/intel/gemm/with_post_ops.cpp
Outdated
| kernel_ctx.define_int("USE_OFFSET", use_offset); | ||
| kernel_ctx.define_int("USE_HOST_SCALARS", use_host_scalars); | ||
| kernel_ctx.define_int("HAS_OUTPUT_MASK", has_output_mask); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you name these DROPOUT_* so it's clearer in the kernel what's being controlled by these macros? Specifically, USE_HOST_SCALARS is confusing as there are other scalars that may be supplied on the host or device side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They all actually come withing the WITH_DROPOUT macro, so it is going to be pretty clear. But I added the DROPOUT_* to HOST_SCALARS for clarity
src/gpu/intel/gemm/with_post_ops.cpp
Outdated
| arg_list.set(idx++, GEMM_CTX_ARG_STORAGE(dropout_seed)); | ||
| arg_list.set(idx++, GEMM_CTX_ARG_STORAGE(dropout_offset)); | ||
| arg_list.set(idx, GEMM_CTX_ARG_STORAGE(dropout_prob)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these need to be set in gpu/intel/matmul/gemm.cpp similarly to the other gemm:exec_args_t values.
131cfc2 to
f4fabd4
Compare
3b947f2 to
83991a2
Compare
83991a2 to
bc6ff50
Compare
This is the last part of enabling dropout as request from GRAPH API. This will enable SDPA on the fast path.