diff --git a/pysages/nlist/__init__.py b/pysages/nlist/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pysages/nlist/cell_list.py b/pysages/nlist/cell_list.py new file mode 100644 index 00000000..d411065a --- /dev/null +++ b/pysages/nlist/cell_list.py @@ -0,0 +1,167 @@ +from typing import Union + + +import jax +from jax import numpy as np + +def _tuple_to_idx(tup: jax.Array, cell_edge: jax.Array) -> Union[np.int32, jax.Array]: + """ + Covnert cell index from tuple to scalar. + + Args: + tup (jax.Array): [index in x, index in y, index in z] + cell_edge (jax.Array): Number of cells in each dimension + + Returns: + np.int32 or jax.Array: Scalar index of cell or (N, ) array of scalar indices + """ + return np.int32(tup[...,0]*cell_edge[1]*cell_edge[2] + tup[...,1]*cell_edge[2] + tup[...,2]) + +def _idx_to_tuple(idx: int, cell_edge: jax.Array) -> jax.Array: + """ + Convert cell index from scalar to tuple. + + Args: + idx (int): Scalar index of cell + cell_edge (jax.Array): Number of cells in each dimension + + Returns: + jax.Array: [index in x, index in y, index in z] (3,) or (N, 3) + """ + x: np.int32 = idx//(cell_edge[1]*cell_edge[2]) + y: np.int32 = (idx//cell_edge[2])%cell_edge[1] + z: np.int32 = idx%cell_edge[2] + return np.concatenate([x, y, z], axis=-1, dtype=np.int32) + +def get_cell_list(pos: jax.Array, box_size: jax.Array, cutoff: float) -> jax.Array: + """ + Initialize the cell list. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + box_size (Tuple): box size (3, ) + cutoff (float): cutoff distance for neighbor list (scalar) + + Returns: + cell_idx (jax.Array): cell index for each particle (N, ) + """ + #setup the box parameters + cell_edge = np.floor(box_size/cutoff) # (3, ) + cell_cut = box_size/cell_edge # (3, ) + # get the cell ids + cell_tuples = pos//cell_cut + cell_idx = _tuple_to_idx(cell_tuples, cell_edge) + return cell_idx + +def _wrap_cell_ids(cell_ids: jax.Array, cell_edge: np.int32) -> jax.Array: + """ + Wraps the cell ids of particles in edge cells. (single dimension) + + Args: + cell_ids (jax.Array): Array of tuple cell ids in the current dimension for each particle (N, 1) + cell_edge (np.int32): Number of cells in current dimension + + Returns: + jax.Array: Wrapped cell ids (tuple) for each particle (N, 3) + """ + out_of_bound_low = (cell_ids == -1) # if cell id is -1 (out of bound from below) + out_of_bound_high = (cell_ids == cell_edge) # if cell id equal to the number of cells in that dimension (out of bound from above) + cell_ids = np.where(out_of_bound_low, cell_edge-1, cell_ids) # if out of bound, then wrap around from below + cell_ids = np.where(out_of_bound_high, 0, cell_ids) # if out of bound, then wrap around from above + return cell_ids + +def _get_neighbor_box(ids: jax.Array, cell_edge: jax.Array) -> jax.Array: + """ + Wrap the tuple cell ids of particles for each neighbor. (helper function for get_neighbor_ids to use with vmap) + + Args: + ids (jax.Array): Array of tuple cell ids (N, 3) + cell_edge (jax.Array): Array of number of cells in each dimension (3, ) + + Returns: + jax.Array: Wrapped tuple cell ids (N, 3) + """ + i, j, k = ids + x = _wrap_cell_ids(i, cell_edge[0]) + y = _wrap_cell_ids(j, cell_edge[1]) + z = _wrap_cell_ids(k, cell_edge[2]) + return np.asarray([x, y, z]) + +def get_neighbor_ids(box_size: jax.Array, cutoff: float, cell_idx: jax.Array, idx: int, buffer_size_cell: int) -> jax.Array: + """ + Get neighbor ids for a single particle. + + Args: + box_size (Tuple): box size (3, ) + cutoff (float): cutoff distance for neighbor list (scalar) + cell_idx (jax.Array): cell index for each particle (N, ) + idx (int): index of the particle in the pos matrix (scalar) + buffer_size_cell (int): buffer size for the cell list (scalar) + + Raises: + ValueError: If the neighbor list overflows + + Returns: + jax.Array: Array of neighbor ids for the particle (N, ) + """ + cell_edge = np.floor(box_size/cutoff) # (3, ) + cell_id = cell_idx[idx] # index of the cell that the particle is in scalar + cell_id = np.expand_dims(cell_id, axis=0) # scalar to (1, ) + + cell_tuple = _idx_to_tuple(cell_id, cell_edge) # tuple of the cell that the particle is in (1, dim) + + neighbor_tuples = [] + for i in [-1, 0, 1]: # loop over cells behind and ahead of the current cell in each dimension + for j in [-1, 0, 1]: + for k in [-1, 0, 1]: + neighbor_tuples.append(np.asarray([cell_tuple[0]+i, cell_tuple[1]+j, cell_tuple[2]+k])) + + neighbor_tuples = np.asarray(neighbor_tuples) # list to jax.Array (27, dim) + neighbor_tuples_wrapped = jax.vmap(_get_neighbor_box, in_axes=(0, None), out_axes=0)(neighbor_tuples, cell_edge) # wrap the cell ids of the neighbors (27, dim) + + # get scalar ids for the neighboring cells + neighbor_cell_ids = jax.vmap(_tuple_to_idx, (0, None))(neighbor_tuples_wrapped, cell_edge) + + neighbor_ids = [] # get ids of the particles in the neighboring cells. -1 is used as a filler for empty cells. + for cidx in neighbor_cell_ids: + neighbor_ids.append(np.where(cell_idx == cidx, fill_value=-1, size=buffer_size_cell)[0]) + + + # concatenate the neighbor ids into a single array. + neighbor_ids = np.concatenate(neighbor_ids, axis=-1) + return neighbor_ids + +def get_neighbors_list(box_size: jax.Array, cutoff: float, cell_idx: jax.Array, idxs: jax.Array, buffer_size_cell: int, mask_self: bool= False) -> jax.Array: + """ + Get neighbor ids for a list of particles. Uses vmap to vectorize on get_neighbor_ids function. + + Args: + cell_idx (jax.Array): cell index for each particle (N, ) + idxs (jax.Array): Array of particle indices (n, ) + mask_self (bool, optional): Whether to exclude self from neighbor list. Defaults to False. + + Raises: + ValueError: If the cell list is not initialized before calling get_neighbor_ids() + + Returns: + jax.Array: Array of neighbor ids for each particle (n, buffer_size_nl)) + """ + + # get neighbor ids for the particles + n_ids = jax.vmap(get_neighbor_ids, in_axes=(None, None, None, 0, None))(box_size, cutoff, cell_idx, idxs, buffer_size_cell) + # check for overflow + min_buffer = np.min(np.count_nonzero(n_ids == -1, axis=-1)) + if min_buffer < 27: # if there are less than 27 -1s in a row of the neighbor list, there is an overflow from buffer_size_cell + raise ValueError("Neighbor list overflow. Increase buffer_size_cell.") + # remove self from neighbor list if mask_self is True + if mask_self: + # set the self index to -1 + n_ids = n_ids.at[..., n_ids == idxs[:, None]].set(-1) + # add one to the minimum buffer size to account for the -1 just added + min_buffer += 1 + # sort + n_ids = np.sort(n_ids, axis=-1)[:, ::-1] + # truncate. Remove the -1s from the end of the neighbor list so that the row with the least -1s will have none (smallest possible neighbor list). + n_ids = n_ids[:, :-min_buffer] + + return n_ids diff --git a/pysages/nlist/testbench.ipynb b/pysages/nlist/testbench.ipynb new file mode 100644 index 00000000..edab9bf7 --- /dev/null +++ b/pysages/nlist/testbench.ipynb @@ -0,0 +1,535 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%load_ext line_profiler\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + } + ], + "source": [ + "from cell_list import get_cell_list, get_neighbors_list, get_neighbor_ids\n", + "from verlet_list import get_neighbor_ids as get_verlet_neighbor_ids\n", + "from verlet_list import get_neighborhood, _pairwise_dist\n", + "from jax import random\n", + "import jax.numpy as jnp\n", + "from jax_md.partition import neighbor_list\n", + "from jax_md.space import periodic" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting up a simple box\n", + "\n", + "This will create a $5 \\times 5 \\times 5$ box and puts 100 random particles in it." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "box_edge = 5.0\n", + "box_size = jnp.array([box_edge, box_edge, box_edge])\n", + "positions = random.uniform(random.PRNGKey(0), (int(1e2), 3))*box_edge\n", + "cutoff_c = 1.0 # cutoff for cell list\n", + "cutoff_v = 1.0 # cutoff for verlet list\n", + "buffer = 30" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cell list method\n", + "This will create a cell list which breaks the box into 125 boxes of edge size 1 (`cutoff_c`)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "cell_list = get_cell_list(positions, box_size, cutoff_c)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method 1\n", + "\n", + "Using `get_neighbors_list` we can get the neighbors for a list of particles. Under the hood, this functions calls `vmap` on another function `get_neighbor_ids` which is implemented for a single particle (+ some postprocessing)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[99 95 93 85 78 77 75 71 70 69 68 63 62 56 38 32 31 30 12 2 0 -1 -1 -1\n", + " -1 -1 -1 -1]\n", + " [99 98 94 88 82 80 79 73 67 66 58 57 54 44 43 42 41 34 29 24 22 20 11 10\n", + " 8 7 4 1]\n", + " [96 94 93 89 76 71 65 55 47 39 37 25 21 14 9 2 0 -1 -1 -1 -1 -1 -1 -1\n", + " -1 -1 -1 -1]\n", + " [92 81 80 79 75 72 60 56 54 48 42 40 35 33 30 28 25 24 23 18 17 12 11 3\n", + " -1 -1 -1 -1]\n", + " [99 98 94 88 82 80 79 73 67 66 58 57 54 44 43 42 41 34 29 24 22 20 11 10\n", + " 8 7 4 1]\n", + " [96 91 86 83 69 68 61 59 55 52 50 45 38 32 27 26 19 16 15 13 12 10 7 5\n", + " -1 -1 -1 -1]\n", + " [94 89 88 87 69 62 45 41 39 38 33 28 22 8 6 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", + " -1 -1 -1 -1]\n", + " [99 95 88 85 82 78 70 65 58 57 51 50 45 41 34 31 22 20 10 8 7 5 4 1\n", + " -1 -1 -1 -1]\n", + " [94 90 89 88 87 82 69 67 58 51 45 41 39 38 34 28 22 20 10 8 7 6 4 1\n", + " -1 -1 -1 -1]\n", + " [99 92 86 84 78 74 71 70 64 63 60 52 37 36 35 15 14 9 2 -1 -1 -1 -1 -1\n", + " -1 -1 -1 -1]]\n" + ] + } + ], + "source": [ + "idxs = jnp.asarray([i for i in range(10)]) # only the first 10 particles\n", + "nbors = get_neighbors_list(box_size=box_size, cutoff=cutoff_c, cell_idx=cell_list, idxs=idxs, buffer_size_cell=buffer, mask_self=False)\n", + "print(nbors) # padded with -1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method 2\n", + "\n", + "Using `get_neighbor_ids` directly with an external for loop and the postprocessing outside of the neighbor list." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[99 38 69 85 62 70 78 31 95 32 68 63 71 2 77 93 30 56 75 12 43 11 24 42\n", + " 67 44 73 94 29 66 54 79 80 22 88 4 34 58 82 57 8 41 20 98 99 7 10 9\n", + " 14 37 47 65 0 76 71 21 93 55 96 25 89 39 94 30 56 75 12 60 35 92 40 33\n", + " 28 17 18 11 24 42 25 48 81 72 54 79 80 23 43 11 24 42 67 44 73 94 29 66\n", + " 54 79 80 22 88 1 34 58 82 57 8 41 20 98 99 7 10 38 69 7 10 45 15 16\n", + " 83 32 68 50 52 86 27 91 59 55 96 13 12 19 26 61 89 39 94 33 28 22 88 87\n", + " 8 41 38 69 62 45 22 88 1 4 34 58 82 57 8 41 20 51 99 10 45 85 65 70\n", + " 78 31 5 50 95 89 67 39 94 28 90 22 88 1 4 34 58 82 6 87 41 20 51 38\n", + " 69 7 10 45 64 84 36 15 99 74 52 86 14 37 70 78 60 35 92 63 71 2]\n" + ] + } + ], + "source": [ + "nbors_2 = []\n", + "for i in idxs:\n", + " n_i = get_neighbor_ids(box_size=box_size, cutoff=cutoff_c, cell_idx=cell_list, idx=i, buffer_size_cell=buffer)\n", + " n_i = n_i[n_i != i]\n", + " n_i = n_i[n_i != -1]\n", + " nbors_2.append(n_i)\n", + "print(jnp.concatenate(nbors_2, axis=0)) # no padding" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Verlet list\n", + "\n", + "Verlet list can be used together with a cell list to get exact cutoffs. Cell list will capture all particles within a `cutoff_c`$\\times$`cutoff_c`$\\times$`cutoff_c` box, regardless of exact cutoff selected. This hybrid method is more efficient than calculating pairwise distance for all particles and doesn't require a skin radius." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method 1\n", + "Here we can get the particle ids of the neighbors of particle $0$ and get its exact neighborhood (`cutoff_v`) with a verlet list using `get_neighborhood`. This functions returns a list of `bool`s that shows whether the particles in the cell list neighbor list are within the verlet cutoff as well." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[93 70 63 31 0]\n" + ] + } + ], + "source": [ + "atom0_n_ids = nbors[0][nbors[0] != -1] # get neighbor ids of atom 0 and remove padding\n", + "neighbors = get_neighborhood(positions[atom0_n_ids, :], positions[0, :], cutoff_v, box_size)\n", + "print(atom0_n_ids[jnp.where(neighbors)]) # note that 0 is in the list too." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method 2\n", + "We can also get the neighborhood for all particles in a position matrix. The result is a $N \\times N$ matrix (or a list of size $N$ if `sparse = True`)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ True False False ... False False False]\n", + " [False True False ... False False False]\n", + " [False False True ... False False False]\n", + " ...\n", + " [False False False ... True False False]\n", + " [False False False ... False True False]\n", + " [False False False ... False False True]]\n" + ] + } + ], + "source": [ + "n_ids_verlet = get_verlet_neighbor_ids(positions, cutoff_v, box_size, sparse=False)\n", + "print(n_ids_verlet)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[31 63 70 93]\n", + "[ 4 10 58 66 82 88]\n", + "[76]\n", + "[11 17 35 60 72 92]\n", + "[ 1 34 43 54 58 66 88]\n", + "[45 50 91]\n", + "[ 8 38 41 87]\n", + "[10]\n", + "[ 6 89]\n", + "[14 37]\n", + "[ 1 7 41]\n", + "[ 3 17 24 35 42 72 81 92]\n", + "[18]\n", + "[55 96]\n", + "[ 9 37]\n", + "[16 36]\n", + "[15]\n", + "[ 3 11 33 48 60 72 81]\n", + "[12]\n", + "[27 52 86]\n", + "[41]\n", + "[96]\n", + "[23 58 66 88]\n", + "[22 88]\n", + "[11 34 42 43 44 57 67 73 92]\n", + "[89]\n", + "[61 90 97]\n", + "[19 47 52 83 91]\n", + "[87]\n", + "[33 40 75]\n", + "[56 60 74 75 95]\n", + "[ 0 63 70 77]\n", + "[45 68]\n", + "[17 29 56 60 75]\n", + "[ 4 24 42 43 44 57 58 66 67 82]\n", + "[ 3 11 60 92]\n", + "[15]\n", + "[ 9 14]\n", + "[ 6 45 69]\n", + "[89]\n", + "[29]\n", + "[ 6 10 20]\n", + "[11 24 34 44 57 73 92]\n", + "[ 4 24 34 44 58 67]\n", + "[24 34 42 43 57 67 73 92]\n", + "[ 5 32 38 69]\n", + "[76]\n", + "[27 83 98]\n", + "[17 72 79 81 84]\n", + "[51 53]\n", + "[ 5 59]\n", + "[49]\n", + "[19 27 86]\n", + "[49]\n", + "[ 4 57 66 79 80]\n", + "[13 96]\n", + "[30 33 60 74 75 95]\n", + "[24 34 42 44 54 73]\n", + "[ 1 4 22 34 43 66 67 82 88]\n", + "[50 64 74]\n", + "[ 3 17 30 33 35 56]\n", + "[26 90 97]\n", + "[]\n", + "[ 0 31 70 77 78]\n", + "[59 74 84 95]\n", + "[98]\n", + "[ 1 4 22 34 54 58 88]\n", + "[24 34 43 44 58 82]\n", + "[32 86]\n", + "[38 45]\n", + "[ 0 31 63 78 99]\n", + "[92]\n", + "[ 3 11 17 48 81]\n", + "[24 42 44 57]\n", + "[30 56 59 64 95]\n", + "[29 30 33 56 77]\n", + "[ 2 46]\n", + "[31 63 75]\n", + "[63 70]\n", + "[48 54 80]\n", + "[54 79 81]\n", + "[11 17 48 72 80]\n", + "[ 1 34 58 67]\n", + "[27 47 91]\n", + "[48 64 85]\n", + "[84]\n", + "[19 52 68]\n", + "[ 6 28]\n", + "[ 1 4 22 23 58 66]\n", + "[ 8 25 39]\n", + "[26 61 97]\n", + "[ 5 27 83]\n", + "[ 3 11 24 35 42 44 71]\n", + "[0]\n", + "[]\n", + "[30 56 64 74]\n", + "[13 21 55]\n", + "[26 61 90]\n", + "[47 65]\n", + "[70]\n" + ] + } + ], + "source": [ + "n_ids_verlet_s = get_verlet_neighbor_ids(positions, cutoff_v, box_size, sparse=True, mask_self=True)\n", + "for particle in n_ids_verlet_s:\n", + " print(particle)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparison with `jax_md` neighbor list" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define a periodic box and create the neighbor list" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "disp_fn, shift_fn = periodic(box_size)\n", + "nb_fn = neighbor_list(disp_fn, box_size, cutoff_v)\n", + "nbs = nb_fn.allocate(positions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check that all the returned neighbors are the same from `verlet_list` and `jax_md`." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0: OK\n", + "1: OK\n", + "2: OK\n", + "3: OK\n", + "4: OK\n", + "5: OK\n", + "6: OK\n", + "7: OK\n", + "8: OK\n", + "9: OK\n", + "10: OK\n", + "11: OK\n", + "12: OK\n", + "13: OK\n", + "14: OK\n", + "15: OK\n", + "16: OK\n", + "17: OK\n", + "18: OK\n", + "19: OK\n", + "20: OK\n", + "21: OK\n", + "22: OK\n", + "23: OK\n", + "24: OK\n", + "25: OK\n", + "26: OK\n", + "27: OK\n", + "28: OK\n", + "29: OK\n", + "30: OK\n", + "31: OK\n", + "32: OK\n", + "33: OK\n", + "34: OK\n", + "35: OK\n", + "36: OK\n", + "37: OK\n", + "38: OK\n", + "39: OK\n", + "40: OK\n", + "41: OK\n", + "42: OK\n", + "43: OK\n", + "44: OK\n", + "45: OK\n", + "46: OK\n", + "47: OK\n", + "48: OK\n", + "49: OK\n", + "50: OK\n", + "51: OK\n", + "52: OK\n", + "53: OK\n", + "54: OK\n", + "55: OK\n", + "56: OK\n", + "57: OK\n", + "58: OK\n", + "59: OK\n", + "60: OK\n", + "61: OK\n", + "62: OK\n", + "63: OK\n", + "64: OK\n", + "65: OK\n", + "66: OK\n", + "67: OK\n", + "68: OK\n", + "69: OK\n", + "70: OK\n", + "71: OK\n", + "72: OK\n", + "73: OK\n", + "74: OK\n", + "75: OK\n", + "76: OK\n", + "77: OK\n", + "78: OK\n", + "79: OK\n", + "80: OK\n", + "81: OK\n", + "82: OK\n", + "83: OK\n", + "84: OK\n", + "85: OK\n", + "86: OK\n", + "87: OK\n", + "88: OK\n", + "89: OK\n", + "90: OK\n", + "91: OK\n", + "92: OK\n", + "93: OK\n", + "94: OK\n", + "95: OK\n", + "96: OK\n", + "97: OK\n", + "98: OK\n", + "99: OK\n", + "All OK\n" + ] + } + ], + "source": [ + "for idx, nbl in enumerate(n_ids_verlet_s):\n", + " for n in nbl:\n", + " assert n in nbs.idx[idx]\n", + " for n in nbs.idx[idx]:\n", + " if n != 100:\n", + " assert n in nbl\n", + " print(f\"{idx}: OK\")\n", + "print(\"All OK\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pysages/nlist/verlet_list.py b/pysages/nlist/verlet_list.py new file mode 100644 index 00000000..09e23f86 --- /dev/null +++ b/pysages/nlist/verlet_list.py @@ -0,0 +1,96 @@ +import jax +import jax.numpy as np + +def _get_dist(i: jax.Array, j: jax.Array, box_size: jax.Array) -> np.float32: + """ + Calculate the distance between two particles. (helper function for _pairwise_dist to use with vmap) + + Args: + i (jax.Array): Position of particle i (3, ) + j (jax.Array): Position of particle j (3, ) + + Returns: + np.float32: Distance between particles i and j (scalar) + """ + dx = i[0] - j[0] + dx = np.where(dx > box_size[0]/2, dx - box_size[0], dx) + dx = np.where(dx < -box_size[0]/2, dx + box_size[0], dx) + dy = i[1] - j[1] + dy = np.where(dy > box_size[1]/2, dy - box_size[1], dy) + dy = np.where(dy < -box_size[1]/2, dy + box_size[1], dy) + dz = i[2] - j[2] + dz = np.where(dz > box_size[2]/2, dz - box_size[2], dz) + dz = np.where(dz < -box_size[2]/2, dz + box_size[2], dz) + return np.sqrt(dx**2 + dy**2 + dz**2) + +def _pairwise_dist(pos: jax.Array, ref: jax.Array, box_size: jax.Array) -> jax.Array: + """ + Calculate the pairwise distance between particles in pos and a single reference particle. (helper function for get_neighbor_ids to use with vmap) + + Args: + pos (jax.Array): position of particles (N, 3) + ref (jax.Array): position of reference particle (3, ) + + Returns: + jax.Array: array of distances between particles and reference particle (N, ) + """ + return jax.vmap(_get_dist, (0, None, None))(pos, ref, box_size) + +def _is_neighbor(dist: jax.Array, cutoff: float) -> jax.Array: + """ + Check if a particle is a neighbor of the reference particle. (helper function for get_neighbor_ids to use with vmap) + + Args: + dist (jax.Array): Array of distances between particles and reference particle (N, ) + cutoff (float): cutoff distance for neighbor list (scalar) + + Returns: + jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) + """ + return dist < cutoff + +def get_neighbor_ids(pos: jax.Array, cutoff: float, box_size: jax.Array, sparse: bool = False, mask_self: bool = False) -> jax.Array: + """ + Get neighbor ids for each particle in pos matrix. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + sparse (bool, optional): Whether to return the full (N, N) matrix of neighborhood or an Array. Defaults to False. + mask_self (bool, optional): Whether to exclude self from neighbor list. Defaults to False. + + Returns: + jax.Array: Array of neighbor ids for each particle (N, ) or (N, N) matrix of bools indicating whether a particle is a neighbor of another particle. + """ + # calculate the pairwise distances between all particles + pair_dists = jax.vmap(_pairwise_dist, (None, 0, None))(pos, pos, box_size) + # check if a particle is a neighbor of another particle based on the cutoff distance + is_neighbor = jax.vmap(_is_neighbor, (0, None))(pair_dists, cutoff) + # remove self from neighbor list if mask_self is True + if mask_self: + i, j = np.diag_indices(is_neighbor.shape[0]) + is_neighbor = is_neighbor.at[..., i, j].set(False) + # return a list of arrays if sparse is True + if sparse: # return a list of arrays + neighbor_list = [] + for row in is_neighbor: + neighbor_list.append(np.where(row)[0]) + return neighbor_list + + return is_neighbor # return a NxN array of bools + +def get_neighborhood(pos: jax.Array, ref: jax.Array, cutoff: float, box_size: jax.Array) -> jax.Array: + """ + Get the neighborhood of a reference particle. + + Args: + pos (jax.Array): Array of particle positions (N, 3) + ref (jax.Array): position of reference particle (3, ) + + Returns: + jax.Array: Array of bools indicating whether a particle is a neighbor of the reference particle (N, ) + """ + # calculate the pairwise distances between all particles and the reference particle + pair_dists = _pairwise_dist(pos, ref, box_size) + # check if a particle is a neighbor of the reference particle based on the cutoff distance + is_neighbor = _is_neighbor(pair_dists, cutoff) + return is_neighbor \ No newline at end of file