Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions Wrappers/Python/cil/optimisation/algorithms/Algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,17 @@ class Algorithm:
r"""Base class providing minimal infrastructure for iterative algorithms.

An iterative algorithm is designed to solve an optimization problem by repeatedly refining a solution. In CIL, we use iterative algorithms to minimize an objective function, often referred to as a loss. The process begins with an initial guess, and with each iteration, the algorithm updates the current solution based on the results of previous iterations (previous iterates). Iterative algorithms typically continue until a stopping criterion is met, indicating that an optimal or sufficiently good solution has been found. In CIL, stopping criteria can be implemented using a callback function (`cil.optimisation.utilities.callbacks`).

The user is required to implement the :code:`set_up`, :code:`__init__`, :code:`update` and :code:`update_objective` methods.

The method :code:`run` is available to run :code:`n` iterations. The method accepts :code:`callbacks`: a list of callables, each of which receive the current Algorithm object (which in turn contains the iteration number and the actual objective value) and can be used to trigger print to screens and other user interactions. The :code:`run` method will stop when the stopping criterion is met or `StopIteration` is raised.

Parameters
----------
update_objective_interval: int, optional, default 1
The objective (or loss) is calculated and saved every `update_objective_interval`. 1 means every iteration, 2 every 2 iterations and so forth. This is by default 1 and should be increased when evaluating the objective is computationally expensive.
"""

def __init__(self, update_objective_interval=1, max_iteration=None, log_file=None):

def __init__(self, update_objective_interval=None, max_iteration=None, log_file=None):
if update_objective_interval is None:
update_objective_interval = 1
else:
warn("use `Algorithm.run(update_objective_interval)` instead of `update_objective_interval`", DeprecationWarning, stacklevel=2)
self.iteration = -1
self.__max_iteration = 1
if max_iteration is not None:
Expand Down Expand Up @@ -223,31 +221,30 @@ def update_objective_interval(self):
@update_objective_interval.setter
def update_objective_interval(self, value):
'''sets the update_objective_interval'''
if not isinstance(value, Integral) or value < 0:
if not ((isinstance(value, Integral) and value >= 0) or np.isposinf(value)):
raise ValueError('interval must be an integer >= 0')
self.__update_objective_interval = value

def run(self, iterations=None, callbacks: Optional[List[Callback]]=None, verbose=1, **kwargs):
r"""run upto :code:`iterations` with callbacks/logging.
def run(self, iterations: int, update_objective_interval=None, callbacks: Optional[List[Callback]]=None, verbose=1, **kwargs):
"""run upto :code:`iterations` with callbacks/logging.

For a demonstration of callbacks see https://github.com/TomographicImaging/CIL-Demos/blob/main/misc/callback_demonstration.ipynb

Parameters
-----------
iterations: int, default is None
iterations: int
Number of iterations to run. If a positive infinity is passed, the algorithm will run indefinitely until a callback raises `StopIteration`.
update_objective_interval: int, optional, default 1
The objective (or loss) is calculated and saved every `update_objective_interval`. 1 means every iteration, 2 every 2 iterations and so forth. This is by default 1 and should be increased when evaluating the objective is computationally expensive.
callbacks: list of callables, default is Defaults to :code:`[ProgressCallback(verbose)]`
List of callables which are passed the current Algorithm object each iteration. Defaults to :code:`[ProgressCallback(verbose)]`.
verbose: 0=quiet, 1=info, 2=debug
Passed to the default callback to determine the verbosity of the printed output.
Passed to the default callback to determine the verbosity of the printed output.
"""

if iterations is None:
raise ValueError("`run()` missing number of `iterations`")

if update_objective_interval is not None:
self.update_objective_interval = update_objective_interval
if 'print_interval' in kwargs:
warn("use `TextProgressCallback(miniters)` instead of `run(print_interval)`",
DeprecationWarning, stacklevel=2)
warn("use `TextProgressCallback(miniters)` instead of `run(print_interval)`", DeprecationWarning, stacklevel=2)
if np.isposinf(iterations):
if callbacks is None:
raise ValueError("Infinite iterations require a callback with a stopping criterion that raises `StopIteration`")
Expand Down
Loading