Skip to content

Commit d20942c

Browse files
authored
Merge pull request #9 from vanderhe/mixedTargetsFnetout
Support for data extraction when trained on mixed target types
2 parents 1b0e991 + 5cb20a3 commit d20942c

10 files changed

+496
-99
lines changed

src/fortformat/fnetout.py

Lines changed: 148 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def __init__(self, fname):
4040
raise FnetoutError('Invalid running mode specification.')
4141

4242
output = fnetoutfile['fnetout']['output']
43+
44+
# read number of datapoints
4345
self._ndatapoints = output.attrs.get('ndatapoints')
4446
if len(self._ndatapoints) == 1:
4547
# number of datapoints stored in array of size 1
@@ -48,6 +50,28 @@ def __init__(self, fname):
4850
msg = "Error while reading fnetout file '" + self._fname + \
4951
"'. Unrecognized number of datapoints obtained."
5052
raise FnetoutError(msg)
53+
54+
# read number of system-wide targets
55+
self._nglobaltargets = output.attrs.get('nglobaltargets')
56+
if len(self._nglobaltargets) == 1:
57+
# number of system-wide targets stored in array of size 1
58+
self._nglobaltargets = self._nglobaltargets[0]
59+
else:
60+
msg = "Error while reading fnetout file '" + self._fname + \
61+
"'. Unrecognized number of global targets obtained."
62+
raise FnetoutError(msg)
63+
64+
# read number of atomic targets
65+
self._natomictargets = output.attrs.get('natomictargets')
66+
if len(self._natomictargets) == 1:
67+
# number of atomic targets stored in array of size 1
68+
self._natomictargets = self._natomictargets[0]
69+
else:
70+
msg = "Error while reading fnetout file '" + self._fname + \
71+
"'. Unrecognized number of atomic targets obtained."
72+
raise FnetoutError(msg)
73+
74+
# read force specification
5175
self._tforces = output.attrs.get('tforces')
5276
# account for legacy files where no force entry is present
5377
if self._tforces is None:
@@ -60,18 +84,6 @@ def __init__(self, fname):
6084
"'. Unrecognized force specification obtained."
6185
raise FnetoutError(msg)
6286

63-
self._targettype = \
64-
output.attrs.get('targettype').decode('UTF-8').strip()
65-
if not self._targettype in ('atomic', 'global'):
66-
raise FnetoutError('Invalid running mode obtained.')
67-
68-
# get number of atomic or global predictions/targets
69-
self._npredictions = np.shape(
70-
np.array(output['datapoint1']['output']))[1]
71-
72-
if self._mode == 'validate':
73-
self._npredictions = int(self._npredictions / 2)
74-
7587

7688
@property
7789
def mode(self):
@@ -100,16 +112,29 @@ def ndatapoints(self):
100112

101113

102114
@property
103-
def targettype(self):
104-
'''Defines property, providing the target type.
115+
def nglobaltargets(self):
116+
'''Defines property, providing the number of system-wide targets.
117+
118+
Returns:
119+
120+
nglobaltargets (int): number of global targets per datapoint
121+
122+
'''
123+
124+
return self._nglobaltargets
125+
126+
127+
@property
128+
def natomictargets(self):
129+
'''Defines property, providing the number of atomic targets.
105130
106131
Returns:
107132
108-
targettype (str): type of targets the network was trained on
133+
natomictargets (int): number of atomic targets per datapoint
109134
110135
'''
111136

112-
return self._targettype
137+
return self._natomictargets
113138

114139

115140
@property
@@ -126,75 +151,134 @@ def tforces(self):
126151

127152

128153
@property
129-
def predictions(self):
130-
'''Defines property, providing the predictions of Fortnet.
154+
def globalpredictions(self):
155+
'''Defines property, providing the system-wide predictions of Fortnet.
131156
132157
Returns:
133158
134-
predictions (list or 2darray): predictions of the network
159+
predictions (2darray): predictions of the network
135160
136161
'''
137162

163+
if not self._nglobaltargets > 0:
164+
return None
165+
166+
predictions = np.empty((self._ndatapoints, self._nglobaltargets),
167+
dtype=float)
168+
138169
with h5py.File(self._fname, 'r') as fnetoutfile:
139170
output = fnetoutfile['fnetout']['output']
140-
if self._targettype == 'atomic':
141-
predictions = []
142-
for idata in range(self._ndatapoints):
143-
dataname = 'datapoint' + str(idata + 1)
144-
if self._mode == 'validate':
145-
predictions.append(
146-
np.array(output[dataname]['output'],
147-
dtype=float)[:, :self._npredictions])
148-
else:
149-
predictions.append(
150-
np.array(output[dataname]['output'], dtype=float))
151-
else:
152-
predictions = np.empty(
153-
(self._ndatapoints, self._npredictions), dtype=float)
154-
for idata in range(self._ndatapoints):
155-
dataname = 'datapoint' + str(idata + 1)
156-
if self._mode == 'validate':
157-
predictions[idata, :] = \
158-
np.array(output[dataname]['output'],
159-
dtype=float)[0, :self._npredictions]
160-
else:
161-
predictions[idata, :] = \
162-
np.array(output[dataname]['output'],
163-
dtype=float)[0, :]
171+
for idata in range(self._ndatapoints):
172+
dataname = 'datapoint' + str(idata + 1)
173+
predictions[idata, :] = np.array(
174+
output[dataname]['globalpredictions'],
175+
dtype=float)
164176

