Skip to content
Snippets Groups Projects
Commit 11effdd4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by tensorflow-copybara
Browse files

add linear warmup option for learning rate schedules

PiperOrigin-RevId: 321812938
parent bb4e5f40
No related branches found
No related tags found
No related merge requests found
......@@ -87,7 +87,7 @@ def define_optimizer_flags(prefix: str) -> None:
flags.DEFINE_float(
name='{!s}_learning_rate'.format(prefix),
default=None,
help='Learning rate for optimizer `{!s}`'.format(prefix))
help='Base learning rate for optimizer `{!s}`'.format(prefix))
logging.info('Defined new flag: [%s]', '{!s}_learning_rate'.format(prefix))
for optimizer_name, optimizer_cls in _SUPPORTED_OPTIMIZERS.items():
......@@ -243,6 +243,7 @@ def define_lr_schedule_flags(prefix: str) -> None:
This creates four new flags:
* `--<prefix>_lr_schedule`
* `--<prefix>_lr_warmup_steps`
* `--<prefix>_lr_decay_step`
* `--<prefix>_lr_decay_rate`
* `--<prefix>_lr_staircase`
......@@ -258,11 +259,11 @@ def define_lr_schedule_flags(prefix: str) -> None:
def prefixed(basename):
return '{}_{}'.format(prefix, basename) if prefix else basename
initial_lr_flag_name = prefixed('learning_rate')
if flags.FLAGS[initial_lr_flag_name] is None:
base_lr_flag_name = prefixed('learning_rate')
if flags.FLAGS[base_lr_flag_name] is None:
logging.warning(
'The flag %s is not set. This must be set before calling '
'`create_lr_schedule_from_flags`.', initial_lr_flag_name)
'`create_lr_schedule_from_flags`.', base_lr_flag_name)
flags.DEFINE_enum(
'{!s}_lr_schedule'.format(prefix),
......@@ -270,6 +271,11 @@ def define_lr_schedule_flags(prefix: str) -> None:
enum_values=['constant', 'exp_decay', 'inv_lin_decay', 'inv_sqrt_decay'],
help='Type of learning rate decay schedule to use for `{!s}`.'.format(
prefix))
flags.DEFINE_integer(
'{!s}_lr_warmup_steps'.format(prefix),
default=None,
help='An int number of steps to warm up the `{!s}` learning rate (e.g. '
'increase linearly from 0 to the base value).'.format(prefix))
flags.DEFINE_integer(
'{!s}_lr_decay_steps'.format(prefix),
default=None,
......@@ -287,11 +293,38 @@ def define_lr_schedule_flags(prefix: str) -> None:
.format(prefix))
def warmup_and_decay_schedule_builder(base_value, warmup_steps, decay_fn):
"""Creates a learning rate schedule with warmup and decay.
Args:
base_value: The base value of the quantity to warm up to, then decay from,
over time.
warmup_steps: A scalar for the number of steps to linearly increase the
value (from base_value/warmup_steps to base_value) prior to decaying. No
warmup if 0 or negative.
decay_fn: A 1-arg callable producing a decayed version of the base value
when passed the current round_num (adjusted for warmup_steps if relevant).
Returns:
A 1-arg callable that produces a warmed up then decayed version of the base
value when passed the (unadjusted) current round_num.
"""
def warmup_and_decay_fn(round_num):
if warmup_steps and warmup_steps > 0:
if round_num < warmup_steps:
return base_value * (round_num + 1) / warmup_steps
round_num = round_num - warmup_steps
return decay_fn(round_num)
return warmup_and_decay_fn
def exp_decay_schedule_builder(base_value, decay_steps, decay_rate, staircase):
"""Creates a learning rate schedule with exponential root decay.
Args:
base_value: The initial value of the quantity to decay over time.
base_value: The base value of the quantity to decay over time.
decay_steps: A positive scalar that governs how much the value decays at a
given round number.
decay_rate: A float between 0 and 1 that governs how quickly the decay
......@@ -317,7 +350,7 @@ def inv_lin_schedule_builder(base_value, decay_steps, decay_rate, staircase):
"""Creates a learning rate schedule with inverse linear decay.
Args:
base_value: The initial value of the quantity to decay over time.
base_value: The base value of the quantity to decay over time.
decay_steps: A positive scalar that governs how much the value decays at a
given round number.
decay_rate: A positive scalar that governs how quickly the decay occurs.
......@@ -342,7 +375,7 @@ def inv_sqrt_schedule_builder(base_value, decay_steps, decay_rate, staircase):
"""Creates a learning rate schedule with inverse square root decay.
Args:
base_value: The initial value of the quantity to decay over time.
base_value: The base value of the quantity to decay over time.
decay_steps: A positive scalar that governs how much the value decays at a
given round number.
decay_rate: A positive scalar that governs how quickly the decay occurs.
......@@ -374,9 +407,7 @@ def create_lr_schedule_from_flags(
This method expects the following flags to have been defined and set:
* `--<prefix>_learning_rate`
* `--<prefix>_lr_schedule`
If `<prefix>_lr_schedule` is set to `constant`, then this function will
return a callable that always outputs <prefix>_learning_rate`.
* `--<prefix>_lr_warmup_steps`
If <prefix>_lr_schedule is not `constant`, then this method expects the
following flags to be defined as well:
......@@ -400,25 +431,38 @@ def create_lr_schedule_from_flags(
lr_schedule_flag_name = prefixed('lr_schedule')
if flags.FLAGS[lr_schedule_flag_name] is None:
raise ValueError('Must specify flag --{!s}'.format(lr_schedule_flag_name))
lr_warmup_steps_flag_name = prefixed('lr_warmup_steps')
if flags.FLAGS[lr_warmup_steps_flag_name] is None:
raise ValueError(
'Must specify flag --{!s}'.format(lr_warmup_steps_flag_name))
base_lr = flags.FLAGS[lr_flag_name].value
lr_schedule_type = flags.FLAGS[lr_schedule_flag_name].value
lr_warmup_steps = flags.FLAGS[lr_warmup_steps_flag_name].value
if lr_schedule_type == 'constant':
return lambda round_num: base_lr
return warmup_and_decay_schedule_builder(base_lr, lr_warmup_steps,
lambda _: base_lr)
lr_decay_steps = flags.FLAGS[prefixed('lr_decay_steps')].value
lr_decay_rate = flags.FLAGS[prefixed('lr_decay_rate')].value
lr_staircase = flags.FLAGS[prefixed('lr_staircase')].value
if lr_schedule_type == 'exp_decay':
return exp_decay_schedule_builder(
base_lr, lr_decay_steps, lr_decay_rate, lr_staircase)
return warmup_and_decay_schedule_builder(
base_lr, lr_warmup_steps,
exp_decay_schedule_builder(base_lr, lr_decay_steps, lr_decay_rate,
lr_staircase))
elif lr_schedule_type == 'inv_lin_decay':
return inv_lin_schedule_builder(
base_lr, lr_decay_steps, lr_decay_rate, lr_staircase)
return warmup_and_decay_schedule_builder(
base_lr, lr_warmup_steps,
inv_lin_schedule_builder(base_lr, lr_decay_steps, lr_decay_rate,
lr_staircase))
elif lr_schedule_type == 'inv_sqrt_decay':
return inv_sqrt_schedule_builder(
base_lr, lr_decay_steps, lr_decay_rate, lr_staircase)
return warmup_and_decay_schedule_builder(
base_lr, lr_warmup_steps,
inv_sqrt_schedule_builder(base_lr, lr_decay_steps, lr_decay_rate,
lr_staircase))
else:
raise ValueError(
'Unrecognized schedule type {!s}'.format(lr_schedule_type))
......@@ -135,7 +135,7 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
def test_create_constant_client_lr_schedule_from_flags(self):
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0,
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'constant'
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'constant',
}):
lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
TEST_CLIENT_FLAG_PREFIX)
......@@ -143,6 +143,19 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(lr_schedule(1), 3.0, err=1e-5)
self.assertNear(lr_schedule(105), 3.0, err=1e-5)
self.assertNear(lr_schedule(1042), 3.0, err=1e-5)
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0,
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'constant',
'{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10
}):
lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
TEST_CLIENT_FLAG_PREFIX)
self.assertNear(lr_schedule(0), 0.3, err=1e-5)
self.assertNear(lr_schedule(1), 0.6, err=1e-5)
self.assertNear(lr_schedule(10), 3.0, err=1e-5)
self.assertNear(lr_schedule(11), 3.0, err=1e-5)
self.assertNear(lr_schedule(115), 3.0, err=1e-5)
self.assertNear(lr_schedule(1052), 3.0, err=1e-5)
def test_create_exp_decay_client_lr_schedule_from_flags(self):
with flag_sandbox({
......@@ -163,6 +176,7 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0,
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'exp_decay',
'{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 0,
'{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 0.1,
'{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
......@@ -174,6 +188,23 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(lr_schedule(10), 0.3, err=1e-5)
self.assertNear(lr_schedule(25), 0.00948683298, err=1e-5)
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0,
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'exp_decay',
'{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 0.1,
'{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
}):
lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
TEST_CLIENT_FLAG_PREFIX)
self.assertNear(lr_schedule(0), 0.3, err=1e-5)
self.assertNear(lr_schedule(1), 0.6, err=1e-5)
self.assertNear(lr_schedule(10), 3.0, err=1e-5)
self.assertNear(lr_schedule(11), 2.38298470417, err=1e-5)
self.assertNear(lr_schedule(20), 0.3, err=1e-5)
self.assertNear(lr_schedule(35), 0.00948683298, err=1e-5)
def test_create_inv_lin_client_lr_schedule_from_flags(self):
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 5.0,
......@@ -193,6 +224,7 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 5.0,
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_lin_decay',
'{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 0,
'{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0,
'{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
......@@ -204,6 +236,23 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(lr_schedule(9), 0.5, err=1e-5)
self.assertNear(lr_schedule(19), 0.25, err=1e-5)
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 5.0,
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_lin_decay',
'{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0,
'{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
}):
lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
TEST_CLIENT_FLAG_PREFIX)
self.assertNear(lr_schedule(0), 0.5, err=1e-5)
self.assertNear(lr_schedule(1), 1.0, err=1e-5)
self.assertNear(lr_schedule(10), 5.0, err=1e-5)
self.assertNear(lr_schedule(11), 2.5, err=1e-5)
self.assertNear(lr_schedule(19), 0.5, err=1e-5)
self.assertNear(lr_schedule(29), 0.25, err=1e-5)
def test_create_inv_sqrt_client_lr_schedule_from_flags(self):
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 2.0,
......@@ -223,6 +272,7 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 2.0,
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_sqrt_decay',
'{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 0,
'{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0,
'{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
......@@ -234,6 +284,23 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(lr_schedule(99), 0.2, err=1e-5)
self.assertNear(lr_schedule(399), 0.1, err=1e-5)
with flag_sandbox({
'{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 2.0,
'{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_sqrt_decay',
'{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10,
'{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0,
'{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False,
}):
lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
TEST_CLIENT_FLAG_PREFIX)
self.assertNear(lr_schedule(0), 0.2, err=1e-5)
self.assertNear(lr_schedule(1), 0.4, err=1e-5)
self.assertNear(lr_schedule(10), 2.0, err=1e-5)
self.assertNear(lr_schedule(13), 1.0, err=1e-5)
self.assertNear(lr_schedule(109), 0.2, err=1e-5)
self.assertNear(lr_schedule(409), 0.1, err=1e-5)
if __name__ == '__main__':
tf.test.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment