Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions jqc/backend/cart2sph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@
with open(f"{cuda_path}/common/sph2cart.cu") as f:
sph2cart_scripts = f.read()

with open(f"{cuda_path}/common/cart2cart.cu") as f:
cart2cart_scripts = f.read()

_cart2sph_kernel_cache = {}
_sph2cart_kernel_cache = {}
_cart2cart_kernel = None


def cart2sph(dm_cart, angs, cart_offset, sph_offset, nao_sph, out=None):
Expand Down Expand Up @@ -278,30 +282,49 @@ def cart2cart(dm_src, angs, src_offset, dst_offset, nao, out=None):
else:
dst_offset_np = np.asarray(dst_offset)

# Build comprehensive mapping for all AO pairs
nbas = len(src_offset_np) - 1 # offset arrays have nbas+1 elements
nbas = len(src_offset_np) - 1
nao_src = dm_src_cp.shape[-1]

# First, build a mapping from each internal AO to its source AO
internal_to_source = []
# Build index mapping arrays on CPU
src_indices = []
dst_indices = []
for s in range(nbas):
ang_s = int(angs[s])
nf = (ang_s + 1) * (ang_s + 2) // 2
src_row0 = int(src_offset_np[s])
dst_row0 = int(dst_offset_np[s])
src_start = int(src_offset_np[s])
dst_start = int(dst_offset_np[s])

for f in range(nf):
internal_to_source.append((dst_row0 + f, src_row0 + f))
src_idx = src_start + f
dst_idx = dst_start + f
if src_idx < nao_src and dst_idx < nao:
src_indices.append(src_idx)
dst_indices.append(dst_idx)

src_indices_cp = cp.asarray(src_indices, dtype=cp.int32)
dst_indices_cp = cp.asarray(dst_indices, dtype=cp.int32)
n_idx = len(dst_indices)

# Load CUDA kernel (cached)
global _cart2cart_kernel
if _cart2cart_kernel is None:
mod = cp.RawModule(code=cart2cart_scripts, options=compile_options)
_cart2cart_kernel = mod.get_function("cart2cart")

# Launch kernel for each density matrix
threads = (16, 16)
blocks = ((n_idx + threads[0] - 1) // threads[0], (n_idx + threads[1] - 1) // threads[1])

# Now copy all matrix elements using the mapping
for b in range(ndms):
for dst_i, src_i in internal_to_source:
for dst_j, src_j in internal_to_source:
if (
dst_i < nao
and dst_j < nao
and src_i < dm_src_cp.shape[-1]
and src_j < dm_src_cp.shape[-1]
):
dm_dst[b][dst_i, dst_j] += dm_src_cp[b][src_i, src_j]
args = (
dm_dst[b],
dm_src_cp[b],
nao,
nao_src,
n_idx,
dst_indices_cp,
src_indices_cp,
)
_cart2cart_kernel(blocks, threads, args)

return dm_dst
36 changes: 36 additions & 0 deletions jqc/backend/common/cart2cart.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
# Copyright 2025 ByteDance Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
*/

extern "C" __global__
void cart2cart(double *dst, const double* src,
const int nao_dst, const int nao_src, const int n_idx,
const int* __restrict__ dst_indices,
const int* __restrict__ src_indices){

const int i = blockIdx.x * blockDim.x + threadIdx.x;
const int j = blockIdx.y * blockDim.y + threadIdx.y;

if (i >= n_idx || j >= n_idx) return;

const int src_i = src_indices[i];
const int src_j = src_indices[j];
const int dst_i = dst_indices[i];
const int dst_j = dst_indices[j];

const double val = src[src_i * nao_src + src_j];
atomicAdd(&dst[dst_i * nao_dst + dst_j], val);
}