From d684eb0ac58ffa9e52497bd8fffb9c5caa7b9643 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Tue, 24 Feb 2026 15:42:11 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E5=8E=9F?= =?UTF-8?q?=E8=AE=BA=E6=96=87=E4=B8=AD=E4=BD=BF=E7=94=A8=E7=9A=84=E4=B8=A4?= =?UTF-8?q?=E4=B8=AA=E6=95=B0=E6=8D=AE=E9=9B=86=E7=9B=B8=E5=85=B3=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/datasets/ECDFormerDataset/__init__.py | 407 +++++++++ .../datasets/ECDFormerDataset/colored_tqdm.py | 76 ++ .../ECDFormerDataset/compound_tools.py | 828 ++++++++++++++++++ ppmat/datasets/ECDFormerDataset/dataloader.py | 56 ++ ppmat/datasets/ECDFormerDataset/eval_func.py | 147 ++++ ppmat/datasets/ECDFormerDataset/place_env.py | 173 ++++ ppmat/datasets/ECDFormerDataset/util_func.py | 43 + ppmat/datasets/IRDataset/__init__.py | 439 ++++++++++ ppmat/datasets/IRDataset/colored_tqdm.py | 76 ++ ppmat/datasets/IRDataset/compound_tools.py | 828 ++++++++++++++++++ ppmat/datasets/IRDataset/place_env.py | 173 ++++ ppmat/datasets/__init__.py | 7 + 12 files changed, 3253 insertions(+) create mode 100644 ppmat/datasets/ECDFormerDataset/__init__.py create mode 100644 ppmat/datasets/ECDFormerDataset/colored_tqdm.py create mode 100644 ppmat/datasets/ECDFormerDataset/compound_tools.py create mode 100644 ppmat/datasets/ECDFormerDataset/dataloader.py create mode 100644 ppmat/datasets/ECDFormerDataset/eval_func.py create mode 100644 ppmat/datasets/ECDFormerDataset/place_env.py create mode 100644 ppmat/datasets/ECDFormerDataset/util_func.py create mode 100644 ppmat/datasets/IRDataset/__init__.py create mode 100644 ppmat/datasets/IRDataset/colored_tqdm.py create mode 100644 ppmat/datasets/IRDataset/compound_tools.py create mode 100644 ppmat/datasets/IRDataset/place_env.py diff --git a/ppmat/datasets/ECDFormerDataset/__init__.py b/ppmat/datasets/ECDFormerDataset/__init__.py new file mode 100644 index 00000000..7934e0ee --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/__init__.py @@ -0,0 +1,407 @@ +# __init__.py +""" +ECDFormer数据集加载模块 +""" + +import os +import numpy as np +import pandas as pd +import paddle +from paddle.io import Dataset +from paddle_geometric.data import Data + +from .compound_tools import get_atom_feature_dims, get_bond_feature_dims +from .util_func import normalize_func +from .eval_func import get_sequence_peak +from .colored_tqdm import ColoredTqdm as tqdm +from .place_env import PlaceEnv +from .dataloader import ECDFormerDataset_DataLoader + + +# ----------------Commonly-used Parameters---------------- +atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", +] +bond_id_names = ["bond_dir", "bond_type", "is_in_ring"] +full_atom_feature_dims = get_atom_feature_dims(atom_id_names) +full_bond_feature_dims = get_bond_feature_dims(bond_id_names) +bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] +column_specify={ + 'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4], + 'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9], + 'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14], + 'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19], + 'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24] +} +bond_float_names = [] + + +def get_key_padding_mask(tokens): + """生成query padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + + +def Construct_dataset(dataset, data_index, path): + """ + 从原始特征构建图数据 + 完全复用原型程序的Construct_dataset逻辑 + """ + graph_atom_bond = [] + graph_bond_angle = [] + + all_descriptor = np.load(os.path.join(path, 'descriptor_all_column.npy')) # (25847, 1826) + + for i in tqdm(range(len(dataset)), desc="Constructing graphs"): + data = dataset[i] + + # 收集原子特征 + atom_feature = [] + for name in atom_id_names: + atom_feature.append(data[name]) + + # 收集键特征 + bond_feature = [] + for name in bond_id_names[0:3]: + bond_feature.append(data[name]) + + # 转换为Tensor + atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') + bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') + bond_float_feature = paddle.to_tensor(data['bond_length'].astype(np.float32)) + bond_angle_feature = paddle.to_tensor(data['bond_angle'].astype(np.float32)) + edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') + bond_index = paddle.to_tensor(data['BondAngleGraph_edges'].T, dtype='int64') + data_index_int = paddle.to_tensor(np.array(data_index[i]), dtype='int64') + + # 添加描述符特征(与原型程序完全一致) + TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 + RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] + RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] + MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] + MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] + + # 合并特征 + bond_feature = paddle.concat( + [bond_feature.astype(bond_float_feature.dtype), + bond_float_feature.reshape([-1, 1])], + axis=1 + ) + + bond_angle_feature = paddle.concat( + [bond_angle_feature.reshape([-1, 1]), TPSA.reshape([-1, 1])], + axis=1 + ) + bond_angle_feature = paddle.concat([bond_angle_feature, RASA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, RPSA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MDEC.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MATS.reshape([-1, 1])], axis=1) + + # 创建Data对象 + data_atom_bond = Data( + x=atom_feature, + edge_index=edge_index, + edge_attr=bond_feature, + data_index=data_index_int, + ) + data_bond_angle = Data( + edge_index=bond_index, + edge_attr=bond_angle_feature, + num_nodes=atom_feature.shape[0] + ) + + graph_atom_bond.append(data_atom_bond) + graph_bond_angle.append(data_bond_angle) + + return graph_atom_bond, graph_bond_angle + + +def read_total_ecd(sample_path, fix_length=20): + """ + 读取所有ECD光谱文件,提取峰值信息 + 完全复用原型程序的read_total_ecd逻辑 + """ + filepaths = [ + os.path.join(sample_path, "500ECD/data/"), + os.path.join(sample_path, "501-2000ECD/data/"), + os.path.join(sample_path, "2k-6kECD/data/"), + os.path.join(sample_path, "6k-8kECD/data/"), + os.path.join(sample_path, "8k-11kECD/data/"), + ] + + ecd_dict = {} + ecd_original_dict = {} + + for filepath in filepaths: + if not os.path.exists(filepath): + continue + files = os.listdir(filepath) + for file in files: + if not file.endswith(".csv"): + continue + fileid = int(file[:-4]) + single_file_path = os.path.join(filepath, file) + ECD_info = pd.read_csv(single_file_path).to_dict(orient='list') + wavelengths_o, mdegs_o = ECD_info['Wavelength (nm)'], ECD_info['ECD (Mdeg)'] + + wavelengths = [int(i) for i in wavelengths_o] + # 将小值置零 + mdegs = [int(i) if abs(i) > 1 else 0 for i in mdegs_o] + + # 去除前后零值 + begin, end = 0, 0 + for i in range(len(mdegs)): + if mdegs[i] != 0: + begin = i + break + for i in range(len(mdegs) - 1, 0, -1): + if mdegs[i] != 0: + end = i + break + + ecd_dict[fileid] = { + 'wavelengths': wavelengths[begin: end + 1], + 'ecd': mdegs[begin: end + 1], + } + ecd_original_dict[fileid] = { + 'wavelengths': wavelengths, + 'ecd': mdegs, + } + + # 处理光谱序列,提取峰值 + ecd_final_list = [] + for key, itm in ecd_dict.items(): + # 等间隔采样 + distance = int(len(itm['ecd']) / (fix_length - 1)) + sequence_org = [itm['ecd'][i] for i in range(0, len(itm['ecd']), distance)][:fix_length] + + # 归一化 + sequence = normalize_func(sequence_org, norm_range=[-100, 100]) + + # padding到固定长度 + if len(sequence) < fix_length: + sequence.extend([0] * (fix_length - len(sequence))) + sequence_org.extend([0] * (fix_length - len(sequence_org))) + assert len(sequence) == fix_length + + # 生成峰值掩码 + peak_mask = [0] * len(sequence) + for i in range(1, len(sequence) - 1): + if sequence[i - 1] < sequence[i] and sequence[i] > sequence[i + 1]: + if peak_mask[i - 1] != 2: + peak_mask[i - 1] = 1 + peak_mask[i] = 2 + if peak_mask[i + 1] != 2: + peak_mask[i + 1] = 1 + if sequence[i - 1] > sequence[i] and sequence[i] < sequence[i + 1]: + if peak_mask[i - 1] != 2: + peak_mask[i - 1] = 1 + peak_mask[i] = 2 + if peak_mask[i + 1] != 2: + peak_mask[i + 1] = 1 + + # 提取峰值位置 + peak_position_list = get_sequence_peak(sequence) + peak_number = len(peak_position_list) + assert peak_number < 9, f"Peak number {peak_number} >= 9" + + # 峰值符号 + peak_height_list = [] + for i in peak_position_list: + peak_height_list.append(1 if sequence[i] >= 0 else 0) + + # padding到9个峰 + peak_position_list = peak_position_list + [-1] * (9 - peak_number) + peak_height_list = peak_height_list + [-1] * (9 - peak_number) + query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) + + tmp_dict = { + 'id': key, + 'seq': [0] + sequence, + 'seq_original': sequence_org, + 'seq_mask': peak_mask, + 'peak_num': peak_number, + 'peak_position': peak_position_list, + 'peak_height': peak_height_list, + 'query_mask': query_padding_mask.unsqueeze(0), + } + ecd_final_list.append(tmp_dict) + + ecd_final_list.sort(key=lambda x: x['id']) + return ecd_final_list, ecd_original_dict + + +def GetAtomBondAngleDataset( + sample_path, + dataset_all, + index_all, + hand_idx_dict, + line_idx_dict +): + """ + 核心函数:构建并返回切好的图数据集 + + Args: + sample_path: ECD光谱文件路径 + dataset_all: 从npy加载的info列表 + index_all: 索引列表 + hand_idx_dict: 手性对映射 + line_idx_dict: 行号映射 + + Returns: + dataset_graph_atom_bond: atom-bond图列表 + dataset_graph_bond_angle: bond-angle图列表 + """ + # 1. 读取ECD光谱序列 + ecd_sequences, ecd_original_sequences = read_total_ecd(sample_path) + + # 2. 构建图数据 + total_graph_atom_bond, total_graph_bond_angle = Construct_dataset( + dataset_all, index_all, sample_path + ) + print("Case Before Process = ", len(total_graph_atom_bond), len(total_graph_bond_angle)) + + # 3. 将光谱序列信息附加到图数据上 + dataset_graph_atom_bond, dataset_graph_bond_angle = [], [] + + for itm in ecd_sequences: + line_num = itm['id'] - 1 + atom_bond = total_graph_atom_bond[line_num] + + # 附加光谱信息 + atom_bond.sequence = paddle.to_tensor([itm['seq']]) + atom_bond.ecd_id = paddle.to_tensor(itm['id']) + atom_bond.seq_mask = paddle.to_tensor([itm['seq_mask']]) + atom_bond.seq_original = paddle.to_tensor([itm['seq_original']]) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + dataset_graph_atom_bond.append(atom_bond) + dataset_graph_bond_angle.append(total_graph_bond_angle[line_num]) + + # 4. 对映体增强:添加对映体样本 + hand_id, unnamed_id = line_idx_dict[line_num]['hand_id'], line_idx_dict[line_num]['unnamed_id'] + another_line_num = -1 + + for alternative in hand_idx_dict[hand_id]: + if alternative['unnamed_id'] != unnamed_id: + another_line_num = alternative['line_number'] + break + + assert another_line_num != -1, f"cannot find the hand info of {line_num}" + + # 对映体:光谱取反 + atom_bond_oppo = total_graph_atom_bond[another_line_num] + atom_bond_oppo.sequence = paddle.neg(paddle.to_tensor([itm['seq']])) + atom_bond_oppo.ecd_id = paddle.to_tensor(another_line_num + 1) + atom_bond_oppo.seq_mask = paddle.to_tensor([itm['seq_mask']]) + atom_bond_oppo.seq_original = paddle.neg(paddle.to_tensor([itm['seq_original']])) + atom_bond_oppo.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond_oppo.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond_oppo.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond_oppo.query_mask = itm['query_mask'] + + dataset_graph_atom_bond.append(atom_bond_oppo) + dataset_graph_bond_angle.append(total_graph_bond_angle[another_line_num]) + + total_num = len(dataset_graph_atom_bond) + print("Case After Process = ", len(dataset_graph_atom_bond), len(dataset_graph_bond_angle)) + print('=================== Data prepared ================\n') + + return dataset_graph_atom_bond, dataset_graph_bond_angle + + +_cache = None + +class ECDFormerDataset(Dataset): + """ + ECDFormer数据集类 + 返回 (atom_bond_graph, bond_angle_graph) + """ + @PlaceEnv(paddle.CPUPlace()) + def __init__(self, + path: str = "dataset/ECD", + Use_geometry_enhanced: bool = True, + Use_column_info: bool = False): + global _cache + + if _cache: + self.graph_atom_bond, self.graph_bond_angle = _cache + return + + # 保存参数 + self.path = path + self.Use_geometry_enhanced = Use_geometry_enhanced + self.Use_column_info = Use_column_info + + # 1. 加载npy文件 + print(f"Loading ECDFormer dataset from {path}") + self.ecd_dataset = np.load( + os.path.join(path, 'ecd_column_charity_new_smiles.npy'), + allow_pickle=True + ).tolist() + + # 2. 加载csv文件 + self.ecd_info = pd.read_csv( + os.path.join(path, 'ecd_info.csv'), + encoding='gbk' + ) + + # 3. 提取info列表和索引 + self.dataset_all = [item['info'] for item in self.ecd_dataset] + self.index_all = self.ecd_info['Unnamed: 0'].values + + # 4. 构建手性对映射 + self.unnamed_idx_dict, self.hand_idx_dict, self.line_idx_dict = {}, {}, {} + for i, itm in enumerate(self.ecd_dataset): + self.line_idx_dict[i] = { + 'hand_id': itm['hand_id'], + 'unnamed_id': itm['id'], + 'smiles': itm['smiles'] + } + + if itm['id'] not in self.unnamed_idx_dict: + self.unnamed_idx_dict[itm['id']] = { + 'line_number': i, + 'hand_id': itm['hand_id'], + 'smiles': itm['smiles'] + } + else: + raise AssertionError(f"Duplicate unnamed id: {itm['id']}") + + if itm['hand_id'] not in self.hand_idx_dict: + self.hand_idx_dict[itm['hand_id']] = [] + self.hand_idx_dict[itm['hand_id']].append({ + 'line_number': i, + 'unnamed_id': itm['id'], + 'smiles': itm['smiles'] + }) + + # 5. 构建图数据集(核心调用) + self.graph_atom_bond, self.graph_bond_angle = GetAtomBondAngleDataset( + sample_path=path, + dataset_all=self.dataset_all, + index_all=self.index_all, + hand_idx_dict=self.hand_idx_dict, + line_idx_dict=self.line_idx_dict + ) + + _cache = (self.graph_atom_bond, self.graph_bond_angle) + assert len(self.graph_atom_bond) == len(self.graph_bond_angle), \ + "Mismatch between atom_bond and bond_angle graph lengths" + + def __len__(self): + return len(self.graph_atom_bond) + + def __getitem__(self, idx): + """ + 返回: + atom_bond_graph: paddle_geometric.data.Data + bond_angle_graph: paddle_geometric.data.Data + """ + return self.graph_atom_bond[idx], self.graph_bond_angle[idx] \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/colored_tqdm.py b/ppmat/datasets/ECDFormerDataset/colored_tqdm.py new file mode 100644 index 00000000..c6b9cff0 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/colored_tqdm.py @@ -0,0 +1,76 @@ +from tqdm import tqdm +import time +import os;os.system("") #兼容windows + +def hex_to_ansi(hex_color: str, background: bool = False) -> str: + """ + 将十六进制颜色转换为ANSI转义序列 + + Args: + hex_color: 十六进制颜色,如 '#dda0a0' 或 'dda0a0' + background: True表示背景色,False表示前景色 + + Returns: + ANSI转义序列字符串,如 '\033[38;2;221;160;160m' + + Example: + >>> print(f"{hex_to_ansi('#dda0a0')}Hello{hex_to_ansi('#000000')} World") + >>> print(f"{hex_to_ansi('dda0a0', background=True)}背景色{hex_to_ansi.reset()}") + """ + # 移除#号并转换为小写 + hex_color = hex_color.lower().lstrip('#') + + # 处理简写形式 (#fff -> ffffff) + if len(hex_color) == 3: + hex_color = ''.join([c * 2 for c in hex_color]) + + # 转换为RGB值 + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + + # ANSI真彩色序列 + # 38;2;R;G;B 为前景色,48;2;R;G;B 为背景色 + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +def rgb_to_ansi(r: int, g: int, b: int, background: bool = False) -> str: + """RGB值直接转ANSI""" + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +# 重置颜色的ANSI码 +hex_to_ansi.reset = '\033[0m' + +class ColoredTqdm(tqdm): + def __init__(self, *args, + start_color=(221, 160, 160), # RGB: #DDA0A0 + end_color=(160, 221, 160), # RGB: #A0DDA0 + **kwargs): + super().__init__(*args, **kwargs) + self.start_color = start_color + self.end_color = end_color + + def get_current_color(self): + progress = self.n / self.total if self.total > 0 else 0 + current_rgb = tuple( + int(start + (end - start) * progress) + for start, end in zip(self.start_color, self.end_color) + ) + result = current_rgb[0] * 16 ** 4 \ + + current_rgb[1] * 16 ** 2 \ + + current_rgb[2] * 16 ** 0 + return "%06x" % result + + def update(self, n=1): + super().update(n) + # 使用Rich的真彩色支持 + style = hex_to_ansi(self.get_current_color()) + self.bar_format = f'{{l_bar}}{style}{{bar}}{hex_to_ansi.reset}{{r_bar}}' + self.refresh() + + +if __name__ == "__main__": + # 使用示例 + for i in ColoredTqdm(range(100), desc="🌈 彩虹渐变"): + time.sleep(0.1) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/compound_tools.py b/ppmat/datasets/ECDFormerDataset/compound_tools.py new file mode 100644 index 00000000..bf47cd28 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/compound_tools.py @@ -0,0 +1,828 @@ +import numpy as np +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import rdchem + +DAY_LIGHT_FG_SMARTS_LIST = [ + # C + "[CX4]", + "[$([CX2](=C)=C)]", + "[$([CX3]=[CX3])]", + "[$([CX2]#C)]", + # C & O + "[CX3]=[OX1]", + "[$([CX3]=[OX1]),$([CX3+]-[OX1-])]", + "[CX3](=[OX1])C", + "[OX1]=CN", + "[CX3](=[OX1])O", + "[CX3](=[OX1])[F,Cl,Br,I]", + "[CX3H1](=O)[#6]", + "[CX3](=[OX1])[OX2][CX3](=[OX1])", + "[NX3][CX3](=[OX1])[#6]", + "[NX3][CX3]=[NX3+]", + "[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]", + "[NX3][CX3](=[OX1])[OX2H0]", + "[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]", + "[CX3](=O)[O-]", + "[CX3](=[OX1])(O)O", + "[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]", + "C[OX2][CX3](=[OX1])[OX2]C", + "[CX3](=O)[OX2H1]", + "[CX3](=O)[OX1H0-,OX2H1]", + "[NX3][CX2]#[NX1]", + "[#6][CX3](=O)[OX2H0][#6]", + "[#6][CX3](=O)[#6]", + "[OD2]([#6])[#6]", + # H + "[H]", + "[!#1]", + "[H+]", + "[+H]", + "[!H]", + # N + "[NX3;H2,H1;!$(NC=O)]", + "[NX3][CX3]=[CX3]", + "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]", + "[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]", + "[NX3][$(C=C),$(cc)]", + "[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]", + "[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]", + "[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH3X4]", + "[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]", + "[CH2X4][CX3](=[OX1])[NX3H2]", + "[CH2X4][CX3](=[OX1])[OH0-,OH]", + "[CH2X4][SX2H,SX1H0-]", + "[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]", + "[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]", + "[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\ +[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1", + "[CHX4]([CH3X4])[CH2X4][CH3X4]", + "[CH2X4][CHX4]([CH3X4])[CH3X4]", + "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]", + "[CH2X4][CH2X4][SX2][CH3X4]", + "[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1", + "[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH2X4][OX2H]", + "[NX3][CX3]=[SX1]", + "[CHX4]([CH3X4])[OX2H]", + "[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12", + "[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1", + "[CHX4]([CH3X4])[CH3X4]", + "N[CX4H2][CX3](=[OX1])[O,N]", + "N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]", + "[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]", + "[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]", + "[#7]", + "[NX2]=N", + "[NX2]=[NX2]", + "[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]", + "[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]", + "[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]", + "[NX3][NX3]", + "[NX3][NX2]=[*]", + "[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]", + "[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]", + "[NX3+]=[CX3]", + "[CX3](=[OX1])[NX3H][CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])", + "[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]", + "[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]", + "[NX1]#[CX2]", + "[CX1-]#[NX2+]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[NX2]=[OX1]", + "[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]", + # O + "[OX2H]", + "[#6][OX2H]", + "[OX2H][CX3]=[OX1]", + "[OX2H]P", + "[OX2H][#6X3]=[#6]", + "[OX2H][cX3]:[c]", + "[OX2H][$(C=C),$(cc)]", + "[$([OH]-*=[!#6])]", + "[OX2,OX1-][OX2,OX1-]", + # P + "[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\ +$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\ +,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]", + "[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\ +$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\ +$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]", + # S + "[S-][CX3](=S)[#6]", + "[#6X3](=[SX1])([!N])[!N]", + "[SX2]", + "[#16X2H]", + "[#16!H0]", + "[#16X2H0]", + "[#16X2H0][!#16]", + "[#16X2H0][#16X2H0]", + "[#16X2H0][!#16].[#16X2H0][!#16]", + "[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]", + "[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]", + "[SX4](C)(C)(=O)=N", + "[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]", + "[$([#16X3]=[OX1]),$([#16X3+][OX1-])]", + "[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]", + "[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]", + "[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]", + "[#16X2][OX2H,OX1H0-]", + "[#16X2][OX2H0]", + # X + "[#6][F,Cl,Br,I]", + "[F,Cl,Br,I]", + "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]", + ] + + +def get_gasteiger_partial_charges(mol, n_iter=12): + """ + Calculates list of gasteiger partial charges for each atom in mol object. + Args: + mol: rdkit mol object. + n_iter(int): number of iterations. Default 12. + Returns: + list of computed partial charges for each atom. + """ + Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, + throwOnParamFailure=True) + partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in + mol.GetAtoms()] + return partial_charges + + +def create_standardized_mol_id(smiles): + """ + Args: + smiles: smiles sequence. + Returns: + inchi. + """ + if check_smiles_validity(smiles): + # remove stereochemistry + smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), + isomericSmiles=False) + mol = AllChem.MolFromSmiles(smiles) + if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 + if '.' in smiles: # if multiple species, pick largest molecule + mol_species_list = split_rdkit_mol_obj(mol) + largest_mol = get_largest_mol(mol_species_list) + inchi = AllChem.MolToInchi(largest_mol) + else: + inchi = AllChem.MolToInchi(mol) + return inchi + else: + return + else: + return + + +def check_smiles_validity(smiles): + """ + Check whether the smile can't be converted to rdkit mol object. + """ + try: + m = Chem.MolFromSmiles(smiles) + if m: + return True + else: + return False + except Exception as e: + return False + + +def split_rdkit_mol_obj(mol): + """ + Split rdkit mol object containing multiple species or one species into a + list of mol objects or a list containing a single object respectively. + Args: + mol: rdkit mol object. + """ + smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) + smiles_list = smiles.split('.') + mol_species_list = [] + for s in smiles_list: + if check_smiles_validity(s): + mol_species_list.append(AllChem.MolFromSmiles(s)) + return mol_species_list + + +def get_largest_mol(mol_list): + """ + Given a list of rdkit mol objects, returns mol object containing the + largest num of atoms. If multiple containing largest num of atoms, + picks the first one. + Args: + mol_list(list): a list of rdkit mol object. + Returns: + the largest mol. + """ + num_atoms_list = [len(m.GetAtoms()) for m in mol_list] + largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) + return mol_list[largest_mol_idx] + + +def rdchem_enum_to_list(values): + """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} + """ + return [values[i] for i in range(len(values))] + + +def safe_index(alist, elem): + """ + Return index of element e in list l. If e is not present, return the last index + """ + try: + return alist.index(elem) + except ValueError: + return len(alist) - 1 + + +def get_atom_feature_dims(list_acquired_feature_names): + """ tbd + """ + return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) + + +def get_bond_feature_dims(list_acquired_feature_names): + """ tbd + """ + list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) + # +1 for self loop edges + return [_l + 1 for _l in list_bond_feat_dim] + + +class CompoundKit(object): + """ + CompoundKit + """ + atom_vocab_dict = { + "atomic_num": list(range(1, 119)) + ['misc'], + "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), + "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values), + "implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "is_aromatic": [0, 1], + "total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'num_radical_e': [0, 1, 2, 3, 4, 'misc'], + 'atom_is_in_ring': [0, 1], + 'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + } + bond_vocab_dict = { + "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), + "bond_type": rdchem_enum_to_list(rdchem.BondType.values), + "is_in_ring": [0, 1], + + 'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values), + 'is_conjugated': [0, 1], + } + # float features + atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] + # bond_float_feats= ["bond_length", "bond_angle"] # optional + + ### functional groups + day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST + day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] + + morgan_fp_N = 200 + morgan2048_fp_N = 2048 + maccs_fp_N = 167 + + period_table = Chem.GetPeriodicTable() + + ### atom + + @staticmethod + def get_atom_value(atom, name): + """get atom values""" + if name == 'atomic_num': + return atom.GetAtomicNum() + elif name == 'chiral_tag': + return atom.GetChiralTag() + elif name == 'degree': + return atom.GetDegree() + elif name == 'explicit_valence': + return atom.GetExplicitValence() + elif name == 'formal_charge': + return atom.GetFormalCharge() + elif name == 'hybridization': + return atom.GetHybridization() + elif name == 'implicit_valence': + return atom.GetImplicitValence() + elif name == 'is_aromatic': + return int(atom.GetIsAromatic()) + elif name == 'mass': + return int(atom.GetMass()) + elif name == 'total_numHs': + return atom.GetTotalNumHs() + elif name == 'num_radical_e': + return atom.GetNumRadicalElectrons() + elif name == 'atom_is_in_ring': + return int(atom.IsInRing()) + elif name == 'valence_out_shell': + return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) + else: + raise ValueError(name) + + @staticmethod + def get_atom_feature_id(atom, name): + """get atom features id""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) + + @staticmethod + def get_atom_feature_size(name): + """get atom features size""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return len(CompoundKit.atom_vocab_dict[name]) + + ### bond + + @staticmethod + def get_bond_value(bond, name): + """get bond values""" + if name == 'bond_dir': + return bond.GetBondDir() + elif name == 'bond_type': + return bond.GetBondType() + elif name == 'is_in_ring': + return int(bond.IsInRing()) + elif name == 'is_conjugated': + return int(bond.GetIsConjugated()) + elif name == 'bond_stereo': + return bond.GetStereo() + else: + raise ValueError(name) + + @staticmethod + def get_bond_feature_id(bond, name): + """get bond features id""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) + + @staticmethod + def get_bond_feature_size(name): + """get bond features size""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return len(CompoundKit.bond_vocab_dict[name]) + + ### fingerprint + + @staticmethod + def get_morgan_fingerprint(mol, radius=2): + """get morgan fingerprint""" + nBits = CompoundKit.morgan_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_morgan2048_fingerprint(mol, radius=2): + """get morgan2048 fingerprint""" + nBits = CompoundKit.morgan2048_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_maccs_fingerprint(mol): + """get maccs fingerprint""" + fp = AllChem.GetMACCSKeysFingerprint(mol) + return [int(b) for b in fp.ToBitString()] + + ### functional groups + + @staticmethod + def get_daylight_functional_group_counts(mol): + """get daylight functional group counts""" + fg_counts = [] + for fg_mol in CompoundKit.day_light_fg_mo_list: + sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) + fg_counts.append(len(sub_structs)) + return fg_counts + + @staticmethod + def get_ring_size(mol): + """return (N,6) list""" + rings = mol.GetRingInfo() + rings_info = [] + for r in rings.AtomRings(): + rings_info.append(r) + ring_list = [] + for atom in mol.GetAtoms(): + atom_result = [] + for ringsize in range(3, 9): + num_of_ring_at_ringsize = 0 + for r in rings_info: + if len(r) == ringsize and atom.GetIdx() in r: + num_of_ring_at_ringsize += 1 + if num_of_ring_at_ringsize > 8: + num_of_ring_at_ringsize = 9 + atom_result.append(num_of_ring_at_ringsize) + + ring_list.append(atom_result) + return ring_list + + @staticmethod + def atom_to_feat_vector(atom): + """ tbd """ + atom_names = { + "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), + "chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()), + "degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()), + "explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()), + "formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()), + "hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()), + "implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()), + "is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())), + "total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()), + 'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()), + 'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())), + 'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'], + CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())), + 'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()), + 'partial_charge': CompoundKit.check_partial_charge(atom), + 'mass': atom.GetMass(), + } + return atom_names + + @staticmethod + def get_atom_names(mol): + """get atom name list + TODO: to be remove in the future + """ + atom_features_dicts = [] + Chem.rdPartialCharges.ComputeGasteigerCharges(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) + + ring_list = CompoundKit.get_ring_size(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) + atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) + atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) + atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) + atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) + atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) + + return atom_features_dicts + + @staticmethod + def check_partial_charge(atom): + """tbd""" + pc = atom.GetDoubleProp('_GasteigerCharge') + if pc != pc: + # unsupported atom, replace nan with 0 + pc = 0 + if pc == float('inf'): + # max 4 for other atoms, set to 10 here if inf is get + pc = 10 + return pc + + +class Compound3DKit(object): + """the 3Dkit of Compound""" + + @staticmethod + def get_atom_poses(mol, conf): + """tbd""" + atom_poses = [] + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) + pos = conf.GetAtomPosition(i) + atom_poses.append([pos.x, pos.y, pos.z]) + return atom_poses + + @staticmethod + def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): + """the atoms of mol will be changed in some cases.""" + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return mol,atom_poses + # try: + # new_mol = Chem.AddHs(mol) + # res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs) + # ### MMFF generates multiple conformations + # res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + # new_mol = Chem.RemoveHs(new_mol) + # index = np.argmin([x[1] for x in res]) + # energy = res[index][1] + # conf = new_mol.GetConformer(id=int(index)) + # except: + # new_mol = mol + # AllChem.Compute2DCoords(new_mol) + # energy = 0 + # conf = new_mol.GetConformer() + # + # atom_poses = Compound3DKit.get_atom_poses(new_mol, conf) + # if return_energy: + # return new_mol, atom_poses, energy + # else: + # return new_mol, atom_poses + + @staticmethod + def get_2d_atom_poses(mol): + """get 2d atom poses""" + AllChem.Compute2DCoords(mol) + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return atom_poses + + @staticmethod + def get_bond_lengths(edges, atom_poses): + """get bond lengths""" + bond_lengths = [] + for src_node_i, tar_node_j in edges: + bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) + bond_lengths = np.array(bond_lengths, 'float32') + return bond_lengths + + @staticmethod + def get_superedge_angles(edges, atom_poses, dir_type='HT'): + """get superedge angles""" + + def _get_vec(atom_poses, edge): + return atom_poses[edge[1]] - atom_poses[edge[0]] + + def _get_angle(vec1, vec2): + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + if norm1 == 0 or norm2 == 0: + return 0 + vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors + vec2 = vec2 / (norm2 + 1e-5) + angle = np.arccos(np.dot(vec1, vec2)) + return angle + + E = len(edges) + edge_indices = np.arange(E) + super_edges = [] + bond_angles = [] + bond_angle_dirs = [] + for tar_edge_i in range(E): + tar_edge = edges[tar_edge_i] + if dir_type == 'HT': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] + elif dir_type == 'HH': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] + else: + raise ValueError(dir_type) + for src_edge_i in src_edge_indices: + if src_edge_i == tar_edge_i: + continue + src_edge = edges[src_edge_i] + src_vec = _get_vec(atom_poses, src_edge) + tar_vec = _get_vec(atom_poses, tar_edge) + super_edges.append([src_edge_i, tar_edge_i]) + angle = _get_angle(src_vec, tar_vec) + bond_angles.append(angle) + bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T + + if len(super_edges) == 0: + super_edges = np.zeros([0, 2], 'int64') + bond_angles = np.zeros([0, ], 'float32') + else: + super_edges = np.array(super_edges, 'int64') + bond_angles = np.array(bond_angles, 'float32') + return super_edges, bond_angles, bond_angle_dirs + + +def new_smiles_to_graph_data(smiles, **kwargs): + """ + Convert smiles to graph data. + """ + mol = AllChem.MolFromSmiles(smiles) + if mol is None: + return None + data = new_mol_to_graph_data(mol) + return data + + +def new_mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names + bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) + + data = {} + + ### atom features + data = {name: [] for name in atom_id_names} + + raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) + for atom_feat in raw_atom_feat_dicts: + for name in atom_id_names: + data[name].append(atom_feat[name]) + + ### bond and bond features + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + data[name] += [bond_feature_id] * 2 + + #### self loop + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1 + data[name] += [bond_feature_id] * N + + ### make ndarray and check length + for name in list(CompoundKit.atom_vocab_dict.keys()): + data[name] = np.array(data[name], 'int64') + for name in CompoundKit.atom_float_names: + data[name] = np.array(data[name], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", + ] + bond_id_names = [ + "bond_dir", "bond_type", "is_in_ring", + ] + + data = {} + for name in atom_id_names: + data[name] = [] + data['mass'] = [] + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + ### atom features + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return None + for name in atom_id_names: + data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV + data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) + + ### bond features + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV + data[name] += [bond_feature_id] * 2 + + ### self loop (+2) + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop + data[name] += [bond_feature_id] * N + + ### check whether edge exists + if len(data['edges']) == 0: # mol has no bonds + for name in bond_id_names: + data[name] = np.zeros((0,), dtype="int64") + data['edges'] = np.zeros((0, 2), dtype="int64") + + ### make ndarray and check length + for name in atom_id_names: + data[name] = np.array(data[name], 'int64') + data['mass'] = np.array(data['mass'], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_geognn_graph_data(mol, atom_poses, dir_type): + """ + mol: rdkit molecule + dir_type: direction type for bond_angle grpah + """ + if len(mol.GetAtoms()) == 0: + return None + + data = mol_to_graph_data(mol) + + data['atom_pos'] = np.array(atom_poses, 'float32') + data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) + BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ + Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) + data['BondAngleGraph_edges'] = BondAngleGraph_edges + data['bond_angle'] = np.array(bond_angles, 'float32') + return data + + +def mol_to_geognn_graph_data_MMFF3d(mol): + """tbd""" + if len(mol.GetAtoms()) <= 400: + mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) + else: + atom_poses = Compound3DKit.get_2d_atom_poses(mol) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + + +def mol_to_geognn_graph_data_raw3d(mol): + """tbd""" + atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + +def obtain_3D_mol(smiles,name): + mol = AllChem.MolFromSmiles(smiles) + new_mol = Chem.AddHs(mol) + res = AllChem.EmbedMultipleConfs(new_mol) + ### MMFF generates multiple conformations + res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + new_mol = Chem.RemoveHs(new_mol) + Chem.MolToMolFile(new_mol, name+'.mol') + return new_mol + +def predict_SMILES_info(smiles): + # by lihao, input smiles, output dict + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + info_dict = mol_to_geognn_graph_data_MMFF3d(mol) + return info_dict + +if __name__ == "__main__": + # smiles = "OCc1ccccc1CN" + smiles = r"[H]/[NH+]=C(\N)C1=CC(=O)/C(=C\C=c2ccc(=C(N)[NH3+])cc2)C=C1" + # smiles = 'CC' + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + data = mol_to_geognn_graph_data_MMFF3d(mol) + for key, value in data.items(): + print(key, value.shape) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/dataloader.py b/ppmat/datasets/ECDFormerDataset/dataloader.py new file mode 100644 index 00000000..848c9ae4 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/dataloader.py @@ -0,0 +1,56 @@ +import paddle +from paddle.io import Dataset + +from typing import Any, List, Sequence, Union + +from paddle_geometric.data import Batch, Dataset +from paddle_geometric.data.data import BaseData +from paddle_geometric.data.datapipes import DatasetAdapter + +def call(batch: List[Any]) -> Any: + batch = [list(x) for x in zip(*batch)] # transpose + for i in range(len(batch)): # 组Batch + batch[i] = Batch.from_data_list(batch[i]) + + batch0 = batch[0] + batch1 = batch[1] + + # Data解包到Tensor字典 + batch_atom_bond, batch_bond_angle = batch0, batch1 + x, edge_index, edge_attr, query_mask =batch_atom_bond.x,batch_atom_bond.edge_index,batch_atom_bond.edge_attr,batch_atom_bond.query_mask + ba_edge_index, ba_edge_attr = batch_bond_angle.edge_index,batch_bond_angle.edge_attr + batch_data = batch_atom_bond.batch + pos_gt = batch_atom_bond.peak_position + height_gt = batch_atom_bond.peak_height + num_gt = batch_atom_bond.peak_num + return \ + { + "x" : x , + "edge_index" : edge_index , + "edge_attr" : edge_attr , + "batch_data" : batch_data , + "ba_edge_index" : ba_edge_index , + "ba_edge_attr" : ba_edge_attr , + "query_mask" : query_mask + }, \ + { + "peak_number_gt" : num_gt , + "peak_position_gt": pos_gt , + "peak_height_gt" : height_gt + } + +class ECDFormerDataset_DataLoader(paddle.io.DataLoader): + def __init__( + self, + dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], + batch_size: int = 1, + shuffle: bool = False, + **kwargs, + ): + super().__init__( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=call, + **kwargs, + ) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/eval_func.py b/ppmat/datasets/ECDFormerDataset/eval_func.py new file mode 100644 index 00000000..5bd5e1ab --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/eval_func.py @@ -0,0 +1,147 @@ +import math +import json + +import paddle +import paddle.nn as nn +import numpy as np + +from sklearn.metrics import mean_squared_error + +def Accuracy(pred, gt): + # the implimentation of good sample Accuracy. We calculate multi-range accuracy + # pred: paddle.Tensor(batch), the tensor of prediction + # gt: paddle.Tensor(batch), the tensor of groundtruth + pred, gt = pred.tolist(), gt.tolist() + assert len(pred) == len(gt) + + acc, acc_1, acc_2, acc_3 = 0, 0, 0, 0 + for i in range(len(pred)): + if pred[i] == gt[i]: + acc += 1 + elif abs(pred[i]-gt[i]) == 1: + acc_1 += 1 + elif abs(pred[i]-gt[i]) == 2: + acc_2 += 1 + elif abs(pred[i]-gt[i]) == 3: + acc_3 += 1 + else: continue + + return acc, acc_1, acc_2, acc_3 + +def MAE(pred, gt): + # the implimentation of Mean Absolute Error + # MAE == L1 loss + # pred: paddle.Tensor(batch, seq_len), the tensor of prediction + # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth + + pred = paddle.to_tensor(pred, dtype=paddle.float32) + gt = paddle.to_tensor(gt, dtype=paddle.float32) + assert pred.shape == gt.shape + L1Loss = nn.L1Loss(reduction='mean') # paddle��size_average�ѷ�����ͳһ��reduction + mae = L1Loss(pred, gt) + + return mae + + +def MAPE(pred, gt): + # the implimentation of Mean Absolute Percentage Error + # pred: paddle.Tensor(batch, seq_len), the tensor of prediction + # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth + + pred = paddle.to_tensor(pred, dtype=paddle.float32) + gt = paddle.to_tensor(gt, dtype=paddle.float32) + pred, gt = pred.numpy(), gt.numpy() # paddle����cpu()������ֱ����numpy() + mape_loss = np.mean( + np.abs(np.divide(pred-gt, gt, out=np.zeros_like(gt), where=gt!=0)) + ) * 100 + return mape_loss + +def get_sequence_peak(sequence): + # input- seq: List + # output- peak_list contains peak position + peak_list = [] + for i in range(1, len(sequence)-1): + if sequence[i-1]sequence[i+1]: + peak_list.append(i) + if sequence[i-1]>sequence[i] and sequence[i]0: correct_peak_symbols += 1 + ## peak number + number_gt.append(len(peaks_gt)) + number_pred.append(len(peaks_pred)) + ## peak position + if len(peaks_gt) == min_peaks_len: + peaks_gt.extend([len(gt[i])] * (max_peaks_len-min_peaks_len) ) + else: peaks_pred.extend([len(pred[i])] * (max_peaks_len-min_peaks_len) ) + position_gt.extend(peaks_gt) + position_pred.extend(peaks_pred) + + rmse_position = np.sqrt(mean_squared_error(position_gt, position_pred)) + rmse_number = np.sqrt(mean_squared_error(number_gt, number_pred)) + symbol_acc = correct_peak_symbols / total_peaks if total_peaks != 0 else 0.0 + return symbol_acc, rmse_position, rmse_number + +def Peak_for_draw(pred, gt): + # the implementation of RMSE for Peak Range, Number, Position + # return the Peak information + # pred: paddle.Tensor(batch, seq_len), the tensor of prediction + # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth + + pred = paddle.to_tensor(pred, dtype=paddle.float32) + gt = paddle.to_tensor(gt, dtype=paddle.float32) + pred, gt = pred.numpy().tolist(), gt.numpy().tolist() # paddle����cpu()���� + batch_size = len(pred) + + # calculate RMSE Range + range_gt, range_pred = [], [] + for i in range(batch_size): + range_gt.append(max(gt[i]) - min(gt[i])) + range_pred.append(max(pred[i]) - min(pred[i])) + + # calculate RMSE for Peak number + number_gt, number_pred = [], [] + position_gt, position_pred = [], [] + for i in range(batch_size): + peaks_gt = get_sequence_peak(gt[i]) + peaks_pred = get_sequence_peak(pred[i]) + + number_gt.append(len(peaks_gt)) + number_pred.append(len(peaks_pred)) + + min_peaks_len = min(len(peaks_gt), len(peaks_pred)) + max_peaks_len = max(len(peaks_gt), len(peaks_pred)) + + if len(peaks_gt) == min_peaks_len: + peaks_gt.extend([len(gt[i])] * (max_peaks_len-min_peaks_len) ) + else: + peaks_pred.extend([len(pred[i])] * (max_peaks_len-min_peaks_len) ) + position_gt.append(sum(peaks_gt)) + position_pred.append(sum(peaks_pred)) + + return dict( + peak_range = (range_gt, range_pred), + peak_num = (number_gt, number_pred), + peak_pos = (position_gt, position_pred), + ) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/place_env.py b/ppmat/datasets/ECDFormerDataset/place_env.py new file mode 100644 index 00000000..6d06a504 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/place_env.py @@ -0,0 +1,173 @@ +import paddle +import functools +from contextlib import contextmanager + +@contextmanager +def place_env(place): + """ + 上下文管理器,用于临时设置PaddlePaddle的运行设备 + + Args: + place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + + 用法: + with place_env(paddle.CPUPlace()): + # 这里的代码在CPU上运行 + x = paddle.rand([2, 3]) + print(x) + + @place_env(paddle.CUDAPlace(0)) + def train(): + # 这个函数在GPU上运行 + pass + """ + # 保存当前的设备设置 + current_device = paddle.get_device() + + # 根据place类型设置设备 + if isinstance(place, paddle.CPUPlace): + paddle.set_device('cpu') + elif isinstance(place, paddle.CUDAPlace): + # 获取GPU设备ID + device_id = place.get_device_id() + paddle.set_device(f'gpu:{device_id}') + else: + raise ValueError(f"不支持的place类型: {type(place)}") + + try: + yield + finally: + # 恢复原来的设备设置 + paddle.set_device(current_device) + + +class PlaceEnv: + """ + 类版本的上下文管理器,也支持装饰器功能 + """ + + def __init__(self, place): + """ + 初始化PlaceEnv + + Args: + place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + """ + self.place = place + self.original_device = None + + def __enter__(self): + """进入上下文时调用""" + # 保存当前的设备设置 + self.original_device = paddle.get_device() + + # 根据place类型设置设备 + if isinstance(self.place, paddle.CPUPlace): + paddle.set_device('cpu') + elif isinstance(self.place, paddle.CUDAPlace): + device_id = self.place.get_device_id() + paddle.set_device(f'gpu:{device_id}') + else: + raise ValueError(f"不支持的place类型: {type(self.place)}") + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """退出上下文时调用""" + # 恢复原来的设备设置 + if self.original_device is not None: + paddle.set_device(self.original_device) + + def __call__(self, func): + """ + 使实例可以作为装饰器使用 + + Args: + func: 要装饰的函数 + + Returns: + 装饰后的函数 + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + # 使用with语句来临时改变设备设置 + with self: + return func(*args, **kwargs) + return wrapper + + +# 为了兼容性,也可以保留函数版本的上下文管理器 +@contextmanager +def with_place_env(place): + """ + with_place_env的别名,与place_env功能相同 + """ + with place_env(place): + yield + + +# 使用示例 +if __name__ == "__main__": + # 测试with语句 + print("=== 测试with语句 ===") + print(f"当前设备: {paddle.get_device()}") + + with place_env(paddle.CPUPlace()): + print(f"with块内设备: {paddle.get_device()}") + x = paddle.rand([2, 3]) + print(f"创建的张量: {x}") + + print(f"with块外设备: {paddle.get_device()}") + + print("\n=== 测试类版本with语句 ===") + with PlaceEnv(paddle.CPUPlace()): + print(f"with块内设备: {paddle.get_device()}") + y = paddle.ones([2, 3]) + print(f"创建的张量: {y}") + + print(f"with块外设备: {paddle.get_device()}") + + # 测试装饰器功能 + print("\n=== 测试装饰器功能 ===") + + @PlaceEnv(paddle.CPUPlace()) + def cpu_function(): + """这个函数会在CPU上运行""" + print(f"函数内设备: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # 检查是否有GPU可用 + if paddle.device.cuda.device_count() > 0: + @PlaceEnv(paddle.CUDAPlace(0)) + def gpu_function(): + """这个函数会在GPU上运行""" + print(f"函数内设备: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # 调用装饰后的函数 + print("调用cpu_function:") + result_cpu = cpu_function() + print(f"函数执行后设备: {paddle.get_device()}") + print(f"结果: {result_cpu}") + + if paddle.device.cuda.device_count() > 0: + print("\n调用gpu_function:") + result_gpu = gpu_function() + print(f"函数执行后设备: {paddle.get_device()}") + print(f"结果: {result_gpu}") + + print("\n=== 测试多层嵌套 ===") + print(f"初始设备: {paddle.get_device()}") + + with PlaceEnv(paddle.CPUPlace()): + print(f"第一层with内设备: {paddle.get_device()}") + + if paddle.device.cuda.device_count() > 0: + with PlaceEnv(paddle.CUDAPlace(0)): + print(f"第二层with内设备: {paddle.get_device()}") + z = paddle.rand([2, 2]) + print(f"创建的张量: {z}") + + print(f"回到第一层with设备: {paddle.get_device()}") + + print(f"最终设备: {paddle.get_device()}") \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/util_func.py b/ppmat/datasets/ECDFormerDataset/util_func.py new file mode 100644 index 00000000..010ef951 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/util_func.py @@ -0,0 +1,43 @@ +def has_element_in_range(lst, lower_bound, upper_bound): + """ + 检查给定列表 lst 中是否存在元素在指定的区间 [lower_bound, upper_bound] 内。 + + 参数: + - lst: 输入的列表 + - lower_bound: 区间的下界 + - upper_bound: 区间的上界 + + 返回: + - 存在元素在指定区间内时返回 True, 否则返回 False + """ + for element in lst: + if lower_bound <= element <= upper_bound: + return True + return False + + +def normalize_func(src_list, norm_range=[-100, 100]): + # lihao implecation for list normalization + # input: src_list, normalization range + # output: tgt_list after normalization + + src_max, src_min = max(src_list), min(src_list) + norm_min, norm_max = norm_range[0], norm_range[1] + if src_max == 0: src_max = 1 + if src_min == 0: src_min = -1 + + tgt_list = [] + for i in range(len(src_list)): + if src_list[i] >= 0: + tgt_list.append(src_list[i] * norm_max / src_max) + else: + tgt_list.append(src_list[i] * norm_min / src_min) + + assert len(src_list) == len(tgt_list) + return tgt_list + +if __name__ == "__main__": + src = [-50, 0, 1, 50] + norm_range = [-100, 100] + tgt = normalize_func(src, norm_range) + print(tgt) diff --git a/ppmat/datasets/IRDataset/__init__.py b/ppmat/datasets/IRDataset/__init__.py new file mode 100644 index 00000000..0cb43d2d --- /dev/null +++ b/ppmat/datasets/IRDataset/__init__.py @@ -0,0 +1,439 @@ +# IRDataset.py +""" +IR光谱预测数据集模块 +支持预加载的npy文件,包含缓存机制,默认使用100样本的小数据集 +""" + +import os +import numpy as np +import pandas as pd +import paddle +from paddle.io import Dataset, DataLoader +from paddle_geometric.data import Data +from tqdm import tqdm +import pickle +import json +import warnings + +import rdkit +from rdkit import Chem +from rdkit.Chem import AllChem + +from .compound_tools import mol_to_geognn_graph_data_MMFF3d +from .compound_tools import get_atom_feature_dims, get_bond_feature_dims +from .colored_tqdm import ColoredTqdm as tqdm +from .place_env import PlaceEnv + +# ----------------常量定义---------------- +ATOM_ID_NAMES = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", +] + +BOND_ID_NAMES = ["bond_dir", "bond_type", "is_in_ring"] + +BOND_ANGLE_FLOAT_NAMES = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] + +# 获取特征维度 +FULL_ATOM_FEATURE_DIMS = get_atom_feature_dims(ATOM_ID_NAMES) +FULL_BOND_FEATURE_DIMS = get_bond_feature_dims(BOND_ID_NAMES) + +# IR光谱参数 +IR_WAVELENGTH_MIN = 500 +IR_WAVELENGTH_MAX = 4000 +IR_STEP = 100 # 波数步长,用于离散化 +IR_NUM_POSITION_CLASSES = (IR_WAVELENGTH_MAX - IR_WAVELENGTH_MIN) // IR_STEP # 36 + +# 默认最大峰数 +DEFAULT_MAX_PEAKS = 15 + + +def get_key_padding_mask(tokens): + """生成query padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + + +def x_bin_position(real_x, distance=IR_STEP): + """将实际波数转换为箱ID""" + return int((real_x - IR_WAVELENGTH_MIN) / distance) + + +def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): + """ + 从原始特征构建IR图数据 + 类似于ECD的Construct_dataset,但针对IR任务 + + Args: + dataset: list of dict, 每个元素是分子的info字典 + data_index: list or array, 索引列表 + descriptor_path: str, 描述符文件路径(IR可能不需要,保留接口) + + Returns: + graph_atom_bond: list of Data, atom-bond图 + graph_bond_angle: list of Data, bond-angle图 + """ + graph_atom_bond = [] + graph_bond_angle = [] + + # IR任务可能不需要描述符,但如果需要可以加载 + all_descriptor = None + if descriptor_path and os.path.exists(descriptor_path): + all_descriptor = np.load(descriptor_path) + + for i in tqdm(range(len(dataset)), desc="Constructing IR graphs"): + data = dataset[i] + + # 收集原子特征 + atom_feature = [] + for name in ATOM_ID_NAMES: + if name in data: + atom_feature.append(data[name]) + else: + # 如果某些特征缺失,用0填充 + # 注意:根据实际数据调整 + if i == 0: # 只在第一次警告 + warnings.warn(f"Feature {name} not found in data, using zeros") + num_atoms = data.get('atomic_num', np.zeros(1)).shape[0] + atom_feature.append(np.zeros(num_atoms)) + + # 收集键特征 + bond_feature = [] + for name in BOND_ID_NAMES: + if name in data: + bond_feature.append(data[name]) + else: + if i == 0: + warnings.warn(f"Bond feature {name} not found, using zeros") + num_bonds = data.get('bond_dir', np.zeros(1)).shape[0] + bond_feature.append(np.zeros(num_bonds)) + + # 转换为Tensor + atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') + bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') + + # 键长特征(IR可能不需要,但保留) + bond_float_feature = paddle.to_tensor(data.get('bond_length', np.zeros(data['edges'].shape[0])).astype(np.float32)) + + # 键角特征(IR可能不需要,但保留) + bond_angle_feature = paddle.to_tensor(data.get('bond_angle', np.zeros(data.get('BondAngleGraph_edges', np.zeros((0,2))).shape[0])).astype(np.float32)) + + # 边索引 + edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') + bond_index = paddle.to_tensor(data.get('BondAngleGraph_edges', np.zeros((0,2))).T, dtype='int64') + + + data_index_int = paddle.to_tensor(np.array(int(data_index[i])), dtype='int64') + + # 获取原子数(键角图的节点数) + num_atoms = atom_feature.shape[0] + + # 合并键特征 + bond_feature = paddle.concat( + [bond_feature.astype(bond_float_feature.dtype), + bond_float_feature.reshape([-1, 1])], + axis=1 + ) + + # 处理键角特征(如果有) + if bond_angle_feature.shape[0] > 0: + # 如果有描述符,可以添加 + if all_descriptor is not None and i < all_descriptor.shape[0]: + TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 + RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] + RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] + MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] + MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] + + bond_angle_feature = paddle.concat( + [bond_angle_feature.reshape([-1, 1]), TPSA.reshape([-1, 1])], axis=1 + ) + bond_angle_feature = paddle.concat([bond_angle_feature, RASA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, RPSA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MDEC.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MATS.reshape([-1, 1])], axis=1) + else: + # 如果没有描述符,直接reshape + bond_angle_feature = bond_angle_feature.reshape([-1, 1]) + + # 创建Data对象 + data_atom_bond = Data( + x=atom_feature, + edge_index=edge_index, + edge_attr=bond_feature, + data_index=data_index_int, + ) + + data_bond_angle = Data( + edge_index=bond_index, + edge_attr=bond_angle_feature if bond_angle_feature.shape[0] > 0 else paddle.zeros([0, 1]), + num_nodes=num_atoms, # 键角图的节点数等于原子数 + ) + + graph_atom_bond.append(data_atom_bond) + graph_bond_angle.append(data_bond_angle) + + return graph_atom_bond, graph_bond_angle + + +def read_ir_spectra_by_ids(sample_path, index_all, max_peak=DEFAULT_MAX_PEAKS): + """ + 按需读取IR光谱文件 + + Args: + sample_path: str, IR光谱JSON文件目录 + index_all: list, 需要读取的文件ID列表 + max_peak: int, 最大峰数 + + Returns: + ir_final_list: list of dict, 包含峰值信息的字典列表 + """ + ir_final_list = [] + + for fileid in tqdm(index_all, desc="Reading IR spectra by ID"): + filepath = os.path.join(sample_path, f"{fileid}.json") + + try: + with open(filepath, 'r') as f: + raw_ir_info = json.load(f) + + ir_x = raw_ir_info['x'] + ir_y = raw_ir_info['y_40'] + + from scipy.signal import find_peaks + peaks_raw, _ = find_peaks(x=ir_y, height=0.1, distance=100) + peaks_raw = peaks_raw.tolist() + + # 处理峰值(与原作完全一致) + peak_num = min(len(peaks_raw), max_peak) + + if peak_num > 0: + if len(peaks_raw) > max_peak: + peaks = peaks_raw[len(peaks_raw)-max_peak:] + else: + peaks = peaks_raw + + peak_position_list = [x_bin_position(ir_x[i]) for i in peaks] + peak_height_list = [ir_y[i] for i in peaks] + else: + peak_position_list = [] + peak_height_list = [] + + peak_position_list = peak_position_list + [-1] * (max_peak - len(peak_position_list)) + peak_height_list = peak_height_list + [-1] * (max_peak - len(peak_height_list)) + + query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) + + tmp_dict = { + 'id': fileid, + 'seq_40': ir_y, + 'peak_num': peak_num, + 'peak_position': peak_position_list, + 'peak_height': peak_height_list, + 'query_mask': query_padding_mask.unsqueeze(0), + } + ir_final_list.append(tmp_dict) + + except Exception as e: + warnings.warn(f"Error processing {fileid}.json: {e}") + continue + + ir_final_list.sort(key=lambda x: x['id']) + return ir_final_list + + +def load_ir_meta_file(mode='100'): + """ + 加载预生成的IR元数据文件 + + Args: + mode: str, 可选 '100', '10000', 'all' + + Returns: + dataset_all: list, 图特征数据 + smiles_all: list, SMILES列表 + index_all: list, 索引列表 + """ + valid_modes = {'100', '10000', 'all'} + if mode not in valid_modes: + warnings.warn(f"Invalid mode {mode}, using '100'") + mode = '100' + + filename = f'ir_column_charity_{mode}.npy' + + if not os.path.exists(filename): + # 尝试在dataset/IR目录下查找 + alt_path = os.path.join('dataset', 'IR', filename) + if os.path.exists(alt_path): + filename = alt_path + else: + raise FileNotFoundError(f"IR meta file {filename} not found") + + print(f"Loading IR meta file: {filename}") + data = np.load(filename, allow_pickle=True).item() + + return data['dataset_all'], data['smiles_all'], data['index_all'] + + +class IRDataset(Dataset): + """ + IR光谱预测数据集类 + + 支持三种预加载模式: + - '100': 100个样本的小数据集(默认,用于快速测试) + - '10000': 1万个样本的中等数据集 + - 'all': 全部样本(可能很大) + + 关键优化:按需读取JSON文件,只读取index_all中指定的文件! + """ + + _cache = {} + + @PlaceEnv(paddle.CPUPlace()) + def __init__(self, + path: str = "dataset/IR", + mode: str = '100', + use_geometry_enhanced: bool = True, + force_reload: bool = False, + cache: bool = True): + + self.path = path + self.mode = mode + self.use_geometry_enhanced = use_geometry_enhanced + self.cache_enabled = cache + + cache_key = f"{path}_{mode}_{use_geometry_enhanced}" + + if cache and not force_reload and cache_key in self._cache: + print(f"Loading IR dataset from cache: {mode}") + cached_data = self._cache[cache_key] + self.graph_atom_bond = cached_data['atom_bond'] + self.graph_bond_angle = cached_data['bond_angle'] + self.smiles_list = cached_data.get('smiles', []) + return + + print(f"First-time loading IR dataset (mode={mode})...") + + # 1. 加载元数据文件(获取index_all) + meta_path = os.path.join(path, f'ir_column_charity_{mode}.npy') + if not os.path.exists(meta_path): + raise FileNotFoundError(f"IR meta file {meta_path} not found") + + data = np.load(meta_path, allow_pickle=True).item() + dataset_all = data['dataset_all'] + smiles_all = data['smiles_all'] + index_all = data['index_all'] + + print(f"Loaded meta data: {len(index_all)} samples") + + # 2. 按需读取IR光谱(只读取index_all中的文件!) + spectra_path = os.path.join(path, 'qm9_ir_spec') + self.ir_sequences = read_ir_spectra_by_ids(spectra_path, index_all) + + print(f"Loaded {len(self.ir_sequences)} IR spectra") + + # 3. 构建图数据 + descriptor_path = os.path.join(path, 'descriptor_all_column.npy') + if not os.path.exists(descriptor_path): + descriptor_path = None + + total_graph_atom_bond, total_graph_bond_angle = Construct_IR_Dataset( + dataset_all, index_all, descriptor_path + ) + + # 4. 将光谱信息附加到图数据 + self.graph_atom_bond = [] + self.graph_bond_angle = [] + self.smiles_list = [] + + # 注意:ir_sequences已经按id排序,index_all也是按id排序的 + for i, itm in enumerate(self.ir_sequences): + atom_bond = total_graph_atom_bond[i] + + atom_bond.sequence = paddle.to_tensor([itm['seq_40']]) + atom_bond.ir_id = paddle.to_tensor(int(itm['id'])) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + self.graph_atom_bond.append(atom_bond) + self.graph_bond_angle.append(total_graph_bond_angle[i]) + self.smiles_list.append(smiles_all[i]) + + print(f"Final dataset size: {len(self.graph_atom_bond)}") + + if cache: + self._cache[cache_key] = { + 'atom_bond': self.graph_atom_bond, + 'bond_angle': self.graph_bond_angle, + 'smiles': self.smiles_list + } + def __len__(self): + """返回数据集大小""" + return len(self.graph_atom_bond) + + def __getitem__(self, idx): + return (self.graph_atom_bond[idx], self.graph_bond_angle[idx]) + +# 自定义DataLoader +class IRDataLoader(DataLoader): + """IR数据集专用DataLoader""" + + def __init__(self, dataset, batch_size=64, shuffle=True, num_workers=0, **kwargs): + super().__init__( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=self._collate_fn, + **kwargs + ) + + @staticmethod + def _collate_fn(batch): + """ + 自定义collate函数 + + Args: + batch: list of (atom_bond, bond_angle, smiles) tuples + + Returns: + batched_atom_bond: Batch + batched_bond_angle: Batch + smiles_list: list of str + """ + from paddle_geometric.data import Batch + + atom_bond_list = [item[0] for item in batch] + bond_angle_list = [item[1] for item in batch] + #smiles_list = [item[2] for item in batch] + + batch_atom_bond = Batch.from_data_list(atom_bond_list) + batch_bond_angle = Batch.from_data_list(bond_angle_list) + # Data解包到Tensor字典 + x, edge_index, edge_attr, query_mask =batch_atom_bond.x,batch_atom_bond.edge_index,batch_atom_bond.edge_attr,batch_atom_bond.query_mask + ba_edge_index, ba_edge_attr = batch_bond_angle.edge_index,batch_bond_angle.edge_attr + batch_data = batch_atom_bond.batch + pos_gt = batch_atom_bond.peak_position + height_gt = batch_atom_bond.peak_height + num_gt = batch_atom_bond.peak_num + return \ + { + "x" : x , + "edge_index" : edge_index , + "edge_attr" : edge_attr , + "batch_data" : batch_data , + "ba_edge_index" : ba_edge_index , + "ba_edge_attr" : ba_edge_attr , + "query_mask" : query_mask + }, \ + { + "peak_number_gt" : num_gt , + "peak_position_gt": pos_gt , + "peak_height_gt" : height_gt + } + return batched_atom_bond, batched_bond_angle + diff --git a/ppmat/datasets/IRDataset/colored_tqdm.py b/ppmat/datasets/IRDataset/colored_tqdm.py new file mode 100644 index 00000000..c6b9cff0 --- /dev/null +++ b/ppmat/datasets/IRDataset/colored_tqdm.py @@ -0,0 +1,76 @@ +from tqdm import tqdm +import time +import os;os.system("") #兼容windows + +def hex_to_ansi(hex_color: str, background: bool = False) -> str: + """ + 将十六进制颜色转换为ANSI转义序列 + + Args: + hex_color: 十六进制颜色,如 '#dda0a0' 或 'dda0a0' + background: True表示背景色,False表示前景色 + + Returns: + ANSI转义序列字符串,如 '\033[38;2;221;160;160m' + + Example: + >>> print(f"{hex_to_ansi('#dda0a0')}Hello{hex_to_ansi('#000000')} World") + >>> print(f"{hex_to_ansi('dda0a0', background=True)}背景色{hex_to_ansi.reset()}") + """ + # 移除#号并转换为小写 + hex_color = hex_color.lower().lstrip('#') + + # 处理简写形式 (#fff -> ffffff) + if len(hex_color) == 3: + hex_color = ''.join([c * 2 for c in hex_color]) + + # 转换为RGB值 + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + + # ANSI真彩色序列 + # 38;2;R;G;B 为前景色,48;2;R;G;B 为背景色 + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +def rgb_to_ansi(r: int, g: int, b: int, background: bool = False) -> str: + """RGB值直接转ANSI""" + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +# 重置颜色的ANSI码 +hex_to_ansi.reset = '\033[0m' + +class ColoredTqdm(tqdm): + def __init__(self, *args, + start_color=(221, 160, 160), # RGB: #DDA0A0 + end_color=(160, 221, 160), # RGB: #A0DDA0 + **kwargs): + super().__init__(*args, **kwargs) + self.start_color = start_color + self.end_color = end_color + + def get_current_color(self): + progress = self.n / self.total if self.total > 0 else 0 + current_rgb = tuple( + int(start + (end - start) * progress) + for start, end in zip(self.start_color, self.end_color) + ) + result = current_rgb[0] * 16 ** 4 \ + + current_rgb[1] * 16 ** 2 \ + + current_rgb[2] * 16 ** 0 + return "%06x" % result + + def update(self, n=1): + super().update(n) + # 使用Rich的真彩色支持 + style = hex_to_ansi(self.get_current_color()) + self.bar_format = f'{{l_bar}}{style}{{bar}}{hex_to_ansi.reset}{{r_bar}}' + self.refresh() + + +if __name__ == "__main__": + # 使用示例 + for i in ColoredTqdm(range(100), desc="🌈 彩虹渐变"): + time.sleep(0.1) \ No newline at end of file diff --git a/ppmat/datasets/IRDataset/compound_tools.py b/ppmat/datasets/IRDataset/compound_tools.py new file mode 100644 index 00000000..bf47cd28 --- /dev/null +++ b/ppmat/datasets/IRDataset/compound_tools.py @@ -0,0 +1,828 @@ +import numpy as np +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import rdchem + +DAY_LIGHT_FG_SMARTS_LIST = [ + # C + "[CX4]", + "[$([CX2](=C)=C)]", + "[$([CX3]=[CX3])]", + "[$([CX2]#C)]", + # C & O + "[CX3]=[OX1]", + "[$([CX3]=[OX1]),$([CX3+]-[OX1-])]", + "[CX3](=[OX1])C", + "[OX1]=CN", + "[CX3](=[OX1])O", + "[CX3](=[OX1])[F,Cl,Br,I]", + "[CX3H1](=O)[#6]", + "[CX3](=[OX1])[OX2][CX3](=[OX1])", + "[NX3][CX3](=[OX1])[#6]", + "[NX3][CX3]=[NX3+]", + "[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]", + "[NX3][CX3](=[OX1])[OX2H0]", + "[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]", + "[CX3](=O)[O-]", + "[CX3](=[OX1])(O)O", + "[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]", + "C[OX2][CX3](=[OX1])[OX2]C", + "[CX3](=O)[OX2H1]", + "[CX3](=O)[OX1H0-,OX2H1]", + "[NX3][CX2]#[NX1]", + "[#6][CX3](=O)[OX2H0][#6]", + "[#6][CX3](=O)[#6]", + "[OD2]([#6])[#6]", + # H + "[H]", + "[!#1]", + "[H+]", + "[+H]", + "[!H]", + # N + "[NX3;H2,H1;!$(NC=O)]", + "[NX3][CX3]=[CX3]", + "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]", + "[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]", + "[NX3][$(C=C),$(cc)]", + "[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]", + "[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]", + "[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH3X4]", + "[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]", + "[CH2X4][CX3](=[OX1])[NX3H2]", + "[CH2X4][CX3](=[OX1])[OH0-,OH]", + "[CH2X4][SX2H,SX1H0-]", + "[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]", + "[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]", + "[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\ +[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1", + "[CHX4]([CH3X4])[CH2X4][CH3X4]", + "[CH2X4][CHX4]([CH3X4])[CH3X4]", + "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]", + "[CH2X4][CH2X4][SX2][CH3X4]", + "[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1", + "[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH2X4][OX2H]", + "[NX3][CX3]=[SX1]", + "[CHX4]([CH3X4])[OX2H]", + "[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12", + "[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1", + "[CHX4]([CH3X4])[CH3X4]", + "N[CX4H2][CX3](=[OX1])[O,N]", + "N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]", + "[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]", + "[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]", + "[#7]", + "[NX2]=N", + "[NX2]=[NX2]", + "[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]", + "[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]", + "[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]", + "[NX3][NX3]", + "[NX3][NX2]=[*]", + "[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]", + "[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]", + "[NX3+]=[CX3]", + "[CX3](=[OX1])[NX3H][CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])", + "[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]", + "[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]", + "[NX1]#[CX2]", + "[CX1-]#[NX2+]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[NX2]=[OX1]", + "[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]", + # O + "[OX2H]", + "[#6][OX2H]", + "[OX2H][CX3]=[OX1]", + "[OX2H]P", + "[OX2H][#6X3]=[#6]", + "[OX2H][cX3]:[c]", + "[OX2H][$(C=C),$(cc)]", + "[$([OH]-*=[!#6])]", + "[OX2,OX1-][OX2,OX1-]", + # P + "[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\ +$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\ +,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]", + "[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\ +$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\ +$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]", + # S + "[S-][CX3](=S)[#6]", + "[#6X3](=[SX1])([!N])[!N]", + "[SX2]", + "[#16X2H]", + "[#16!H0]", + "[#16X2H0]", + "[#16X2H0][!#16]", + "[#16X2H0][#16X2H0]", + "[#16X2H0][!#16].[#16X2H0][!#16]", + "[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]", + "[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]", + "[SX4](C)(C)(=O)=N", + "[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]", + "[$([#16X3]=[OX1]),$([#16X3+][OX1-])]", + "[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]", + "[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]", + "[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]", + "[#16X2][OX2H,OX1H0-]", + "[#16X2][OX2H0]", + # X + "[#6][F,Cl,Br,I]", + "[F,Cl,Br,I]", + "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]", + ] + + +def get_gasteiger_partial_charges(mol, n_iter=12): + """ + Calculates list of gasteiger partial charges for each atom in mol object. + Args: + mol: rdkit mol object. + n_iter(int): number of iterations. Default 12. + Returns: + list of computed partial charges for each atom. + """ + Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, + throwOnParamFailure=True) + partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in + mol.GetAtoms()] + return partial_charges + + +def create_standardized_mol_id(smiles): + """ + Args: + smiles: smiles sequence. + Returns: + inchi. + """ + if check_smiles_validity(smiles): + # remove stereochemistry + smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), + isomericSmiles=False) + mol = AllChem.MolFromSmiles(smiles) + if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 + if '.' in smiles: # if multiple species, pick largest molecule + mol_species_list = split_rdkit_mol_obj(mol) + largest_mol = get_largest_mol(mol_species_list) + inchi = AllChem.MolToInchi(largest_mol) + else: + inchi = AllChem.MolToInchi(mol) + return inchi + else: + return + else: + return + + +def check_smiles_validity(smiles): + """ + Check whether the smile can't be converted to rdkit mol object. + """ + try: + m = Chem.MolFromSmiles(smiles) + if m: + return True + else: + return False + except Exception as e: + return False + + +def split_rdkit_mol_obj(mol): + """ + Split rdkit mol object containing multiple species or one species into a + list of mol objects or a list containing a single object respectively. + Args: + mol: rdkit mol object. + """ + smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) + smiles_list = smiles.split('.') + mol_species_list = [] + for s in smiles_list: + if check_smiles_validity(s): + mol_species_list.append(AllChem.MolFromSmiles(s)) + return mol_species_list + + +def get_largest_mol(mol_list): + """ + Given a list of rdkit mol objects, returns mol object containing the + largest num of atoms. If multiple containing largest num of atoms, + picks the first one. + Args: + mol_list(list): a list of rdkit mol object. + Returns: + the largest mol. + """ + num_atoms_list = [len(m.GetAtoms()) for m in mol_list] + largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) + return mol_list[largest_mol_idx] + + +def rdchem_enum_to_list(values): + """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} + """ + return [values[i] for i in range(len(values))] + + +def safe_index(alist, elem): + """ + Return index of element e in list l. If e is not present, return the last index + """ + try: + return alist.index(elem) + except ValueError: + return len(alist) - 1 + + +def get_atom_feature_dims(list_acquired_feature_names): + """ tbd + """ + return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) + + +def get_bond_feature_dims(list_acquired_feature_names): + """ tbd + """ + list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) + # +1 for self loop edges + return [_l + 1 for _l in list_bond_feat_dim] + + +class CompoundKit(object): + """ + CompoundKit + """ + atom_vocab_dict = { + "atomic_num": list(range(1, 119)) + ['misc'], + "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), + "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values), + "implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "is_aromatic": [0, 1], + "total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'num_radical_e': [0, 1, 2, 3, 4, 'misc'], + 'atom_is_in_ring': [0, 1], + 'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + } + bond_vocab_dict = { + "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), + "bond_type": rdchem_enum_to_list(rdchem.BondType.values), + "is_in_ring": [0, 1], + + 'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values), + 'is_conjugated': [0, 1], + } + # float features + atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] + # bond_float_feats= ["bond_length", "bond_angle"] # optional + + ### functional groups + day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST + day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] + + morgan_fp_N = 200 + morgan2048_fp_N = 2048 + maccs_fp_N = 167 + + period_table = Chem.GetPeriodicTable() + + ### atom + + @staticmethod + def get_atom_value(atom, name): + """get atom values""" + if name == 'atomic_num': + return atom.GetAtomicNum() + elif name == 'chiral_tag': + return atom.GetChiralTag() + elif name == 'degree': + return atom.GetDegree() + elif name == 'explicit_valence': + return atom.GetExplicitValence() + elif name == 'formal_charge': + return atom.GetFormalCharge() + elif name == 'hybridization': + return atom.GetHybridization() + elif name == 'implicit_valence': + return atom.GetImplicitValence() + elif name == 'is_aromatic': + return int(atom.GetIsAromatic()) + elif name == 'mass': + return int(atom.GetMass()) + elif name == 'total_numHs': + return atom.GetTotalNumHs() + elif name == 'num_radical_e': + return atom.GetNumRadicalElectrons() + elif name == 'atom_is_in_ring': + return int(atom.IsInRing()) + elif name == 'valence_out_shell': + return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) + else: + raise ValueError(name) + + @staticmethod + def get_atom_feature_id(atom, name): + """get atom features id""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) + + @staticmethod + def get_atom_feature_size(name): + """get atom features size""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return len(CompoundKit.atom_vocab_dict[name]) + + ### bond + + @staticmethod + def get_bond_value(bond, name): + """get bond values""" + if name == 'bond_dir': + return bond.GetBondDir() + elif name == 'bond_type': + return bond.GetBondType() + elif name == 'is_in_ring': + return int(bond.IsInRing()) + elif name == 'is_conjugated': + return int(bond.GetIsConjugated()) + elif name == 'bond_stereo': + return bond.GetStereo() + else: + raise ValueError(name) + + @staticmethod + def get_bond_feature_id(bond, name): + """get bond features id""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) + + @staticmethod + def get_bond_feature_size(name): + """get bond features size""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return len(CompoundKit.bond_vocab_dict[name]) + + ### fingerprint + + @staticmethod + def get_morgan_fingerprint(mol, radius=2): + """get morgan fingerprint""" + nBits = CompoundKit.morgan_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_morgan2048_fingerprint(mol, radius=2): + """get morgan2048 fingerprint""" + nBits = CompoundKit.morgan2048_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_maccs_fingerprint(mol): + """get maccs fingerprint""" + fp = AllChem.GetMACCSKeysFingerprint(mol) + return [int(b) for b in fp.ToBitString()] + + ### functional groups + + @staticmethod + def get_daylight_functional_group_counts(mol): + """get daylight functional group counts""" + fg_counts = [] + for fg_mol in CompoundKit.day_light_fg_mo_list: + sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) + fg_counts.append(len(sub_structs)) + return fg_counts + + @staticmethod + def get_ring_size(mol): + """return (N,6) list""" + rings = mol.GetRingInfo() + rings_info = [] + for r in rings.AtomRings(): + rings_info.append(r) + ring_list = [] + for atom in mol.GetAtoms(): + atom_result = [] + for ringsize in range(3, 9): + num_of_ring_at_ringsize = 0 + for r in rings_info: + if len(r) == ringsize and atom.GetIdx() in r: + num_of_ring_at_ringsize += 1 + if num_of_ring_at_ringsize > 8: + num_of_ring_at_ringsize = 9 + atom_result.append(num_of_ring_at_ringsize) + + ring_list.append(atom_result) + return ring_list + + @staticmethod + def atom_to_feat_vector(atom): + """ tbd """ + atom_names = { + "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), + "chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()), + "degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()), + "explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()), + "formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()), + "hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()), + "implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()), + "is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())), + "total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()), + 'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()), + 'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())), + 'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'], + CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())), + 'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()), + 'partial_charge': CompoundKit.check_partial_charge(atom), + 'mass': atom.GetMass(), + } + return atom_names + + @staticmethod + def get_atom_names(mol): + """get atom name list + TODO: to be remove in the future + """ + atom_features_dicts = [] + Chem.rdPartialCharges.ComputeGasteigerCharges(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) + + ring_list = CompoundKit.get_ring_size(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) + atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) + atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) + atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) + atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) + atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) + + return atom_features_dicts + + @staticmethod + def check_partial_charge(atom): + """tbd""" + pc = atom.GetDoubleProp('_GasteigerCharge') + if pc != pc: + # unsupported atom, replace nan with 0 + pc = 0 + if pc == float('inf'): + # max 4 for other atoms, set to 10 here if inf is get + pc = 10 + return pc + + +class Compound3DKit(object): + """the 3Dkit of Compound""" + + @staticmethod + def get_atom_poses(mol, conf): + """tbd""" + atom_poses = [] + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) + pos = conf.GetAtomPosition(i) + atom_poses.append([pos.x, pos.y, pos.z]) + return atom_poses + + @staticmethod + def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): + """the atoms of mol will be changed in some cases.""" + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return mol,atom_poses + # try: + # new_mol = Chem.AddHs(mol) + # res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs) + # ### MMFF generates multiple conformations + # res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + # new_mol = Chem.RemoveHs(new_mol) + # index = np.argmin([x[1] for x in res]) + # energy = res[index][1] + # conf = new_mol.GetConformer(id=int(index)) + # except: + # new_mol = mol + # AllChem.Compute2DCoords(new_mol) + # energy = 0 + # conf = new_mol.GetConformer() + # + # atom_poses = Compound3DKit.get_atom_poses(new_mol, conf) + # if return_energy: + # return new_mol, atom_poses, energy + # else: + # return new_mol, atom_poses + + @staticmethod + def get_2d_atom_poses(mol): + """get 2d atom poses""" + AllChem.Compute2DCoords(mol) + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return atom_poses + + @staticmethod + def get_bond_lengths(edges, atom_poses): + """get bond lengths""" + bond_lengths = [] + for src_node_i, tar_node_j in edges: + bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) + bond_lengths = np.array(bond_lengths, 'float32') + return bond_lengths + + @staticmethod + def get_superedge_angles(edges, atom_poses, dir_type='HT'): + """get superedge angles""" + + def _get_vec(atom_poses, edge): + return atom_poses[edge[1]] - atom_poses[edge[0]] + + def _get_angle(vec1, vec2): + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + if norm1 == 0 or norm2 == 0: + return 0 + vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors + vec2 = vec2 / (norm2 + 1e-5) + angle = np.arccos(np.dot(vec1, vec2)) + return angle + + E = len(edges) + edge_indices = np.arange(E) + super_edges = [] + bond_angles = [] + bond_angle_dirs = [] + for tar_edge_i in range(E): + tar_edge = edges[tar_edge_i] + if dir_type == 'HT': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] + elif dir_type == 'HH': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] + else: + raise ValueError(dir_type) + for src_edge_i in src_edge_indices: + if src_edge_i == tar_edge_i: + continue + src_edge = edges[src_edge_i] + src_vec = _get_vec(atom_poses, src_edge) + tar_vec = _get_vec(atom_poses, tar_edge) + super_edges.append([src_edge_i, tar_edge_i]) + angle = _get_angle(src_vec, tar_vec) + bond_angles.append(angle) + bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T + + if len(super_edges) == 0: + super_edges = np.zeros([0, 2], 'int64') + bond_angles = np.zeros([0, ], 'float32') + else: + super_edges = np.array(super_edges, 'int64') + bond_angles = np.array(bond_angles, 'float32') + return super_edges, bond_angles, bond_angle_dirs + + +def new_smiles_to_graph_data(smiles, **kwargs): + """ + Convert smiles to graph data. + """ + mol = AllChem.MolFromSmiles(smiles) + if mol is None: + return None + data = new_mol_to_graph_data(mol) + return data + + +def new_mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names + bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) + + data = {} + + ### atom features + data = {name: [] for name in atom_id_names} + + raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) + for atom_feat in raw_atom_feat_dicts: + for name in atom_id_names: + data[name].append(atom_feat[name]) + + ### bond and bond features + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + data[name] += [bond_feature_id] * 2 + + #### self loop + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1 + data[name] += [bond_feature_id] * N + + ### make ndarray and check length + for name in list(CompoundKit.atom_vocab_dict.keys()): + data[name] = np.array(data[name], 'int64') + for name in CompoundKit.atom_float_names: + data[name] = np.array(data[name], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", + ] + bond_id_names = [ + "bond_dir", "bond_type", "is_in_ring", + ] + + data = {} + for name in atom_id_names: + data[name] = [] + data['mass'] = [] + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + ### atom features + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return None + for name in atom_id_names: + data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV + data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) + + ### bond features + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV + data[name] += [bond_feature_id] * 2 + + ### self loop (+2) + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop + data[name] += [bond_feature_id] * N + + ### check whether edge exists + if len(data['edges']) == 0: # mol has no bonds + for name in bond_id_names: + data[name] = np.zeros((0,), dtype="int64") + data['edges'] = np.zeros((0, 2), dtype="int64") + + ### make ndarray and check length + for name in atom_id_names: + data[name] = np.array(data[name], 'int64') + data['mass'] = np.array(data['mass'], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_geognn_graph_data(mol, atom_poses, dir_type): + """ + mol: rdkit molecule + dir_type: direction type for bond_angle grpah + """ + if len(mol.GetAtoms()) == 0: + return None + + data = mol_to_graph_data(mol) + + data['atom_pos'] = np.array(atom_poses, 'float32') + data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) + BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ + Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) + data['BondAngleGraph_edges'] = BondAngleGraph_edges + data['bond_angle'] = np.array(bond_angles, 'float32') + return data + + +def mol_to_geognn_graph_data_MMFF3d(mol): + """tbd""" + if len(mol.GetAtoms()) <= 400: + mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) + else: + atom_poses = Compound3DKit.get_2d_atom_poses(mol) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + + +def mol_to_geognn_graph_data_raw3d(mol): + """tbd""" + atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + +def obtain_3D_mol(smiles,name): + mol = AllChem.MolFromSmiles(smiles) + new_mol = Chem.AddHs(mol) + res = AllChem.EmbedMultipleConfs(new_mol) + ### MMFF generates multiple conformations + res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + new_mol = Chem.RemoveHs(new_mol) + Chem.MolToMolFile(new_mol, name+'.mol') + return new_mol + +def predict_SMILES_info(smiles): + # by lihao, input smiles, output dict + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + info_dict = mol_to_geognn_graph_data_MMFF3d(mol) + return info_dict + +if __name__ == "__main__": + # smiles = "OCc1ccccc1CN" + smiles = r"[H]/[NH+]=C(\N)C1=CC(=O)/C(=C\C=c2ccc(=C(N)[NH3+])cc2)C=C1" + # smiles = 'CC' + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + data = mol_to_geognn_graph_data_MMFF3d(mol) + for key, value in data.items(): + print(key, value.shape) \ No newline at end of file diff --git a/ppmat/datasets/IRDataset/place_env.py b/ppmat/datasets/IRDataset/place_env.py new file mode 100644 index 00000000..6d06a504 --- /dev/null +++ b/ppmat/datasets/IRDataset/place_env.py @@ -0,0 +1,173 @@ +import paddle +import functools +from contextlib import contextmanager + +@contextmanager +def place_env(place): + """ + 上下文管理器,用于临时设置PaddlePaddle的运行设备 + + Args: + place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + + 用法: + with place_env(paddle.CPUPlace()): + # 这里的代码在CPU上运行 + x = paddle.rand([2, 3]) + print(x) + + @place_env(paddle.CUDAPlace(0)) + def train(): + # 这个函数在GPU上运行 + pass + """ + # 保存当前的设备设置 + current_device = paddle.get_device() + + # 根据place类型设置设备 + if isinstance(place, paddle.CPUPlace): + paddle.set_device('cpu') + elif isinstance(place, paddle.CUDAPlace): + # 获取GPU设备ID + device_id = place.get_device_id() + paddle.set_device(f'gpu:{device_id}') + else: + raise ValueError(f"不支持的place类型: {type(place)}") + + try: + yield + finally: + # 恢复原来的设备设置 + paddle.set_device(current_device) + + +class PlaceEnv: + """ + 类版本的上下文管理器,也支持装饰器功能 + """ + + def __init__(self, place): + """ + 初始化PlaceEnv + + Args: + place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + """ + self.place = place + self.original_device = None + + def __enter__(self): + """进入上下文时调用""" + # 保存当前的设备设置 + self.original_device = paddle.get_device() + + # 根据place类型设置设备 + if isinstance(self.place, paddle.CPUPlace): + paddle.set_device('cpu') + elif isinstance(self.place, paddle.CUDAPlace): + device_id = self.place.get_device_id() + paddle.set_device(f'gpu:{device_id}') + else: + raise ValueError(f"不支持的place类型: {type(self.place)}") + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """退出上下文时调用""" + # 恢复原来的设备设置 + if self.original_device is not None: + paddle.set_device(self.original_device) + + def __call__(self, func): + """ + 使实例可以作为装饰器使用 + + Args: + func: 要装饰的函数 + + Returns: + 装饰后的函数 + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + # 使用with语句来临时改变设备设置 + with self: + return func(*args, **kwargs) + return wrapper + + +# 为了兼容性,也可以保留函数版本的上下文管理器 +@contextmanager +def with_place_env(place): + """ + with_place_env的别名,与place_env功能相同 + """ + with place_env(place): + yield + + +# 使用示例 +if __name__ == "__main__": + # 测试with语句 + print("=== 测试with语句 ===") + print(f"当前设备: {paddle.get_device()}") + + with place_env(paddle.CPUPlace()): + print(f"with块内设备: {paddle.get_device()}") + x = paddle.rand([2, 3]) + print(f"创建的张量: {x}") + + print(f"with块外设备: {paddle.get_device()}") + + print("\n=== 测试类版本with语句 ===") + with PlaceEnv(paddle.CPUPlace()): + print(f"with块内设备: {paddle.get_device()}") + y = paddle.ones([2, 3]) + print(f"创建的张量: {y}") + + print(f"with块外设备: {paddle.get_device()}") + + # 测试装饰器功能 + print("\n=== 测试装饰器功能 ===") + + @PlaceEnv(paddle.CPUPlace()) + def cpu_function(): + """这个函数会在CPU上运行""" + print(f"函数内设备: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # 检查是否有GPU可用 + if paddle.device.cuda.device_count() > 0: + @PlaceEnv(paddle.CUDAPlace(0)) + def gpu_function(): + """这个函数会在GPU上运行""" + print(f"函数内设备: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # 调用装饰后的函数 + print("调用cpu_function:") + result_cpu = cpu_function() + print(f"函数执行后设备: {paddle.get_device()}") + print(f"结果: {result_cpu}") + + if paddle.device.cuda.device_count() > 0: + print("\n调用gpu_function:") + result_gpu = gpu_function() + print(f"函数执行后设备: {paddle.get_device()}") + print(f"结果: {result_gpu}") + + print("\n=== 测试多层嵌套 ===") + print(f"初始设备: {paddle.get_device()}") + + with PlaceEnv(paddle.CPUPlace()): + print(f"第一层with内设备: {paddle.get_device()}") + + if paddle.device.cuda.device_count() > 0: + with PlaceEnv(paddle.CUDAPlace(0)): + print(f"第二层with内设备: {paddle.get_device()}") + z = paddle.rand([2, 2]) + print(f"创建的张量: {z}") + + print(f"回到第一层with设备: {paddle.get_device()}") + + print(f"最终设备: {paddle.get_device()}") \ No newline at end of file diff --git a/ppmat/datasets/__init__.py b/ppmat/datasets/__init__.py index 98eec451..86d6eaad 100644 --- a/ppmat/datasets/__init__.py +++ b/ppmat/datasets/__init__.py @@ -47,6 +47,9 @@ from ppmat.datasets.oc20_s2ef_dataset import OC20S2EFDataset # noqa from ppmat.datasets.qm9_dataset import QM9Dataset # noqa from ppmat.datasets.omol25_dataset import OMol25Dataset +from ppmat.datasets.IRDataset import IRDataset, IRDataLoader +from ppmat.datasets.ECDFormerDataset import ECDFormerDataset, ECDFormerDataset_DataLoader +from ppmat.datasets.omol25_dataset import OMol25Dataset from ppmat.datasets.split_mptrj_data import none_to_zero from ppmat.datasets.transform import build_transforms from ppmat.utils import logger @@ -67,6 +70,10 @@ "DensityDataset", "SmallDensityDataset", "OMol25Dataset", + "IRDataset", + "ECDFormerDataset", + "IRDataLoader", + "ECDFormerDataset_DataLoader", ] INFO_CLASS_REGISTRY: Dict[str, type] = { From e7abf4ca4ed03697a589381e43b456915239c967 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <108660099+PlumBlossomMaid@users.noreply.github.com> Date: Tue, 24 Feb 2026 17:32:35 +0800 Subject: [PATCH 2/2] Remove duplicate import Removed duplicate import of OMol25Dataset. --- ppmat/datasets/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppmat/datasets/__init__.py b/ppmat/datasets/__init__.py index 86d6eaad..f95adcec 100644 --- a/ppmat/datasets/__init__.py +++ b/ppmat/datasets/__init__.py @@ -49,7 +49,6 @@ from ppmat.datasets.omol25_dataset import OMol25Dataset from ppmat.datasets.IRDataset import IRDataset, IRDataLoader from ppmat.datasets.ECDFormerDataset import ECDFormerDataset, ECDFormerDataset_DataLoader -from ppmat.datasets.omol25_dataset import OMol25Dataset from ppmat.datasets.split_mptrj_data import none_to_zero from ppmat.datasets.transform import build_transforms from ppmat.utils import logger