Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 149 additions & 20 deletions examples/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,30 @@ def fixed_schedule(
count: Numeric,
value: Numeric,
) -> Numeric:
"""Fixed/constant schedule."""
"""Fixed/constant schedule.

Args:
count: The current count.
value: The value to return.

Returns:
The value.
"""
del count
return value


def kfac_resnet50_schedule(
count: Numeric,
) -> Numeric:
"""Custom schedule for KFAC ResNet50 experiment."""
"""Custom schedule for KFAC ResNet50 experiment.

Args:
count: The current count.

Returns:
The value of the schedule at the current count.
"""

# We linearly interpolate in log space
exponent = piecewise_interpolated_schedule(
Expand All @@ -115,11 +130,21 @@ def kfac_resnet50_schedule(

def cosine_schedule(
count: Numeric,
total: Numeric,
total: int,
peak_value: float,
end_value: float = 0.0,
) -> Numeric:
"""Cosine schedule."""
"""Cosine schedule.

Args:
count: The current count.
total: The value of count at the end of the schedule.
peak_value: The initial value of the schedule (at count=0).
end_value: The value at the end of the schedule (at count=total).

Returns:
The value of the schedule at the current count.
"""

val = optax.cosine_decay_schedule(
init_value=peak_value,
Expand All @@ -133,9 +158,9 @@ def cosine_schedule(

def stepwise_schedule(
count: Numeric,
boundaries: Array,
decay_factors: Sequence[float],
init_value: float,
boundaries: Sequence[Numeric],
decay_factors: Sequence[Numeric],
init_value: Numeric,
) -> Numeric:
"""A basic stepwise schedule.

Expand Down Expand Up @@ -166,13 +191,24 @@ def stepwise_schedule(


def exponential_decay_schedule(
count: int,
start: Numeric,
total: Numeric,
count: Numeric,
start: int,
total: int,
init_value: float,
end_value: float,
) -> Numeric:
"""Exponential decay schedule, similar to Optax."""
"""Exponential decay schedule, similar to Optax.

Args:
count: The current count.
start: The count at which to start the decay.
total: The value of count at the end of the schedule.
init_value: The initial value.
end_value: The final value.

Returns:
The value of the schedule at the current count.
"""

val = optax.exponential_decay(
init_value=init_value,
Expand All @@ -187,13 +223,26 @@ def exponential_decay_schedule(


def _custom_polynomial_schedule(
init_value: Numeric,
init_value: float,
end_value: float,
power: Numeric,
transition_steps: int,
transition_begin: int = 0
transition_steps: Numeric,
transition_begin: Numeric = 0
) -> GenericSchedule:
"""Polynomial schedule similar to Optax, but works even when init_value < end_value."""
"""Polynomial schedule similar to Optax, but works even when init_value < end_value.

Args:
init_value: The initial value.
end_value: The final value.
power: The power of the polynomial.
transition_steps: The length of the schedule after decay/growth begins (as
determined by transition_begin).
transition_begin: The count value at which to begin the decay/growth. Will
return init_value for counts less than this value.

Returns:
A function that takes the current count and returns the schedule value.
"""

def schedule(count):

Expand All @@ -210,14 +259,26 @@ def schedule(count):


def polynomial_schedule(
count: int,
start: int,
total: int,
count: Numeric,
start: Numeric,
total: Numeric,
init_value: float,
end_value: float,
power: Numeric = 1,
) -> Numeric:
"""Polynomial schedule (defaults to linear), similar to Optax."""
"""Polynomial schedule similar to Optax, but works even when init_value < end_value.

Args:
count: The current count.
start: The count at which to start the decay/growth.
total: The value of count at the end of the schedule.
init_value: The initial value.
end_value: The final value.
power: The power of the polynomial.

Returns:
The value of the schedule at the current count.
"""

val = _custom_polynomial_schedule(
init_value=init_value,
Expand All @@ -231,6 +292,57 @@ def polynomial_schedule(
return val


def inverse_time_decay_schedule(
count: Numeric,
start: Numeric,
total: Numeric,
init_value: float,
end_value: float,
power: Numeric = 1,
offset_before_power: bool = False,
) -> Numeric:
"""Inverse time decay schedule with 'power' option.

Return returns init_value * offset / (count ** power + offset), where offset
is chosen such that the schedule starts at init_value and ends at end_value.

If offset_before_power is True, the offset is added before the power
operation, i.e. it returns
init_value * offset**power / ((count + offset)**power), where again, offset is
chosen such that the schedule starts at init_value and ends at end_value.

Args:
count: The current count.
start: The count at which to start the decay.
total: The value of count at the end of the schedule.
init_value: The initial value.
end_value: The final value.
power: The value of p (see above).
offset_before_power: If True, the offset is added before the power
operation.

Returns:
The value of the schedule at the current count.
"""

if init_value < end_value:
raise ValueError("Inverse time decay schedule requires init_value >= "
"end_value.")

duration = total - start
count = jnp.clip(count - start, 0, duration)

end_factor = end_value / init_value if init_value != 0 else 0.0

if offset_before_power:
offset = end_factor**(1/power) * duration / (1 - end_factor**(1/power))
return init_value * offset**power / ((count + offset)**power)

else:
offset = end_factor * duration**power / (1 - end_factor)
return init_value * offset / (count**power + offset)


# For each schedule we specify:
# - "params_to_convert": list of parameters to convert (excluding
# warmup-related ones)
Expand Down Expand Up @@ -281,6 +393,12 @@ def polynomial_schedule(
"include_total": False,
"warmup_end_value_key": "vals",
},
"inverse_time_decay": {
"ctor": inverse_time_decay_schedule,
"params_to_convert": ["start"],
"include_total": True,
"warmup_end_value_key": "init_value",
},
}


Expand All @@ -290,7 +408,18 @@ def with_warmup(
warmup_start_value: float,
warmup_end_value: float
) -> GenericSchedule:
"""Wraps a base schedule with a linear warmup phase."""
"""Wraps a base schedule with a linear warmup phase.

Args:
base_schedule_fn: The schedule function to wrap.
warmup_duration: The duration of the warmup phase.
warmup_start_value: The value at the beginning of the warmup.
warmup_end_value: The value at the end of the warmup (which is also the
input to the base schedule).

Returns:
A new schedule function with warmup.
"""

warmup_sched = optax.linear_schedule(
init_value=warmup_start_value,
Expand Down
Loading