165177
return predictions
166178

167179

168180
@property
169-
def targets(self):
170-
'''Defines property, providing the targets during training.
181+
def globalpredictions_atomic(self):
182+
'''Defines property, providing the (atom-resolved) system-wide
183+
predictions of Fortnet.
171184
172185
Returns:
173186
174-
targets (list or 2darray): targets during training
187+
predictions (list): predictions of the network
175188
176189
'''
177190

178-
if self._mode == 'predict':
191+
if not self._nglobaltargets > 0:
179192
return None
180193

194+
predictions = []
195+
181196
with h5py.File(self._fname, 'r') as fnetoutfile:
182197
output = fnetoutfile['fnetout']['output']
183-
if self._targettype == 'atomic':
184-
targets = []
185-
for idata in range(self._ndatapoints):
186-
dataname = 'datapoint' + str(idata + 1)
187-
targets.append(
188-
np.array(output[dataname]['output'],
189-
dtype=float)[:, self._npredictions:])
190-
else:
191-
targets = np.empty(
192-
(self._ndatapoints, self._npredictions), dtype=float)
193-
for idata in range(self._ndatapoints):
194-
dataname = 'datapoint' + str(idata + 1)
195-
targets[idata, :] = \
196-
np.array(output[dataname]['output'],
197-
dtype=float)[0, self._npredictions:]
198+
for idata in range(self._ndatapoints):
199+
dataname = 'datapoint' + str(idata + 1)
200+
predictions.append(
201+
np.array(output[dataname]['rawpredictions'],
202+
dtype=float)[:, 0:self._nglobaltargets])
203+
204+
return predictions
205+
206+
207+
@property
208+
def atomicpredictions(self):
209+
'''Defines property, providing the atomic predictions of Fortnet.
210+
211+
Returns:
212+
213+
predictions (list): predictions of the network
214+
215+
'''
216+
217+
if not self._natomictargets > 0:
218+
return None
219+
220+
predictions = []
221+
222+
with h5py.File(self._fname, 'r') as fnetoutfile:
223+
output = fnetoutfile['fnetout']['output']
224+
for idata in range(self._ndatapoints):
225+
dataname = 'datapoint' + str(idata + 1)
226+
predictions.append(
227+
np.array(output[dataname]
228+
['rawpredictions'], dtype=float)
229+
[:, self._nglobaltargets:])
230+
231+
return predictions
232+
233+
234+
@property
235+
def globaltargets(self):
236+
'''Defines property, providing the system-wide targets during training.
237+
238+
Returns:
239+
240+
targets (2darray): system-wide targets during training
241+
242+
'''
243+
244+
if self._mode == 'predict' or self._nglobaltargets == 0:
245+
return None
246+
247+
targets = np.empty((self._ndatapoints, self._nglobaltargets),
248+
dtype=float)
249+
250+
with h5py.File(self._fname, 'r') as fnetoutfile:
251+
output = fnetoutfile['fnetout']['output']
252+
for idata in range(self._ndatapoints):
253+
dataname = 'datapoint' + str(idata + 1)
254+
targets[idata, :] = np.array(
255+
output[dataname]['globaltargets'],
256+
dtype=float)
257+
258+
return targets
259+
260+
261+
@property
262+
def atomictargets(self):
263+
'''Defines property, providing the atomic targets during training.
264+
265+
Returns:
266+
267+
targets (list): atomic targets during training
268+
269+
'''
270+
271+
if self._mode == 'predict' or self._natomictargets == 0:
272+
return None
273+
274+
targets = []
275+
276+
with h5py.File(self._fname, 'r') as fnetoutfile:
277+
output = fnetoutfile['fnetout']['output']
278+
for idata in range(self._ndatapoints):
279+
dataname = 'datapoint' + str(idata + 1)
280+
targets.append(np.array(output[dataname]
281+
['atomictargets'], dtype=float))
198282

199283
return targets
200284

@@ -214,10 +298,10 @@ def forces(self):
214298

215299
tmp1 = []
216300

217-
if self._targettype == 'atomic':
301+
if self._natomictargets > 0:
218302
msg = "Error while extracting forces from fnetout file '" \
219303
+ self._fname + \
220-
"'. Forces only supplied for global property targets."
304+
"'. Forces supplied for global property targets only."
221305
raise FnetoutError(msg)
222306

223307
with h5py.File(self._fname, 'r') as fnetoutfile:

test/common.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,15 @@ def compare_fnetout_references(ref, fname, atol=ATOL, rtol=RTOL):
8080

8181
mode = fnetout.mode
8282
ndatapoints = fnetout.ndatapoints
83-
targettype = fnetout.targettype
84-
targets = fnetout.targets
83+
nglobaltargets = fnetout.nglobaltargets
84+
natomictargets = fnetout.natomictargets
85+
globaltargets = fnetout.globaltargets
86+
atomictargets = fnetout.atomictargets
8587
tforces = fnetout.tforces
8688
forces = fnetout.forces
87-
predictions = fnetout.predictions
89+
atomicpredictions = fnetout.atomicpredictions
90+
globalpredictions = fnetout.globalpredictions
91+
globalpredictions_atomic = fnetout.globalpredictions_atomic
8892

8993
equal = mode == ref['mode']
9094

@@ -98,19 +102,35 @@ def compare_fnetout_references(ref, fname, atol=ATOL, rtol=RTOL):
98102
warnings.warn('Mismatch in number of training datapoints.')
99103
return False
100104

101-
equal = targettype == ref['targettype']
105+
equal = nglobaltargets == ref['nglobaltargets']
102106

103107
if not equal:
104-
warnings.warn('Mismatch in target type specification.')
108+
warnings.warn('Mismatch in number of system-wide targets.')
105109
return False
106110

107-
if ref['targets'] is not None:
108-
for ii, target in enumerate(targets):
109-
equal = np.allclose(target, ref['targets'][ii],
111+
equal = natomictargets == ref['natomictargets']
112+
113+
if not equal:
114+
warnings.warn('Mismatch in number of atomic targets.')
115+
return False
116+
117+
if ref['globaltargets'] is not None:
118+
for ii, target in enumerate(globaltargets):
119+
equal = np.allclose(target, ref['globaltargets'][ii],
110120
rtol=rtol, atol=atol)
111121

112122
if not equal:
113-
warnings.warn('Mismatch in targets of datapoint ' \
123+
warnings.warn('Mismatch in global targets of datapoint ' \
124+
+ str(ii + 1) + '.')
125+
return False
126+
127+
if ref['atomictargets'] is not None:
128+
for ii, target in enumerate(atomictargets):
129+
equal = np.allclose(target, ref['atomictargets'][ii],
130+
rtol=rtol, atol=atol)
131+
132+
if not equal:
133+
warnings.warn('Mismatch in atomic targets of datapoint ' \
114134
+ str(ii + 1) + '.')
115135
return False
116136

@@ -133,17 +153,39 @@ def compare_fnetout_references(ref, fname, atol=ATOL, rtol=RTOL):
133153
str(itarget + 1) + '.')
134154
return False
135155

136-
for ii, prediction in enumerate(predictions):
137-
equal = np.allclose(prediction, ref['predictions'][ii],
138-
rtol=rtol, atol=atol)
156+
if ref['atomicpredictions'] is not None:
157+
for ii, prediction in enumerate(atomicpredictions):
158+
equal = np.allclose(prediction, ref['atomicpredictions'][ii],
159+
rtol=rtol, atol=atol)
139160

140-
if not equal:
141-
warnings.warn('Mismatch in predictions of datapoint ' \
142-
+ str(ii + 1) + '.')
143-
return False
161+
if not equal:
162+
warnings.warn('Mismatch in atomic predictions of datapoint ' \
163+
+ str(ii + 1) + '.')
164+
return False
165+
166+
if ref['globalpredictions'] is not None:
167+
for ii, target in enumerate(globalpredictions):
168+
equal = np.allclose(target, ref['globalpredictions'][ii],
169+
rtol=rtol, atol=atol)
170+
171+
if not equal:
172+
warnings.warn('Mismatch in global predictions' \
173+
+ ' of datapoint ' + str(ii + 1) + '.')
174+
return False
175+
176+
if ref['globalpredictions_atomic'] is not None:
177+
for ii, target in enumerate(globalpredictions_atomic):
178+
equal = np.allclose(target, ref['globalpredictions_atomic'][ii],
179+
rtol=rtol, atol=atol)
180+
181+
if not equal:
182+
warnings.warn('Mismatch in (atom-resolved) global predictions' \
183+
+ ' of datapoint ' + str(ii + 1) + '.')
184+
return False
144185

145186
return True
146187

188+
147189
def get_mixed_geometries():
148190
'''Generates six geometries with(out) periodic boundary conditions.'''
149191

176 Bytes
Binary file not shown.
1.5 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
1.5 KB
Binary file not shown.
Binary file not shown.
2.83 KB
Binary file not shown.

0 commit comments

Comments
 (0)