Skip to content
Snippets Groups Projects
Commit 4da4cf14 authored by Sergei Shmulyian's avatar Sergei Shmulyian Committed by tensorflow-copybara
Browse files

add LAMB to the list of supported optimizers

PiperOrigin-RevId: 322895995
parent 79f99345
No related branches found
No related tags found
No related merge requests found
......@@ -14,12 +14,12 @@
"""Optimizer utilities supporting federated averaging experiments."""
import inspect
from typing import Callable
from typing import Optional
from typing import Callable, List, Optional
from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_addons.optimizers as tfao
from tensorflow_federated.python.research.optimization.shared import yogi
......@@ -33,7 +33,7 @@ def _optimizer_canonical_name(optimizer_cls):
_SUPPORTED_OPTIMIZERS = {
_optimizer_canonical_name(cls): cls for cls in [
tf.keras.optimizers.SGD, tf.keras.optimizers.Adagrad,
tf.keras.optimizers.Adam, yogi.Yogi
tf.keras.optimizers.Adam, yogi.Yogi, tfao.lamb.LAMB
]
}
......@@ -117,11 +117,16 @@ def define_optimizer_flags(prefix: str) -> None:
define_flag_fn = flags.DEFINE_integer
elif is_param_of_type(param, str):
define_flag_fn = flags.DEFINE_string
elif is_param_of_type(param, List[str]):
define_flag_fn = flags.DEFINE_multi_string
else:
raise NotImplementedError('Cannot handle flag [{!s}] of type [{!s}] on '
'optimizers [{!s}]'.format(
param.name, type(param.default),
optimizer_name))
raise NotImplementedError('Cannot define flag [{!s}] '
'for parameter [{!s}] of type [{!s}] '
'(default value type [{!s}]) '
'on optimizer [{!s}]'.format(
prefixed(param.name),
param.name, param.annotation,
type(param.default), optimizer_name))
define_flag_fn(
name=prefixed(param.name),
default=param.default,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment