Skip to content

Commit 4bc3b27

Browse files
authored
fix single precision error (#1212)
1 parent 7045ba5 commit 4bc3b27

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

deepmd/descriptor/se_a.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Tuple, List, Dict, Any
44

55
from deepmd.env import tf
6-
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, get_np_precision
6+
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
77
from deepmd.utils.argcheck import list_to_doc
88
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
99
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
@@ -13,7 +13,7 @@
1313
from deepmd.utils.tabulate import DPTabulate
1414
from deepmd.utils.type_embed import embed_atom_type
1515
from deepmd.utils.sess import run_sess
16-
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables
16+
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
1717
from .descriptor import Descriptor
1818
from .se import DescrptSe
1919

@@ -133,7 +133,6 @@ def __init__ (self,
133133
self.compress_activation_fn = get_activation_func(activation_function)
134134
self.filter_activation_fn = get_activation_func(activation_function)
135135
self.filter_precision = get_precision(precision)
136-
self.filter_np_precision = get_np_precision(precision)
137136
self.exclude_types = set()
138137
for tt in exclude_types:
139138
assert(len(tt) == 2)
@@ -687,7 +686,7 @@ def _filter_lower(
687686
net = 'filter_-1_net_' + str(type_i)
688687
else:
689688
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
690-
return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
689+
return op_module.tabulate_fusion(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
691690
else:
692691
if (not is_exclude):
693692
xyz_scatter = embedding_net(

0 commit comments

Comments
 (0)