提交 5721e8f9 编辑于 作者: Sean Augenstein's avatar Sean Augenstein 提交者: tensorflow-copybara
浏览文件

Tweaks keras_utils to allow specifying a single loss with multiple model...

Tweaks keras_utils to allow specifying a single loss with multiple model outputs and/or labels, where the outputs/labels are pushed directly into the (probably custom user-defined) loss.

PiperOrigin-RevId: 347030237
上级 f49e5267
......@@ -46,16 +46,20 @@ def from_keras_model(
TFF needs a slightly different notion of "fully specified type" than
pure Keras does. That is, the model `M` takes inputs of type `x` and
produces predictions of type `p`; the loss function `L` takes inputs of type
`<p, y>` and produces a scalar. Therefore in order to fully specify the type
signatures for computations in which the generated `tff.learning.Model` will
appear, TFF needs the type `y` in addition to the type `x`.
`<p, y>` (where `y` is the ground truth label type) and produces a scalar.
Therefore in order to fully specify the type signatures for computations in
which the generated `tff.learning.Model` will appear, TFF needs the type `y`
in addition to the type `x`.
Args:
keras_model: A `tf.keras.Model` object that is not compiled.
loss: A `tf.keras.losses.Loss`, or a list of losses-per-output if the model
has multiple outputs. If multiple outputs are present, the model will
attempt to minimize the sum of all individual losses (optionally weighted
using the `loss_weights` argument).
loss: A single `tf.keras.losses.Loss` or a list of losses-per-output. If a
single loss is provided, then all model output (as well as all prediction
information) is passed to the loss; this includes situations of multiple
model outputs and/or predictions. If multiple losses are provided as a
list, then each loss is expected to correspond to a model output; the
model will attempt to minimize the sum of all individual losses
(optionally weighted using the `loss_weights` argument).
input_spec: A structure of `tf.TensorSpec`s or `tff.Type` specifying the
type of arguments the model expects. Notice this must be a compound
structure of two elements, specifying both the data fed into the model (x)
......@@ -63,21 +67,24 @@ def from_keras_model(
(y). If provided as a list, it must be in the order [x, y]. If provided as
a dictionary, the keys must explicitly be named `'x'` and `'y'`.
loss_weights: (Optional) A list of Python floats used to weight the loss
contribution of each model output.
contribution of each model output (when providing a list of losses for the
`loss` argument).
metrics: (Optional) a list of `tf.keras.metrics.Metric` objects.
Returns:
A `tff.learning.Model` object.
Raises:
TypeError: If `keras_model` is not instance of `tf.keras.Model`, if
`keras_model` has a single output and `loss` is not instance of
`tf.keras.losses.Loss`, or if `keras_model` has multiple outputs and
`loss` is not a list of instances of `tf.keras.losses.Loss`.
ValueError: If `keras_model` was compiled, if `keras_model` has multiple
outputs and `loss` is not list of equal length, if `input_spec` does not
contain exactly two elements, or if `input_spec` is a dictionary and does
not contain keys `'x'` and `'y'`.
TypeError: If `keras_model` is not an instance of `tf.keras.Model`, if
`loss` is not an instance of `tf.keras.losses.Loss` nor a list of
instances of `tf.keras.losses.Loss`, if `loss_weight` is provided but is
not a list of floats, or if `metrics` is provided but is not a list of
instances of `tf.keras.metrics.Metric`.
ValueError: If `keras_model` was compiled, if `loss` is a list of unequal
length to the number of outputs of `keras_model`, if `loss_weights` is
specified but `loss` is not a list, if `input_spec` does not contain
exactly two elements, or if `input_spec` is a dictionary and does not
contain keys `'x'` and `'y'`.
"""
# Validate `keras_model`
py_typecheck.check_type(keras_model, tf.keras.Model)
......@@ -85,19 +92,17 @@ def from_keras_model(
raise ValueError('`keras_model` must not be compiled')
# Validate and normalize `loss` and `loss_weights`
if len(keras_model.outputs) == 1:
if not isinstance(loss, list):
py_typecheck.check_type(loss, tf.keras.losses.Loss)
if loss_weights is not None:
raise ValueError('`loss_weights` cannot be used if `keras_model` has '
'only one output.')
raise ValueError('`loss_weights` cannot be used if `loss` is not a list.')
loss = [loss]
loss_weights = [1.0]
else:
py_typecheck.check_type(loss, list)
if len(loss) != len(keras_model.outputs):
raise ValueError('`keras_model` must have equal number of '
'outputs and losses.\nloss: {}\nof length: {}.'
'\noutputs: {}\nof length: {}.'.format(
raise ValueError('If a loss list is provided, `keras_model` must have '
'equal number of outputs to the losses.\nloss: {}\nof '
'length: {}.\noutputs: {}\nof length: {}.'.format(
loss, len(loss), keras_model.outputs,
len(keras_model.outputs)))
for loss_fn in loss:
......@@ -263,11 +268,14 @@ class _KerasModel(model_lib.Model):
self._loss_weights = loss_weights
def update_state(self, y_true, y_pred, sample_weight=None):
if len(self._loss_fns) == 1:
if isinstance(y_pred, list):
batch_size = tf.shape(y_pred[0])[0]
else:
batch_size = tf.shape(y_pred)[0]
if len(self._loss_fns) == 1:
batch_loss = self._loss_fns[0](y_true, y_pred)
else:
batch_size = tf.shape(y_pred[0])[0]
batch_loss = tf.zeros(())
for i in range(len(self._loss_fns)):
batch_loss += self._loss_weights[i] * self._loss_fns[i](y_true[i],
......
......@@ -584,7 +584,7 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
])
self.assertIsInstance(tff_model, model_utils.EnhancedModel)
dummy_batch = collections.OrderedDict(
example_batch = collections.OrderedDict(
x=[
np.zeros([1, 1], dtype=np.float32),
np.zeros([1, 1], dtype=np.float32)
......@@ -594,7 +594,26 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
np.ones([1, 1], dtype=np.float32),
np.ones([1, 1], dtype=np.float32)
])
output = tff_model.forward_pass(dummy_batch)
output = tff_model.forward_pass(example_batch)
self.assertAllClose(output.loss, 2.0)
class CustomLoss(tf.keras.losses.Loss):
def __init__(self):
super().__init__(name='custom_loss')
def call(self, y_true, y_pred):
loss = tf.constant(0.0)
for label, prediction in zip(y_true, y_pred):
loss += tf.keras.losses.MeanSquaredError()(label, prediction)
return loss
keras_model = model_examples.build_multiple_outputs_keras_model()
with self.subTest('single_custom_loss_can_work_with_multiple_outputs'):
tff_model = keras_utils.from_keras_model(
keras_model=keras_model, input_spec=input_spec, loss=CustomLoss())
output = tff_model.forward_pass(example_batch)
self.assertAllClose(output.loss, 2.0)
keras_model = model_examples.build_multiple_outputs_keras_model()
......@@ -609,10 +628,10 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
],
loss_weights=[0.1, 0.2, 0.3])
output = tff_model.forward_pass(dummy_batch)
output = tff_model.forward_pass(example_batch)
self.assertAllClose(output.loss, 0.5)
output = tff_model.forward_pass(dummy_batch)
output = tff_model.forward_pass(example_batch)
self.assertAllClose(output.loss, 0.5)
with self.subTest('loss_weights_assert_fail_list'):
......@@ -683,7 +702,7 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
])
self.assertIsInstance(tff_model, model_utils.EnhancedModel)
dummy_batch = collections.OrderedDict(
example_batch = collections.OrderedDict(
x=[
np.zeros([1, 1], dtype=np.float32),
np.zeros([1, 1], dtype=np.float32)
......@@ -693,7 +712,7 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
np.ones([1, 1], dtype=np.float32),
np.ones([1, 1], dtype=np.float32)
])
output = tff_model.forward_pass(dummy_batch)
output = tff_model.forward_pass(example_batch)
# Labels are (0, 1, 1), preds are (1, 1, 3).
# Total MSE is 1**2 + 0**2 + 2**2 = 5.
......@@ -719,14 +738,14 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
],
loss_weights=[0.1, 0.2, 0.3])
output = tff_model.forward_pass(dummy_batch)
output = tff_model.forward_pass(example_batch)
# Labels are (0, 1, 1), preds are (1, 1, 3).
# Weighted MSE is 0.1 * 1**2 + 0.2 * 0**2 + 0.3 * 2**2 = 1.3.
# Regularization loss is 0.11 as before, for a total loss of 1.41.
self.assertAllClose(output.loss, 1.41)
output = tff_model.forward_pass(dummy_batch)
output = tff_model.forward_pass(example_batch)
self.assertAllClose(output.loss, 1.41)
with self.subTest('loss_weights_assert_fail_list'):
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册