From 4da4cf1453a118087580b7d739c53a71e4e63bd4 Mon Sep 17 00:00:00 2001 From: Sergei Shmulyian <sshmulyian@google.com> Date: Thu, 23 Jul 2020 16:52:51 -0700 Subject: [PATCH] add LAMB to the list of supported optimizers PiperOrigin-RevId: 322895995 --- .../optimization/shared/optimizer_utils.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow_federated/python/research/optimization/shared/optimizer_utils.py b/tensorflow_federated/python/research/optimization/shared/optimizer_utils.py index 5ad3c1308..6a35a33dc 100644 --- a/tensorflow_federated/python/research/optimization/shared/optimizer_utils.py +++ b/tensorflow_federated/python/research/optimization/shared/optimizer_utils.py @@ -14,12 +14,12 @@ """Optimizer utilities supporting federated averaging experiments.""" import inspect -from typing import Callable -from typing import Optional +from typing import Callable, List, Optional from absl import flags from absl import logging import tensorflow as tf +import tensorflow_addons.optimizers as tfao from tensorflow_federated.python.research.optimization.shared import yogi @@ -33,7 +33,7 @@ def _optimizer_canonical_name(optimizer_cls): _SUPPORTED_OPTIMIZERS = { _optimizer_canonical_name(cls): cls for cls in [ tf.keras.optimizers.SGD, tf.keras.optimizers.Adagrad, - tf.keras.optimizers.Adam, yogi.Yogi + tf.keras.optimizers.Adam, yogi.Yogi, tfao.lamb.LAMB ] } @@ -117,11 +117,16 @@ def define_optimizer_flags(prefix: str) -> None: define_flag_fn = flags.DEFINE_integer elif is_param_of_type(param, str): define_flag_fn = flags.DEFINE_string + elif is_param_of_type(param, List[str]): + define_flag_fn = flags.DEFINE_multi_string else: - raise NotImplementedError('Cannot handle flag [{!s}] of type [{!s}] on ' - 'optimizers [{!s}]'.format( - param.name, type(param.default), - optimizer_name)) + raise NotImplementedError('Cannot define flag [{!s}] ' + 'for parameter [{!s}] of type [{!s}] ' + '(default value type [{!s}]) ' + 'on optimizer [{!s}]'.format( + prefixed(param.name), + param.name, param.annotation, + type(param.default), optimizer_name)) define_flag_fn( name=prefixed(param.name), default=param.default, -- GitLab