Commit 2cec0ed7 authored by Yu Xiao's avatar Yu Xiao Committed by tensorflow-copybara
Browse files

Automated rollback of commit 9aee155c

PiperOrigin-RevId: 410312882
parent 9aee155c
......@@ -242,8 +242,21 @@ def federated_aggregate_keras_metric(
# If type(metric) is subclass of another tf.keras.metric arguments passed
# to __init__ must include arguments expected by the superclass and
# specified in superclass get_config().
finalizer.check_keras_metric_config_constructable(metric)
keras_metric = type(metric).from_config(metric.get_config())
# TODO(b/197746608): finds a safer way of reconstructing the metric,
# default argument values in Metric constructors can cause problems here.
keras_metric = None
try:
# This is some trickery to reconstruct a metric object in the current
# scope, so that the `tf.Variable`s get created when we desire.
keras_metric = type(metric).from_config(metric.get_config())
except TypeError as e:
# Re-raise the error with a more helpful message, but the previous stack
# trace.
raise TypeError(
'Caught exception trying to call `{t}.from_config()` with '
'config {c}. Confirm that {t}.__init__() has an argument for '
'each member of the config.\nException: {e}'.format(
t=type(metric), c=metric.get_config(), e=e))
assignments = []
for v, a in zip(keras_metric.variables, values):
......
......@@ -1020,57 +1020,6 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
metrics=[NumBatchesCounter(),
NumExamplesCounter()])
def test_custom_keras_metric_with_extra_init_args_raises(self):
class CustomCounter(tf.keras.metrics.Sum):
"""A custom `tf.keras.metrics.Metric` with extra args in `__init__`."""
def __init__(self, name='new_counter', arg1=0, dtype=tf.int64):
super().__init__(name, dtype)
self._arg1 = arg1
def update_state(self, y_true, y_pred, sample_weight=None):
return super().update_state(1, sample_weight)
feature_dims = 3
keras_model = model_examples.build_linear_regression_keras_functional_model(
feature_dims)
with self.assertRaisesRegex(TypeError, 'extra arguments'):
keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_whimsy_types(feature_dims),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[CustomCounter(arg1=1)])
def test_custom_keras_metric_no_extra_init_args_builds(self):
class CustomCounter(tf.keras.metrics.Sum):
"""A custom `tf.keras.metrics.Metric` without extra args in `__init__`."""
def __init__(self, name='new_counter', arg1=0, dtype=tf.int64):
super().__init__(name, dtype)
self._arg1 = arg1
def update_state(self, y_true, y_pred, sample_weight=None):
return super().update_state(1, sample_weight)
def get_config(self):
config = super().get_config()
config['arg1'] = self._arg1
return config
feature_dims = 3
keras_model = model_examples.build_linear_regression_keras_functional_model(
feature_dims)
tff_model = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_whimsy_types(feature_dims),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[CustomCounter(arg1=1)])
self.assertIsInstance(tff_model, model_lib.Model)
if __name__ == '__main__':
execution_contexts.set_local_python_execution_context()
......
......@@ -13,8 +13,6 @@
# limitations under the License.
"""Helper functions for creating metric finalizers."""
import inspect
from typing import Any, Callable, List, Union
import tensorflow as tf
......@@ -58,7 +56,6 @@ def create_keras_metric_finalizer(
# use `keras_metric.result()`.
with tf.init_scope():
if isinstance(metric, tf.keras.metrics.Metric):
check_keras_metric_config_constructable(metric)
keras_metric = type(metric).from_config(metric.get_config())
elif callable(metric):
keras_metric = metric()
......@@ -92,47 +89,3 @@ def create_keras_metric_finalizer(
return keras_metric.result()
return finalizer
def check_keras_metric_config_constructable(
metric: tf.keras.metrics.Metric) -> tf.keras.metrics.Metric:
"""Checks that a Keras metric is constructable from the `get_config()` method.
Args:
metric: A single `tf.keras.metrics.Metric`.
Returns:
The metric.
Raises:
TypeError: If the metric is not an instance of `tf.keras.metrics.Metric`, if
the metric is not constructable from the `get_config()` method.
"""
if not isinstance(metric, tf.keras.metrics.Metric):
raise TypeError(f'Metric {type(metric)} is not a `tf.keras.metrics.Metric` '
'to be constructable from the `get_config()` method.')
metric_type_str = type(metric).__name__
if hasattr(tf.keras.metrics, metric_type_str):
return metric
init_args = inspect.getfullargspec(metric.__init__).args
init_args.remove('self')
get_config_args = metric.get_config().keys()
extra_args = [arg for arg in init_args if arg not in get_config_args]
if extra_args:
raise TypeError(f'Metric {metric_type_str} is not constructable from '
'the `get_config()` method, because `__init__` takes extra '
'arguments that are not included in the `get_config()`: '
f'{extra_args}. Override or update the `get_config()` in '
'the metric class to include these extra arguments.\n'
'Example:\n'
'class CustomMetric(tf.keras.metrics.Metric):\n'
' def __init__(self, arg1):\n'
' self._arg1 = arg1\n\n'
' def get_config(self)\n'
' config = super().get_config()\n'
' config[\'arg1\'] = self._arg1\n'
' return config')
return metric
......@@ -82,17 +82,6 @@ class CustomSumMetric(tf.keras.metrics.Sum):
return config
class CustomCounter(tf.keras.metrics.Sum):
"""A custom `tf.keras.metrics.Metric` with extra arguments in `__init__`."""
def __init__(self, name='new_metric', arg1=0, dtype=tf.int64):
super().__init__(name, dtype)
self._arg1 = arg1
def update_state(self, y_true, y_pred, sample_weight=None):
return super().update_state(1, sample_weight)
class FinalizerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters(
......@@ -126,10 +115,8 @@ class FinalizerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters(
('tensor', tf.constant(1.0), 'found a non-callable'),
('loss_constructor', tf.keras.losses.MeanSquaredError,
'found a callable'), # go/pyformat-break
('custom_metric_with_extra_init_args', CustomCounter(arg1=1),
'extra arguments'))
('loss_constructor', tf.keras.losses.MeanSquaredError, 'found a callable')
)
def test_create_keras_metric_finalizer_fails_with_invalid_input(
self, invalid_metric, error_message):
unused_type = [tf.TensorSpec(shape=[], dtype=tf.float32)]
......
......@@ -40,11 +40,6 @@ class NumTokensCounter(tf.keras.metrics.Sum):
sample_weight = tf.reshape(sample_weight, [-1])
super().update_state(sample_weight)
def get_config(self):
config = super().get_config()
config['masked_tokens'] = tuple(self._masked_tokens)
return config
class MaskedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
"""An accuracy metric that masks some tokens."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment