From 4da4cf1453a118087580b7d739c53a71e4e63bd4 Mon Sep 17 00:00:00 2001
From: Sergei Shmulyian <sshmulyian@google.com>
Date: Thu, 23 Jul 2020 16:52:51 -0700
Subject: [PATCH] add LAMB to the list of supported optimizers

PiperOrigin-RevId: 322895995
---
 .../optimization/shared/optimizer_utils.py    | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/tensorflow_federated/python/research/optimization/shared/optimizer_utils.py b/tensorflow_federated/python/research/optimization/shared/optimizer_utils.py
index 5ad3c1308..6a35a33dc 100644
--- a/tensorflow_federated/python/research/optimization/shared/optimizer_utils.py
+++ b/tensorflow_federated/python/research/optimization/shared/optimizer_utils.py
@@ -14,12 +14,12 @@
 """Optimizer utilities supporting federated averaging experiments."""
 
 import inspect
-from typing import Callable
-from typing import Optional
+from typing import Callable, List, Optional
 
 from absl import flags
 from absl import logging
 import tensorflow as tf
+import tensorflow_addons.optimizers as tfao
 
 from tensorflow_federated.python.research.optimization.shared import yogi
 
@@ -33,7 +33,7 @@ def _optimizer_canonical_name(optimizer_cls):
 _SUPPORTED_OPTIMIZERS = {
     _optimizer_canonical_name(cls): cls for cls in [
         tf.keras.optimizers.SGD, tf.keras.optimizers.Adagrad,
-        tf.keras.optimizers.Adam, yogi.Yogi
+        tf.keras.optimizers.Adam, yogi.Yogi, tfao.lamb.LAMB
     ]
 }
 
@@ -117,11 +117,16 @@ def define_optimizer_flags(prefix: str) -> None:
         define_flag_fn = flags.DEFINE_integer
       elif is_param_of_type(param, str):
         define_flag_fn = flags.DEFINE_string
+      elif is_param_of_type(param, List[str]):
+        define_flag_fn = flags.DEFINE_multi_string
       else:
-        raise NotImplementedError('Cannot handle flag [{!s}] of type [{!s}] on '
-                                  'optimizers [{!s}]'.format(
-                                      param.name, type(param.default),
-                                      optimizer_name))
+        raise NotImplementedError('Cannot define flag [{!s}] '
+                                  'for parameter [{!s}] of type [{!s}] '
+                                  '(default value type [{!s}]) '
+                                  'on optimizer [{!s}]'.format(
+                                      prefixed(param.name),
+                                      param.name, param.annotation,
+                                      type(param.default), optimizer_name))
       define_flag_fn(
           name=prefixed(param.name),
           default=param.default,
-- 
GitLab