From d5a48df165ccb058729c2d4bc1370769730eef6a Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Tue, 24 Feb 2026 10:39:50 +0800 Subject: [PATCH 01/16] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BA=86=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/models/__init__.py | 2 + ppmat/models/ecformer/__init__.py | 38 ++ ppmat/models/ecformer/encoders/__init__.py | 15 + .../ecformer/encoders/gin_node_embedding.py | 183 ++++++++ ppmat/models/ecformer/layers/__init__.py | 18 + ppmat/models/ecformer/layers/atom_encoder.py | 33 ++ ppmat/models/ecformer/layers/bond_encoder.py | 33 ++ ppmat/models/ecformer/layers/gin_conv.py | 46 ++ ppmat/models/ecformer/layers/rbf.py | 119 +++++ ppmat/models/ecformer/models/ECD.py | 154 +++++++ ppmat/models/ecformer/models/IR.py | 157 +++++++ ppmat/models/ecformer/models/__init__.py | 16 + ppmat/models/ecformer/models/base_ecformer.py | 259 +++++++++++ ppmat/models/ecformer/utils/__init__.py | 15 + ppmat/models/ecformer/utils/graph_utils.py | 59 +++ .../models/ecformer/utils/loss/dilate_loss.py | 25 + .../ecformer/utils/loss/path_soft_dtw.py | 134 ++++++ ppmat/models/ecformer/utils/loss/soft_dtw.py | 97 ++++ .../ecformer/utils/loss/soft_dtw_cuda.py | 427 ++++++++++++++++++ 19 files changed, 1830 insertions(+) create mode 100644 ppmat/models/ecformer/__init__.py create mode 100644 ppmat/models/ecformer/encoders/__init__.py create mode 100644 ppmat/models/ecformer/encoders/gin_node_embedding.py create mode 100644 ppmat/models/ecformer/layers/__init__.py create mode 100644 ppmat/models/ecformer/layers/atom_encoder.py create mode 100644 ppmat/models/ecformer/layers/bond_encoder.py create mode 100644 ppmat/models/ecformer/layers/gin_conv.py create mode 100644 ppmat/models/ecformer/layers/rbf.py create mode 100644 ppmat/models/ecformer/models/ECD.py create mode 100644 ppmat/models/ecformer/models/IR.py create mode 100644 ppmat/models/ecformer/models/__init__.py create mode 100644 ppmat/models/ecformer/models/base_ecformer.py create mode 100644 ppmat/models/ecformer/utils/__init__.py create mode 100644 ppmat/models/ecformer/utils/graph_utils.py create mode 100644 ppmat/models/ecformer/utils/loss/dilate_loss.py create mode 100644 ppmat/models/ecformer/utils/loss/path_soft_dtw.py create mode 100644 ppmat/models/ecformer/utils/loss/soft_dtw.py create mode 100644 ppmat/models/ecformer/utils/loss/soft_dtw_cuda.py diff --git a/ppmat/models/__init__.py b/ppmat/models/__init__.py index 95d73232..93520a59 100644 --- a/ppmat/models/__init__.py +++ b/ppmat/models/__init__.py @@ -35,6 +35,8 @@ from ppmat.models.diffnmr.diffnmr import MolecularGraphFormer from ppmat.models.diffnmr.diffnmr import NMRNetCLIP from ppmat.models.dimenetpp.dimenetpp import DimeNetPlusPlus +from ppmat.models.ecformer import ECFormerECD +from ppmat.models.ecformer import ECFormerIR from ppmat.models.mattergen.mattergen import MatterGen from ppmat.models.mattergen.mattergen import MatterGenWithCondition from ppmat.models.mattersim.m3gnet import M3GNet diff --git a/ppmat/models/ecformer/__init__.py b/ppmat/models/ecformer/__init__.py new file mode 100644 index 00000000..08b481e5 --- /dev/null +++ b/ppmat/models/ecformer/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 导出模型类 +from .models.ECD import ECFormerECD +from .models.IR import ECFormerIR + +# 导出编码器(如需直接使用) +from .encoders.gin_node_embedding import GINNodeEmbedding + +# 导出工具函数 +from .utils.graph_utils import ( + index_transform, + get_key_padding_mask, + feat_padding_mask, + pad_node_features +) + +__all__ = [ + 'ECFormerECD', + 'ECFormerIR', + 'GINNodeEmbedding', + 'index_transform', + 'get_key_padding_mask', + 'feat_padding_mask', + 'pad_node_features', +] \ No newline at end of file diff --git a/ppmat/models/ecformer/encoders/__init__.py b/ppmat/models/ecformer/encoders/__init__.py new file mode 100644 index 00000000..36adf7e7 --- /dev/null +++ b/ppmat/models/ecformer/encoders/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gin_node_embedding import GINNodeEmbedding \ No newline at end of file diff --git a/ppmat/models/ecformer/encoders/gin_node_embedding.py b/ppmat/models/ecformer/encoders/gin_node_embedding.py new file mode 100644 index 00000000..17df78e5 --- /dev/null +++ b/ppmat/models/ecformer/encoders/gin_node_embedding.py @@ -0,0 +1,183 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ..layers.atom_encoder import AtomEncoder +from ..layers.bond_encoder import BondEncoder +from ..layers.rbf import BondFloatRBF, BondAngleFloatRBF +from ..layers.gin_conv import GINConv + + +class GINNodeEmbedding(nn.Layer): + """GIN节点嵌入模块 - 支持几何增强的双图结构""" + + def __init__( + self, + full_atom_feature_dims, + full_bond_feature_dims, + bond_float_names, + bond_angle_float_names, + bond_id_names, + num_layers=5, + emb_dim=128, + drop_ratio=0.5, + JK="last", + residual=False, + use_geometry_enhanced=True + ): + super(GINNodeEmbedding, self).__init__() + + self.num_layers = num_layers + self.drop_ratio = drop_ratio + self.JK = JK + self.residual = residual + self.use_geometry_enhanced = use_geometry_enhanced + self.bond_id_names = bond_id_names + + if self.num_layers < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + # 编码器 + self.atom_encoder = AtomEncoder(full_atom_feature_dims, emb_dim) + self.bond_encoder = BondEncoder(full_bond_feature_dims, emb_dim) + self.bond_float_encoder = BondFloatRBF(bond_float_names, emb_dim) + self.bond_angle_encoder = BondAngleFloatRBF(bond_angle_float_names, emb_dim) + + # GNN层列表 + self.convs = nn.LayerList() + self.convs_bond_angle = nn.LayerList() + self.convs_bond_embedding = nn.LayerList() + self.convs_bond_float = nn.LayerList() + self.convs_angle_float = nn.LayerList() + self.batch_norms = nn.LayerList() + self.batch_norms_ba = nn.LayerList() + + for _ in range(num_layers): + self.convs.append(GINConv(emb_dim)) + self.convs_bond_angle.append(GINConv(emb_dim)) + self.convs_bond_embedding.append(BondEncoder(full_bond_feature_dims, emb_dim)) + self.convs_bond_float.append(BondFloatRBF(bond_float_names, emb_dim)) + self.convs_angle_float.append(BondAngleFloatRBF(bond_angle_float_names, emb_dim)) + self.batch_norms.append(nn.BatchNorm1D(emb_dim)) + self.batch_norms_ba.append(nn.BatchNorm1D(emb_dim)) + + + def forward( + self, + x, # [N, F] 原子特征 + edge_index, # [2, E] 边索引 + edge_attr, # [E, D] 边特征 + # 几何增强相关输入 + ba_edge_index=None, # [2, E_ba] 键角图边索引 + ba_edge_attr=None, # [E_ba, D_ba] 键角图边特征 + ): + """ + 前向传播 + """ + # 1. 原子特征编码 + if x.dtype != paddle.int64: + x = x.astype(paddle.int64) + h_list = [self.atom_encoder(x)] + + if self.use_geometry_enhanced and ba_edge_index is not None: + return self._forward_enhanced( + h_list, edge_index, edge_attr, + ba_edge_index, ba_edge_attr + ) + else: + return self._forward_simple( + h_list, edge_index, edge_attr + ) + + def _forward_enhanced(self, h_list, edge_index, edge_attr, + ba_edge_index, ba_edge_attr): + """几何增强前向传播""" + + bond_id_len = len(self.bond_id_names) + + # 初始化边表示 + h_list_ba = [self.bond_float_encoder( + edge_attr[:, bond_id_len:edge_attr.shape[1]+1].astype('float32') + ) + self.bond_encoder( + edge_attr[:, 0:bond_id_len].astype('int64') + )] + + for layer in range(self.num_layers): + # 节点更新 + h = self.convs[layer](h_list[layer], edge_index, h_list_ba[layer]) + + # 边更新 + cur_h_ba = self.convs_bond_embedding[layer]( + edge_attr[:, 0:bond_id_len].astype('int64') + ) + self.convs_bond_float[layer]( + edge_attr[:, bond_id_len:edge_attr.shape[1]+1].astype('float32') + ) + cur_angle_hidden = self.convs_angle_float[layer](ba_edge_attr) + h_ba = self.convs_bond_angle[layer](cur_h_ba, ba_edge_index, cur_angle_hidden) + + # Dropout和残差 + if layer == self.num_layers - 1: + h = F.dropout(h, self.drop_ratio, training=self.training) + h_ba = F.dropout(h_ba, self.drop_ratio, training=self.training) + else: + h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) + h_ba = F.dropout(F.relu(h_ba), self.drop_ratio, training=self.training) + + if self.residual: + h += h_list[layer] + h_ba += h_list_ba[layer] + + h_list.append(h) + h_list_ba.append(h_ba) + + # JK连接策略 + if self.JK == "last": + node_representation = h_list[-1] + edge_representation = h_list_ba[-1] + elif self.JK == "sum": + node_representation = sum(h_list) + edge_representation = sum(h_list_ba) + + return node_representation, edge_representation + + def _forward_simple(self, h_list, edge_index, edge_attr): + """简化前向传播""" + bond_id_len = len(self.bond_id_names) + + for layer in range(self.num_layers): + h = self.convs[layer]( + h_list[layer], + edge_index, + self.convs_bond_embedding[layer](edge_attr[:, 0:bond_id_len].astype('int64')) + + self.convs_bond_float[layer](edge_attr[:, bond_id_len:edge_attr.shape[1]+1].astype('float32')) + ) + h = self.batch_norms[layer](h) + + if layer == self.num_layers - 1: + h = F.dropout(h, self.drop_ratio, training=self.training) + else: + h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) + + if self.residual: + h += h_list[layer] + + h_list.append(h) + + if self.JK == "last": + return h_list[-1] + elif self.JK == "sum": + return sum(h_list) \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/__init__.py b/ppmat/models/ecformer/layers/__init__.py new file mode 100644 index 00000000..a763d9c7 --- /dev/null +++ b/ppmat/models/ecformer/layers/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .atom_encoder import AtomEncoder +from .bond_encoder import BondEncoder +from .rbf import RBF, BondFloatRBF +from .gin_conv import GINConv \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/atom_encoder.py b/ppmat/models/ecformer/layers/atom_encoder.py new file mode 100644 index 00000000..34ef8d03 --- /dev/null +++ b/ppmat/models/ecformer/layers/atom_encoder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + +class AtomEncoder(nn.Layer): + """原子特征编码器 - 将离散原子特征映射为连续向量""" + + def __init__(self, full_atom_feature_dims, emb_dim): + super(AtomEncoder, self).__init__() + self.atom_embedding_list = nn.LayerList() + + for dim in full_atom_feature_dims: + emb = nn.Embedding(dim + 5, emb_dim) + nn.initializer.XavierUniform()(emb.weight) + self.atom_embedding_list.append(emb) + + def forward(self, x): + x_embedding = 0 + for i in range(x.shape[1]): + x_embedding += self.atom_embedding_list[i](x[:, i]) + return x_embedding \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/bond_encoder.py b/ppmat/models/ecformer/layers/bond_encoder.py new file mode 100644 index 00000000..ccd7207b --- /dev/null +++ b/ppmat/models/ecformer/layers/bond_encoder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + +class BondEncoder(nn.Layer): + """键特征编码器 - 将离散键特征映射为连续向量""" + + def __init__(self, full_bond_feature_dims, emb_dim): + super(BondEncoder, self).__init__() + self.bond_embedding_list = nn.LayerList() + + for dim in full_bond_feature_dims: + emb = nn.Embedding(dim + 5, emb_dim) + nn.initializer.XavierUniform()(emb.weight) + self.bond_embedding_list.append(emb) + + def forward(self, edge_attr): + bond_embedding = 0 + for i in range(edge_attr.shape[1]): + bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) + return bond_embedding \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/gin_conv.py b/ppmat/models/ecformer/layers/gin_conv.py new file mode 100644 index 00000000..7a15945a --- /dev/null +++ b/ppmat/models/ecformer/layers/gin_conv.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_geometric.nn import MessagePassing +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +class GINConv(MessagePassing): + """图同构卷积层""" + + def __init__(self, emb_dim): + super(GINConv, self).__init__(aggr="add") + + self.mlp = nn.Sequential( + nn.Linear(emb_dim, emb_dim), + nn.BatchNorm1D(emb_dim), + nn.ReLU(), + nn.Linear(emb_dim, emb_dim) + ) + self.eps = paddle.create_parameter( + shape=[1], + dtype='float32', + default_initializer=nn.initializer.Assign(paddle.to_tensor([0.])) + ) + + def forward(self, x, edge_index, edge_attr): + out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_attr)) + return out + + def message(self, x_j, edge_attr): + return F.relu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/rbf.py b/ppmat/models/ecformer/layers/rbf.py new file mode 100644 index 00000000..73617b52 --- /dev/null +++ b/ppmat/models/ecformer/layers/rbf.py @@ -0,0 +1,119 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np + +class RBF(nn.Layer): + """径向基函数""" + + def __init__(self, + centers: paddle.nn.parameter.Parameter, + gamma: paddle.nn.parameter.Parameter): + super(RBF, self).__init__() + self.centers = centers.data.reshape([1, -1]) + self.gamma = gamma + + def forward(self, x): + x = x.reshape([-1, 1]) + return paddle.exp(-self.gamma * paddle.square(x - self.centers)) + + +class BondFloatRBF(nn.Layer): + """连续键特征RBF编码器""" + + def __init__(self, bond_float_names, embed_dim, rbf_params=None): + super(BondFloatRBF, self).__init__() + self.bond_float_names = bond_float_names + + if rbf_params is None: + self.rbf_params = self._default_rbf_params() + else: + self.rbf_params = rbf_params + + self.linear_list = nn.LayerList() + self.rbf_list = nn.LayerList() + for name in self.bond_float_names: + centers, gamma = self.rbf_params[name] + rbf = RBF(centers, gamma) + self.rbf_list.append(rbf) + linear = nn.Linear(len(centers), embed_dim) + self.linear_list.append(linear) + + def _default_rbf_params(self): + return { + 'bond_length': (paddle.create_parameter(shape=paddle.arange(0, 2, 0.1).shape, + dtype=paddle.arange(0, 2, 0.1).dtype, + default_initializer=paddle.nn.initializer.Assign(paddle.arange(0, 2, 0.1))), + paddle.create_parameter(shape=paddle.to_tensor([10.0]).shape, + dtype=paddle.to_tensor([10.0]).dtype, + default_initializer=paddle.nn.initializer.Assign(paddle.to_tensor([10.0])))), + } + + def forward(self, bond_float_features): + out_embed = 0 + for i, name in enumerate(self.bond_float_names): + x = bond_float_features[:, i].reshape([-1, 1]) + rbf_x = self.rbf_list[i](x) + out_embed += self.linear_list[i](rbf_x) + return out_embed + + +class BondAngleFloatRBF(nn.Layer): + """键角连续特征RBF编码器""" + + def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None): + super(BondAngleFloatRBF, self).__init__() + self.bond_angle_float_names = bond_angle_float_names + + if rbf_params is None: + self.rbf_params = { + 'bond_angle': (paddle.create_parameter(shape=paddle.arange(0, np.pi, 0.1).shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Assign(paddle.arange(0, np.pi, 0.1))), + paddle.create_parameter(shape=paddle.to_tensor([10.0]).shape, + dtype=paddle.to_tensor([10.0]).dtype, + default_initializer=paddle.nn.initializer.Assign(paddle.to_tensor([10.0])))), + } + else: + self.rbf_params = rbf_params + + self.linear_list = nn.LayerList() + self.rbf_list = nn.LayerList() + + for name in self.bond_angle_float_names: + if name == 'bond_angle': + centers, gamma = self.rbf_params[name] + rbf = RBF(centers, gamma) + self.rbf_list.append(rbf) + linear = nn.Linear(len(centers), embed_dim) + self.linear_list.append(linear) + else: + linear = nn.Linear(len(self.bond_angle_float_names) - 1, embed_dim) + self.linear_list.append(linear) + break + + def forward(self, bond_angle_float_features): + out_embed = 0 + for i, name in enumerate(self.bond_angle_float_names): + if name == 'bond_angle': + x = bond_angle_float_features[:, i].reshape([-1, 1]) + rbf_x = self.rbf_list[i](x) + out_embed += self.linear_list[i](rbf_x) + else: + x = bond_angle_float_features[:, 1:] + out_embed += self.linear_list[i](x) + break + return out_embed \ No newline at end of file diff --git a/ppmat/models/ecformer/models/ECD.py b/ppmat/models/ecformer/models/ECD.py new file mode 100644 index 00000000..c0a5d351 --- /dev/null +++ b/ppmat/models/ecformer/models/ECD.py @@ -0,0 +1,154 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np + +from .base_ecformer import ECFormerBase +from ..utils.graph_utils import get_key_padding_mask + + +class ECFormerECD(ECFormerBase): + """ECFormer for ECD光谱预测 - 峰属性解耦版本""" + + def __init__( + self, + num_position_classes = 20, + height_classes = 2, + loss_weight_height = 2.0, + **kwargs + ): + super().__init__(**kwargs) + + self.num_position_classes = num_position_classes + self.height_classes = height_classes + self.loss_weight_height = loss_weight_height + + # 峰数预测头 + self.pred_number_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim * 2), + nn.ReLU(), + nn.Linear(self.emb_dim * 2, self.max_peaks) + ) + + # 峰位置预测头 + self.pred_position_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim // 4), + nn.ReLU(), + nn.Linear(self.emb_dim // 4, num_position_classes) + ) + + # 峰符号预测头 + self.pred_height_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim // 4), + nn.ReLU(), + nn.Linear(self.emb_dim // 4, height_classes) + ) + + # 损失函数 + self.ce_loss = nn.CrossEntropyLoss() + + def get_loss(self, predictions, targets): + """ECD任务损失:峰数 + 位置 + 符号""" + # 峰数损失 + loss_num = self.ce_loss(predictions['peak_number'], targets['peak_num']) + + # 由于每个样本峰数不同,需要动态处理 + batch_size = targets['peak_num'].shape[0] + + loss_pos_total = 0 + loss_height_total = 0 + valid_samples = 0 + + for i in range(batch_size): + n_peaks = int(targets['peak_num'][i]) + if n_peaks == 0: + continue + + # 位置损失 + pos_pred = predictions['peak_position'][i, :n_peaks, :].reshape([-1, self.num_position_classes]) + pos_gt = targets['peak_position'][i, :n_peaks].reshape([-1]) + loss_pos_total += self.ce_loss(pos_pred, pos_gt) + + # 符号损失 + height_pred = predictions['peak_height'][i, :n_peaks, :].reshape([-1, self.height_classes]) + height_gt = targets['peak_height'][i, :n_peaks].reshape([-1]) + loss_height_total += self.ce_loss(height_pred, height_gt) + + valid_samples += 1 + + if valid_samples > 0: + loss_pos = loss_pos_total / valid_samples + loss_height = loss_height_total / valid_samples + else: + loss_pos = paddle.to_tensor(0.0) + loss_height = paddle.to_tensor(0.0) + + return loss_num + self.loss_weight_height * loss_height + loss_pos + + def get_metrics(self, predictions, targets): + """ECD任务评估指标:Number-RMSE, Position-RMSE, Symbol-Acc""" + + batch_size = targets['peak_num'].shape[0] + + # 峰数预测 + pred_nums = predictions['peak_number'].argmax(axis=1) + true_nums = targets['peak_num'] + + # 位置误差 + pos_errors = [] + # 符号准确率 + symbol_correct = 0 + symbol_total = 0 + # 首峰符号准确率 + first_symbol_correct = 0 + first_symbol_total = 0 + + for i in range(batch_size): + n_true = int(true_nums[i]) + n_pred = int(pred_nums[i]) + + if n_true > 0 and n_pred > 0: + n_match = min(n_true, n_pred) + + # 位置误差 + pos_true = targets['peak_position'][i, :n_match].numpy() + pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1).numpy() + pos_errors.extend(pos_pred - pos_true) + + # 符号准确率 + height_true = targets['peak_height'][i, :n_match].numpy() + height_pred = predictions['peak_height'][i, :n_match, :].argmax(axis=1).numpy() + + symbol_correct += np.sum(height_true == height_pred) + symbol_total += n_match + + # 首峰符号准确率 + if height_true[0] == height_pred[0]: + first_symbol_correct += 1 + first_symbol_total += 1 + + # 计算指标 + pos_rmse = np.sqrt(np.mean(np.square(pos_errors))) if pos_errors else 0.0 + num_rmse = np.sqrt(np.mean(np.square((pred_nums - true_nums).numpy()))) + symbol_acc = symbol_correct / symbol_total if symbol_total > 0 else 0.0 + first_symbol_acc = first_symbol_correct / first_symbol_total if first_symbol_total > 0 else 0.0 + + return { + 'num_rmse': num_rmse, + 'pos_rmse': pos_rmse, + 'symbol_acc': symbol_acc, + 'first_symbol_acc': first_symbol_acc + } \ No newline at end of file diff --git a/ppmat/models/ecformer/models/IR.py b/ppmat/models/ecformer/models/IR.py new file mode 100644 index 00000000..34c4b607 --- /dev/null +++ b/ppmat/models/ecformer/models/IR.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from sklearn.metrics import mean_squared_error + +from .base_ecformer import ECFormerBase +from ..utils.graph_utils import get_key_padding_mask +from ..utils.loss.soft_dtw_cuda import SoftDTW + + +class ECFormerIR(ECFormerBase): + """ECDFormer for IR光谱预测 - 序列回归版本""" + + def __init__( + self, + spectrum_length=1000, + num_position_classes=36, + use_height_prediction=True, + dtw_gamma=0.1, + **kwargs + ): + # IR任务最大峰数不同 + kwargs['max_peaks'] = kwargs.get('max_peaks', 15) + + super().__init__(**kwargs) + + self.spectrum_length = spectrum_length + self.num_position_classes = num_position_classes + self.use_height_prediction = use_height_prediction + + # 峰数预测头(IR最多15个峰) + self.pred_number_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim * 2), + nn.ReLU(), + nn.Linear(self.emb_dim * 2, self.max_peaks + 1) + ) + + # 峰位置预测头(IR位置分类更多) + self.pred_position_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim // 4), + nn.ReLU(), + nn.Linear(self.emb_dim // 4, num_position_classes) + ) + + # 峰强度预测头(IR回归) + if use_height_prediction: + self.pred_height_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim // 4), + nn.ReLU(), + nn.Linear(self.emb_dim // 4, 1) + ) + + # 损失函数 + self.ce_loss = nn.CrossEntropyLoss() + self.mse_loss = nn.MSELoss(reduction='mean') + use_cuda = True if "gpu" in paddle.device.get_device() else False + self.dtw_loss = SoftDTW(use_cuda=use_cuda, gamma=dtw_gamma, normalize=True,) + + def get_loss(self, predictions, targets): + """IR任务损失:峰数 + 位置 + 强度""" + + # 峰数损失 + loss_num = self.ce_loss(predictions['peak_number'], targets['peak_num']) + + batch_size = targets['peak_num'].shape[0] + + loss_pos_total = 0 + loss_height_total = 0 + valid_samples = 0 + + for i in range(batch_size): + n_peaks = int(targets['peak_num'][i]) + if n_peaks == 0: + continue + + # 位置损失 + pos_pred = predictions['peak_position'][i, :n_peaks, :].reshape([-1, self.num_position_classes]) + pos_gt = targets['peak_position'][i, :n_peaks].reshape([-1]) + loss_pos_total += self.ce_loss(pos_pred, pos_gt) + + # 强度损失(回归) + if self.use_height_prediction and 'peak_height' in predictions: + height_pred = predictions['peak_height'][i, :n_peaks].reshape([-1]) + height_gt = targets['peak_height'][i, :n_peaks].reshape([-1]) + loss_height_total += self.mse_loss(height_pred, height_gt) + + valid_samples += 1 + + loss_pos = loss_pos_total / valid_samples if valid_samples > 0 else paddle.to_tensor(0.0) + loss = loss_num + loss_pos + + if self.use_height_prediction: + loss_height = loss_height_total / valid_samples if valid_samples > 0 else paddle.to_tensor(0.0) + loss += loss_height + + return loss + + def get_metrics(self, predictions, targets): + """IR任务评估指标""" + + batch_size = targets['peak_num'].shape[0] + + # 峰数预测 + pred_nums = predictions['peak_number'].argmax(axis=1) + true_nums = targets['peak_num'] + + # 位置误差 + pos_errors = [] + # 高度误差 + height_errors = [] + + for i in range(batch_size): + n_true = int(true_nums[i]) + n_pred = int(pred_nums[i]) + + if n_true > 0 and n_pred > 0: + n_match = min(n_true, n_pred) + + # 位置误差 + pos_true = targets['peak_position'][i, :n_match].numpy() + pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1).numpy() + pos_errors.extend(pos_pred - pos_true) + + # 强度误差 + if 'peak_height' in predictions: + height_true = targets['peak_height'][i, :n_match].numpy() + height_pred = predictions['peak_height'][i, :n_match].numpy().flatten() + height_errors.extend(np.abs(height_true - height_pred)) + + # 计算指标 + pos_rmse = np.sqrt(np.mean(np.square(pos_errors))) if pos_errors else 0.0 + num_rmse = np.sqrt(np.mean(np.square((pred_nums - true_nums).numpy()))) + height_rmse = np.sqrt(np.mean(np.square(height_errors))) if height_errors else 0.0 + + metrics = { + 'num_rmse': num_rmse, + 'pos_rmse': pos_rmse, + } + + if self.use_height_prediction: + metrics['height_rmse'] = height_rmse + + return metrics \ No newline at end of file diff --git a/ppmat/models/ecformer/models/__init__.py b/ppmat/models/ecformer/models/__init__.py new file mode 100644 index 00000000..95d7e67a --- /dev/null +++ b/ppmat/models/ecformer/models/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .ECD import ECFormerECD +from .IR import ECFormerIR \ No newline at end of file diff --git a/ppmat/models/ecformer/models/base_ecformer.py b/ppmat/models/ecformer/models/base_ecformer.py new file mode 100644 index 00000000..42049dea --- /dev/null +++ b/ppmat/models/ecformer/models/base_ecformer.py @@ -0,0 +1,259 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +import paddle +import paddle.nn as nn +from paddle.nn import TransformerEncoder, TransformerEncoderLayer + +from ..encoders.gin_node_embedding import GINNodeEmbedding +from ..utils.graph_utils import pad_node_features, feat_padding_mask +from paddle_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set + +def fix_mask_for_paddle(mask, n_head=None): + """ + 简单直接的掩码修复函数 + + Args: + mask: 输入掩码 + n_head: 注意力头数 (attention mask 时需要) + """ + shape = mask.shape + assert len(shape) == 2 + # 如果是 [batch_size, src_len] 但想用作 attention mask + batch_size, s_len = shape + # [32, 73] -> [32, 1, 73, 73] + if n_head: + return mask.reshape([batch_size, 1, 1, s_len]).expand([-1, n_head, s_len, -1]) + else: + return mask.unsqueeze(1).unsqueeze(2).expand([-1, -1, s_len, -1]) + + +class ECFormerBase(nn.Layer, ABC): + """ECFormer基类 - 所有谱图预测模型的抽象接口""" + + def __init__( + self, + # GNN参数 + full_atom_feature_dims, + full_bond_feature_dims, + bond_float_names, + bond_angle_float_names, + bond_id_names, + num_layers=5, + emb_dim=128, + drop_ratio=0.0, + JK="last", + residual=False, + graph_pooling="attention", + use_geometry_enhanced=True, + max_node_num=63, + # Transformer参数 + num_heads=4, + tf_layers=2, + tf_dropout=0.1, + max_peaks=9, + ): + super().__init__() + + self.emb_dim = emb_dim + self.max_node_num = max_node_num + self.max_peaks = max_peaks + self.use_geometry_enhanced = use_geometry_enhanced + + # 1. GNN节点编码器 + self.gnn_node = GINNodeEmbedding( + full_atom_feature_dims=full_atom_feature_dims, + full_bond_feature_dims=full_bond_feature_dims, + bond_float_names=bond_float_names, + bond_angle_float_names=bond_angle_float_names, + bond_id_names=bond_id_names, + num_layers=num_layers, + emb_dim=emb_dim, + drop_ratio=drop_ratio, + JK=JK, + residual=residual, + use_geometry_enhanced=use_geometry_enhanced + ) + + # 2. 图池化层 + self.pool = self._build_pooling(graph_pooling, emb_dim) + + # 3. Query嵌入(峰查询向量) + self.query_embed = nn.Embedding(max_peaks, emb_dim) + + # 4. Transformer编码器 + self.tf_encoder = self._build_transformer(emb_dim, num_heads, tf_layers, tf_dropout) + + def _build_pooling(self, graph_pooling, emb_dim): + """构建图池化层""" + if graph_pooling == "sum": + return global_add_pool + elif graph_pooling == "mean": + return global_mean_pool + elif graph_pooling == "max": + return global_max_pool + elif graph_pooling == "attention": + return GlobalAttention( + gate_nn=nn.Sequential( + nn.Linear(emb_dim, emb_dim), + nn.BatchNorm1D(emb_dim), + nn.ReLU(), + nn.Linear(emb_dim, 1) + ) + ) + elif graph_pooling == "set2set": + return Set2Set(emb_dim, processing_steps=2) + else: + raise ValueError(f"Invalid graph pooling type: {graph_pooling}") + + def _build_transformer(self, emb_dim, num_heads, num_layers, dropout): + """构建Transformer编码器""" + + assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads" + + encoder_layer = TransformerEncoderLayer( + d_model=emb_dim, + nhead=num_heads, + dim_feedforward=emb_dim, + dropout=dropout, + activation='relu', + ) + return TransformerEncoder(encoder_layer, num_layers=num_layers) + + def encode_molecule( + self, + x, # [N, F] 原子特征 + edge_index, # [2, E] 边索引 + edge_attr, # [E, D] 边特征 + batch_data, # [N] 批次信息 + # 几何增强相关 + ba_edge_index=None, # [2, E_ba] 键角图边索引 + ba_edge_attr=None, # [E_ba, D_ba] 键角图边特征 + ): + """分子编码器 - 纯Tensor输入""" + + # 1. GNN编码 + if self.use_geometry_enhanced and ba_edge_index is not None: + h_node, _ = self.gnn_node( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + ba_edge_index=ba_edge_index, + ba_edge_attr=ba_edge_attr + ) + else: + h_node = self.gnn_node( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + ) + + # 2. 节点特征padding(需要batch信息) + batch_size = batch_data[-1] + 1 + + node_feat, node_index = pad_node_features( + h_node, batch_data, batch_size, self.max_node_num, self.emb_dim + ) + + # 3. 图池化 + h_graph = self.pool(h_node, batch_data).unsqueeze(1) + + # 4. 拼接图特征和节点特征 + + total_node_feat = paddle.concat([h_graph, node_feat], axis=1) + + # 5. 生成padding mask + node_padding_mask = feat_padding_mask(node_index, self.max_node_num) + pooling_padding_mask = paddle.zeros([node_padding_mask.shape[0], 1], dtype='float32') + total_padding_mask = paddle.concat([pooling_padding_mask, node_padding_mask], axis=1) + + return total_node_feat, total_padding_mask, node_padding_mask + + def forward(self, + x: paddle.Tensor, + edge_index: paddle.Tensor, + edge_attr: paddle.Tensor, + batch_data: paddle.Tensor, + ba_edge_index: paddle.Tensor = None, + ba_edge_attr: paddle.Tensor = None, + query_mask: paddle.Tensor = None): + # 0. 数据类型检查 + if batch_data.dtype != paddle.int64: + batch_data = batch_data.astype(paddle.int64) + + # 1. 分子编码 + node_feat, padding_mask, node_padding_mask = self.encode_molecule(x, edge_index, edge_attr,batch_data, ba_edge_index, ba_edge_attr) + + # 2. 峰数预测(从图特征) + graph_feat = node_feat[:, 0, :] + pred_number = self.pred_number_layer(graph_feat) + + # 3. Query准备 + query_feat = self.query_embed.weight.unsqueeze(0).tile([node_feat.shape[0], 1, 1]) + + # 推理时根据预测峰数生成query mask + if not self.training: + pred_peak_num = pred_number.argmax(axis=1) + peak_position = [ + [1] * int(pred_peak_num[i]) + [-1] * (self.max_peaks - int(pred_peak_num[i])) + for i in range(len(pred_peak_num)) + ] + peak_position = paddle.to_tensor(peak_position) + query_mask = get_key_padding_mask(peak_position) + + # 4. Transformer编码 + encoder_input = paddle.concat([node_feat, query_feat], axis=1) + encoder_padding_mask = paddle.concat([padding_mask, query_mask], axis=1) + + encoder_output = self.tf_encoder(encoder_input, fix_mask_for_paddle(encoder_padding_mask)) + + # 5. 峰位置和符号预测 + query_output = encoder_output[:, node_feat.shape[1]:, :] + pred_position = self.pred_position_layer(query_output) + pred_height = self.pred_height_layer(query_output) + + # 6. 注意力权重(用于可视化) + node_feat_output = encoder_output[:, :node_feat.shape[1], :] + attn_weights = paddle.einsum("bid,bjd->bij", + node_feat_output, + query_output[:, 0, :].unsqueeze(1) + ) + attn_weights = attn_weights[:, 1:, :].squeeze() + attn_mask = node_padding_mask[:, 1:] + + return { + 'peak_number': pred_number, + 'peak_position': pred_position, + 'peak_height': pred_height, + 'attention': { + 'weights': attn_weights.cpu().tolist() if not self.training else None, + 'mask': attn_mask.cpu().tolist() if not self.training else None + } + } + + @abstractmethod + def get_loss(self, predictions, targets): + """损失函数 - 由子类实现""" + pass + + @abstractmethod + def get_metrics(self, predictions, targets): + """评估指标 - 由子类实现""" + pass + +def get_key_padding_mask(tokens): + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/__init__.py b/ppmat/models/ecformer/utils/__init__.py new file mode 100644 index 00000000..3157548d --- /dev/null +++ b/ppmat/models/ecformer/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import graph_utils \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/graph_utils.py b/ppmat/models/ecformer/utils/graph_utils.py new file mode 100644 index 00000000..b31c4284 --- /dev/null +++ b/ppmat/models/ecformer/utils/graph_utils.py @@ -0,0 +1,59 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np + +def index_transform(raw_index, batch_size): + """将压缩的批次索引还原为每个样本的节点索引列表""" + + def get_index1(lst=None, batch_num=-1): + return [index for (index, value) in enumerate(lst) if value == batch_num] + + raw_index = raw_index.tolist() + index_list = [] + for batch_id in range(batch_size): + index_list.append(get_index1(raw_index, batch_id)) + return index_list + + +def get_key_padding_mask(tokens): + """生成key padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + + +def feat_padding_mask(index, max_node_num): + """根据节点索引生成特征padding mask""" + new_index = [] + for itm_list in index: + new_index.append(itm_list + [-1] * (max_node_num - len(itm_list))) + new_index = paddle.to_tensor(new_index) + return get_key_padding_mask(new_index) + + +def pad_node_features(molecule_features, batch_index, this_batch_size, max_node_num, emb_dim): + """将压缩的节点特征padding为 [batch, max_node, emb_dim] 格式""" + index_list = index_transform(batch_index, this_batch_size) + + new_batch_list = [] + for batch_id in range(this_batch_size): + empty_batch_tensor = paddle.zeros([max_node_num, emb_dim]) + for i in range(len(index_list[batch_id])): + empty_batch_tensor[i, :] = molecule_features[index_list[batch_id][i], :] + new_batch_list.append(empty_batch_tensor) + + node_feature = paddle.stack(new_batch_list, axis=0) + return node_feature, index_list \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/loss/dilate_loss.py b/ppmat/models/ecformer/utils/loss/dilate_loss.py new file mode 100644 index 00000000..1edaedd4 --- /dev/null +++ b/ppmat/models/ecformer/utils/loss/dilate_loss.py @@ -0,0 +1,25 @@ +import paddle +from . import soft_dtw +from . import path_soft_dtw + +def dilate_loss(outputs, targets, alpha, gamma, device): + # outputs, targets: shape (batch_size, N_output, 1) + batch_size, N_output = outputs.shape[0: 2] + loss_shape = 0 + softdtw_batch = soft_dtw.SoftDTWBatch.apply + D = paddle.zeros((batch_size, N_output, N_output)) + for k in range(batch_size): + Dk = soft_dtw.pairwise_distances(targets[k,:,:].reshape(-1,1), outputs[k,:,:].reshape(-1,1)) + D[k:k+1,:,:] = Dk + loss_shape = softdtw_batch(D, gamma) + + path_dtw = path_soft_dtw.PathDTWBatch.apply + path = path_dtw(D, gamma) + + Omega = soft_dtw.pairwise_distances( + paddle.arange(1.0, float(N_output+1)).reshape(N_output,1) + ) + loss_temporal = paddle.sum(path * Omega) / (N_output * N_output) + + loss = alpha*loss_shape + (1-alpha)*loss_temporal + return loss \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/loss/path_soft_dtw.py b/ppmat/models/ecformer/utils/loss/path_soft_dtw.py new file mode 100644 index 00000000..98849383 --- /dev/null +++ b/ppmat/models/ecformer/utils/loss/path_soft_dtw.py @@ -0,0 +1,134 @@ +import numpy as np +import paddle +from paddle.autograd import PyLayer +from numba import jit + + +@jit(nopython = True) +def my_max(x, gamma): + # use the log-sum-exp trick + max_x = np.max(x) + exp_x = np.exp((x - max_x) / gamma) + Z = np.sum(exp_x) + return gamma * np.log(Z) + max_x, exp_x / Z + +@jit(nopython = True) +def my_min(x,gamma) : + min_x, argmax_x = my_max(-x, gamma) + return - min_x, argmax_x + +@jit(nopython = True) +def my_max_hessian_product(p, z, gamma): + return ( p * z - p * np.sum(p * z) ) /gamma + +@jit(nopython = True) +def my_min_hessian_product(p, z, gamma): + return - my_max_hessian_product(p, z, gamma) + + +@jit(nopython = True) +def dtw_grad(theta, gamma): + m = theta.shape[0] + n = theta.shape[1] + V = np.zeros((m + 1, n + 1)) + V[:, 0] = 1e10 + V[0, :] = 1e10 + V[0, 0] = 0 + + Q = np.zeros((m + 2, n + 2, 3)) + + for i in range(1, m + 1): + for j in range(1, n + 1): + # theta is indexed starting from 0. + v, Q[i, j] = my_min(np.array([V[i, j - 1], + V[i - 1, j - 1], + V[i - 1, j]]) , gamma) + V[i, j] = theta[i - 1, j - 1] + v + + E = np.zeros((m + 2, n + 2)) + E[m + 1, :] = 0 + E[:, n + 1] = 0 + E[m + 1, n + 1] = 1 + Q[m + 1, n + 1] = 1 + + for i in range(m,0,-1): + for j in range(n,0,-1): + E[i, j] = Q[i, j + 1, 0] * E[i, j + 1] + \ + Q[i + 1, j + 1, 1] * E[i + 1, j + 1] + \ + Q[i + 1, j, 2] * E[i + 1, j] + + return V[m, n], E[1:m + 1, 1:n + 1], Q, E + + +@jit(nopython = True) +def dtw_hessian_prod(theta, Z, Q, E, gamma): + m = Z.shape[0] + n = Z.shape[1] + + V_dot = np.zeros((m + 1, n + 1)) + V_dot[0, 0] = 0 + + Q_dot = np.zeros((m + 2, n + 2, 3)) + for i in range(1, m + 1): + for j in range(1, n + 1): + # theta is indexed starting from 0. + V_dot[i, j] = Z[i - 1, j - 1] + \ + Q[i, j, 0] * V_dot[i, j - 1] + \ + Q[i, j, 1] * V_dot[i - 1, j - 1] + \ + Q[i, j, 2] * V_dot[i - 1, j] + + v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]]) + Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma) + E_dot = np.zeros((m + 2, n + 2)) + + for j in range(n,0,-1): + for i in range(m,0,-1): + E_dot[i, j] = Q_dot[i, j + 1, 0] * E[i, j + 1] + \ + Q[i, j + 1, 0] * E_dot[i, j + 1] + \ + Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1] + \ + Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1] + \ + Q_dot[i + 1, j, 2] * E[i + 1, j] + \ + Q[i + 1, j, 2] * E_dot[i + 1, j] + + return V_dot[m, n], E_dot[1:m + 1, 1:n + 1] + + +class PathDTWBatch(PyLayer): + @staticmethod + def forward(ctx, D, gamma): # D.shape: [batch_size, N , N] + batch_size, N, N = D.shape + device = D.place + D_cpu = D.detach().cpu().numpy() + gamma_paddle = paddle.to_tensor([gamma], dtype='float32').to(device) + + grad_paddle = paddle.zeros((batch_size, N ,N), place=device) + Q_paddle = paddle.zeros((batch_size, N+2 ,N+2, 3), place=device) + E_paddle = paddle.zeros((batch_size, N+2 ,N+2), place=device) + + for k in range(0, batch_size): # loop over all D in the batch + _, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k,:,:], gamma) + grad_paddle[k,:,:] = paddle.to_tensor(grad_cpu_k, dtype='float32').to(device) + Q_paddle[k,:,:,:] = paddle.to_tensor(Q_cpu_k, dtype='float32').to(device) + E_paddle[k,:,:] = paddle.to_tensor(E_cpu_k, dtype='float32').to(device) + + ctx.save_for_backward(grad_paddle, D, Q_paddle, E_paddle, gamma_paddle) + return paddle.mean(grad_paddle, axis=0) + + @staticmethod + def backward(ctx, grad_output): + device = grad_output.place + grad_paddle, D_paddle, Q_paddle, E_paddle, gamma = ctx.saved_tensor() + D_cpu = D_paddle.detach().cpu().numpy() + Q_cpu = Q_paddle.detach().cpu().numpy() + E_cpu = E_paddle.detach().cpu().numpy() + gamma = gamma.detach().cpu().numpy()[0] + Z = grad_output.detach().cpu().numpy() + + batch_size, N, N = D_cpu.shape + Hessian = paddle.zeros((batch_size, N ,N), place=device) + + for k in range(0, batch_size): + _, hess_k = dtw_hessian_prod(D_cpu[k,:,:], Z, Q_cpu[k,:,:,:], E_cpu[k,:,:], gamma) + Hessian[k:k+1,:,:] = paddle.to_tensor(hess_k, dtype='float32').to(device) + + return Hessian, None \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/loss/soft_dtw.py b/ppmat/models/ecformer/utils/loss/soft_dtw.py new file mode 100644 index 00000000..cc787802 --- /dev/null +++ b/ppmat/models/ecformer/utils/loss/soft_dtw.py @@ -0,0 +1,97 @@ +import numpy as np +import paddle +from numba import jit +from paddle.autograd import PyLayer + +def pairwise_distances(x, y=None): + ''' + Input: x is a Nxd matrix + y is an optional Mxd matirx + Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] + if y is not given then use 'y=x'. + i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 + ''' + x_norm = (x**2).sum(1).reshape([-1, 1]) + if y is not None: + y_t = paddle.transpose(y, perm=[0, 1]) + y_norm = (y**2).sum(1).reshape([1, -1]) + else: + y_t = paddle.transpose(x, perm=[0, 1]) + y_norm = x_norm.reshape([1, -1]) + + dist = x_norm + y_norm - 2.0 * paddle.mm(x, y_t) + return paddle.clip(dist, 0.0, float('inf')) + +@jit(nopython = True) +def compute_softdtw(D, gamma): + N = D.shape[0] + M = D.shape[1] + R = np.zeros((N + 2, M + 2)) + 1e8 + R[0, 0] = 0 + for j in range(1, M + 1): + for i in range(1, N + 1): + r0 = -R[i - 1, j - 1] / gamma + r1 = -R[i - 1, j] / gamma + r2 = -R[i, j - 1] / gamma + rmax = max(max(r0, r1), r2) + rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) + softmin = - gamma * (np.log(rsum) + rmax) + R[i, j] = D[i - 1, j - 1] + softmin + return R + +@jit(nopython = True) +def compute_softdtw_backward(D_, R, gamma): + N = D_.shape[0] + M = D_.shape[1] + D = np.zeros((N + 2, M + 2)) + E = np.zeros((N + 2, M + 2)) + D[1:N + 1, 1:M + 1] = D_ + E[-1, -1] = 1 + R[:, -1] = -1e8 + R[-1, :] = -1e8 + R[-1, -1] = R[-2, -2] + for j in range(M, 0, -1): + for i in range(N, 0, -1): + a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma + b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma + c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma + a = np.exp(a0) + b = np.exp(b0) + c = np.exp(c0) + E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c + return E[1:N + 1, 1:M + 1] + + +class SoftDTWBatch(PyLayer): + @staticmethod + def forward(ctx, D, gamma = 1.0): # D.shape: [batch_size, N , N] + dev = D.place + batch_size, N, N = D.shape + gamma = paddle.to_tensor([gamma], dtype='float32').to(dev) + D_ = D.detach().cpu().numpy() + g_ = gamma.item() + + total_loss = 0 + R = paddle.zeros((batch_size, N+2, N+2)) + for k in range(0, batch_size): # loop over all D in the batch + Rk = paddle.to_tensor(compute_softdtw(D_[k,:,:], g_), dtype='float32').to(dev) + R[k:k+1,:,:] = Rk + total_loss = total_loss + Rk[-2,-2] + ctx.save_for_backward(D, R, gamma) + return total_loss / batch_size + + @staticmethod + def backward(ctx, grad_output): + dev = grad_output.place + D, R, gamma = ctx.saved_tensor() + batch_size, N, N = D.shape + D_ = D.detach().cpu().numpy() + R_ = R.detach().cpu().numpy() + g_ = gamma.item() + + E = paddle.zeros((batch_size, N, N)) + for k in range(batch_size): + Ek = paddle.to_tensor(compute_softdtw_backward(D_[k,:,:], R_[k,:,:], g_), dtype='float32').to(dev) + E[k:k+1,:,:] = Ek + + return grad_output * E, None \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/loss/soft_dtw_cuda.py b/ppmat/models/ecformer/utils/loss/soft_dtw_cuda.py new file mode 100644 index 00000000..d2a545b7 --- /dev/null +++ b/ppmat/models/ecformer/utils/loss/soft_dtw_cuda.py @@ -0,0 +1,427 @@ +# MIT License +# +# Copyright (c) 2020 Mehran Maghoumi +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ---------------------------------------------------------------------------------------------------------------------- + +import numpy as np +import paddle +from numba import jit, prange +from paddle.autograd import PyLayer +from numba import cuda +import math + +# ---------------------------------------------------------------------------------------------------------------------- +@cuda.jit +def compute_softdtw_cuda(D, gamma, bandwidth, max_i, max_j, n_passes, R): + """ + :param seq_len: The length of the sequence (both inputs are assumed to be of the same size) + :param n_passes: 2 * seq_len - 1 (The number of anti-diagonals) + """ + # Each block processes one pair of examples + b = cuda.blockIdx.x + # We have as many threads as seq_len, because the most number of threads we need + # is equal to the number of elements on the largest anti-diagonal + tid = cuda.threadIdx.x + + # Compute I, J, the indices from [0, seq_len) + + # The row index is always the same as tid + I = tid + + inv_gamma = 1.0 / gamma + + # Go over each anti-diagonal. Only process threads that fall on the current on the anti-diagonal + for p in range(n_passes): + + # The index is actually 'p - tid' but need to force it in-bounds + J = max(0, min(p - tid, max_j - 1)) + + # For simplicity, we define i, j which start from 1 (offset from I, J) + i = I + 1 + j = J + 1 + + # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds + if I + J == p and (I < max_i and J < max_j): + # Don't compute if outside bandwidth + if not (abs(i - j) > bandwidth > 0): + r0 = -R[b, i - 1, j - 1] * inv_gamma + r1 = -R[b, i - 1, j] * inv_gamma + r2 = -R[b, i, j - 1] * inv_gamma + rmax = max(max(r0, r1), r2) + rsum = math.exp(r0 - rmax) + math.exp(r1 - rmax) + math.exp(r2 - rmax) + softmin = -gamma * (math.log(rsum) + rmax) + R[b, i, j] = D[b, i - 1, j - 1] + softmin + + # Wait for other threads in this block + cuda.syncthreads() + +# ---------------------------------------------------------------------------------------------------------------------- +@cuda.jit +def compute_softdtw_backward_cuda(D, R, inv_gamma, bandwidth, max_i, max_j, n_passes, E): + k = cuda.blockIdx.x + tid = cuda.threadIdx.x + + # Indexing logic is the same as above, however, the anti-diagonal needs to + # progress backwards + I = tid + + for p in range(n_passes): + # Reverse the order to make the loop go backward + rev_p = n_passes - p - 1 + + # convert tid to I, J, then i, j + J = max(0, min(rev_p - tid, max_j - 1)) + + i = I + 1 + j = J + 1 + + # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds + if I + J == rev_p and (I < max_i and J < max_j): + + if math.isinf(R[k, i, j]): + R[k, i, j] = -math.inf + + # Don't compute if outside bandwidth + if not (abs(i - j) > bandwidth > 0): + a = math.exp((R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) * inv_gamma) + b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) * inv_gamma) + c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) * inv_gamma) + E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c + + # Wait for other threads in this block + cuda.syncthreads() + +# ---------------------------------------------------------------------------------------------------------------------- +class _SoftDTWCUDA(PyLayer): + """ + CUDA implementation is inspired by the diagonal one proposed in https://ieeexplore.ieee.org/document/8400444: + "Developing a pattern discovery method in time series data and its GPU acceleration" + """ + + @staticmethod + def forward(ctx, D, gamma, bandwidth): + dev = D.place + dtype = D.dtype + gamma = paddle.to_tensor([gamma]) + bandwidth = paddle.to_tensor([bandwidth]) + + B = D.shape[0] + N = D.shape[1] + M = D.shape[2] + threads_per_block = max(N, M) + n_passes = 2 * threads_per_block - 1 + + # Prepare the output array + R = paddle.ones((B, N + 2, M + 2), dtype=dtype) * math.inf + R[:, 0, 0] = 0 + + # Run the CUDA kernel. + # Set CUDA's grid size to be equal to the batch size (every CUDA block processes one sample pair) + # Set the CUDA block size to be equal to the length of the longer sequence (equal to the size of the largest diagonal) + compute_softdtw_cuda[B, threads_per_block](cuda.as_cuda_array(D.detach()), + gamma.item(), bandwidth.item(), N, M, n_passes, + cuda.as_cuda_array(R)) + ctx.save_for_backward(D, R.clone(), gamma, bandwidth) + return R[:, -2, -2] + + @staticmethod + def backward(ctx, grad_output): + D, R, gamma, bandwidth = ctx.saved_tensor() + dev = grad_output.place + dtype = grad_output.dtype + + B = D.shape[0] + N = D.shape[1] + M = D.shape[2] + threads_per_block = max(N, M) + n_passes = 2 * threads_per_block - 1 + + D_ = paddle.zeros((B, N + 2, M + 2), dtype=dtype) + D_[:, 1:N + 1, 1:M + 1] = D + + R[:, :, -1] = -math.inf + R[:, -1, :] = -math.inf + R[:, -1, -1] = R[:, -2, -2] + + E = paddle.zeros((B, N + 2, M + 2), dtype=dtype) + E[:, -1, -1] = 1 + + # Grid and block sizes are set same as done above for the forward() call + compute_softdtw_backward_cuda[B, threads_per_block](cuda.as_cuda_array(D_), + cuda.as_cuda_array(R), + 1.0 / gamma.item(), bandwidth.item(), N, M, n_passes, + cuda.as_cuda_array(E)) + E = E[:, 1:N + 1, 1:M + 1] + return grad_output.reshape([-1, 1, 1]).expand_as(E) * E, None, None + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# The following is the CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw +# Credit goes to Kanru Hua. +# I've added support for batching and pruning. +# +# ---------------------------------------------------------------------------------------------------------------------- +@jit(nopython=True, parallel=True) +def compute_softdtw(D, gamma, bandwidth): + B = D.shape[0] + N = D.shape[1] + M = D.shape[2] + R = np.ones((B, N + 2, M + 2)) * np.inf + R[:, 0, 0] = 0 + for b in prange(B): + for j in range(1, M + 1): + for i in range(1, N + 1): + + # Check the pruning condition + if 0 < bandwidth < np.abs(i - j): + continue + + r0 = -R[b, i - 1, j - 1] / gamma + r1 = -R[b, i - 1, j] / gamma + r2 = -R[b, i, j - 1] / gamma + rmax = max(max(r0, r1), r2) + rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) + softmin = - gamma * (np.log(rsum) + rmax) + R[b, i, j] = D[b, i - 1, j - 1] + softmin + return R + +# ---------------------------------------------------------------------------------------------------------------------- +@jit(nopython=True, parallel=True) +def compute_softdtw_backward(D_, R, gamma, bandwidth): + B = D_.shape[0] + N = D_.shape[1] + M = D_.shape[2] + D = np.zeros((B, N + 2, M + 2)) + E = np.zeros((B, N + 2, M + 2)) + D[:, 1:N + 1, 1:M + 1] = D_ + E[:, -1, -1] = 1 + R[:, :, -1] = -np.inf + R[:, -1, :] = -np.inf + R[:, -1, -1] = R[:, -2, -2] + for k in prange(B): + for j in range(M, 0, -1): + for i in range(N, 0, -1): + + if np.isinf(R[k, i, j]): + R[k, i, j] = -np.inf + + # Check the pruning condition + if 0 < bandwidth < np.abs(i - j): + continue + + a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma + b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma + c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma + a = np.exp(a0) + b = np.exp(b0) + c = np.exp(c0) + E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c + return E[:, 1:N + 1, 1:M + 1] + +# ---------------------------------------------------------------------------------------------------------------------- +class _SoftDTW(PyLayer): + """ + CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw + """ + + @staticmethod + def forward(ctx, D, gamma, bandwidth): + dev = D.place + dtype = D.dtype + gamma = paddle.Tensor([gamma]).to(dev).astype(dtype) + bandwidth = paddle.Tensor([bandwidth]).to(dev).astype(dtype) + D_ = D.detach().cpu().numpy() + g_ = gamma.item() + b_ = bandwidth.item() + R = paddle.Tensor(compute_softdtw(D_, g_, b_)).to(dev).astype(dtype) + ctx.save_for_backward(D, R, gamma, bandwidth) + return R[:, -2, -2] + + @staticmethod + def backward(ctx, grad_output): + D, R, gamma, bandwidth = ctx.saved_tensor() + dev = grad_output.place + dtype = grad_output.dtype + D_ = D.detach().cpu().numpy() + R_ = R.detach().cpu().numpy() + g_ = gamma.item() + b_ = bandwidth.item() + E = paddle.Tensor(compute_softdtw_backward(D_, R_, g_, b_)).to(dev).astype(dtype) + return grad_output.reshape([-1, 1, 1]).expand_as(E) * E, None, None + +# ---------------------------------------------------------------------------------------------------------------------- +class SoftDTW(paddle.nn.Layer): + """ + The soft DTW implementation that optionally supports CUDA + """ + + def __init__(self, use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_func=None): + """ + Initializes a new instance using the supplied parameters + :param use_cuda: Flag indicating whether the CUDA implementation should be used + :param gamma: sDTW's gamma parameter + :param normalize: Flag indicating whether to perform normalization + (as discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790) + :param bandwidth: Sakoe-Chiba bandwidth for pruning. Passing 'None' will disable pruning. + :param dist_func: Optional point-wise distance function to use. If 'None', then a default Euclidean distance function will be used. + """ + super(SoftDTW, self).__init__() # ���ָ����ʼ���߼� + self.normalize = normalize + self.gamma = gamma + self.bandwidth = 0 if bandwidth is None else float(bandwidth) + self.use_cuda = use_cuda + + # Set the distance function + if dist_func is not None: + self.dist_func = dist_func + else: + self.dist_func = SoftDTW._euclidean_dist_func + + def _get_func_dtw(self, x, y): + """ + Checks the inputs and selects the proper implementation to use. + """ + bx, lx, dx = x.shape + by, ly, dy = y.shape + # Make sure the dimensions match + assert bx == by # Equal batch sizes + assert dx == dy # Equal feature dimensions + + use_cuda = self.use_cuda + + if use_cuda and (lx > 1024 or ly > 1024): # We should be able to spawn enough threads in CUDA + print("SoftDTW: Cannot use CUDA because the sequence length > 1024 (the maximum block size supported by CUDA)") + use_cuda = False + + # Finally, return the correct function + return _SoftDTWCUDA.apply if use_cuda else _SoftDTW.apply + + @staticmethod + def _euclidean_dist_func(x, y): + """ + Calculates the Euclidean distance between each element in x and y per timestep + """ + n = x.shape[1] + m = y.shape[1] + d = x.shape[2] + x = x.unsqueeze(2).expand([-1, n, m, d]) # paddle expand ��Ҫ��ʽ�б� + y = y.unsqueeze(1).expand([-1, n, m, d]) # paddle expand ��Ҫ��ʽ�б� + return paddle.pow(x - y, 2).sum(3) + + def forward(self, X, Y): + """ + Compute the soft-DTW value between X and Y + :param X: One batch of examples, batch_size x seq_len x dims + :param Y: The other batch of examples, batch_size x seq_len x dims + :return: The computed results + """ + + # Check the inputs and get the correct implementation + func_dtw = self._get_func_dtw(X, Y) + + if self.normalize: + # Stack everything up and run + x = paddle.concat([X, X, Y]) + y = paddle.concat([Y, X, Y]) + D = self.dist_func(x, y) + out = func_dtw(D, self.gamma, self.bandwidth) + out_xy, out_xx, out_yy = paddle.split(out, X.shape[0]) + return out_xy - 1 / 2 * (out_xx + out_yy) + else: + D_xy = self.dist_func(X, Y) + return func_dtw(D_xy, self.gamma, self.bandwidth) + +# ---------------------------------------------------------------------------------------------------------------------- +def timed_run(a, b, sdtw): + """ + Runs a and b through sdtw, and times the forward and backward passes. + Assumes that a requires gradients. + :return: timing, forward result, backward result + """ + from timeit import default_timer as timer + + # Forward pass + start = timer() + forward = sdtw(a, b) + end = timer() + t = end - start + + grad_outputs = paddle.ones_like(forward) + + # Backward + start = timer() + grads = paddle.autograd.grad(forward, a, grad_outputs=grad_outputs)[0] + end = timer() + + # Total time + t += end - start + + return t, forward, grads + +# ---------------------------------------------------------------------------------------------------------------------- +def profile(batch_size, seq_len_a, seq_len_b, dims, tol_backward): + sdtw = SoftDTW(False, gamma=1.0, normalize=False) + sdtw_cuda = SoftDTW(True, gamma=1.0, normalize=False) + n_iters = 6 + + print("Profiling forward() + backward() times for batch_size={}, seq_len_a={}, seq_len_b={}, dims={}...".format(batch_size, seq_len_a, seq_len_b, dims)) + + times_cpu = [] + times_gpu = [] + + for i in range(n_iters): + a_cpu = paddle.rand((batch_size, seq_len_a, dims), requires_grad=True) + b_cpu = paddle.rand((batch_size, seq_len_b, dims)) + a_gpu = a_cpu.cuda() + b_gpu = b_cpu.cuda() + + # GPU + t_gpu, forward_gpu, backward_gpu = timed_run(a_gpu, b_gpu, sdtw_cuda) + + # CPU + t_cpu, forward_cpu, backward_cpu = timed_run(a_cpu, b_cpu, sdtw) + + # Verify the results + assert paddle.allclose(forward_cpu, forward_gpu.cpu()) + assert paddle.allclose(backward_cpu, backward_gpu.cpu(), atol=tol_backward) + + if i > 0: # Ignore the first time we run, in case this is a cold start (because timings are off at a cold start of the script) + times_cpu += [t_cpu] + times_gpu += [t_gpu] + + # Average and log + avg_cpu = np.mean(times_cpu) + avg_gpu = np.mean(times_gpu) + print(" CPU: ", avg_cpu) + print(" GPU: ", avg_gpu) + print(" Speedup: ", avg_cpu / avg_gpu) + print() + +# ---------------------------------------------------------------------------------------------------------------------- +if __name__ == "__main__": + from timeit import default_timer as timer + + paddle.seed(1234) + + profile(128, 17, 15, 2, tol_backward=1e-6) + profile(512, 64, 64, 2, tol_backward=1e-4) + profile(512, 256, 256, 2, tol_backward=1e-3) \ No newline at end of file From d684eb0ac58ffa9e52497bd8fffb9c5caa7b9643 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Tue, 24 Feb 2026 15:42:11 +0800 Subject: [PATCH 02/16] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E5=8E=9F?= =?UTF-8?q?=E8=AE=BA=E6=96=87=E4=B8=AD=E4=BD=BF=E7=94=A8=E7=9A=84=E4=B8=A4?= =?UTF-8?q?=E4=B8=AA=E6=95=B0=E6=8D=AE=E9=9B=86=E7=9B=B8=E5=85=B3=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/datasets/ECDFormerDataset/__init__.py | 407 +++++++++ .../datasets/ECDFormerDataset/colored_tqdm.py | 76 ++ .../ECDFormerDataset/compound_tools.py | 828 ++++++++++++++++++ ppmat/datasets/ECDFormerDataset/dataloader.py | 56 ++ ppmat/datasets/ECDFormerDataset/eval_func.py | 147 ++++ ppmat/datasets/ECDFormerDataset/place_env.py | 173 ++++ ppmat/datasets/ECDFormerDataset/util_func.py | 43 + ppmat/datasets/IRDataset/__init__.py | 439 ++++++++++ ppmat/datasets/IRDataset/colored_tqdm.py | 76 ++ ppmat/datasets/IRDataset/compound_tools.py | 828 ++++++++++++++++++ ppmat/datasets/IRDataset/place_env.py | 173 ++++ ppmat/datasets/__init__.py | 7 + 12 files changed, 3253 insertions(+) create mode 100644 ppmat/datasets/ECDFormerDataset/__init__.py create mode 100644 ppmat/datasets/ECDFormerDataset/colored_tqdm.py create mode 100644 ppmat/datasets/ECDFormerDataset/compound_tools.py create mode 100644 ppmat/datasets/ECDFormerDataset/dataloader.py create mode 100644 ppmat/datasets/ECDFormerDataset/eval_func.py create mode 100644 ppmat/datasets/ECDFormerDataset/place_env.py create mode 100644 ppmat/datasets/ECDFormerDataset/util_func.py create mode 100644 ppmat/datasets/IRDataset/__init__.py create mode 100644 ppmat/datasets/IRDataset/colored_tqdm.py create mode 100644 ppmat/datasets/IRDataset/compound_tools.py create mode 100644 ppmat/datasets/IRDataset/place_env.py diff --git a/ppmat/datasets/ECDFormerDataset/__init__.py b/ppmat/datasets/ECDFormerDataset/__init__.py new file mode 100644 index 00000000..7934e0ee --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/__init__.py @@ -0,0 +1,407 @@ +# __init__.py +""" +ECDFormer数据集加载模块 +""" + +import os +import numpy as np +import pandas as pd +import paddle +from paddle.io import Dataset +from paddle_geometric.data import Data + +from .compound_tools import get_atom_feature_dims, get_bond_feature_dims +from .util_func import normalize_func +from .eval_func import get_sequence_peak +from .colored_tqdm import ColoredTqdm as tqdm +from .place_env import PlaceEnv +from .dataloader import ECDFormerDataset_DataLoader + + +# ----------------Commonly-used Parameters---------------- +atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", +] +bond_id_names = ["bond_dir", "bond_type", "is_in_ring"] +full_atom_feature_dims = get_atom_feature_dims(atom_id_names) +full_bond_feature_dims = get_bond_feature_dims(bond_id_names) +bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] +column_specify={ + 'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4], + 'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9], + 'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14], + 'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19], + 'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24] +} +bond_float_names = [] + + +def get_key_padding_mask(tokens): + """生成query padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + + +def Construct_dataset(dataset, data_index, path): + """ + 从原始特征构建图数据 + 完全复用原型程序的Construct_dataset逻辑 + """ + graph_atom_bond = [] + graph_bond_angle = [] + + all_descriptor = np.load(os.path.join(path, 'descriptor_all_column.npy')) # (25847, 1826) + + for i in tqdm(range(len(dataset)), desc="Constructing graphs"): + data = dataset[i] + + # 收集原子特征 + atom_feature = [] + for name in atom_id_names: + atom_feature.append(data[name]) + + # 收集键特征 + bond_feature = [] + for name in bond_id_names[0:3]: + bond_feature.append(data[name]) + + # 转换为Tensor + atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') + bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') + bond_float_feature = paddle.to_tensor(data['bond_length'].astype(np.float32)) + bond_angle_feature = paddle.to_tensor(data['bond_angle'].astype(np.float32)) + edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') + bond_index = paddle.to_tensor(data['BondAngleGraph_edges'].T, dtype='int64') + data_index_int = paddle.to_tensor(np.array(data_index[i]), dtype='int64') + + # 添加描述符特征(与原型程序完全一致) + TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 + RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] + RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] + MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] + MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] + + # 合并特征 + bond_feature = paddle.concat( + [bond_feature.astype(bond_float_feature.dtype), + bond_float_feature.reshape([-1, 1])], + axis=1 + ) + + bond_angle_feature = paddle.concat( + [bond_angle_feature.reshape([-1, 1]), TPSA.reshape([-1, 1])], + axis=1 + ) + bond_angle_feature = paddle.concat([bond_angle_feature, RASA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, RPSA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MDEC.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MATS.reshape([-1, 1])], axis=1) + + # 创建Data对象 + data_atom_bond = Data( + x=atom_feature, + edge_index=edge_index, + edge_attr=bond_feature, + data_index=data_index_int, + ) + data_bond_angle = Data( + edge_index=bond_index, + edge_attr=bond_angle_feature, + num_nodes=atom_feature.shape[0] + ) + + graph_atom_bond.append(data_atom_bond) + graph_bond_angle.append(data_bond_angle) + + return graph_atom_bond, graph_bond_angle + + +def read_total_ecd(sample_path, fix_length=20): + """ + 读取所有ECD光谱文件,提取峰值信息 + 完全复用原型程序的read_total_ecd逻辑 + """ + filepaths = [ + os.path.join(sample_path, "500ECD/data/"), + os.path.join(sample_path, "501-2000ECD/data/"), + os.path.join(sample_path, "2k-6kECD/data/"), + os.path.join(sample_path, "6k-8kECD/data/"), + os.path.join(sample_path, "8k-11kECD/data/"), + ] + + ecd_dict = {} + ecd_original_dict = {} + + for filepath in filepaths: + if not os.path.exists(filepath): + continue + files = os.listdir(filepath) + for file in files: + if not file.endswith(".csv"): + continue + fileid = int(file[:-4]) + single_file_path = os.path.join(filepath, file) + ECD_info = pd.read_csv(single_file_path).to_dict(orient='list') + wavelengths_o, mdegs_o = ECD_info['Wavelength (nm)'], ECD_info['ECD (Mdeg)'] + + wavelengths = [int(i) for i in wavelengths_o] + # 将小值置零 + mdegs = [int(i) if abs(i) > 1 else 0 for i in mdegs_o] + + # 去除前后零值 + begin, end = 0, 0 + for i in range(len(mdegs)): + if mdegs[i] != 0: + begin = i + break + for i in range(len(mdegs) - 1, 0, -1): + if mdegs[i] != 0: + end = i + break + + ecd_dict[fileid] = { + 'wavelengths': wavelengths[begin: end + 1], + 'ecd': mdegs[begin: end + 1], + } + ecd_original_dict[fileid] = { + 'wavelengths': wavelengths, + 'ecd': mdegs, + } + + # 处理光谱序列,提取峰值 + ecd_final_list = [] + for key, itm in ecd_dict.items(): + # 等间隔采样 + distance = int(len(itm['ecd']) / (fix_length - 1)) + sequence_org = [itm['ecd'][i] for i in range(0, len(itm['ecd']), distance)][:fix_length] + + # 归一化 + sequence = normalize_func(sequence_org, norm_range=[-100, 100]) + + # padding到固定长度 + if len(sequence) < fix_length: + sequence.extend([0] * (fix_length - len(sequence))) + sequence_org.extend([0] * (fix_length - len(sequence_org))) + assert len(sequence) == fix_length + + # 生成峰值掩码 + peak_mask = [0] * len(sequence) + for i in range(1, len(sequence) - 1): + if sequence[i - 1] < sequence[i] and sequence[i] > sequence[i + 1]: + if peak_mask[i - 1] != 2: + peak_mask[i - 1] = 1 + peak_mask[i] = 2 + if peak_mask[i + 1] != 2: + peak_mask[i + 1] = 1 + if sequence[i - 1] > sequence[i] and sequence[i] < sequence[i + 1]: + if peak_mask[i - 1] != 2: + peak_mask[i - 1] = 1 + peak_mask[i] = 2 + if peak_mask[i + 1] != 2: + peak_mask[i + 1] = 1 + + # 提取峰值位置 + peak_position_list = get_sequence_peak(sequence) + peak_number = len(peak_position_list) + assert peak_number < 9, f"Peak number {peak_number} >= 9" + + # 峰值符号 + peak_height_list = [] + for i in peak_position_list: + peak_height_list.append(1 if sequence[i] >= 0 else 0) + + # padding到9个峰 + peak_position_list = peak_position_list + [-1] * (9 - peak_number) + peak_height_list = peak_height_list + [-1] * (9 - peak_number) + query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) + + tmp_dict = { + 'id': key, + 'seq': [0] + sequence, + 'seq_original': sequence_org, + 'seq_mask': peak_mask, + 'peak_num': peak_number, + 'peak_position': peak_position_list, + 'peak_height': peak_height_list, + 'query_mask': query_padding_mask.unsqueeze(0), + } + ecd_final_list.append(tmp_dict) + + ecd_final_list.sort(key=lambda x: x['id']) + return ecd_final_list, ecd_original_dict + + +def GetAtomBondAngleDataset( + sample_path, + dataset_all, + index_all, + hand_idx_dict, + line_idx_dict +): + """ + 核心函数:构建并返回切好的图数据集 + + Args: + sample_path: ECD光谱文件路径 + dataset_all: 从npy加载的info列表 + index_all: 索引列表 + hand_idx_dict: 手性对映射 + line_idx_dict: 行号映射 + + Returns: + dataset_graph_atom_bond: atom-bond图列表 + dataset_graph_bond_angle: bond-angle图列表 + """ + # 1. 读取ECD光谱序列 + ecd_sequences, ecd_original_sequences = read_total_ecd(sample_path) + + # 2. 构建图数据 + total_graph_atom_bond, total_graph_bond_angle = Construct_dataset( + dataset_all, index_all, sample_path + ) + print("Case Before Process = ", len(total_graph_atom_bond), len(total_graph_bond_angle)) + + # 3. 将光谱序列信息附加到图数据上 + dataset_graph_atom_bond, dataset_graph_bond_angle = [], [] + + for itm in ecd_sequences: + line_num = itm['id'] - 1 + atom_bond = total_graph_atom_bond[line_num] + + # 附加光谱信息 + atom_bond.sequence = paddle.to_tensor([itm['seq']]) + atom_bond.ecd_id = paddle.to_tensor(itm['id']) + atom_bond.seq_mask = paddle.to_tensor([itm['seq_mask']]) + atom_bond.seq_original = paddle.to_tensor([itm['seq_original']]) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + dataset_graph_atom_bond.append(atom_bond) + dataset_graph_bond_angle.append(total_graph_bond_angle[line_num]) + + # 4. 对映体增强:添加对映体样本 + hand_id, unnamed_id = line_idx_dict[line_num]['hand_id'], line_idx_dict[line_num]['unnamed_id'] + another_line_num = -1 + + for alternative in hand_idx_dict[hand_id]: + if alternative['unnamed_id'] != unnamed_id: + another_line_num = alternative['line_number'] + break + + assert another_line_num != -1, f"cannot find the hand info of {line_num}" + + # 对映体:光谱取反 + atom_bond_oppo = total_graph_atom_bond[another_line_num] + atom_bond_oppo.sequence = paddle.neg(paddle.to_tensor([itm['seq']])) + atom_bond_oppo.ecd_id = paddle.to_tensor(another_line_num + 1) + atom_bond_oppo.seq_mask = paddle.to_tensor([itm['seq_mask']]) + atom_bond_oppo.seq_original = paddle.neg(paddle.to_tensor([itm['seq_original']])) + atom_bond_oppo.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond_oppo.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond_oppo.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond_oppo.query_mask = itm['query_mask'] + + dataset_graph_atom_bond.append(atom_bond_oppo) + dataset_graph_bond_angle.append(total_graph_bond_angle[another_line_num]) + + total_num = len(dataset_graph_atom_bond) + print("Case After Process = ", len(dataset_graph_atom_bond), len(dataset_graph_bond_angle)) + print('=================== Data prepared ================\n') + + return dataset_graph_atom_bond, dataset_graph_bond_angle + + +_cache = None + +class ECDFormerDataset(Dataset): + """ + ECDFormer数据集类 + 返回 (atom_bond_graph, bond_angle_graph) + """ + @PlaceEnv(paddle.CPUPlace()) + def __init__(self, + path: str = "dataset/ECD", + Use_geometry_enhanced: bool = True, + Use_column_info: bool = False): + global _cache + + if _cache: + self.graph_atom_bond, self.graph_bond_angle = _cache + return + + # 保存参数 + self.path = path + self.Use_geometry_enhanced = Use_geometry_enhanced + self.Use_column_info = Use_column_info + + # 1. 加载npy文件 + print(f"Loading ECDFormer dataset from {path}") + self.ecd_dataset = np.load( + os.path.join(path, 'ecd_column_charity_new_smiles.npy'), + allow_pickle=True + ).tolist() + + # 2. 加载csv文件 + self.ecd_info = pd.read_csv( + os.path.join(path, 'ecd_info.csv'), + encoding='gbk' + ) + + # 3. 提取info列表和索引 + self.dataset_all = [item['info'] for item in self.ecd_dataset] + self.index_all = self.ecd_info['Unnamed: 0'].values + + # 4. 构建手性对映射 + self.unnamed_idx_dict, self.hand_idx_dict, self.line_idx_dict = {}, {}, {} + for i, itm in enumerate(self.ecd_dataset): + self.line_idx_dict[i] = { + 'hand_id': itm['hand_id'], + 'unnamed_id': itm['id'], + 'smiles': itm['smiles'] + } + + if itm['id'] not in self.unnamed_idx_dict: + self.unnamed_idx_dict[itm['id']] = { + 'line_number': i, + 'hand_id': itm['hand_id'], + 'smiles': itm['smiles'] + } + else: + raise AssertionError(f"Duplicate unnamed id: {itm['id']}") + + if itm['hand_id'] not in self.hand_idx_dict: + self.hand_idx_dict[itm['hand_id']] = [] + self.hand_idx_dict[itm['hand_id']].append({ + 'line_number': i, + 'unnamed_id': itm['id'], + 'smiles': itm['smiles'] + }) + + # 5. 构建图数据集(核心调用) + self.graph_atom_bond, self.graph_bond_angle = GetAtomBondAngleDataset( + sample_path=path, + dataset_all=self.dataset_all, + index_all=self.index_all, + hand_idx_dict=self.hand_idx_dict, + line_idx_dict=self.line_idx_dict + ) + + _cache = (self.graph_atom_bond, self.graph_bond_angle) + assert len(self.graph_atom_bond) == len(self.graph_bond_angle), \ + "Mismatch between atom_bond and bond_angle graph lengths" + + def __len__(self): + return len(self.graph_atom_bond) + + def __getitem__(self, idx): + """ + 返回: + atom_bond_graph: paddle_geometric.data.Data + bond_angle_graph: paddle_geometric.data.Data + """ + return self.graph_atom_bond[idx], self.graph_bond_angle[idx] \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/colored_tqdm.py b/ppmat/datasets/ECDFormerDataset/colored_tqdm.py new file mode 100644 index 00000000..c6b9cff0 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/colored_tqdm.py @@ -0,0 +1,76 @@ +from tqdm import tqdm +import time +import os;os.system("") #兼容windows + +def hex_to_ansi(hex_color: str, background: bool = False) -> str: + """ + 将十六进制颜色转换为ANSI转义序列 + + Args: + hex_color: 十六进制颜色,如 '#dda0a0' 或 'dda0a0' + background: True表示背景色,False表示前景色 + + Returns: + ANSI转义序列字符串,如 '\033[38;2;221;160;160m' + + Example: + >>> print(f"{hex_to_ansi('#dda0a0')}Hello{hex_to_ansi('#000000')} World") + >>> print(f"{hex_to_ansi('dda0a0', background=True)}背景色{hex_to_ansi.reset()}") + """ + # 移除#号并转换为小写 + hex_color = hex_color.lower().lstrip('#') + + # 处理简写形式 (#fff -> ffffff) + if len(hex_color) == 3: + hex_color = ''.join([c * 2 for c in hex_color]) + + # 转换为RGB值 + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + + # ANSI真彩色序列 + # 38;2;R;G;B 为前景色,48;2;R;G;B 为背景色 + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +def rgb_to_ansi(r: int, g: int, b: int, background: bool = False) -> str: + """RGB值直接转ANSI""" + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +# 重置颜色的ANSI码 +hex_to_ansi.reset = '\033[0m' + +class ColoredTqdm(tqdm): + def __init__(self, *args, + start_color=(221, 160, 160), # RGB: #DDA0A0 + end_color=(160, 221, 160), # RGB: #A0DDA0 + **kwargs): + super().__init__(*args, **kwargs) + self.start_color = start_color + self.end_color = end_color + + def get_current_color(self): + progress = self.n / self.total if self.total > 0 else 0 + current_rgb = tuple( + int(start + (end - start) * progress) + for start, end in zip(self.start_color, self.end_color) + ) + result = current_rgb[0] * 16 ** 4 \ + + current_rgb[1] * 16 ** 2 \ + + current_rgb[2] * 16 ** 0 + return "%06x" % result + + def update(self, n=1): + super().update(n) + # 使用Rich的真彩色支持 + style = hex_to_ansi(self.get_current_color()) + self.bar_format = f'{{l_bar}}{style}{{bar}}{hex_to_ansi.reset}{{r_bar}}' + self.refresh() + + +if __name__ == "__main__": + # 使用示例 + for i in ColoredTqdm(range(100), desc="🌈 彩虹渐变"): + time.sleep(0.1) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/compound_tools.py b/ppmat/datasets/ECDFormerDataset/compound_tools.py new file mode 100644 index 00000000..bf47cd28 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/compound_tools.py @@ -0,0 +1,828 @@ +import numpy as np +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import rdchem + +DAY_LIGHT_FG_SMARTS_LIST = [ + # C + "[CX4]", + "[$([CX2](=C)=C)]", + "[$([CX3]=[CX3])]", + "[$([CX2]#C)]", + # C & O + "[CX3]=[OX1]", + "[$([CX3]=[OX1]),$([CX3+]-[OX1-])]", + "[CX3](=[OX1])C", + "[OX1]=CN", + "[CX3](=[OX1])O", + "[CX3](=[OX1])[F,Cl,Br,I]", + "[CX3H1](=O)[#6]", + "[CX3](=[OX1])[OX2][CX3](=[OX1])", + "[NX3][CX3](=[OX1])[#6]", + "[NX3][CX3]=[NX3+]", + "[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]", + "[NX3][CX3](=[OX1])[OX2H0]", + "[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]", + "[CX3](=O)[O-]", + "[CX3](=[OX1])(O)O", + "[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]", + "C[OX2][CX3](=[OX1])[OX2]C", + "[CX3](=O)[OX2H1]", + "[CX3](=O)[OX1H0-,OX2H1]", + "[NX3][CX2]#[NX1]", + "[#6][CX3](=O)[OX2H0][#6]", + "[#6][CX3](=O)[#6]", + "[OD2]([#6])[#6]", + # H + "[H]", + "[!#1]", + "[H+]", + "[+H]", + "[!H]", + # N + "[NX3;H2,H1;!$(NC=O)]", + "[NX3][CX3]=[CX3]", + "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]", + "[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]", + "[NX3][$(C=C),$(cc)]", + "[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]", + "[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]", + "[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH3X4]", + "[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]", + "[CH2X4][CX3](=[OX1])[NX3H2]", + "[CH2X4][CX3](=[OX1])[OH0-,OH]", + "[CH2X4][SX2H,SX1H0-]", + "[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]", + "[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]", + "[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\ +[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1", + "[CHX4]([CH3X4])[CH2X4][CH3X4]", + "[CH2X4][CHX4]([CH3X4])[CH3X4]", + "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]", + "[CH2X4][CH2X4][SX2][CH3X4]", + "[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1", + "[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH2X4][OX2H]", + "[NX3][CX3]=[SX1]", + "[CHX4]([CH3X4])[OX2H]", + "[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12", + "[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1", + "[CHX4]([CH3X4])[CH3X4]", + "N[CX4H2][CX3](=[OX1])[O,N]", + "N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]", + "[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]", + "[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]", + "[#7]", + "[NX2]=N", + "[NX2]=[NX2]", + "[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]", + "[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]", + "[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]", + "[NX3][NX3]", + "[NX3][NX2]=[*]", + "[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]", + "[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]", + "[NX3+]=[CX3]", + "[CX3](=[OX1])[NX3H][CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])", + "[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]", + "[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]", + "[NX1]#[CX2]", + "[CX1-]#[NX2+]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[NX2]=[OX1]", + "[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]", + # O + "[OX2H]", + "[#6][OX2H]", + "[OX2H][CX3]=[OX1]", + "[OX2H]P", + "[OX2H][#6X3]=[#6]", + "[OX2H][cX3]:[c]", + "[OX2H][$(C=C),$(cc)]", + "[$([OH]-*=[!#6])]", + "[OX2,OX1-][OX2,OX1-]", + # P + "[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\ +$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\ +,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]", + "[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\ +$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\ +$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]", + # S + "[S-][CX3](=S)[#6]", + "[#6X3](=[SX1])([!N])[!N]", + "[SX2]", + "[#16X2H]", + "[#16!H0]", + "[#16X2H0]", + "[#16X2H0][!#16]", + "[#16X2H0][#16X2H0]", + "[#16X2H0][!#16].[#16X2H0][!#16]", + "[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]", + "[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]", + "[SX4](C)(C)(=O)=N", + "[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]", + "[$([#16X3]=[OX1]),$([#16X3+][OX1-])]", + "[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]", + "[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]", + "[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]", + "[#16X2][OX2H,OX1H0-]", + "[#16X2][OX2H0]", + # X + "[#6][F,Cl,Br,I]", + "[F,Cl,Br,I]", + "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]", + ] + + +def get_gasteiger_partial_charges(mol, n_iter=12): + """ + Calculates list of gasteiger partial charges for each atom in mol object. + Args: + mol: rdkit mol object. + n_iter(int): number of iterations. Default 12. + Returns: + list of computed partial charges for each atom. + """ + Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, + throwOnParamFailure=True) + partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in + mol.GetAtoms()] + return partial_charges + + +def create_standardized_mol_id(smiles): + """ + Args: + smiles: smiles sequence. + Returns: + inchi. + """ + if check_smiles_validity(smiles): + # remove stereochemistry + smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), + isomericSmiles=False) + mol = AllChem.MolFromSmiles(smiles) + if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 + if '.' in smiles: # if multiple species, pick largest molecule + mol_species_list = split_rdkit_mol_obj(mol) + largest_mol = get_largest_mol(mol_species_list) + inchi = AllChem.MolToInchi(largest_mol) + else: + inchi = AllChem.MolToInchi(mol) + return inchi + else: + return + else: + return + + +def check_smiles_validity(smiles): + """ + Check whether the smile can't be converted to rdkit mol object. + """ + try: + m = Chem.MolFromSmiles(smiles) + if m: + return True + else: + return False + except Exception as e: + return False + + +def split_rdkit_mol_obj(mol): + """ + Split rdkit mol object containing multiple species or one species into a + list of mol objects or a list containing a single object respectively. + Args: + mol: rdkit mol object. + """ + smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) + smiles_list = smiles.split('.') + mol_species_list = [] + for s in smiles_list: + if check_smiles_validity(s): + mol_species_list.append(AllChem.MolFromSmiles(s)) + return mol_species_list + + +def get_largest_mol(mol_list): + """ + Given a list of rdkit mol objects, returns mol object containing the + largest num of atoms. If multiple containing largest num of atoms, + picks the first one. + Args: + mol_list(list): a list of rdkit mol object. + Returns: + the largest mol. + """ + num_atoms_list = [len(m.GetAtoms()) for m in mol_list] + largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) + return mol_list[largest_mol_idx] + + +def rdchem_enum_to_list(values): + """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} + """ + return [values[i] for i in range(len(values))] + + +def safe_index(alist, elem): + """ + Return index of element e in list l. If e is not present, return the last index + """ + try: + return alist.index(elem) + except ValueError: + return len(alist) - 1 + + +def get_atom_feature_dims(list_acquired_feature_names): + """ tbd + """ + return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) + + +def get_bond_feature_dims(list_acquired_feature_names): + """ tbd + """ + list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) + # +1 for self loop edges + return [_l + 1 for _l in list_bond_feat_dim] + + +class CompoundKit(object): + """ + CompoundKit + """ + atom_vocab_dict = { + "atomic_num": list(range(1, 119)) + ['misc'], + "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), + "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values), + "implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "is_aromatic": [0, 1], + "total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'num_radical_e': [0, 1, 2, 3, 4, 'misc'], + 'atom_is_in_ring': [0, 1], + 'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + } + bond_vocab_dict = { + "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), + "bond_type": rdchem_enum_to_list(rdchem.BondType.values), + "is_in_ring": [0, 1], + + 'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values), + 'is_conjugated': [0, 1], + } + # float features + atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] + # bond_float_feats= ["bond_length", "bond_angle"] # optional + + ### functional groups + day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST + day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] + + morgan_fp_N = 200 + morgan2048_fp_N = 2048 + maccs_fp_N = 167 + + period_table = Chem.GetPeriodicTable() + + ### atom + + @staticmethod + def get_atom_value(atom, name): + """get atom values""" + if name == 'atomic_num': + return atom.GetAtomicNum() + elif name == 'chiral_tag': + return atom.GetChiralTag() + elif name == 'degree': + return atom.GetDegree() + elif name == 'explicit_valence': + return atom.GetExplicitValence() + elif name == 'formal_charge': + return atom.GetFormalCharge() + elif name == 'hybridization': + return atom.GetHybridization() + elif name == 'implicit_valence': + return atom.GetImplicitValence() + elif name == 'is_aromatic': + return int(atom.GetIsAromatic()) + elif name == 'mass': + return int(atom.GetMass()) + elif name == 'total_numHs': + return atom.GetTotalNumHs() + elif name == 'num_radical_e': + return atom.GetNumRadicalElectrons() + elif name == 'atom_is_in_ring': + return int(atom.IsInRing()) + elif name == 'valence_out_shell': + return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) + else: + raise ValueError(name) + + @staticmethod + def get_atom_feature_id(atom, name): + """get atom features id""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) + + @staticmethod + def get_atom_feature_size(name): + """get atom features size""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return len(CompoundKit.atom_vocab_dict[name]) + + ### bond + + @staticmethod + def get_bond_value(bond, name): + """get bond values""" + if name == 'bond_dir': + return bond.GetBondDir() + elif name == 'bond_type': + return bond.GetBondType() + elif name == 'is_in_ring': + return int(bond.IsInRing()) + elif name == 'is_conjugated': + return int(bond.GetIsConjugated()) + elif name == 'bond_stereo': + return bond.GetStereo() + else: + raise ValueError(name) + + @staticmethod + def get_bond_feature_id(bond, name): + """get bond features id""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) + + @staticmethod + def get_bond_feature_size(name): + """get bond features size""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return len(CompoundKit.bond_vocab_dict[name]) + + ### fingerprint + + @staticmethod + def get_morgan_fingerprint(mol, radius=2): + """get morgan fingerprint""" + nBits = CompoundKit.morgan_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_morgan2048_fingerprint(mol, radius=2): + """get morgan2048 fingerprint""" + nBits = CompoundKit.morgan2048_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_maccs_fingerprint(mol): + """get maccs fingerprint""" + fp = AllChem.GetMACCSKeysFingerprint(mol) + return [int(b) for b in fp.ToBitString()] + + ### functional groups + + @staticmethod + def get_daylight_functional_group_counts(mol): + """get daylight functional group counts""" + fg_counts = [] + for fg_mol in CompoundKit.day_light_fg_mo_list: + sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) + fg_counts.append(len(sub_structs)) + return fg_counts + + @staticmethod + def get_ring_size(mol): + """return (N,6) list""" + rings = mol.GetRingInfo() + rings_info = [] + for r in rings.AtomRings(): + rings_info.append(r) + ring_list = [] + for atom in mol.GetAtoms(): + atom_result = [] + for ringsize in range(3, 9): + num_of_ring_at_ringsize = 0 + for r in rings_info: + if len(r) == ringsize and atom.GetIdx() in r: + num_of_ring_at_ringsize += 1 + if num_of_ring_at_ringsize > 8: + num_of_ring_at_ringsize = 9 + atom_result.append(num_of_ring_at_ringsize) + + ring_list.append(atom_result) + return ring_list + + @staticmethod + def atom_to_feat_vector(atom): + """ tbd """ + atom_names = { + "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), + "chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()), + "degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()), + "explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()), + "formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()), + "hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()), + "implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()), + "is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())), + "total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()), + 'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()), + 'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())), + 'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'], + CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())), + 'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()), + 'partial_charge': CompoundKit.check_partial_charge(atom), + 'mass': atom.GetMass(), + } + return atom_names + + @staticmethod + def get_atom_names(mol): + """get atom name list + TODO: to be remove in the future + """ + atom_features_dicts = [] + Chem.rdPartialCharges.ComputeGasteigerCharges(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) + + ring_list = CompoundKit.get_ring_size(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) + atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) + atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) + atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) + atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) + atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) + + return atom_features_dicts + + @staticmethod + def check_partial_charge(atom): + """tbd""" + pc = atom.GetDoubleProp('_GasteigerCharge') + if pc != pc: + # unsupported atom, replace nan with 0 + pc = 0 + if pc == float('inf'): + # max 4 for other atoms, set to 10 here if inf is get + pc = 10 + return pc + + +class Compound3DKit(object): + """the 3Dkit of Compound""" + + @staticmethod + def get_atom_poses(mol, conf): + """tbd""" + atom_poses = [] + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) + pos = conf.GetAtomPosition(i) + atom_poses.append([pos.x, pos.y, pos.z]) + return atom_poses + + @staticmethod + def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): + """the atoms of mol will be changed in some cases.""" + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return mol,atom_poses + # try: + # new_mol = Chem.AddHs(mol) + # res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs) + # ### MMFF generates multiple conformations + # res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + # new_mol = Chem.RemoveHs(new_mol) + # index = np.argmin([x[1] for x in res]) + # energy = res[index][1] + # conf = new_mol.GetConformer(id=int(index)) + # except: + # new_mol = mol + # AllChem.Compute2DCoords(new_mol) + # energy = 0 + # conf = new_mol.GetConformer() + # + # atom_poses = Compound3DKit.get_atom_poses(new_mol, conf) + # if return_energy: + # return new_mol, atom_poses, energy + # else: + # return new_mol, atom_poses + + @staticmethod + def get_2d_atom_poses(mol): + """get 2d atom poses""" + AllChem.Compute2DCoords(mol) + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return atom_poses + + @staticmethod + def get_bond_lengths(edges, atom_poses): + """get bond lengths""" + bond_lengths = [] + for src_node_i, tar_node_j in edges: + bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) + bond_lengths = np.array(bond_lengths, 'float32') + return bond_lengths + + @staticmethod + def get_superedge_angles(edges, atom_poses, dir_type='HT'): + """get superedge angles""" + + def _get_vec(atom_poses, edge): + return atom_poses[edge[1]] - atom_poses[edge[0]] + + def _get_angle(vec1, vec2): + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + if norm1 == 0 or norm2 == 0: + return 0 + vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors + vec2 = vec2 / (norm2 + 1e-5) + angle = np.arccos(np.dot(vec1, vec2)) + return angle + + E = len(edges) + edge_indices = np.arange(E) + super_edges = [] + bond_angles = [] + bond_angle_dirs = [] + for tar_edge_i in range(E): + tar_edge = edges[tar_edge_i] + if dir_type == 'HT': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] + elif dir_type == 'HH': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] + else: + raise ValueError(dir_type) + for src_edge_i in src_edge_indices: + if src_edge_i == tar_edge_i: + continue + src_edge = edges[src_edge_i] + src_vec = _get_vec(atom_poses, src_edge) + tar_vec = _get_vec(atom_poses, tar_edge) + super_edges.append([src_edge_i, tar_edge_i]) + angle = _get_angle(src_vec, tar_vec) + bond_angles.append(angle) + bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T + + if len(super_edges) == 0: + super_edges = np.zeros([0, 2], 'int64') + bond_angles = np.zeros([0, ], 'float32') + else: + super_edges = np.array(super_edges, 'int64') + bond_angles = np.array(bond_angles, 'float32') + return super_edges, bond_angles, bond_angle_dirs + + +def new_smiles_to_graph_data(smiles, **kwargs): + """ + Convert smiles to graph data. + """ + mol = AllChem.MolFromSmiles(smiles) + if mol is None: + return None + data = new_mol_to_graph_data(mol) + return data + + +def new_mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names + bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) + + data = {} + + ### atom features + data = {name: [] for name in atom_id_names} + + raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) + for atom_feat in raw_atom_feat_dicts: + for name in atom_id_names: + data[name].append(atom_feat[name]) + + ### bond and bond features + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + data[name] += [bond_feature_id] * 2 + + #### self loop + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1 + data[name] += [bond_feature_id] * N + + ### make ndarray and check length + for name in list(CompoundKit.atom_vocab_dict.keys()): + data[name] = np.array(data[name], 'int64') + for name in CompoundKit.atom_float_names: + data[name] = np.array(data[name], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", + ] + bond_id_names = [ + "bond_dir", "bond_type", "is_in_ring", + ] + + data = {} + for name in atom_id_names: + data[name] = [] + data['mass'] = [] + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + ### atom features + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return None + for name in atom_id_names: + data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV + data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) + + ### bond features + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV + data[name] += [bond_feature_id] * 2 + + ### self loop (+2) + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop + data[name] += [bond_feature_id] * N + + ### check whether edge exists + if len(data['edges']) == 0: # mol has no bonds + for name in bond_id_names: + data[name] = np.zeros((0,), dtype="int64") + data['edges'] = np.zeros((0, 2), dtype="int64") + + ### make ndarray and check length + for name in atom_id_names: + data[name] = np.array(data[name], 'int64') + data['mass'] = np.array(data['mass'], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_geognn_graph_data(mol, atom_poses, dir_type): + """ + mol: rdkit molecule + dir_type: direction type for bond_angle grpah + """ + if len(mol.GetAtoms()) == 0: + return None + + data = mol_to_graph_data(mol) + + data['atom_pos'] = np.array(atom_poses, 'float32') + data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) + BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ + Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) + data['BondAngleGraph_edges'] = BondAngleGraph_edges + data['bond_angle'] = np.array(bond_angles, 'float32') + return data + + +def mol_to_geognn_graph_data_MMFF3d(mol): + """tbd""" + if len(mol.GetAtoms()) <= 400: + mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) + else: + atom_poses = Compound3DKit.get_2d_atom_poses(mol) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + + +def mol_to_geognn_graph_data_raw3d(mol): + """tbd""" + atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + +def obtain_3D_mol(smiles,name): + mol = AllChem.MolFromSmiles(smiles) + new_mol = Chem.AddHs(mol) + res = AllChem.EmbedMultipleConfs(new_mol) + ### MMFF generates multiple conformations + res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + new_mol = Chem.RemoveHs(new_mol) + Chem.MolToMolFile(new_mol, name+'.mol') + return new_mol + +def predict_SMILES_info(smiles): + # by lihao, input smiles, output dict + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + info_dict = mol_to_geognn_graph_data_MMFF3d(mol) + return info_dict + +if __name__ == "__main__": + # smiles = "OCc1ccccc1CN" + smiles = r"[H]/[NH+]=C(\N)C1=CC(=O)/C(=C\C=c2ccc(=C(N)[NH3+])cc2)C=C1" + # smiles = 'CC' + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + data = mol_to_geognn_graph_data_MMFF3d(mol) + for key, value in data.items(): + print(key, value.shape) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/dataloader.py b/ppmat/datasets/ECDFormerDataset/dataloader.py new file mode 100644 index 00000000..848c9ae4 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/dataloader.py @@ -0,0 +1,56 @@ +import paddle +from paddle.io import Dataset + +from typing import Any, List, Sequence, Union + +from paddle_geometric.data import Batch, Dataset +from paddle_geometric.data.data import BaseData +from paddle_geometric.data.datapipes import DatasetAdapter + +def call(batch: List[Any]) -> Any: + batch = [list(x) for x in zip(*batch)] # transpose + for i in range(len(batch)): # 组Batch + batch[i] = Batch.from_data_list(batch[i]) + + batch0 = batch[0] + batch1 = batch[1] + + # Data解包到Tensor字典 + batch_atom_bond, batch_bond_angle = batch0, batch1 + x, edge_index, edge_attr, query_mask =batch_atom_bond.x,batch_atom_bond.edge_index,batch_atom_bond.edge_attr,batch_atom_bond.query_mask + ba_edge_index, ba_edge_attr = batch_bond_angle.edge_index,batch_bond_angle.edge_attr + batch_data = batch_atom_bond.batch + pos_gt = batch_atom_bond.peak_position + height_gt = batch_atom_bond.peak_height + num_gt = batch_atom_bond.peak_num + return \ + { + "x" : x , + "edge_index" : edge_index , + "edge_attr" : edge_attr , + "batch_data" : batch_data , + "ba_edge_index" : ba_edge_index , + "ba_edge_attr" : ba_edge_attr , + "query_mask" : query_mask + }, \ + { + "peak_number_gt" : num_gt , + "peak_position_gt": pos_gt , + "peak_height_gt" : height_gt + } + +class ECDFormerDataset_DataLoader(paddle.io.DataLoader): + def __init__( + self, + dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], + batch_size: int = 1, + shuffle: bool = False, + **kwargs, + ): + super().__init__( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=call, + **kwargs, + ) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/eval_func.py b/ppmat/datasets/ECDFormerDataset/eval_func.py new file mode 100644 index 00000000..5bd5e1ab --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/eval_func.py @@ -0,0 +1,147 @@ +import math +import json + +import paddle +import paddle.nn as nn +import numpy as np + +from sklearn.metrics import mean_squared_error + +def Accuracy(pred, gt): + # the implimentation of good sample Accuracy. We calculate multi-range accuracy + # pred: paddle.Tensor(batch), the tensor of prediction + # gt: paddle.Tensor(batch), the tensor of groundtruth + pred, gt = pred.tolist(), gt.tolist() + assert len(pred) == len(gt) + + acc, acc_1, acc_2, acc_3 = 0, 0, 0, 0 + for i in range(len(pred)): + if pred[i] == gt[i]: + acc += 1 + elif abs(pred[i]-gt[i]) == 1: + acc_1 += 1 + elif abs(pred[i]-gt[i]) == 2: + acc_2 += 1 + elif abs(pred[i]-gt[i]) == 3: + acc_3 += 1 + else: continue + + return acc, acc_1, acc_2, acc_3 + +def MAE(pred, gt): + # the implimentation of Mean Absolute Error + # MAE == L1 loss + # pred: paddle.Tensor(batch, seq_len), the tensor of prediction + # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth + + pred = paddle.to_tensor(pred, dtype=paddle.float32) + gt = paddle.to_tensor(gt, dtype=paddle.float32) + assert pred.shape == gt.shape + L1Loss = nn.L1Loss(reduction='mean') # paddle��size_average�ѷ�����ͳһ��reduction + mae = L1Loss(pred, gt) + + return mae + + +def MAPE(pred, gt): + # the implimentation of Mean Absolute Percentage Error + # pred: paddle.Tensor(batch, seq_len), the tensor of prediction + # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth + + pred = paddle.to_tensor(pred, dtype=paddle.float32) + gt = paddle.to_tensor(gt, dtype=paddle.float32) + pred, gt = pred.numpy(), gt.numpy() # paddle����cpu()������ֱ����numpy() + mape_loss = np.mean( + np.abs(np.divide(pred-gt, gt, out=np.zeros_like(gt), where=gt!=0)) + ) * 100 + return mape_loss + +def get_sequence_peak(sequence): + # input- seq: List + # output- peak_list contains peak position + peak_list = [] + for i in range(1, len(sequence)-1): + if sequence[i-1]sequence[i+1]: + peak_list.append(i) + if sequence[i-1]>sequence[i] and sequence[i]0: correct_peak_symbols += 1 + ## peak number + number_gt.append(len(peaks_gt)) + number_pred.append(len(peaks_pred)) + ## peak position + if len(peaks_gt) == min_peaks_len: + peaks_gt.extend([len(gt[i])] * (max_peaks_len-min_peaks_len) ) + else: peaks_pred.extend([len(pred[i])] * (max_peaks_len-min_peaks_len) ) + position_gt.extend(peaks_gt) + position_pred.extend(peaks_pred) + + rmse_position = np.sqrt(mean_squared_error(position_gt, position_pred)) + rmse_number = np.sqrt(mean_squared_error(number_gt, number_pred)) + symbol_acc = correct_peak_symbols / total_peaks if total_peaks != 0 else 0.0 + return symbol_acc, rmse_position, rmse_number + +def Peak_for_draw(pred, gt): + # the implementation of RMSE for Peak Range, Number, Position + # return the Peak information + # pred: paddle.Tensor(batch, seq_len), the tensor of prediction + # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth + + pred = paddle.to_tensor(pred, dtype=paddle.float32) + gt = paddle.to_tensor(gt, dtype=paddle.float32) + pred, gt = pred.numpy().tolist(), gt.numpy().tolist() # paddle����cpu()���� + batch_size = len(pred) + + # calculate RMSE Range + range_gt, range_pred = [], [] + for i in range(batch_size): + range_gt.append(max(gt[i]) - min(gt[i])) + range_pred.append(max(pred[i]) - min(pred[i])) + + # calculate RMSE for Peak number + number_gt, number_pred = [], [] + position_gt, position_pred = [], [] + for i in range(batch_size): + peaks_gt = get_sequence_peak(gt[i]) + peaks_pred = get_sequence_peak(pred[i]) + + number_gt.append(len(peaks_gt)) + number_pred.append(len(peaks_pred)) + + min_peaks_len = min(len(peaks_gt), len(peaks_pred)) + max_peaks_len = max(len(peaks_gt), len(peaks_pred)) + + if len(peaks_gt) == min_peaks_len: + peaks_gt.extend([len(gt[i])] * (max_peaks_len-min_peaks_len) ) + else: + peaks_pred.extend([len(pred[i])] * (max_peaks_len-min_peaks_len) ) + position_gt.append(sum(peaks_gt)) + position_pred.append(sum(peaks_pred)) + + return dict( + peak_range = (range_gt, range_pred), + peak_num = (number_gt, number_pred), + peak_pos = (position_gt, position_pred), + ) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/place_env.py b/ppmat/datasets/ECDFormerDataset/place_env.py new file mode 100644 index 00000000..6d06a504 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/place_env.py @@ -0,0 +1,173 @@ +import paddle +import functools +from contextlib import contextmanager + +@contextmanager +def place_env(place): + """ + 上下文管理器,用于临时设置PaddlePaddle的运行设备 + + Args: + place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + + 用法: + with place_env(paddle.CPUPlace()): + # 这里的代码在CPU上运行 + x = paddle.rand([2, 3]) + print(x) + + @place_env(paddle.CUDAPlace(0)) + def train(): + # 这个函数在GPU上运行 + pass + """ + # 保存当前的设备设置 + current_device = paddle.get_device() + + # 根据place类型设置设备 + if isinstance(place, paddle.CPUPlace): + paddle.set_device('cpu') + elif isinstance(place, paddle.CUDAPlace): + # 获取GPU设备ID + device_id = place.get_device_id() + paddle.set_device(f'gpu:{device_id}') + else: + raise ValueError(f"不支持的place类型: {type(place)}") + + try: + yield + finally: + # 恢复原来的设备设置 + paddle.set_device(current_device) + + +class PlaceEnv: + """ + 类版本的上下文管理器,也支持装饰器功能 + """ + + def __init__(self, place): + """ + 初始化PlaceEnv + + Args: + place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + """ + self.place = place + self.original_device = None + + def __enter__(self): + """进入上下文时调用""" + # 保存当前的设备设置 + self.original_device = paddle.get_device() + + # 根据place类型设置设备 + if isinstance(self.place, paddle.CPUPlace): + paddle.set_device('cpu') + elif isinstance(self.place, paddle.CUDAPlace): + device_id = self.place.get_device_id() + paddle.set_device(f'gpu:{device_id}') + else: + raise ValueError(f"不支持的place类型: {type(self.place)}") + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """退出上下文时调用""" + # 恢复原来的设备设置 + if self.original_device is not None: + paddle.set_device(self.original_device) + + def __call__(self, func): + """ + 使实例可以作为装饰器使用 + + Args: + func: 要装饰的函数 + + Returns: + 装饰后的函数 + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + # 使用with语句来临时改变设备设置 + with self: + return func(*args, **kwargs) + return wrapper + + +# 为了兼容性,也可以保留函数版本的上下文管理器 +@contextmanager +def with_place_env(place): + """ + with_place_env的别名,与place_env功能相同 + """ + with place_env(place): + yield + + +# 使用示例 +if __name__ == "__main__": + # 测试with语句 + print("=== 测试with语句 ===") + print(f"当前设备: {paddle.get_device()}") + + with place_env(paddle.CPUPlace()): + print(f"with块内设备: {paddle.get_device()}") + x = paddle.rand([2, 3]) + print(f"创建的张量: {x}") + + print(f"with块外设备: {paddle.get_device()}") + + print("\n=== 测试类版本with语句 ===") + with PlaceEnv(paddle.CPUPlace()): + print(f"with块内设备: {paddle.get_device()}") + y = paddle.ones([2, 3]) + print(f"创建的张量: {y}") + + print(f"with块外设备: {paddle.get_device()}") + + # 测试装饰器功能 + print("\n=== 测试装饰器功能 ===") + + @PlaceEnv(paddle.CPUPlace()) + def cpu_function(): + """这个函数会在CPU上运行""" + print(f"函数内设备: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # 检查是否有GPU可用 + if paddle.device.cuda.device_count() > 0: + @PlaceEnv(paddle.CUDAPlace(0)) + def gpu_function(): + """这个函数会在GPU上运行""" + print(f"函数内设备: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # 调用装饰后的函数 + print("调用cpu_function:") + result_cpu = cpu_function() + print(f"函数执行后设备: {paddle.get_device()}") + print(f"结果: {result_cpu}") + + if paddle.device.cuda.device_count() > 0: + print("\n调用gpu_function:") + result_gpu = gpu_function() + print(f"函数执行后设备: {paddle.get_device()}") + print(f"结果: {result_gpu}") + + print("\n=== 测试多层嵌套 ===") + print(f"初始设备: {paddle.get_device()}") + + with PlaceEnv(paddle.CPUPlace()): + print(f"第一层with内设备: {paddle.get_device()}") + + if paddle.device.cuda.device_count() > 0: + with PlaceEnv(paddle.CUDAPlace(0)): + print(f"第二层with内设备: {paddle.get_device()}") + z = paddle.rand([2, 2]) + print(f"创建的张量: {z}") + + print(f"回到第一层with设备: {paddle.get_device()}") + + print(f"最终设备: {paddle.get_device()}") \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/util_func.py b/ppmat/datasets/ECDFormerDataset/util_func.py new file mode 100644 index 00000000..010ef951 --- /dev/null +++ b/ppmat/datasets/ECDFormerDataset/util_func.py @@ -0,0 +1,43 @@ +def has_element_in_range(lst, lower_bound, upper_bound): + """ + 检查给定列表 lst 中是否存在元素在指定的区间 [lower_bound, upper_bound] 内。 + + 参数: + - lst: 输入的列表 + - lower_bound: 区间的下界 + - upper_bound: 区间的上界 + + 返回: + - 存在元素在指定区间内时返回 True, 否则返回 False + """ + for element in lst: + if lower_bound <= element <= upper_bound: + return True + return False + + +def normalize_func(src_list, norm_range=[-100, 100]): + # lihao implecation for list normalization + # input: src_list, normalization range + # output: tgt_list after normalization + + src_max, src_min = max(src_list), min(src_list) + norm_min, norm_max = norm_range[0], norm_range[1] + if src_max == 0: src_max = 1 + if src_min == 0: src_min = -1 + + tgt_list = [] + for i in range(len(src_list)): + if src_list[i] >= 0: + tgt_list.append(src_list[i] * norm_max / src_max) + else: + tgt_list.append(src_list[i] * norm_min / src_min) + + assert len(src_list) == len(tgt_list) + return tgt_list + +if __name__ == "__main__": + src = [-50, 0, 1, 50] + norm_range = [-100, 100] + tgt = normalize_func(src, norm_range) + print(tgt) diff --git a/ppmat/datasets/IRDataset/__init__.py b/ppmat/datasets/IRDataset/__init__.py new file mode 100644 index 00000000..0cb43d2d --- /dev/null +++ b/ppmat/datasets/IRDataset/__init__.py @@ -0,0 +1,439 @@ +# IRDataset.py +""" +IR光谱预测数据集模块 +支持预加载的npy文件,包含缓存机制,默认使用100样本的小数据集 +""" + +import os +import numpy as np +import pandas as pd +import paddle +from paddle.io import Dataset, DataLoader +from paddle_geometric.data import Data +from tqdm import tqdm +import pickle +import json +import warnings + +import rdkit +from rdkit import Chem +from rdkit.Chem import AllChem + +from .compound_tools import mol_to_geognn_graph_data_MMFF3d +from .compound_tools import get_atom_feature_dims, get_bond_feature_dims +from .colored_tqdm import ColoredTqdm as tqdm +from .place_env import PlaceEnv + +# ----------------常量定义---------------- +ATOM_ID_NAMES = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", +] + +BOND_ID_NAMES = ["bond_dir", "bond_type", "is_in_ring"] + +BOND_ANGLE_FLOAT_NAMES = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] + +# 获取特征维度 +FULL_ATOM_FEATURE_DIMS = get_atom_feature_dims(ATOM_ID_NAMES) +FULL_BOND_FEATURE_DIMS = get_bond_feature_dims(BOND_ID_NAMES) + +# IR光谱参数 +IR_WAVELENGTH_MIN = 500 +IR_WAVELENGTH_MAX = 4000 +IR_STEP = 100 # 波数步长,用于离散化 +IR_NUM_POSITION_CLASSES = (IR_WAVELENGTH_MAX - IR_WAVELENGTH_MIN) // IR_STEP # 36 + +# 默认最大峰数 +DEFAULT_MAX_PEAKS = 15 + + +def get_key_padding_mask(tokens): + """生成query padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + + +def x_bin_position(real_x, distance=IR_STEP): + """将实际波数转换为箱ID""" + return int((real_x - IR_WAVELENGTH_MIN) / distance) + + +def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): + """ + 从原始特征构建IR图数据 + 类似于ECD的Construct_dataset,但针对IR任务 + + Args: + dataset: list of dict, 每个元素是分子的info字典 + data_index: list or array, 索引列表 + descriptor_path: str, 描述符文件路径(IR可能不需要,保留接口) + + Returns: + graph_atom_bond: list of Data, atom-bond图 + graph_bond_angle: list of Data, bond-angle图 + """ + graph_atom_bond = [] + graph_bond_angle = [] + + # IR任务可能不需要描述符,但如果需要可以加载 + all_descriptor = None + if descriptor_path and os.path.exists(descriptor_path): + all_descriptor = np.load(descriptor_path) + + for i in tqdm(range(len(dataset)), desc="Constructing IR graphs"): + data = dataset[i] + + # 收集原子特征 + atom_feature = [] + for name in ATOM_ID_NAMES: + if name in data: + atom_feature.append(data[name]) + else: + # 如果某些特征缺失,用0填充 + # 注意:根据实际数据调整 + if i == 0: # 只在第一次警告 + warnings.warn(f"Feature {name} not found in data, using zeros") + num_atoms = data.get('atomic_num', np.zeros(1)).shape[0] + atom_feature.append(np.zeros(num_atoms)) + + # 收集键特征 + bond_feature = [] + for name in BOND_ID_NAMES: + if name in data: + bond_feature.append(data[name]) + else: + if i == 0: + warnings.warn(f"Bond feature {name} not found, using zeros") + num_bonds = data.get('bond_dir', np.zeros(1)).shape[0] + bond_feature.append(np.zeros(num_bonds)) + + # 转换为Tensor + atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') + bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') + + # 键长特征(IR可能不需要,但保留) + bond_float_feature = paddle.to_tensor(data.get('bond_length', np.zeros(data['edges'].shape[0])).astype(np.float32)) + + # 键角特征(IR可能不需要,但保留) + bond_angle_feature = paddle.to_tensor(data.get('bond_angle', np.zeros(data.get('BondAngleGraph_edges', np.zeros((0,2))).shape[0])).astype(np.float32)) + + # 边索引 + edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') + bond_index = paddle.to_tensor(data.get('BondAngleGraph_edges', np.zeros((0,2))).T, dtype='int64') + + + data_index_int = paddle.to_tensor(np.array(int(data_index[i])), dtype='int64') + + # 获取原子数(键角图的节点数) + num_atoms = atom_feature.shape[0] + + # 合并键特征 + bond_feature = paddle.concat( + [bond_feature.astype(bond_float_feature.dtype), + bond_float_feature.reshape([-1, 1])], + axis=1 + ) + + # 处理键角特征(如果有) + if bond_angle_feature.shape[0] > 0: + # 如果有描述符,可以添加 + if all_descriptor is not None and i < all_descriptor.shape[0]: + TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 + RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] + RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] + MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] + MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] + + bond_angle_feature = paddle.concat( + [bond_angle_feature.reshape([-1, 1]), TPSA.reshape([-1, 1])], axis=1 + ) + bond_angle_feature = paddle.concat([bond_angle_feature, RASA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, RPSA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MDEC.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MATS.reshape([-1, 1])], axis=1) + else: + # 如果没有描述符,直接reshape + bond_angle_feature = bond_angle_feature.reshape([-1, 1]) + + # 创建Data对象 + data_atom_bond = Data( + x=atom_feature, + edge_index=edge_index, + edge_attr=bond_feature, + data_index=data_index_int, + ) + + data_bond_angle = Data( + edge_index=bond_index, + edge_attr=bond_angle_feature if bond_angle_feature.shape[0] > 0 else paddle.zeros([0, 1]), + num_nodes=num_atoms, # 键角图的节点数等于原子数 + ) + + graph_atom_bond.append(data_atom_bond) + graph_bond_angle.append(data_bond_angle) + + return graph_atom_bond, graph_bond_angle + + +def read_ir_spectra_by_ids(sample_path, index_all, max_peak=DEFAULT_MAX_PEAKS): + """ + 按需读取IR光谱文件 + + Args: + sample_path: str, IR光谱JSON文件目录 + index_all: list, 需要读取的文件ID列表 + max_peak: int, 最大峰数 + + Returns: + ir_final_list: list of dict, 包含峰值信息的字典列表 + """ + ir_final_list = [] + + for fileid in tqdm(index_all, desc="Reading IR spectra by ID"): + filepath = os.path.join(sample_path, f"{fileid}.json") + + try: + with open(filepath, 'r') as f: + raw_ir_info = json.load(f) + + ir_x = raw_ir_info['x'] + ir_y = raw_ir_info['y_40'] + + from scipy.signal import find_peaks + peaks_raw, _ = find_peaks(x=ir_y, height=0.1, distance=100) + peaks_raw = peaks_raw.tolist() + + # 处理峰值(与原作完全一致) + peak_num = min(len(peaks_raw), max_peak) + + if peak_num > 0: + if len(peaks_raw) > max_peak: + peaks = peaks_raw[len(peaks_raw)-max_peak:] + else: + peaks = peaks_raw + + peak_position_list = [x_bin_position(ir_x[i]) for i in peaks] + peak_height_list = [ir_y[i] for i in peaks] + else: + peak_position_list = [] + peak_height_list = [] + + peak_position_list = peak_position_list + [-1] * (max_peak - len(peak_position_list)) + peak_height_list = peak_height_list + [-1] * (max_peak - len(peak_height_list)) + + query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) + + tmp_dict = { + 'id': fileid, + 'seq_40': ir_y, + 'peak_num': peak_num, + 'peak_position': peak_position_list, + 'peak_height': peak_height_list, + 'query_mask': query_padding_mask.unsqueeze(0), + } + ir_final_list.append(tmp_dict) + + except Exception as e: + warnings.warn(f"Error processing {fileid}.json: {e}") + continue + + ir_final_list.sort(key=lambda x: x['id']) + return ir_final_list + + +def load_ir_meta_file(mode='100'): + """ + 加载预生成的IR元数据文件 + + Args: + mode: str, 可选 '100', '10000', 'all' + + Returns: + dataset_all: list, 图特征数据 + smiles_all: list, SMILES列表 + index_all: list, 索引列表 + """ + valid_modes = {'100', '10000', 'all'} + if mode not in valid_modes: + warnings.warn(f"Invalid mode {mode}, using '100'") + mode = '100' + + filename = f'ir_column_charity_{mode}.npy' + + if not os.path.exists(filename): + # 尝试在dataset/IR目录下查找 + alt_path = os.path.join('dataset', 'IR', filename) + if os.path.exists(alt_path): + filename = alt_path + else: + raise FileNotFoundError(f"IR meta file {filename} not found") + + print(f"Loading IR meta file: {filename}") + data = np.load(filename, allow_pickle=True).item() + + return data['dataset_all'], data['smiles_all'], data['index_all'] + + +class IRDataset(Dataset): + """ + IR光谱预测数据集类 + + 支持三种预加载模式: + - '100': 100个样本的小数据集(默认,用于快速测试) + - '10000': 1万个样本的中等数据集 + - 'all': 全部样本(可能很大) + + 关键优化:按需读取JSON文件,只读取index_all中指定的文件! + """ + + _cache = {} + + @PlaceEnv(paddle.CPUPlace()) + def __init__(self, + path: str = "dataset/IR", + mode: str = '100', + use_geometry_enhanced: bool = True, + force_reload: bool = False, + cache: bool = True): + + self.path = path + self.mode = mode + self.use_geometry_enhanced = use_geometry_enhanced + self.cache_enabled = cache + + cache_key = f"{path}_{mode}_{use_geometry_enhanced}" + + if cache and not force_reload and cache_key in self._cache: + print(f"Loading IR dataset from cache: {mode}") + cached_data = self._cache[cache_key] + self.graph_atom_bond = cached_data['atom_bond'] + self.graph_bond_angle = cached_data['bond_angle'] + self.smiles_list = cached_data.get('smiles', []) + return + + print(f"First-time loading IR dataset (mode={mode})...") + + # 1. 加载元数据文件(获取index_all) + meta_path = os.path.join(path, f'ir_column_charity_{mode}.npy') + if not os.path.exists(meta_path): + raise FileNotFoundError(f"IR meta file {meta_path} not found") + + data = np.load(meta_path, allow_pickle=True).item() + dataset_all = data['dataset_all'] + smiles_all = data['smiles_all'] + index_all = data['index_all'] + + print(f"Loaded meta data: {len(index_all)} samples") + + # 2. 按需读取IR光谱(只读取index_all中的文件!) + spectra_path = os.path.join(path, 'qm9_ir_spec') + self.ir_sequences = read_ir_spectra_by_ids(spectra_path, index_all) + + print(f"Loaded {len(self.ir_sequences)} IR spectra") + + # 3. 构建图数据 + descriptor_path = os.path.join(path, 'descriptor_all_column.npy') + if not os.path.exists(descriptor_path): + descriptor_path = None + + total_graph_atom_bond, total_graph_bond_angle = Construct_IR_Dataset( + dataset_all, index_all, descriptor_path + ) + + # 4. 将光谱信息附加到图数据 + self.graph_atom_bond = [] + self.graph_bond_angle = [] + self.smiles_list = [] + + # 注意:ir_sequences已经按id排序,index_all也是按id排序的 + for i, itm in enumerate(self.ir_sequences): + atom_bond = total_graph_atom_bond[i] + + atom_bond.sequence = paddle.to_tensor([itm['seq_40']]) + atom_bond.ir_id = paddle.to_tensor(int(itm['id'])) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + self.graph_atom_bond.append(atom_bond) + self.graph_bond_angle.append(total_graph_bond_angle[i]) + self.smiles_list.append(smiles_all[i]) + + print(f"Final dataset size: {len(self.graph_atom_bond)}") + + if cache: + self._cache[cache_key] = { + 'atom_bond': self.graph_atom_bond, + 'bond_angle': self.graph_bond_angle, + 'smiles': self.smiles_list + } + def __len__(self): + """返回数据集大小""" + return len(self.graph_atom_bond) + + def __getitem__(self, idx): + return (self.graph_atom_bond[idx], self.graph_bond_angle[idx]) + +# 自定义DataLoader +class IRDataLoader(DataLoader): + """IR数据集专用DataLoader""" + + def __init__(self, dataset, batch_size=64, shuffle=True, num_workers=0, **kwargs): + super().__init__( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=self._collate_fn, + **kwargs + ) + + @staticmethod + def _collate_fn(batch): + """ + 自定义collate函数 + + Args: + batch: list of (atom_bond, bond_angle, smiles) tuples + + Returns: + batched_atom_bond: Batch + batched_bond_angle: Batch + smiles_list: list of str + """ + from paddle_geometric.data import Batch + + atom_bond_list = [item[0] for item in batch] + bond_angle_list = [item[1] for item in batch] + #smiles_list = [item[2] for item in batch] + + batch_atom_bond = Batch.from_data_list(atom_bond_list) + batch_bond_angle = Batch.from_data_list(bond_angle_list) + # Data解包到Tensor字典 + x, edge_index, edge_attr, query_mask =batch_atom_bond.x,batch_atom_bond.edge_index,batch_atom_bond.edge_attr,batch_atom_bond.query_mask + ba_edge_index, ba_edge_attr = batch_bond_angle.edge_index,batch_bond_angle.edge_attr + batch_data = batch_atom_bond.batch + pos_gt = batch_atom_bond.peak_position + height_gt = batch_atom_bond.peak_height + num_gt = batch_atom_bond.peak_num + return \ + { + "x" : x , + "edge_index" : edge_index , + "edge_attr" : edge_attr , + "batch_data" : batch_data , + "ba_edge_index" : ba_edge_index , + "ba_edge_attr" : ba_edge_attr , + "query_mask" : query_mask + }, \ + { + "peak_number_gt" : num_gt , + "peak_position_gt": pos_gt , + "peak_height_gt" : height_gt + } + return batched_atom_bond, batched_bond_angle + diff --git a/ppmat/datasets/IRDataset/colored_tqdm.py b/ppmat/datasets/IRDataset/colored_tqdm.py new file mode 100644 index 00000000..c6b9cff0 --- /dev/null +++ b/ppmat/datasets/IRDataset/colored_tqdm.py @@ -0,0 +1,76 @@ +from tqdm import tqdm +import time +import os;os.system("") #兼容windows + +def hex_to_ansi(hex_color: str, background: bool = False) -> str: + """ + 将十六进制颜色转换为ANSI转义序列 + + Args: + hex_color: 十六进制颜色,如 '#dda0a0' 或 'dda0a0' + background: True表示背景色,False表示前景色 + + Returns: + ANSI转义序列字符串,如 '\033[38;2;221;160;160m' + + Example: + >>> print(f"{hex_to_ansi('#dda0a0')}Hello{hex_to_ansi('#000000')} World") + >>> print(f"{hex_to_ansi('dda0a0', background=True)}背景色{hex_to_ansi.reset()}") + """ + # 移除#号并转换为小写 + hex_color = hex_color.lower().lstrip('#') + + # 处理简写形式 (#fff -> ffffff) + if len(hex_color) == 3: + hex_color = ''.join([c * 2 for c in hex_color]) + + # 转换为RGB值 + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + + # ANSI真彩色序列 + # 38;2;R;G;B 为前景色,48;2;R;G;B 为背景色 + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +def rgb_to_ansi(r: int, g: int, b: int, background: bool = False) -> str: + """RGB值直接转ANSI""" + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +# 重置颜色的ANSI码 +hex_to_ansi.reset = '\033[0m' + +class ColoredTqdm(tqdm): + def __init__(self, *args, + start_color=(221, 160, 160), # RGB: #DDA0A0 + end_color=(160, 221, 160), # RGB: #A0DDA0 + **kwargs): + super().__init__(*args, **kwargs) + self.start_color = start_color + self.end_color = end_color + + def get_current_color(self): + progress = self.n / self.total if self.total > 0 else 0 + current_rgb = tuple( + int(start + (end - start) * progress) + for start, end in zip(self.start_color, self.end_color) + ) + result = current_rgb[0] * 16 ** 4 \ + + current_rgb[1] * 16 ** 2 \ + + current_rgb[2] * 16 ** 0 + return "%06x" % result + + def update(self, n=1): + super().update(n) + # 使用Rich的真彩色支持 + style = hex_to_ansi(self.get_current_color()) + self.bar_format = f'{{l_bar}}{style}{{bar}}{hex_to_ansi.reset}{{r_bar}}' + self.refresh() + + +if __name__ == "__main__": + # 使用示例 + for i in ColoredTqdm(range(100), desc="🌈 彩虹渐变"): + time.sleep(0.1) \ No newline at end of file diff --git a/ppmat/datasets/IRDataset/compound_tools.py b/ppmat/datasets/IRDataset/compound_tools.py new file mode 100644 index 00000000..bf47cd28 --- /dev/null +++ b/ppmat/datasets/IRDataset/compound_tools.py @@ -0,0 +1,828 @@ +import numpy as np +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import rdchem + +DAY_LIGHT_FG_SMARTS_LIST = [ + # C + "[CX4]", + "[$([CX2](=C)=C)]", + "[$([CX3]=[CX3])]", + "[$([CX2]#C)]", + # C & O + "[CX3]=[OX1]", + "[$([CX3]=[OX1]),$([CX3+]-[OX1-])]", + "[CX3](=[OX1])C", + "[OX1]=CN", + "[CX3](=[OX1])O", + "[CX3](=[OX1])[F,Cl,Br,I]", + "[CX3H1](=O)[#6]", + "[CX3](=[OX1])[OX2][CX3](=[OX1])", + "[NX3][CX3](=[OX1])[#6]", + "[NX3][CX3]=[NX3+]", + "[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]", + "[NX3][CX3](=[OX1])[OX2H0]", + "[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]", + "[CX3](=O)[O-]", + "[CX3](=[OX1])(O)O", + "[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]", + "C[OX2][CX3](=[OX1])[OX2]C", + "[CX3](=O)[OX2H1]", + "[CX3](=O)[OX1H0-,OX2H1]", + "[NX3][CX2]#[NX1]", + "[#6][CX3](=O)[OX2H0][#6]", + "[#6][CX3](=O)[#6]", + "[OD2]([#6])[#6]", + # H + "[H]", + "[!#1]", + "[H+]", + "[+H]", + "[!H]", + # N + "[NX3;H2,H1;!$(NC=O)]", + "[NX3][CX3]=[CX3]", + "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]", + "[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]", + "[NX3][$(C=C),$(cc)]", + "[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]", + "[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]", + "[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH3X4]", + "[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]", + "[CH2X4][CX3](=[OX1])[NX3H2]", + "[CH2X4][CX3](=[OX1])[OH0-,OH]", + "[CH2X4][SX2H,SX1H0-]", + "[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]", + "[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]", + "[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\ +[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1", + "[CHX4]([CH3X4])[CH2X4][CH3X4]", + "[CH2X4][CHX4]([CH3X4])[CH3X4]", + "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]", + "[CH2X4][CH2X4][SX2][CH3X4]", + "[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1", + "[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH2X4][OX2H]", + "[NX3][CX3]=[SX1]", + "[CHX4]([CH3X4])[OX2H]", + "[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12", + "[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1", + "[CHX4]([CH3X4])[CH3X4]", + "N[CX4H2][CX3](=[OX1])[O,N]", + "N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]", + "[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]", + "[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]", + "[#7]", + "[NX2]=N", + "[NX2]=[NX2]", + "[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]", + "[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]", + "[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]", + "[NX3][NX3]", + "[NX3][NX2]=[*]", + "[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]", + "[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]", + "[NX3+]=[CX3]", + "[CX3](=[OX1])[NX3H][CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])", + "[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]", + "[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]", + "[NX1]#[CX2]", + "[CX1-]#[NX2+]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[NX2]=[OX1]", + "[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]", + # O + "[OX2H]", + "[#6][OX2H]", + "[OX2H][CX3]=[OX1]", + "[OX2H]P", + "[OX2H][#6X3]=[#6]", + "[OX2H][cX3]:[c]", + "[OX2H][$(C=C),$(cc)]", + "[$([OH]-*=[!#6])]", + "[OX2,OX1-][OX2,OX1-]", + # P + "[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\ +$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\ +,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]", + "[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\ +$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\ +$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]", + # S + "[S-][CX3](=S)[#6]", + "[#6X3](=[SX1])([!N])[!N]", + "[SX2]", + "[#16X2H]", + "[#16!H0]", + "[#16X2H0]", + "[#16X2H0][!#16]", + "[#16X2H0][#16X2H0]", + "[#16X2H0][!#16].[#16X2H0][!#16]", + "[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]", + "[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]", + "[SX4](C)(C)(=O)=N", + "[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]", + "[$([#16X3]=[OX1]),$([#16X3+][OX1-])]", + "[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]", + "[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]", + "[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]", + "[#16X2][OX2H,OX1H0-]", + "[#16X2][OX2H0]", + # X + "[#6][F,Cl,Br,I]", + "[F,Cl,Br,I]", + "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]", + ] + + +def get_gasteiger_partial_charges(mol, n_iter=12): + """ + Calculates list of gasteiger partial charges for each atom in mol object. + Args: + mol: rdkit mol object. + n_iter(int): number of iterations. Default 12. + Returns: + list of computed partial charges for each atom. + """ + Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, + throwOnParamFailure=True) + partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in + mol.GetAtoms()] + return partial_charges + + +def create_standardized_mol_id(smiles): + """ + Args: + smiles: smiles sequence. + Returns: + inchi. + """ + if check_smiles_validity(smiles): + # remove stereochemistry + smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), + isomericSmiles=False) + mol = AllChem.MolFromSmiles(smiles) + if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 + if '.' in smiles: # if multiple species, pick largest molecule + mol_species_list = split_rdkit_mol_obj(mol) + largest_mol = get_largest_mol(mol_species_list) + inchi = AllChem.MolToInchi(largest_mol) + else: + inchi = AllChem.MolToInchi(mol) + return inchi + else: + return + else: + return + + +def check_smiles_validity(smiles): + """ + Check whether the smile can't be converted to rdkit mol object. + """ + try: + m = Chem.MolFromSmiles(smiles) + if m: + return True + else: + return False + except Exception as e: + return False + + +def split_rdkit_mol_obj(mol): + """ + Split rdkit mol object containing multiple species or one species into a + list of mol objects or a list containing a single object respectively. + Args: + mol: rdkit mol object. + """ + smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) + smiles_list = smiles.split('.') + mol_species_list = [] + for s in smiles_list: + if check_smiles_validity(s): + mol_species_list.append(AllChem.MolFromSmiles(s)) + return mol_species_list + + +def get_largest_mol(mol_list): + """ + Given a list of rdkit mol objects, returns mol object containing the + largest num of atoms. If multiple containing largest num of atoms, + picks the first one. + Args: + mol_list(list): a list of rdkit mol object. + Returns: + the largest mol. + """ + num_atoms_list = [len(m.GetAtoms()) for m in mol_list] + largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) + return mol_list[largest_mol_idx] + + +def rdchem_enum_to_list(values): + """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} + """ + return [values[i] for i in range(len(values))] + + +def safe_index(alist, elem): + """ + Return index of element e in list l. If e is not present, return the last index + """ + try: + return alist.index(elem) + except ValueError: + return len(alist) - 1 + + +def get_atom_feature_dims(list_acquired_feature_names): + """ tbd + """ + return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) + + +def get_bond_feature_dims(list_acquired_feature_names): + """ tbd + """ + list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) + # +1 for self loop edges + return [_l + 1 for _l in list_bond_feat_dim] + + +class CompoundKit(object): + """ + CompoundKit + """ + atom_vocab_dict = { + "atomic_num": list(range(1, 119)) + ['misc'], + "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), + "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values), + "implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "is_aromatic": [0, 1], + "total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'num_radical_e': [0, 1, 2, 3, 4, 'misc'], + 'atom_is_in_ring': [0, 1], + 'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + } + bond_vocab_dict = { + "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), + "bond_type": rdchem_enum_to_list(rdchem.BondType.values), + "is_in_ring": [0, 1], + + 'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values), + 'is_conjugated': [0, 1], + } + # float features + atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] + # bond_float_feats= ["bond_length", "bond_angle"] # optional + + ### functional groups + day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST + day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] + + morgan_fp_N = 200 + morgan2048_fp_N = 2048 + maccs_fp_N = 167 + + period_table = Chem.GetPeriodicTable() + + ### atom + + @staticmethod + def get_atom_value(atom, name): + """get atom values""" + if name == 'atomic_num': + return atom.GetAtomicNum() + elif name == 'chiral_tag': + return atom.GetChiralTag() + elif name == 'degree': + return atom.GetDegree() + elif name == 'explicit_valence': + return atom.GetExplicitValence() + elif name == 'formal_charge': + return atom.GetFormalCharge() + elif name == 'hybridization': + return atom.GetHybridization() + elif name == 'implicit_valence': + return atom.GetImplicitValence() + elif name == 'is_aromatic': + return int(atom.GetIsAromatic()) + elif name == 'mass': + return int(atom.GetMass()) + elif name == 'total_numHs': + return atom.GetTotalNumHs() + elif name == 'num_radical_e': + return atom.GetNumRadicalElectrons() + elif name == 'atom_is_in_ring': + return int(atom.IsInRing()) + elif name == 'valence_out_shell': + return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) + else: + raise ValueError(name) + + @staticmethod + def get_atom_feature_id(atom, name): + """get atom features id""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) + + @staticmethod + def get_atom_feature_size(name): + """get atom features size""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return len(CompoundKit.atom_vocab_dict[name]) + + ### bond + + @staticmethod + def get_bond_value(bond, name): + """get bond values""" + if name == 'bond_dir': + return bond.GetBondDir() + elif name == 'bond_type': + return bond.GetBondType() + elif name == 'is_in_ring': + return int(bond.IsInRing()) + elif name == 'is_conjugated': + return int(bond.GetIsConjugated()) + elif name == 'bond_stereo': + return bond.GetStereo() + else: + raise ValueError(name) + + @staticmethod + def get_bond_feature_id(bond, name): + """get bond features id""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) + + @staticmethod + def get_bond_feature_size(name): + """get bond features size""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return len(CompoundKit.bond_vocab_dict[name]) + + ### fingerprint + + @staticmethod + def get_morgan_fingerprint(mol, radius=2): + """get morgan fingerprint""" + nBits = CompoundKit.morgan_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_morgan2048_fingerprint(mol, radius=2): + """get morgan2048 fingerprint""" + nBits = CompoundKit.morgan2048_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_maccs_fingerprint(mol): + """get maccs fingerprint""" + fp = AllChem.GetMACCSKeysFingerprint(mol) + return [int(b) for b in fp.ToBitString()] + + ### functional groups + + @staticmethod + def get_daylight_functional_group_counts(mol): + """get daylight functional group counts""" + fg_counts = [] + for fg_mol in CompoundKit.day_light_fg_mo_list: + sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) + fg_counts.append(len(sub_structs)) + return fg_counts + + @staticmethod + def get_ring_size(mol): + """return (N,6) list""" + rings = mol.GetRingInfo() + rings_info = [] + for r in rings.AtomRings(): + rings_info.append(r) + ring_list = [] + for atom in mol.GetAtoms(): + atom_result = [] + for ringsize in range(3, 9): + num_of_ring_at_ringsize = 0 + for r in rings_info: + if len(r) == ringsize and atom.GetIdx() in r: + num_of_ring_at_ringsize += 1 + if num_of_ring_at_ringsize > 8: + num_of_ring_at_ringsize = 9 + atom_result.append(num_of_ring_at_ringsize) + + ring_list.append(atom_result) + return ring_list + + @staticmethod + def atom_to_feat_vector(atom): + """ tbd """ + atom_names = { + "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), + "chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()), + "degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()), + "explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()), + "formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()), + "hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()), + "implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()), + "is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())), + "total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()), + 'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()), + 'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())), + 'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'], + CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())), + 'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()), + 'partial_charge': CompoundKit.check_partial_charge(atom), + 'mass': atom.GetMass(), + } + return atom_names + + @staticmethod + def get_atom_names(mol): + """get atom name list + TODO: to be remove in the future + """ + atom_features_dicts = [] + Chem.rdPartialCharges.ComputeGasteigerCharges(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) + + ring_list = CompoundKit.get_ring_size(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) + atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) + atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) + atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) + atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) + atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) + + return atom_features_dicts + + @staticmethod + def check_partial_charge(atom): + """tbd""" + pc = atom.GetDoubleProp('_GasteigerCharge') + if pc != pc: + # unsupported atom, replace nan with 0 + pc = 0 + if pc == float('inf'): + # max 4 for other atoms, set to 10 here if inf is get + pc = 10 + return pc + + +class Compound3DKit(object): + """the 3Dkit of Compound""" + + @staticmethod + def get_atom_poses(mol, conf): + """tbd""" + atom_poses = [] + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) + pos = conf.GetAtomPosition(i) + atom_poses.append([pos.x, pos.y, pos.z]) + return atom_poses + + @staticmethod + def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): + """the atoms of mol will be changed in some cases.""" + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return mol,atom_poses + # try: + # new_mol = Chem.AddHs(mol) + # res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs) + # ### MMFF generates multiple conformations + # res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + # new_mol = Chem.RemoveHs(new_mol) + # index = np.argmin([x[1] for x in res]) + # energy = res[index][1] + # conf = new_mol.GetConformer(id=int(index)) + # except: + # new_mol = mol + # AllChem.Compute2DCoords(new_mol) + # energy = 0 + # conf = new_mol.GetConformer() + # + # atom_poses = Compound3DKit.get_atom_poses(new_mol, conf) + # if return_energy: + # return new_mol, atom_poses, energy + # else: + # return new_mol, atom_poses + + @staticmethod + def get_2d_atom_poses(mol): + """get 2d atom poses""" + AllChem.Compute2DCoords(mol) + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return atom_poses + + @staticmethod + def get_bond_lengths(edges, atom_poses): + """get bond lengths""" + bond_lengths = [] + for src_node_i, tar_node_j in edges: + bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) + bond_lengths = np.array(bond_lengths, 'float32') + return bond_lengths + + @staticmethod + def get_superedge_angles(edges, atom_poses, dir_type='HT'): + """get superedge angles""" + + def _get_vec(atom_poses, edge): + return atom_poses[edge[1]] - atom_poses[edge[0]] + + def _get_angle(vec1, vec2): + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + if norm1 == 0 or norm2 == 0: + return 0 + vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors + vec2 = vec2 / (norm2 + 1e-5) + angle = np.arccos(np.dot(vec1, vec2)) + return angle + + E = len(edges) + edge_indices = np.arange(E) + super_edges = [] + bond_angles = [] + bond_angle_dirs = [] + for tar_edge_i in range(E): + tar_edge = edges[tar_edge_i] + if dir_type == 'HT': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] + elif dir_type == 'HH': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] + else: + raise ValueError(dir_type) + for src_edge_i in src_edge_indices: + if src_edge_i == tar_edge_i: + continue + src_edge = edges[src_edge_i] + src_vec = _get_vec(atom_poses, src_edge) + tar_vec = _get_vec(atom_poses, tar_edge) + super_edges.append([src_edge_i, tar_edge_i]) + angle = _get_angle(src_vec, tar_vec) + bond_angles.append(angle) + bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T + + if len(super_edges) == 0: + super_edges = np.zeros([0, 2], 'int64') + bond_angles = np.zeros([0, ], 'float32') + else: + super_edges = np.array(super_edges, 'int64') + bond_angles = np.array(bond_angles, 'float32') + return super_edges, bond_angles, bond_angle_dirs + + +def new_smiles_to_graph_data(smiles, **kwargs): + """ + Convert smiles to graph data. + """ + mol = AllChem.MolFromSmiles(smiles) + if mol is None: + return None + data = new_mol_to_graph_data(mol) + return data + + +def new_mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names + bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) + + data = {} + + ### atom features + data = {name: [] for name in atom_id_names} + + raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) + for atom_feat in raw_atom_feat_dicts: + for name in atom_id_names: + data[name].append(atom_feat[name]) + + ### bond and bond features + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + data[name] += [bond_feature_id] * 2 + + #### self loop + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1 + data[name] += [bond_feature_id] * N + + ### make ndarray and check length + for name in list(CompoundKit.atom_vocab_dict.keys()): + data[name] = np.array(data[name], 'int64') + for name in CompoundKit.atom_float_names: + data[name] = np.array(data[name], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", + ] + bond_id_names = [ + "bond_dir", "bond_type", "is_in_ring", + ] + + data = {} + for name in atom_id_names: + data[name] = [] + data['mass'] = [] + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + ### atom features + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return None + for name in atom_id_names: + data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV + data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) + + ### bond features + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV + data[name] += [bond_feature_id] * 2 + + ### self loop (+2) + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop + data[name] += [bond_feature_id] * N + + ### check whether edge exists + if len(data['edges']) == 0: # mol has no bonds + for name in bond_id_names: + data[name] = np.zeros((0,), dtype="int64") + data['edges'] = np.zeros((0, 2), dtype="int64") + + ### make ndarray and check length + for name in atom_id_names: + data[name] = np.array(data[name], 'int64') + data['mass'] = np.array(data['mass'], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_geognn_graph_data(mol, atom_poses, dir_type): + """ + mol: rdkit molecule + dir_type: direction type for bond_angle grpah + """ + if len(mol.GetAtoms()) == 0: + return None + + data = mol_to_graph_data(mol) + + data['atom_pos'] = np.array(atom_poses, 'float32') + data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) + BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ + Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) + data['BondAngleGraph_edges'] = BondAngleGraph_edges + data['bond_angle'] = np.array(bond_angles, 'float32') + return data + + +def mol_to_geognn_graph_data_MMFF3d(mol): + """tbd""" + if len(mol.GetAtoms()) <= 400: + mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) + else: + atom_poses = Compound3DKit.get_2d_atom_poses(mol) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + + +def mol_to_geognn_graph_data_raw3d(mol): + """tbd""" + atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + +def obtain_3D_mol(smiles,name): + mol = AllChem.MolFromSmiles(smiles) + new_mol = Chem.AddHs(mol) + res = AllChem.EmbedMultipleConfs(new_mol) + ### MMFF generates multiple conformations + res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + new_mol = Chem.RemoveHs(new_mol) + Chem.MolToMolFile(new_mol, name+'.mol') + return new_mol + +def predict_SMILES_info(smiles): + # by lihao, input smiles, output dict + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + info_dict = mol_to_geognn_graph_data_MMFF3d(mol) + return info_dict + +if __name__ == "__main__": + # smiles = "OCc1ccccc1CN" + smiles = r"[H]/[NH+]=C(\N)C1=CC(=O)/C(=C\C=c2ccc(=C(N)[NH3+])cc2)C=C1" + # smiles = 'CC' + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + data = mol_to_geognn_graph_data_MMFF3d(mol) + for key, value in data.items(): + print(key, value.shape) \ No newline at end of file diff --git a/ppmat/datasets/IRDataset/place_env.py b/ppmat/datasets/IRDataset/place_env.py new file mode 100644 index 00000000..6d06a504 --- /dev/null +++ b/ppmat/datasets/IRDataset/place_env.py @@ -0,0 +1,173 @@ +import paddle +import functools +from contextlib import contextmanager + +@contextmanager +def place_env(place): + """ + 上下文管理器,用于临时设置PaddlePaddle的运行设备 + + Args: + place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + + 用法: + with place_env(paddle.CPUPlace()): + # 这里的代码在CPU上运行 + x = paddle.rand([2, 3]) + print(x) + + @place_env(paddle.CUDAPlace(0)) + def train(): + # 这个函数在GPU上运行 + pass + """ + # 保存当前的设备设置 + current_device = paddle.get_device() + + # 根据place类型设置设备 + if isinstance(place, paddle.CPUPlace): + paddle.set_device('cpu') + elif isinstance(place, paddle.CUDAPlace): + # 获取GPU设备ID + device_id = place.get_device_id() + paddle.set_device(f'gpu:{device_id}') + else: + raise ValueError(f"不支持的place类型: {type(place)}") + + try: + yield + finally: + # 恢复原来的设备设置 + paddle.set_device(current_device) + + +class PlaceEnv: + """ + 类版本的上下文管理器,也支持装饰器功能 + """ + + def __init__(self, place): + """ + 初始化PlaceEnv + + Args: + place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + """ + self.place = place + self.original_device = None + + def __enter__(self): + """进入上下文时调用""" + # 保存当前的设备设置 + self.original_device = paddle.get_device() + + # 根据place类型设置设备 + if isinstance(self.place, paddle.CPUPlace): + paddle.set_device('cpu') + elif isinstance(self.place, paddle.CUDAPlace): + device_id = self.place.get_device_id() + paddle.set_device(f'gpu:{device_id}') + else: + raise ValueError(f"不支持的place类型: {type(self.place)}") + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """退出上下文时调用""" + # 恢复原来的设备设置 + if self.original_device is not None: + paddle.set_device(self.original_device) + + def __call__(self, func): + """ + 使实例可以作为装饰器使用 + + Args: + func: 要装饰的函数 + + Returns: + 装饰后的函数 + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + # 使用with语句来临时改变设备设置 + with self: + return func(*args, **kwargs) + return wrapper + + +# 为了兼容性,也可以保留函数版本的上下文管理器 +@contextmanager +def with_place_env(place): + """ + with_place_env的别名,与place_env功能相同 + """ + with place_env(place): + yield + + +# 使用示例 +if __name__ == "__main__": + # 测试with语句 + print("=== 测试with语句 ===") + print(f"当前设备: {paddle.get_device()}") + + with place_env(paddle.CPUPlace()): + print(f"with块内设备: {paddle.get_device()}") + x = paddle.rand([2, 3]) + print(f"创建的张量: {x}") + + print(f"with块外设备: {paddle.get_device()}") + + print("\n=== 测试类版本with语句 ===") + with PlaceEnv(paddle.CPUPlace()): + print(f"with块内设备: {paddle.get_device()}") + y = paddle.ones([2, 3]) + print(f"创建的张量: {y}") + + print(f"with块外设备: {paddle.get_device()}") + + # 测试装饰器功能 + print("\n=== 测试装饰器功能 ===") + + @PlaceEnv(paddle.CPUPlace()) + def cpu_function(): + """这个函数会在CPU上运行""" + print(f"函数内设备: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # 检查是否有GPU可用 + if paddle.device.cuda.device_count() > 0: + @PlaceEnv(paddle.CUDAPlace(0)) + def gpu_function(): + """这个函数会在GPU上运行""" + print(f"函数内设备: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # 调用装饰后的函数 + print("调用cpu_function:") + result_cpu = cpu_function() + print(f"函数执行后设备: {paddle.get_device()}") + print(f"结果: {result_cpu}") + + if paddle.device.cuda.device_count() > 0: + print("\n调用gpu_function:") + result_gpu = gpu_function() + print(f"函数执行后设备: {paddle.get_device()}") + print(f"结果: {result_gpu}") + + print("\n=== 测试多层嵌套 ===") + print(f"初始设备: {paddle.get_device()}") + + with PlaceEnv(paddle.CPUPlace()): + print(f"第一层with内设备: {paddle.get_device()}") + + if paddle.device.cuda.device_count() > 0: + with PlaceEnv(paddle.CUDAPlace(0)): + print(f"第二层with内设备: {paddle.get_device()}") + z = paddle.rand([2, 2]) + print(f"创建的张量: {z}") + + print(f"回到第一层with设备: {paddle.get_device()}") + + print(f"最终设备: {paddle.get_device()}") \ No newline at end of file diff --git a/ppmat/datasets/__init__.py b/ppmat/datasets/__init__.py index 98eec451..86d6eaad 100644 --- a/ppmat/datasets/__init__.py +++ b/ppmat/datasets/__init__.py @@ -47,6 +47,9 @@ from ppmat.datasets.oc20_s2ef_dataset import OC20S2EFDataset # noqa from ppmat.datasets.qm9_dataset import QM9Dataset # noqa from ppmat.datasets.omol25_dataset import OMol25Dataset +from ppmat.datasets.IRDataset import IRDataset, IRDataLoader +from ppmat.datasets.ECDFormerDataset import ECDFormerDataset, ECDFormerDataset_DataLoader +from ppmat.datasets.omol25_dataset import OMol25Dataset from ppmat.datasets.split_mptrj_data import none_to_zero from ppmat.datasets.transform import build_transforms from ppmat.utils import logger @@ -67,6 +70,10 @@ "DensityDataset", "SmallDensityDataset", "OMol25Dataset", + "IRDataset", + "ECDFormerDataset", + "IRDataLoader", + "ECDFormerDataset_DataLoader", ] INFO_CLASS_REGISTRY: Dict[str, type] = { From e7abf4ca4ed03697a589381e43b456915239c967 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <108660099+PlumBlossomMaid@users.noreply.github.com> Date: Tue, 24 Feb 2026 17:32:35 +0800 Subject: [PATCH 03/16] Remove duplicate import Removed duplicate import of OMol25Dataset. --- ppmat/datasets/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppmat/datasets/__init__.py b/ppmat/datasets/__init__.py index 86d6eaad..f95adcec 100644 --- a/ppmat/datasets/__init__.py +++ b/ppmat/datasets/__init__.py @@ -49,7 +49,6 @@ from ppmat.datasets.omol25_dataset import OMol25Dataset from ppmat.datasets.IRDataset import IRDataset, IRDataLoader from ppmat.datasets.ECDFormerDataset import ECDFormerDataset, ECDFormerDataset_DataLoader -from ppmat.datasets.omol25_dataset import OMol25Dataset from ppmat.datasets.split_mptrj_data import none_to_zero from ppmat.datasets.transform import build_transforms from ppmat.utils import logger From 8b114ff7fe478fb1a8fd6d0e292f7971d08c9799 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Fri, 27 Feb 2026 23:22:41 +0800 Subject: [PATCH 04/16] =?UTF-8?q?=E6=A0=B9=E6=8D=AEReview=E8=A6=81?= =?UTF-8?q?=E6=B1=82=EF=BC=8C=E6=B7=BB=E5=8A=A0=E7=89=88=E6=9D=83=E5=A3=B0?= =?UTF-8?q?=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/datasets/ECDFormerDataset/__init__.py | 15 ++++++++++++++- ppmat/datasets/ECDFormerDataset/colored_tqdm.py | 14 ++++++++++++++ ppmat/datasets/ECDFormerDataset/compound_tools.py | 14 ++++++++++++++ ppmat/datasets/ECDFormerDataset/dataloader.py | 14 ++++++++++++++ ppmat/datasets/ECDFormerDataset/eval_func.py | 14 ++++++++++++++ ppmat/datasets/ECDFormerDataset/place_env.py | 14 ++++++++++++++ ppmat/datasets/ECDFormerDataset/util_func.py | 14 ++++++++++++++ ppmat/datasets/IRDataset/__init__.py | 15 ++++++++++++++- ppmat/datasets/IRDataset/colored_tqdm.py | 14 ++++++++++++++ ppmat/datasets/IRDataset/compound_tools.py | 14 ++++++++++++++ ppmat/datasets/IRDataset/place_env.py | 14 ++++++++++++++ 11 files changed, 154 insertions(+), 2 deletions(-) diff --git a/ppmat/datasets/ECDFormerDataset/__init__.py b/ppmat/datasets/ECDFormerDataset/__init__.py index 7934e0ee..e84dfb8a 100644 --- a/ppmat/datasets/ECDFormerDataset/__init__.py +++ b/ppmat/datasets/ECDFormerDataset/__init__.py @@ -1,4 +1,17 @@ -# __init__.py +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ ECDFormer数据集加载模块 """ diff --git a/ppmat/datasets/ECDFormerDataset/colored_tqdm.py b/ppmat/datasets/ECDFormerDataset/colored_tqdm.py index c6b9cff0..8176778b 100644 --- a/ppmat/datasets/ECDFormerDataset/colored_tqdm.py +++ b/ppmat/datasets/ECDFormerDataset/colored_tqdm.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from tqdm import tqdm import time import os;os.system("") #兼容windows diff --git a/ppmat/datasets/ECDFormerDataset/compound_tools.py b/ppmat/datasets/ECDFormerDataset/compound_tools.py index bf47cd28..cee0c635 100644 --- a/ppmat/datasets/ECDFormerDataset/compound_tools.py +++ b/ppmat/datasets/ECDFormerDataset/compound_tools.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np from rdkit import Chem from rdkit.Chem import AllChem diff --git a/ppmat/datasets/ECDFormerDataset/dataloader.py b/ppmat/datasets/ECDFormerDataset/dataloader.py index 848c9ae4..cfbb15b8 100644 --- a/ppmat/datasets/ECDFormerDataset/dataloader.py +++ b/ppmat/datasets/ECDFormerDataset/dataloader.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle from paddle.io import Dataset diff --git a/ppmat/datasets/ECDFormerDataset/eval_func.py b/ppmat/datasets/ECDFormerDataset/eval_func.py index 5bd5e1ab..72f772ba 100644 --- a/ppmat/datasets/ECDFormerDataset/eval_func.py +++ b/ppmat/datasets/ECDFormerDataset/eval_func.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import json diff --git a/ppmat/datasets/ECDFormerDataset/place_env.py b/ppmat/datasets/ECDFormerDataset/place_env.py index 6d06a504..62acf945 100644 --- a/ppmat/datasets/ECDFormerDataset/place_env.py +++ b/ppmat/datasets/ECDFormerDataset/place_env.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle import functools from contextlib import contextmanager diff --git a/ppmat/datasets/ECDFormerDataset/util_func.py b/ppmat/datasets/ECDFormerDataset/util_func.py index 010ef951..49bac40a 100644 --- a/ppmat/datasets/ECDFormerDataset/util_func.py +++ b/ppmat/datasets/ECDFormerDataset/util_func.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + def has_element_in_range(lst, lower_bound, upper_bound): """ 检查给定列表 lst 中是否存在元素在指定的区间 [lower_bound, upper_bound] 内。 diff --git a/ppmat/datasets/IRDataset/__init__.py b/ppmat/datasets/IRDataset/__init__.py index 0cb43d2d..b0e4cd71 100644 --- a/ppmat/datasets/IRDataset/__init__.py +++ b/ppmat/datasets/IRDataset/__init__.py @@ -1,4 +1,17 @@ -# IRDataset.py +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ IR光谱预测数据集模块 支持预加载的npy文件,包含缓存机制,默认使用100样本的小数据集 diff --git a/ppmat/datasets/IRDataset/colored_tqdm.py b/ppmat/datasets/IRDataset/colored_tqdm.py index c6b9cff0..8176778b 100644 --- a/ppmat/datasets/IRDataset/colored_tqdm.py +++ b/ppmat/datasets/IRDataset/colored_tqdm.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from tqdm import tqdm import time import os;os.system("") #兼容windows diff --git a/ppmat/datasets/IRDataset/compound_tools.py b/ppmat/datasets/IRDataset/compound_tools.py index bf47cd28..cee0c635 100644 --- a/ppmat/datasets/IRDataset/compound_tools.py +++ b/ppmat/datasets/IRDataset/compound_tools.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np from rdkit import Chem from rdkit.Chem import AllChem diff --git a/ppmat/datasets/IRDataset/place_env.py b/ppmat/datasets/IRDataset/place_env.py index 6d06a504..62acf945 100644 --- a/ppmat/datasets/IRDataset/place_env.py +++ b/ppmat/datasets/IRDataset/place_env.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle import functools from contextlib import contextmanager From 33964a0e29ea42ee81b49ca0679f8179921ba475 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Sun, 1 Mar 2026 19:13:03 +0800 Subject: [PATCH 05/16] =?UTF-8?q?=E7=A7=BB=E5=8A=A8loss=E5=92=8Cmetrics?= =?UTF-8?q?=E5=88=B0=E5=AF=B9=E5=BA=94=E7=9A=84=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/losses/__init__.py | 4 + ppmat/losses/ecd_loss.py | 132 ++++++ ppmat/losses/ir_loss.py | 132 ++++++ ppmat/metrics/__init__.py | 4 + ppmat/metrics/ecd_metric.py | 318 +++++++++++++ ppmat/metrics/ir_metric.py | 178 ++++++++ ppmat/models/ecformer/models/ECD.py | 106 +---- ppmat/models/ecformer/models/IR.py | 105 +---- ppmat/models/ecformer/models/base_ecformer.py | 12 +- .../models/ecformer/utils/loss/dilate_loss.py | 25 - .../ecformer/utils/loss/path_soft_dtw.py | 134 ------ ppmat/models/ecformer/utils/loss/soft_dtw.py | 97 ---- .../ecformer/utils/loss/soft_dtw_cuda.py | 427 ------------------ 13 files changed, 771 insertions(+), 903 deletions(-) create mode 100644 ppmat/losses/ecd_loss.py create mode 100644 ppmat/losses/ir_loss.py create mode 100644 ppmat/metrics/ecd_metric.py create mode 100644 ppmat/metrics/ir_metric.py delete mode 100644 ppmat/models/ecformer/utils/loss/dilate_loss.py delete mode 100644 ppmat/models/ecformer/utils/loss/path_soft_dtw.py delete mode 100644 ppmat/models/ecformer/utils/loss/soft_dtw.py delete mode 100644 ppmat/models/ecformer/utils/loss/soft_dtw_cuda.py diff --git a/ppmat/losses/__init__.py b/ppmat/losses/__init__.py index e5258105..085290c9 100644 --- a/ppmat/losses/__init__.py +++ b/ppmat/losses/__init__.py @@ -18,6 +18,8 @@ from ppmat.losses.l1_loss import L1Loss from ppmat.losses.l1_loss import MAELoss from ppmat.losses.l1_loss import SmoothL1Loss +from ppmat.losses.ecd_loss import ECDLoss +from ppmat.losses.ir_loss import IRLoss from ppmat.losses.loss_warper import LossWarper from ppmat.losses.mse_loss import MSELoss @@ -25,6 +27,8 @@ "MSELoss", "L1Loss", "SmoothL1Loss", + "ECDLoss", + "IRLoss", "MAELoss", "HuberLoss", "LossWarper", diff --git a/ppmat/losses/ecd_loss.py b/ppmat/losses/ecd_loss.py new file mode 100644 index 00000000..05e00dd0 --- /dev/null +++ b/ppmat/losses/ecd_loss.py @@ -0,0 +1,132 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + + +class ECDLoss(nn.Layer): + """Loss function for ECFormer ECD task. + + Combines three cross-entropy losses for peak number, position, and symbol, + with symbol loss weighted as in the original paper. + """ + + def __init__(self, loss_weight_height=2.0, num_position_classes=20, height_classes=2): + """ + Args: + loss_weight_height (float): Weight for peak symbol loss (2.0 in paper) + num_position_classes (int): Number of position classes (default: 20) + height_classes (int): Number of symbol classes (default: 2: positive/negative) + """ + super().__init__() + self.ce_loss = nn.CrossEntropyLoss() + self.loss_weight_height = loss_weight_height + self.num_position_classes = num_position_classes + self.height_classes = height_classes + + # Accumulators for epoch-level statistics + self.reset() + + def forward(self, predictions, targets): + """ + Compute ECD task losses. + + Args: + predictions (dict): Model outputs containing: + - peak_number (Tensor): [batch_size, max_peaks] logits for peak count + - peak_position (Tensor): [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height (Tensor): [batch_size, max_peaks, height_classes] logits for symbols + targets (dict): Ground truth containing: + - peak_num (Tensor): [batch_size] true peak counts + - peak_position (Tensor): [batch_size, max_peaks] true position labels + - peak_height (Tensor): [batch_size, max_peaks] true symbol labels + + Returns: + dict: Loss components and total loss + """ + # Peak number loss + loss_num = self.ce_loss(predictions['peak_number'], targets['peak_num']) + + batch_size = targets['peak_num'].shape[0] + + loss_pos_total = 0.0 + loss_height_total = 0.0 + valid_samples = 0 + + for i in range(batch_size): + n_peaks = int(targets['peak_num'][i]) + if n_peaks == 0: + continue + + # Position loss (only for valid peaks) + pos_pred = predictions['peak_position'][i, :n_peaks, :].reshape([-1, self.num_position_classes]) + pos_gt = targets['peak_position'][i, :n_peaks].reshape([-1]) + loss_pos_total += self.ce_loss(pos_pred, pos_gt) + + # Symbol loss + height_pred = predictions['peak_height'][i, :n_peaks, :].reshape([-1, self.height_classes]) + height_gt = targets['peak_height'][i, :n_peaks].reshape([-1]) + loss_height_total += self.ce_loss(height_pred, height_gt) + + valid_samples += 1 + + if valid_samples > 0: + loss_pos = loss_pos_total / valid_samples + loss_height = loss_height_total / valid_samples + else: + loss_pos = paddle.to_tensor(0.0) + loss_height = paddle.to_tensor(0.0) + + # Total loss with weighted symbol term + total_loss = loss_num + self.loss_weight_height * loss_height + loss_pos + + # Update accumulators for epoch statistics + self._accumulate(loss_num, loss_pos, loss_height, valid_samples) + + return { + "loss": total_loss, + "loss_num": loss_num, + "loss_pos": loss_pos, + "loss_height": loss_height, + } + + def _accumulate(self, loss_num, loss_pos, loss_height, valid_samples): + """Accumulate losses for epoch-level statistics.""" + self.loss_num_sum += loss_num.item() if hasattr(loss_num, 'item') else loss_num + self.loss_pos_sum += loss_pos.item() if hasattr(loss_pos, 'item') else loss_pos + self.loss_height_sum += loss_height.item() if hasattr(loss_height, 'item') else loss_height + self.total_samples += valid_samples + + def reset(self): + """Reset accumulated statistics.""" + self.loss_num_sum = 0.0 + self.loss_pos_sum = 0.0 + self.loss_height_sum = 0.0 + self.total_samples = 0 + + def log_epoch_metrics(self): + """Return epoch-level loss statistics.""" + if self.total_samples == 0: + return { + "train_epoch/loss_num": -1.0, + "train_epoch/loss_pos": -1.0, + "train_epoch/loss_height": -1.0, + } + + return { + "train_epoch/loss_num": self.loss_num_sum / self.total_samples, + "train_epoch/loss_pos": self.loss_pos_sum / self.total_samples, + "train_epoch/loss_height": self.loss_height_sum / self.total_samples, + } \ No newline at end of file diff --git a/ppmat/losses/ir_loss.py b/ppmat/losses/ir_loss.py new file mode 100644 index 00000000..ea31114a --- /dev/null +++ b/ppmat/losses/ir_loss.py @@ -0,0 +1,132 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + + +class IRLoss(nn.Layer): + """Loss function for ECFormer IR task. + + Combines cross-entropy loss for peak position and peak number, + and MSE loss for peak intensity (height) regression. + """ + + def __init__(self, num_position_classes=36, use_height_prediction=True): + """ + Args: + num_position_classes (int): Number of position classes (default: 36 for IR) + use_height_prediction (bool): Whether to use height (intensity) regression loss + """ + super().__init__() + self.ce_loss = nn.CrossEntropyLoss() + self.mse_loss = nn.MSELoss(reduction='mean') + self.num_position_classes = num_position_classes + self.use_height_prediction = use_height_prediction + + # Accumulators for epoch-level statistics + self.reset() + + def forward(self, predictions, targets): + """ + Compute IR task losses. + + Args: + predictions (dict): Model outputs containing: + - peak_number (Tensor): [batch_size, max_peaks+1] logits for peak count + - peak_position (Tensor): [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height (Tensor, optional): [batch_size, max_peaks] predicted intensity values + targets (dict): Ground truth containing: + - peak_num (Tensor): [batch_size] true peak counts + - peak_position (Tensor): [batch_size, max_peaks] true position labels + - peak_height (Tensor): [batch_size, max_peaks] true intensity values + + Returns: + dict: Loss components and total loss + """ + # Peak number loss + loss_num = self.ce_loss(predictions['peak_number'], targets['peak_num']) + + batch_size = targets['peak_num'].shape[0] + + loss_pos_total = 0.0 + loss_height_total = 0.0 + valid_samples = 0 + + for i in range(batch_size): + n_peaks = int(targets['peak_num'][i]) + if n_peaks == 0: + continue + + # Position loss (cross-entropy) + pos_pred = predictions['peak_position'][i, :n_peaks, :].reshape([-1, self.num_position_classes]) + pos_gt = targets['peak_position'][i, :n_peaks].reshape([-1]) + loss_pos_total += self.ce_loss(pos_pred, pos_gt) + + # Height loss (MSE regression) if enabled + if self.use_height_prediction and 'peak_height' in predictions: + height_pred = predictions['peak_height'][i, :n_peaks].reshape([-1]) + height_gt = targets['peak_height'][i, :n_peaks].reshape([-1]) + loss_height_total += self.mse_loss(height_pred, height_gt) + + valid_samples += 1 + + # Average losses over valid samples + loss_pos = loss_pos_total / valid_samples if valid_samples > 0 else paddle.to_tensor(0.0) + total_loss = loss_num + loss_pos + + if self.use_height_prediction: + loss_height = loss_height_total / valid_samples if valid_samples > 0 else paddle.to_tensor(0.0) + total_loss += loss_height + else: + loss_height = paddle.to_tensor(0.0) + + # Update accumulators for epoch statistics + self._accumulate(loss_num, loss_pos, loss_height, valid_samples) + + return { + "loss": total_loss, + "loss_num": loss_num, + "loss_pos": loss_pos, + "loss_height": loss_height, + } + + def _accumulate(self, loss_num, loss_pos, loss_height, valid_samples): + """Accumulate losses for epoch-level statistics.""" + self.loss_num_sum += loss_num.item() if hasattr(loss_num, 'item') else loss_num + self.loss_pos_sum += loss_pos.item() if hasattr(loss_pos, 'item') else loss_pos + self.loss_height_sum += loss_height.item() if hasattr(loss_height, 'item') else loss_height + self.total_samples += valid_samples + + def reset(self): + """Reset accumulated statistics.""" + self.loss_num_sum = 0.0 + self.loss_pos_sum = 0.0 + self.loss_height_sum = 0.0 + self.total_samples = 0 + + def log_epoch_metrics(self): + """Return epoch-level loss statistics.""" + if self.total_samples == 0: + return { + "train_epoch/loss_num": -1.0, + "train_epoch/loss_pos": -1.0, + "train_epoch/loss_height": -1.0, + } + + return { + "train_epoch/loss_num": self.loss_num_sum / self.total_samples, + "train_epoch/loss_pos": self.loss_pos_sum / self.total_samples, + "train_epoch/loss_height": self.loss_height_sum / self.total_samples, + } \ No newline at end of file diff --git a/ppmat/metrics/__init__.py b/ppmat/metrics/__init__.py index a0e3fb75..7f0b4169 100644 --- a/ppmat/metrics/__init__.py +++ b/ppmat/metrics/__init__.py @@ -17,11 +17,15 @@ import paddle # noqa from ppmat.metrics.csp_metric import CSPMetric +from ppmat.metrics.ecd_metric import ECDMetrics +from ppmat.metrics.ir_metric import IRMetrics from ppmat.metrics.diffnmr_streaming_adapter import DiffNMRStreamingAdapter __all__ = [ "build_metric", "CSPMetric", + "ECDMetrics", + "IRMetrics", "DiffNMRStreamingAdapter", # "DiffNMRMetric", # "NLL", "CrossEntropyMetric", "SumExceptBatchMetric", "SumExceptBatchKL", diff --git a/ppmat/metrics/ecd_metric.py b/ppmat/metrics/ecd_metric.py new file mode 100644 index 00000000..8611e134 --- /dev/null +++ b/ppmat/metrics/ecd_metric.py @@ -0,0 +1,318 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from sklearn.metrics import mean_squared_error +from typing import Dict, List, Optional, Any, Tuple + + +# ========================= +# Utilities +# ========================= + +def _is_dist(): + """Check if distributed training is initialized.""" + try: + import paddle.distributed as dist + return dist.is_initialized() and dist.get_world_size() > 1 + except Exception: + return False + + +def _all_reduce_sum_(t: paddle.Tensor) -> paddle.Tensor: + """In-place SUM all_reduce if distributed; returns t.""" + if _is_dist(): + import paddle.distributed as dist + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t + + +def _to_f32(x) -> paddle.Tensor: + """Convert to float32 tensor.""" + return ( + paddle.to_tensor(x, dtype="float32") + if not isinstance(x, paddle.Tensor) + else x.astype("float32") + ) + + +# ========================= +# ECD Metrics +# ========================= + +class ECDMetrics(nn.Layer): + """Evaluation metrics for ECFormer ECD task. + + Computes: + - Number-RMSE: RMSE of predicted vs true peak count + - Position-RMSE: RMSE of predicted vs true peak positions (class indices) + - Symbol-Acc: Accuracy of predicted peak symbols (positive/negative) + - First-Symbol-Acc: Accuracy of the first peak's symbol + """ + + def __init__(self, num_position_classes=20, max_peaks=9): + """ + Args: + num_position_classes (int): Number of position classes (default: 20) + max_peaks (int): Maximum number of peaks (default: 9) + """ + super().__init__() + self.num_position_classes = num_position_classes + self.max_peaks = max_peaks + + # Accumulators for streaming metrics + self.reset() + + def reset(self): + """Reset all accumulated statistics.""" + self.num_rmse_sum = _to_f32(0.0) + self.pos_rmse_sum = _to_f32(0.0) + self.symbol_correct = _to_f32(0.0) + self.symbol_total = _to_f32(0.0) + self.first_symbol_correct = _to_f32(0.0) + self.first_symbol_total = _to_f32(0.0) + self.num_samples = _to_f32(0.0) + + def update(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]): + """ + Update metrics with a batch of predictions and targets. + + Args: + predictions: dict from model forward + - peak_number: [batch_size, max_peaks] logits for peak count + - peak_position: [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height: [batch_size, max_peaks, 2] logits for symbols + targets: dict from dataloader + - peak_num: [batch_size] true peak counts + - peak_position: [batch_size, max_peaks] true position labels + - peak_height: [batch_size, max_peaks] true symbol labels + """ + batch_size = targets['peak_num'].shape[0] + + # Peak number predictions + pred_nums = predictions['peak_number'].argmax(axis=1) + true_nums = targets['peak_num'] + + # Number RMSE accumulation + num_errors = (pred_nums - true_nums).astype('float32') + self.num_rmse_sum += paddle.sum(paddle.square(num_errors)) + + # Process each sample for position and symbol metrics + for i in range(batch_size): + n_true = int(true_nums[i]) + n_pred = int(pred_nums[i]) + + if n_true > 0 and n_pred > 0: + n_match = min(n_true, n_pred) + + # Position errors (only on matched peaks) + pos_true = targets['peak_position'][i, :n_match].astype('int64') + pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1) + pos_errors = (pos_pred - pos_true).astype('float32') + self.pos_rmse_sum += paddle.sum(paddle.square(pos_errors)) + + # Symbol accuracy + height_true = targets['peak_height'][i, :n_match] + height_pred = predictions['peak_height'][i, :n_match, :].argmax(axis=1) + correct = (height_true == height_pred).astype('float32') + self.symbol_correct += paddle.sum(correct) + self.symbol_total += _to_f32(n_match) + + # First peak symbol accuracy + if height_true[0] == height_pred[0]: + self.first_symbol_correct += _to_f32(1.0) + self.first_symbol_total += _to_f32(1.0) + + self.num_samples += _to_f32(batch_size) + + def accumulate(self) -> Dict[str, float]: + """ + Compute accumulated metrics. + + Returns: + dict: Dictionary containing all metrics + """ + # Distributed reduction + num_rmse_sum = _all_reduce_sum_(self.num_rmse_sum.clone()) + pos_rmse_sum = _all_reduce_sum_(self.pos_rmse_sum.clone()) + symbol_correct = _all_reduce_sum_(self.symbol_correct.clone()) + symbol_total = _all_reduce_sum_(self.symbol_total.clone()) + first_correct = _all_reduce_sum_(self.first_symbol_correct.clone()) + first_total = _all_reduce_sum_(self.first_symbol_total.clone()) + num_samples = _all_reduce_sum_(self.num_samples.clone()) + + # Compute final metrics + num_rmse = paddle.sqrt(num_rmse_sum / paddle.maximum(num_samples, _to_f32(1.0))).item() + + # For position RMSE, we need to average over matched peaks + pos_rmse = paddle.sqrt( + pos_rmse_sum / paddle.maximum(symbol_total, _to_f32(1.0)) + ).item() + + symbol_acc = (symbol_correct / paddle.maximum(symbol_total, _to_f32(1.0))).item() + first_symbol_acc = (first_correct / paddle.maximum(first_total, _to_f32(1.0))).item() + + return { + 'num_rmse': num_rmse, + 'pos_rmse': pos_rmse, + 'symbol_acc': symbol_acc, + 'first_symbol_acc': first_symbol_acc, + } + + def forward(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]) -> Dict[str, float]: + """ + Compute metrics for a single batch (non-streaming version). + + Args: + predictions: dict from model forward + targets: dict from dataloader + + Returns: + dict: Dictionary containing all metrics for this batch + """ + self.reset() + self.update(predictions, targets) + return self.accumulate() + + +class ECFormerIRMetrics(nn.Layer): + """Evaluation metrics for ECFormer IR task. + + Computes: + - Number-RMSE: RMSE of predicted vs true peak count + - Position-RMSE: RMSE of predicted vs true peak positions (class indices) + - Height-RMSE: RMSE of predicted vs true peak intensities + """ + + def __init__(self, num_position_classes=36, max_peaks=15, use_height_prediction=True): + """ + Args: + num_position_classes (int): Number of position classes (default: 36 for IR) + max_peaks (int): Maximum number of peaks (default: 15) + use_height_prediction (bool): Whether height prediction is used + """ + super().__init__() + self.num_position_classes = num_position_classes + self.max_peaks = max_peaks + self.use_height_prediction = use_height_prediction + + # Accumulators for streaming metrics + self.reset() + + def reset(self): + """Reset all accumulated statistics.""" + self.num_rmse_sum = _to_f32(0.0) + self.pos_rmse_sum = _to_f32(0.0) + self.height_rmse_sum = _to_f32(0.0) + self.pos_count = _to_f32(0.0) + self.height_count = _to_f32(0.0) + self.num_samples = _to_f32(0.0) + + def update(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]): + """ + Update metrics with a batch of predictions and targets. + + Args: + predictions: dict from model forward + - peak_number: [batch_size, max_peaks+1] logits for peak count + - peak_position: [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height (optional): [batch_size, max_peaks] predicted intensity values + targets: dict from dataloader + - peak_num: [batch_size] true peak counts + - peak_position: [batch_size, max_peaks] true position labels + - peak_height: [batch_size, max_peaks] true intensity values + """ + batch_size = targets['peak_num'].shape[0] + + # Peak number predictions + pred_nums = predictions['peak_number'].argmax(axis=1) + true_nums = targets['peak_num'] + + # Number RMSE accumulation + num_errors = (pred_nums - true_nums).astype('float32') + self.num_rmse_sum += paddle.sum(paddle.square(num_errors)) + + # Process each sample for position and height metrics + for i in range(batch_size): + n_true = int(true_nums[i]) + n_pred = int(pred_nums[i]) + + if n_true > 0 and n_pred > 0: + n_match = min(n_true, n_pred) + + # Position errors (only on matched peaks) + pos_true = targets['peak_position'][i, :n_match].astype('int64') + pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1) + pos_errors = (pos_pred - pos_true).astype('float32') + self.pos_rmse_sum += paddle.sum(paddle.square(pos_errors)) + self.pos_count += _to_f32(n_match) + + # Height errors if enabled + if self.use_height_prediction and 'peak_height' in predictions: + height_true = targets['peak_height'][i, :n_match].astype('float32') + height_pred = predictions['peak_height'][i, :n_match].reshape([-1]) + height_errors = height_true - height_pred + self.height_rmse_sum += paddle.sum(paddle.square(height_errors)) + self.height_count += _to_f32(n_match) + + self.num_samples += _to_f32(batch_size) + + def accumulate(self) -> Dict[str, float]: + """ + Compute accumulated metrics. + + Returns: + dict: Dictionary containing all metrics + """ + # Distributed reduction + num_rmse_sum = _all_reduce_sum_(self.num_rmse_sum.clone()) + pos_rmse_sum = _all_reduce_sum_(self.pos_rmse_sum.clone()) + height_rmse_sum = _all_reduce_sum_(self.height_rmse_sum.clone()) + pos_count = _all_reduce_sum_(self.pos_count.clone()) + height_count = _all_reduce_sum_(self.height_count.clone()) + num_samples = _all_reduce_sum_(self.num_samples.clone()) + + # Compute final metrics + num_rmse = paddle.sqrt(num_rmse_sum / paddle.maximum(num_samples, _to_f32(1.0))).item() + pos_rmse = paddle.sqrt(pos_rmse_sum / paddle.maximum(pos_count, _to_f32(1.0))).item() + + metrics = { + 'num_rmse': num_rmse, + 'pos_rmse': pos_rmse, + } + + if self.use_height_prediction: + height_rmse = paddle.sqrt( + height_rmse_sum / paddle.maximum(height_count, _to_f32(1.0)) + ).item() + metrics['height_rmse'] = height_rmse + + return metrics + + def forward(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]) -> Dict[str, float]: + """ + Compute metrics for a single batch (non-streaming version). + + Args: + predictions: dict from model forward + targets: dict from dataloader + + Returns: + dict: Dictionary containing all metrics for this batch + """ + self.reset() + self.update(predictions, targets) + return self.accumulate() \ No newline at end of file diff --git a/ppmat/metrics/ir_metric.py b/ppmat/metrics/ir_metric.py new file mode 100644 index 00000000..8754ad85 --- /dev/null +++ b/ppmat/metrics/ir_metric.py @@ -0,0 +1,178 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from typing import Dict, Optional, Any + + +# ========================= +# Utilities +# ========================= + +def _is_dist(): + """Check if distributed training is initialized.""" + try: + import paddle.distributed as dist + return dist.is_initialized() and dist.get_world_size() > 1 + except Exception: + return False + + +def _all_reduce_sum_(t: paddle.Tensor) -> paddle.Tensor: + """In-place SUM all_reduce if distributed; returns t.""" + if _is_dist(): + import paddle.distributed as dist + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t + + +def _to_f32(x) -> paddle.Tensor: + """Convert to float32 tensor.""" + return ( + paddle.to_tensor(x, dtype="float32") + if not isinstance(x, paddle.Tensor) + else x.astype("float32") + ) + + +# ========================= +# IR Metrics +# ========================= + +class IRMetrics(nn.Layer): + """Evaluation metrics for ECFormer IR task. + + Computes: + - Number-RMSE: RMSE of predicted vs true peak count + - Position-RMSE: RMSE of predicted vs true peak positions (class indices) + - Height-RMSE: RMSE of predicted vs true peak intensities (if enabled) + """ + + def __init__(self, use_height_prediction=True): + """ + Args: + use_height_prediction (bool): Whether height prediction is used + """ + super().__init__() + self.use_height_prediction = use_height_prediction + + # Accumulators for streaming metrics + self.reset() + + def reset(self): + """Reset all accumulated statistics.""" + self.num_rmse_sum = _to_f32(0.0) + self.pos_rmse_sum = _to_f32(0.0) + self.height_rmse_sum = _to_f32(0.0) + self.pos_count = _to_f32(0.0) + self.height_count = _to_f32(0.0) + self.num_samples = _to_f32(0.0) + + def update(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]): + """ + Update metrics with a batch of predictions and targets. + + Args: + predictions: dict from model forward + - peak_number: [batch_size, max_peaks+1] logits for peak count + - peak_position: [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height (optional): [batch_size, max_peaks] predicted intensity values + targets: dict from dataloader + - peak_num: [batch_size] true peak counts + - peak_position: [batch_size, max_peaks] true position labels + - peak_height: [batch_size, max_peaks] true intensity values + """ + batch_size = targets['peak_num'].shape[0] + + # Peak number predictions + pred_nums = predictions['peak_number'].argmax(axis=1) + true_nums = targets['peak_num'] + + # Number RMSE accumulation + num_errors = (pred_nums - true_nums).astype('float32') + self.num_rmse_sum += paddle.sum(paddle.square(num_errors)) + + # Process each sample for position and height metrics + for i in range(batch_size): + n_true = int(true_nums[i]) + n_pred = int(pred_nums[i]) + + if n_true > 0 and n_pred > 0: + n_match = min(n_true, n_pred) + + # Position errors (only on matched peaks) + pos_true = targets['peak_position'][i, :n_match].astype('int64') + pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1) + pos_errors = (pos_pred - pos_true).astype('float32') + self.pos_rmse_sum += paddle.sum(paddle.square(pos_errors)) + self.pos_count += _to_f32(n_match) + + # Height errors if enabled + if self.use_height_prediction and 'peak_height' in predictions: + height_true = targets['peak_height'][i, :n_match].astype('float32') + height_pred = predictions['peak_height'][i, :n_match].reshape([-1]) + height_errors = height_true - height_pred + self.height_rmse_sum += paddle.sum(paddle.square(height_errors)) + self.height_count += _to_f32(n_match) + + self.num_samples += _to_f32(batch_size) + + def accumulate(self) -> Dict[str, float]: + """ + Compute accumulated metrics. + + Returns: + dict: Dictionary containing all metrics + """ + # Distributed reduction + num_rmse_sum = _all_reduce_sum_(self.num_rmse_sum.clone()) + pos_rmse_sum = _all_reduce_sum_(self.pos_rmse_sum.clone()) + height_rmse_sum = _all_reduce_sum_(self.height_rmse_sum.clone()) + pos_count = _all_reduce_sum_(self.pos_count.clone()) + height_count = _all_reduce_sum_(self.height_count.clone()) + num_samples = _all_reduce_sum_(self.num_samples.clone()) + + # Compute final metrics + num_rmse = paddle.sqrt(num_rmse_sum / paddle.maximum(num_samples, _to_f32(1.0))).item() + pos_rmse = paddle.sqrt(pos_rmse_sum / paddle.maximum(pos_count, _to_f32(1.0))).item() + + metrics = { + 'num_rmse': num_rmse, + 'pos_rmse': pos_rmse, + } + + if self.use_height_prediction: + height_rmse = paddle.sqrt( + height_rmse_sum / paddle.maximum(height_count, _to_f32(1.0)) + ).item() + metrics['height_rmse'] = height_rmse + + return metrics + + def forward(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]) -> Dict[str, float]: + """ + Compute metrics for a single batch (non-streaming version). + + Args: + predictions: dict from model forward + targets: dict from dataloader + + Returns: + dict: Dictionary containing all metrics for this batch + """ + self.reset() + self.update(predictions, targets) + return self.accumulate() \ No newline at end of file diff --git a/ppmat/models/ecformer/models/ECD.py b/ppmat/models/ecformer/models/ECD.py index c0a5d351..fb585675 100644 --- a/ppmat/models/ecformer/models/ECD.py +++ b/ppmat/models/ecformer/models/ECD.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle import paddle.nn as nn -import numpy as np from .base_ecformer import ECFormerBase -from ..utils.graph_utils import get_key_padding_mask class ECFormerECD(ECFormerBase): @@ -27,15 +24,10 @@ def __init__( self, num_position_classes = 20, height_classes = 2, - loss_weight_height = 2.0, **kwargs ): super().__init__(**kwargs) - self.num_position_classes = num_position_classes - self.height_classes = height_classes - self.loss_weight_height = loss_weight_height - # 峰数预测头 self.pred_number_layer = nn.Sequential( nn.Linear(self.emb_dim, self.emb_dim * 2), @@ -55,100 +47,4 @@ def __init__( nn.Linear(self.emb_dim, self.emb_dim // 4), nn.ReLU(), nn.Linear(self.emb_dim // 4, height_classes) - ) - - # 损失函数 - self.ce_loss = nn.CrossEntropyLoss() - - def get_loss(self, predictions, targets): - """ECD任务损失:峰数 + 位置 + 符号""" - # 峰数损失 - loss_num = self.ce_loss(predictions['peak_number'], targets['peak_num']) - - # 由于每个样本峰数不同,需要动态处理 - batch_size = targets['peak_num'].shape[0] - - loss_pos_total = 0 - loss_height_total = 0 - valid_samples = 0 - - for i in range(batch_size): - n_peaks = int(targets['peak_num'][i]) - if n_peaks == 0: - continue - - # 位置损失 - pos_pred = predictions['peak_position'][i, :n_peaks, :].reshape([-1, self.num_position_classes]) - pos_gt = targets['peak_position'][i, :n_peaks].reshape([-1]) - loss_pos_total += self.ce_loss(pos_pred, pos_gt) - - # 符号损失 - height_pred = predictions['peak_height'][i, :n_peaks, :].reshape([-1, self.height_classes]) - height_gt = targets['peak_height'][i, :n_peaks].reshape([-1]) - loss_height_total += self.ce_loss(height_pred, height_gt) - - valid_samples += 1 - - if valid_samples > 0: - loss_pos = loss_pos_total / valid_samples - loss_height = loss_height_total / valid_samples - else: - loss_pos = paddle.to_tensor(0.0) - loss_height = paddle.to_tensor(0.0) - - return loss_num + self.loss_weight_height * loss_height + loss_pos - - def get_metrics(self, predictions, targets): - """ECD任务评估指标:Number-RMSE, Position-RMSE, Symbol-Acc""" - - batch_size = targets['peak_num'].shape[0] - - # 峰数预测 - pred_nums = predictions['peak_number'].argmax(axis=1) - true_nums = targets['peak_num'] - - # 位置误差 - pos_errors = [] - # 符号准确率 - symbol_correct = 0 - symbol_total = 0 - # 首峰符号准确率 - first_symbol_correct = 0 - first_symbol_total = 0 - - for i in range(batch_size): - n_true = int(true_nums[i]) - n_pred = int(pred_nums[i]) - - if n_true > 0 and n_pred > 0: - n_match = min(n_true, n_pred) - - # 位置误差 - pos_true = targets['peak_position'][i, :n_match].numpy() - pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1).numpy() - pos_errors.extend(pos_pred - pos_true) - - # 符号准确率 - height_true = targets['peak_height'][i, :n_match].numpy() - height_pred = predictions['peak_height'][i, :n_match, :].argmax(axis=1).numpy() - - symbol_correct += np.sum(height_true == height_pred) - symbol_total += n_match - - # 首峰符号准确率 - if height_true[0] == height_pred[0]: - first_symbol_correct += 1 - first_symbol_total += 1 - - # 计算指标 - pos_rmse = np.sqrt(np.mean(np.square(pos_errors))) if pos_errors else 0.0 - num_rmse = np.sqrt(np.mean(np.square((pred_nums - true_nums).numpy()))) - symbol_acc = symbol_correct / symbol_total if symbol_total > 0 else 0.0 - first_symbol_acc = first_symbol_correct / first_symbol_total if first_symbol_total > 0 else 0.0 - - return { - 'num_rmse': num_rmse, - 'pos_rmse': pos_rmse, - 'symbol_acc': symbol_acc, - 'first_symbol_acc': first_symbol_acc - } \ No newline at end of file + ) \ No newline at end of file diff --git a/ppmat/models/ecformer/models/IR.py b/ppmat/models/ecformer/models/IR.py index 34c4b607..122c6caf 100644 --- a/ppmat/models/ecformer/models/IR.py +++ b/ppmat/models/ecformer/models/IR.py @@ -12,14 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle import paddle.nn as nn -import numpy as np -from sklearn.metrics import mean_squared_error from .base_ecformer import ECFormerBase -from ..utils.graph_utils import get_key_padding_mask -from ..utils.loss.soft_dtw_cuda import SoftDTW class ECFormerIR(ECFormerBase): @@ -27,10 +22,8 @@ class ECFormerIR(ECFormerBase): def __init__( self, - spectrum_length=1000, num_position_classes=36, use_height_prediction=True, - dtw_gamma=0.1, **kwargs ): # IR任务最大峰数不同 @@ -38,10 +31,6 @@ def __init__( super().__init__(**kwargs) - self.spectrum_length = spectrum_length - self.num_position_classes = num_position_classes - self.use_height_prediction = use_height_prediction - # 峰数预测头(IR最多15个峰) self.pred_number_layer = nn.Sequential( nn.Linear(self.emb_dim, self.emb_dim * 2), @@ -62,96 +51,4 @@ def __init__( nn.Linear(self.emb_dim, self.emb_dim // 4), nn.ReLU(), nn.Linear(self.emb_dim // 4, 1) - ) - - # 损失函数 - self.ce_loss = nn.CrossEntropyLoss() - self.mse_loss = nn.MSELoss(reduction='mean') - use_cuda = True if "gpu" in paddle.device.get_device() else False - self.dtw_loss = SoftDTW(use_cuda=use_cuda, gamma=dtw_gamma, normalize=True,) - - def get_loss(self, predictions, targets): - """IR任务损失:峰数 + 位置 + 强度""" - - # 峰数损失 - loss_num = self.ce_loss(predictions['peak_number'], targets['peak_num']) - - batch_size = targets['peak_num'].shape[0] - - loss_pos_total = 0 - loss_height_total = 0 - valid_samples = 0 - - for i in range(batch_size): - n_peaks = int(targets['peak_num'][i]) - if n_peaks == 0: - continue - - # 位置损失 - pos_pred = predictions['peak_position'][i, :n_peaks, :].reshape([-1, self.num_position_classes]) - pos_gt = targets['peak_position'][i, :n_peaks].reshape([-1]) - loss_pos_total += self.ce_loss(pos_pred, pos_gt) - - # 强度损失(回归) - if self.use_height_prediction and 'peak_height' in predictions: - height_pred = predictions['peak_height'][i, :n_peaks].reshape([-1]) - height_gt = targets['peak_height'][i, :n_peaks].reshape([-1]) - loss_height_total += self.mse_loss(height_pred, height_gt) - - valid_samples += 1 - - loss_pos = loss_pos_total / valid_samples if valid_samples > 0 else paddle.to_tensor(0.0) - loss = loss_num + loss_pos - - if self.use_height_prediction: - loss_height = loss_height_total / valid_samples if valid_samples > 0 else paddle.to_tensor(0.0) - loss += loss_height - - return loss - - def get_metrics(self, predictions, targets): - """IR任务评估指标""" - - batch_size = targets['peak_num'].shape[0] - - # 峰数预测 - pred_nums = predictions['peak_number'].argmax(axis=1) - true_nums = targets['peak_num'] - - # 位置误差 - pos_errors = [] - # 高度误差 - height_errors = [] - - for i in range(batch_size): - n_true = int(true_nums[i]) - n_pred = int(pred_nums[i]) - - if n_true > 0 and n_pred > 0: - n_match = min(n_true, n_pred) - - # 位置误差 - pos_true = targets['peak_position'][i, :n_match].numpy() - pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1).numpy() - pos_errors.extend(pos_pred - pos_true) - - # 强度误差 - if 'peak_height' in predictions: - height_true = targets['peak_height'][i, :n_match].numpy() - height_pred = predictions['peak_height'][i, :n_match].numpy().flatten() - height_errors.extend(np.abs(height_true - height_pred)) - - # 计算指标 - pos_rmse = np.sqrt(np.mean(np.square(pos_errors))) if pos_errors else 0.0 - num_rmse = np.sqrt(np.mean(np.square((pred_nums - true_nums).numpy()))) - height_rmse = np.sqrt(np.mean(np.square(height_errors))) if height_errors else 0.0 - - metrics = { - 'num_rmse': num_rmse, - 'pos_rmse': pos_rmse, - } - - if self.use_height_prediction: - metrics['height_rmse'] = height_rmse - - return metrics \ No newline at end of file + ) \ No newline at end of file diff --git a/ppmat/models/ecformer/models/base_ecformer.py b/ppmat/models/ecformer/models/base_ecformer.py index 42049dea..f6e2527e 100644 --- a/ppmat/models/ecformer/models/base_ecformer.py +++ b/ppmat/models/ecformer/models/base_ecformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from abc import ABC import paddle import paddle.nn as nn from paddle.nn import TransformerEncoder, TransformerEncoderLayer @@ -243,16 +243,6 @@ def forward(self, } } - @abstractmethod - def get_loss(self, predictions, targets): - """损失函数 - 由子类实现""" - pass - - @abstractmethod - def get_metrics(self, predictions, targets): - """评估指标 - 由子类实现""" - pass - def get_key_padding_mask(tokens): key_padding_mask = paddle.zeros(tokens.shape) key_padding_mask[tokens == -1] = -paddle.inf diff --git a/ppmat/models/ecformer/utils/loss/dilate_loss.py b/ppmat/models/ecformer/utils/loss/dilate_loss.py deleted file mode 100644 index 1edaedd4..00000000 --- a/ppmat/models/ecformer/utils/loss/dilate_loss.py +++ /dev/null @@ -1,25 +0,0 @@ -import paddle -from . import soft_dtw -from . import path_soft_dtw - -def dilate_loss(outputs, targets, alpha, gamma, device): - # outputs, targets: shape (batch_size, N_output, 1) - batch_size, N_output = outputs.shape[0: 2] - loss_shape = 0 - softdtw_batch = soft_dtw.SoftDTWBatch.apply - D = paddle.zeros((batch_size, N_output, N_output)) - for k in range(batch_size): - Dk = soft_dtw.pairwise_distances(targets[k,:,:].reshape(-1,1), outputs[k,:,:].reshape(-1,1)) - D[k:k+1,:,:] = Dk - loss_shape = softdtw_batch(D, gamma) - - path_dtw = path_soft_dtw.PathDTWBatch.apply - path = path_dtw(D, gamma) - - Omega = soft_dtw.pairwise_distances( - paddle.arange(1.0, float(N_output+1)).reshape(N_output,1) - ) - loss_temporal = paddle.sum(path * Omega) / (N_output * N_output) - - loss = alpha*loss_shape + (1-alpha)*loss_temporal - return loss \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/loss/path_soft_dtw.py b/ppmat/models/ecformer/utils/loss/path_soft_dtw.py deleted file mode 100644 index 98849383..00000000 --- a/ppmat/models/ecformer/utils/loss/path_soft_dtw.py +++ /dev/null @@ -1,134 +0,0 @@ -import numpy as np -import paddle -from paddle.autograd import PyLayer -from numba import jit - - -@jit(nopython = True) -def my_max(x, gamma): - # use the log-sum-exp trick - max_x = np.max(x) - exp_x = np.exp((x - max_x) / gamma) - Z = np.sum(exp_x) - return gamma * np.log(Z) + max_x, exp_x / Z - -@jit(nopython = True) -def my_min(x,gamma) : - min_x, argmax_x = my_max(-x, gamma) - return - min_x, argmax_x - -@jit(nopython = True) -def my_max_hessian_product(p, z, gamma): - return ( p * z - p * np.sum(p * z) ) /gamma - -@jit(nopython = True) -def my_min_hessian_product(p, z, gamma): - return - my_max_hessian_product(p, z, gamma) - - -@jit(nopython = True) -def dtw_grad(theta, gamma): - m = theta.shape[0] - n = theta.shape[1] - V = np.zeros((m + 1, n + 1)) - V[:, 0] = 1e10 - V[0, :] = 1e10 - V[0, 0] = 0 - - Q = np.zeros((m + 2, n + 2, 3)) - - for i in range(1, m + 1): - for j in range(1, n + 1): - # theta is indexed starting from 0. - v, Q[i, j] = my_min(np.array([V[i, j - 1], - V[i - 1, j - 1], - V[i - 1, j]]) , gamma) - V[i, j] = theta[i - 1, j - 1] + v - - E = np.zeros((m + 2, n + 2)) - E[m + 1, :] = 0 - E[:, n + 1] = 0 - E[m + 1, n + 1] = 1 - Q[m + 1, n + 1] = 1 - - for i in range(m,0,-1): - for j in range(n,0,-1): - E[i, j] = Q[i, j + 1, 0] * E[i, j + 1] + \ - Q[i + 1, j + 1, 1] * E[i + 1, j + 1] + \ - Q[i + 1, j, 2] * E[i + 1, j] - - return V[m, n], E[1:m + 1, 1:n + 1], Q, E - - -@jit(nopython = True) -def dtw_hessian_prod(theta, Z, Q, E, gamma): - m = Z.shape[0] - n = Z.shape[1] - - V_dot = np.zeros((m + 1, n + 1)) - V_dot[0, 0] = 0 - - Q_dot = np.zeros((m + 2, n + 2, 3)) - for i in range(1, m + 1): - for j in range(1, n + 1): - # theta is indexed starting from 0. - V_dot[i, j] = Z[i - 1, j - 1] + \ - Q[i, j, 0] * V_dot[i, j - 1] + \ - Q[i, j, 1] * V_dot[i - 1, j - 1] + \ - Q[i, j, 2] * V_dot[i - 1, j] - - v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]]) - Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma) - E_dot = np.zeros((m + 2, n + 2)) - - for j in range(n,0,-1): - for i in range(m,0,-1): - E_dot[i, j] = Q_dot[i, j + 1, 0] * E[i, j + 1] + \ - Q[i, j + 1, 0] * E_dot[i, j + 1] + \ - Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1] + \ - Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1] + \ - Q_dot[i + 1, j, 2] * E[i + 1, j] + \ - Q[i + 1, j, 2] * E_dot[i + 1, j] - - return V_dot[m, n], E_dot[1:m + 1, 1:n + 1] - - -class PathDTWBatch(PyLayer): - @staticmethod - def forward(ctx, D, gamma): # D.shape: [batch_size, N , N] - batch_size, N, N = D.shape - device = D.place - D_cpu = D.detach().cpu().numpy() - gamma_paddle = paddle.to_tensor([gamma], dtype='float32').to(device) - - grad_paddle = paddle.zeros((batch_size, N ,N), place=device) - Q_paddle = paddle.zeros((batch_size, N+2 ,N+2, 3), place=device) - E_paddle = paddle.zeros((batch_size, N+2 ,N+2), place=device) - - for k in range(0, batch_size): # loop over all D in the batch - _, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k,:,:], gamma) - grad_paddle[k,:,:] = paddle.to_tensor(grad_cpu_k, dtype='float32').to(device) - Q_paddle[k,:,:,:] = paddle.to_tensor(Q_cpu_k, dtype='float32').to(device) - E_paddle[k,:,:] = paddle.to_tensor(E_cpu_k, dtype='float32').to(device) - - ctx.save_for_backward(grad_paddle, D, Q_paddle, E_paddle, gamma_paddle) - return paddle.mean(grad_paddle, axis=0) - - @staticmethod - def backward(ctx, grad_output): - device = grad_output.place - grad_paddle, D_paddle, Q_paddle, E_paddle, gamma = ctx.saved_tensor() - D_cpu = D_paddle.detach().cpu().numpy() - Q_cpu = Q_paddle.detach().cpu().numpy() - E_cpu = E_paddle.detach().cpu().numpy() - gamma = gamma.detach().cpu().numpy()[0] - Z = grad_output.detach().cpu().numpy() - - batch_size, N, N = D_cpu.shape - Hessian = paddle.zeros((batch_size, N ,N), place=device) - - for k in range(0, batch_size): - _, hess_k = dtw_hessian_prod(D_cpu[k,:,:], Z, Q_cpu[k,:,:,:], E_cpu[k,:,:], gamma) - Hessian[k:k+1,:,:] = paddle.to_tensor(hess_k, dtype='float32').to(device) - - return Hessian, None \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/loss/soft_dtw.py b/ppmat/models/ecformer/utils/loss/soft_dtw.py deleted file mode 100644 index cc787802..00000000 --- a/ppmat/models/ecformer/utils/loss/soft_dtw.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np -import paddle -from numba import jit -from paddle.autograd import PyLayer - -def pairwise_distances(x, y=None): - ''' - Input: x is a Nxd matrix - y is an optional Mxd matirx - Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] - if y is not given then use 'y=x'. - i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 - ''' - x_norm = (x**2).sum(1).reshape([-1, 1]) - if y is not None: - y_t = paddle.transpose(y, perm=[0, 1]) - y_norm = (y**2).sum(1).reshape([1, -1]) - else: - y_t = paddle.transpose(x, perm=[0, 1]) - y_norm = x_norm.reshape([1, -1]) - - dist = x_norm + y_norm - 2.0 * paddle.mm(x, y_t) - return paddle.clip(dist, 0.0, float('inf')) - -@jit(nopython = True) -def compute_softdtw(D, gamma): - N = D.shape[0] - M = D.shape[1] - R = np.zeros((N + 2, M + 2)) + 1e8 - R[0, 0] = 0 - for j in range(1, M + 1): - for i in range(1, N + 1): - r0 = -R[i - 1, j - 1] / gamma - r1 = -R[i - 1, j] / gamma - r2 = -R[i, j - 1] / gamma - rmax = max(max(r0, r1), r2) - rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) - softmin = - gamma * (np.log(rsum) + rmax) - R[i, j] = D[i - 1, j - 1] + softmin - return R - -@jit(nopython = True) -def compute_softdtw_backward(D_, R, gamma): - N = D_.shape[0] - M = D_.shape[1] - D = np.zeros((N + 2, M + 2)) - E = np.zeros((N + 2, M + 2)) - D[1:N + 1, 1:M + 1] = D_ - E[-1, -1] = 1 - R[:, -1] = -1e8 - R[-1, :] = -1e8 - R[-1, -1] = R[-2, -2] - for j in range(M, 0, -1): - for i in range(N, 0, -1): - a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma - b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma - c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma - a = np.exp(a0) - b = np.exp(b0) - c = np.exp(c0) - E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c - return E[1:N + 1, 1:M + 1] - - -class SoftDTWBatch(PyLayer): - @staticmethod - def forward(ctx, D, gamma = 1.0): # D.shape: [batch_size, N , N] - dev = D.place - batch_size, N, N = D.shape - gamma = paddle.to_tensor([gamma], dtype='float32').to(dev) - D_ = D.detach().cpu().numpy() - g_ = gamma.item() - - total_loss = 0 - R = paddle.zeros((batch_size, N+2, N+2)) - for k in range(0, batch_size): # loop over all D in the batch - Rk = paddle.to_tensor(compute_softdtw(D_[k,:,:], g_), dtype='float32').to(dev) - R[k:k+1,:,:] = Rk - total_loss = total_loss + Rk[-2,-2] - ctx.save_for_backward(D, R, gamma) - return total_loss / batch_size - - @staticmethod - def backward(ctx, grad_output): - dev = grad_output.place - D, R, gamma = ctx.saved_tensor() - batch_size, N, N = D.shape - D_ = D.detach().cpu().numpy() - R_ = R.detach().cpu().numpy() - g_ = gamma.item() - - E = paddle.zeros((batch_size, N, N)) - for k in range(batch_size): - Ek = paddle.to_tensor(compute_softdtw_backward(D_[k,:,:], R_[k,:,:], g_), dtype='float32').to(dev) - E[k:k+1,:,:] = Ek - - return grad_output * E, None \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/loss/soft_dtw_cuda.py b/ppmat/models/ecformer/utils/loss/soft_dtw_cuda.py deleted file mode 100644 index d2a545b7..00000000 --- a/ppmat/models/ecformer/utils/loss/soft_dtw_cuda.py +++ /dev/null @@ -1,427 +0,0 @@ -# MIT License -# -# Copyright (c) 2020 Mehran Maghoumi -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ---------------------------------------------------------------------------------------------------------------------- - -import numpy as np -import paddle -from numba import jit, prange -from paddle.autograd import PyLayer -from numba import cuda -import math - -# ---------------------------------------------------------------------------------------------------------------------- -@cuda.jit -def compute_softdtw_cuda(D, gamma, bandwidth, max_i, max_j, n_passes, R): - """ - :param seq_len: The length of the sequence (both inputs are assumed to be of the same size) - :param n_passes: 2 * seq_len - 1 (The number of anti-diagonals) - """ - # Each block processes one pair of examples - b = cuda.blockIdx.x - # We have as many threads as seq_len, because the most number of threads we need - # is equal to the number of elements on the largest anti-diagonal - tid = cuda.threadIdx.x - - # Compute I, J, the indices from [0, seq_len) - - # The row index is always the same as tid - I = tid - - inv_gamma = 1.0 / gamma - - # Go over each anti-diagonal. Only process threads that fall on the current on the anti-diagonal - for p in range(n_passes): - - # The index is actually 'p - tid' but need to force it in-bounds - J = max(0, min(p - tid, max_j - 1)) - - # For simplicity, we define i, j which start from 1 (offset from I, J) - i = I + 1 - j = J + 1 - - # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds - if I + J == p and (I < max_i and J < max_j): - # Don't compute if outside bandwidth - if not (abs(i - j) > bandwidth > 0): - r0 = -R[b, i - 1, j - 1] * inv_gamma - r1 = -R[b, i - 1, j] * inv_gamma - r2 = -R[b, i, j - 1] * inv_gamma - rmax = max(max(r0, r1), r2) - rsum = math.exp(r0 - rmax) + math.exp(r1 - rmax) + math.exp(r2 - rmax) - softmin = -gamma * (math.log(rsum) + rmax) - R[b, i, j] = D[b, i - 1, j - 1] + softmin - - # Wait for other threads in this block - cuda.syncthreads() - -# ---------------------------------------------------------------------------------------------------------------------- -@cuda.jit -def compute_softdtw_backward_cuda(D, R, inv_gamma, bandwidth, max_i, max_j, n_passes, E): - k = cuda.blockIdx.x - tid = cuda.threadIdx.x - - # Indexing logic is the same as above, however, the anti-diagonal needs to - # progress backwards - I = tid - - for p in range(n_passes): - # Reverse the order to make the loop go backward - rev_p = n_passes - p - 1 - - # convert tid to I, J, then i, j - J = max(0, min(rev_p - tid, max_j - 1)) - - i = I + 1 - j = J + 1 - - # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds - if I + J == rev_p and (I < max_i and J < max_j): - - if math.isinf(R[k, i, j]): - R[k, i, j] = -math.inf - - # Don't compute if outside bandwidth - if not (abs(i - j) > bandwidth > 0): - a = math.exp((R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) * inv_gamma) - b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) * inv_gamma) - c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) * inv_gamma) - E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c - - # Wait for other threads in this block - cuda.syncthreads() - -# ---------------------------------------------------------------------------------------------------------------------- -class _SoftDTWCUDA(PyLayer): - """ - CUDA implementation is inspired by the diagonal one proposed in https://ieeexplore.ieee.org/document/8400444: - "Developing a pattern discovery method in time series data and its GPU acceleration" - """ - - @staticmethod - def forward(ctx, D, gamma, bandwidth): - dev = D.place - dtype = D.dtype - gamma = paddle.to_tensor([gamma]) - bandwidth = paddle.to_tensor([bandwidth]) - - B = D.shape[0] - N = D.shape[1] - M = D.shape[2] - threads_per_block = max(N, M) - n_passes = 2 * threads_per_block - 1 - - # Prepare the output array - R = paddle.ones((B, N + 2, M + 2), dtype=dtype) * math.inf - R[:, 0, 0] = 0 - - # Run the CUDA kernel. - # Set CUDA's grid size to be equal to the batch size (every CUDA block processes one sample pair) - # Set the CUDA block size to be equal to the length of the longer sequence (equal to the size of the largest diagonal) - compute_softdtw_cuda[B, threads_per_block](cuda.as_cuda_array(D.detach()), - gamma.item(), bandwidth.item(), N, M, n_passes, - cuda.as_cuda_array(R)) - ctx.save_for_backward(D, R.clone(), gamma, bandwidth) - return R[:, -2, -2] - - @staticmethod - def backward(ctx, grad_output): - D, R, gamma, bandwidth = ctx.saved_tensor() - dev = grad_output.place - dtype = grad_output.dtype - - B = D.shape[0] - N = D.shape[1] - M = D.shape[2] - threads_per_block = max(N, M) - n_passes = 2 * threads_per_block - 1 - - D_ = paddle.zeros((B, N + 2, M + 2), dtype=dtype) - D_[:, 1:N + 1, 1:M + 1] = D - - R[:, :, -1] = -math.inf - R[:, -1, :] = -math.inf - R[:, -1, -1] = R[:, -2, -2] - - E = paddle.zeros((B, N + 2, M + 2), dtype=dtype) - E[:, -1, -1] = 1 - - # Grid and block sizes are set same as done above for the forward() call - compute_softdtw_backward_cuda[B, threads_per_block](cuda.as_cuda_array(D_), - cuda.as_cuda_array(R), - 1.0 / gamma.item(), bandwidth.item(), N, M, n_passes, - cuda.as_cuda_array(E)) - E = E[:, 1:N + 1, 1:M + 1] - return grad_output.reshape([-1, 1, 1]).expand_as(E) * E, None, None - - -# ---------------------------------------------------------------------------------------------------------------------- -# -# The following is the CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw -# Credit goes to Kanru Hua. -# I've added support for batching and pruning. -# -# ---------------------------------------------------------------------------------------------------------------------- -@jit(nopython=True, parallel=True) -def compute_softdtw(D, gamma, bandwidth): - B = D.shape[0] - N = D.shape[1] - M = D.shape[2] - R = np.ones((B, N + 2, M + 2)) * np.inf - R[:, 0, 0] = 0 - for b in prange(B): - for j in range(1, M + 1): - for i in range(1, N + 1): - - # Check the pruning condition - if 0 < bandwidth < np.abs(i - j): - continue - - r0 = -R[b, i - 1, j - 1] / gamma - r1 = -R[b, i - 1, j] / gamma - r2 = -R[b, i, j - 1] / gamma - rmax = max(max(r0, r1), r2) - rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) - softmin = - gamma * (np.log(rsum) + rmax) - R[b, i, j] = D[b, i - 1, j - 1] + softmin - return R - -# ---------------------------------------------------------------------------------------------------------------------- -@jit(nopython=True, parallel=True) -def compute_softdtw_backward(D_, R, gamma, bandwidth): - B = D_.shape[0] - N = D_.shape[1] - M = D_.shape[2] - D = np.zeros((B, N + 2, M + 2)) - E = np.zeros((B, N + 2, M + 2)) - D[:, 1:N + 1, 1:M + 1] = D_ - E[:, -1, -1] = 1 - R[:, :, -1] = -np.inf - R[:, -1, :] = -np.inf - R[:, -1, -1] = R[:, -2, -2] - for k in prange(B): - for j in range(M, 0, -1): - for i in range(N, 0, -1): - - if np.isinf(R[k, i, j]): - R[k, i, j] = -np.inf - - # Check the pruning condition - if 0 < bandwidth < np.abs(i - j): - continue - - a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma - b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma - c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma - a = np.exp(a0) - b = np.exp(b0) - c = np.exp(c0) - E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c - return E[:, 1:N + 1, 1:M + 1] - -# ---------------------------------------------------------------------------------------------------------------------- -class _SoftDTW(PyLayer): - """ - CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw - """ - - @staticmethod - def forward(ctx, D, gamma, bandwidth): - dev = D.place - dtype = D.dtype - gamma = paddle.Tensor([gamma]).to(dev).astype(dtype) - bandwidth = paddle.Tensor([bandwidth]).to(dev).astype(dtype) - D_ = D.detach().cpu().numpy() - g_ = gamma.item() - b_ = bandwidth.item() - R = paddle.Tensor(compute_softdtw(D_, g_, b_)).to(dev).astype(dtype) - ctx.save_for_backward(D, R, gamma, bandwidth) - return R[:, -2, -2] - - @staticmethod - def backward(ctx, grad_output): - D, R, gamma, bandwidth = ctx.saved_tensor() - dev = grad_output.place - dtype = grad_output.dtype - D_ = D.detach().cpu().numpy() - R_ = R.detach().cpu().numpy() - g_ = gamma.item() - b_ = bandwidth.item() - E = paddle.Tensor(compute_softdtw_backward(D_, R_, g_, b_)).to(dev).astype(dtype) - return grad_output.reshape([-1, 1, 1]).expand_as(E) * E, None, None - -# ---------------------------------------------------------------------------------------------------------------------- -class SoftDTW(paddle.nn.Layer): - """ - The soft DTW implementation that optionally supports CUDA - """ - - def __init__(self, use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_func=None): - """ - Initializes a new instance using the supplied parameters - :param use_cuda: Flag indicating whether the CUDA implementation should be used - :param gamma: sDTW's gamma parameter - :param normalize: Flag indicating whether to perform normalization - (as discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790) - :param bandwidth: Sakoe-Chiba bandwidth for pruning. Passing 'None' will disable pruning. - :param dist_func: Optional point-wise distance function to use. If 'None', then a default Euclidean distance function will be used. - """ - super(SoftDTW, self).__init__() # ���ָ����ʼ���߼� - self.normalize = normalize - self.gamma = gamma - self.bandwidth = 0 if bandwidth is None else float(bandwidth) - self.use_cuda = use_cuda - - # Set the distance function - if dist_func is not None: - self.dist_func = dist_func - else: - self.dist_func = SoftDTW._euclidean_dist_func - - def _get_func_dtw(self, x, y): - """ - Checks the inputs and selects the proper implementation to use. - """ - bx, lx, dx = x.shape - by, ly, dy = y.shape - # Make sure the dimensions match - assert bx == by # Equal batch sizes - assert dx == dy # Equal feature dimensions - - use_cuda = self.use_cuda - - if use_cuda and (lx > 1024 or ly > 1024): # We should be able to spawn enough threads in CUDA - print("SoftDTW: Cannot use CUDA because the sequence length > 1024 (the maximum block size supported by CUDA)") - use_cuda = False - - # Finally, return the correct function - return _SoftDTWCUDA.apply if use_cuda else _SoftDTW.apply - - @staticmethod - def _euclidean_dist_func(x, y): - """ - Calculates the Euclidean distance between each element in x and y per timestep - """ - n = x.shape[1] - m = y.shape[1] - d = x.shape[2] - x = x.unsqueeze(2).expand([-1, n, m, d]) # paddle expand ��Ҫ��ʽ�б� - y = y.unsqueeze(1).expand([-1, n, m, d]) # paddle expand ��Ҫ��ʽ�б� - return paddle.pow(x - y, 2).sum(3) - - def forward(self, X, Y): - """ - Compute the soft-DTW value between X and Y - :param X: One batch of examples, batch_size x seq_len x dims - :param Y: The other batch of examples, batch_size x seq_len x dims - :return: The computed results - """ - - # Check the inputs and get the correct implementation - func_dtw = self._get_func_dtw(X, Y) - - if self.normalize: - # Stack everything up and run - x = paddle.concat([X, X, Y]) - y = paddle.concat([Y, X, Y]) - D = self.dist_func(x, y) - out = func_dtw(D, self.gamma, self.bandwidth) - out_xy, out_xx, out_yy = paddle.split(out, X.shape[0]) - return out_xy - 1 / 2 * (out_xx + out_yy) - else: - D_xy = self.dist_func(X, Y) - return func_dtw(D_xy, self.gamma, self.bandwidth) - -# ---------------------------------------------------------------------------------------------------------------------- -def timed_run(a, b, sdtw): - """ - Runs a and b through sdtw, and times the forward and backward passes. - Assumes that a requires gradients. - :return: timing, forward result, backward result - """ - from timeit import default_timer as timer - - # Forward pass - start = timer() - forward = sdtw(a, b) - end = timer() - t = end - start - - grad_outputs = paddle.ones_like(forward) - - # Backward - start = timer() - grads = paddle.autograd.grad(forward, a, grad_outputs=grad_outputs)[0] - end = timer() - - # Total time - t += end - start - - return t, forward, grads - -# ---------------------------------------------------------------------------------------------------------------------- -def profile(batch_size, seq_len_a, seq_len_b, dims, tol_backward): - sdtw = SoftDTW(False, gamma=1.0, normalize=False) - sdtw_cuda = SoftDTW(True, gamma=1.0, normalize=False) - n_iters = 6 - - print("Profiling forward() + backward() times for batch_size={}, seq_len_a={}, seq_len_b={}, dims={}...".format(batch_size, seq_len_a, seq_len_b, dims)) - - times_cpu = [] - times_gpu = [] - - for i in range(n_iters): - a_cpu = paddle.rand((batch_size, seq_len_a, dims), requires_grad=True) - b_cpu = paddle.rand((batch_size, seq_len_b, dims)) - a_gpu = a_cpu.cuda() - b_gpu = b_cpu.cuda() - - # GPU - t_gpu, forward_gpu, backward_gpu = timed_run(a_gpu, b_gpu, sdtw_cuda) - - # CPU - t_cpu, forward_cpu, backward_cpu = timed_run(a_cpu, b_cpu, sdtw) - - # Verify the results - assert paddle.allclose(forward_cpu, forward_gpu.cpu()) - assert paddle.allclose(backward_cpu, backward_gpu.cpu(), atol=tol_backward) - - if i > 0: # Ignore the first time we run, in case this is a cold start (because timings are off at a cold start of the script) - times_cpu += [t_cpu] - times_gpu += [t_gpu] - - # Average and log - avg_cpu = np.mean(times_cpu) - avg_gpu = np.mean(times_gpu) - print(" CPU: ", avg_cpu) - print(" GPU: ", avg_gpu) - print(" Speedup: ", avg_cpu / avg_gpu) - print() - -# ---------------------------------------------------------------------------------------------------------------------- -if __name__ == "__main__": - from timeit import default_timer as timer - - paddle.seed(1234) - - profile(128, 17, 15, 2, tol_backward=1e-6) - profile(512, 64, 64, 2, tol_backward=1e-4) - profile(512, 256, 256, 2, tol_backward=1e-3) \ No newline at end of file From 55e472fa01b83a2b4a883bb7731935aff0a76894 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Sun, 1 Mar 2026 20:55:38 +0800 Subject: [PATCH 06/16] =?UTF-8?q?=E5=AE=8C=E6=88=90utils=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/models/ecformer/__init__.py | 12 ------------ ppmat/models/ecformer/models/base_ecformer.py | 2 +- ppmat/models/ecformer/utils/__init__.py | 15 --------------- ppmat/{models/ecformer => }/utils/graph_utils.py | 1 - 4 files changed, 1 insertion(+), 29 deletions(-) delete mode 100644 ppmat/models/ecformer/utils/__init__.py rename ppmat/{models/ecformer => }/utils/graph_utils.py (99%) diff --git a/ppmat/models/ecformer/__init__.py b/ppmat/models/ecformer/__init__.py index 08b481e5..176592d5 100644 --- a/ppmat/models/ecformer/__init__.py +++ b/ppmat/models/ecformer/__init__.py @@ -19,20 +19,8 @@ # 导出编码器(如需直接使用) from .encoders.gin_node_embedding import GINNodeEmbedding -# 导出工具函数 -from .utils.graph_utils import ( - index_transform, - get_key_padding_mask, - feat_padding_mask, - pad_node_features -) - __all__ = [ 'ECFormerECD', 'ECFormerIR', 'GINNodeEmbedding', - 'index_transform', - 'get_key_padding_mask', - 'feat_padding_mask', - 'pad_node_features', ] \ No newline at end of file diff --git a/ppmat/models/ecformer/models/base_ecformer.py b/ppmat/models/ecformer/models/base_ecformer.py index f6e2527e..41ab15e0 100644 --- a/ppmat/models/ecformer/models/base_ecformer.py +++ b/ppmat/models/ecformer/models/base_ecformer.py @@ -18,7 +18,7 @@ from paddle.nn import TransformerEncoder, TransformerEncoderLayer from ..encoders.gin_node_embedding import GINNodeEmbedding -from ..utils.graph_utils import pad_node_features, feat_padding_mask +from ppmat.utils.graph_utils import pad_node_features, feat_padding_mask from paddle_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set def fix_mask_for_paddle(mask, n_head=None): diff --git a/ppmat/models/ecformer/utils/__init__.py b/ppmat/models/ecformer/utils/__init__.py deleted file mode 100644 index 3157548d..00000000 --- a/ppmat/models/ecformer/utils/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import graph_utils \ No newline at end of file diff --git a/ppmat/models/ecformer/utils/graph_utils.py b/ppmat/utils/graph_utils.py similarity index 99% rename from ppmat/models/ecformer/utils/graph_utils.py rename to ppmat/utils/graph_utils.py index b31c4284..b012340b 100644 --- a/ppmat/models/ecformer/utils/graph_utils.py +++ b/ppmat/utils/graph_utils.py @@ -13,7 +13,6 @@ # limitations under the License. import paddle -import numpy as np def index_transform(raw_index, batch_size): """将压缩的批次索引还原为每个样本的节点索引列表""" From 247cf95bf7e40ddbd67416a5eee62540f3031528 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Sun, 1 Mar 2026 22:42:55 +0800 Subject: [PATCH 07/16] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=BF=9B=E4=B8=80?= =?UTF-8?q?=E6=AD=A5=E5=AF=B9=E9=BD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/models/ecformer/layers/rbf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppmat/models/ecformer/layers/rbf.py b/ppmat/models/ecformer/layers/rbf.py index 73617b52..8fc04823 100644 --- a/ppmat/models/ecformer/layers/rbf.py +++ b/ppmat/models/ecformer/layers/rbf.py @@ -24,7 +24,7 @@ def __init__(self, gamma: paddle.nn.parameter.Parameter): super(RBF, self).__init__() self.centers = centers.data.reshape([1, -1]) - self.gamma = gamma + self.gamma = gamma.data def forward(self, x): x = x.reshape([-1, 1]) From 6c10c647fd3de792aeb0e77fc7da0a9e8c799b7b Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Sun, 1 Mar 2026 23:04:44 +0800 Subject: [PATCH 08/16] =?UTF-8?q?=E5=AF=B9=E9=BD=90=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=9D=83=E9=87=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/models/ecformer/models/base_ecformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppmat/models/ecformer/models/base_ecformer.py b/ppmat/models/ecformer/models/base_ecformer.py index 41ab15e0..550bb5af 100644 --- a/ppmat/models/ecformer/models/base_ecformer.py +++ b/ppmat/models/ecformer/models/base_ecformer.py @@ -123,14 +123,14 @@ def _build_transformer(self, emb_dim, num_heads, num_layers, dropout): assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads" - encoder_layer = TransformerEncoderLayer( + self.tf_enc_layer = TransformerEncoderLayer( d_model=emb_dim, nhead=num_heads, dim_feedforward=emb_dim, dropout=dropout, activation='relu', ) - return TransformerEncoder(encoder_layer, num_layers=num_layers) + return TransformerEncoder(self.tf_enc_layer, num_layers=num_layers) def encode_molecule( self, From 2d1a259317675240fc2f9ad55bcbb876564242fa Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Mon, 9 Mar 2026 00:05:19 +0800 Subject: [PATCH 09/16] =?UTF-8?q?=E8=A7=84=E8=8C=83=E5=8C=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E6=A0=BC=E5=BC=8F=E5=B9=B6=E5=85=A8=E9=83=A8?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E9=80=9A=E8=BF=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ECDFormerDataset/compound_tools.py | 842 ------------------ ppmat/datasets/ECDFormerDataset/dataloader.py | 70 -- ppmat/datasets/ECDFormerDataset/eval_func.py | 161 ---- ppmat/datasets/ECDFormerDataset/util_func.py | 57 -- ppmat/datasets/IRDataset/__init__.py | 452 ---------- ppmat/datasets/IRDataset/colored_tqdm.py | 90 -- ppmat/datasets/IRDataset/place_env.py | 187 ---- ppmat/datasets/__init__.py | 10 +- .../__init__.py => build_ecd.py} | 411 ++++----- ppmat/datasets/build_ir.py | 355 ++++++++ ppmat/datasets/collate_fn.py | 77 ++ ppmat/datasets/ecd_dataset.py | 164 ++++ ppmat/datasets/geometric_data_type/batch.py | 2 +- ppmat/datasets/ir_dataset.py | 172 ++++ ppmat/utils/__init__.py | 4 + .../colored_tqdm.py | 24 +- .../IRDataset => utils}/compound_tools.py | 34 +- .../ECDFormerDataset => utils}/place_env.py | 84 +- 18 files changed, 1018 insertions(+), 2178 deletions(-) delete mode 100644 ppmat/datasets/ECDFormerDataset/compound_tools.py delete mode 100644 ppmat/datasets/ECDFormerDataset/dataloader.py delete mode 100644 ppmat/datasets/ECDFormerDataset/eval_func.py delete mode 100644 ppmat/datasets/ECDFormerDataset/util_func.py delete mode 100644 ppmat/datasets/IRDataset/__init__.py delete mode 100644 ppmat/datasets/IRDataset/colored_tqdm.py delete mode 100644 ppmat/datasets/IRDataset/place_env.py rename ppmat/datasets/{ECDFormerDataset/__init__.py => build_ecd.py} (70%) create mode 100644 ppmat/datasets/build_ir.py create mode 100644 ppmat/datasets/ecd_dataset.py create mode 100644 ppmat/datasets/ir_dataset.py rename ppmat/{datasets/ECDFormerDataset => utils}/colored_tqdm.py (76%) rename ppmat/{datasets/IRDataset => utils}/compound_tools.py (97%) rename ppmat/{datasets/ECDFormerDataset => utils}/place_env.py (55%) diff --git a/ppmat/datasets/ECDFormerDataset/compound_tools.py b/ppmat/datasets/ECDFormerDataset/compound_tools.py deleted file mode 100644 index cee0c635..00000000 --- a/ppmat/datasets/ECDFormerDataset/compound_tools.py +++ /dev/null @@ -1,842 +0,0 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from rdkit import Chem -from rdkit.Chem import AllChem -from rdkit.Chem import rdchem - -DAY_LIGHT_FG_SMARTS_LIST = [ - # C - "[CX4]", - "[$([CX2](=C)=C)]", - "[$([CX3]=[CX3])]", - "[$([CX2]#C)]", - # C & O - "[CX3]=[OX1]", - "[$([CX3]=[OX1]),$([CX3+]-[OX1-])]", - "[CX3](=[OX1])C", - "[OX1]=CN", - "[CX3](=[OX1])O", - "[CX3](=[OX1])[F,Cl,Br,I]", - "[CX3H1](=O)[#6]", - "[CX3](=[OX1])[OX2][CX3](=[OX1])", - "[NX3][CX3](=[OX1])[#6]", - "[NX3][CX3]=[NX3+]", - "[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]", - "[NX3][CX3](=[OX1])[OX2H0]", - "[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]", - "[CX3](=O)[O-]", - "[CX3](=[OX1])(O)O", - "[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]", - "C[OX2][CX3](=[OX1])[OX2]C", - "[CX3](=O)[OX2H1]", - "[CX3](=O)[OX1H0-,OX2H1]", - "[NX3][CX2]#[NX1]", - "[#6][CX3](=O)[OX2H0][#6]", - "[#6][CX3](=O)[#6]", - "[OD2]([#6])[#6]", - # H - "[H]", - "[!#1]", - "[H+]", - "[+H]", - "[!H]", - # N - "[NX3;H2,H1;!$(NC=O)]", - "[NX3][CX3]=[CX3]", - "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]", - "[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]", - "[NX3][$(C=C),$(cc)]", - "[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]", - "[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]", - "[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]", - "[CH3X4]", - "[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]", - "[CH2X4][CX3](=[OX1])[NX3H2]", - "[CH2X4][CX3](=[OX1])[OH0-,OH]", - "[CH2X4][SX2H,SX1H0-]", - "[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]", - "[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]", - "[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\ -[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1", - "[CHX4]([CH3X4])[CH2X4][CH3X4]", - "[CH2X4][CHX4]([CH3X4])[CH3X4]", - "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]", - "[CH2X4][CH2X4][SX2][CH3X4]", - "[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1", - "[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]", - "[CH2X4][OX2H]", - "[NX3][CX3]=[SX1]", - "[CHX4]([CH3X4])[OX2H]", - "[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12", - "[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1", - "[CHX4]([CH3X4])[CH3X4]", - "N[CX4H2][CX3](=[OX1])[O,N]", - "N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]", - "[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]", - "[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]", - "[#7]", - "[NX2]=N", - "[NX2]=[NX2]", - "[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]", - "[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]", - "[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]", - "[NX3][NX3]", - "[NX3][NX2]=[*]", - "[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]", - "[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]", - "[NX3+]=[CX3]", - "[CX3](=[OX1])[NX3H][CX3](=[OX1])", - "[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])", - "[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])", - "[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]", - "[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]", - "[NX1]#[CX2]", - "[CX1-]#[NX2+]", - "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", - "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", - "[NX2]=[OX1]", - "[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]", - # O - "[OX2H]", - "[#6][OX2H]", - "[OX2H][CX3]=[OX1]", - "[OX2H]P", - "[OX2H][#6X3]=[#6]", - "[OX2H][cX3]:[c]", - "[OX2H][$(C=C),$(cc)]", - "[$([OH]-*=[!#6])]", - "[OX2,OX1-][OX2,OX1-]", - # P - "[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\ -$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\ -,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]", - "[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\ -$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\ -$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]", - # S - "[S-][CX3](=S)[#6]", - "[#6X3](=[SX1])([!N])[!N]", - "[SX2]", - "[#16X2H]", - "[#16!H0]", - "[#16X2H0]", - "[#16X2H0][!#16]", - "[#16X2H0][#16X2H0]", - "[#16X2H0][!#16].[#16X2H0][!#16]", - "[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]", - "[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]", - "[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]", - "[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]", - "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]", - "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]", - "[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]", - "[SX4](C)(C)(=O)=N", - "[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]", - "[$([#16X3]=[OX1]),$([#16X3+][OX1-])]", - "[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]", - "[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]", - "[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]", - "[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]", - "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]", - "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]", - "[#16X2][OX2H,OX1H0-]", - "[#16X2][OX2H0]", - # X - "[#6][F,Cl,Br,I]", - "[F,Cl,Br,I]", - "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]", - ] - - -def get_gasteiger_partial_charges(mol, n_iter=12): - """ - Calculates list of gasteiger partial charges for each atom in mol object. - Args: - mol: rdkit mol object. - n_iter(int): number of iterations. Default 12. - Returns: - list of computed partial charges for each atom. - """ - Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, - throwOnParamFailure=True) - partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in - mol.GetAtoms()] - return partial_charges - - -def create_standardized_mol_id(smiles): - """ - Args: - smiles: smiles sequence. - Returns: - inchi. - """ - if check_smiles_validity(smiles): - # remove stereochemistry - smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), - isomericSmiles=False) - mol = AllChem.MolFromSmiles(smiles) - if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 - if '.' in smiles: # if multiple species, pick largest molecule - mol_species_list = split_rdkit_mol_obj(mol) - largest_mol = get_largest_mol(mol_species_list) - inchi = AllChem.MolToInchi(largest_mol) - else: - inchi = AllChem.MolToInchi(mol) - return inchi - else: - return - else: - return - - -def check_smiles_validity(smiles): - """ - Check whether the smile can't be converted to rdkit mol object. - """ - try: - m = Chem.MolFromSmiles(smiles) - if m: - return True - else: - return False - except Exception as e: - return False - - -def split_rdkit_mol_obj(mol): - """ - Split rdkit mol object containing multiple species or one species into a - list of mol objects or a list containing a single object respectively. - Args: - mol: rdkit mol object. - """ - smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) - smiles_list = smiles.split('.') - mol_species_list = [] - for s in smiles_list: - if check_smiles_validity(s): - mol_species_list.append(AllChem.MolFromSmiles(s)) - return mol_species_list - - -def get_largest_mol(mol_list): - """ - Given a list of rdkit mol objects, returns mol object containing the - largest num of atoms. If multiple containing largest num of atoms, - picks the first one. - Args: - mol_list(list): a list of rdkit mol object. - Returns: - the largest mol. - """ - num_atoms_list = [len(m.GetAtoms()) for m in mol_list] - largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) - return mol_list[largest_mol_idx] - - -def rdchem_enum_to_list(values): - """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, - 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, - 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, - 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} - """ - return [values[i] for i in range(len(values))] - - -def safe_index(alist, elem): - """ - Return index of element e in list l. If e is not present, return the last index - """ - try: - return alist.index(elem) - except ValueError: - return len(alist) - 1 - - -def get_atom_feature_dims(list_acquired_feature_names): - """ tbd - """ - return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) - - -def get_bond_feature_dims(list_acquired_feature_names): - """ tbd - """ - list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) - # +1 for self loop edges - return [_l + 1 for _l in list_bond_feat_dim] - - -class CompoundKit(object): - """ - CompoundKit - """ - atom_vocab_dict = { - "atomic_num": list(range(1, 119)) + ['misc'], - "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), - "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], - "explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], - "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], - "hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values), - "implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], - "is_aromatic": [0, 1], - "total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], - 'num_radical_e': [0, 1, 2, 3, 4, 'misc'], - 'atom_is_in_ring': [0, 1], - 'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], - 'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], - 'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], - 'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], - 'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], - 'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], - 'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], - } - bond_vocab_dict = { - "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), - "bond_type": rdchem_enum_to_list(rdchem.BondType.values), - "is_in_ring": [0, 1], - - 'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values), - 'is_conjugated': [0, 1], - } - # float features - atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] - # bond_float_feats= ["bond_length", "bond_angle"] # optional - - ### functional groups - day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST - day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] - - morgan_fp_N = 200 - morgan2048_fp_N = 2048 - maccs_fp_N = 167 - - period_table = Chem.GetPeriodicTable() - - ### atom - - @staticmethod - def get_atom_value(atom, name): - """get atom values""" - if name == 'atomic_num': - return atom.GetAtomicNum() - elif name == 'chiral_tag': - return atom.GetChiralTag() - elif name == 'degree': - return atom.GetDegree() - elif name == 'explicit_valence': - return atom.GetExplicitValence() - elif name == 'formal_charge': - return atom.GetFormalCharge() - elif name == 'hybridization': - return atom.GetHybridization() - elif name == 'implicit_valence': - return atom.GetImplicitValence() - elif name == 'is_aromatic': - return int(atom.GetIsAromatic()) - elif name == 'mass': - return int(atom.GetMass()) - elif name == 'total_numHs': - return atom.GetTotalNumHs() - elif name == 'num_radical_e': - return atom.GetNumRadicalElectrons() - elif name == 'atom_is_in_ring': - return int(atom.IsInRing()) - elif name == 'valence_out_shell': - return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) - else: - raise ValueError(name) - - @staticmethod - def get_atom_feature_id(atom, name): - """get atom features id""" - assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name - return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) - - @staticmethod - def get_atom_feature_size(name): - """get atom features size""" - assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name - return len(CompoundKit.atom_vocab_dict[name]) - - ### bond - - @staticmethod - def get_bond_value(bond, name): - """get bond values""" - if name == 'bond_dir': - return bond.GetBondDir() - elif name == 'bond_type': - return bond.GetBondType() - elif name == 'is_in_ring': - return int(bond.IsInRing()) - elif name == 'is_conjugated': - return int(bond.GetIsConjugated()) - elif name == 'bond_stereo': - return bond.GetStereo() - else: - raise ValueError(name) - - @staticmethod - def get_bond_feature_id(bond, name): - """get bond features id""" - assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name - return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) - - @staticmethod - def get_bond_feature_size(name): - """get bond features size""" - assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name - return len(CompoundKit.bond_vocab_dict[name]) - - ### fingerprint - - @staticmethod - def get_morgan_fingerprint(mol, radius=2): - """get morgan fingerprint""" - nBits = CompoundKit.morgan_fp_N - mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) - return [int(b) for b in mfp.ToBitString()] - - @staticmethod - def get_morgan2048_fingerprint(mol, radius=2): - """get morgan2048 fingerprint""" - nBits = CompoundKit.morgan2048_fp_N - mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) - return [int(b) for b in mfp.ToBitString()] - - @staticmethod - def get_maccs_fingerprint(mol): - """get maccs fingerprint""" - fp = AllChem.GetMACCSKeysFingerprint(mol) - return [int(b) for b in fp.ToBitString()] - - ### functional groups - - @staticmethod - def get_daylight_functional_group_counts(mol): - """get daylight functional group counts""" - fg_counts = [] - for fg_mol in CompoundKit.day_light_fg_mo_list: - sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) - fg_counts.append(len(sub_structs)) - return fg_counts - - @staticmethod - def get_ring_size(mol): - """return (N,6) list""" - rings = mol.GetRingInfo() - rings_info = [] - for r in rings.AtomRings(): - rings_info.append(r) - ring_list = [] - for atom in mol.GetAtoms(): - atom_result = [] - for ringsize in range(3, 9): - num_of_ring_at_ringsize = 0 - for r in rings_info: - if len(r) == ringsize and atom.GetIdx() in r: - num_of_ring_at_ringsize += 1 - if num_of_ring_at_ringsize > 8: - num_of_ring_at_ringsize = 9 - atom_result.append(num_of_ring_at_ringsize) - - ring_list.append(atom_result) - return ring_list - - @staticmethod - def atom_to_feat_vector(atom): - """ tbd """ - atom_names = { - "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), - "chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()), - "degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()), - "explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()), - "formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()), - "hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()), - "implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()), - "is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())), - "total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()), - 'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()), - 'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())), - 'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'], - CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())), - 'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()), - 'partial_charge': CompoundKit.check_partial_charge(atom), - 'mass': atom.GetMass(), - } - return atom_names - - @staticmethod - def get_atom_names(mol): - """get atom name list - TODO: to be remove in the future - """ - atom_features_dicts = [] - Chem.rdPartialCharges.ComputeGasteigerCharges(mol) - for i, atom in enumerate(mol.GetAtoms()): - atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) - - ring_list = CompoundKit.get_ring_size(mol) - for i, atom in enumerate(mol.GetAtoms()): - atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( - CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) - atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( - CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) - atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( - CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) - atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( - CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) - atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( - CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) - atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( - CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) - - return atom_features_dicts - - @staticmethod - def check_partial_charge(atom): - """tbd""" - pc = atom.GetDoubleProp('_GasteigerCharge') - if pc != pc: - # unsupported atom, replace nan with 0 - pc = 0 - if pc == float('inf'): - # max 4 for other atoms, set to 10 here if inf is get - pc = 10 - return pc - - -class Compound3DKit(object): - """the 3Dkit of Compound""" - - @staticmethod - def get_atom_poses(mol, conf): - """tbd""" - atom_poses = [] - for i, atom in enumerate(mol.GetAtoms()): - if atom.GetAtomicNum() == 0: - return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) - pos = conf.GetAtomPosition(i) - atom_poses.append([pos.x, pos.y, pos.z]) - return atom_poses - - @staticmethod - def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): - """the atoms of mol will be changed in some cases.""" - conf = mol.GetConformer() - atom_poses = Compound3DKit.get_atom_poses(mol, conf) - return mol,atom_poses - # try: - # new_mol = Chem.AddHs(mol) - # res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs) - # ### MMFF generates multiple conformations - # res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) - # new_mol = Chem.RemoveHs(new_mol) - # index = np.argmin([x[1] for x in res]) - # energy = res[index][1] - # conf = new_mol.GetConformer(id=int(index)) - # except: - # new_mol = mol - # AllChem.Compute2DCoords(new_mol) - # energy = 0 - # conf = new_mol.GetConformer() - # - # atom_poses = Compound3DKit.get_atom_poses(new_mol, conf) - # if return_energy: - # return new_mol, atom_poses, energy - # else: - # return new_mol, atom_poses - - @staticmethod - def get_2d_atom_poses(mol): - """get 2d atom poses""" - AllChem.Compute2DCoords(mol) - conf = mol.GetConformer() - atom_poses = Compound3DKit.get_atom_poses(mol, conf) - return atom_poses - - @staticmethod - def get_bond_lengths(edges, atom_poses): - """get bond lengths""" - bond_lengths = [] - for src_node_i, tar_node_j in edges: - bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) - bond_lengths = np.array(bond_lengths, 'float32') - return bond_lengths - - @staticmethod - def get_superedge_angles(edges, atom_poses, dir_type='HT'): - """get superedge angles""" - - def _get_vec(atom_poses, edge): - return atom_poses[edge[1]] - atom_poses[edge[0]] - - def _get_angle(vec1, vec2): - norm1 = np.linalg.norm(vec1) - norm2 = np.linalg.norm(vec2) - if norm1 == 0 or norm2 == 0: - return 0 - vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors - vec2 = vec2 / (norm2 + 1e-5) - angle = np.arccos(np.dot(vec1, vec2)) - return angle - - E = len(edges) - edge_indices = np.arange(E) - super_edges = [] - bond_angles = [] - bond_angle_dirs = [] - for tar_edge_i in range(E): - tar_edge = edges[tar_edge_i] - if dir_type == 'HT': - src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] - elif dir_type == 'HH': - src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] - else: - raise ValueError(dir_type) - for src_edge_i in src_edge_indices: - if src_edge_i == tar_edge_i: - continue - src_edge = edges[src_edge_i] - src_vec = _get_vec(atom_poses, src_edge) - tar_vec = _get_vec(atom_poses, tar_edge) - super_edges.append([src_edge_i, tar_edge_i]) - angle = _get_angle(src_vec, tar_vec) - bond_angles.append(angle) - bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T - - if len(super_edges) == 0: - super_edges = np.zeros([0, 2], 'int64') - bond_angles = np.zeros([0, ], 'float32') - else: - super_edges = np.array(super_edges, 'int64') - bond_angles = np.array(bond_angles, 'float32') - return super_edges, bond_angles, bond_angle_dirs - - -def new_smiles_to_graph_data(smiles, **kwargs): - """ - Convert smiles to graph data. - """ - mol = AllChem.MolFromSmiles(smiles) - if mol is None: - return None - data = new_mol_to_graph_data(mol) - return data - - -def new_mol_to_graph_data(mol): - """ - mol_to_graph_data - Args: - atom_features: Atom features. - edge_features: Edge features. - morgan_fingerprint: Morgan fingerprint. - functional_groups: Functional groups. - """ - if len(mol.GetAtoms()) == 0: - return None - - atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names - bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) - - data = {} - - ### atom features - data = {name: [] for name in atom_id_names} - - raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) - for atom_feat in raw_atom_feat_dicts: - for name in atom_id_names: - data[name].append(atom_feat[name]) - - ### bond and bond features - for name in bond_id_names: - data[name] = [] - data['edges'] = [] - - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - # i->j and j->i - data['edges'] += [(i, j), (j, i)] - for name in bond_id_names: - bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) - data[name] += [bond_feature_id] * 2 - - #### self loop - N = len(data[atom_id_names[0]]) - for i in range(N): - data['edges'] += [(i, i)] - for name in bond_id_names: - bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1 - data[name] += [bond_feature_id] * N - - ### make ndarray and check length - for name in list(CompoundKit.atom_vocab_dict.keys()): - data[name] = np.array(data[name], 'int64') - for name in CompoundKit.atom_float_names: - data[name] = np.array(data[name], 'float32') - for name in bond_id_names: - data[name] = np.array(data[name], 'int64') - data['edges'] = np.array(data['edges'], 'int64') - - ### morgan fingerprint - data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') - # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') - data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') - data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') - return data - - -def mol_to_graph_data(mol): - """ - mol_to_graph_data - Args: - atom_features: Atom features. - edge_features: Edge features. - morgan_fingerprint: Morgan fingerprint. - functional_groups: Functional groups. - """ - if len(mol.GetAtoms()) == 0: - return None - - atom_id_names = [ - "atomic_num", "chiral_tag", "degree", "explicit_valence", - "formal_charge", "hybridization", "implicit_valence", - "is_aromatic", "total_numHs", - ] - bond_id_names = [ - "bond_dir", "bond_type", "is_in_ring", - ] - - data = {} - for name in atom_id_names: - data[name] = [] - data['mass'] = [] - for name in bond_id_names: - data[name] = [] - data['edges'] = [] - - ### atom features - for i, atom in enumerate(mol.GetAtoms()): - if atom.GetAtomicNum() == 0: - return None - for name in atom_id_names: - data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV - data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) - - ### bond features - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - # i->j and j->i - data['edges'] += [(i, j), (j, i)] - for name in bond_id_names: - bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV - data[name] += [bond_feature_id] * 2 - - ### self loop (+2) - N = len(data[atom_id_names[0]]) - for i in range(N): - data['edges'] += [(i, i)] - for name in bond_id_names: - bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop - data[name] += [bond_feature_id] * N - - ### check whether edge exists - if len(data['edges']) == 0: # mol has no bonds - for name in bond_id_names: - data[name] = np.zeros((0,), dtype="int64") - data['edges'] = np.zeros((0, 2), dtype="int64") - - ### make ndarray and check length - for name in atom_id_names: - data[name] = np.array(data[name], 'int64') - data['mass'] = np.array(data['mass'], 'float32') - for name in bond_id_names: - data[name] = np.array(data[name], 'int64') - data['edges'] = np.array(data['edges'], 'int64') - - ### morgan fingerprint - data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') - # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') - data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') - data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') - return data - - -def mol_to_geognn_graph_data(mol, atom_poses, dir_type): - """ - mol: rdkit molecule - dir_type: direction type for bond_angle grpah - """ - if len(mol.GetAtoms()) == 0: - return None - - data = mol_to_graph_data(mol) - - data['atom_pos'] = np.array(atom_poses, 'float32') - data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) - BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ - Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) - data['BondAngleGraph_edges'] = BondAngleGraph_edges - data['bond_angle'] = np.array(bond_angles, 'float32') - return data - - -def mol_to_geognn_graph_data_MMFF3d(mol): - """tbd""" - if len(mol.GetAtoms()) <= 400: - mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) - else: - atom_poses = Compound3DKit.get_2d_atom_poses(mol) - return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') - - -def mol_to_geognn_graph_data_raw3d(mol): - """tbd""" - atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) - return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') - -def obtain_3D_mol(smiles,name): - mol = AllChem.MolFromSmiles(smiles) - new_mol = Chem.AddHs(mol) - res = AllChem.EmbedMultipleConfs(new_mol) - ### MMFF generates multiple conformations - res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) - new_mol = Chem.RemoveHs(new_mol) - Chem.MolToMolFile(new_mol, name+'.mol') - return new_mol - -def predict_SMILES_info(smiles): - # by lihao, input smiles, output dict - mol = AllChem.MolFromSmiles(smiles) - AllChem.EmbedMolecule(mol) - info_dict = mol_to_geognn_graph_data_MMFF3d(mol) - return info_dict - -if __name__ == "__main__": - # smiles = "OCc1ccccc1CN" - smiles = r"[H]/[NH+]=C(\N)C1=CC(=O)/C(=C\C=c2ccc(=C(N)[NH3+])cc2)C=C1" - # smiles = 'CC' - mol = AllChem.MolFromSmiles(smiles) - AllChem.EmbedMolecule(mol) - data = mol_to_geognn_graph_data_MMFF3d(mol) - for key, value in data.items(): - print(key, value.shape) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/dataloader.py b/ppmat/datasets/ECDFormerDataset/dataloader.py deleted file mode 100644 index cfbb15b8..00000000 --- a/ppmat/datasets/ECDFormerDataset/dataloader.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -from paddle.io import Dataset - -from typing import Any, List, Sequence, Union - -from paddle_geometric.data import Batch, Dataset -from paddle_geometric.data.data import BaseData -from paddle_geometric.data.datapipes import DatasetAdapter - -def call(batch: List[Any]) -> Any: - batch = [list(x) for x in zip(*batch)] # transpose - for i in range(len(batch)): # 组Batch - batch[i] = Batch.from_data_list(batch[i]) - - batch0 = batch[0] - batch1 = batch[1] - - # Data解包到Tensor字典 - batch_atom_bond, batch_bond_angle = batch0, batch1 - x, edge_index, edge_attr, query_mask =batch_atom_bond.x,batch_atom_bond.edge_index,batch_atom_bond.edge_attr,batch_atom_bond.query_mask - ba_edge_index, ba_edge_attr = batch_bond_angle.edge_index,batch_bond_angle.edge_attr - batch_data = batch_atom_bond.batch - pos_gt = batch_atom_bond.peak_position - height_gt = batch_atom_bond.peak_height - num_gt = batch_atom_bond.peak_num - return \ - { - "x" : x , - "edge_index" : edge_index , - "edge_attr" : edge_attr , - "batch_data" : batch_data , - "ba_edge_index" : ba_edge_index , - "ba_edge_attr" : ba_edge_attr , - "query_mask" : query_mask - }, \ - { - "peak_number_gt" : num_gt , - "peak_position_gt": pos_gt , - "peak_height_gt" : height_gt - } - -class ECDFormerDataset_DataLoader(paddle.io.DataLoader): - def __init__( - self, - dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], - batch_size: int = 1, - shuffle: bool = False, - **kwargs, - ): - super().__init__( - dataset, - batch_size=batch_size, - shuffle=shuffle, - collate_fn=call, - **kwargs, - ) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/eval_func.py b/ppmat/datasets/ECDFormerDataset/eval_func.py deleted file mode 100644 index 72f772ba..00000000 --- a/ppmat/datasets/ECDFormerDataset/eval_func.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import json - -import paddle -import paddle.nn as nn -import numpy as np - -from sklearn.metrics import mean_squared_error - -def Accuracy(pred, gt): - # the implimentation of good sample Accuracy. We calculate multi-range accuracy - # pred: paddle.Tensor(batch), the tensor of prediction - # gt: paddle.Tensor(batch), the tensor of groundtruth - pred, gt = pred.tolist(), gt.tolist() - assert len(pred) == len(gt) - - acc, acc_1, acc_2, acc_3 = 0, 0, 0, 0 - for i in range(len(pred)): - if pred[i] == gt[i]: - acc += 1 - elif abs(pred[i]-gt[i]) == 1: - acc_1 += 1 - elif abs(pred[i]-gt[i]) == 2: - acc_2 += 1 - elif abs(pred[i]-gt[i]) == 3: - acc_3 += 1 - else: continue - - return acc, acc_1, acc_2, acc_3 - -def MAE(pred, gt): - # the implimentation of Mean Absolute Error - # MAE == L1 loss - # pred: paddle.Tensor(batch, seq_len), the tensor of prediction - # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth - - pred = paddle.to_tensor(pred, dtype=paddle.float32) - gt = paddle.to_tensor(gt, dtype=paddle.float32) - assert pred.shape == gt.shape - L1Loss = nn.L1Loss(reduction='mean') # paddle��size_average�ѷ�����ͳһ��reduction - mae = L1Loss(pred, gt) - - return mae - - -def MAPE(pred, gt): - # the implimentation of Mean Absolute Percentage Error - # pred: paddle.Tensor(batch, seq_len), the tensor of prediction - # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth - - pred = paddle.to_tensor(pred, dtype=paddle.float32) - gt = paddle.to_tensor(gt, dtype=paddle.float32) - pred, gt = pred.numpy(), gt.numpy() # paddle����cpu()������ֱ����numpy() - mape_loss = np.mean( - np.abs(np.divide(pred-gt, gt, out=np.zeros_like(gt), where=gt!=0)) - ) * 100 - return mape_loss - -def get_sequence_peak(sequence): - # input- seq: List - # output- peak_list contains peak position - peak_list = [] - for i in range(1, len(sequence)-1): - if sequence[i-1]sequence[i+1]: - peak_list.append(i) - if sequence[i-1]>sequence[i] and sequence[i]0: correct_peak_symbols += 1 - ## peak number - number_gt.append(len(peaks_gt)) - number_pred.append(len(peaks_pred)) - ## peak position - if len(peaks_gt) == min_peaks_len: - peaks_gt.extend([len(gt[i])] * (max_peaks_len-min_peaks_len) ) - else: peaks_pred.extend([len(pred[i])] * (max_peaks_len-min_peaks_len) ) - position_gt.extend(peaks_gt) - position_pred.extend(peaks_pred) - - rmse_position = np.sqrt(mean_squared_error(position_gt, position_pred)) - rmse_number = np.sqrt(mean_squared_error(number_gt, number_pred)) - symbol_acc = correct_peak_symbols / total_peaks if total_peaks != 0 else 0.0 - return symbol_acc, rmse_position, rmse_number - -def Peak_for_draw(pred, gt): - # the implementation of RMSE for Peak Range, Number, Position - # return the Peak information - # pred: paddle.Tensor(batch, seq_len), the tensor of prediction - # gt: paddle.Tensor(batch, seq_len), the tensor of groundtruth - - pred = paddle.to_tensor(pred, dtype=paddle.float32) - gt = paddle.to_tensor(gt, dtype=paddle.float32) - pred, gt = pred.numpy().tolist(), gt.numpy().tolist() # paddle����cpu()���� - batch_size = len(pred) - - # calculate RMSE Range - range_gt, range_pred = [], [] - for i in range(batch_size): - range_gt.append(max(gt[i]) - min(gt[i])) - range_pred.append(max(pred[i]) - min(pred[i])) - - # calculate RMSE for Peak number - number_gt, number_pred = [], [] - position_gt, position_pred = [], [] - for i in range(batch_size): - peaks_gt = get_sequence_peak(gt[i]) - peaks_pred = get_sequence_peak(pred[i]) - - number_gt.append(len(peaks_gt)) - number_pred.append(len(peaks_pred)) - - min_peaks_len = min(len(peaks_gt), len(peaks_pred)) - max_peaks_len = max(len(peaks_gt), len(peaks_pred)) - - if len(peaks_gt) == min_peaks_len: - peaks_gt.extend([len(gt[i])] * (max_peaks_len-min_peaks_len) ) - else: - peaks_pred.extend([len(pred[i])] * (max_peaks_len-min_peaks_len) ) - position_gt.append(sum(peaks_gt)) - position_pred.append(sum(peaks_pred)) - - return dict( - peak_range = (range_gt, range_pred), - peak_num = (number_gt, number_pred), - peak_pos = (position_gt, position_pred), - ) \ No newline at end of file diff --git a/ppmat/datasets/ECDFormerDataset/util_func.py b/ppmat/datasets/ECDFormerDataset/util_func.py deleted file mode 100644 index 49bac40a..00000000 --- a/ppmat/datasets/ECDFormerDataset/util_func.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -def has_element_in_range(lst, lower_bound, upper_bound): - """ - 检查给定列表 lst 中是否存在元素在指定的区间 [lower_bound, upper_bound] 内。 - - 参数: - - lst: 输入的列表 - - lower_bound: 区间的下界 - - upper_bound: 区间的上界 - - 返回: - - 存在元素在指定区间内时返回 True, 否则返回 False - """ - for element in lst: - if lower_bound <= element <= upper_bound: - return True - return False - - -def normalize_func(src_list, norm_range=[-100, 100]): - # lihao implecation for list normalization - # input: src_list, normalization range - # output: tgt_list after normalization - - src_max, src_min = max(src_list), min(src_list) - norm_min, norm_max = norm_range[0], norm_range[1] - if src_max == 0: src_max = 1 - if src_min == 0: src_min = -1 - - tgt_list = [] - for i in range(len(src_list)): - if src_list[i] >= 0: - tgt_list.append(src_list[i] * norm_max / src_max) - else: - tgt_list.append(src_list[i] * norm_min / src_min) - - assert len(src_list) == len(tgt_list) - return tgt_list - -if __name__ == "__main__": - src = [-50, 0, 1, 50] - norm_range = [-100, 100] - tgt = normalize_func(src, norm_range) - print(tgt) diff --git a/ppmat/datasets/IRDataset/__init__.py b/ppmat/datasets/IRDataset/__init__.py deleted file mode 100644 index b0e4cd71..00000000 --- a/ppmat/datasets/IRDataset/__init__.py +++ /dev/null @@ -1,452 +0,0 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -IR光谱预测数据集模块 -支持预加载的npy文件,包含缓存机制,默认使用100样本的小数据集 -""" - -import os -import numpy as np -import pandas as pd -import paddle -from paddle.io import Dataset, DataLoader -from paddle_geometric.data import Data -from tqdm import tqdm -import pickle -import json -import warnings - -import rdkit -from rdkit import Chem -from rdkit.Chem import AllChem - -from .compound_tools import mol_to_geognn_graph_data_MMFF3d -from .compound_tools import get_atom_feature_dims, get_bond_feature_dims -from .colored_tqdm import ColoredTqdm as tqdm -from .place_env import PlaceEnv - -# ----------------常量定义---------------- -ATOM_ID_NAMES = [ - "atomic_num", "chiral_tag", "degree", "explicit_valence", - "formal_charge", "hybridization", "implicit_valence", - "is_aromatic", "total_numHs", -] - -BOND_ID_NAMES = ["bond_dir", "bond_type", "is_in_ring"] - -BOND_ANGLE_FLOAT_NAMES = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] - -# 获取特征维度 -FULL_ATOM_FEATURE_DIMS = get_atom_feature_dims(ATOM_ID_NAMES) -FULL_BOND_FEATURE_DIMS = get_bond_feature_dims(BOND_ID_NAMES) - -# IR光谱参数 -IR_WAVELENGTH_MIN = 500 -IR_WAVELENGTH_MAX = 4000 -IR_STEP = 100 # 波数步长,用于离散化 -IR_NUM_POSITION_CLASSES = (IR_WAVELENGTH_MAX - IR_WAVELENGTH_MIN) // IR_STEP # 36 - -# 默认最大峰数 -DEFAULT_MAX_PEAKS = 15 - - -def get_key_padding_mask(tokens): - """生成query padding mask""" - key_padding_mask = paddle.zeros(tokens.shape) - key_padding_mask[tokens == -1] = -paddle.inf - return key_padding_mask - - -def x_bin_position(real_x, distance=IR_STEP): - """将实际波数转换为箱ID""" - return int((real_x - IR_WAVELENGTH_MIN) / distance) - - -def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): - """ - 从原始特征构建IR图数据 - 类似于ECD的Construct_dataset,但针对IR任务 - - Args: - dataset: list of dict, 每个元素是分子的info字典 - data_index: list or array, 索引列表 - descriptor_path: str, 描述符文件路径(IR可能不需要,保留接口) - - Returns: - graph_atom_bond: list of Data, atom-bond图 - graph_bond_angle: list of Data, bond-angle图 - """ - graph_atom_bond = [] - graph_bond_angle = [] - - # IR任务可能不需要描述符,但如果需要可以加载 - all_descriptor = None - if descriptor_path and os.path.exists(descriptor_path): - all_descriptor = np.load(descriptor_path) - - for i in tqdm(range(len(dataset)), desc="Constructing IR graphs"): - data = dataset[i] - - # 收集原子特征 - atom_feature = [] - for name in ATOM_ID_NAMES: - if name in data: - atom_feature.append(data[name]) - else: - # 如果某些特征缺失,用0填充 - # 注意:根据实际数据调整 - if i == 0: # 只在第一次警告 - warnings.warn(f"Feature {name} not found in data, using zeros") - num_atoms = data.get('atomic_num', np.zeros(1)).shape[0] - atom_feature.append(np.zeros(num_atoms)) - - # 收集键特征 - bond_feature = [] - for name in BOND_ID_NAMES: - if name in data: - bond_feature.append(data[name]) - else: - if i == 0: - warnings.warn(f"Bond feature {name} not found, using zeros") - num_bonds = data.get('bond_dir', np.zeros(1)).shape[0] - bond_feature.append(np.zeros(num_bonds)) - - # 转换为Tensor - atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') - bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') - - # 键长特征(IR可能不需要,但保留) - bond_float_feature = paddle.to_tensor(data.get('bond_length', np.zeros(data['edges'].shape[0])).astype(np.float32)) - - # 键角特征(IR可能不需要,但保留) - bond_angle_feature = paddle.to_tensor(data.get('bond_angle', np.zeros(data.get('BondAngleGraph_edges', np.zeros((0,2))).shape[0])).astype(np.float32)) - - # 边索引 - edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') - bond_index = paddle.to_tensor(data.get('BondAngleGraph_edges', np.zeros((0,2))).T, dtype='int64') - - - data_index_int = paddle.to_tensor(np.array(int(data_index[i])), dtype='int64') - - # 获取原子数(键角图的节点数) - num_atoms = atom_feature.shape[0] - - # 合并键特征 - bond_feature = paddle.concat( - [bond_feature.astype(bond_float_feature.dtype), - bond_float_feature.reshape([-1, 1])], - axis=1 - ) - - # 处理键角特征(如果有) - if bond_angle_feature.shape[0] > 0: - # 如果有描述符,可以添加 - if all_descriptor is not None and i < all_descriptor.shape[0]: - TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 - RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] - RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] - MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] - MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] - - bond_angle_feature = paddle.concat( - [bond_angle_feature.reshape([-1, 1]), TPSA.reshape([-1, 1])], axis=1 - ) - bond_angle_feature = paddle.concat([bond_angle_feature, RASA.reshape([-1, 1])], axis=1) - bond_angle_feature = paddle.concat([bond_angle_feature, RPSA.reshape([-1, 1])], axis=1) - bond_angle_feature = paddle.concat([bond_angle_feature, MDEC.reshape([-1, 1])], axis=1) - bond_angle_feature = paddle.concat([bond_angle_feature, MATS.reshape([-1, 1])], axis=1) - else: - # 如果没有描述符,直接reshape - bond_angle_feature = bond_angle_feature.reshape([-1, 1]) - - # 创建Data对象 - data_atom_bond = Data( - x=atom_feature, - edge_index=edge_index, - edge_attr=bond_feature, - data_index=data_index_int, - ) - - data_bond_angle = Data( - edge_index=bond_index, - edge_attr=bond_angle_feature if bond_angle_feature.shape[0] > 0 else paddle.zeros([0, 1]), - num_nodes=num_atoms, # 键角图的节点数等于原子数 - ) - - graph_atom_bond.append(data_atom_bond) - graph_bond_angle.append(data_bond_angle) - - return graph_atom_bond, graph_bond_angle - - -def read_ir_spectra_by_ids(sample_path, index_all, max_peak=DEFAULT_MAX_PEAKS): - """ - 按需读取IR光谱文件 - - Args: - sample_path: str, IR光谱JSON文件目录 - index_all: list, 需要读取的文件ID列表 - max_peak: int, 最大峰数 - - Returns: - ir_final_list: list of dict, 包含峰值信息的字典列表 - """ - ir_final_list = [] - - for fileid in tqdm(index_all, desc="Reading IR spectra by ID"): - filepath = os.path.join(sample_path, f"{fileid}.json") - - try: - with open(filepath, 'r') as f: - raw_ir_info = json.load(f) - - ir_x = raw_ir_info['x'] - ir_y = raw_ir_info['y_40'] - - from scipy.signal import find_peaks - peaks_raw, _ = find_peaks(x=ir_y, height=0.1, distance=100) - peaks_raw = peaks_raw.tolist() - - # 处理峰值(与原作完全一致) - peak_num = min(len(peaks_raw), max_peak) - - if peak_num > 0: - if len(peaks_raw) > max_peak: - peaks = peaks_raw[len(peaks_raw)-max_peak:] - else: - peaks = peaks_raw - - peak_position_list = [x_bin_position(ir_x[i]) for i in peaks] - peak_height_list = [ir_y[i] for i in peaks] - else: - peak_position_list = [] - peak_height_list = [] - - peak_position_list = peak_position_list + [-1] * (max_peak - len(peak_position_list)) - peak_height_list = peak_height_list + [-1] * (max_peak - len(peak_height_list)) - - query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) - - tmp_dict = { - 'id': fileid, - 'seq_40': ir_y, - 'peak_num': peak_num, - 'peak_position': peak_position_list, - 'peak_height': peak_height_list, - 'query_mask': query_padding_mask.unsqueeze(0), - } - ir_final_list.append(tmp_dict) - - except Exception as e: - warnings.warn(f"Error processing {fileid}.json: {e}") - continue - - ir_final_list.sort(key=lambda x: x['id']) - return ir_final_list - - -def load_ir_meta_file(mode='100'): - """ - 加载预生成的IR元数据文件 - - Args: - mode: str, 可选 '100', '10000', 'all' - - Returns: - dataset_all: list, 图特征数据 - smiles_all: list, SMILES列表 - index_all: list, 索引列表 - """ - valid_modes = {'100', '10000', 'all'} - if mode not in valid_modes: - warnings.warn(f"Invalid mode {mode}, using '100'") - mode = '100' - - filename = f'ir_column_charity_{mode}.npy' - - if not os.path.exists(filename): - # 尝试在dataset/IR目录下查找 - alt_path = os.path.join('dataset', 'IR', filename) - if os.path.exists(alt_path): - filename = alt_path - else: - raise FileNotFoundError(f"IR meta file {filename} not found") - - print(f"Loading IR meta file: {filename}") - data = np.load(filename, allow_pickle=True).item() - - return data['dataset_all'], data['smiles_all'], data['index_all'] - - -class IRDataset(Dataset): - """ - IR光谱预测数据集类 - - 支持三种预加载模式: - - '100': 100个样本的小数据集(默认,用于快速测试) - - '10000': 1万个样本的中等数据集 - - 'all': 全部样本(可能很大) - - 关键优化:按需读取JSON文件,只读取index_all中指定的文件! - """ - - _cache = {} - - @PlaceEnv(paddle.CPUPlace()) - def __init__(self, - path: str = "dataset/IR", - mode: str = '100', - use_geometry_enhanced: bool = True, - force_reload: bool = False, - cache: bool = True): - - self.path = path - self.mode = mode - self.use_geometry_enhanced = use_geometry_enhanced - self.cache_enabled = cache - - cache_key = f"{path}_{mode}_{use_geometry_enhanced}" - - if cache and not force_reload and cache_key in self._cache: - print(f"Loading IR dataset from cache: {mode}") - cached_data = self._cache[cache_key] - self.graph_atom_bond = cached_data['atom_bond'] - self.graph_bond_angle = cached_data['bond_angle'] - self.smiles_list = cached_data.get('smiles', []) - return - - print(f"First-time loading IR dataset (mode={mode})...") - - # 1. 加载元数据文件(获取index_all) - meta_path = os.path.join(path, f'ir_column_charity_{mode}.npy') - if not os.path.exists(meta_path): - raise FileNotFoundError(f"IR meta file {meta_path} not found") - - data = np.load(meta_path, allow_pickle=True).item() - dataset_all = data['dataset_all'] - smiles_all = data['smiles_all'] - index_all = data['index_all'] - - print(f"Loaded meta data: {len(index_all)} samples") - - # 2. 按需读取IR光谱(只读取index_all中的文件!) - spectra_path = os.path.join(path, 'qm9_ir_spec') - self.ir_sequences = read_ir_spectra_by_ids(spectra_path, index_all) - - print(f"Loaded {len(self.ir_sequences)} IR spectra") - - # 3. 构建图数据 - descriptor_path = os.path.join(path, 'descriptor_all_column.npy') - if not os.path.exists(descriptor_path): - descriptor_path = None - - total_graph_atom_bond, total_graph_bond_angle = Construct_IR_Dataset( - dataset_all, index_all, descriptor_path - ) - - # 4. 将光谱信息附加到图数据 - self.graph_atom_bond = [] - self.graph_bond_angle = [] - self.smiles_list = [] - - # 注意:ir_sequences已经按id排序,index_all也是按id排序的 - for i, itm in enumerate(self.ir_sequences): - atom_bond = total_graph_atom_bond[i] - - atom_bond.sequence = paddle.to_tensor([itm['seq_40']]) - atom_bond.ir_id = paddle.to_tensor(int(itm['id'])) - atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) - atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) - atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) - atom_bond.query_mask = itm['query_mask'] - - self.graph_atom_bond.append(atom_bond) - self.graph_bond_angle.append(total_graph_bond_angle[i]) - self.smiles_list.append(smiles_all[i]) - - print(f"Final dataset size: {len(self.graph_atom_bond)}") - - if cache: - self._cache[cache_key] = { - 'atom_bond': self.graph_atom_bond, - 'bond_angle': self.graph_bond_angle, - 'smiles': self.smiles_list - } - def __len__(self): - """返回数据集大小""" - return len(self.graph_atom_bond) - - def __getitem__(self, idx): - return (self.graph_atom_bond[idx], self.graph_bond_angle[idx]) - -# 自定义DataLoader -class IRDataLoader(DataLoader): - """IR数据集专用DataLoader""" - - def __init__(self, dataset, batch_size=64, shuffle=True, num_workers=0, **kwargs): - super().__init__( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - collate_fn=self._collate_fn, - **kwargs - ) - - @staticmethod - def _collate_fn(batch): - """ - 自定义collate函数 - - Args: - batch: list of (atom_bond, bond_angle, smiles) tuples - - Returns: - batched_atom_bond: Batch - batched_bond_angle: Batch - smiles_list: list of str - """ - from paddle_geometric.data import Batch - - atom_bond_list = [item[0] for item in batch] - bond_angle_list = [item[1] for item in batch] - #smiles_list = [item[2] for item in batch] - - batch_atom_bond = Batch.from_data_list(atom_bond_list) - batch_bond_angle = Batch.from_data_list(bond_angle_list) - # Data解包到Tensor字典 - x, edge_index, edge_attr, query_mask =batch_atom_bond.x,batch_atom_bond.edge_index,batch_atom_bond.edge_attr,batch_atom_bond.query_mask - ba_edge_index, ba_edge_attr = batch_bond_angle.edge_index,batch_bond_angle.edge_attr - batch_data = batch_atom_bond.batch - pos_gt = batch_atom_bond.peak_position - height_gt = batch_atom_bond.peak_height - num_gt = batch_atom_bond.peak_num - return \ - { - "x" : x , - "edge_index" : edge_index , - "edge_attr" : edge_attr , - "batch_data" : batch_data , - "ba_edge_index" : ba_edge_index , - "ba_edge_attr" : ba_edge_attr , - "query_mask" : query_mask - }, \ - { - "peak_number_gt" : num_gt , - "peak_position_gt": pos_gt , - "peak_height_gt" : height_gt - } - return batched_atom_bond, batched_bond_angle - diff --git a/ppmat/datasets/IRDataset/colored_tqdm.py b/ppmat/datasets/IRDataset/colored_tqdm.py deleted file mode 100644 index 8176778b..00000000 --- a/ppmat/datasets/IRDataset/colored_tqdm.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tqdm import tqdm -import time -import os;os.system("") #兼容windows - -def hex_to_ansi(hex_color: str, background: bool = False) -> str: - """ - 将十六进制颜色转换为ANSI转义序列 - - Args: - hex_color: 十六进制颜色,如 '#dda0a0' 或 'dda0a0' - background: True表示背景色,False表示前景色 - - Returns: - ANSI转义序列字符串,如 '\033[38;2;221;160;160m' - - Example: - >>> print(f"{hex_to_ansi('#dda0a0')}Hello{hex_to_ansi('#000000')} World") - >>> print(f"{hex_to_ansi('dda0a0', background=True)}背景色{hex_to_ansi.reset()}") - """ - # 移除#号并转换为小写 - hex_color = hex_color.lower().lstrip('#') - - # 处理简写形式 (#fff -> ffffff) - if len(hex_color) == 3: - hex_color = ''.join([c * 2 for c in hex_color]) - - # 转换为RGB值 - r = int(hex_color[0:2], 16) - g = int(hex_color[2:4], 16) - b = int(hex_color[4:6], 16) - - # ANSI真彩色序列 - # 38;2;R;G;B 为前景色,48;2;R;G;B 为背景色 - code = 48 if background else 38 - return f'\033[{code};2;{r};{g};{b}m' - -def rgb_to_ansi(r: int, g: int, b: int, background: bool = False) -> str: - """RGB值直接转ANSI""" - code = 48 if background else 38 - return f'\033[{code};2;{r};{g};{b}m' - -# 重置颜色的ANSI码 -hex_to_ansi.reset = '\033[0m' - -class ColoredTqdm(tqdm): - def __init__(self, *args, - start_color=(221, 160, 160), # RGB: #DDA0A0 - end_color=(160, 221, 160), # RGB: #A0DDA0 - **kwargs): - super().__init__(*args, **kwargs) - self.start_color = start_color - self.end_color = end_color - - def get_current_color(self): - progress = self.n / self.total if self.total > 0 else 0 - current_rgb = tuple( - int(start + (end - start) * progress) - for start, end in zip(self.start_color, self.end_color) - ) - result = current_rgb[0] * 16 ** 4 \ - + current_rgb[1] * 16 ** 2 \ - + current_rgb[2] * 16 ** 0 - return "%06x" % result - - def update(self, n=1): - super().update(n) - # 使用Rich的真彩色支持 - style = hex_to_ansi(self.get_current_color()) - self.bar_format = f'{{l_bar}}{style}{{bar}}{hex_to_ansi.reset}{{r_bar}}' - self.refresh() - - -if __name__ == "__main__": - # 使用示例 - for i in ColoredTqdm(range(100), desc="🌈 彩虹渐变"): - time.sleep(0.1) \ No newline at end of file diff --git a/ppmat/datasets/IRDataset/place_env.py b/ppmat/datasets/IRDataset/place_env.py deleted file mode 100644 index 62acf945..00000000 --- a/ppmat/datasets/IRDataset/place_env.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -import functools -from contextlib import contextmanager - -@contextmanager -def place_env(place): - """ - 上下文管理器,用于临时设置PaddlePaddle的运行设备 - - Args: - place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 - - 用法: - with place_env(paddle.CPUPlace()): - # 这里的代码在CPU上运行 - x = paddle.rand([2, 3]) - print(x) - - @place_env(paddle.CUDAPlace(0)) - def train(): - # 这个函数在GPU上运行 - pass - """ - # 保存当前的设备设置 - current_device = paddle.get_device() - - # 根据place类型设置设备 - if isinstance(place, paddle.CPUPlace): - paddle.set_device('cpu') - elif isinstance(place, paddle.CUDAPlace): - # 获取GPU设备ID - device_id = place.get_device_id() - paddle.set_device(f'gpu:{device_id}') - else: - raise ValueError(f"不支持的place类型: {type(place)}") - - try: - yield - finally: - # 恢复原来的设备设置 - paddle.set_device(current_device) - - -class PlaceEnv: - """ - 类版本的上下文管理器,也支持装饰器功能 - """ - - def __init__(self, place): - """ - 初始化PlaceEnv - - Args: - place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 - """ - self.place = place - self.original_device = None - - def __enter__(self): - """进入上下文时调用""" - # 保存当前的设备设置 - self.original_device = paddle.get_device() - - # 根据place类型设置设备 - if isinstance(self.place, paddle.CPUPlace): - paddle.set_device('cpu') - elif isinstance(self.place, paddle.CUDAPlace): - device_id = self.place.get_device_id() - paddle.set_device(f'gpu:{device_id}') - else: - raise ValueError(f"不支持的place类型: {type(self.place)}") - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """退出上下文时调用""" - # 恢复原来的设备设置 - if self.original_device is not None: - paddle.set_device(self.original_device) - - def __call__(self, func): - """ - 使实例可以作为装饰器使用 - - Args: - func: 要装饰的函数 - - Returns: - 装饰后的函数 - """ - @functools.wraps(func) - def wrapper(*args, **kwargs): - # 使用with语句来临时改变设备设置 - with self: - return func(*args, **kwargs) - return wrapper - - -# 为了兼容性,也可以保留函数版本的上下文管理器 -@contextmanager -def with_place_env(place): - """ - with_place_env的别名,与place_env功能相同 - """ - with place_env(place): - yield - - -# 使用示例 -if __name__ == "__main__": - # 测试with语句 - print("=== 测试with语句 ===") - print(f"当前设备: {paddle.get_device()}") - - with place_env(paddle.CPUPlace()): - print(f"with块内设备: {paddle.get_device()}") - x = paddle.rand([2, 3]) - print(f"创建的张量: {x}") - - print(f"with块外设备: {paddle.get_device()}") - - print("\n=== 测试类版本with语句 ===") - with PlaceEnv(paddle.CPUPlace()): - print(f"with块内设备: {paddle.get_device()}") - y = paddle.ones([2, 3]) - print(f"创建的张量: {y}") - - print(f"with块外设备: {paddle.get_device()}") - - # 测试装饰器功能 - print("\n=== 测试装饰器功能 ===") - - @PlaceEnv(paddle.CPUPlace()) - def cpu_function(): - """这个函数会在CPU上运行""" - print(f"函数内设备: {paddle.get_device()}") - return paddle.rand([2, 2]) - - # 检查是否有GPU可用 - if paddle.device.cuda.device_count() > 0: - @PlaceEnv(paddle.CUDAPlace(0)) - def gpu_function(): - """这个函数会在GPU上运行""" - print(f"函数内设备: {paddle.get_device()}") - return paddle.rand([2, 2]) - - # 调用装饰后的函数 - print("调用cpu_function:") - result_cpu = cpu_function() - print(f"函数执行后设备: {paddle.get_device()}") - print(f"结果: {result_cpu}") - - if paddle.device.cuda.device_count() > 0: - print("\n调用gpu_function:") - result_gpu = gpu_function() - print(f"函数执行后设备: {paddle.get_device()}") - print(f"结果: {result_gpu}") - - print("\n=== 测试多层嵌套 ===") - print(f"初始设备: {paddle.get_device()}") - - with PlaceEnv(paddle.CPUPlace()): - print(f"第一层with内设备: {paddle.get_device()}") - - if paddle.device.cuda.device_count() > 0: - with PlaceEnv(paddle.CUDAPlace(0)): - print(f"第二层with内设备: {paddle.get_device()}") - z = paddle.rand([2, 2]) - print(f"创建的张量: {z}") - - print(f"回到第一层with设备: {paddle.get_device()}") - - print(f"最终设备: {paddle.get_device()}") \ No newline at end of file diff --git a/ppmat/datasets/__init__.py b/ppmat/datasets/__init__.py index f95adcec..93eb0115 100644 --- a/ppmat/datasets/__init__.py +++ b/ppmat/datasets/__init__.py @@ -47,8 +47,8 @@ from ppmat.datasets.oc20_s2ef_dataset import OC20S2EFDataset # noqa from ppmat.datasets.qm9_dataset import QM9Dataset # noqa from ppmat.datasets.omol25_dataset import OMol25Dataset -from ppmat.datasets.IRDataset import IRDataset, IRDataLoader -from ppmat.datasets.ECDFormerDataset import ECDFormerDataset, ECDFormerDataset_DataLoader +from ppmat.datasets.ir_dataset import IRDataset +from ppmat.datasets.ecd_dataset import ECDDataset from ppmat.datasets.split_mptrj_data import none_to_zero from ppmat.datasets.transform import build_transforms from ppmat.utils import logger @@ -70,9 +70,7 @@ "SmallDensityDataset", "OMol25Dataset", "IRDataset", - "ECDFormerDataset", - "IRDataLoader", - "ECDFormerDataset_DataLoader", + "ECDDataset", ] INFO_CLASS_REGISTRY: Dict[str, type] = { @@ -283,7 +281,7 @@ def set_build_sample(sampler_cfg, world_size, dataset): ) batch_sampler = getattr(io, batch_sampler_cls)( dataset, - batch_size=init_params["batch_size"], + batch_size=2, # use default batch_size=2 to avoid error when batch_sampler is not specified shuffle=False, drop_last=False, ) diff --git a/ppmat/datasets/ECDFormerDataset/__init__.py b/ppmat/datasets/build_ecd.py similarity index 70% rename from ppmat/datasets/ECDFormerDataset/__init__.py rename to ppmat/datasets/build_ecd.py index e84dfb8a..6b67b93a 100644 --- a/ppmat/datasets/ECDFormerDataset/__init__.py +++ b/ppmat/datasets/build_ecd.py @@ -1,55 +1,116 @@ # Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -ECDFormer数据集加载模块 -""" +from __future__ import annotations import os -import numpy as np -import pandas as pd +import copy +import importlib +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, List import paddle -from paddle.io import Dataset -from paddle_geometric.data import Data +import pandas as pd +import numpy as np -from .compound_tools import get_atom_feature_dims, get_bond_feature_dims -from .util_func import normalize_func -from .eval_func import get_sequence_peak -from .colored_tqdm import ColoredTqdm as tqdm -from .place_env import PlaceEnv -from .dataloader import ECDFormerDataset_DataLoader - - -# ----------------Commonly-used Parameters---------------- -atom_id_names = [ - "atomic_num", "chiral_tag", "degree", "explicit_valence", - "formal_charge", "hybridization", "implicit_valence", - "is_aromatic", "total_numHs", -] -bond_id_names = ["bond_dir", "bond_type", "is_in_ring"] -full_atom_feature_dims = get_atom_feature_dims(atom_id_names) -full_bond_feature_dims = get_bond_feature_dims(bond_id_names) -bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] -column_specify={ - 'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4], - 'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9], - 'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14], - 'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19], - 'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24] -} -bond_float_names = [] +from paddle_geometric.data import Data +from ppmat.utils import download as download_utils +from ppmat.utils import logger +from ppmat.utils import ColoredTqdm as tqdm +from ppmat.utils.compound_tools import ( + atom_id_names, bond_id_names, bond_angle_float_names +) + +def _locate_class(class_name: str): + if "." in class_name: + mod, cls = class_name.rsplit(".", 1) + return getattr(importlib.import_module(mod), cls) + return globals()[class_name] + + +def _parse_factory_cfg( + cfg: Optional[Dict[str, Any] | str], + *, + default_class_name: str, +) -> Tuple[str, Dict[str, Any]]: + """解析工厂配置,兼容多种格式""" + if cfg is None: + return default_class_name, {} + + if isinstance(cfg, str): + return cfg, {} + + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be None, str, or dict, got {type(cfg).__name__}") + + cfg = copy.deepcopy(cfg) + class_name = cfg.pop("__class_name__", None) or cfg.pop("class_name", None) or cfg.pop("type", None) + if not class_name: + raise ValueError("Factory cfg must include class name key") + + init_params = ( + cfg.pop("__init_params__", None) or + cfg.pop("init_params", None) or + cfg.pop("params", None) or {} + ) + if not isinstance(init_params, dict): + raise TypeError(f"init_params must be dict, got {type(init_params).__name__}") + + if cfg: + raise ValueError(f"Unsupported keys in cfg: {list(cfg.keys())}") + return class_name, init_params + + +class StrictIndexSampleBuilder: + """按严格索引构建样本(适用于 ECD 数据集)""" + def build(self, data_dir: Path, index_file: str, sample_path: str, data_count: Optional[int] = None): + import pandas as pd + samples = [] + df = pd.read_csv(data_dir / index_file, encoding='gbk') + ids = df['Unnamed: 0'].values[:data_count] if data_count else df['Unnamed: 0'].values + for idx in ids: + samples.append({ + 'id': int(idx), + 'smiles': df[df['Unnamed: 0'] == idx]['SMILES'].values[0], + 'spectrum_path': str(Path(sample_path) / f"{idx}.csv") + }) + return samples + + +class DefaultECDDatasetDownloader: + """ECD 数据集下载器""" + def __init__(self, datasets_home: Optional[str] = None): + self.datasets_home = datasets_home or download_utils.DATASETS_HOME + + def download(self, url: str, md5: Optional[str] = None, force_download: bool = False) -> Path: + if force_download: + downloaded_root = download_utils.get_path_from_url( + url, self.datasets_home, md5sum=md5, check_exist=False, decompress=True + ) + else: + downloaded_root = download_utils.get_datasets_path_from_url(url, md5) + return Path(downloaded_root) + +def build_ecformer_downloader(cfg: Optional[Dict[str, Any] | str]): + """构建下载器""" + class_name, init_params = _parse_factory_cfg(cfg, default_class_name="DefaultECDDatasetDownloader") + cls = _locate_class(class_name) + downloader = cls(**init_params) + if not hasattr(downloader, 'download'): + raise TypeError(f"Downloader {class_name} must implement 'download' method") + logger.debug(f"Use downloader: {class_name}") + return downloader def get_key_padding_mask(tokens): """生成query padding mask""" @@ -57,80 +118,36 @@ def get_key_padding_mask(tokens): key_padding_mask[tokens == -1] = -paddle.inf return key_padding_mask - -def Construct_dataset(dataset, data_index, path): - """ - 从原始特征构建图数据 - 完全复用原型程序的Construct_dataset逻辑 - """ - graph_atom_bond = [] - graph_bond_angle = [] - - all_descriptor = np.load(os.path.join(path, 'descriptor_all_column.npy')) # (25847, 1826) - - for i in tqdm(range(len(dataset)), desc="Constructing graphs"): - data = dataset[i] - - # 收集原子特征 - atom_feature = [] - for name in atom_id_names: - atom_feature.append(data[name]) - - # 收集键特征 - bond_feature = [] - for name in bond_id_names[0:3]: - bond_feature.append(data[name]) - - # 转换为Tensor - atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') - bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') - bond_float_feature = paddle.to_tensor(data['bond_length'].astype(np.float32)) - bond_angle_feature = paddle.to_tensor(data['bond_angle'].astype(np.float32)) - edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') - bond_index = paddle.to_tensor(data['BondAngleGraph_edges'].T, dtype='int64') - data_index_int = paddle.to_tensor(np.array(data_index[i]), dtype='int64') - - # 添加描述符特征(与原型程序完全一致) - TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 - RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] - RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] - MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] - MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] - - # 合并特征 - bond_feature = paddle.concat( - [bond_feature.astype(bond_float_feature.dtype), - bond_float_feature.reshape([-1, 1])], - axis=1 - ) - - bond_angle_feature = paddle.concat( - [bond_angle_feature.reshape([-1, 1]), TPSA.reshape([-1, 1])], - axis=1 - ) - bond_angle_feature = paddle.concat([bond_angle_feature, RASA.reshape([-1, 1])], axis=1) - bond_angle_feature = paddle.concat([bond_angle_feature, RPSA.reshape([-1, 1])], axis=1) - bond_angle_feature = paddle.concat([bond_angle_feature, MDEC.reshape([-1, 1])], axis=1) - bond_angle_feature = paddle.concat([bond_angle_feature, MATS.reshape([-1, 1])], axis=1) - - # 创建Data对象 - data_atom_bond = Data( - x=atom_feature, - edge_index=edge_index, - edge_attr=bond_feature, - data_index=data_index_int, - ) - data_bond_angle = Data( - edge_index=bond_index, - edge_attr=bond_angle_feature, - num_nodes=atom_feature.shape[0] - ) - - graph_atom_bond.append(data_atom_bond) - graph_bond_angle.append(data_bond_angle) - - return graph_atom_bond, graph_bond_angle - +def normalize_func(src_list, norm_range=[-100, 100]): + # lihao implecation for list normalization + # input: src_list, normalization range + # output: tgt_list after normalization + + src_max, src_min = max(src_list), min(src_list) + norm_min, norm_max = norm_range[0], norm_range[1] + if src_max == 0: src_max = 1 + if src_min == 0: src_min = -1 + + tgt_list = [] + for i in range(len(src_list)): + if src_list[i] >= 0: + tgt_list.append(src_list[i] * norm_max / src_max) + else: + tgt_list.append(src_list[i] * norm_min / src_min) + + assert len(src_list) == len(tgt_list) + return tgt_list + +def get_sequence_peak(sequence): + # input- seq: List + # output- peak_list contains peak position + peak_list = [] + for i in range(1, len(sequence)-1): + if sequence[i-1]sequence[i+1]: + peak_list.append(i) + if sequence[i-1]>sequence[i] and sequence[i] Tuple[str, Dict[str, Any]]: + """解析工厂配置,兼容多种格式""" + if cfg is None: + return default_class_name, {} + + if isinstance(cfg, str): + return cfg, {} + + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be None, str, or dict, got {type(cfg).__name__}") + + cfg = copy.deepcopy(cfg) + class_name = cfg.pop("__class_name__", None) or cfg.pop("class_name", None) or cfg.pop("type", None) + if not class_name: + raise ValueError("Factory cfg must include class name key") + + init_params = ( + cfg.pop("__init_params__", None) or + cfg.pop("init_params", None) or + cfg.pop("params", None) or {} + ) + if not isinstance(init_params, dict): + raise TypeError(f"init_params must be dict, got {type(init_params).__name__}") + + if cfg: + raise ValueError(f"Unsupported keys in cfg: {list(cfg.keys())}") + return class_name, init_params + + +class IRStrictIndexSampleBuilder: + """按严格索引构建IR样本""" + def build(self, data_dir: Path, meta_file: str, spectra_dir: str, data_count: Optional[int] = None): + """构建样本列表""" + samples = [] + meta_path = data_dir / meta_file + data = np.load(meta_path, allow_pickle=True).item() + + index_all = data['index_all'][:data_count] if data_count else data['index_all'] + + for idx in index_all: + samples.append({ + 'id': int(idx), + 'smiles': data['smiles_all'][data['index_all'].index(idx)] if hasattr(data['index_all'], 'index') else None, + 'spectrum_path': str(Path(spectra_dir) / f"{idx}.json") + }) + return samples + + +class DefaultIRDatasetDownloader: + """IR 数据集下载器""" + def __init__(self, datasets_home: Optional[str] = None): + self.datasets_home = datasets_home or download_utils.DATASETS_HOME + + def download(self, url: str, md5: Optional[str] = None, force_download: bool = False) -> Path: + if force_download: + downloaded_root = download_utils.get_path_from_url( + url, self.datasets_home, md5sum=md5, check_exist=False, decompress=True + ) + else: + downloaded_root = download_utils.get_datasets_path_from_url(url, md5) + return Path(downloaded_root) + + +def build_ir_downloader(cfg: Optional[Dict[str, Any] | str]): + """构建下载器""" + class_name, init_params = _parse_factory_cfg(cfg, default_class_name="DefaultIRDatasetDownloader") + cls = _locate_class(class_name) + downloader = cls(**init_params) + if not hasattr(downloader, 'download'): + raise TypeError(f"Downloader {class_name} must implement 'download' method") + logger.debug(f"Use downloader: {class_name}") + return downloader + + +def build_ir_sample_builder(cfg: Optional[Dict[str, Any] | str]): + """构建样本构建器""" + class_name, init_params = _parse_factory_cfg(cfg, default_class_name="IRStrictIndexSampleBuilder") + cls = _locate_class(class_name) + builder = cls(**init_params) + if not hasattr(builder, 'build'): + raise TypeError(f"Sample builder {class_name} must implement 'build' method") + logger.debug(f"Use sample builder: {class_name}") + return builder + + +# ==================== IR 特定工具函数 ==================== + +IR_WAVELENGTH_MIN = 500 +IR_WAVELENGTH_MAX = 4000 +IR_STEP = 100 +DEFAULT_MAX_PEAKS = 15 + + +def get_key_padding_mask(tokens): + """生成query padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + + +def x_bin_position(real_x, distance=IR_STEP): + """将实际波数转换为箱ID""" + return int((real_x - IR_WAVELENGTH_MIN) / distance) + + +def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): + """ + 从原始特征构建IR图数据 + """ + graph_atom_bond = [] + graph_bond_angle = [] + + all_descriptor = None + if descriptor_path and os.path.exists(descriptor_path): + all_descriptor = np.load(descriptor_path) + + for i in tqdm(range(len(dataset)), desc="Constructing IR graphs"): + data = dataset[i] + + # 收集原子特征 + atom_feature = [] + for name in atom_id_names: + if name in data: + atom_feature.append(data[name]) + else: + if i == 0: + warnings.warn(f"Feature {name} not found in data, using zeros") + num_atoms = data.get('atomic_num', np.zeros(1)).shape[0] + atom_feature.append(np.zeros(num_atoms)) + + # 收集键特征 + bond_feature = [] + for name in bond_id_names: + if name in data: + bond_feature.append(data[name]) + else: + if i == 0: + warnings.warn(f"Bond feature {name} not found, using zeros") + num_bonds = data.get('bond_dir', np.zeros(1)).shape[0] + bond_feature.append(np.zeros(num_bonds)) + + # 转换为Tensor + atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') + bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') + + bond_float_feature = paddle.to_tensor(data.get('bond_length', np.zeros(data['edges'].shape[0])).astype(np.float32)) + bond_angle_feature = paddle.to_tensor(data.get('bond_angle', np.zeros(data.get('BondAngleGraph_edges', np.zeros((0,2))).shape[0])).astype(np.float32)) + + edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') + bond_index = paddle.to_tensor(data.get('BondAngleGraph_edges', np.zeros((0,2))).T, dtype='int64') + + data_index_int = paddle.to_tensor(np.array(int(data_index[i])), dtype='int64') + num_atoms = atom_feature.shape[0] + + # 合并键特征 + bond_feature = paddle.concat( + [bond_feature.astype(bond_float_feature.dtype), + bond_float_feature.reshape([-1, 1])], + axis=1 + ) + + # 处理键角特征 - 确保输出6维! + if bond_angle_feature.shape[0] > 0: + # 基础特征:bond_angle + features = [bond_angle_feature.reshape([-1, 1])] + + # 如果有描述符,添加5个描述符特征 + if all_descriptor is not None and i < all_descriptor.shape[0]: + TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 + RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] + RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] + MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] + MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] + + features.extend([ + TPSA.reshape([-1, 1]), + RASA.reshape([-1, 1]), + RPSA.reshape([-1, 1]), + MDEC.reshape([-1, 1]), + MATS.reshape([-1, 1]) + ]) + else: + # 如果没有描述符,用0填充剩下的5维 + for _ in range(5): + features.append(paddle.zeros([bond_angle_feature.shape[0], 1])) + + # 拼接成 [E_ba, 6] + bond_angle_feature = paddle.concat(features, axis=1) + else: + # 如果没有键角,创建全0的 [0, 6] + bond_angle_feature = paddle.zeros([0, 6]) + + data_atom_bond = Data( + x=atom_feature, + edge_index=edge_index, + edge_attr=bond_feature, + data_index=data_index_int, + ) + + data_bond_angle = Data( + edge_index=bond_index, + edge_attr=bond_angle_feature if bond_angle_feature.shape[0] > 0 else paddle.zeros([0, 1]), + num_nodes=num_atoms, + ) + + graph_atom_bond.append(data_atom_bond) + graph_bond_angle.append(data_bond_angle) + + return graph_atom_bond, graph_bond_angle + + +def read_ir_spectra_by_ids(sample_path, index_all, max_peak=DEFAULT_MAX_PEAKS): + """ + 按需读取IR光谱文件 + """ + ir_final_list = [] + + for fileid in tqdm(index_all, desc="Reading IR spectra by ID"): + filepath = os.path.join(sample_path, f"{fileid}.json") + + try: + with open(filepath, 'r') as f: + raw_ir_info = json.load(f) + + ir_x = raw_ir_info['x'] + ir_y = raw_ir_info['y_40'] + + peaks_raw, _ = find_peaks(x=ir_y, height=0.1, distance=100) + peaks_raw = peaks_raw.tolist() + + peak_num = min(len(peaks_raw), max_peak) + + if peak_num > 0: + if len(peaks_raw) > max_peak: + peaks = peaks_raw[len(peaks_raw)-max_peak:] + else: + peaks = peaks_raw + + peak_position_list = [x_bin_position(ir_x[i]) for i in peaks] + peak_height_list = [ir_y[i] for i in peaks] + else: + peak_position_list = [] + peak_height_list = [] + + peak_position_list = peak_position_list + [-1] * (max_peak - len(peak_position_list)) + peak_height_list = peak_height_list + [-1] * (max_peak - len(peak_height_list)) + + query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) + + tmp_dict = { + 'id': fileid, + 'seq_40': ir_y, + 'peak_num': peak_num, + 'peak_position': peak_position_list, + 'peak_height': peak_height_list, + 'query_mask': query_padding_mask.unsqueeze(0), + } + ir_final_list.append(tmp_dict) + + except Exception as e: + warnings.warn(f"Error processing {fileid}.json: {e}") + continue + + ir_final_list.sort(key=lambda x: x['id']) + return ir_final_list + + +def GetIRDataset( + sample_path, + dataset_all, + index_all, +): + """ + 核心函数:构建并返回IR图数据集 + """ + # 1. 读取IR光谱序列 + ir_sequences = read_ir_spectra_by_ids(sample_path, index_all) + + # 2. 构建图数据 + total_graph_atom_bond, total_graph_bond_angle = Construct_IR_Dataset( + dataset_all, index_all, sample_path + ) + print("Case Before Process = ", len(total_graph_atom_bond), len(total_graph_bond_angle)) + + # 3. 将光谱信息附加到图数据上 + dataset_graph_atom_bond, dataset_graph_bond_angle = [], [] + + for i, itm in enumerate(ir_sequences): + atom_bond = total_graph_atom_bond[i] + + # 附加光谱信息 + atom_bond.sequence = paddle.to_tensor([itm['seq_40']]) + atom_bond.ir_id = paddle.to_tensor(int(itm['id'])) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + dataset_graph_atom_bond.append(atom_bond) + dataset_graph_bond_angle.append(total_graph_bond_angle[i]) + + total_num = len(dataset_graph_atom_bond) + print("Case After Process = ", len(dataset_graph_atom_bond), len(dataset_graph_bond_angle)) + print('=================== Data prepared ================\n') + + return dataset_graph_atom_bond, dataset_graph_bond_angle \ No newline at end of file diff --git a/ppmat/datasets/collate_fn.py b/ppmat/datasets/collate_fn.py index 9073af4c..cb010ee8 100644 --- a/ppmat/datasets/collate_fn.py +++ b/ppmat/datasets/collate_fn.py @@ -302,3 +302,80 @@ def pad_sequence(sequences, batch_first=False, padding_value=0): out_tensor[:length, i, ...] = tensor return out_tensor + + +class ECDCollator(DefaultCollator): + def __call__(self, batch: List[Any]) -> Any: + batch = [list(x) for x in zip(*batch)] # transpose + for i in range(len(batch)): # 组Batch + batch[i] = Batch.from_data_list(batch[i]) + + batch0 = batch[0] + batch1 = batch[1] + + # Data解包到Tensor字典 + batch_atom_bond, batch_bond_angle = batch0, batch1 + x, edge_index, edge_attr, query_mask =batch_atom_bond.x,batch_atom_bond.edge_index,batch_atom_bond.edge_attr,batch_atom_bond.query_mask + ba_edge_index, ba_edge_attr = batch_bond_angle.edge_index,batch_bond_angle.edge_attr + batch_data = batch_atom_bond.batch + pos_gt = batch_atom_bond.peak_position + height_gt = batch_atom_bond.peak_height + num_gt = batch_atom_bond.peak_num + return \ + { + "x" : x , + "edge_index" : edge_index , + "edge_attr" : edge_attr , + "batch_data" : batch_data , + "ba_edge_index" : ba_edge_index , + "ba_edge_attr" : ba_edge_attr , + "query_mask" : query_mask + }, \ + { + "peak_number_gt" : num_gt , + "peak_position_gt": pos_gt , + "peak_height_gt" : height_gt + } + + +class IRCollator(DefaultCollator): + """IR 数据集专用 collator,返回 Tensor 字典""" + + def __call__(self, batch: List[Any]) -> Any: + batch = [list(x) for x in zip(*batch)] # transpose + for i in range(len(batch)): + batch[i] = Batch.from_data_list(batch[i]) + + batch_atom_bond, batch_bond_angle = batch[0], batch[1] + + x, edge_index, edge_attr, query_mask = ( + batch_atom_bond.x, + batch_atom_bond.edge_index, + batch_atom_bond.edge_attr, + batch_atom_bond.query_mask, + ) + ba_edge_index, ba_edge_attr = ( + batch_bond_angle.edge_index, + batch_bond_angle.edge_attr, + ) + batch_data = batch_atom_bond.batch + pos_gt = batch_atom_bond.peak_position + height_gt = batch_atom_bond.peak_height + num_gt = batch_atom_bond.peak_num + + return ( + { + "x": x, + "edge_index": edge_index, + "edge_attr": edge_attr, + "batch_data": batch_data, + "ba_edge_index": ba_edge_index, + "ba_edge_attr": ba_edge_attr, + "query_mask": query_mask, + }, + { + "peak_number_gt": num_gt, + "peak_position_gt": pos_gt, + "peak_height_gt": height_gt, + } + ) \ No newline at end of file diff --git a/ppmat/datasets/ecd_dataset.py b/ppmat/datasets/ecd_dataset.py new file mode 100644 index 00000000..c67651bf --- /dev/null +++ b/ppmat/datasets/ecd_dataset.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import numpy as np +import pandas as pd +import paddle +from paddle.io import Dataset +from paddle_geometric.data import Data +from pathlib import Path +from typing import Dict, List, Optional, Any + +from ppmat.utils import ColoredTqdm as tqdm +from ppmat.utils import PlaceEnv +from ppmat.utils.compound_tools import get_atom_feature_dims, get_bond_feature_dims +from ppmat.datasets.build_ecd import build_ecformer_sample_builder +from ppmat.datasets.build_ecd import build_ecformer_downloader +from ppmat.datasets.build_ecd import GetAtomBondAngleDataset + +_cache = () + +class ECDDataset(Dataset): + """ + ECDFormer ECD 光谱预测数据集 + + 数据来源:https://paddle-org.bj.bcebos.com/paddlematerials/datasets/ECD/ECD.tar.gz + """ + + url = "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/ECD/ECD.tar.gz" + md5 = "aa86eddee2397dbc37c4b7b9a45b1e27" + + def __init__( + self, + data_path: str, + split: Optional[str] = None, # 'train'/'val'/'test' + data_count: Optional[int] = None, + sample_builder_cfg: Optional[Dict] = None, + downloader_cfg: Optional[Dict] = None, + download: bool = True, + force_download: bool = False, + use_geometry_enhanced: bool = True, + use_column_info: bool = False, + ): + super().__init__() + + self.data_path = Path(data_path) + self.split = split + self.data_count = data_count + self.use_geometry_enhanced = use_geometry_enhanced + self.use_column_info = use_column_info + + # 构建组件 + self.sample_builder = build_ecformer_sample_builder(sample_builder_cfg) + self.downloader = build_ecformer_downloader(downloader_cfg) + + # 处理下载 + if force_download or (not self._check_files() and download): + self.downloaded_root = self.downloader.download( + self.url, self.md5, force_download=force_download + ) + self.data_path = self.downloaded_root + + # 加载数据 + self._load_data() + + def _check_files(self): + """检查必要的文件是否存在""" + npy_path = self.data_path / 'ecd_column_charity_new_smiles.npy' + csv_path = self.data_path / 'ecd_info.csv' + + if not npy_path.exists(): + return False + if not csv_path.exists(): + return False + return True + + def _load_data(self): + """加载所有数据""" + # 1. 加载 npy 文件 + npy_path = self.data_path / 'ecd_column_charity_new_smiles.npy' + if not npy_path.exists(): + raise FileNotFoundError(f"npy file not found: {npy_path}") + + self.ecd_dataset = np.load(npy_path, allow_pickle=True).tolist() + + # 2. 加载 csv 文件 + csv_path = self.data_path / 'ecd_info.csv' + if not csv_path.exists(): + raise FileNotFoundError(f"csv file not found: {csv_path}") + + self.ecd_info = pd.read_csv(csv_path, encoding='gbk') + + # 3. 提取数据 + self.dataset_all = [item['info'] for item in self.ecd_dataset] + self.smiles_all = [item['smiles'] for item in self.ecd_dataset] + self.index_all = self.ecd_info['Unnamed: 0'].values + + # 4. 构建手性对映射 + self._build_chiral_mapping() + + # 5. 构建图数据集 + self._build_graph_dataset() + + def _build_chiral_mapping(self): + """构建手性对映体映射""" + self.hand_idx_dict = {} + self.line_idx_dict = {} + + for i, itm in enumerate(self.ecd_dataset): + self.line_idx_dict[i] = { + 'hand_id': itm['hand_id'], + 'unnamed_id': itm['id'], + 'smiles': itm['smiles'] + } + + if itm['hand_id'] not in self.hand_idx_dict: + self.hand_idx_dict[itm['hand_id']] = [] + self.hand_idx_dict[itm['hand_id']].append({ + 'line_number': i, + 'unnamed_id': itm['id'], + 'smiles': itm['smiles'] + }) + + @PlaceEnv(paddle.CPUPlace()) + def _build_graph_dataset(self): + """构建图数据集""" + global _cache + + if len(_cache) > 0: + self.graph_atom_bond, self.graph_bond_angle = _cache + return + + self.graph_atom_bond, self.graph_bond_angle = GetAtomBondAngleDataset( + sample_path=str(self.data_path), + dataset_all=self.dataset_all, + index_all=self.index_all, + hand_idx_dict=self.hand_idx_dict, + line_idx_dict=self.line_idx_dict + ) + + _cache = (self.graph_atom_bond, self.graph_bond_angle) + + assert len(self.graph_atom_bond) == len(self.graph_bond_angle) + + def __len__(self): + return len(self.graph_atom_bond) + + @PlaceEnv(paddle.CPUPlace()) + def __getitem__(self, idx): + """返回 (atom_bond_graph, bond_angle_graph)""" + return self.graph_atom_bond[idx], self.graph_bond_angle[idx] \ No newline at end of file diff --git a/ppmat/datasets/geometric_data_type/batch.py b/ppmat/datasets/geometric_data_type/batch.py index 26a9d311..59c02bab 100644 --- a/ppmat/datasets/geometric_data_type/batch.py +++ b/ppmat/datasets/geometric_data_type/batch.py @@ -56,7 +56,7 @@ def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`.""" - keys = list(set(data_list[0].keys) - set(exclude_keys)) + keys = list(set(data_list[0].keys()) - set(exclude_keys)) assert "batch" not in keys and "ptr" not in keys batch = cls() for key in data_list[0].__dict__.keys(): diff --git a/ppmat/datasets/ir_dataset.py b/ppmat/datasets/ir_dataset.py new file mode 100644 index 00000000..0e3d5f34 --- /dev/null +++ b/ppmat/datasets/ir_dataset.py @@ -0,0 +1,172 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import numpy as np +import paddle +from paddle.io import Dataset +from paddle_geometric.data import Data +from pathlib import Path +from typing import Dict, List, Optional, Any + +from ppmat.utils import ColoredTqdm as tqdm +from ppmat.utils import PlaceEnv +from ppmat.datasets.build_ir import ( + build_ir_sample_builder, + build_ir_downloader, + GetIRDataset, + read_ir_spectra_by_ids, + Construct_IR_Dataset, +) + +_cache = {} + + +class IRDataset(Dataset): + """ + ECFormer IR 光谱预测数据集 + + 支持三种预加载模式: + - '100': 100个样本的小数据集(默认,用于快速测试) + - '10000': 1万个样本的中等数据集 + - 'all': 全部样本(可能很大) + """ + + url = "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/IR/IR.tar.gz" + md5 = "e1ea5624cf9b92b3657933245196f5dc" + + def __init__( + self, + data_path: str, + mode: str = '100', + split: Optional[str] = None, + data_count: Optional[int] = None, + sample_builder_cfg: Optional[Dict] = None, + downloader_cfg: Optional[Dict] = None, + download: bool = True, + force_download: bool = False, + use_geometry_enhanced: bool = True, + use_cache: bool = True, + ): + super().__init__() + + self.data_path = Path(data_path) + self.mode = mode + self.split = split + self.data_count = data_count + self.use_geometry_enhanced = use_geometry_enhanced + self.use_cache = use_cache + + cache_key = f"{data_path}_{mode}_{use_geometry_enhanced}" + + # 如果启用缓存且命中,直接返回 + if use_cache and cache_key in _cache: + cached_data = _cache[cache_key] + self.graph_atom_bond = cached_data['atom_bond'] + self.graph_bond_angle = cached_data['bond_angle'] + self.smiles_list = cached_data.get('smiles', []) + return + + # 构建组件 + self.sample_builder = build_ir_sample_builder(sample_builder_cfg) + self.downloader = build_ir_downloader(downloader_cfg) + + # 处理下载 + if force_download or (not self._check_files() and download): + self.downloaded_root = self.downloader.download( + self.url, self.md5, force_download=force_download + ) + self.data_path = self.downloaded_root + + # 加载数据 + self._load_data() + + # 存入缓存 + if use_cache: + _cache[cache_key] = { + 'atom_bond': self.graph_atom_bond, + 'bond_angle': self.graph_bond_angle, + 'smiles': self.smiles_list + } + + def _check_files(self): + """检查必要的文件是否存在""" + meta_path = self.data_path / f'ir_column_charity_{self.mode}.npy' + spectra_path = self.data_path / 'qm9_ir_spec' + + if not meta_path.exists(): + return False + if not spectra_path.exists(): + return False + return True + + def _load_data(self): + """加载所有数据""" + # 1. 加载元数据文件 + meta_path = self.data_path / f'ir_column_charity_{self.mode}.npy' + if not meta_path.exists(): + raise FileNotFoundError(f"IR meta file {meta_path} not found") + + data = np.load(meta_path, allow_pickle=True).item() + dataset_all = data['dataset_all'] + smiles_all = data['smiles_all'] + index_all = data['index_all'] + + print(f"Loaded meta data: {len(index_all)} samples") + + # 2. 按需读取IR光谱 + spectra_path = self.data_path / 'qm9_ir_spec' + self.ir_sequences = read_ir_spectra_by_ids(str(spectra_path), index_all) + + print(f"Loaded {len(self.ir_sequences)} IR spectra") + + # 3. 构建图数据 + descriptor_path = self.data_path / 'descriptor_all_column.npy' + if not descriptor_path.exists(): + descriptor_path = None + + total_graph_atom_bond, total_graph_bond_angle = Construct_IR_Dataset( + dataset_all, index_all, descriptor_path + ) + + # 4. 将光谱信息附加到图数据 + self.graph_atom_bond = [] + self.graph_bond_angle = [] + self.smiles_list = [] + + for i, itm in enumerate(self.ir_sequences): + atom_bond = total_graph_atom_bond[i] + + atom_bond.sequence = paddle.to_tensor([itm['seq_40']]) + atom_bond.ir_id = paddle.to_tensor(int(itm['id'])) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + self.graph_atom_bond.append(atom_bond) + self.graph_bond_angle.append(total_graph_bond_angle[i]) + self.smiles_list.append(smiles_all[i]) + + print(f"Final dataset size: {len(self.graph_atom_bond)}") + + def __len__(self): + return len(self.graph_atom_bond) + + @PlaceEnv(paddle.CPUPlace()) + def __getitem__(self, idx): + """返回 (atom_bond_graph, bond_angle_graph)""" + return self.graph_atom_bond[idx], self.graph_bond_angle[idx] \ No newline at end of file diff --git a/ppmat/utils/__init__.py b/ppmat/utils/__init__.py index 8b5fb924..913a51b1 100644 --- a/ppmat/utils/__init__.py +++ b/ppmat/utils/__init__.py @@ -22,6 +22,8 @@ from ppmat.utils.save_load import load_checkpoint from ppmat.utils.save_load import load_pretrain from ppmat.utils.save_load import save_checkpoint +from ppmat.utils.place_env import PlaceEnv +from ppmat.utils.colored_tqdm import ColoredTqdm __all__ = [ logger, @@ -33,4 +35,6 @@ load_checkpoint, load_pretrain, save_checkpoint, + PlaceEnv, + ColoredTqdm ] diff --git a/ppmat/datasets/ECDFormerDataset/colored_tqdm.py b/ppmat/utils/colored_tqdm.py similarity index 76% rename from ppmat/datasets/ECDFormerDataset/colored_tqdm.py rename to ppmat/utils/colored_tqdm.py index 8176778b..e108256a 100644 --- a/ppmat/datasets/ECDFormerDataset/colored_tqdm.py +++ b/ppmat/utils/colored_tqdm.py @@ -1,17 +1,3 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from tqdm import tqdm import time import os;os.system("") #兼容windows @@ -66,6 +52,10 @@ def __init__(self, *args, self.end_color = end_color def get_current_color(self): + + if self.total is None: + return "#FFFFFF" + progress = self.n / self.total if self.total > 0 else 0 current_rgb = tuple( int(start + (end - start) * progress) @@ -78,7 +68,6 @@ def get_current_color(self): def update(self, n=1): super().update(n) - # 使用Rich的真彩色支持 style = hex_to_ansi(self.get_current_color()) self.bar_format = f'{{l_bar}}{style}{{bar}}{hex_to_ansi.reset}{{r_bar}}' self.refresh() @@ -86,5 +75,6 @@ def update(self, n=1): if __name__ == "__main__": # 使用示例 - for i in ColoredTqdm(range(100), desc="🌈 彩虹渐变"): - time.sleep(0.1) \ No newline at end of file + for i in ColoredTqdm(range(10), desc="🌈 彩虹渐变", leave = False): + for j in ColoredTqdm(range(100), desc="🌈 彩虹渐变", leave = False): + time.sleep(0.01) \ No newline at end of file diff --git a/ppmat/datasets/IRDataset/compound_tools.py b/ppmat/utils/compound_tools.py similarity index 97% rename from ppmat/datasets/IRDataset/compound_tools.py rename to ppmat/utils/compound_tools.py index cee0c635..03461f0b 100644 --- a/ppmat/datasets/IRDataset/compound_tools.py +++ b/ppmat/utils/compound_tools.py @@ -1,17 +1,3 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import numpy as np from rdkit import Chem from rdkit.Chem import AllChem @@ -160,7 +146,6 @@ "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]", ] - def get_gasteiger_partial_charges(mol, n_iter=12): """ Calculates list of gasteiger partial charges for each atom in mol object. @@ -831,6 +816,25 @@ def predict_SMILES_info(smiles): info_dict = mol_to_geognn_graph_data_MMFF3d(mol) return info_dict +# ----------------Commonly-used Parameters---------------- +atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", +] +bond_id_names = ["bond_dir", "bond_type", "is_in_ring"] +full_atom_feature_dims = get_atom_feature_dims(atom_id_names) +full_bond_feature_dims = get_bond_feature_dims(bond_id_names) +bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] +column_specify={ + 'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4], + 'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9], + 'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14], + 'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19], + 'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24] +} +bond_float_names = [] + if __name__ == "__main__": # smiles = "OCc1ccccc1CN" smiles = r"[H]/[NH+]=C(\N)C1=CC(=O)/C(=C\C=c2ccc(=C(N)[NH3+])cc2)C=C1" diff --git a/ppmat/datasets/ECDFormerDataset/place_env.py b/ppmat/utils/place_env.py similarity index 55% rename from ppmat/datasets/ECDFormerDataset/place_env.py rename to ppmat/utils/place_env.py index 62acf945..347e3e07 100644 --- a/ppmat/datasets/ECDFormerDataset/place_env.py +++ b/ppmat/utils/place_env.py @@ -1,66 +1,13 @@ -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import paddle import functools -from contextlib import contextmanager - -@contextmanager -def place_env(place): - """ - 上下文管理器,用于临时设置PaddlePaddle的运行设备 - - Args: - place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 - - 用法: - with place_env(paddle.CPUPlace()): - # 这里的代码在CPU上运行 - x = paddle.rand([2, 3]) - print(x) - - @place_env(paddle.CUDAPlace(0)) - def train(): - # 这个函数在GPU上运行 - pass - """ - # 保存当前的设备设置 - current_device = paddle.get_device() - - # 根据place类型设置设备 - if isinstance(place, paddle.CPUPlace): - paddle.set_device('cpu') - elif isinstance(place, paddle.CUDAPlace): - # 获取GPU设备ID - device_id = place.get_device_id() - paddle.set_device(f'gpu:{device_id}') - else: - raise ValueError(f"不支持的place类型: {type(place)}") - - try: - yield - finally: - # 恢复原来的设备设置 - paddle.set_device(current_device) - +from paddle._typing.device_like import PlaceLike class PlaceEnv: """ 类版本的上下文管理器,也支持装饰器功能 """ - def __init__(self, place): + def __init__(self, place: PlaceLike): """ 初始化PlaceEnv @@ -74,16 +21,7 @@ def __enter__(self): """进入上下文时调用""" # 保存当前的设备设置 self.original_device = paddle.get_device() - - # 根据place类型设置设备 - if isinstance(self.place, paddle.CPUPlace): - paddle.set_device('cpu') - elif isinstance(self.place, paddle.CUDAPlace): - device_id = self.place.get_device_id() - paddle.set_device(f'gpu:{device_id}') - else: - raise ValueError(f"不支持的place类型: {type(self.place)}") - + paddle.set_device(self.place) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -110,15 +48,6 @@ def wrapper(*args, **kwargs): return wrapper -# 为了兼容性,也可以保留函数版本的上下文管理器 -@contextmanager -def with_place_env(place): - """ - with_place_env的别名,与place_env功能相同 - """ - with place_env(place): - yield - # 使用示例 if __name__ == "__main__": @@ -126,13 +55,6 @@ def with_place_env(place): print("=== 测试with语句 ===") print(f"当前设备: {paddle.get_device()}") - with place_env(paddle.CPUPlace()): - print(f"with块内设备: {paddle.get_device()}") - x = paddle.rand([2, 3]) - print(f"创建的张量: {x}") - - print(f"with块外设备: {paddle.get_device()}") - print("\n=== 测试类版本with语句 ===") with PlaceEnv(paddle.CPUPlace()): print(f"with块内设备: {paddle.get_device()}") From 3b162900e43c20bd519bcbc0a8afab637dafa4fd Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Mon, 9 Mar 2026 19:45:06 +0800 Subject: [PATCH 10/16] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E5=AF=B9fp64=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/datasets/build_ecd.py | 4 ++-- ppmat/datasets/build_ir.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ppmat/datasets/build_ecd.py b/ppmat/datasets/build_ecd.py index 6b67b93a..8b619b7e 100644 --- a/ppmat/datasets/build_ecd.py +++ b/ppmat/datasets/build_ecd.py @@ -290,8 +290,8 @@ def Construct_dataset(dataset, data_index, path): # 转换为Tensor atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') - bond_float_feature = paddle.to_tensor(data['bond_length'].astype(np.float32)) - bond_angle_feature = paddle.to_tensor(data['bond_angle'].astype(np.float32)) + bond_float_feature = paddle.to_tensor(data['bond_length'].astype(paddle.get_default_dtype())) + bond_angle_feature = paddle.to_tensor(data['bond_angle'].astype(paddle.get_default_dtype())) edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') bond_index = paddle.to_tensor(data['BondAngleGraph_edges'].T, dtype='int64') data_index_int = paddle.to_tensor(np.array(data_index[i]), dtype='int64') diff --git a/ppmat/datasets/build_ir.py b/ppmat/datasets/build_ir.py index 8eb1d366..d207c872 100644 --- a/ppmat/datasets/build_ir.py +++ b/ppmat/datasets/build_ir.py @@ -192,8 +192,8 @@ def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') - bond_float_feature = paddle.to_tensor(data.get('bond_length', np.zeros(data['edges'].shape[0])).astype(np.float32)) - bond_angle_feature = paddle.to_tensor(data.get('bond_angle', np.zeros(data.get('BondAngleGraph_edges', np.zeros((0,2))).shape[0])).astype(np.float32)) + bond_float_feature = paddle.to_tensor(data.get('bond_length', np.zeros(data['edges'].shape[0])).astype(paddle.get_default_dtype())) + bond_angle_feature = paddle.to_tensor(data.get('bond_angle', np.zeros(data.get('BondAngleGraph_edges', np.zeros((0,2))).shape[0])).astype(paddle.get_default_dtype())) edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') bond_index = paddle.to_tensor(data.get('BondAngleGraph_edges', np.zeros((0,2))).T, dtype='int64') From 5c15e0af1e107c5b6ce74c56dfe53f465ef7382f Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Mon, 9 Mar 2026 22:27:45 +0800 Subject: [PATCH 11/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=94=AF=E6=8C=81fp64?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/models/ecformer/models/base_ecformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppmat/models/ecformer/models/base_ecformer.py b/ppmat/models/ecformer/models/base_ecformer.py index 550bb5af..15c4bf16 100644 --- a/ppmat/models/ecformer/models/base_ecformer.py +++ b/ppmat/models/ecformer/models/base_ecformer.py @@ -176,7 +176,7 @@ def encode_molecule( # 5. 生成padding mask node_padding_mask = feat_padding_mask(node_index, self.max_node_num) - pooling_padding_mask = paddle.zeros([node_padding_mask.shape[0], 1], dtype='float32') + pooling_padding_mask = paddle.zeros([node_padding_mask.shape[0], 1], dtype=paddle.get_default_dtype()) total_padding_mask = paddle.concat([pooling_padding_mask, node_padding_mask], axis=1) return total_node_feat, total_padding_mask, node_padding_mask From f0e1ee2555ba1b2ae2b9307f5dad82f1be0421f6 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Tue, 10 Mar 2026 21:45:59 +0800 Subject: [PATCH 12/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86fp32=E6=BC=8F?= =?UTF-8?q?=E7=BD=91=E4=B9=8B=E9=B1=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppmat/models/ecformer/layers/gin_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppmat/models/ecformer/layers/gin_conv.py b/ppmat/models/ecformer/layers/gin_conv.py index 7a15945a..bad11e6b 100644 --- a/ppmat/models/ecformer/layers/gin_conv.py +++ b/ppmat/models/ecformer/layers/gin_conv.py @@ -31,7 +31,7 @@ def __init__(self, emb_dim): ) self.eps = paddle.create_parameter( shape=[1], - dtype='float32', + dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Assign(paddle.to_tensor([0.])) ) From 69e554cffd7e3835ff7cd94247e02b4d92c50c5e Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Fri, 13 Mar 2026 15:38:55 +0800 Subject: [PATCH 13/16] fix some import --- ppmat/datasets/build_ecd.py | 2 +- ppmat/datasets/build_ir.py | 2 +- ppmat/datasets/geometric_data_type/batch.py | 2 +- ppmat/models/ecformer/layers/gin_conv.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ppmat/datasets/build_ecd.py b/ppmat/datasets/build_ecd.py index 8b619b7e..1db8442b 100644 --- a/ppmat/datasets/build_ecd.py +++ b/ppmat/datasets/build_ecd.py @@ -23,7 +23,7 @@ import pandas as pd import numpy as np -from paddle_geometric.data import Data +from ppmat.datasets.geometric_data_type.data import Data from ppmat.utils import download as download_utils from ppmat.utils import logger diff --git a/ppmat/datasets/build_ir.py b/ppmat/datasets/build_ir.py index d207c872..be3a4fe0 100644 --- a/ppmat/datasets/build_ir.py +++ b/ppmat/datasets/build_ir.py @@ -26,7 +26,7 @@ import numpy as np import pandas as pd from scipy.signal import find_peaks -from paddle_geometric.data import Data +from ppmat.datasets.geometric_data_type.data import Data from ppmat.utils import download as download_utils from ppmat.utils import logger diff --git a/ppmat/datasets/geometric_data_type/batch.py b/ppmat/datasets/geometric_data_type/batch.py index 59c02bab..26a9d311 100644 --- a/ppmat/datasets/geometric_data_type/batch.py +++ b/ppmat/datasets/geometric_data_type/batch.py @@ -56,7 +56,7 @@ def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`.""" - keys = list(set(data_list[0].keys()) - set(exclude_keys)) + keys = list(set(data_list[0].keys) - set(exclude_keys)) assert "batch" not in keys and "ptr" not in keys batch = cls() for key in data_list[0].__dict__.keys(): diff --git a/ppmat/models/ecformer/layers/gin_conv.py b/ppmat/models/ecformer/layers/gin_conv.py index bad11e6b..73d0ab7f 100644 --- a/ppmat/models/ecformer/layers/gin_conv.py +++ b/ppmat/models/ecformer/layers/gin_conv.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle_geometric.nn import MessagePassing +from ppmat.models.common.message_passing.message_passing import MessagePassing import paddle import paddle.nn as nn import paddle.nn.functional as F From 1b8e97de8595b65c3f091e63bdab658f84dca3a1 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Fri, 13 Mar 2026 22:45:09 +0800 Subject: [PATCH 14/16] add ECFormer train code --- ppmat/datasets/collate_fn.py | 12 +- ppmat/losses/ecd_loss.py | 6 +- ppmat/losses/ir_loss.py | 8 +- ppmat/metrics/ecd_metric.py | 8 +- ppmat/metrics/ir_metric.py | 3 +- .../configs/ecformer/ecd.yaml | 113 ++++ spectrum_elucidation/configs/ecformer/ir.yaml | 123 +++++ spectrum_elucidation/{ => diffnmr}/sample.py | 0 spectrum_elucidation/{ => diffnmr}/train.py | 0 spectrum_elucidation/ecformer/train.py | 232 ++++++++ spectrum_elucidation/ecformer/trainer.py | 497 ++++++++++++++++++ 11 files changed, 982 insertions(+), 20 deletions(-) create mode 100644 spectrum_elucidation/configs/ecformer/ecd.yaml create mode 100644 spectrum_elucidation/configs/ecformer/ir.yaml rename spectrum_elucidation/{ => diffnmr}/sample.py (100%) rename spectrum_elucidation/{ => diffnmr}/train.py (100%) create mode 100644 spectrum_elucidation/ecformer/train.py create mode 100644 spectrum_elucidation/ecformer/trainer.py diff --git a/ppmat/datasets/collate_fn.py b/ppmat/datasets/collate_fn.py index cb010ee8..28189e46 100644 --- a/ppmat/datasets/collate_fn.py +++ b/ppmat/datasets/collate_fn.py @@ -332,9 +332,9 @@ def __call__(self, batch: List[Any]) -> Any: "query_mask" : query_mask }, \ { - "peak_number_gt" : num_gt , - "peak_position_gt": pos_gt , - "peak_height_gt" : height_gt + "peak_number" : num_gt , + "peak_position": pos_gt , + "peak_height" : height_gt } @@ -374,8 +374,8 @@ def __call__(self, batch: List[Any]) -> Any: "query_mask": query_mask, }, { - "peak_number_gt": num_gt, - "peak_position_gt": pos_gt, - "peak_height_gt": height_gt, + "peak_number": num_gt, + "peak_position": pos_gt, + "peak_height": height_gt, } ) \ No newline at end of file diff --git a/ppmat/losses/ecd_loss.py b/ppmat/losses/ecd_loss.py index 05e00dd0..e659417d 100644 --- a/ppmat/losses/ecd_loss.py +++ b/ppmat/losses/ecd_loss.py @@ -57,16 +57,16 @@ def forward(self, predictions, targets): dict: Loss components and total loss """ # Peak number loss - loss_num = self.ce_loss(predictions['peak_number'], targets['peak_num']) + loss_num = self.ce_loss(predictions['peak_number'], targets['peak_number']) - batch_size = targets['peak_num'].shape[0] + batch_size = targets['peak_number'].shape[0] loss_pos_total = 0.0 loss_height_total = 0.0 valid_samples = 0 for i in range(batch_size): - n_peaks = int(targets['peak_num'][i]) + n_peaks = int(targets['peak_number'][i]) if n_peaks == 0: continue diff --git a/ppmat/losses/ir_loss.py b/ppmat/losses/ir_loss.py index ea31114a..97cffe25 100644 --- a/ppmat/losses/ir_loss.py +++ b/ppmat/losses/ir_loss.py @@ -48,7 +48,7 @@ def forward(self, predictions, targets): - peak_position (Tensor): [batch_size, max_peaks, num_position_classes] logits for positions - peak_height (Tensor, optional): [batch_size, max_peaks] predicted intensity values targets (dict): Ground truth containing: - - peak_num (Tensor): [batch_size] true peak counts + - peak_number (Tensor): [batch_size] true peak counts - peak_position (Tensor): [batch_size, max_peaks] true position labels - peak_height (Tensor): [batch_size, max_peaks] true intensity values @@ -56,16 +56,16 @@ def forward(self, predictions, targets): dict: Loss components and total loss """ # Peak number loss - loss_num = self.ce_loss(predictions['peak_number'], targets['peak_num']) + loss_num = self.ce_loss(predictions['peak_number'], targets['peak_number']) - batch_size = targets['peak_num'].shape[0] + batch_size = targets['peak_number'].shape[0] loss_pos_total = 0.0 loss_height_total = 0.0 valid_samples = 0 for i in range(batch_size): - n_peaks = int(targets['peak_num'][i]) + n_peaks = int(targets['peak_number'][i]) if n_peaks == 0: continue diff --git a/ppmat/metrics/ecd_metric.py b/ppmat/metrics/ecd_metric.py index 8611e134..54b74041 100644 --- a/ppmat/metrics/ecd_metric.py +++ b/ppmat/metrics/ecd_metric.py @@ -1,22 +1,20 @@ # Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. -# Licensed under the Apache License, Version 2.0 (the License); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle import paddle.nn as nn -import numpy as np -from sklearn.metrics import mean_squared_error -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict # ========================= diff --git a/ppmat/metrics/ir_metric.py b/ppmat/metrics/ir_metric.py index 8754ad85..5a7c3dcc 100644 --- a/ppmat/metrics/ir_metric.py +++ b/ppmat/metrics/ir_metric.py @@ -14,8 +14,7 @@ import paddle import paddle.nn as nn -import numpy as np -from typing import Dict, Optional, Any +from typing import Dict # ========================= diff --git a/spectrum_elucidation/configs/ecformer/ecd.yaml b/spectrum_elucidation/configs/ecformer/ecd.yaml new file mode 100644 index 00000000..53c9d583 --- /dev/null +++ b/spectrum_elucidation/configs/ecformer/ecd.yaml @@ -0,0 +1,113 @@ +# ECFormer ECD Task Configuration +# Task: Electronic Circular Dichroism Spectrum Prediction + +Global: + do_train: True + do_eval: True + do_test: True + label_names: ["peak_number", "peak_position", "peak_height"] + +Dataset: + train: + dataset: + __class_name__: ECDDataset + __init_params__: + data_path: ./datasets/ECD + data_count: null # null means use all data + use_geometry_enhanced: True + use_column_info: False + loader: + num_workers: 4 + collate_fn: ECDCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 128 + shuffle: True + + val: + dataset: + __class_name__: ECDDataset + __init_params__: + data_path: ./datasets/ECD + data_count: null + use_geometry_enhanced: True + use_column_info: False + loader: + num_workers: 4 + collate_fn: ECDCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 128 + shuffle: False + + test: + dataset: + __class_name__: ECDDataset + __init_params__: + data_path: ./datasets/ECD + data_count: null + use_geometry_enhanced: True + use_column_info: False + loader: + num_workers: 4 + collate_fn: ECDCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 128 + shuffle: False + +Model: + __class_name__: ECFormerECD + __init_params__: + full_atom_feature_dims: [119, 9, 12, 14, 17, 9, 14, 2, 10] + full_bond_feature_dims: [8, 23, 3] + bond_float_names: ['bond_length'] + bond_angle_float_names: ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] + bond_id_names: ['bond_dir', 'bond_type', 'is_in_ring'] + num_layers: 5 + emb_dim: 256 + drop_ratio: 0.0 + graph_pooling: 'sum' + use_geometry_enhanced: True + max_node_num: 63 + num_heads: 4 + tf_layers: 2 + tf_dropout: 0.1 + max_peaks: 9 + num_position_classes: 20 + height_classes: 2 + +Optimizer: + __class_name__: AdamW + __init_params__: + lr: 0.001 + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.01 + +Metric: + __class_name__: ECDMetrics + __init_params__: + num_position_classes: 20 + max_peaks: 9 + +Trainer: + output_dir: ./output/ecformer_ecd + max_epochs: 100 + log_freq: 10 + save_freq: 5 + eval_freq: 1 + seed: 42 + use_amp: False + start_eval_epoch: 1 + amp_level: 'O1' + eval_with_no_grad: True + compute_metric_during_train: False + metric_strategy_during_eval: 'step' + use_visualdl: False + use_wandb: False + use_tensorboard: False \ No newline at end of file diff --git a/spectrum_elucidation/configs/ecformer/ir.yaml b/spectrum_elucidation/configs/ecformer/ir.yaml new file mode 100644 index 00000000..383b5787 --- /dev/null +++ b/spectrum_elucidation/configs/ecformer/ir.yaml @@ -0,0 +1,123 @@ +# ECFormer IR Task Configuration +# Task: Infrared Spectrum Prediction + +Global: + do_train: True + do_eval: True + do_test: True + label_names: ["peak_number", "peak_position", "peak_height"] + +Dataset: + train: + dataset: + __class_name__: IRDataset + __init_params__: + data_path: ./datasets/IR + data_count: null + use_geometry_enhanced: True + loader: + num_workers: 4 + collate_fn: IRCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 32 + shuffle: True + + val: + dataset: + __class_name__: IRDataset + __init_params__: + data_path: ./datasets/IR + data_count: null + use_geometry_enhanced: True + loader: + num_workers: 4 + collate_fn: IRCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 32 + shuffle: False + + test: + dataset: + __class_name__: IRDataset + __init_params__: + data_path: ./datasets/IR + data_count: null + use_geometry_enhanced: True + loader: + num_workers: 4 + collate_fn: IRCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 32 + shuffle: False + +Model: + __class_name__: ECFormerIR + __init_params__: + # Atom feature dimensions (consistent with ECD) + full_atom_feature_dims: [119, 9, 12, 14, 17, 9, 14, 2, 10] + # Bond feature dimensions (consistent with ECD) + full_bond_feature_dims: [8, 23, 3] + bond_float_names: ['bond_length'] + bond_angle_float_names: ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] + bond_id_names: ['bond_dir', 'bond_type', 'is_in_ring'] + + # GNN parameters + num_layers: 5 + emb_dim: 128 # IR task can use smaller embedding dimension + drop_ratio: 0.0 + graph_pooling: 'sum' + use_geometry_enhanced: True + max_node_num: 63 # Maximum number of atoms (consistent with ECD) + + # Transformer parameters + num_heads: 4 + tf_layers: 2 + tf_dropout: 0.1 + + # IR-specific parameters + max_peaks: 15 # IR spectra have up to 15 peaks + num_position_classes: 36 # IR position classes (wavenumber range) + use_height_prediction: True # IR uses intensity prediction (regression) + +Optimizer: + __class_name__: AdamW + __init_params__: + lr: 0.001 + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.01 + +Metric: + __class_name__: IRMetrics + __init_params__: + use_height_prediction: True # Consistent with model + +Trainer: + output_dir: ./output/ecformer_ir + max_epochs: 100 + log_freq: 10 + save_freq: 5 + eval_freq: 1 + seed: 42 + use_amp: False + start_eval_epoch: 1 + amp_level: 'O1' + eval_with_no_grad: True + compute_metric_during_train: False + metric_strategy_during_eval: 'epoch' # 'epoch' recommended for streaming metrics + use_visualdl: False + use_wandb: False + use_tensorboard: False + + # The following parameters can be passed to trainer (override model settings if needed) + loss_weight_height: 1.0 # Optional, override in trainer layer + max_peaks: 15 # Optional + num_position_classes: 36 # Optional + height_classes: 1 # IR intensity is regression task, output dimension is 1 \ No newline at end of file diff --git a/spectrum_elucidation/sample.py b/spectrum_elucidation/diffnmr/sample.py similarity index 100% rename from spectrum_elucidation/sample.py rename to spectrum_elucidation/diffnmr/sample.py diff --git a/spectrum_elucidation/train.py b/spectrum_elucidation/diffnmr/train.py similarity index 100% rename from spectrum_elucidation/train.py rename to spectrum_elucidation/diffnmr/train.py diff --git a/spectrum_elucidation/ecformer/train.py b/spectrum_elucidation/ecformer/train.py new file mode 100644 index 00000000..56b1db88 --- /dev/null +++ b/spectrum_elucidation/ecformer/train.py @@ -0,0 +1,232 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import os.path as osp + +import paddle.distributed as dist +from omegaconf import OmegaConf + +from ppmat.datasets import build_dataloader +from ppmat.datasets import set_signal_handlers +from ppmat.models import build_model +from ppmat.optimizer import build_optimizer +from ppmat.utils import logger +from ppmat.utils import misc + +from spectrum_elucidation.ecformer.trainer import ECDFormerTrainer + + +def main(): + # 解析参数 + parser = argparse.ArgumentParser(description="ECDFormer for ECD Spectrum Prediction") + parser.add_argument( + "-c", "--config", + type=str, + default="./spectrum_elucidation/ecformer/configs/ecd.yaml", + help="Path to config file", + ) + parser.add_argument( + "--resume", + type=str, + default=None, + help="Resume from checkpoint path", + ) + parser.add_argument( + "--eval-only", + action="store_true", + help="Only run evaluation on validation set", + ) + parser.add_argument( + "--test-only", + action="store_true", + help="Only run evaluation on test set", + ) + parser.add_argument( + "--predict", + type=str, + default=None, + help="Path to data for prediction (inference mode)", + ) + + args, dynamic_args = parser.parse_known_args() + + # 加载配置 + config = OmegaConf.load(args.config) + cli_config = OmegaConf.from_dotlist(dynamic_args) + config = OmegaConf.merge(config, cli_config) + + # 根据命令行参数覆盖Global配置 + if args.eval_only: + config.Global.do_train = False + config.Global.do_eval = True + config.Global.do_test = False + elif args.test_only: + config.Global.do_train = False + config.Global.do_eval = False + config.Global.do_test = True + elif args.predict is not None: + config.Global.do_train = False + config.Global.do_eval = False + config.Global.do_test = False + config.Global.do_predict = True + config.Dataset.predict.data_path = args.predict + + # 保存配置 + if dist.get_rank() == 0: + os.makedirs(config.Trainer.output_dir, exist_ok=True) + config_name = os.path.basename(args.config) + OmegaConf.save(config, osp.join(config.Trainer.output_dir, config_name)) + + # 转换为字典 + config = OmegaConf.to_container(config, resolve=True) + + # 初始化日志 + logger_path = osp.join(config["Trainer"]["output_dir"], "run.log") + logger.init_logger(log_file=logger_path) + logger.info(f"Logger saved to {logger_path}") + logger.info(f"Config: {config}") + + # 设置随机种子 + seed = config["Trainer"].get("seed", 42) + misc.set_random_seed(seed) + logger.info(f"Set random seed to {seed}") + + # 设置信号处理 + set_signal_handlers() + + # 构建数据加载器 + dataloaders = {} + + if config["Global"].get("do_train", True): + train_cfg = config["Dataset"].get("train") + assert train_cfg is not None, "train dataset must be defined when do_train is True" + dataloaders["train"] = build_dataloader(train_cfg) + logger.info(f"Train dataset loaded, size: {len(dataloaders['train'].dataset)}") + + if config["Global"].get("do_eval", False) or config["Global"].get("do_train", True): + val_cfg = config["Dataset"].get("val") + if val_cfg is not None: + dataloaders["val"] = build_dataloader(val_cfg) + logger.info(f"Validation dataset loaded, size: {len(dataloaders['val'].dataset)}") + else: + logger.info("No validation dataset defined.") + + if config["Global"].get("do_test", False): + test_cfg = config["Dataset"].get("test") + assert test_cfg is not None, "test dataset must be defined when do_test is True" + dataloaders["test"] = build_dataloader(test_cfg) + logger.info(f"Test dataset loaded, size: {len(dataloaders['test'].dataset)}") + + if config["Global"].get("do_predict", False): + predict_cfg = config["Dataset"].get("predict") + assert predict_cfg is not None, "predict dataset must be defined when do_predict is True" + dataloaders["predict"] = build_dataloader(predict_cfg) + logger.info(f"Prediction dataset loaded, size: {len(dataloaders['predict'].dataset)}") + + # 构建模型 + model_cfg = config["Model"] + model = build_model(model_cfg) + logger.info(f"Model built: {model_cfg['__class_name__']}") + + # 打印模型参数量 + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if not p.stop_gradient) + logger.info(f"Total parameters: {total_params / 1e6:.2f}M") + logger.info(f"Trainable parameters: {trainable_params / 1e6:.2f}M") + + # 构建优化器和学习率调度器 + optimizer = None + lr_scheduler = None + + if config.get("Optimizer") is not None and config["Global"].get("do_train", True): + assert dataloaders.get("train") is not None, "train_loader must be defined when optimizer is defined" + assert config["Trainer"].get("max_epochs") is not None, "max_epochs must be defined" + + optimizer, lr_scheduler = build_optimizer( + config["Optimizer"], + model, + config["Trainer"]["max_epochs"], + len(dataloaders["train"]), + ) + logger.info(f"Optimizer built: {config['Optimizer']['__class_name__']}") + + # 构建训练器 + trainer = ECDFormerTrainer( + config=config["Trainer"], + model=model, + train_dataloader=dataloaders.get("train"), + val_dataloader=dataloaders.get("val"), + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + + # 恢复检查点 + if args.resume is not None: + logger.info(f"Resuming from checkpoint: {args.resume}") + save_load.load_checkpoint( + args.resume, + model, + optimizer, + trainer.scaler, + ) + + # 执行训练/评估/预测 + if config["Global"].get("do_train", True): + logger.info("Starting training...") + trainer.train() + + if config["Global"].get("do_eval", False): + logger.info("Evaluating on validation set...") + if "val" in dataloaders: + time_info, loss_info, metric_info = trainer.eval(dataloaders["val"]) + + # 打印详细指标 + msg = "Validation Results:" + for key, meter in metric_info.items(): + msg += f" | {key}: {meter.avg:.6f}" + logger.info(msg) + else: + logger.warning("No validation dataloader found, skipping evaluation.") + + if config["Global"].get("do_test", False): + logger.info("Evaluating on test set...") + if "test" in dataloaders: + time_info, loss_info, metric_info = trainer.eval(dataloaders["test"]) + + msg = "Test Results:" + for key, meter in metric_info.items(): + msg += f" | {key}: {meter.avg:.6f}" + logger.info(msg) + else: + logger.warning("No test dataloader found, skipping test evaluation.") + + if config["Global"].get("do_predict", False): + logger.info("Running prediction...") + if "predict" in dataloaders: + results = trainer.predict(dataloaders["predict"]) + + # 保存预测结果 + import json + output_path = osp.join(config["Trainer"]["output_dir"], "predictions.json") + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + logger.info(f"Predictions saved to {output_path}") + else: + logger.warning("No prediction dataloader found, skipping prediction.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/spectrum_elucidation/ecformer/trainer.py b/spectrum_elucidation/ecformer/trainer.py new file mode 100644 index 00000000..4291375d --- /dev/null +++ b/spectrum_elucidation/ecformer/trainer.py @@ -0,0 +1,497 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from collections import OrderedDict +from typing import Dict, Optional, List, Any, Union + +import numpy as np +import paddle +from paddle import nn +from paddle import optimizer as optim +from paddle.distributed import fleet + +from ppmat.trainer.base_trainer import BaseTrainer +from ppmat.utils import logger +from ppmat.utils import AverageMeter +from ppmat.utils import save_load +from ppmat.metrics.ecd_metric import ECDMetrics +from ppmat.metrics.ir_metric import IRMetrics +from ppmat.losses.ecd_loss import ECDLoss +from ppmat.losses.ir_loss import IRLoss + + +class ECDFormerTrainer(BaseTrainer): + """ + ECDFormer trainer supporting both ECD and IR tasks with dedicated metrics. + + Features: + - Automatic task detection from model class name + - Task-specific loss functions (ECDLoss for classification, IRLoss for regression) + - Task-specific streaming metrics (ECDMetrics, IRMetrics) + - Attention visualization during inference + - Compatible with BaseTrainer training loop + """ + + def __init__( + self, + config: Dict, + model: nn.Layer, + train_dataloader: Optional[paddle.io.DataLoader] = None, + val_dataloader: Optional[paddle.io.DataLoader] = None, + optimizer: Optional[optim.Optimizer] = None, + lr_scheduler: Optional[optim.lr.LRScheduler] = None, + compute_metric_func_dict: Optional[Dict] = None, + ): + # Initialize parent class + super().__init__( + config=config, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + compute_metric_func_dict=compute_metric_func_dict, + ) + + # Task detection from model class name + model_class_name = model.__class__.__name__ + self.is_ir_task = "IR" in model_class_name + self.is_ecd_task = "ECD" in model_class_name + + logger.info(f"Task type detected: {'IR' if self.is_ir_task else 'ECD' if self.is_ecd_task else 'Unknown'}") + + # Get task-specific parameters from config + self.max_peaks = config.get("max_peaks", 15 if self.is_ir_task else 9) + self.num_position_classes = config.get("num_position_classes", 36 if self.is_ir_task else 20) + + # Initialize task-specific loss function + if self.is_ecd_task: + self.loss_fn = ECDLoss( + loss_weight_height=config.get("loss_weight_height", 2.0), + num_position_classes=self.num_position_classes, + height_classes=config.get("height_classes", 2) + ) + logger.info("Using ECDLoss for ECD task") + elif self.is_ir_task: + self.loss_fn = IRLoss( + num_position_classes=self.num_position_classes, + use_height_prediction=config.get("use_height_prediction", True) + ) + logger.info("Using IRLoss for IR task") + else: + # Fallback to simple cross-entropy + self.ce_loss = nn.CrossEntropyLoss() + logger.warning("Unknown task type, using fallback CrossEntropyLoss") + + # Initialize task-specific metrics (will be attached via attach_metrics) + self.train_metrics = None + self.eval_metrics = None + + def attach_metrics(self, metric_cfg=None, **runtime_objs): + """ + Attach task-specific metrics to the trainer. + + Args: + metric_cfg: Metric configuration from config file + **runtime_objs: Additional runtime objects + """ + super().attach_metrics(metric_cfg, **runtime_objs) + + # Create task-specific metric instances if not already in metric_modules + if self.is_ecd_task and 'ECDMetrics' not in str(self.metric_modules): + self.metric_modules['ecd_metrics'] = ECDMetrics( + num_position_classes=self.num_position_classes, + max_peaks=self.max_peaks + ) + logger.info("ECDMetrics attached") + elif self.is_ir_task and 'IRMetrics' not in str(self.metric_modules): + self.metric_modules['ir_metrics'] = IRMetrics( + use_height_prediction=config.get("use_height_prediction", True) + ) + logger.info("IRMetrics attached") + + def train_epoch(self, dataloader: paddle.io.DataLoader): + """ + Train for one epoch using task-specific loss functions. + + Args: + dataloader: Training data loader + + Returns: + tuple: time_info, loss_info, metric_info + """ + self.model.train() + + # Initialize statistics + loss_info = {} + metric_info = {} + time_info = { + "reader_cost": AverageMeter(name="reader_cost", postfix="s"), + "batch_cost": AverageMeter(name="batch_cost", postfix="s"), + } + + # Update training state + self.state.max_steps_in_train_epoch = len(dataloader) + self.state.step_in_train_epoch = 0 + + # Timers + reader_tic = time.perf_counter() + batch_tic = time.perf_counter() + + for iter_id, batch in enumerate(dataloader): + # Parse batch data (adapting to ECDCollator/IRCollator format) + model_inputs, targets = batch + + reader_cost = time.perf_counter() - reader_tic + time_info["reader_cost"].update(reader_cost) + + # Calculate batch size + batch_size = model_inputs['x'].shape[0] if hasattr(model_inputs['x'], 'shape') else 1 + + # Forward pass + with self.autocast_context_manager(self.use_amp, self.amp_level): + predictions = self.model( + x=model_inputs['x'], + edge_index=model_inputs['edge_index'], + edge_attr=model_inputs['edge_attr'], + batch_data=model_inputs['batch_data'], + ba_edge_index=model_inputs.get('ba_edge_index', None), + ba_edge_attr=model_inputs.get('ba_edge_attr', None), + query_mask=model_inputs.get('query_mask', None) + ) + + # Compute loss using task-specific loss function + loss_dict = self.loss_fn(predictions, targets) + loss = loss_dict["loss"] + + # Backward pass + if self.use_amp: + loss_scaled = self.scaler.scale(loss) + loss_scaled.backward() + else: + loss.backward() + + # Update parameters + if self.use_amp: + self.scaler.minimize(self.optimizer, loss_scaled) + else: + self.optimizer.step() + self.optimizer.clear_grad() + + # Update loss statistics + for key, value in loss_dict.items(): + if key not in loss_info: + loss_info[key] = AverageMeter(key) + loss_info[key].update(float(value), batch_size) + + # Update streaming metrics + self._update_streaming_metrics(result={'predictions': predictions, 'loss_dict': loss_dict}, + batch=targets, stage='train') + + batch_cost = time.perf_counter() - batch_tic + time_info["batch_cost"].update(batch_cost) + + # Update state + self.state.step_in_train_epoch += 1 + self.state.global_step += 1 + + # Update learning rate (step-based) + if self.lr_scheduler is not None and not self.lr_scheduler.by_epoch: + self.lr_scheduler.step() + + # Logging + if (self.state.step_in_train_epoch % self.log_freq == 0 or + self.state.step_in_train_epoch == self.state.max_steps_in_train_epoch): + + logs = OrderedDict() + logs["lr"] = self.optimizer.get_lr() + for name, meter in time_info.items(): + logs[name] = meter.val + for name, meter in loss_info.items(): + logs[name] = meter.val + + # Add streaming metrics if available + stream_metrics = self._compute_streaming_metrics(stage='train') + for name, value in stream_metrics.items(): + if isinstance(value, (int, float)): + logs[f"{name}"] = value + if name not in metric_info: + metric_info[name] = AverageMeter(name) + metric_info[name].update(float(value), 1) + + display_logs = self._filter_out_dict(logs, stage="train") + + msg = f"Train: Epoch [{self.state.epoch}/{self.max_epochs}]" + msg += f" | Step: [{self.state.step_in_train_epoch}/{self.state.max_steps_in_train_epoch}]" + for key, val in display_logs.items(): + msg += f" | {key}: {val:.6f}" + logger.info(msg) + + # Write to visualization tools + logger.scalar( + tag="train(step)", + metric_dict=logs, + step=self.state.global_step, + visualdl_writer=self.visualdl_writer, + wandb_writer=self.wandb_writer, + tensorboard_writer=self.tensorboard_writer, + ) + + batch_tic = time.perf_counter() + reader_tic = time.perf_counter() + + # Compute epoch-level streaming metrics + epoch_stream_metrics = self._compute_streaming_metrics(stage='train') + for name, value in epoch_stream_metrics.items(): + if isinstance(value, (int, float)): + if name not in metric_info: + metric_info[name] = AverageMeter(name) + metric_info[name].update(float(value), 1) + + return time_info, loss_info, metric_info + + def eval_epoch(self, dataloader: paddle.io.DataLoader): + """ + Evaluate for one epoch using task-specific metrics. + + Args: + dataloader: Validation data loader + + Returns: + tuple: time_info, loss_info, metric_info + """ + self.model.eval() + + loss_info = {} + metric_info = {} + time_info = { + "reader_cost": AverageMeter(name="reader_cost", postfix="s"), + "batch_cost": AverageMeter(name="batch_cost", postfix="s"), + } + + self.state.max_steps_in_eval_epoch = len(dataloader) + self.state.step_in_eval_epoch = 0 + + # Reset streaming metrics for evaluation + for _, m in self.metric_modules.items(): + if hasattr(m, 'reset'): + m.reset() + + reader_tic = time.perf_counter() + batch_tic = time.perf_counter() + + with paddle.no_grad(): + for iter_id, batch in enumerate(dataloader): + model_inputs, targets = batch + + reader_cost = time.perf_counter() - reader_tic + time_info["reader_cost"].update(reader_cost) + + batch_size = model_inputs['x'].shape[0] if hasattr(model_inputs['x'], 'shape') else 1 + + # Forward pass + with self.autocast_context_manager(self.use_amp, self.amp_level): + predictions = self.model( + x=model_inputs['x'], + edge_index=model_inputs['edge_index'], + edge_attr=model_inputs['edge_attr'], + batch_data=model_inputs['batch_data'], + ba_edge_index=model_inputs.get('ba_edge_index', None), + ba_edge_attr=model_inputs.get('ba_edge_attr', None), + query_mask=model_inputs.get('query_mask', None) + ) + + # Compute loss + loss_dict = self.loss_fn(predictions, targets) + + # Update loss statistics + for key, value in loss_dict.items(): + if key not in loss_info: + loss_info[key] = AverageMeter(key) + loss_info[key].update(float(value), batch_size) + + # Update streaming metrics + self._update_streaming_metrics(result={'predictions': predictions, 'loss_dict': loss_dict}, + batch=targets, stage='eval') + + # Step-wise metric computation (if configured) + if self.metric_strategy_during_eval == "step": + step_metrics = self._compute_streaming_metrics(stage='eval') + for name, value in step_metrics.items(): + if isinstance(value, (int, float)): + if name not in metric_info: + metric_info[name] = AverageMeter(name) + metric_info[name].update(float(value), batch_size) + + batch_cost = time.perf_counter() - batch_tic + time_info["batch_cost"].update(batch_cost) + + self.state.step_in_eval_epoch += 1 + + # Logging + if (self.state.step_in_eval_epoch % self.log_freq == 0 or + self.state.step_in_eval_epoch == self.state.max_steps_in_eval_epoch): + + logs = OrderedDict() + for name, meter in time_info.items(): + logs[name] = meter.val + for name, meter in loss_info.items(): + logs[name] = meter.val + + display_logs = self._filter_out_dict(logs, stage="eval") + + msg = f"Eval: Epoch [{self.state.epoch}/{self.max_epochs}]" + msg += f" | Step: [{self.state.step_in_eval_epoch}/{self.state.max_steps_in_eval_epoch}]" + for key, val in display_logs.items(): + msg += f" | {key}: {val:.6f}" + logger.info(msg) + + batch_tic = time.perf_counter() + reader_tic = time.perf_counter() + + # Compute epoch-level metrics from streaming accumulators + epoch_metrics = self._compute_streaming_metrics(stage='eval') + for name, value in epoch_metrics.items(): + if isinstance(value, (int, float)): + if name not in metric_info: + metric_info[name] = AverageMeter(name) + metric_info[name].update(float(value), len(dataloader.dataset)) + + return time_info, loss_info, metric_info + + def predict(self, dataloader: paddle.io.DataLoader) -> Dict[str, Any]: + """ + Run inference and return predictions with attention visualization. + + Args: + dataloader: Data loader for prediction + + Returns: + dict: Predictions including peak positions, heights, and attention weights + """ + self.model.eval() + + all_pos_pred = [] + all_height_pred = [] + all_attn_weights = [] + all_peak_nums = [] + + with paddle.no_grad(): + for batch in dataloader: + model_inputs, _ = batch # No targets needed for inference + + predictions = self.model( + x=model_inputs['x'], + edge_index=model_inputs['edge_index'], + edge_attr=model_inputs['edge_attr'], + batch_data=model_inputs['batch_data'], + ba_edge_index=model_inputs.get('ba_edge_index', None), + ba_edge_attr=model_inputs.get('ba_edge_attr', None), + query_mask=model_inputs.get('query_mask', None) + ) + + # Get predicted peak numbers + prob_num = paddle.nn.functional.softmax(predictions['peak_number'], axis=1) + pred_peak_num = paddle.argmax(prob_num, axis=1) + all_peak_nums.extend(pred_peak_num.cpu().numpy().tolist()) + + for i in range(pred_peak_num.shape[0]): + n_pred = int(pred_peak_num[i]) + + # Position predictions + pos_pred = paddle.argmax( + predictions['peak_position'][i, :n_pred, :], axis=1 + ).cpu().numpy().tolist() + + # Height predictions (classification or regression) + if 'peak_height' in predictions: + if len(predictions['peak_height'].shape) == 3: # Classification (ECD) + height_pred = paddle.argmax( + predictions['peak_height'][i, :n_pred, :], axis=1 + ).cpu().numpy().tolist() + else: # Regression (IR) + height_pred = predictions['peak_height'][i, :n_pred].reshape([-1]).cpu().numpy().tolist() + else: + height_pred = [] + + all_pos_pred.append(pos_pred) + all_height_pred.append(height_pred) + + # Attention weights for visualization + if predictions.get('attention', {}).get('weights'): + all_attn_weights.append({ + 'weights': predictions['attention']['weights'][i], + 'mask': predictions['attention']['mask'][i] if predictions['attention']['mask'] else None + }) + + return { + 'peak_number': all_peak_nums, + 'peak_position': all_pos_pred, + 'peak_height': all_height_pred, + 'attention': all_attn_weights if all_attn_weights else None + } + + def _update_streaming_metrics(self, *, result, batch, stage: str): + """ + Update streaming metrics with predictions and targets. + + Args: + result: dict containing 'predictions' from model + batch: target batch + stage: 'train' or 'eval' + """ + predictions = result.get('predictions', {}) + + for name, metric in self.metric_modules.items(): + if hasattr(metric, 'update'): + try: + metric.update(predictions, batch) + except Exception as e: + logger.debug(f"Error updating metric {name}: {e}") + + def _compute_streaming_metrics(self, *, stage: str) -> Dict[str, float]: + """ + Compute and reset streaming metrics. + + Args: + stage: 'train' or 'eval' + + Returns: + dict: Computed metrics + """ + all_metrics = {} + + for name, metric in self.metric_modules.items(): + if hasattr(metric, 'accumulate'): + try: + metrics = metric.accumulate() + if isinstance(metrics, dict): + # Add prefix for clarity + for k, v in metrics.items(): + all_metrics[f"{name}/{k}"] = v + else: + all_metrics[name] = metrics + except Exception as e: + logger.debug(f"Error computing metric {name}: {e}") + + if hasattr(metric, 'reset'): + try: + metric.reset() + except Exception: + pass + + return all_metrics \ No newline at end of file From 5ea9d431bfe9fc6d6460497a5a53e59dbeae3175 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Sat, 14 Mar 2026 19:09:21 +0800 Subject: [PATCH 15/16] translate Chinese to English --- ppmat/datasets/build_ecd.py | 82 +++++++++---------- ppmat/datasets/build_ir.py | 52 ++++++------ ppmat/datasets/collate_fn.py | 67 ++++++++------- ppmat/datasets/ecd_dataset.py | 34 ++++---- ppmat/datasets/ir_dataset.py | 40 ++++----- ppmat/models/ecformer/__init__.py | 4 +- .../ecformer/encoders/gin_node_embedding.py | 36 ++++---- ppmat/models/ecformer/layers/atom_encoder.py | 2 +- ppmat/models/ecformer/layers/bond_encoder.py | 2 +- ppmat/models/ecformer/layers/gin_conv.py | 2 +- ppmat/models/ecformer/layers/rbf.py | 6 +- ppmat/models/ecformer/models/ECD.py | 8 +- ppmat/models/ecformer/models/IR.py | 10 +-- ppmat/models/ecformer/models/base_ecformer.py | 68 +++++++-------- ppmat/utils/colored_tqdm.py | 32 ++++---- ppmat/utils/graph_utils.py | 8 +- ppmat/utils/place_env.py | 80 +++++++++--------- spectrum_elucidation/ecformer/train.py | 34 ++++---- spectrum_elucidation/ecformer/trainer.py | 2 +- 19 files changed, 286 insertions(+), 283 deletions(-) diff --git a/ppmat/datasets/build_ecd.py b/ppmat/datasets/build_ecd.py index 1db8442b..f7b44b64 100644 --- a/ppmat/datasets/build_ecd.py +++ b/ppmat/datasets/build_ecd.py @@ -44,7 +44,7 @@ def _parse_factory_cfg( *, default_class_name: str, ) -> Tuple[str, Dict[str, Any]]: - """解析工厂配置,兼容多种格式""" + """Parse factory configuration, compatible with multiple formats""" if cfg is None: return default_class_name, {} @@ -73,7 +73,7 @@ def _parse_factory_cfg( class StrictIndexSampleBuilder: - """按严格索引构建样本(适用于 ECD 数据集)""" + """Build samples by strict index (for ECD dataset)""" def build(self, data_dir: Path, index_file: str, sample_path: str, data_count: Optional[int] = None): import pandas as pd samples = [] @@ -89,7 +89,7 @@ def build(self, data_dir: Path, index_file: str, sample_path: str, data_count: O class DefaultECDDatasetDownloader: - """ECD 数据集下载器""" + """ECD dataset downloader""" def __init__(self, datasets_home: Optional[str] = None): self.datasets_home = datasets_home or download_utils.DATASETS_HOME @@ -103,7 +103,7 @@ def download(self, url: str, md5: Optional[str] = None, force_download: bool = F return Path(downloaded_root) def build_ecformer_downloader(cfg: Optional[Dict[str, Any] | str]): - """构建下载器""" + """Build downloader""" class_name, init_params = _parse_factory_cfg(cfg, default_class_name="DefaultECDDatasetDownloader") cls = _locate_class(class_name) downloader = cls(**init_params) @@ -113,7 +113,7 @@ def build_ecformer_downloader(cfg: Optional[Dict[str, Any] | str]): return downloader def get_key_padding_mask(tokens): - """生成query padding mask""" + """Generate query padding mask""" key_padding_mask = paddle.zeros(tokens.shape) key_padding_mask[tokens == -1] = -paddle.inf return key_padding_mask @@ -151,8 +151,8 @@ def get_sequence_peak(sequence): def read_total_ecd(sample_path, fix_length=20): """ - 读取所有ECD光谱文件,提取峰值信息 - 完全复用原型程序的read_total_ecd逻辑 + Read all ECD spectrum files, extract peak information + Completely reuse the read_total_ecd logic from the prototype program """ filepaths = [ os.path.join(sample_path, "500ECD/data/"), @@ -178,10 +178,10 @@ def read_total_ecd(sample_path, fix_length=20): wavelengths_o, mdegs_o = ECD_info['Wavelength (nm)'], ECD_info['ECD (Mdeg)'] wavelengths = [int(i) for i in wavelengths_o] - # 将小值置零 + # Set small values to zero mdegs = [int(i) if abs(i) > 1 else 0 for i in mdegs_o] - # 去除前后零值 + # Remove leading and trailing zeros begin, end = 0, 0 for i in range(len(mdegs)): if mdegs[i] != 0: @@ -201,23 +201,23 @@ def read_total_ecd(sample_path, fix_length=20): 'ecd': mdegs, } - # 处理光谱序列,提取峰值 + # Process spectrum sequences, extract peaks ecd_final_list = [] for key, itm in ecd_dict.items(): - # 等间隔采样 + # Uniform sampling distance = int(len(itm['ecd']) / (fix_length - 1)) sequence_org = [itm['ecd'][i] for i in range(0, len(itm['ecd']), distance)][:fix_length] - # 归一化 + # Normalize sequence = normalize_func(sequence_org, norm_range=[-100, 100]) - # padding到固定长度 + # Pad to fixed length if len(sequence) < fix_length: sequence.extend([0] * (fix_length - len(sequence))) sequence_org.extend([0] * (fix_length - len(sequence_org))) assert len(sequence) == fix_length - # 生成峰值掩码 + # Generate peak mask peak_mask = [0] * len(sequence) for i in range(1, len(sequence) - 1): if sequence[i - 1] < sequence[i] and sequence[i] > sequence[i + 1]: @@ -233,17 +233,17 @@ def read_total_ecd(sample_path, fix_length=20): if peak_mask[i + 1] != 2: peak_mask[i + 1] = 1 - # 提取峰值位置 + # Extract peak positions peak_position_list = get_sequence_peak(sequence) peak_number = len(peak_position_list) assert peak_number < 9, f"Peak number {peak_number} >= 9" - # 峰值符号 + # Peak signs peak_height_list = [] for i in peak_position_list: peak_height_list.append(1 if sequence[i] >= 0 else 0) - # padding到9个峰 + # Pad to 9 peaks peak_position_list = peak_position_list + [-1] * (9 - peak_number) peak_height_list = peak_height_list + [-1] * (9 - peak_number) query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) @@ -266,8 +266,8 @@ def read_total_ecd(sample_path, fix_length=20): def Construct_dataset(dataset, data_index, path): """ - 从原始特征构建图数据 - 完全复用原型程序的Construct_dataset逻辑 + Construct graph data from raw features + Completely reuse the Construct_dataset logic from the prototype program """ graph_atom_bond = [] graph_bond_angle = [] @@ -277,17 +277,17 @@ def Construct_dataset(dataset, data_index, path): for i in tqdm(range(len(dataset)), desc="Constructing graphs"): data = dataset[i] - # 收集原子特征 + # Collect atom features atom_feature = [] for name in atom_id_names: atom_feature.append(data[name]) - # 收集键特征 + # Collect bond features bond_feature = [] for name in bond_id_names[0:3]: bond_feature.append(data[name]) - # 转换为Tensor + # Convert to Tensor atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') bond_float_feature = paddle.to_tensor(data['bond_length'].astype(paddle.get_default_dtype())) @@ -296,14 +296,14 @@ def Construct_dataset(dataset, data_index, path): bond_index = paddle.to_tensor(data['BondAngleGraph_edges'].T, dtype='int64') data_index_int = paddle.to_tensor(np.array(data_index[i]), dtype='int64') - # 添加描述符特征(与原型程序完全一致) + # Add descriptor features (exactly the same as prototype program) TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] - # 合并特征 + # Merge features bond_feature = paddle.concat( [bond_feature.astype(bond_float_feature.dtype), bond_float_feature.reshape([-1, 1])], @@ -319,7 +319,7 @@ def Construct_dataset(dataset, data_index, path): bond_angle_feature = paddle.concat([bond_angle_feature, MDEC.reshape([-1, 1])], axis=1) bond_angle_feature = paddle.concat([bond_angle_feature, MATS.reshape([-1, 1])], axis=1) - # 创建Data对象 + # Create Data objects data_atom_bond = Data( x=atom_feature, edge_index=edge_index, @@ -348,36 +348,36 @@ def GetAtomBondAngleDataset( line_idx_dict ): """ - 核心函数:构建并返回切好的图数据集 + Core function: build and return the processed graph dataset Args: - sample_path: ECD光谱文件路径 - dataset_all: 从npy加载的info列表 - index_all: 索引列表 - hand_idx_dict: 手性对映射 - line_idx_dict: 行号映射 + sample_path: ECD spectrum file path + dataset_all: info list loaded from npy + index_all: index list + hand_idx_dict: chiral pair mapping + line_idx_dict: line number mapping Returns: - dataset_graph_atom_bond: atom-bond图列表 - dataset_graph_bond_angle: bond-angle图列表 + dataset_graph_atom_bond: atom-bond graph list + dataset_graph_bond_angle: bond-angle graph list """ - # 1. 读取ECD光谱序列 + # 1. Read ECD spectrum sequences ecd_sequences, ecd_original_sequences = read_total_ecd(sample_path) - # 2. 构建图数据 + # 2. Construct graph data total_graph_atom_bond, total_graph_bond_angle = Construct_dataset( dataset_all, index_all, sample_path ) print("Case Before Process = ", len(total_graph_atom_bond), len(total_graph_bond_angle)) - # 3. 将光谱序列信息附加到图数据上 + # 3. Attach spectrum sequence information to graph data dataset_graph_atom_bond, dataset_graph_bond_angle = [], [] for itm in ecd_sequences: line_num = itm['id'] - 1 atom_bond = total_graph_atom_bond[line_num] - # 附加光谱信息 + # Attach spectrum information atom_bond.sequence = paddle.to_tensor([itm['seq']]) atom_bond.ecd_id = paddle.to_tensor(itm['id']) atom_bond.seq_mask = paddle.to_tensor([itm['seq_mask']]) @@ -390,7 +390,7 @@ def GetAtomBondAngleDataset( dataset_graph_atom_bond.append(atom_bond) dataset_graph_bond_angle.append(total_graph_bond_angle[line_num]) - # 4. 对映体增强:添加对映体样本 + # 4. Enantiomer enhancement: add enantiomer samples hand_id, unnamed_id = line_idx_dict[line_num]['hand_id'], line_idx_dict[line_num]['unnamed_id'] another_line_num = -1 @@ -401,7 +401,7 @@ def GetAtomBondAngleDataset( assert another_line_num != -1, f"cannot find the hand info of {line_num}" - # 对映体:光谱取反 + # Enantiomer: invert the spectrum atom_bond_oppo = total_graph_atom_bond[another_line_num] atom_bond_oppo.sequence = paddle.neg(paddle.to_tensor([itm['seq']])) atom_bond_oppo.ecd_id = paddle.to_tensor(another_line_num + 1) @@ -423,11 +423,11 @@ def GetAtomBondAngleDataset( def build_ecformer_sample_builder(cfg: Optional[Dict[str, Any] | str]): - """构建样本构建器""" + """Build sample builder""" class_name, init_params = _parse_factory_cfg(cfg, default_class_name="StrictIndexSampleBuilder") cls = _locate_class(class_name) builder = cls(**init_params) if not hasattr(builder, 'build'): raise TypeError(f"Sample builder {class_name} must implement 'build' method") logger.debug(f"Use sample builder: {class_name}") - return builder + return builder \ No newline at end of file diff --git a/ppmat/datasets/build_ir.py b/ppmat/datasets/build_ir.py index be3a4fe0..f04139ad 100644 --- a/ppmat/datasets/build_ir.py +++ b/ppmat/datasets/build_ir.py @@ -48,7 +48,7 @@ def _parse_factory_cfg( *, default_class_name: str, ) -> Tuple[str, Dict[str, Any]]: - """解析工厂配置,兼容多种格式""" + """Parse factory configuration, compatible with multiple formats""" if cfg is None: return default_class_name, {} @@ -77,9 +77,9 @@ def _parse_factory_cfg( class IRStrictIndexSampleBuilder: - """按严格索引构建IR样本""" + """Build IR samples by strict index""" def build(self, data_dir: Path, meta_file: str, spectra_dir: str, data_count: Optional[int] = None): - """构建样本列表""" + """Build sample list""" samples = [] meta_path = data_dir / meta_file data = np.load(meta_path, allow_pickle=True).item() @@ -96,7 +96,7 @@ def build(self, data_dir: Path, meta_file: str, spectra_dir: str, data_count: Op class DefaultIRDatasetDownloader: - """IR 数据集下载器""" + """IR dataset downloader""" def __init__(self, datasets_home: Optional[str] = None): self.datasets_home = datasets_home or download_utils.DATASETS_HOME @@ -111,7 +111,7 @@ def download(self, url: str, md5: Optional[str] = None, force_download: bool = F def build_ir_downloader(cfg: Optional[Dict[str, Any] | str]): - """构建下载器""" + """Build downloader""" class_name, init_params = _parse_factory_cfg(cfg, default_class_name="DefaultIRDatasetDownloader") cls = _locate_class(class_name) downloader = cls(**init_params) @@ -122,7 +122,7 @@ def build_ir_downloader(cfg: Optional[Dict[str, Any] | str]): def build_ir_sample_builder(cfg: Optional[Dict[str, Any] | str]): - """构建样本构建器""" + """Build sample builder""" class_name, init_params = _parse_factory_cfg(cfg, default_class_name="IRStrictIndexSampleBuilder") cls = _locate_class(class_name) builder = cls(**init_params) @@ -132,7 +132,7 @@ def build_ir_sample_builder(cfg: Optional[Dict[str, Any] | str]): return builder -# ==================== IR 特定工具函数 ==================== +# ==================== IR Specific Utility Functions ==================== IR_WAVELENGTH_MIN = 500 IR_WAVELENGTH_MAX = 4000 @@ -141,20 +141,20 @@ def build_ir_sample_builder(cfg: Optional[Dict[str, Any] | str]): def get_key_padding_mask(tokens): - """生成query padding mask""" + """Generate query padding mask""" key_padding_mask = paddle.zeros(tokens.shape) key_padding_mask[tokens == -1] = -paddle.inf return key_padding_mask def x_bin_position(real_x, distance=IR_STEP): - """将实际波数转换为箱ID""" + """Convert actual wavenumber to bin ID""" return int((real_x - IR_WAVELENGTH_MIN) / distance) def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): """ - 从原始特征构建IR图数据 + Construct IR graph data from raw features """ graph_atom_bond = [] graph_bond_angle = [] @@ -166,7 +166,7 @@ def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): for i in tqdm(range(len(dataset)), desc="Constructing IR graphs"): data = dataset[i] - # 收集原子特征 + # Collect atom features atom_feature = [] for name in atom_id_names: if name in data: @@ -177,7 +177,7 @@ def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): num_atoms = data.get('atomic_num', np.zeros(1)).shape[0] atom_feature.append(np.zeros(num_atoms)) - # 收集键特征 + # Collect bond features bond_feature = [] for name in bond_id_names: if name in data: @@ -188,7 +188,7 @@ def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): num_bonds = data.get('bond_dir', np.zeros(1)).shape[0] bond_feature.append(np.zeros(num_bonds)) - # 转换为Tensor + # Convert to Tensor atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') @@ -201,19 +201,19 @@ def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): data_index_int = paddle.to_tensor(np.array(int(data_index[i])), dtype='int64') num_atoms = atom_feature.shape[0] - # 合并键特征 + # Merge bond features bond_feature = paddle.concat( [bond_feature.astype(bond_float_feature.dtype), bond_float_feature.reshape([-1, 1])], axis=1 ) - # 处理键角特征 - 确保输出6维! + # Process bond angle features - ensure output is 6-dim if bond_angle_feature.shape[0] > 0: - # 基础特征:bond_angle + # Base feature: bond_angle features = [bond_angle_feature.reshape([-1, 1])] - # 如果有描述符,添加5个描述符特征 + # If descriptors exist,add 5 descriptor features if all_descriptor is not None and i < all_descriptor.shape[0]: TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] @@ -229,14 +229,14 @@ def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): MATS.reshape([-1, 1]) ]) else: - # 如果没有描述符,用0填充剩下的5维 + # If no descriptors, fill the remaining 5 dimensions with zeros for _ in range(5): features.append(paddle.zeros([bond_angle_feature.shape[0], 1])) - # 拼接成 [E_ba, 6] + # Concatenate to [E_ba, 6] bond_angle_feature = paddle.concat(features, axis=1) else: - # 如果没有键角,创建全0的 [0, 6] + # If no bond angles, create all-zero [0, 6] bond_angle_feature = paddle.zeros([0, 6]) data_atom_bond = Data( @@ -260,7 +260,7 @@ def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): def read_ir_spectra_by_ids(sample_path, index_all, max_peak=DEFAULT_MAX_PEAKS): """ - 按需读取IR光谱文件 + Read IR spectrum files on demand """ ir_final_list = [] @@ -320,24 +320,24 @@ def GetIRDataset( index_all, ): """ - 核心函数:构建并返回IR图数据集 + Core function: build and return IR graph dataset """ - # 1. 读取IR光谱序列 + # 1. Read IR spectrum sequences ir_sequences = read_ir_spectra_by_ids(sample_path, index_all) - # 2. 构建图数据 + # 2. Construct graph data total_graph_atom_bond, total_graph_bond_angle = Construct_IR_Dataset( dataset_all, index_all, sample_path ) print("Case Before Process = ", len(total_graph_atom_bond), len(total_graph_bond_angle)) - # 3. 将光谱信息附加到图数据上 + # 3. Attach spectrum information to graph data dataset_graph_atom_bond, dataset_graph_bond_angle = [], [] for i, itm in enumerate(ir_sequences): atom_bond = total_graph_atom_bond[i] - # 附加光谱信息 + # Attach spectrum information atom_bond.sequence = paddle.to_tensor([itm['seq_40']]) atom_bond.ir_id = paddle.to_tensor(int(itm['id'])) atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) diff --git a/ppmat/datasets/collate_fn.py b/ppmat/datasets/collate_fn.py index 28189e46..c53ee7b3 100644 --- a/ppmat/datasets/collate_fn.py +++ b/ppmat/datasets/collate_fn.py @@ -306,48 +306,57 @@ def pad_sequence(sequences, batch_first=False, padding_value=0): class ECDCollator(DefaultCollator): def __call__(self, batch: List[Any]) -> Any: - batch = [list(x) for x in zip(*batch)] # transpose - for i in range(len(batch)): # 组Batch + batch = [list(x) for x in zip(*batch)] # transpose + for i in range(len(batch)): # Group into batches batch[i] = Batch.from_data_list(batch[i]) batch0 = batch[0] batch1 = batch[1] - # Data解包到Tensor字典 + # Unpack Data to Tensor dictionary batch_atom_bond, batch_bond_angle = batch0, batch1 - x, edge_index, edge_attr, query_mask =batch_atom_bond.x,batch_atom_bond.edge_index,batch_atom_bond.edge_attr,batch_atom_bond.query_mask - ba_edge_index, ba_edge_attr = batch_bond_angle.edge_index,batch_bond_angle.edge_attr + x, edge_index, edge_attr, query_mask = ( + batch_atom_bond.x, + batch_atom_bond.edge_index, + batch_atom_bond.edge_attr, + batch_atom_bond.query_mask, + ) + ba_edge_index, ba_edge_attr = ( + batch_bond_angle.edge_index, + batch_bond_angle.edge_attr, + ) batch_data = batch_atom_bond.batch - pos_gt = batch_atom_bond.peak_position + pos_gt = batch_atom_bond.peak_position height_gt = batch_atom_bond.peak_height - num_gt = batch_atom_bond.peak_num - return \ - { - "x" : x , - "edge_index" : edge_index , - "edge_attr" : edge_attr , - "batch_data" : batch_data , - "ba_edge_index" : ba_edge_index , - "ba_edge_attr" : ba_edge_attr , - "query_mask" : query_mask - }, \ - { - "peak_number" : num_gt , - "peak_position": pos_gt , - "peak_height" : height_gt - } + num_gt = batch_atom_bond.peak_num + return ( + { + "x": x, + "edge_index": edge_index, + "edge_attr": edge_attr, + "batch_data": batch_data, + "ba_edge_index": ba_edge_index, + "ba_edge_attr": ba_edge_attr, + "query_mask": query_mask, + }, + { + "peak_number": num_gt, + "peak_position": pos_gt, + "peak_height": height_gt, + }, + ) class IRCollator(DefaultCollator): - """IR 数据集专用 collator,返回 Tensor 字典""" - - def __call__(self, batch: List[Any]) -> Any: + """IR dataset specific collator, returns Tensor dictionary""" + + def __call__(self, batch: List[Any]) -> Any: batch = [list(x) for x in zip(*batch)] # transpose for i in range(len(batch)): batch[i] = Batch.from_data_list(batch[i]) - + batch_atom_bond, batch_bond_angle = batch[0], batch[1] - + x, edge_index, edge_attr, query_mask = ( batch_atom_bond.x, batch_atom_bond.edge_index, @@ -362,7 +371,7 @@ def __call__(self, batch: List[Any]) -> Any: pos_gt = batch_atom_bond.peak_position height_gt = batch_atom_bond.peak_height num_gt = batch_atom_bond.peak_num - + return ( { "x": x, @@ -377,5 +386,5 @@ def __call__(self, batch: List[Any]) -> Any: "peak_number": num_gt, "peak_position": pos_gt, "peak_height": height_gt, - } + }, ) \ No newline at end of file diff --git a/ppmat/datasets/ecd_dataset.py b/ppmat/datasets/ecd_dataset.py index c67651bf..f3174c96 100644 --- a/ppmat/datasets/ecd_dataset.py +++ b/ppmat/datasets/ecd_dataset.py @@ -19,11 +19,9 @@ import pandas as pd import paddle from paddle.io import Dataset -from paddle_geometric.data import Data from pathlib import Path -from typing import Dict, List, Optional, Any +from typing import Dict, Optional -from ppmat.utils import ColoredTqdm as tqdm from ppmat.utils import PlaceEnv from ppmat.utils.compound_tools import get_atom_feature_dims, get_bond_feature_dims from ppmat.datasets.build_ecd import build_ecformer_sample_builder @@ -34,9 +32,9 @@ class ECDDataset(Dataset): """ - ECDFormer ECD 光谱预测数据集 + ECDFormer ECD Spectrum Prediction Dataset - 数据来源:https://paddle-org.bj.bcebos.com/paddlematerials/datasets/ECD/ECD.tar.gz + Data source: https://paddle-org.bj.bcebos.com/paddlematerials/datasets/ECD/ECD.tar.gz """ url = "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/ECD/ECD.tar.gz" @@ -62,22 +60,22 @@ def __init__( self.use_geometry_enhanced = use_geometry_enhanced self.use_column_info = use_column_info - # 构建组件 + # Build components self.sample_builder = build_ecformer_sample_builder(sample_builder_cfg) self.downloader = build_ecformer_downloader(downloader_cfg) - # 处理下载 + # Handle download if force_download or (not self._check_files() and download): self.downloaded_root = self.downloader.download( self.url, self.md5, force_download=force_download ) self.data_path = self.downloaded_root - # 加载数据 + # Load data self._load_data() def _check_files(self): - """检查必要的文件是否存在""" + """Check if necessary files exist""" npy_path = self.data_path / 'ecd_column_charity_new_smiles.npy' csv_path = self.data_path / 'ecd_info.csv' @@ -88,34 +86,34 @@ def _check_files(self): return True def _load_data(self): - """加载所有数据""" - # 1. 加载 npy 文件 + """Load all data""" + # 1. Load npy file npy_path = self.data_path / 'ecd_column_charity_new_smiles.npy' if not npy_path.exists(): raise FileNotFoundError(f"npy file not found: {npy_path}") self.ecd_dataset = np.load(npy_path, allow_pickle=True).tolist() - # 2. 加载 csv 文件 + # 2. Load csv file csv_path = self.data_path / 'ecd_info.csv' if not csv_path.exists(): raise FileNotFoundError(f"csv file not found: {csv_path}") self.ecd_info = pd.read_csv(csv_path, encoding='gbk') - # 3. 提取数据 + # 3. Extract data self.dataset_all = [item['info'] for item in self.ecd_dataset] self.smiles_all = [item['smiles'] for item in self.ecd_dataset] self.index_all = self.ecd_info['Unnamed: 0'].values - # 4. 构建手性对映射 + # 4. Build chiral pair mapping self._build_chiral_mapping() - # 5. 构建图数据集 + # 5. Build graph dataset self._build_graph_dataset() def _build_chiral_mapping(self): - """构建手性对映体映射""" + """Build chiral enantiomer mapping""" self.hand_idx_dict = {} self.line_idx_dict = {} @@ -136,7 +134,7 @@ def _build_chiral_mapping(self): @PlaceEnv(paddle.CPUPlace()) def _build_graph_dataset(self): - """构建图数据集""" + """Build graph dataset""" global _cache if len(_cache) > 0: @@ -160,5 +158,5 @@ def __len__(self): @PlaceEnv(paddle.CPUPlace()) def __getitem__(self, idx): - """返回 (atom_bond_graph, bond_angle_graph)""" + """Returns (atom_bond_graph, bond_angle_graph)""" return self.graph_atom_bond[idx], self.graph_bond_angle[idx] \ No newline at end of file diff --git a/ppmat/datasets/ir_dataset.py b/ppmat/datasets/ir_dataset.py index 0e3d5f34..5d73db47 100644 --- a/ppmat/datasets/ir_dataset.py +++ b/ppmat/datasets/ir_dataset.py @@ -14,20 +14,16 @@ from __future__ import annotations -import os import numpy as np import paddle from paddle.io import Dataset -from paddle_geometric.data import Data from pathlib import Path -from typing import Dict, List, Optional, Any +from typing import Dict, Optional -from ppmat.utils import ColoredTqdm as tqdm from ppmat.utils import PlaceEnv from ppmat.datasets.build_ir import ( build_ir_sample_builder, build_ir_downloader, - GetIRDataset, read_ir_spectra_by_ids, Construct_IR_Dataset, ) @@ -37,12 +33,12 @@ class IRDataset(Dataset): """ - ECFormer IR 光谱预测数据集 + ECFormer IR Spectrum Prediction Dataset - 支持三种预加载模式: - - '100': 100个样本的小数据集(默认,用于快速测试) - - '10000': 1万个样本的中等数据集 - - 'all': 全部样本(可能很大) + Supports three preloading modes: + - '100': Small dataset with 100 samples (default, for quick testing) + - '10000': Medium dataset with 10,000 samples + - 'all': All samples (may be very large) """ url = "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/IR/IR.tar.gz" @@ -72,7 +68,7 @@ def __init__( cache_key = f"{data_path}_{mode}_{use_geometry_enhanced}" - # 如果启用缓存且命中,直接返回 + # If cache is enabled and hit, return directly if use_cache and cache_key in _cache: cached_data = _cache[cache_key] self.graph_atom_bond = cached_data['atom_bond'] @@ -80,21 +76,21 @@ def __init__( self.smiles_list = cached_data.get('smiles', []) return - # 构建组件 + # Build components self.sample_builder = build_ir_sample_builder(sample_builder_cfg) self.downloader = build_ir_downloader(downloader_cfg) - # 处理下载 + # Handle download if force_download or (not self._check_files() and download): self.downloaded_root = self.downloader.download( self.url, self.md5, force_download=force_download ) self.data_path = self.downloaded_root - # 加载数据 + # Load data self._load_data() - # 存入缓存 + # Store in cache if use_cache: _cache[cache_key] = { 'atom_bond': self.graph_atom_bond, @@ -103,7 +99,7 @@ def __init__( } def _check_files(self): - """检查必要的文件是否存在""" + """Check if necessary files exist""" meta_path = self.data_path / f'ir_column_charity_{self.mode}.npy' spectra_path = self.data_path / 'qm9_ir_spec' @@ -114,8 +110,8 @@ def _check_files(self): return True def _load_data(self): - """加载所有数据""" - # 1. 加载元数据文件 + """Load all data""" + # 1. Load metadata file meta_path = self.data_path / f'ir_column_charity_{self.mode}.npy' if not meta_path.exists(): raise FileNotFoundError(f"IR meta file {meta_path} not found") @@ -127,13 +123,13 @@ def _load_data(self): print(f"Loaded meta data: {len(index_all)} samples") - # 2. 按需读取IR光谱 + # 2. Read IR spectra on demand spectra_path = self.data_path / 'qm9_ir_spec' self.ir_sequences = read_ir_spectra_by_ids(str(spectra_path), index_all) print(f"Loaded {len(self.ir_sequences)} IR spectra") - # 3. 构建图数据 + # 3. Construct graph data descriptor_path = self.data_path / 'descriptor_all_column.npy' if not descriptor_path.exists(): descriptor_path = None @@ -142,7 +138,7 @@ def _load_data(self): dataset_all, index_all, descriptor_path ) - # 4. 将光谱信息附加到图数据 + # 4. Attach spectrum information to graph data self.graph_atom_bond = [] self.graph_bond_angle = [] self.smiles_list = [] @@ -168,5 +164,5 @@ def __len__(self): @PlaceEnv(paddle.CPUPlace()) def __getitem__(self, idx): - """返回 (atom_bond_graph, bond_angle_graph)""" + """Returns (atom_bond_graph, bond_angle_graph)""" return self.graph_atom_bond[idx], self.graph_bond_angle[idx] \ No newline at end of file diff --git a/ppmat/models/ecformer/__init__.py b/ppmat/models/ecformer/__init__.py index 176592d5..0e90b295 100644 --- a/ppmat/models/ecformer/__init__.py +++ b/ppmat/models/ecformer/__init__.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -# 导出模型类 +# export model calss from .models.ECD import ECFormerECD from .models.IR import ECFormerIR -# 导出编码器(如需直接使用) +# export encoder(if want to use directly) from .encoders.gin_node_embedding import GINNodeEmbedding __all__ = [ diff --git a/ppmat/models/ecformer/encoders/gin_node_embedding.py b/ppmat/models/ecformer/encoders/gin_node_embedding.py index 17df78e5..ec3f1fb3 100644 --- a/ppmat/models/ecformer/encoders/gin_node_embedding.py +++ b/ppmat/models/ecformer/encoders/gin_node_embedding.py @@ -23,7 +23,7 @@ class GINNodeEmbedding(nn.Layer): - """GIN节点嵌入模块 - 支持几何增强的双图结构""" + """GIN node embedding module - supports geometry-enhanced dual graph structure""" def __init__( self, @@ -51,13 +51,13 @@ def __init__( if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") - # 编码器 + # Encoders self.atom_encoder = AtomEncoder(full_atom_feature_dims, emb_dim) self.bond_encoder = BondEncoder(full_bond_feature_dims, emb_dim) self.bond_float_encoder = BondFloatRBF(bond_float_names, emb_dim) self.bond_angle_encoder = BondAngleFloatRBF(bond_angle_float_names, emb_dim) - # GNN层列表 + # GNN layer lists self.convs = nn.LayerList() self.convs_bond_angle = nn.LayerList() self.convs_bond_embedding = nn.LayerList() @@ -78,17 +78,17 @@ def __init__( def forward( self, - x, # [N, F] 原子特征 - edge_index, # [2, E] 边索引 - edge_attr, # [E, D] 边特征 - # 几何增强相关输入 - ba_edge_index=None, # [2, E_ba] 键角图边索引 - ba_edge_attr=None, # [E_ba, D_ba] 键角图边特征 + x, # [N, F] atom features + edge_index, # [2, E] edge indices + edge_attr, # [E, D] edge features + # Geometry enhancement related inputs + ba_edge_index=None, # [2, E_ba] bond-angle graph edge indices + ba_edge_attr=None, # [E_ba, D_ba] bond-angle graph edge features ): """ - 前向传播 + Forward pass """ - # 1. 原子特征编码 + # 1. Atom feature encoding if x.dtype != paddle.int64: x = x.astype(paddle.int64) h_list = [self.atom_encoder(x)] @@ -105,11 +105,11 @@ def forward( def _forward_enhanced(self, h_list, edge_index, edge_attr, ba_edge_index, ba_edge_attr): - """几何增强前向传播""" + """Geometry-enhanced forward pass""" bond_id_len = len(self.bond_id_names) - # 初始化边表示 + # Initialize edge representations h_list_ba = [self.bond_float_encoder( edge_attr[:, bond_id_len:edge_attr.shape[1]+1].astype('float32') ) + self.bond_encoder( @@ -117,10 +117,10 @@ def _forward_enhanced(self, h_list, edge_index, edge_attr, )] for layer in range(self.num_layers): - # 节点更新 + # Node update h = self.convs[layer](h_list[layer], edge_index, h_list_ba[layer]) - # 边更新 + # Edge update cur_h_ba = self.convs_bond_embedding[layer]( edge_attr[:, 0:bond_id_len].astype('int64') ) + self.convs_bond_float[layer]( @@ -129,7 +129,7 @@ def _forward_enhanced(self, h_list, edge_index, edge_attr, cur_angle_hidden = self.convs_angle_float[layer](ba_edge_attr) h_ba = self.convs_bond_angle[layer](cur_h_ba, ba_edge_index, cur_angle_hidden) - # Dropout和残差 + # Dropout and residual if layer == self.num_layers - 1: h = F.dropout(h, self.drop_ratio, training=self.training) h_ba = F.dropout(h_ba, self.drop_ratio, training=self.training) @@ -144,7 +144,7 @@ def _forward_enhanced(self, h_list, edge_index, edge_attr, h_list.append(h) h_list_ba.append(h_ba) - # JK连接策略 + # JK connection strategy if self.JK == "last": node_representation = h_list[-1] edge_representation = h_list_ba[-1] @@ -155,7 +155,7 @@ def _forward_enhanced(self, h_list, edge_index, edge_attr, return node_representation, edge_representation def _forward_simple(self, h_list, edge_index, edge_attr): - """简化前向传播""" + """Simplified forward pass""" bond_id_len = len(self.bond_id_names) for layer in range(self.num_layers): diff --git a/ppmat/models/ecformer/layers/atom_encoder.py b/ppmat/models/ecformer/layers/atom_encoder.py index 34ef8d03..44d5c322 100644 --- a/ppmat/models/ecformer/layers/atom_encoder.py +++ b/ppmat/models/ecformer/layers/atom_encoder.py @@ -15,7 +15,7 @@ import paddle.nn as nn class AtomEncoder(nn.Layer): - """原子特征编码器 - 将离散原子特征映射为连续向量""" + """Atomic Feature Encoder - Maps discrete atomic features to continuous vectors""" def __init__(self, full_atom_feature_dims, emb_dim): super(AtomEncoder, self).__init__() diff --git a/ppmat/models/ecformer/layers/bond_encoder.py b/ppmat/models/ecformer/layers/bond_encoder.py index ccd7207b..46d56fdd 100644 --- a/ppmat/models/ecformer/layers/bond_encoder.py +++ b/ppmat/models/ecformer/layers/bond_encoder.py @@ -15,7 +15,7 @@ import paddle.nn as nn class BondEncoder(nn.Layer): - """键特征编码器 - 将离散键特征映射为连续向量""" + """Bond feature encoder - maps discrete bond features to continuous vectors""" def __init__(self, full_bond_feature_dims, emb_dim): super(BondEncoder, self).__init__() diff --git a/ppmat/models/ecformer/layers/gin_conv.py b/ppmat/models/ecformer/layers/gin_conv.py index 73d0ab7f..77c18209 100644 --- a/ppmat/models/ecformer/layers/gin_conv.py +++ b/ppmat/models/ecformer/layers/gin_conv.py @@ -18,7 +18,7 @@ import paddle.nn.functional as F class GINConv(MessagePassing): - """图同构卷积层""" + """Graph Isomorphism Convolution Layer""" def __init__(self, emb_dim): super(GINConv, self).__init__(aggr="add") diff --git a/ppmat/models/ecformer/layers/rbf.py b/ppmat/models/ecformer/layers/rbf.py index 8fc04823..9330c5c6 100644 --- a/ppmat/models/ecformer/layers/rbf.py +++ b/ppmat/models/ecformer/layers/rbf.py @@ -17,7 +17,7 @@ import numpy as np class RBF(nn.Layer): - """径向基函数""" + """Radial Basis Function""" def __init__(self, centers: paddle.nn.parameter.Parameter, @@ -32,7 +32,7 @@ def forward(self, x): class BondFloatRBF(nn.Layer): - """连续键特征RBF编码器""" + """RBF encoder for continuous bond features""" def __init__(self, bond_float_names, embed_dim, rbf_params=None): super(BondFloatRBF, self).__init__() @@ -72,7 +72,7 @@ def forward(self, bond_float_features): class BondAngleFloatRBF(nn.Layer): - """键角连续特征RBF编码器""" + """RBF encoder for continuous bond angle features""" def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None): super(BondAngleFloatRBF, self).__init__() diff --git a/ppmat/models/ecformer/models/ECD.py b/ppmat/models/ecformer/models/ECD.py index fb585675..18f53190 100644 --- a/ppmat/models/ecformer/models/ECD.py +++ b/ppmat/models/ecformer/models/ECD.py @@ -18,7 +18,7 @@ class ECFormerECD(ECFormerBase): - """ECFormer for ECD光谱预测 - 峰属性解耦版本""" + """ECFormer for ECD spectrum prediction - peak attribute decoupling version""" def __init__( self, @@ -28,21 +28,21 @@ def __init__( ): super().__init__(**kwargs) - # 峰数预测头 + # Peak number prediction head self.pred_number_layer = nn.Sequential( nn.Linear(self.emb_dim, self.emb_dim * 2), nn.ReLU(), nn.Linear(self.emb_dim * 2, self.max_peaks) ) - # 峰位置预测头 + # Peak position prediction head self.pred_position_layer = nn.Sequential( nn.Linear(self.emb_dim, self.emb_dim // 4), nn.ReLU(), nn.Linear(self.emb_dim // 4, num_position_classes) ) - # 峰符号预测头 + # Peak sign prediction head self.pred_height_layer = nn.Sequential( nn.Linear(self.emb_dim, self.emb_dim // 4), nn.ReLU(), diff --git a/ppmat/models/ecformer/models/IR.py b/ppmat/models/ecformer/models/IR.py index 122c6caf..111283da 100644 --- a/ppmat/models/ecformer/models/IR.py +++ b/ppmat/models/ecformer/models/IR.py @@ -18,7 +18,7 @@ class ECFormerIR(ECFormerBase): - """ECDFormer for IR光谱预测 - 序列回归版本""" + """ECDFormer for IR spectrum prediction - sequence regression version""" def __init__( self, @@ -26,26 +26,26 @@ def __init__( use_height_prediction=True, **kwargs ): - # IR任务最大峰数不同 + # IR task has different maximum number of peaks kwargs['max_peaks'] = kwargs.get('max_peaks', 15) super().__init__(**kwargs) - # 峰数预测头(IR最多15个峰) + # Peak number prediction head (IR has at most 15 peaks) self.pred_number_layer = nn.Sequential( nn.Linear(self.emb_dim, self.emb_dim * 2), nn.ReLU(), nn.Linear(self.emb_dim * 2, self.max_peaks + 1) ) - # 峰位置预测头(IR位置分类更多) + # Peak position prediction head (IR has more position classes) self.pred_position_layer = nn.Sequential( nn.Linear(self.emb_dim, self.emb_dim // 4), nn.ReLU(), nn.Linear(self.emb_dim // 4, num_position_classes) ) - # 峰强度预测头(IR回归) + # Peak intensity prediction head (IR regression) if use_height_prediction: self.pred_height_layer = nn.Sequential( nn.Linear(self.emb_dim, self.emb_dim // 4), diff --git a/ppmat/models/ecformer/models/base_ecformer.py b/ppmat/models/ecformer/models/base_ecformer.py index 15c4bf16..e913d6d1 100644 --- a/ppmat/models/ecformer/models/base_ecformer.py +++ b/ppmat/models/ecformer/models/base_ecformer.py @@ -23,15 +23,15 @@ def fix_mask_for_paddle(mask, n_head=None): """ - 简单直接的掩码修复函数 + Simple and direct mask repair function Args: - mask: 输入掩码 - n_head: 注意力头数 (attention mask 时需要) + mask: input mask + n_head: number of attention heads (needed for attention mask) """ shape = mask.shape assert len(shape) == 2 - # 如果是 [batch_size, src_len] 但想用作 attention mask + # If [batch_size, src_len] but intended to be used as attention mask batch_size, s_len = shape # [32, 73] -> [32, 1, 73, 73] if n_head: @@ -41,11 +41,11 @@ def fix_mask_for_paddle(mask, n_head=None): class ECFormerBase(nn.Layer, ABC): - """ECFormer基类 - 所有谱图预测模型的抽象接口""" + """ECFormer Base Class - Abstract interface for all spectrum prediction models""" def __init__( self, - # GNN参数 + # GNN parameters full_atom_feature_dims, full_bond_feature_dims, bond_float_names, @@ -59,7 +59,7 @@ def __init__( graph_pooling="attention", use_geometry_enhanced=True, max_node_num=63, - # Transformer参数 + # Transformer parameters num_heads=4, tf_layers=2, tf_dropout=0.1, @@ -72,7 +72,7 @@ def __init__( self.max_peaks = max_peaks self.use_geometry_enhanced = use_geometry_enhanced - # 1. GNN节点编码器 + # 1. GNN node encoder self.gnn_node = GINNodeEmbedding( full_atom_feature_dims=full_atom_feature_dims, full_bond_feature_dims=full_bond_feature_dims, @@ -87,17 +87,17 @@ def __init__( use_geometry_enhanced=use_geometry_enhanced ) - # 2. 图池化层 + # 2. Graph pooling layer self.pool = self._build_pooling(graph_pooling, emb_dim) - # 3. Query嵌入(峰查询向量) + # 3. Query embedding (peak query vectors) self.query_embed = nn.Embedding(max_peaks, emb_dim) - # 4. Transformer编码器 + # 4. Transformer encoder self.tf_encoder = self._build_transformer(emb_dim, num_heads, tf_layers, tf_dropout) def _build_pooling(self, graph_pooling, emb_dim): - """构建图池化层""" + """Build graph pooling layer""" if graph_pooling == "sum": return global_add_pool elif graph_pooling == "mean": @@ -119,7 +119,7 @@ def _build_pooling(self, graph_pooling, emb_dim): raise ValueError(f"Invalid graph pooling type: {graph_pooling}") def _build_transformer(self, emb_dim, num_heads, num_layers, dropout): - """构建Transformer编码器""" + """Build Transformer encoder""" assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads" @@ -134,17 +134,17 @@ def _build_transformer(self, emb_dim, num_heads, num_layers, dropout): def encode_molecule( self, - x, # [N, F] 原子特征 - edge_index, # [2, E] 边索引 - edge_attr, # [E, D] 边特征 - batch_data, # [N] 批次信息 - # 几何增强相关 - ba_edge_index=None, # [2, E_ba] 键角图边索引 - ba_edge_attr=None, # [E_ba, D_ba] 键角图边特征 + x, # [N, F] atom features + edge_index, # [2, E] edge indices + edge_attr, # [E, D] edge features + batch_data, # [N] batch information + # Geometry enhancement related + ba_edge_index=None, # [2, E_ba] bond-angle graph edge indices + ba_edge_attr=None, # [E_ba, D_ba] bond-angle graph edge features ): - """分子编码器 - 纯Tensor输入""" + """Molecule encoder - pure Tensor input""" - # 1. GNN编码 + # 1. GNN encoding if self.use_geometry_enhanced and ba_edge_index is not None: h_node, _ = self.gnn_node( x=x, @@ -160,21 +160,21 @@ def encode_molecule( edge_attr=edge_attr, ) - # 2. 节点特征padding(需要batch信息) + # 2. Node feature padding (requires batch information) batch_size = batch_data[-1] + 1 node_feat, node_index = pad_node_features( h_node, batch_data, batch_size, self.max_node_num, self.emb_dim ) - # 3. 图池化 + # 3. Graph pooling h_graph = self.pool(h_node, batch_data).unsqueeze(1) - # 4. 拼接图特征和节点特征 + # 4. Concatenate graph features and node features total_node_feat = paddle.concat([h_graph, node_feat], axis=1) - # 5. 生成padding mask + # 5. Generate padding mask node_padding_mask = feat_padding_mask(node_index, self.max_node_num) pooling_padding_mask = paddle.zeros([node_padding_mask.shape[0], 1], dtype=paddle.get_default_dtype()) total_padding_mask = paddle.concat([pooling_padding_mask, node_padding_mask], axis=1) @@ -189,21 +189,21 @@ def forward(self, ba_edge_index: paddle.Tensor = None, ba_edge_attr: paddle.Tensor = None, query_mask: paddle.Tensor = None): - # 0. 数据类型检查 + # 0. Data type check if batch_data.dtype != paddle.int64: batch_data = batch_data.astype(paddle.int64) - # 1. 分子编码 + # 1. Molecule encoding node_feat, padding_mask, node_padding_mask = self.encode_molecule(x, edge_index, edge_attr,batch_data, ba_edge_index, ba_edge_attr) - # 2. 峰数预测(从图特征) + # 2. Peak number prediction (from graph features) graph_feat = node_feat[:, 0, :] pred_number = self.pred_number_layer(graph_feat) - # 3. Query准备 + # 3. Query preparation query_feat = self.query_embed.weight.unsqueeze(0).tile([node_feat.shape[0], 1, 1]) - # 推理时根据预测峰数生成query mask + # Generate query mask based on predicted peak number during inference if not self.training: pred_peak_num = pred_number.argmax(axis=1) peak_position = [ @@ -213,18 +213,18 @@ def forward(self, peak_position = paddle.to_tensor(peak_position) query_mask = get_key_padding_mask(peak_position) - # 4. Transformer编码 + # 4. Transformer encoding encoder_input = paddle.concat([node_feat, query_feat], axis=1) encoder_padding_mask = paddle.concat([padding_mask, query_mask], axis=1) encoder_output = self.tf_encoder(encoder_input, fix_mask_for_paddle(encoder_padding_mask)) - # 5. 峰位置和符号预测 + # 5. Peak position and sign prediction query_output = encoder_output[:, node_feat.shape[1]:, :] pred_position = self.pred_position_layer(query_output) pred_height = self.pred_height_layer(query_output) - # 6. 注意力权重(用于可视化) + # 6. Attention weights (for visualization) node_feat_output = encoder_output[:, :node_feat.shape[1], :] attn_weights = paddle.einsum("bid,bjd->bij", node_feat_output, diff --git a/ppmat/utils/colored_tqdm.py b/ppmat/utils/colored_tqdm.py index e108256a..f4762c4f 100644 --- a/ppmat/utils/colored_tqdm.py +++ b/ppmat/utils/colored_tqdm.py @@ -1,45 +1,45 @@ from tqdm import tqdm import time -import os;os.system("") #兼容windows +import os;os.system("") # Compatible with Windows def hex_to_ansi(hex_color: str, background: bool = False) -> str: """ - 将十六进制颜色转换为ANSI转义序列 + Convert hexadecimal color to ANSI escape sequence Args: - hex_color: 十六进制颜色,如 '#dda0a0' 或 'dda0a0' - background: True表示背景色,False表示前景色 + hex_color: Hexadecimal color, e.g., '#dda0a0' or 'dda0a0' + background: True for background color, False for foreground color Returns: - ANSI转义序列字符串,如 '\033[38;2;221;160;160m' + ANSI escape sequence string, e.g., '\033[38;2;221;160;160m' Example: >>> print(f"{hex_to_ansi('#dda0a0')}Hello{hex_to_ansi('#000000')} World") - >>> print(f"{hex_to_ansi('dda0a0', background=True)}背景色{hex_to_ansi.reset()}") + >>> print(f"{hex_to_ansi('dda0a0', background=True)}Background color{hex_to_ansi.reset()}") """ - # 移除#号并转换为小写 + # Remove # symbol and convert to lowercase hex_color = hex_color.lower().lstrip('#') - # 处理简写形式 (#fff -> ffffff) + # Handle shorthand form (#fff -> ffffff) if len(hex_color) == 3: hex_color = ''.join([c * 2 for c in hex_color]) - # 转换为RGB值 + # Convert to RGB values r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) - # ANSI真彩色序列 - # 38;2;R;G;B 为前景色,48;2;R;G;B 为背景色 + # ANSI true color sequence + # 38;2;R;G;B for foreground, 48;2;R;G;B for background code = 48 if background else 38 return f'\033[{code};2;{r};{g};{b}m' def rgb_to_ansi(r: int, g: int, b: int, background: bool = False) -> str: - """RGB值直接转ANSI""" + """Convert RGB values directly to ANSI""" code = 48 if background else 38 return f'\033[{code};2;{r};{g};{b}m' -# 重置颜色的ANSI码 +# ANSI code to reset color hex_to_ansi.reset = '\033[0m' class ColoredTqdm(tqdm): @@ -74,7 +74,7 @@ def update(self, n=1): if __name__ == "__main__": - # 使用示例 - for i in ColoredTqdm(range(10), desc="🌈 彩虹渐变", leave = False): - for j in ColoredTqdm(range(100), desc="🌈 彩虹渐变", leave = False): + # Usage example + for i in ColoredTqdm(range(10), desc="🌈 Rainbow gradient", leave = False): + for j in ColoredTqdm(range(100), desc="🌈 Rainbow gradient", leave = False): time.sleep(0.01) \ No newline at end of file diff --git a/ppmat/utils/graph_utils.py b/ppmat/utils/graph_utils.py index b012340b..394328d7 100644 --- a/ppmat/utils/graph_utils.py +++ b/ppmat/utils/graph_utils.py @@ -15,7 +15,7 @@ import paddle def index_transform(raw_index, batch_size): - """将压缩的批次索引还原为每个样本的节点索引列表""" + """Convert compressed batch indices to a list of node indices for each sample""" def get_index1(lst=None, batch_num=-1): return [index for (index, value) in enumerate(lst) if value == batch_num] @@ -28,14 +28,14 @@ def get_index1(lst=None, batch_num=-1): def get_key_padding_mask(tokens): - """生成key padding mask""" + """Generate key padding mask""" key_padding_mask = paddle.zeros(tokens.shape) key_padding_mask[tokens == -1] = -paddle.inf return key_padding_mask def feat_padding_mask(index, max_node_num): - """根据节点索引生成特征padding mask""" + """Generate feature padding mask based on node indices""" new_index = [] for itm_list in index: new_index.append(itm_list + [-1] * (max_node_num - len(itm_list))) @@ -44,7 +44,7 @@ def feat_padding_mask(index, max_node_num): def pad_node_features(molecule_features, batch_index, this_batch_size, max_node_num, emb_dim): - """将压缩的节点特征padding为 [batch, max_node, emb_dim] 格式""" + """Pad compressed node features to [batch, max_node, emb_dim] format""" index_list = index_transform(batch_index, this_batch_size) new_batch_list = [] diff --git a/ppmat/utils/place_env.py b/ppmat/utils/place_env.py index 347e3e07..83274396 100644 --- a/ppmat/utils/place_env.py +++ b/ppmat/utils/place_env.py @@ -4,106 +4,106 @@ class PlaceEnv: """ - 类版本的上下文管理器,也支持装饰器功能 + Class version of context manager, also supports decorator functionality """ def __init__(self, place: PlaceLike): """ - 初始化PlaceEnv + Initialize PlaceEnv Args: - place: paddle.CPUPlace() 或 paddle.CUDAPlace(0) 等设备对象 + place: device objects like paddle.CPUPlace() or paddle.CUDAPlace(0) """ self.place = place self.original_device = None def __enter__(self): - """进入上下文时调用""" - # 保存当前的设备设置 + """Called when entering the context""" + # Save current device setting self.original_device = paddle.get_device() paddle.set_device(self.place) return self def __exit__(self, exc_type, exc_val, exc_tb): - """退出上下文时调用""" - # 恢复原来的设备设置 + """Called when exiting the context""" + # Restore original device setting if self.original_device is not None: paddle.set_device(self.original_device) def __call__(self, func): """ - 使实例可以作为装饰器使用 + Allows instance to be used as a decorator Args: - func: 要装饰的函数 + func: function to be decorated Returns: - 装饰后的函数 + Decorated function """ @functools.wraps(func) def wrapper(*args, **kwargs): - # 使用with语句来临时改变设备设置 + # Use with statement to temporarily change device setting with self: return func(*args, **kwargs) return wrapper -# 使用示例 +# Usage example if __name__ == "__main__": - # 测试with语句 - print("=== 测试with语句 ===") - print(f"当前设备: {paddle.get_device()}") + # Test with statement + print("=== Testing with statement ===") + print(f"Current device: {paddle.get_device()}") - print("\n=== 测试类版本with语句 ===") + print("\n=== Testing class version with statement ===") with PlaceEnv(paddle.CPUPlace()): - print(f"with块内设备: {paddle.get_device()}") + print(f"Device inside with block: {paddle.get_device()}") y = paddle.ones([2, 3]) - print(f"创建的张量: {y}") + print(f"Created tensor: {y}") - print(f"with块外设备: {paddle.get_device()}") + print(f"Device outside with block: {paddle.get_device()}") - # 测试装饰器功能 - print("\n=== 测试装饰器功能 ===") + # Test decorator functionality + print("\n=== Testing decorator functionality ===") @PlaceEnv(paddle.CPUPlace()) def cpu_function(): - """这个函数会在CPU上运行""" - print(f"函数内设备: {paddle.get_device()}") + """This function will run on CPU""" + print(f"Device inside function: {paddle.get_device()}") return paddle.rand([2, 2]) - # 检查是否有GPU可用 + # Check if GPU is available if paddle.device.cuda.device_count() > 0: @PlaceEnv(paddle.CUDAPlace(0)) def gpu_function(): - """这个函数会在GPU上运行""" - print(f"函数内设备: {paddle.get_device()}") + """This function will run on GPU""" + print(f"Device inside function: {paddle.get_device()}") return paddle.rand([2, 2]) - # 调用装饰后的函数 - print("调用cpu_function:") + # Call decorated functions + print("Calling cpu_function:") result_cpu = cpu_function() - print(f"函数执行后设备: {paddle.get_device()}") - print(f"结果: {result_cpu}") + print(f"Device after function execution: {paddle.get_device()}") + print(f"Result: {result_cpu}") if paddle.device.cuda.device_count() > 0: - print("\n调用gpu_function:") + print("\nCalling gpu_function:") result_gpu = gpu_function() - print(f"函数执行后设备: {paddle.get_device()}") - print(f"结果: {result_gpu}") + print(f"Device after function execution: {paddle.get_device()}") + print(f"Result: {result_gpu}") - print("\n=== 测试多层嵌套 ===") - print(f"初始设备: {paddle.get_device()}") + print("\n=== Testing multiple nesting ===") + print(f"Initial device: {paddle.get_device()}") with PlaceEnv(paddle.CPUPlace()): - print(f"第一层with内设备: {paddle.get_device()}") + print(f"Device inside first with block: {paddle.get_device()}") if paddle.device.cuda.device_count() > 0: with PlaceEnv(paddle.CUDAPlace(0)): - print(f"第二层with内设备: {paddle.get_device()}") + print(f"Device inside second with block: {paddle.get_device()}") z = paddle.rand([2, 2]) - print(f"创建的张量: {z}") + print(f"Created tensor: {z}") - print(f"回到第一层with设备: {paddle.get_device()}") + print(f"Back to first with block device: {paddle.get_device()}") - print(f"最终设备: {paddle.get_device()}") \ No newline at end of file + print(f"Final device: {paddle.get_device()}") \ No newline at end of file diff --git a/spectrum_elucidation/ecformer/train.py b/spectrum_elucidation/ecformer/train.py index 56b1db88..5df1a677 100644 --- a/spectrum_elucidation/ecformer/train.py +++ b/spectrum_elucidation/ecformer/train.py @@ -30,7 +30,7 @@ def main(): - # 解析参数 + # Parse arguments parser = argparse.ArgumentParser(description="ECDFormer for ECD Spectrum Prediction") parser.add_argument( "-c", "--config", @@ -63,12 +63,12 @@ def main(): args, dynamic_args = parser.parse_known_args() - # 加载配置 + # Load configuration config = OmegaConf.load(args.config) cli_config = OmegaConf.from_dotlist(dynamic_args) config = OmegaConf.merge(config, cli_config) - # 根据命令行参数覆盖Global配置 + # Override Global configuration based on command line arguments if args.eval_only: config.Global.do_train = False config.Global.do_eval = True @@ -84,30 +84,30 @@ def main(): config.Global.do_predict = True config.Dataset.predict.data_path = args.predict - # 保存配置 + # Save configuration if dist.get_rank() == 0: os.makedirs(config.Trainer.output_dir, exist_ok=True) config_name = os.path.basename(args.config) OmegaConf.save(config, osp.join(config.Trainer.output_dir, config_name)) - # 转换为字典 + # Convert to dictionary config = OmegaConf.to_container(config, resolve=True) - # 初始化日志 + # Initialize logging logger_path = osp.join(config["Trainer"]["output_dir"], "run.log") logger.init_logger(log_file=logger_path) logger.info(f"Logger saved to {logger_path}") logger.info(f"Config: {config}") - # 设置随机种子 + # Set random seed seed = config["Trainer"].get("seed", 42) misc.set_random_seed(seed) logger.info(f"Set random seed to {seed}") - # 设置信号处理 + # Set signal handlers set_signal_handlers() - # 构建数据加载器 + # Build data loaders dataloaders = {} if config["Global"].get("do_train", True): @@ -136,18 +136,18 @@ def main(): dataloaders["predict"] = build_dataloader(predict_cfg) logger.info(f"Prediction dataset loaded, size: {len(dataloaders['predict'].dataset)}") - # 构建模型 + # Build model model_cfg = config["Model"] model = build_model(model_cfg) logger.info(f"Model built: {model_cfg['__class_name__']}") - # 打印模型参数量 + # Print model parameters count total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if not p.stop_gradient) logger.info(f"Total parameters: {total_params / 1e6:.2f}M") logger.info(f"Trainable parameters: {trainable_params / 1e6:.2f}M") - # 构建优化器和学习率调度器 + # Build optimizer and learning rate scheduler optimizer = None lr_scheduler = None @@ -163,7 +163,7 @@ def main(): ) logger.info(f"Optimizer built: {config['Optimizer']['__class_name__']}") - # 构建训练器 + # Build trainer trainer = ECDFormerTrainer( config=config["Trainer"], model=model, @@ -173,7 +173,7 @@ def main(): lr_scheduler=lr_scheduler, ) - # 恢复检查点 + # Resume from checkpoint if args.resume is not None: logger.info(f"Resuming from checkpoint: {args.resume}") save_load.load_checkpoint( @@ -183,7 +183,7 @@ def main(): trainer.scaler, ) - # 执行训练/评估/预测 + # Execute training/evaluation/prediction if config["Global"].get("do_train", True): logger.info("Starting training...") trainer.train() @@ -193,7 +193,7 @@ def main(): if "val" in dataloaders: time_info, loss_info, metric_info = trainer.eval(dataloaders["val"]) - # 打印详细指标 + # Print detailed metrics msg = "Validation Results:" for key, meter in metric_info.items(): msg += f" | {key}: {meter.avg:.6f}" @@ -218,7 +218,7 @@ def main(): if "predict" in dataloaders: results = trainer.predict(dataloaders["predict"]) - # 保存预测结果 + # Save prediction results import json output_path = osp.join(config["Trainer"]["output_dir"], "predictions.json") with open(output_path, "w") as f: diff --git a/spectrum_elucidation/ecformer/trainer.py b/spectrum_elucidation/ecformer/trainer.py index 4291375d..19a9845b 100644 --- a/spectrum_elucidation/ecformer/trainer.py +++ b/spectrum_elucidation/ecformer/trainer.py @@ -16,7 +16,7 @@ import time from collections import OrderedDict -from typing import Dict, Optional, List, Any, Union +from typing import Dict, Optional, Any import numpy as np import paddle From 4a8c2730d971c4eaaf1c6aa630aaa4a8eb5e30a2 Mon Sep 17 00:00:00 2001 From: PlumBlossomMaid <1589524335@qq.com> Date: Mon, 23 Mar 2026 20:15:19 +0800 Subject: [PATCH 16/16] Optimize training algorithm --- ppmat/trainer/__init__.py | 3 +- .../trainer/ecformer_trainer.py | 132 +++++++++++++----- spectrum_elucidation/ecformer/train.py | 11 +- 3 files changed, 107 insertions(+), 39 deletions(-) rename spectrum_elucidation/ecformer/trainer.py => ppmat/trainer/ecformer_trainer.py (81%) diff --git a/ppmat/trainer/__init__.py b/ppmat/trainer/__init__.py index 41a848b7..49b7793a 100644 --- a/ppmat/trainer/__init__.py +++ b/ppmat/trainer/__init__.py @@ -1,6 +1,7 @@ from ppmat.trainer.base_trainer import BaseTrainer +from ppmat.trainer.ecformer_trainer import ECFormerTrainer -__all__ = ["BaseTrainer", "build_trainer"] +__all__ = ["BaseTrainer", "build_trainer", "ECFormerTrainer"] def build_trainer(cfg, **kwargs): diff --git a/spectrum_elucidation/ecformer/trainer.py b/ppmat/trainer/ecformer_trainer.py similarity index 81% rename from spectrum_elucidation/ecformer/trainer.py rename to ppmat/trainer/ecformer_trainer.py index 19a9845b..142a006c 100644 --- a/spectrum_elucidation/ecformer/trainer.py +++ b/ppmat/trainer/ecformer_trainer.py @@ -16,13 +16,14 @@ import time from collections import OrderedDict -from typing import Dict, Optional, Any +from typing import Dict, Optional, List, Any, Union import numpy as np import paddle from paddle import nn from paddle import optimizer as optim from paddle.distributed import fleet +from tqdm import tqdm from ppmat.trainer.base_trainer import BaseTrainer from ppmat.utils import logger @@ -34,9 +35,9 @@ from ppmat.losses.ir_loss import IRLoss -class ECDFormerTrainer(BaseTrainer): +class ECFormerTrainer(BaseTrainer): """ - ECDFormer trainer supporting both ECD and IR tasks with dedicated metrics. + ECFormer trainer supporting both ECD and IR tasks with dedicated metrics. Features: - Automatic task detection from model class name @@ -100,6 +101,9 @@ def __init__( # Initialize task-specific metrics (will be attached via attach_metrics) self.train_metrics = None self.eval_metrics = None + + # Cache for dataset building to avoid repeated decompression + self._dataset_cache = {} def attach_metrics(self, metric_cfg=None, **runtime_objs): """ @@ -148,6 +152,19 @@ def train_epoch(self, dataloader: paddle.io.DataLoader): self.state.max_steps_in_train_epoch = len(dataloader) self.state.step_in_train_epoch = 0 + # Determine if this is the main process for progress bar + is_main_process = paddle.distributed.get_rank() == 0 if paddle.distributed.is_initialized() else True + + # Create progress bar only on main process + if is_main_process: + pbar = tqdm( + total=len(dataloader), + desc=f"Epoch {self.state.epoch}/{self.max_epochs}", + unit="batch", + ncols=100, + leave=True + ) + # Timers reader_tic = time.perf_counter() batch_tic = time.perf_counter() @@ -213,10 +230,31 @@ def train_epoch(self, dataloader: paddle.io.DataLoader): if self.lr_scheduler is not None and not self.lr_scheduler.by_epoch: self.lr_scheduler.step() - # Logging - if (self.state.step_in_train_epoch % self.log_freq == 0 or - self.state.step_in_train_epoch == self.state.max_steps_in_train_epoch): + # Update progress bar on main process + if is_main_process: + # Prepare current metrics for display + current_metrics = {} + current_metrics["lr"] = f"{self.optimizer.get_lr():.2e}" + for name, meter in loss_info.items(): + # Show only the most important metrics + if "loss" == name.lower() or "acc" in name.lower(): + current_metrics[name] = f"{meter.val:.4f}" + # Add streaming metrics if available (only show a few key metrics to avoid clutter) + stream_metrics = self._compute_streaming_metrics(stage='train') + for name, value in stream_metrics.items(): + if isinstance(value, (int, float)): + # Show only the most important metrics + if "loss" == name.lower() or "acc" in name.lower(): + short_name = name.split('/')[-1] if '/' in name else name + current_metrics[short_name] = f"{value:.4f}" + + # Update progress bar postfix + pbar.set_postfix(current_metrics, refresh=False) + pbar.update(1) + + # Write to visualization tools (every log_freq steps) + if self.state.step_in_train_epoch % self.log_freq == 0: logs = OrderedDict() logs["lr"] = self.optimizer.get_lr() for name, meter in time_info.items(): @@ -224,7 +262,7 @@ def train_epoch(self, dataloader: paddle.io.DataLoader): for name, meter in loss_info.items(): logs[name] = meter.val - # Add streaming metrics if available + # Add streaming metrics stream_metrics = self._compute_streaming_metrics(stage='train') for name, value in stream_metrics.items(): if isinstance(value, (int, float)): @@ -233,15 +271,7 @@ def train_epoch(self, dataloader: paddle.io.DataLoader): metric_info[name] = AverageMeter(name) metric_info[name].update(float(value), 1) - display_logs = self._filter_out_dict(logs, stage="train") - - msg = f"Train: Epoch [{self.state.epoch}/{self.max_epochs}]" - msg += f" | Step: [{self.state.step_in_train_epoch}/{self.state.max_steps_in_train_epoch}]" - for key, val in display_logs.items(): - msg += f" | {key}: {val:.6f}" - logger.info(msg) - - # Write to visualization tools + # Write to visualization tools (not to console) logger.scalar( tag="train(step)", metric_dict=logs, @@ -254,6 +284,10 @@ def train_epoch(self, dataloader: paddle.io.DataLoader): batch_tic = time.perf_counter() reader_tic = time.perf_counter() + # Close progress bar + if is_main_process: + pbar.close() + # Compute epoch-level streaming metrics epoch_stream_metrics = self._compute_streaming_metrics(stage='train') for name, value in epoch_stream_metrics.items(): @@ -262,6 +296,9 @@ def train_epoch(self, dataloader: paddle.io.DataLoader): metric_info[name] = AverageMeter(name) metric_info[name].update(float(value), 1) + # Log epoch summary to file (not to console) + logger.info(f"Epoch {self.state.epoch} completed. Avg Loss: {loss_info.get('loss', AverageMeter('loss')).avg:.4f}") + return time_info, loss_info, metric_info def eval_epoch(self, dataloader: paddle.io.DataLoader): @@ -291,6 +328,19 @@ def eval_epoch(self, dataloader: paddle.io.DataLoader): if hasattr(m, 'reset'): m.reset() + # Determine if this is the main process for progress bar + is_main_process = paddle.distributed.get_rank() == 0 if paddle.distributed.is_initialized() else True + + # Create progress bar only on main process + if is_main_process: + pbar = tqdm( + total=len(dataloader), + desc=f"Eval Epoch {self.state.epoch}/{self.max_epochs}", + unit="batch", + ncols=80, + leave=False + ) + reader_tic = time.perf_counter() batch_tic = time.perf_counter() @@ -342,27 +392,23 @@ def eval_epoch(self, dataloader: paddle.io.DataLoader): self.state.step_in_eval_epoch += 1 - # Logging - if (self.state.step_in_eval_epoch % self.log_freq == 0 or - self.state.step_in_eval_epoch == self.state.max_steps_in_eval_epoch): - - logs = OrderedDict() - for name, meter in time_info.items(): - logs[name] = meter.val + # Update progress bar on main process + if is_main_process: + current_metrics = {} for name, meter in loss_info.items(): - logs[name] = meter.val - - display_logs = self._filter_out_dict(logs, stage="eval") - - msg = f"Eval: Epoch [{self.state.epoch}/{self.max_epochs}]" - msg += f" | Step: [{self.state.step_in_eval_epoch}/{self.state.max_steps_in_eval_epoch}]" - for key, val in display_logs.items(): - msg += f" | {key}: {val:.6f}" - logger.info(msg) + # Show only the most important metrics + if "loss" == name.lower() or "acc" in name.lower(): + current_metrics[name] = f"{meter.val:.4f}" + pbar.set_postfix(current_metrics, refresh=False) + pbar.update(1) batch_tic = time.perf_counter() reader_tic = time.perf_counter() + # Close progress bar + if is_main_process: + pbar.close() + # Compute epoch-level metrics from streaming accumulators epoch_metrics = self._compute_streaming_metrics(stage='eval') for name, value in epoch_metrics.items(): @@ -371,6 +417,9 @@ def eval_epoch(self, dataloader: paddle.io.DataLoader): metric_info[name] = AverageMeter(name) metric_info[name].update(float(value), len(dataloader.dataset)) + # Log evaluation summary to file (not to console) + logger.info(f"Eval Epoch {self.state.epoch} completed. Avg Loss: {loss_info.get('loss', AverageMeter('loss')).avg:.4f}") + return time_info, loss_info, metric_info def predict(self, dataloader: paddle.io.DataLoader) -> Dict[str, Any]: @@ -390,8 +439,21 @@ def predict(self, dataloader: paddle.io.DataLoader) -> Dict[str, Any]: all_attn_weights = [] all_peak_nums = [] + # Determine if this is the main process for progress bar + is_main_process = paddle.distributed.get_rank() == 0 if paddle.distributed.is_initialized() else True + + # Create progress bar only on main process + if is_main_process: + pbar = tqdm( + total=len(dataloader), + desc="Predicting", + unit="batch", + ncols=80, + leave=True + ) + with paddle.no_grad(): - for batch in dataloader: + for batch in pbar if is_main_process else dataloader: model_inputs, _ = batch # No targets needed for inference predictions = self.model( @@ -438,6 +500,10 @@ def predict(self, dataloader: paddle.io.DataLoader) -> Dict[str, Any]: 'mask': predictions['attention']['mask'][i] if predictions['attention']['mask'] else None }) + # Close progress bar + if is_main_process: + pbar.close() + return { 'peak_number': all_peak_nums, 'peak_position': all_pos_pred, diff --git a/spectrum_elucidation/ecformer/train.py b/spectrum_elucidation/ecformer/train.py index 5df1a677..58b332ea 100644 --- a/spectrum_elucidation/ecformer/train.py +++ b/spectrum_elucidation/ecformer/train.py @@ -25,8 +25,9 @@ from ppmat.optimizer import build_optimizer from ppmat.utils import logger from ppmat.utils import misc +from ppmat.utils import save_load -from spectrum_elucidation.ecformer.trainer import ECDFormerTrainer +from ppmat.trainer import ECFormerTrainer def main(): @@ -35,7 +36,7 @@ def main(): parser.add_argument( "-c", "--config", type=str, - default="./spectrum_elucidation/ecformer/configs/ecd.yaml", + default="../configs/ecformer/ecd.yaml", help="Path to config file", ) parser.add_argument( @@ -141,7 +142,7 @@ def main(): model = build_model(model_cfg) logger.info(f"Model built: {model_cfg['__class_name__']}") - # Print model parameters count + # Print model parameter count total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if not p.stop_gradient) logger.info(f"Total parameters: {total_params / 1e6:.2f}M") @@ -164,7 +165,7 @@ def main(): logger.info(f"Optimizer built: {config['Optimizer']['__class_name__']}") # Build trainer - trainer = ECDFormerTrainer( + trainer = ECFormerTrainer( config=config["Trainer"], model=model, train_dataloader=dataloaders.get("train"), @@ -183,7 +184,7 @@ def main(): trainer.scaler, ) - # Execute training/evaluation/prediction + # Execute training / evaluation / prediction if config["Global"].get("do_train", True): logger.info("Starting training...") trainer.train()