diff --git a/examples/schedules.py b/examples/schedules.py index 1b230f3..8409aff 100644 --- a/examples/schedules.py +++ b/examples/schedules.py @@ -93,7 +93,15 @@ def fixed_schedule( count: Numeric, value: Numeric, ) -> Numeric: - """Fixed/constant schedule.""" + """Fixed/constant schedule. + + Args: + count: The current count. + value: The value to return. + + Returns: + The value. + """ del count return value @@ -101,7 +109,14 @@ def fixed_schedule( def kfac_resnet50_schedule( count: Numeric, ) -> Numeric: - """Custom schedule for KFAC ResNet50 experiment.""" + """Custom schedule for KFAC ResNet50 experiment. + + Args: + count: The current count. + + Returns: + The value of the schedule at the current count. + """ # We linearly interpolate in log space exponent = piecewise_interpolated_schedule( @@ -115,11 +130,21 @@ def kfac_resnet50_schedule( def cosine_schedule( count: Numeric, - total: Numeric, + total: int, peak_value: float, end_value: float = 0.0, ) -> Numeric: - """Cosine schedule.""" + """Cosine schedule. + + Args: + count: The current count. + total: The value of count at the end of the schedule. + peak_value: The initial value of the schedule (at count=0). + end_value: The value at the end of the schedule (at count=total). + + Returns: + The value of the schedule at the current count. + """ val = optax.cosine_decay_schedule( init_value=peak_value, @@ -133,9 +158,9 @@ def cosine_schedule( def stepwise_schedule( count: Numeric, - boundaries: Array, - decay_factors: Sequence[float], - init_value: float, + boundaries: Sequence[Numeric], + decay_factors: Sequence[Numeric], + init_value: Numeric, ) -> Numeric: """A basic stepwise schedule. @@ -166,13 +191,24 @@ def stepwise_schedule( def exponential_decay_schedule( - count: int, - start: Numeric, - total: Numeric, + count: Numeric, + start: int, + total: int, init_value: float, end_value: float, ) -> Numeric: - """Exponential decay schedule, similar to Optax.""" + """Exponential decay schedule, similar to Optax. + + Args: + count: The current count. + start: The count at which to start the decay. + total: The value of count at the end of the schedule. + init_value: The initial value. + end_value: The final value. + + Returns: + The value of the schedule at the current count. + """ val = optax.exponential_decay( init_value=init_value, @@ -187,13 +223,26 @@ def exponential_decay_schedule( def _custom_polynomial_schedule( - init_value: Numeric, + init_value: float, end_value: float, power: Numeric, - transition_steps: int, - transition_begin: int = 0 + transition_steps: Numeric, + transition_begin: Numeric = 0 ) -> GenericSchedule: - """Polynomial schedule similar to Optax, but works even when init_value < end_value.""" + """Polynomial schedule similar to Optax, but works even when init_value < end_value. + + Args: + init_value: The initial value. + end_value: The final value. + power: The power of the polynomial. + transition_steps: The length of the schedule after decay/growth begins (as + determined by transition_begin). + transition_begin: The count value at which to begin the decay/growth. Will + return init_value for counts less than this value. + + Returns: + A function that takes the current count and returns the schedule value. + """ def schedule(count): @@ -210,14 +259,26 @@ def schedule(count): def polynomial_schedule( - count: int, - start: int, - total: int, + count: Numeric, + start: Numeric, + total: Numeric, init_value: float, end_value: float, power: Numeric = 1, ) -> Numeric: - """Polynomial schedule (defaults to linear), similar to Optax.""" + """Polynomial schedule similar to Optax, but works even when init_value < end_value. + + Args: + count: The current count. + start: The count at which to start the decay/growth. + total: The value of count at the end of the schedule. + init_value: The initial value. + end_value: The final value. + power: The power of the polynomial. + + Returns: + The value of the schedule at the current count. + """ val = _custom_polynomial_schedule( init_value=init_value, @@ -231,6 +292,57 @@ def polynomial_schedule( return val +def inverse_time_decay_schedule( + count: Numeric, + start: Numeric, + total: Numeric, + init_value: float, + end_value: float, + power: Numeric = 1, + offset_before_power: bool = False, +) -> Numeric: + """Inverse time decay schedule with 'power' option. + + Return returns init_value * offset / (count ** power + offset), where offset + is chosen such that the schedule starts at init_value and ends at end_value. + + If offset_before_power is True, the offset is added before the power + operation, i.e. it returns + init_value * offset**power / ((count + offset)**power), where again, offset is + chosen such that the schedule starts at init_value and ends at end_value. + + Args: + count: The current count. + start: The count at which to start the decay. + total: The value of count at the end of the schedule. + init_value: The initial value. + end_value: The final value. + power: The value of p (see above). + offset_before_power: If True, the offset is added before the power + operation. + + Returns: + The value of the schedule at the current count. + """ + + if init_value < end_value: + raise ValueError("Inverse time decay schedule requires init_value >= " + "end_value.") + + duration = total - start + count = jnp.clip(count - start, 0, duration) + + end_factor = end_value / init_value if init_value != 0 else 0.0 + + if offset_before_power: + offset = end_factor**(1/power) * duration / (1 - end_factor**(1/power)) + return init_value * offset**power / ((count + offset)**power) + + else: + offset = end_factor * duration**power / (1 - end_factor) + return init_value * offset / (count**power + offset) + + # For each schedule we specify: # - "params_to_convert": list of parameters to convert (excluding # warmup-related ones) @@ -281,6 +393,12 @@ def polynomial_schedule( "include_total": False, "warmup_end_value_key": "vals", }, + "inverse_time_decay": { + "ctor": inverse_time_decay_schedule, + "params_to_convert": ["start"], + "include_total": True, + "warmup_end_value_key": "init_value", + }, } @@ -290,7 +408,18 @@ def with_warmup( warmup_start_value: float, warmup_end_value: float ) -> GenericSchedule: - """Wraps a base schedule with a linear warmup phase.""" + """Wraps a base schedule with a linear warmup phase. + + Args: + base_schedule_fn: The schedule function to wrap. + warmup_duration: The duration of the warmup phase. + warmup_start_value: The value at the beginning of the warmup. + warmup_end_value: The value at the end of the warmup (which is also the + input to the base schedule). + + Returns: + A new schedule function with warmup. + """ warmup_sched = optax.linear_schedule( init_value=warmup_start_value,