Skip to content

XQA always sets dynamic smem size on rank 0 #2494

@zackangelo

Description

@zackangelo

flashinfer/csrc/xqa/mha.cu

Lines 2589 to 2594 in 9bf007d

static uint32_t configureKernel() {
uint32_t size;
cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize));
cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size);
return size;
}

I believe this will always set dynamic shared memory size on GPU 0 when the library loads. This makes it difficult to integrate in multi-GPU environments where a PID has access to more than one GPU because we can't control where this value is applied.

If this value is applied on the incorrect GPU all kernel launches will fail with a CUDA invalid argument error.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions