Skip to content

Commit f410ce4

Browse files
authored
Merge pull request #6 from sxm13/main
add "keep_connect"
2 parents 6ea18e6 + ce8efac commit f410ce4

File tree

8 files changed

+3078
-39
lines changed

8 files changed

+3078
-39
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ pip install -r requirements.txt
3636

3737
```sh
3838
from PACMANCharge import pmcharge
39-
pmcharge.predict(cif_file="./test/Cu-BTC.cif",charge_type="DDEC6",digits=10,atom_type=True,neutral=True)
39+
pmcharge.predict(cif_file="./test/Cu-BTC.cif",charge_type="DDEC6",digits=10,atom_type=True,neutral=True,keep_connect=True)
4040

4141
```
4242
4343
**Terminal**
4444
```sh
45-
python pmcharge.py folder-name[path] --charge_type[DDEC6/Bader/CM5/REPEAT] --digits[int] --atom_type[bool] --neutral[bool]
45+
python pmcharge.py folder-name[path] --charge_type[DDEC6/Bader/CM5/REPEAT] --digits[int] --atom_type[bool] --neutral[bool] --keep_connect[bool]
4646
```
4747
**Example command:** ```python pmcharge.py test_file/test-1/ --charge_type DDEC6 --digits 10```
4848

@@ -53,6 +53,7 @@ python pmcharge.py folder-name[path] --charge_type[DDEC6/Bader/CM5/REPEAT] --dig
5353
* digits (default: 6): number of decimal places to print for partial atomic charges. ML models were trained on a 6-digit dataset
5454
* atom-type (default: True): Default is to keep the same partial atomic charge for the same atom types (based on the similarity of partial atomic charges up to 3 decimal places)
5555
* neutral (default: True): Default is to keep the net charge is zero. We use "mean" method to neuralize the system where the excess charges are equally distributed across all atoms
56+
* keep_connect (default: True): retain the atomic and connection information (such as _atom_site_adp_type, bond) for the structure.
5657

5758
# Website & Zenodo
5859
* Predict partial atomic charges using an online APP :point_right: [link](https://pacman-charge-mtap.streamlit.app/)

model/cif2data.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import os
22
import re
3+
import glob
34
import json
45
import warnings
56
import numpy as np
67
import pandas as pd
78
from tqdm import tqdm
9+
from CifFile import ReadCif
810
import pymatgen.core as mg
911
from ase.io import read
1012
from ase import neighborlist
@@ -207,6 +209,14 @@ def get_ddec_data(root_cif_dir,dataset_csv,save_ddec_dir):
207209
np.save(save_ddec_dir + mof + '.npy', ddec_data)
208210
f.close()
209211

212+
def get_ddec_data_from_qmof(root_cif_dir,dataset_csv,save_ddec_dir):
213+
mofs = pd.read_csv(dataset_csv)["name"]
214+
for mof in mofs:
215+
data_cif = ReadCif(root_cif_dir + mof + ".cif")
216+
charge = data_cif[data_cif.keys()[0]]["_atom_site_pbe_ddec_charge"]
217+
np.save(save_ddec_dir + mof + '.npy', charge)
218+
219+
210220
def get_bader_data(root_cif_dir,dataset_csv,save_bader_dir):
211221
mofs = pd.read_csv(dataset_csv)["name"]
212222
for mof in mofs:
@@ -234,7 +244,14 @@ def get_bader_data(root_cif_dir,dataset_csv,save_bader_dir):
234244
f.close()
235245
except:
236246
pass
237-
247+
248+
def get_bader_data_from_qmof(root_cif_dir,dataset_csv,save_bader_dir):
249+
mofs = pd.read_csv(dataset_csv)["name"]
250+
for mof in mofs:
251+
data_cif = ReadCif(root_cif_dir + mof + ".cif")
252+
charge = data_cif[data_cif.keys()[0]]["_atom_site_pbe_bader_charge"]
253+
np.save(save_bader_dir + mof + '.npy', charge)
254+
238255
def get_cm5_data(root_cif_dir,dataset_csv,save_cm5_dir):
239256
mofs = pd.read_csv(dataset_csv)["name"]
240257
for mof in mofs:
@@ -263,6 +280,13 @@ def get_cm5_data(root_cif_dir,dataset_csv,save_cm5_dir):
263280
except:
264281
pass
265282

283+
def get_cm5_data_from_qmof(root_cif_dir,dataset_csv,save_cm5_dir):
284+
mofs = pd.read_csv(dataset_csv)["name"]
285+
for mof in mofs:
286+
data_cif = ReadCif(root_cif_dir + mof + ".cif")
287+
charge = data_cif[data_cif.keys()[0]]["_atom_site_pbe_cm5_charge"]
288+
np.save(save_cm5_dir + mof + '.npy', charge)
289+
266290
def get_repeat_data(root_cif_dir,save_repeat_dir):
267291
mofs = glob.glob(os.path.join(root_cif_dir, '*.cif'))
268292
for mof in tqdm(mofs[:]):
@@ -279,6 +303,13 @@ def get_repeat_data(root_cif_dir,save_repeat_dir):
279303
repeat_data.append(repeat.replace("\n",""))
280304
np.save(save_repeat_dir + mof + '.npy', repeat_data)
281305
f.close()
282-
283306
except:
284-
pass
307+
pass
308+
309+
def get_repeat_data_from_arcmof(root_cif_dir,save_repeat_dir):
310+
mofs = glob.glob(os.path.join(root_cif_dir, '*.cif'))
311+
for mof in mofs:
312+
mof = mof.replace(".cif","").split("/")[-1]
313+
data_cif = ReadCif(root_cif_dir + mof + ".cif")
314+
charge = data_cif[data_cif.keys()[0]]["_atom_site_pbe_cm5_charge"]
315+
np.save(save_repeat_dir + mof + '.npy', charge)
524 Bytes
Binary file not shown.

model4pre/cif2data.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
import numpy as np
33
import pymatgen.core as mg
4+
from CifFile import ReadCif
45
from ase.io import read,write
56
from pymatgen.io.ase import AseAtomsAdaptor
67
from pymatgen.io.cif import CifParser
@@ -146,7 +147,7 @@ def average_and_replace(numbers, di):
146147
numbers[i] = avg
147148
return numbers
148149

149-
def write4cif(mof,chg,digits,atom_type,neutral,charge_type):
150+
def write4cif(mof,chg,digits,atom_type,neutral,charge_type,keep_connect):
150151
name = mof.split('.cif')[0]
151152
chg = chg.numpy()
152153
dia = int(digits)
@@ -190,36 +191,43 @@ def write4cif(mof,chg,digits,atom_type,neutral,charge_type):
190191

191192
if neutral==False:
192193
print("net charge: "+str(sum(charges)))
194+
if keep_connect:
195+
mof = ReadCif(name + ".cif")
196+
mof.first_block().AddToLoop("_atom_site_type_symbol",{'_atom_site_charge':[str(q) for q in charges]})
197+
with open(name + "_pacman.cif", 'w') as f:
198+
f.write("# " + charge_type + "charges by PACMAN v1.3 (https://github.com/mtap-research/PACMAN-charge/)\n" +
199+
f"data_{name.split('/')[-1]}" + str(mof.first_block()))
200+
print("Compelete and save as "+ name + "_pacman.cif")
201+
else:
202+
with open(name + ".cif", 'r') as file:
203+
lines = file.readlines()
204+
lines[0] = "# "+charge_type+" charges by PACMAN v1.3 (https://github.com/mtap-research/PACMAN-charge/)\n"
205+
lines[1] = "data_" + name.split("/")[-1] + "_pacman\n"
206+
for i, line in enumerate(lines):
207+
if '_atom_site_occupancy' in line:
208+
lines.insert(i + 1, " _atom_site_charge\n")
209+
break
210+
charge_index = 0
211+
for j in range(i + 2, len(lines)):
212+
if charge_index < len(charges):
213+
lines[j] = lines[j].strip() + " " + str(charges[charge_index]) + "\n"
214+
charge_index += 1
215+
else:
216+
break
193217

194-
with open(name + ".cif", 'r') as file:
195-
lines = file.readlines()
196-
lines[0] = "# "+charge_type+" charges by PACMAN v1.1 (https://github.com/mtap-research/PACMAN-charge/)\n"
197-
lines[1] = "data_" + name.split("/")[-1] + "_pacman\n"
198-
for i, line in enumerate(lines):
199-
if '_atom_site_occupancy' in line:
200-
lines.insert(i + 1, " _atom_site_charge\n")
201-
break
202-
charge_index = 0
203-
for j in range(i + 2, len(lines)):
204-
if charge_index < len(charges):
205-
lines[j] = lines[j].strip() + " " + str(charges[charge_index]) + "\n"
206-
charge_index += 1
207-
else:
208-
break
209-
210-
with open(name + "_pacman.cif", 'w') as file:
211-
file.writelines(lines)
212-
file.close()
218+
with open(name + "_pacman.cif", 'w') as file:
219+
file.writelines(lines)
220+
file.close()
213221

214-
with open(name + "_pacman.cif", 'r') as file:
215-
content = file.read()
216-
file.close()
222+
with open(name + "_pacman.cif", 'r') as file:
223+
content = file.read()
224+
file.close()
217225

218-
new_content = content.replace('_space_group_name_H-M_alt', '_symmetry_space_group_name_H-M')
219-
new_content = new_content.replace('_space_group_IT_number', '_symmetry_Int_Tables_number')
220-
new_content = new_content.replace('_space_group_symop_operation_xyz', '_symmetry_equiv_pos_as_xyz')
226+
new_content = content.replace('_space_group_name_H-M_alt', '_symmetry_space_group_name_H-M')
227+
new_content = new_content.replace('_space_group_IT_number', '_symmetry_Int_Tables_number')
228+
new_content = new_content.replace('_space_group_symop_operation_xyz', '_symmetry_equiv_pos_as_xyz')
221229

222-
with open(name + "_pacman.cif", 'wb') as file:
223-
file.write(new_content.encode('utf-8'))
224-
file.close()
225-
print("Compelete and save as "+ name + "_pacman.cif")
230+
with open(name + "_pacman.cif", 'wb') as file:
231+
file.write(new_content.encode('utf-8'))
232+
file.close()
233+
print("Compelete and save as "+ name + "_pacman.cif")

pmcharge.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ def main():
1717
parser.add_argument('--digits', type=int, default=6, help='Number of decimal places to print for partial atomic charges')
1818
parser.add_argument('--atom_type', type=bool, default=True, help='Keep the same partial atomic charge for the same atom types')
1919
parser.add_argument('--neutral', type=bool, default=True, help='Keep the net charge is zero')
20+
parser.add_argument('--keep_connect', type=bool, default=True, help='Keep information from original CIF file')
2021
args = parser.parse_args()
2122

2223
path = args.folder_name
2324
charge_type = args.charge_type
2425
digits = args.digits
2526
atom_type = args.atom_type
2627
neutral = args.neutral
28+
keep_connect = args.keep_connect
2729
if os.path.isfile(path):
2830
print("please input a folder, not a file")
2931
elif os.path.isdir(path):
@@ -65,6 +67,7 @@ def main():
6567
print("Digits: " + str(digits))
6668
print("Atom Type:" + str(atom_type))
6769
print("Neutral: " + str(neutral))
70+
print("Keep Connect: " + str(keep_connect))
6871

6972
cif_files = glob.glob(os.path.join(path, '*.cif'))
7073
print("writing cif: ***_pacman.cif")
@@ -74,7 +77,10 @@ def main():
7477
i = 0
7578
for cif in tqdm(cif_files):
7679
try:
77-
ase_format(cif)
80+
if keep_connect:
81+
pass
82+
else:
83+
ase_format(cif)
7884
cif_data = CIF2json(cif)
7985
pos = pre4pre(cif)
8086
# num_atom = n_atom(cif)
@@ -114,7 +120,7 @@ def main():
114120
# model_bandgap.eval()
115121

116122
gcn = GCN(chg_1-3, chg_2, 128, 7, 256,5)
117-
chkpt = torch.load(model_charge_name, map_location=torch.device(device), weights_only=True)
123+
chkpt = torch.load(model_charge_name, map_location=torch.device(device)) #, weights_only=True
118124
model4chg = SemiFullGN(chg_1,chg_2,128,8,256)
119125
model4chg.to(device)
120126
model4chg.load_state_dict(chkpt['state_dict'])
@@ -146,7 +152,7 @@ def main():
146152

147153
chg = model4chg(*input_var2)
148154
chg = charge_nor.denorm(chg.data.cpu())
149-
write4cif(cif,chg,digits,atom_type,neutral,charge_type)
155+
write4cif(cif,chg,digits,atom_type,neutral,charge_type,keep_connect)
150156
except:
151157
print("Fail predict: " + cif)
152158
fail[str(i)]=[cif]

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ tqdm>=4.15
55
pandas>=0.20.3
66
scikit-learn>=0.19.1
77
joblib>= 0.13.2
8-
torch
8+
torch
9+
PyCifRW

0 commit comments

Comments
 (0)