Skip to content

Commit 9b5d6db

Browse files
HevagogdJaniga
andauthored
Optimization should stop after n iteration (#34)
Co-authored-by: Damian Janiga <[email protected]>
1 parent 00c8f04 commit 9b5d6db

File tree

8 files changed

+124
-62
lines changed

8 files changed

+124
-62
lines changed

src/orchestration/risk_management_service/core/mappers/control_vector_mapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
from services.problem_dispatcher_service import ControlVector as PDControlVector
2-
from services.solution_updater_service import ControlVector as SUControlVector
1+
from services.problem_dispatcher_service import (
2+
ControlVector as PDControlVector,
3+
)
4+
from services.solution_updater_service import (
5+
ControlVector as SUControlVector,
6+
)
37

48

59
class ControlVectorMapper:

src/orchestration/risk_management_service/core/service/risk_management_service.py

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33
from logger.u_logger import get_logger
44
from orchestration.risk_management_service.core.mappers import ControlVectorMapper
55
from services.problem_dispatcher_service import (
6-
ControlVector,
76
ProblemDispatcherService,
87
ServiceType,
98
)
109
from services.simulation_service import (
1110
SimulationService,
12-
simulation_cluster_contex_manager,
11+
simulation_cluster_context_manager,
1312
)
1413
from services.solution_updater_service import (
1514
OptimizationEngine,
1615
SolutionUpdaterService,
17-
ensure_not_none,
1816
)
17+
from services.solution_updater_service.core.utils import ensure_not_none
1918
from services.well_management_service import WellManagementService
2019

2120
logger = get_logger(__name__)
@@ -25,11 +24,13 @@ def run_risk_management(
2524
problem_definition: dict[str, Any],
2625
simulation_model_archive: bytes | str,
2726
n_size: int = 10,
27+
max_generations: int = 10,
2828
):
2929
"""
3030
Main entry point for running risk management.
3131
3232
Args:
33+
max_generations: Maximum number of generations for the optimization process.
3334
problem_definition (dict[str, Any]): The problem definition used by the dispatcher.
3435
simulation_model_archive (bytes | str): The simulation model archive to transfer.
3536
n_size (int, optional): Number of samples for the dispatcher. Defaults to 10.
@@ -42,7 +43,7 @@ def run_risk_management(
4243
n_size,
4344
)
4445

45-
with simulation_cluster_contex_manager():
46+
with simulation_cluster_context_manager():
4647
try:
4748
logger.info("Transferring simulation model archive to the cluster.")
4849
SimulationService.transfer_simulation_model(
@@ -53,7 +54,8 @@ def run_risk_management(
5354
"Initializing SolutionUpdaterService and ProblemDispatcherService."
5455
)
5556
solution_updater = SolutionUpdaterService(
56-
optimization_engine=OptimizationEngine.PSO
57+
optimization_engine=OptimizationEngine.PSO,
58+
max_generations=max_generations,
5759
)
5860
dispatcher = ProblemDispatcherService(
5961
problem_definition=problem_definition, n_size=n_size
@@ -64,56 +66,66 @@ def run_risk_management(
6466
logger.debug("Boundaries retrieved: %s", boundaries)
6567

6668
# Initialize solutions
67-
next_solutions: list[ControlVector] | None = None
68-
69-
for iteration in range(5):
70-
logger.info("Starting iteration %d for risk management.", iteration + 1)
69+
next_solutions = None
70+
71+
loop_controller = solution_updater.loop_controller
72+
try:
73+
while loop_controller.running():
74+
logger.info(
75+
"Starting generation %d for risk management.",
76+
loop_controller.current_generation,
77+
)
7178

72-
# Generate or update solutions
73-
solutions = dispatcher.process_iteration(next_solutions)
74-
logger.debug("Generated solutions: %s", solutions)
79+
# Generate or update solutions
80+
solutions = dispatcher.process_iteration(next_solutions)
81+
logger.debug("Generated solutions: %s", solutions)
7582

76-
# Prepare simulation cases
77-
sim_cases = _prepare_simulation_cases(solutions)
78-
logger.debug("Prepared simulation cases: %s", sim_cases)
83+
# Prepare simulation cases
84+
sim_cases = _prepare_simulation_cases(solutions)
85+
logger.debug("Prepared simulation cases: %s", sim_cases)
7986

80-
# Process simulation with the simulation service
81-
logger.info("Submitting simulation cases to SimulationService.")
82-
completed_cases = SimulationService.process_request(
83-
{"simulation_cases": sim_cases}
84-
)
85-
logger.debug("Completed simulation cases: %s", completed_cases)
86-
87-
# Update solutions based on simulation results
88-
updated_solutions = [
89-
{
90-
"control_vector": {"items": simulation_case.control_vector},
91-
"cost_function_results": {
92-
"values": ensure_not_none(
93-
simulation_case.results
94-
).model_dump()
95-
},
96-
}
97-
for simulation_case in completed_cases.simulation_cases
98-
]
99-
logger.debug(
100-
"Updated solutions for next iteration: %s", updated_solutions
101-
)
87+
# Process simulation with the simulation service
88+
logger.info("Submitting simulation cases to SimulationService.")
89+
completed_cases = SimulationService.process_request(
90+
{"simulation_cases": sim_cases}
91+
)
92+
logger.debug("Completed simulation cases: %s", completed_cases)
10293

103-
# Map simulation service solutions to the ProblemDispatcherService format
104-
next_solutions = ControlVectorMapper.convert_su_to_pd(
105-
solution_updater.process_request(
94+
# Update solutions based on simulation results
95+
updated_solutions = [
10696
{
107-
"solution_candidates": updated_solutions,
108-
"optimization_constraints": {"boundaries": boundaries},
97+
"control_vector": {"items": simulation_case.control_vector},
98+
"cost_function_results": {
99+
"values": ensure_not_none(
100+
simulation_case.results
101+
).model_dump()
102+
},
109103
}
110-
).next_iter_solutions
111-
)
104+
for simulation_case in completed_cases.simulation_cases
105+
]
106+
logger.debug(
107+
"Updated solutions for next iteration: %s", updated_solutions
108+
)
109+
110+
# Map simulation service solutions to the ProblemDispatcherService format
111+
next_solutions = ControlVectorMapper.convert_su_to_pd(
112+
solution_updater.process_request(
113+
{
114+
"solution_candidates": updated_solutions,
115+
"optimization_constraints": {"boundaries": boundaries},
116+
}
117+
).next_iter_solutions
118+
)
119+
logger.info(
120+
"Generation %d successfully completed for risk management.",
121+
loop_controller.current_generation,
122+
)
123+
except StopIteration as e:
112124
logger.info(
113-
"Iteration %d successfully completed for risk management.",
114-
iteration + 1,
125+
"Loop controller stopped at generation %d: %s",
126+
loop_controller.current_generation,
127+
str(e),
115128
)
116-
117129
except Exception as e:
118130
logger.error("Error in risk management process: %s", str(e), exc_info=True)
119131
raise

src/services/simulation_service/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from services.simulation_service.core.api import ( # noqa: F401, E402
1717
SimulationService,
18-
simulation_cluster_contex_manager,
18+
simulation_cluster_context_manager,
1919
)
2020

21-
__all__ = ["SimulationService", "simulation_cluster_contex_manager"]
21+
__all__ = ["SimulationService", "simulation_cluster_context_manager"]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from services.simulation_service.core.service import (
22
SimulationService,
3-
simulation_cluster_contex_manager,
3+
simulation_cluster_context_manager,
44
)
55

6-
__all__ = ["SimulationService", "simulation_cluster_contex_manager"]
6+
__all__ = ["SimulationService", "simulation_cluster_context_manager"]

src/services/simulation_service/core/service/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from logger.u_logger import configure_logger, get_logger
66
from services.simulation_service.core.service.simulation_service import (
77
SimulationService,
8-
simulation_cluster_contex_manager,
8+
simulation_cluster_context_manager,
99
)
1010

1111
configure_logger()
@@ -68,4 +68,4 @@ def _initialize_images() -> None:
6868

6969
_initialize_images()
7070

71-
__all__ = ["SimulationService", "simulation_cluster_contex_manager"]
71+
__all__ = ["SimulationService", "simulation_cluster_context_manager"]

src/services/simulation_service/core/service/simulation_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def _from_grpc(simulation: sm.Simulation) -> SimulationCase:
283283

284284

285285
@contextmanager
286-
def simulation_cluster_contex_manager():
286+
def simulation_cluster_context_manager():
287287
"""
288288
Context manager for managing the simulation cluster lifecycle.
289289
"""

src/services/solution_updater_service/core/service/solution_updater_service.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,39 @@ def _initiate_mapper_on_first_call(
321321
return _MapperState(control_vector_mapping, results_mapping, population_size)
322322

323323

324+
class _SolutionUpdaterServiceLoopController:
325+
def __init__(self, max_generations: int) -> None:
326+
"""
327+
Helper class to control the loop of the solution updater service.
328+
This class manages the number of iterations for the optimization process, and raises StopIteration exception when convergence fails.
329+
"""
330+
self._max_generations = max_generations
331+
self.current_generation = 0
332+
self._is_running = True
333+
334+
def running(self) -> bool:
335+
"""
336+
Checks if the loop controller should run
337+
338+
Returns:
339+
bool: True if the loop controller is running, False otherwise.
340+
"""
341+
if self.current_generation >= self._max_generations:
342+
self._is_running = False
343+
self.current_generation += 1
344+
return self._is_running
345+
346+
324347
class SolutionUpdaterService:
325-
def __init__(self, optimization_engine: OptimizationEngine) -> None:
348+
def __init__(
349+
self, optimization_engine: OptimizationEngine, max_generations: int
350+
) -> None:
326351
self._mapper: _Mapper = _Mapper()
327352
self._engine: OptimizationEngineInterface = (
328353
OptimizationEngineFactory.get_engine(optimization_engine)
329354
)
330355
self._logger = get_logger(__name__)
356+
self.loop_controller = _SolutionUpdaterServiceLoopController(max_generations)
331357

332358
def process_request(
333359
self, request_dict: dict[str, Any]
@@ -385,6 +411,8 @@ def process_request(
385411
if not config.solution_candidates:
386412
raise RuntimeError("Nothing to optimize")
387413

414+
self._check_convergence(config.solution_candidates)
415+
388416
control_vector, cost_function_values = self._mapper.to_numpy(
389417
config.solution_candidates
390418
)
@@ -399,4 +427,19 @@ def process_request(
399427

400428
next_iter_solutions = self._mapper.to_control_vectors(updated_params)
401429
self._logger.info("Control vectors update request processed successfully.")
430+
402431
return SolutionUpdaterServiceResponse(next_iter_solutions=next_iter_solutions)
432+
433+
def _check_convergence(
434+
self, solution: list[SolutionCandidate], tol: float = 1e-4
435+
) -> None:
436+
"""
437+
Should raise StopIteration exception when convergence reach desired value.
438+
Args:
439+
tol: function convergence tolerance
440+
solution: list of SolutionCandidate
441+
442+
Returns:
443+
444+
"""
445+
pass

tests/solution_updater_service_test/test_solution_updater_service.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_update_solution_for_next_iteration_single_call( # type: ignore
8484
config_json, expected_result_parameters, mocked_engine, monkeypatch
8585
):
8686
# Arrange
87-
service = SolutionUpdaterService(optimization_engine=engine)
87+
service = SolutionUpdaterService(optimization_engine=engine, max_generations=100)
8888

8989
# Monkeypatch engine
9090
monkeypatch.setattr(service, "_engine", mocked_engine)
@@ -140,7 +140,7 @@ def test_update_solution_for_next_iteration_multiple_calls( # type: ignore
140140
monkeypatch,
141141
):
142142
# Arrange
143-
service = SolutionUpdaterService(optimization_engine=engine)
143+
service = SolutionUpdaterService(optimization_engine=engine, max_generations=100)
144144

145145
# Monkeypatch engine
146146
monkeypatch.setattr(service, "_engine", mocked_engine)
@@ -210,7 +210,7 @@ def test_update_solution_with_boundaries_np( # type: ignore
210210
config_json, expected_result_parameters, mocked_engine_with_bnb, monkeypatch
211211
):
212212
# Arrange
213-
service = SolutionUpdaterService(optimization_engine=engine)
213+
service = SolutionUpdaterService(optimization_engine=engine, max_generations=100)
214214

215215
# Monkeypatch engine to use mocked behavior
216216
monkeypatch.setattr(service, "_engine", mocked_engine_with_bnb)
@@ -312,9 +312,12 @@ def test_optimization_service_full_round(test_case):
312312
"optimization_constraints": {"boundaries": {k: [lb, ub] for k in param_names}},
313313
}
314314

315-
service = SolutionUpdaterService(optimization_engine=engine)
315+
service = SolutionUpdaterService(
316+
optimization_engine=engine, max_generations=iterations
317+
)
318+
loop_controller = service.loop_controller
316319

317-
for _ in range(iterations):
320+
while loop_controller.running():
318321
result = service.process_request(config)
319322
positions = np.array(
320323
[get_numpy_values(vec.items) for vec in result.next_iter_solutions]

0 commit comments

Comments
 (0)