Skip to content

Commit 4a97e59

Browse files
committed
merge
1 parent f8a204c commit 4a97e59

2 files changed

Lines changed: 1 addition & 270 deletions

File tree

opennn/adaptive_moment_estimation.cpp

Lines changed: 0 additions & 269 deletions
Original file line numberDiff line numberDiff line change
@@ -111,275 +111,6 @@ void AdaptiveMomentEstimation::set_maximum_time(const type new_maximum_time)
111111
maximum_time = new_maximum_time;
112112
}
113113

114-
/*
115-
TrainingResults AdaptiveMomentEstimation::train()
116-
{
117-
if(!loss || !loss->has_neural_network() || !loss->has_dataset())
118-
return TrainingResults();
119-
120-
TrainingResults results(maximum_epochs + 1);
121-
122-
check();
123-
124-
if(display) cout << "Training with adaptive moment estimation \"Adam\" ..." << endl;
125-
126-
// Dataset
127-
128-
Dataset* dataset = loss->get_dataset();
129-
130-
if(!dataset)
131-
throw runtime_error("Dataset is null.");
132-
133-
const bool has_validation = dataset->has_validation();
134-
135-
const bool is_text_classification_model = is_instance_of<CrossEntropyError3d>(loss);
136-
137-
const vector<Index> input_feature_indices = dataset->get_feature_indices("Input");
138-
const vector<Index> target_feature_indices = dataset->get_feature_indices("Target");
139-
const vector<Index> decoder_feature_indices = dataset->get_feature_indices("Decoder");
140-
141-
const vector<Index> training_sample_indices = dataset->get_sample_indices("Training");
142-
const vector<Index> validation_sample_indices = dataset->get_sample_indices("Validation");
143-
144-
const Index training_samples_number = dataset->get_samples_number("Training");
145-
const Index validation_samples_number = dataset->get_samples_number("Validation");
146-
147-
const Index training_batch_size = min(training_samples_number, batch_size);
148-
149-
const Index validation_batch_size = (validation_samples_number != 0)
150-
? min(validation_samples_number, batch_size)
151-
: 0;
152-
153-
const Index training_batches_number = (training_batch_size != 0)
154-
? training_samples_number / training_batch_size
155-
: 0;
156-
157-
const Index validation_batches_number = (validation_batch_size != 0)
158-
? validation_samples_number / validation_batch_size
159-
: 0;
160-
161-
vector<vector<Index>> training_batches(training_batches_number);
162-
vector<vector<Index>> validation_batches(validation_batches_number);
163-
164-
// Neural network
165-
166-
NeuralNetwork* neural_network = loss->get_neural_network();
167-
168-
set_names();
169-
set_scaling();
170-
171-
Batch training_batch(training_batch_size, dataset);
172-
unique_ptr<Batch> validation_batch;
173-
174-
ForwardPropagation training_forward_propagation(training_batch_size, neural_network);
175-
unique_ptr<ForwardPropagation> validation_forward_propagation;
176-
177-
// Loss index
178-
179-
loss->set_normalization_coefficient();
180-
181-
BackPropagation training_back_propagation(training_batch_size, loss);
182-
unique_ptr<BackPropagation> validation_back_propagation;
183-
184-
if (has_validation)
185-
{
186-
validation_batch = make_unique<Batch>(validation_batch_size, dataset);
187-
validation_forward_propagation = make_unique<ForwardPropagation>(validation_batch_size, neural_network);
188-
validation_back_propagation = make_unique<BackPropagation>(validation_batch_size, loss);
189-
}
190-
191-
type training_error = type(0);
192-
type training_accuracy = type(0);
193-
194-
type validation_error = type(0);
195-
type validation_accuracy = type(0);
196-
197-
Index validation_failures = 0;
198-
199-
// Optimization algorithm
200-
201-
AdaptiveMomentEstimationData optimization_data(this);
202-
203-
bool stop_training = false;
204-
bool is_training = true;
205-
206-
time_t beginning_time;
207-
time(&beginning_time);
208-
209-
type elapsed_time = type(0);
210-
211-
bool shuffle = true;
212-
213-
if(neural_network->has("Recurrent"))
214-
shuffle = false;
215-
216-
// Main loop
217-
optimization_data.iteration = 1;
218-
219-
for(Index epoch = 0; epoch <= maximum_epochs; epoch++)
220-
{
221-
if(display && epoch%display_period == 0) cout << "Epoch: " << epoch << endl;
222-
223-
training_batches = dataset->get_batches(training_sample_indices, training_batch_size, shuffle);
224-
225-
training_error = type(0);
226-
227-
if(is_text_classification_model) training_accuracy = type(0);
228-
229-
for(Index iteration = 0; iteration < training_batches_number; iteration++)
230-
{
231-
training_back_propagation.neural_network.gradient.setZero();
232-
233-
// Dataset
234-
235-
training_batch.fill(training_batches[iteration],
236-
input_feature_indices,
237-
decoder_feature_indices,
238-
target_feature_indices);
239-
240-
// Neural network
241-
242-
neural_network->forward_propagate(training_batch.get_inputs(),
243-
training_forward_propagation,
244-
is_training);
245-
246-
// Loss index
247-
248-
loss->back_propagate(training_batch,
249-
training_forward_propagation,
250-
training_back_propagation);
251-
252-
training_error += training_back_propagation.error;
253-
254-
if(is_text_classification_model) training_accuracy += training_back_propagation.accuracy(0);
255-
256-
update_parameters(training_back_propagation, optimization_data);
257-
}
258-
259-
// Loss
260-
261-
training_error /= type(training_batches_number);
262-
if(is_text_classification_model)
263-
training_accuracy /= type(training_batches_number);
264-
265-
results.training_error_history(epoch) = training_error;
266-
267-
if(has_validation)
268-
{
269-
validation_batches = dataset->get_batches(validation_sample_indices, validation_batch_size, shuffle);
270-
271-
validation_error = type(0);
272-
273-
if(is_text_classification_model)
274-
validation_accuracy = type(0);
275-
276-
for(Index iteration = 0; iteration < validation_batches_number; iteration++)
277-
{
278-
// Dataset
279-
280-
validation_batch->fill(validation_batches[iteration],
281-
input_feature_indices,
282-
decoder_feature_indices,
283-
target_feature_indices);
284-
285-
// Neural network
286-
287-
neural_network->forward_propagate(validation_batch->get_inputs(),
288-
*validation_forward_propagation,
289-
is_training);
290-
291-
// Loss
292-
293-
loss->calculate_error(*validation_batch,
294-
*validation_forward_propagation,
295-
*validation_back_propagation);
296-
297-
validation_error += validation_back_propagation->error;
298-
299-
if(is_text_classification_model)
300-
validation_accuracy += validation_back_propagation->accuracy(0);
301-
}
302-
303-
validation_error /= type(validation_batches_number);
304-
if(is_text_classification_model) validation_accuracy /= type(validation_batches_number);
305-
306-
results.validation_error_history(epoch) = validation_error;
307-
308-
if(epoch != 0 && results.validation_error_history(epoch) > results.validation_error_history(epoch-1)) validation_failures++;
309-
}
310-
311-
// Elapsed time
312-
313-
elapsed_time = get_elapsed_time(beginning_time);
314-
315-
if(display && epoch%display_period == 0)
316-
{
317-
cout << "Training error: " << training_error << endl;
318-
if(is_text_classification_model) cout << "Training accuracy: " << training_accuracy << endl;
319-
if(has_validation) cout << "Validation error: " << validation_error << endl;
320-
if(has_validation && is_text_classification_model) cout << "Validation accuracy: " << validation_accuracy << endl;
321-
cout << "Elapsed time: " << write_time(elapsed_time) << endl;
322-
}
323-
324-
stop_training = true;
325-
326-
if(epoch == maximum_epochs)
327-
{
328-
if(display) cout << "Epoch " << epoch << "\nMaximum epochs number reached: " << epoch << endl;
329-
results.stopping_condition = StoppingCondition::MaximumEpochsNumber;
330-
}
331-
else if(elapsed_time >= maximum_time)
332-
{
333-
if(display) cout << "Epoch " << epoch << "\nMaximum training time reached: " << write_time(elapsed_time) << endl;
334-
results.stopping_condition = StoppingCondition::MaximumTime;
335-
}
336-
else if(results.training_error_history(epoch) < training_loss_goal)
337-
{
338-
results.stopping_condition = StoppingCondition::LossGoal;
339-
if(display) cout << "Epoch " << epoch << "\nLoss goal reached: " << results.training_error_history(epoch) << endl;
340-
}
341-
else if(training_accuracy >= training_accuracy_goal)
342-
{
343-
results.stopping_condition = StoppingCondition::LossGoal;
344-
if(display) cout << "Epoch " << epoch << "\nAccuracy goal reached: " << training_accuracy << endl;
345-
}
346-
else if(validation_failures >= maximum_validation_failures)
347-
{
348-
if(display) cout << "Epoch " << epoch << "\nMaximum selection failures reached: " << validation_failures << endl;
349-
results.stopping_condition = StoppingCondition::MaximumSelectionErrorIncreases;
350-
}
351-
else
352-
{
353-
stop_training = false;
354-
}
355-
356-
if(stop_training)
357-
{
358-
results.loss = training_back_propagation.loss_value;
359-
360-
results.validation_failures = validation_failures;
361-
362-
results.resize_training_error_history(epoch+1);
363-
364-
results.resize_validation_error_history(has_validation ? epoch + 1 : 0);
365-
366-
results.elapsed_time = write_time(elapsed_time);
367-
368-
break;
369-
}
370-
371-
if(epoch != 0 && epoch % save_period == 0) neural_network->save(neural_network_file_name);
372-
}
373-
374-
set_unscaling();
375-
376-
if(display) results.print();
377-
378-
return results;
379-
}
380-
381-
*/
382-
383114

384115
TrainingResults AdaptiveMomentEstimation::train()
385116
{

opennn/pch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
#include "../eigen/unsupported/Eigen/CXX11/Tensor"
5656
#include "../eigen/Eigen/src/Core/util/DisableStupidWarnings.h"
5757

58-
//#define OPENNN_CUDA // Comment this line to disable cuda files
58+
#define OPENNN_CUDA // Comment this line to disable cuda files
5959

6060
#ifdef OPENNN_CUDA
6161

0 commit comments

Comments
 (0)