Commit 621f8edd authored by Yu Xiao's avatar Yu Xiao Committed by tensorflow-copybara
Browse files

Make `from_keras_model` able to take metric constructors.

PiperOrigin-RevId: 411842768
parent 0f2d1f21
......@@ -14,7 +14,7 @@
"""Utility methods for working with Keras in TensorFlow Federated."""
import collections
from typing import List, Optional, OrderedDict, Sequence, Union
from typing import Callable, List, Optional, OrderedDict, Sequence, Union
import warnings
import tensorflow as tf
......@@ -33,12 +33,16 @@ from tensorflow_federated.python.learning.metrics import finalizer
Loss = Union[tf.keras.losses.Loss, List[tf.keras.losses.Loss]]
# TODO(b/197746608): Remove the code path that takes in constructed Keras
# metrics, because reconstructing metrics via `from_config` can cause problems.
def from_keras_model(
keras_model: tf.keras.Model,
loss: Loss,
input_spec,
loss_weights: Optional[List[float]] = None,
metrics: Optional[List[tf.keras.metrics.Metric]] = None) -> model_lib.Model:
metrics: Optional[Union[List[tf.keras.metrics.Metric],
List[Callable[[], tf.keras.metrics.Metric]]]] = None
) -> model_lib.Model:
"""Builds a `tff.learning.Model` from a `tf.keras.Model`.
The `tff.learning.Model` returned by this function uses `keras_model` for
......@@ -83,7 +87,8 @@ def from_keras_model(
loss_weights: (Optional) A list of Python floats used to weight the loss
contribution of each model output (when providing a list of losses for the
`loss` argument).
metrics: (Optional) a list of `tf.keras.metrics.Metric` objects.
metrics: (Optional) a list of `tf.keras.metrics.Metric` objects or a list of
no-arg callables that each constructs a `tf.keras.metrics.Metric`.
Returns:
A `tff.learning.Model` object.
......@@ -171,8 +176,6 @@ def from_keras_model(
metrics = []
else:
py_typecheck.check_type(metrics, list)
for metric in metrics:
py_typecheck.check_type(metric, tf.keras.metrics.Metric)
for layer in keras_model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
......@@ -191,8 +194,10 @@ def from_keras_model(
def federated_aggregate_keras_metric(
metrics: Union[tf.keras.metrics.Metric,
Sequence[tf.keras.metrics.Metric]], federated_values):
metrics: Union[tf.keras.metrics.Metric, Sequence[tf.keras.metrics.Metric],
Callable[[], tf.keras.metrics.Metric],
Sequence[Callable[[], tf.keras.metrics.Metric]]],
federated_values):
"""Aggregates variables a keras metric placed at CLIENTS to SERVER.
Args:
......@@ -232,8 +237,11 @@ def federated_aggregate_keras_metric(
def report(accumulators):
"""Insert `accumulators` back into the keras metric to obtain result."""
def finalize_metric(metric: tf.keras.metrics.Metric, values):
# Note: the following call requires that `type(metric)` have a no argument
def finalize_metric(metric: Union[tf.keras.metrics.Metric,
Callable[[], tf.keras.metrics.Metric]],
values):
# Note: if the input metric is an instance of `tf.keras.metrics.Metric`,
# the following call requires that `type(metric)` have a no argument
# __init__ method, which will restrict the types of metrics that can be
# used. This is somewhat limiting, but the pattern to use default
# arguments and export the values in `get_config()` (see
......@@ -242,8 +250,7 @@ def federated_aggregate_keras_metric(
# If type(metric) is subclass of another tf.keras.metric arguments passed
# to __init__ must include arguments expected by the superclass and
# specified in superclass get_config().
finalizer.check_keras_metric_config_constructable(metric)
keras_metric = type(metric).from_config(metric.get_config())
keras_metric = finalizer.create_keras_metric(metric)
assignments = []
for v, a in zip(keras_metric.variables, values):
......@@ -270,12 +277,45 @@ class _KerasModel(model_lib.Model):
def __init__(self, keras_model: tf.keras.Model, input_spec,
loss_fns: List[tf.keras.losses.Loss], loss_weights: List[float],
metrics: List[tf.keras.metrics.Metric]):
metrics: Union[List[tf.keras.metrics.Metric],
List[Callable[[], tf.keras.metrics.Metric]]]):
self._keras_model = keras_model
self._input_spec = input_spec
self._loss_fns = loss_fns
self._loss_weights = loss_weights
self._metrics = metrics
self._metrics = []
self._metric_constructors = []
if metrics:
has_keras_metric = False
has_keras_metric_constructor = False
for metric in metrics:
if isinstance(metric, tf.keras.metrics.Metric):
self._metrics.append(metric)
has_keras_metric = True
elif callable(metric):
constructed_metric = metric()
if not isinstance(constructed_metric, tf.keras.metrics.Metric):
raise TypeError(
f'Metric constructor {metric} is not a no-arg callable that '
'creates a `tf.keras.metrics.Metric`.')
self._metric_constructors.append(metric)
self._metrics.append(constructed_metric)
has_keras_metric_constructor = True
else:
raise TypeError(
'Expected the input metric to be either a '
'`tf.keras.metrics.Metric` or a no-arg callable that constructs '
'a `tf.keras.metrics.Metric`, found a non-callable '
f'{py_typecheck.type_string(type(metric))}.')
if has_keras_metric and has_keras_metric_constructor:
raise TypeError(
'Expected the input `metrics` to be either a list of '
'`tf.keras.metrics.Metric` objects or a list of no-arg callables '
'that each constructs a `tf.keras.metrics.Metric`, '
f'found both types in the `metrics`: {metrics}.')
# This is defined here so that it closes over the `loss_fn`.
class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
......@@ -302,7 +342,9 @@ class _KerasModel(model_lib.Model):
return super().update_state(batch_loss, batch_size)
self._loss_metric = _WeightedMeanLossMetric()
self._metrics.append(_WeightedMeanLossMetric())
if not metrics or self._metric_constructors:
self._metric_constructors.append(_WeightedMeanLossMetric)
metric_variable_type_dict = tf.nest.map_structure(
tf.TensorSpec.from_tensor, self.report_local_outputs())
......@@ -310,6 +352,9 @@ class _KerasModel(model_lib.Model):
metric_variable_type_dict, placements.CLIENTS)
def federated_output(local_outputs):
if self._metric_constructors:
return federated_aggregate_keras_metric(self._metric_constructors,
local_outputs)
return federated_aggregate_keras_metric(self.get_metrics(), local_outputs)
self._federated_output_computation = computations.federated_computation(
......@@ -331,7 +376,7 @@ class _KerasModel(model_lib.Model):
return local_variables
def get_metrics(self):
return self._metrics + [self._loss_metric]
return self._metrics
@property
def input_spec(self):
......
......@@ -1070,6 +1070,128 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
self.assertIsInstance(tff_model, model_lib.Model)
@parameterized.named_parameters(
# Test cases for the cartesian product of all parameter values.
*_create_tff_model_from_keras_model_tuples())
def test_keras_model_with_metric_constructors(self, feature_dims, model_fn):
keras_model = model_fn(feature_dims)
tff_model = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_whimsy_types(feature_dims),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[NumBatchesCounter, NumExamplesCounter])
self.assertIsInstance(tff_model, model_lib.Model)
# Metrics should be zero, though the model wrapper internally executes the
# forward pass once.
self.assertSequenceEqual(tff_model.local_variables, [0, 0, 0.0, 0.0])
batch = collections.OrderedDict(
x=np.stack([
np.zeros(feature_dims, np.float32),
np.ones(feature_dims, np.float32)
]),
y=[[0.0], [1.0]])
# from_model() was called without an optimizer which creates a tff.Model.
# There is no train_on_batch() method available in tff.Model.
with self.assertRaisesRegex(AttributeError,
'no attribute \'train_on_batch\''):
tff_model.train_on_batch(batch)
output = tff_model.forward_pass(batch)
# Since the model initializes all weights and biases to zero, we expect
# all predictions to be zero:
# 0*x1 + 0*x2 + ... + 0 = 0
self.assertAllEqual(output.predictions, [[0.0], [0.0]])
# For the single batch:
#
# Example | Prediction | Label | Residual | Loss
# --------+------------+-------+----------+ -----
# 1 | 0.0 | 0.0 | 0.0 | 0.0
# 2 | 0.0 | 1.0 | 1.0 | 1.0
#
# Note that though regularization might be applied, this has no effect on
# the loss since all weights are 0.
# Total loss: 1.0
# Batch average loss: 0.5
self.assertEqual(output.loss, 0.5)
self.assertAllEqual(tff_model.report_local_outputs(),
tff_model.report_local_unfinalized_metrics())
metrics = tff_model.report_local_unfinalized_metrics()
self.assertEqual(metrics['num_batches'], [1])
self.assertEqual(metrics['num_examples'], [2])
self.assertGreater(metrics['loss'][0], 0)
self.assertEqual(metrics['loss'][1], 2)
@parameterized.named_parameters(
# Test cases for the cartesian product of all parameter values.
*_create_tff_model_from_keras_model_tuples())
def test_keras_model_without_input_metrics(self, feature_dims, model_fn):
keras_model = model_fn(feature_dims)
tff_model = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_whimsy_types(feature_dims),
loss=tf.keras.losses.MeanSquaredError())
self.assertIsInstance(tff_model, model_lib.Model)
# Metrics should be zero, though the model wrapper internally executes the
# forward pass once.
self.assertSequenceEqual(tff_model.local_variables, [0, 0])
batch = collections.OrderedDict(
x=np.stack([
np.zeros(feature_dims, np.float32),
np.ones(feature_dims, np.float32)
]),
y=[[0.0], [1.0]])
# from_model() was called without an optimizer which creates a tff.Model.
# There is no train_on_batch() method available in tff.Model.
with self.assertRaisesRegex(AttributeError,
'no attribute \'train_on_batch\''):
tff_model.train_on_batch(batch)
output = tff_model.forward_pass(batch)
# Since the model initializes all weights and biases to zero, we expect
# all predictions to be zero:
# 0*x1 + 0*x2 + ... + 0 = 0
self.assertAllEqual(output.predictions, [[0.0], [0.0]])
# For the single batch:
#
# Example | Prediction | Label | Residual | Loss
# --------+------------+-------+----------+ -----
# 1 | 0.0 | 0.0 | 0.0 | 0.0
# 2 | 0.0 | 1.0 | 1.0 | 1.0
#
# Note that though regularization might be applied, this has no effect on
# the loss since all weights are 0.
# Total loss: 1.0
# Batch average loss: 0.5
self.assertEqual(output.loss, 0.5)
self.assertAllEqual(tff_model.report_local_outputs(),
tff_model.report_local_unfinalized_metrics())
metrics = tff_model.report_local_unfinalized_metrics()
self.assertGreater(metrics['loss'][0], 0)
self.assertEqual(metrics['loss'][1], 2)
@parameterized.named_parameters(
('both_metrics_and_constructors',
[NumExamplesCounter, NumBatchesCounter()], 'found both types'),
('non_callable', [tf.constant(1.0)], 'found a non-callable'),
('non_keras_metric_constructor', [tf.keras.losses.MeanSquaredError
], 'not a no-arg callable'))
def test_keras_model_provided_invalid_metrics_raises(self, metrics,
error_message):
feature_dims = 3
keras_model = model_examples.build_linear_regression_keras_functional_model(
feature_dims)
with self.assertRaisesRegex(TypeError, error_message):
keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_whimsy_types(feature_dims),
loss=tf.keras.losses.MeanSquaredError(),
metrics=metrics)
if __name__ == '__main__':
execution_contexts.set_local_python_execution_context()
......
......@@ -57,22 +57,7 @@ def create_keras_metric_finalizer(
# we need the `tf.Variable`s to be created in the current scope in order to
# use `keras_metric.result()`.
with tf.init_scope():
if isinstance(metric, tf.keras.metrics.Metric):
check_keras_metric_config_constructable(metric)
keras_metric = type(metric).from_config(metric.get_config())
elif callable(metric):
keras_metric = metric()
if not isinstance(keras_metric, tf.keras.metrics.Metric):
raise TypeError(
'Expected input `metric` to be either a `tf.keras.metrics.Metric`'
' or a no-arg callable that creates a `tf.keras.metrics.Metric`, '
'found a callable that returns a '
f'{py_typecheck.type_string(type(keras_metric))}.')
else:
raise TypeError(
'Expected input `metric` to be either a `tf.keras.metrics.Metric` '
'or a no-arg callable that constructs a `tf.keras.metrics.Metric`, '
f'found a non-callable {py_typecheck.type_string(type(metric))}.')
keras_metric = create_keras_metric(metric)
py_typecheck.check_type(unfinalized_metric_values, list)
if len(keras_metric.variables) != len(unfinalized_metric_values):
raise ValueError(
......@@ -94,16 +79,12 @@ def create_keras_metric_finalizer(
return finalizer
def check_keras_metric_config_constructable(
metric: tf.keras.metrics.Metric) -> tf.keras.metrics.Metric:
def _check_keras_metric_config_constructable(metric: tf.keras.metrics.Metric):
"""Checks that a Keras metric is constructable from the `get_config()` method.
Args:
metric: A single `tf.keras.metrics.Metric`.
Returns:
The metric.
Raises:
TypeError: If the metric is not an instance of `tf.keras.metrics.Metric`, if
the metric is not constructable from the `get_config()` method.
......@@ -114,28 +95,64 @@ def check_keras_metric_config_constructable(
metric_type_str = type(metric).__name__
if hasattr(tf.keras.metrics, metric_type_str):
return metric
init_args = inspect.getfullargspec(metric.__init__).args
init_args.remove('self')
get_config_args = metric.get_config().keys()
extra_args = [arg for arg in init_args if arg not in get_config_args]
if extra_args:
# TODO(b/197746608): Updates the error message to redirect users to use
# metric constructors instead of constructed metrics when we support both
# cases in `from_keras_model`.
raise TypeError(f'Metric {metric_type_str} is not constructable from '
'the `get_config()` method, because `__init__` takes extra '
'arguments that are not included in the `get_config()`: '
f'{extra_args}. Override or update the `get_config()` in '
'the metric class to include these extra arguments.\n'
'Example:\n'
'class CustomMetric(tf.keras.metrics.Metric):\n'
' def __init__(self, arg1):\n'
' self._arg1 = arg1\n\n'
' def get_config(self)\n'
' config = super().get_config()\n'
' config[\'arg1\'] = self._arg1\n'
' return config')
return metric
if not hasattr(tf.keras.metrics, metric_type_str):
init_args = inspect.getfullargspec(metric.__init__).args
init_args.remove('self')
get_config_args = metric.get_config().keys()
extra_args = [arg for arg in init_args if arg not in get_config_args]
if extra_args:
# TODO(b/197746608): Remove the suggestion of updating `get_config` if
# that code path is removed.
raise TypeError(
f'Metric {metric_type_str} is not constructable from the '
'`get_config()` method, because `__init__` takes extra arguments '
f'that are not included in the `get_config()`: {extra_args}. '
'Pass the metric constructor instead, or update the `get_config()` '
'in the metric class to include these extra arguments.\n'
'Example:\n'
'class CustomMetric(tf.keras.metrics.Metric):\n'
' def __init__(self, arg1):\n'
' self._arg1 = arg1\n\n'
' def get_config(self)\n'
' config = super().get_config()\n'
' config[\'arg1\'] = self._arg1\n'
' return config')
def create_keras_metric(
metric: Union[tf.keras.metrics.Metric, Callable[[],
tf.keras.metrics.Metric]]
) -> tf.keras.metrics.Metric:
"""Create a `tf.keras.metrics.Metric` from a `tf.keras.metrics.Metric`.
So the `tf.Variable`s in the metric can get created in the right scope in TFF.
Args:
metric: A single `tf.keras.metrics.Metric` or a no-arg callable that creates
a `tf.keras.metrics.Metric`.
Returns:
A `tf.keras.metrics.Metric` object.
Raises:
TypeError: If input metric is neither a `tf.keras.metrics.Metric` or a
no-arg callable that creates a `tf.keras.metrics.Metric`.
"""
keras_metric = None
if isinstance(metric, tf.keras.metrics.Metric):
_check_keras_metric_config_constructable(metric)
keras_metric = type(metric).from_config(metric.get_config())
elif callable(metric):
keras_metric = metric()
if not isinstance(keras_metric, tf.keras.metrics.Metric):
raise TypeError(
'Expected input `metric` to be either a `tf.keras.metrics.Metric` '
'or a no-arg callable that creates a `tf.keras.metrics.Metric`, '
'found a callable that returns a '
f'{py_typecheck.type_string(type(keras_metric))}.')
else:
raise TypeError(
'Expected input `metric` to be either a `tf.keras.metrics.Metric` '
'or a no-arg callable that constructs a `tf.keras.metrics.Metric`, '
f'found a non-callable {py_typecheck.type_string(type(metric))}.')
return keras_metric
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment