diff --git a/jqc/backend/cart2sph.py b/jqc/backend/cart2sph.py index cd06c09..1eb9faf 100644 --- a/jqc/backend/cart2sph.py +++ b/jqc/backend/cart2sph.py @@ -32,8 +32,12 @@ with open(f"{cuda_path}/common/sph2cart.cu") as f: sph2cart_scripts = f.read() +with open(f"{cuda_path}/common/cart2cart.cu") as f: + cart2cart_scripts = f.read() + _cart2sph_kernel_cache = {} _sph2cart_kernel_cache = {} +_cart2cart_kernel = None def cart2sph(dm_cart, angs, cart_offset, sph_offset, nao_sph, out=None): @@ -278,30 +282,49 @@ def cart2cart(dm_src, angs, src_offset, dst_offset, nao, out=None): else: dst_offset_np = np.asarray(dst_offset) - # Build comprehensive mapping for all AO pairs - nbas = len(src_offset_np) - 1 # offset arrays have nbas+1 elements + nbas = len(src_offset_np) - 1 + nao_src = dm_src_cp.shape[-1] - # First, build a mapping from each internal AO to its source AO - internal_to_source = [] + # Build index mapping arrays on CPU + src_indices = [] + dst_indices = [] for s in range(nbas): ang_s = int(angs[s]) nf = (ang_s + 1) * (ang_s + 2) // 2 - src_row0 = int(src_offset_np[s]) - dst_row0 = int(dst_offset_np[s]) + src_start = int(src_offset_np[s]) + dst_start = int(dst_offset_np[s]) for f in range(nf): - internal_to_source.append((dst_row0 + f, src_row0 + f)) + src_idx = src_start + f + dst_idx = dst_start + f + if src_idx < nao_src and dst_idx < nao: + src_indices.append(src_idx) + dst_indices.append(dst_idx) + + src_indices_cp = cp.asarray(src_indices, dtype=cp.int32) + dst_indices_cp = cp.asarray(dst_indices, dtype=cp.int32) + n_idx = len(dst_indices) + + # Load CUDA kernel (cached) + global _cart2cart_kernel + if _cart2cart_kernel is None: + mod = cp.RawModule(code=cart2cart_scripts, options=compile_options) + _cart2cart_kernel = mod.get_function("cart2cart") + + # Launch kernel for each density matrix + threads = (16, 16) + blocks = ((n_idx + threads[0] - 1) // threads[0], (n_idx + threads[1] - 1) // threads[1]) - # Now copy all matrix elements using the mapping for b in range(ndms): - for dst_i, src_i in internal_to_source: - for dst_j, src_j in internal_to_source: - if ( - dst_i < nao - and dst_j < nao - and src_i < dm_src_cp.shape[-1] - and src_j < dm_src_cp.shape[-1] - ): - dm_dst[b][dst_i, dst_j] += dm_src_cp[b][src_i, src_j] + args = ( + dm_dst[b], + dm_src_cp[b], + nao, + nao_src, + n_idx, + dst_indices_cp, + src_indices_cp, + ) + _cart2cart_kernel(blocks, threads, args) return dm_dst diff --git a/jqc/backend/common/cart2cart.cu b/jqc/backend/common/cart2cart.cu new file mode 100644 index 0000000..4506afe --- /dev/null +++ b/jqc/backend/common/cart2cart.cu @@ -0,0 +1,36 @@ +/* +# Copyright 2025 ByteDance Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +*/ + +extern "C" __global__ +void cart2cart(double *dst, const double* src, + const int nao_dst, const int nao_src, const int n_idx, + const int* __restrict__ dst_indices, + const int* __restrict__ src_indices){ + + const int i = blockIdx.x * blockDim.x + threadIdx.x; + const int j = blockIdx.y * blockDim.y + threadIdx.y; + + if (i >= n_idx || j >= n_idx) return; + + const int src_i = src_indices[i]; + const int src_j = src_indices[j]; + const int dst_i = dst_indices[i]; + const int dst_j = dst_indices[j]; + + const double val = src[src_i * nao_src + src_j]; + atomicAdd(&dst[dst_i * nao_dst + dst_j], val); +}