提交 90feb1a0 编辑于 作者: Yu Xiao's avatar Yu Xiao 提交者: tensorflow-copybara
浏览文件

Automated rollback of commit 2cec0ed7

PiperOrigin-RevId: 410631896
上级 238bf340
...@@ -242,21 +242,8 @@ def federated_aggregate_keras_metric( ...@@ -242,21 +242,8 @@ def federated_aggregate_keras_metric(
# If type(metric) is subclass of another tf.keras.metric arguments passed # If type(metric) is subclass of another tf.keras.metric arguments passed
# to __init__ must include arguments expected by the superclass and # to __init__ must include arguments expected by the superclass and
# specified in superclass get_config(). # specified in superclass get_config().
# TODO(b/197746608): finds a safer way of reconstructing the metric, finalizer.check_keras_metric_config_constructable(metric)
# default argument values in Metric constructors can cause problems here. keras_metric = type(metric).from_config(metric.get_config())
keras_metric = None
try:
# This is some trickery to reconstruct a metric object in the current
# scope, so that the `tf.Variable`s get created when we desire.
keras_metric = type(metric).from_config(metric.get_config())
except TypeError as e:
# Re-raise the error with a more helpful message, but the previous stack
# trace.
raise TypeError(
'Caught exception trying to call `{t}.from_config()` with '
'config {c}. Confirm that {t}.__init__() has an argument for '
'each member of the config.\nException: {e}'.format(
t=type(metric), c=metric.get_config(), e=e))
assignments = [] assignments = []
for v, a in zip(keras_metric.variables, values): for v, a in zip(keras_metric.variables, values):
......
...@@ -1020,6 +1020,57 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase): ...@@ -1020,6 +1020,57 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
metrics=[NumBatchesCounter(), metrics=[NumBatchesCounter(),
NumExamplesCounter()]) NumExamplesCounter()])
def test_custom_keras_metric_with_extra_init_args_raises(self):
class CustomCounter(tf.keras.metrics.Sum):
"""A custom `tf.keras.metrics.Metric` with extra args in `__init__`."""
def __init__(self, name='new_counter', arg1=0, dtype=tf.int64):
super().__init__(name, dtype)
self._arg1 = arg1
def update_state(self, y_true, y_pred, sample_weight=None):
return super().update_state(1, sample_weight)
feature_dims = 3
keras_model = model_examples.build_linear_regression_keras_functional_model(
feature_dims)
with self.assertRaisesRegex(TypeError, 'extra arguments'):
keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_whimsy_types(feature_dims),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[CustomCounter(arg1=1)])
def test_custom_keras_metric_no_extra_init_args_builds(self):
class CustomCounter(tf.keras.metrics.Sum):
"""A custom `tf.keras.metrics.Metric` without extra args in `__init__`."""
def __init__(self, name='new_counter', arg1=0, dtype=tf.int64):
super().__init__(name, dtype)
self._arg1 = arg1
def update_state(self, y_true, y_pred, sample_weight=None):
return super().update_state(1, sample_weight)
def get_config(self):
config = super().get_config()
config['arg1'] = self._arg1
return config
feature_dims = 3
keras_model = model_examples.build_linear_regression_keras_functional_model(
feature_dims)
tff_model = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_whimsy_types(feature_dims),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[CustomCounter(arg1=1)])
self.assertIsInstance(tff_model, model_lib.Model)
if __name__ == '__main__': if __name__ == '__main__':
execution_contexts.set_local_python_execution_context() execution_contexts.set_local_python_execution_context()
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
"""Helper functions for creating metric finalizers.""" """Helper functions for creating metric finalizers."""
import inspect
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
import tensorflow as tf import tensorflow as tf
...@@ -56,6 +58,7 @@ def create_keras_metric_finalizer( ...@@ -56,6 +58,7 @@ def create_keras_metric_finalizer(
# use `keras_metric.result()`. # use `keras_metric.result()`.
with tf.init_scope(): with tf.init_scope():
if isinstance(metric, tf.keras.metrics.Metric): if isinstance(metric, tf.keras.metrics.Metric):
check_keras_metric_config_constructable(metric)
keras_metric = type(metric).from_config(metric.get_config()) keras_metric = type(metric).from_config(metric.get_config())
elif callable(metric): elif callable(metric):
keras_metric = metric() keras_metric = metric()
...@@ -89,3 +92,50 @@ def create_keras_metric_finalizer( ...@@ -89,3 +92,50 @@ def create_keras_metric_finalizer(
return keras_metric.result() return keras_metric.result()
return finalizer return finalizer
def check_keras_metric_config_constructable(
metric: tf.keras.metrics.Metric) -> tf.keras.metrics.Metric:
"""Checks that a Keras metric is constructable from the `get_config()` method.
Args:
metric: A single `tf.keras.metrics.Metric`.
Returns:
The metric.
Raises:
TypeError: If the metric is not an instance of `tf.keras.metrics.Metric`, if
the metric is not constructable from the `get_config()` method.
"""
if not isinstance(metric, tf.keras.metrics.Metric):
raise TypeError(f'Metric {type(metric)} is not a `tf.keras.metrics.Metric` '
'to be constructable from the `get_config()` method.')
metric_type_str = type(metric).__name__
if hasattr(tf.keras.metrics, metric_type_str):
return metric
init_args = inspect.getfullargspec(metric.__init__).args
init_args.remove('self')
get_config_args = metric.get_config().keys()
extra_args = [arg for arg in init_args if arg not in get_config_args]
if extra_args:
# TODO(b/197746608): Updates the error message to redirect users to use
# metric constructors instead of constructed metrics when we support both
# cases in `from_keras_model`.
raise TypeError(f'Metric {metric_type_str} is not constructable from '
'the `get_config()` method, because `__init__` takes extra '
'arguments that are not included in the `get_config()`: '
f'{extra_args}. Override or update the `get_config()` in '
'the metric class to include these extra arguments.\n'
'Example:\n'
'class CustomMetric(tf.keras.metrics.Metric):\n'
' def __init__(self, arg1):\n'
' self._arg1 = arg1\n\n'
' def get_config(self)\n'
' config = super().get_config()\n'
' config[\'arg1\'] = self._arg1\n'
' return config')
return metric
...@@ -82,6 +82,17 @@ class CustomSumMetric(tf.keras.metrics.Sum): ...@@ -82,6 +82,17 @@ class CustomSumMetric(tf.keras.metrics.Sum):
return config return config
class CustomCounter(tf.keras.metrics.Sum):
"""A custom `tf.keras.metrics.Metric` with extra arguments in `__init__`."""
def __init__(self, name='new_metric', arg1=0, dtype=tf.int64):
super().__init__(name, dtype)
self._arg1 = arg1
def update_state(self, y_true, y_pred, sample_weight=None):
return super().update_state(1, sample_weight)
class FinalizerTest(parameterized.TestCase, tf.test.TestCase): class FinalizerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -115,8 +126,10 @@ class FinalizerTest(parameterized.TestCase, tf.test.TestCase): ...@@ -115,8 +126,10 @@ class FinalizerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
('tensor', tf.constant(1.0), 'found a non-callable'), ('tensor', tf.constant(1.0), 'found a non-callable'),
('loss_constructor', tf.keras.losses.MeanSquaredError, 'found a callable') ('loss_constructor', tf.keras.losses.MeanSquaredError,
) 'found a callable'), # go/pyformat-break
('custom_metric_with_extra_init_args', CustomCounter(arg1=1),
'extra arguments'))
def test_create_keras_metric_finalizer_fails_with_invalid_input( def test_create_keras_metric_finalizer_fails_with_invalid_input(
self, invalid_metric, error_message): self, invalid_metric, error_message):
unused_type = [tf.TensorSpec(shape=[], dtype=tf.float32)] unused_type = [tf.TensorSpec(shape=[], dtype=tf.float32)]
......
...@@ -40,6 +40,11 @@ class NumTokensCounter(tf.keras.metrics.Sum): ...@@ -40,6 +40,11 @@ class NumTokensCounter(tf.keras.metrics.Sum):
sample_weight = tf.reshape(sample_weight, [-1]) sample_weight = tf.reshape(sample_weight, [-1])
super().update_state(sample_weight) super().update_state(sample_weight)
def get_config(self):
config = super().get_config()
config['masked_tokens'] = tuple(self._masked_tokens)
return config
class MaskedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy): class MaskedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
"""An accuracy metric that masks some tokens.""" """An accuracy metric that masks some tokens."""
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册