|
1 | 1 | import jax |
2 | 2 | import jax.numpy as jnp |
3 | 3 | from jax import jit, value_and_grad |
4 | | -from flax import linen as nn |
5 | 4 | from functools import partial |
6 | 5 | import numpy as np |
7 | | -import MDAnalysis as mda |
8 | 6 | from openmm import * |
9 | 7 | from openmm.unit import * |
10 | 8 | from openmm.app import * |
11 | 9 |
|
12 | 10 | from ase import Atoms |
13 | | -from ase.io import read, Trajectory, write |
| 11 | +from ase.io import read, write |
14 | 12 | from ase.calculators.calculator import Calculator, all_changes |
15 | | -from ase.md.velocitydistribution import MaxwellBoltzmannDistribution |
16 | | -from ase.md.verlet import VelocityVerlet |
17 | 13 | from ase.stress import full_3x3_to_voigt_6_stress |
18 | | -from ase.optimize import BFGS |
19 | 14 |
|
20 | 15 | from ase import units |
21 | | -from ase.md.npt import NPT |
22 | | -from ase.md.nptberendsen import NPTBerendsen |
23 | | -from ase.md.langevin import Langevin |
24 | | -from ase.md.nose_hoover_chain import NoseHooverChainNVT |
25 | | -from ase.md import MDLogger |
26 | | -from ase.io.trajectory import Trajectory |
27 | 16 | from dmff.api import Hamiltonian |
28 | 17 | from dmff.common import nblist |
29 | | -from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales |
30 | | -from dmff.admp.pairwise import distribute_scalar, distribute_v3 |
31 | | -from dmff.admp.spatial import pbc_shift |
32 | | -import pickle |
33 | | -import time |
34 | | - |
35 | | -from dmff.sgnn.gnn import MolGNNForce |
36 | | -# from gnn import MolGNNForce |
37 | | -from dmff.sgnn.graph import TopGraph, from_pdb |
38 | | -# from graph import TopGraph, from_pdb |
39 | | - |
40 | | -# from eapnn import * |
41 | | - |
42 | | -# from jax import config |
43 | | -# config.update("jax_enable_x64", True) |
44 | | -# config.update("jax_debug_nans", True) |
45 | 18 |
|
46 | 19 | def get_atoms_box(atoms): |
47 | 20 | box = atoms.get_cell() / 10.0 |
|
0 commit comments