Skip to content

Commit dbfd03a

Browse files
committed
fix process batch and restore batches
1 parent b25cc94 commit dbfd03a

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

optimizer/mopso.py

+36-9
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,35 @@ def load_checkpoint(self):
368368
individual_states[i][3*self.num_params:], dtype=float)
369369
)
370370
self.particles.append(particle)
371+
372+
#restore batches
373+
self.particles_batch = []
374+
if (self.num_batch == 1):
375+
self.particles_batch.append(self.particles)
376+
self.batch_size = len(self.particles)
377+
else:
378+
# Calculate the approximate batch size
379+
self.batch_size = len(self.particles) // self.num_batch
380+
381+
# Check if the division leaves some elements unallocated
382+
remaining_elements = len(self.particles) % self.batch_size
383+
384+
if remaining_elements > 0:
385+
# Warn the user and suggest adjusting the number of particles or batches
386+
warning_message = (
387+
f"{bcolors.WARNING}The specified number of batches ({self.num_batch}) does not evenly divide the number of particles ({len(self.particles)}).{bcolors.ENDC}"
388+
)
389+
warnings.warn(warning_message)
390+
391+
# Use list comprehension to create batches
392+
self.particles_batch = [self.particles[i:i + self.batch_size]
393+
for i in range(0, len(self.particles), self.batch_size)]
394+
395+
# If the division leaves some elements unallocated, add them to the last batch
396+
if remaining_elements > 0:
397+
last_batch = self.particles_batch.pop()
398+
last_batch.extend(self.particles[len(self.particles_batch) * self.batch_size:])
399+
self.particles_batch.append(last_batch)
371400

372401
# restore pareto front
373402
self.pareto_front = []
@@ -381,18 +410,15 @@ def load_checkpoint(self):
381410
best_position=None,
382411
best_fitness=None)
383412
self.pareto_front.append(particle)
413+
384414

385415
def process_batch(self, worker_id, batch):
386416
# Launch a program for this batch using objective_function
387417
print(f"Worker ID {worker_id}")
388418
params = [particle.position for particle in batch]
389419
optimization_output = self.objective.evaluate(params, worker_id )
390-
for p_id, output in enumerate(optimization_output[0]):
391-
particle = batch[p_id]
392-
if self.optimization_mode == 'individual':
393-
particle.evaluate_fitness(self.objective_functions)
394-
if self.optimization_mode == 'global':
395-
particle.set_fitness(output)
420+
for p_id, particle in enumerate(batch):
421+
particle.set_fitness(optimization_output[:,p_id])
396422
batch[p_id] = particle
397423
return batch
398424

@@ -413,23 +439,24 @@ def optimize(self, num_iterations = 100, max_iter_no_improv = None):
413439
for _ in range(num_iterations):
414440
with ProcessPoolExecutor(max_workers=self.num_batch) as executor:
415441
futures = [executor.submit(self.process_batch, worker_id, batch)
416-
for worker_id, batch in enumerate(self.particles_batch)]
442+
for worker_id, batch in enumerate(self.particles_batch )]
417443

418444
new_batches = []
419445
for future in futures:
420446
batch = future.result()
421447
new_batches.append(batch)
422448
self.particles_batch = new_batches
423449
save_particles = []
450+
updated_particles = []
424451
for batch in self.particles_batch:
425452
for particle in batch:
426453
l = np.concatenate([particle.position, np.ravel(particle.fitness)])
427-
print(l)
428454
save_particles.append(l)
455+
updated_particles.append(particle)
429456
FileManager.save_csv(save_particles,
430457
'history/iteration' + str(self.iteration) + '.csv')
431458

432-
459+
self.particles = updated_particles
433460
self.update_pareto_front()
434461

435462
for batch in self.particles_batch:

0 commit comments

Comments
 (0)