Skip to content
Snippets Groups Projects
Commit 470b85ea authored by Karan Singhal's avatar Karan Singhal Committed by tensorflow-copybara
Browse files

Include layer-wise losses in loss used for training step for TFF Keras models.

PiperOrigin-RevId: 321785542
parent 3c077d82
No related branches found
No related tags found
No related merge requests found
......@@ -276,7 +276,14 @@ class _KerasModel(model_lib.Model):
lf=loss_fns,
llf=len(loss_fns)))
self._loss_weights = loss_weights
# Ensure Keras model isn't compiled, possibly with a different loss or
# optimizer.
if inner_model._is_compiled: # pylint: disable=protected-access
raise ValueError('Keras model must be uncompiled, but got compiled Keras '
'model.')
self._keras_model = inner_model
self._metrics = metrics if metrics is not None else []
# This is defined here so that it closes over the `loss_fn`.
......@@ -363,10 +370,21 @@ class _KerasModel(model_lib.Model):
if y_true is not None:
if len(self._loss_fns) == 1:
loss_fn = self._loss_fns[0]
batch_loss = loss_fn(y_true=y_true, y_pred=predictions)
# Note: we add each of the per-layer regularization losses to the loss
# that we use to update trainable parameters, in addition to the
# user-provided loss function. Keras does the same in the
# `tf.keras.Model` training step. This is expected to have no effect if
# no per-layer losses are added to the model.
batch_loss = tf.add_n([loss_fn(y_true=y_true, y_pred=predictions)] +
self._keras_model.losses)
else:
batch_loss = tf.zeros(())
# Note: we add each of the per-layer regularization losses to the losses
# that we use to update trainable parameters, in addition to the
# user-provided loss functions. Keras does the same in the
# `tf.keras.Model` training step. This is expected to have no effect if
# no per-layer losses are added to the model.
batch_loss = tf.add_n([tf.zeros(())] + self._keras_model.losses)
for i in range(len(self._loss_fns)):
loss_fn = self._loss_fns[i]
loss_wt = self._loss_weights[i]
......
......@@ -64,6 +64,8 @@ def _create_tff_model_from_keras_model_tuples():
model_examples.build_linear_regression_keras_functional_model),
('sequential',
model_examples.build_linear_regression_keras_sequential_model),
('sequential_regularized', model_examples
.build_linear_regression_regularized_keras_sequential_model)
]:
tuples.append(('{}_model_{}_dims'.format(name, n_dims), n_dims, model_fn))
return tuples
......@@ -149,6 +151,8 @@ class KerasUtilsTest(test.TestCase, parameterized.TestCase):
# 1 | 0.0 | 0.0 | 0.0 | 0.0
# 2 | 0.0 | 1.0 | 1.0 | 1.0
#
# Note that though regularization might be applied, this has no effect on
# the loss since all weights are 0.
# Total loss: 1.0
# Batch average loss: 0.5
self.assertEqual(self.evaluate(output.loss), 0.5)
......@@ -158,6 +162,56 @@ class KerasUtilsTest(test.TestCase, parameterized.TestCase):
self.assertGreater(metrics['loss'][0], 0)
self.assertEqual(metrics['loss'][1], 2)
def test_tff_model_from_keras_model_regularization(self):
keras_model = model_examples.build_linear_regression_ones_regularized_keras_sequential_model(
3)
tff_model = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_dummy_types(3),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[NumBatchesCounter(), NumExamplesCounter()])
self.assertIsInstance(tff_model, model_utils.EnhancedModel)
# Metrics should be zero, though the model wrapper internally executes the
# forward pass once.
self.assertSequenceEqual(
self.evaluate(tff_model.local_variables), [0, 0, 0.0, 0.0])
batch = collections.OrderedDict(
x=np.stack([np.zeros(3, np.float32),
np.ones(3, np.float32)]),
y=[[0.0], [1.0]])
# from_model() was called without an optimizer which creates a tff.Model.
# There is no train_on_batch() method available in tff.Model.
with self.assertRaisesRegex(AttributeError,
'no attribute \'train_on_batch\''):
tff_model.train_on_batch(batch)
output = tff_model.forward_pass(batch)
# Since the model initializes all weights and biases to zero, we expect
# all predictions to be zero:
# 0*x1 + 0*x2 + ... + 0 = 0
self.assertAllEqual(output.predictions, [[1.0], [4.0]])
# For the single batch:
#
# Example | Prediction | Label | Residual | Loss
# --------+------------+-------+----------+ -----
# 1 | 1.0 | 0.0 | 1.0 | 1.0
# 2 | 4.0 | 1.0 | 3.0 | 9.0
#
# Regularization loss: with an L2 regularization constant of 0.01: kernel
# regularizer loss is (3 * 1**2) * 0.01, bias regularizer loss is
# 1**2 * 0.01, so total regularization loss is 0.04.
# Total loss: 10.0
# Batch average loss: 5.0
# Total batch loss with regularization: 5.04
self.assertAlmostEqual(self.evaluate(output.loss), 5.04)
metrics = self.evaluate(tff_model.report_local_outputs())
self.assertEqual(metrics['num_batches'], [1])
self.assertEqual(metrics['num_examples'], [2])
self.assertGreater(metrics['loss'][0], 0)
self.assertEqual(metrics['loss'][1], 2)
@parameterized.named_parameters(*_create_tff_model_from_keras_model_tuples())
def test_tff_model_from_keras_model_input_spec(self, feature_dims, model_fn):
keras_model = model_fn(feature_dims)
......@@ -515,6 +569,120 @@ class KerasUtilsTest(test.TestCase, parameterized.TestCase):
'dummy': 0.4
})
def test_regularized_keras_model_multiple_outputs(self):
keras_model = model_examples.build_multiple_outputs_regularized_keras_model(
)
input_spec = collections.OrderedDict(
x=[
tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
tf.TensorSpec(shape=[None, 1], dtype=tf.float32)
],
y=[
tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
tf.TensorSpec(shape=[None, 1], dtype=tf.float32)
])
with self.subTest('loss_output_len_mismatch'):
with self.assertRaises(ValueError):
_ = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=[
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError()
])
with self.subTest('invalid_loss'):
with self.assertRaises(TypeError):
_ = keras_utils.from_keras_model(
keras_model=keras_model, input_spec=input_spec, loss=3)
with self.subTest('loss_list_no_opt'):
tff_model = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=[
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError()
])
self.assertIsInstance(tff_model, model_utils.EnhancedModel)
dummy_batch = collections.OrderedDict(
x=[
np.zeros([1, 1], dtype=np.float32),
np.zeros([1, 1], dtype=np.float32)
],
y=[
np.zeros([1, 1], dtype=np.float32),
np.ones([1, 1], dtype=np.float32),
np.ones([1, 1], dtype=np.float32)
])
output = tff_model.forward_pass(dummy_batch)
# Labels are (0, 1, 1), preds are (1, 1, 3).
# Total MSE is 1**2 + 0**2 + 2**2 = 5.
# Since all weights are initialized to ones and regularization constant is
# 0.01, regularization loss is 0.01 * (num_params). There are 4 dense
# layers that take in one input and produce one output, and these each
# have a single weight and a single bias. There is one dense layer with
# two inputs and one output, so it has two weights and a single bias.
# So there are 11 params total and regularization loss is 0.11, for a
# total batch loss of 5.11.
self.assertAllClose(output.loss, 5.11)
keras_model = model_examples.build_multiple_outputs_regularized_keras_model(
)
with self.subTest('loss_weights_as_list'):
tff_model = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=[
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError()
],
loss_weights=[0.1, 0.2, 0.3])
output = tff_model.forward_pass(dummy_batch)
# Labels are (0, 1, 1), preds are (1, 1, 3).
# Weighted MSE is 0.1 * 1**2 + 0.2 * 0**2 + 0.3 * 2**2 = 1.3.
# Regularization loss is 0.11 as before, for a total loss of 1.41.
self.assertAllClose(output.loss, 1.41)
output = tff_model.forward_pass(dummy_batch)
self.assertAllClose(output.loss, 1.41)
with self.subTest('loss_weights_assert_fail_list'):
with self.assertRaises(ValueError):
_ = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=[
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError()
],
loss_weights=[0.1, 0.2])
with self.subTest('loss_weights_assert_fail_dict'):
with self.assertRaises(TypeError):
_ = keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=[
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError(),
tf.keras.losses.MeanSquaredError()
],
loss_weights={
'dense_5': 0.1,
'dense_6': 0.2,
'dummy': 0.4
})
def test_keras_model_lookup_table(self):
model = model_examples.build_lookup_table_keras_model()
input_spec = collections.OrderedDict(
......@@ -557,6 +725,27 @@ class KerasUtilsTest(test.TestCase, parameterized.TestCase):
self.evaluate(orig_model_output.loss),
self.evaluate(loaded_model_output.loss))
def test_keras_model_fails_compiled(self):
feature_dims = 3
keras_model = model_examples.build_linear_regression_keras_functional_model(
feature_dims)
keras_model.compile(loss=tf.keras.losses.MeanSquaredError())
with self.assertRaisesRegex(ValueError, 'compile'):
keras_utils.from_keras_model(
keras_model=keras_model,
input_spec=_create_dummy_types(feature_dims),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[NumBatchesCounter(),
NumExamplesCounter()])
with self.assertRaisesRegex(ValueError, 'compile'):
keras_utils._KerasModel(
keras_model,
input_spec=_create_dummy_types(feature_dims),
loss_fns=[tf.keras.losses.MeanSquaredError()])
if __name__ == '__main__':
test.main()
......@@ -151,6 +151,80 @@ def _dense_all_zeros_layer(input_dims=None, output_dim=1):
return build_keras_dense_layer()
def _dense_all_zeros_regularized_layer(input_dims=None,
output_dim=1,
regularization_constant=0.01):
"""Create a layer that can be used in isolation for linear regression.
Constructs a Keras dense layer with a single output, using biases and weights
that are initialized to zero. No activation function is applied. When this is
the only layer in a model, the model is effectively a linear regression model.
The regularization constant is used to scale L2 regularization on the weights
and bias.
Args:
input_dims: the integer length of the input to this layers. Maybe None if
the layer input size does not need to be specified.
output_dim: the integer length of the flattened output tensor. Defaults to
one, effectively making the layer perform linear regression.
regularization_constant: the float scaling magnitude (lambda) for L2
regularization on the layer's weights and bias.
Returns:
a `tf.keras.layers.Dense` object.
"""
regularizer = tf.keras.regularizers.l2(regularization_constant)
build_keras_dense_layer = functools.partial(
tf.keras.layers.Dense,
units=output_dim,
use_bias=True,
kernel_initializer='zeros',
bias_initializer='zeros',
kernel_regularizer=regularizer,
bias_regularizer=regularizer,
activation=None)
if input_dims is not None:
return build_keras_dense_layer(input_shape=(input_dims,))
return build_keras_dense_layer()
def _dense_all_ones_regularized_layer(input_dims=None,
output_dim=1,
regularization_constant=0.01):
"""Create a layer that can be used in isolation for linear regression.
Constructs a Keras dense layer with a single output, using biases and weights
that are initialized to ones. No activation function is applied. When this is
the only layer in a model, the model is effectively a linear regression model.
The regularization constant is used to scale L2 regularization on the weights
and bias.
Args:
input_dims: the integer length of the input to this layers. Maybe None if
the layer input size does not need to be specified.
output_dim: the integer length of the flattened output tensor. Defaults to
one, effectively making the layer perform linear regression.
regularization_constant: the float scaling magnitude (lambda) for L2
regularization on the layer's weights and bias.
Returns:
a `tf.keras.layers.Dense` object.
"""
regularizer = tf.keras.regularizers.l2(regularization_constant)
build_keras_dense_layer = functools.partial(
tf.keras.layers.Dense,
units=output_dim,
use_bias=True,
kernel_initializer='ones',
bias_initializer='ones',
kernel_regularizer=regularizer,
bias_regularizer=regularizer,
activation=None)
if input_dims is not None:
return build_keras_dense_layer(input_shape=(input_dims,))
return build_keras_dense_layer()
def build_linear_regression_keras_sequential_model(feature_dims=2):
"""Build a linear regression `tf.keras.Model` using the Sequential API."""
keras_model = tf.keras.models.Sequential()
......@@ -158,6 +232,26 @@ def build_linear_regression_keras_sequential_model(feature_dims=2):
return keras_model
def build_linear_regression_regularized_keras_sequential_model(
feature_dims=2, regularization_constant=0.01):
"""Build a linear regression `tf.keras.Model` using the Sequential API."""
keras_model = tf.keras.models.Sequential()
keras_model.add(
_dense_all_zeros_regularized_layer(
feature_dims, regularization_constant=regularization_constant))
return keras_model
def build_linear_regression_ones_regularized_keras_sequential_model(
feature_dims=2, regularization_constant=0.01):
"""Build a linear regression `tf.keras.Model` using the Sequential API."""
keras_model = tf.keras.models.Sequential()
keras_model.add(
_dense_all_ones_regularized_layer(
feature_dims, regularization_constant=regularization_constant))
return keras_model
def build_linear_regression_keras_functional_model(feature_dims=2):
"""Build a linear regression `tf.keras.Model` using the functional API."""
a = tf.keras.layers.Input(shape=(feature_dims,), dtype=tf.float32)
......@@ -260,6 +354,33 @@ def build_multiple_outputs_keras_model():
return tf.keras.Model(inputs=[a, b], outputs=[output_a, output_b, output_c])
def build_multiple_outputs_regularized_keras_model(
regularization_constant=0.01):
"""Builds a test model with three outputs.
All weights are initialized to ones.
Args:
regularization_constant: L2 scaling constant (lambda) for all weights and
biases.
Returns:
a `tf.keras.Model` object.
"""
dense = functools.partial(
_dense_all_ones_regularized_layer,
output_dim=1,
regularization_constant=regularization_constant)
a = tf.keras.layers.Input((1,))
b = tf.keras.layers.Input((1,))
output_a = dense()(a)
output_b = dense()(b)
output_c = dense()(tf.keras.layers.concatenate([dense()(a), dense()(b)]))
return tf.keras.Model(inputs=[a, b], outputs=[output_a, output_b, output_c])
def build_lookup_table_keras_model():
"""Builds a test model with three outputs."""
l = tf.keras.layers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment