提交 f9bab81f 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Introduce API `tff.learning.models.functional_model_from_keras`.

This is a method to create a `tff.learning.models.FunctionalModel` from a `tf.keras.Model`.

PiperOrigin-RevId: 411169398
上级 9db55cdf
......@@ -25,6 +25,7 @@ py_library(
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning/metrics:finalizer",
"//tensorflow_federated/python/tensorflow_libs:variable_utils",
],
)
......@@ -37,6 +38,7 @@ py_test(
":functional",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/tensorflow_libs:variable_utils",
],
)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Libraries for working with models in Federated Learning algorithms."""
from tensorflow_federated.python.learning.models.functional import functional_model_from_keras
from tensorflow_federated.python.learning.models.functional import FunctionalModel
from tensorflow_federated.python.learning.models.functional import model_from_functional
from tensorflow_federated.python.learning.models.serialization import load
......
......@@ -35,6 +35,7 @@ from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning.metrics import finalizer
from tensorflow_federated.python.tensorflow_libs import variable_utils
Weight = Union[np.ndarray, int, float]
WeightStruct = Union[Sequence[Weight], Mapping[str, Weight]]
......@@ -297,3 +298,202 @@ def model_from_functional(
) -> model_lib.Model:
"""Converts a `FunctionalModel` to a `tff.learning.Model`."""
return _ModelFromFunctional(functional_model, metric_constructors)
class KerasFunctionalModelError(Exception):
"""An error raised when a FunctionalModel backed by Keras is used outside TFF."""
def functional_model_from_keras(
keras_model: tf.keras.Model,
loss_fn: tf.keras.losses.Loss,
input_spec: Union[Sequence[Any], Mapping[str, Any]],
) -> FunctionalModel:
"""Converts a `tf.keras.Model` to a `tff.learning.models.FunctionalModel`.
NOTE: This method only supports models where calling that model with
`training=True` and `training=False` produce the same graph. Keras layers
such as batch normalization will fail because they require updating internal
state when `training=True` which is not suported.
IMPORTANT: The returned model must only be used in a graph context (for
example inside a `tff.tf_computation` decorated callable). It will raise an
error otherwise.
Args:
keras_model: A `tf.keras.Model` object, should be uncompiled. If compiled,
the metrics, optimizer, and loss function will be ignored. Note: models
that have multiple outputs will send all outputs to the `loss_fn`.
loss_fn: A `tf.keras.losses.Loss` object.
input_spec: A structure of `tf.TensorSpec` defining the input to the model.
Returns:
A `tff.learning.models.FunctionalModel`.
Raises:
KerasFunctionalModelError: the model has a batch normalization layer.
"""
# We're going to do something fancy here:
#
# 1. Get a copy of all the variables, in the order they are created during
# model construction, when in a graph context.
# 2. Use this ordering to construct a type signature of the model weights in
# such a way that we can inject TENSORS (those that are coming in as
# arguments) in place of variable creation during a call to
# `tf.keras.models.clone_model()`, which gives us a newly constructed Keras
# model in the context we want.
# 3. Profit by having variableless graphs!
#
# **WARNING** Caveats:
#
# 1. This model _must_ be used inside a graph context (e.g. a
# `tff.tf_computation` decorated callable, aka a `tff.Computation`). Keras
# appears to create extra variables in the eager context that are not part
# of the user specified model, and end up not being compatible.
#
# 2. We have found that this trick does NOT work with non-trainable variables
# that are updated during training. Namely layers such as
# BatchNormalization try to update means/variances during training and are
# not compatible with this approach. We generally recommend
# GroupNormalization in place of BatchNormalization at the current time.
#
# 3. This does not support multiple outputs with different loss functions, or
# laywerise regularization losses TODO(b/156629927).
for layer in keras_model.layers:
# There may be other layers that are problematic, at this time updating the
# mean/variance in batchnorm layer is the only known such instance.
if isinstance(layer, tf.keras.layers.BatchNormalization):
raise KerasFunctionalModelError(
'Keras model contains a batch normalization layer, which is '
'incompatible with `tff.learning.models.FunctionalModel`. Consider '
'using group normalization instead.')
if keras_model.non_trainable_variables:
raise KerasFunctionalModelError(
'Received a Keras model with non-trainable variables. Keras models with '
'non-trainable variables are currently not supported by FunctionalModel'
'. Most training algorithms (e.g. Federated Averaging) will not '
'aggregate them, and they are not updated locally by the optimizer. '
'We can relax this in the future if we have APIs that support updating '
'non-trainable varaibles.')
# Clone the keras model inside a graph context so that we only get the
# variables for the layers (otherwise keras adds other non-user variables). We
# also setup ops to inject the current model weights, because the cloned model
# will be re-initialized from scratch.
with tf.Graph().as_default() as g:
with variable_utils.record_variable_creation_scope() as captured_variables:
cloned_model = tf.keras.models.clone_model(keras_model)
if len(cloned_model.variables) != len(keras_model.variables):
raise KerasFunctionalModelError(
'The input Keras model is likely sharing variables across layers '
'which is unsupported. Cloning the model will duplicate these '
'variables and result in unexpected training gradients.')
# Ensure our cloned model has the same weights as the current model.
# We'll feed in the current model waits into the palceholders for
# assignmnet in a session below.
def assign_placeholder(v):
p = tf.compat.v1.placeholder(dtype=v.dtype)
return v.assign(p), p
assign_ops, placeholders = zip(
*(assign_placeholder(v) for v in cloned_model.variables))
trainable_variables = tuple(v for v in captured_variables if v.trainable)
non_trainable_variables = tuple(
v for v in captured_variables if not v.trainable)
# Here we get the initial weights from the incoming keras model in the order
# they are constructed; and also ensure that the values are set to the
# incoming model weights rather than their fresh initialization.
current_model_weights = tf.nest.map_structure(
lambda v: v.read_value().numpy(), keras_model.variables)
with tf.compat.v1.Session(graph=g) as sess:
sess.run(tf.compat.v1.initializers.variables(captured_variables))
sess.run(
fetches=assign_ops,
feed_dict=dict(zip(placeholders, current_model_weights)))
initial_weights = sess.run(
fetches=(trainable_variables, non_trainable_variables))
@tf.function
def predict_on_batch(model_weights: ModelWeights,
x: Any,
training: bool = True) -> Any:
with tf.init_scope():
if tf.executing_eagerly():
raise KerasFunctionalModelError(
'tf.keras.Model used as a FunctionalModel is only usable inside a '
'tff.tf_computation decorated callable or a graph context.')
# Make a copy of the weights container; can't mutate Python containers
# inside a tf.function.
trainable, non_trainable = (list(w) for w in model_weights)
# Here were intercept variable creation requests during the
# `tf.keras.models.clone_model()` call.
#
# Instead of forwarding the variable request to TF core and getting a
# `tf.Variable` back, we skip that and return only the `tf.Tensor` that
# corresponds to the `tf.Variable` recreation request (avoiding any variable
# creation). This works because TF operations that accept `tf.Variable`
# inputs automatically call `variable.read_value()` and then operate on that
# resulting tensor. We're relying on shortcutting that and providing the
# tensor straight away.
#
# For example, `tf.matmul` doesn't notice its input is `tf.Variable` or
# `tf.Tensor`:
#
# v = tf.Variable([[1], [2], [3]])
# tf.matmul(v, [[4, 5, 6]])
#
# and
#
# v = tf.constant([[1], [2], [3]])
# tf.matmul(v, [[4, 5, 6]])
#
# both result in:
#
# <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
# array([[ 4, 5, 6],
# [ 8, 10, 12],
# [12, 15, 18]], dtype=int32)>
def swap_tensor_parameter_for_variable(_, **kwargs):
if kwargs.get('trainable', True):
return trainable.pop(0)
else:
return non_trainable.pop(0)
with tf.variable_creator_scope(swap_tensor_parameter_for_variable):
variableless_model = tf.keras.models.clone_model(keras_model)
return variableless_model(x, training)
@tf.function
def forward_pass(model_weights: ModelWeights,
batch_input: Any,
training: bool = True) -> model_lib.BatchOutput:
if isinstance(batch_input, collections.abc.Mapping):
x = batch_input['x']
y = batch_input['y']
elif isinstance(batch_input, collections.abc.Sequence):
x, y = batch_input
else:
raise ValueError(
'`batch_input` must be either a mapping with keys `x` '
f'and `y` or a sequence of `(x, y)`. Got: {batch_input!r}')
predictions = predict_on_batch(model_weights, x, training)
batch_loss = loss_fn(y_true=y, y_pred=predictions)
# TODO(b/207033265): more work needed to support models with multiple loss
# functions.
def nrows(t):
return t.nrows() if isinstance(t, tf.RaggedTensor) else tf.shape(t)[0]
return model_lib.BatchOutput(
loss=batch_loss,
predictions=predictions,
num_examples=nrows(tf.nest.flatten(batch_input)[0]))
return FunctionalModel(
initial_weights=initial_weights,
forward_pass_fn=forward_pass,
predict_on_batch_fn=predict_on_batch,
input_spec=input_spec)
......@@ -20,6 +20,7 @@ import tensorflow as tf
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning.models import functional
from tensorflow_federated.python.tensorflow_libs import variable_utils
def initial_weights():
......@@ -56,6 +57,14 @@ def forward_pass(model_weights, batch_input, training):
loss=average_loss, predictions=predictions, num_examples=num_examples)
def create_test_keras_model():
return tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros')
])
def create_test_dataset():
"""Create a test dataset."""
......@@ -113,6 +122,19 @@ class FunctionalTest(tf.test.TestCase):
functional_model.predict_on_batch(functional_model.initial_weights,
example_batch[0]), [[0.]] * 5)
def test_predict_on_batch_keras_outside_graph_fails(self):
dataset = create_test_dataset()
example_batch = next(iter(dataset))
functional_model = functional.functional_model_from_keras(
keras_model=create_test_keras_model(),
loss_fn=tf.keras.losses.MeanSquaredError(),
input_spec=(tf.TensorSpec([None, 1], dtype=tf.float32),
tf.TensorSpec([None, 1], dtype=tf.int32)))
with self.assertRaisesRegex(functional.KerasFunctionalModelError,
'only usable inside a tff.tf_computation'):
functional_model.predict_on_batch(functional_model.initial_weights,
example_batch[0])
def test_forward_pass(self):
dataset = create_test_dataset()
example_batch = next(iter(dataset))
......@@ -124,6 +146,70 @@ class FunctionalTest(tf.test.TestCase):
functional_model.predict_on_batch(functional_model.initial_weights,
example_batch[0]), [[0.]] * 5)
def test_construct_from_keras(self):
keras_model = create_test_keras_model()
# Assign some variables after initialization so we can assert that they
# were cloned into the FunctionalModel.
tf.nest.map_structure(lambda v: v.assign(tf.ones_like(v)),
keras_model.variables)
functional_model = functional.functional_model_from_keras(
keras_model=keras_model,
loss_fn=tf.keras.losses.MeanSquaredError(),
input_spec=(tf.TensorSpec([None, 1], dtype=tf.float32),
tf.TensorSpec([None, 1], dtype=tf.int32)))
self.assertIsInstance(functional_model, functional.FunctionalModel)
# Assert all ones, instead of zeros from a newly initial model.
tf.nest.map_structure(lambda v: self.assertAllClose(v, tf.ones_like(v)),
functional_model.initial_weights)
def test_construct_from_keras_converges(self):
functional_model = functional.functional_model_from_keras(
keras_model=create_test_keras_model(),
loss_fn=tf.keras.losses.MeanSquaredError(),
input_spec=(tf.TensorSpec([None, 1], dtype=tf.float32),
tf.TensorSpec([None, 1], dtype=tf.int32)))
with tf.Graph().as_default() as test_graph:
# Capture all the variables for later initialization in the session,
# otherwise it's hard to get our hands on the Keras-owned variables.
with variable_utils.record_variable_creation_scope(
) as captured_variables:
# Create data satisfying y = 2*x + 1
dataset = tf.data.Dataset.from_tensor_slices((
# Features
[[1.0], [2.0], [3.0]],
# Labels.
[[3.0], [5.0], [7.0]],
)).batch(1)
variables = tf.nest.map_structure(tf.Variable,
functional_model.initial_weights)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
@tf.function
def train():
weights = tf.nest.map_structure(lambda v: v.read_value(), variables)
initial_loss = loss = functional_model.forward_pass(
weights, next(iter(dataset)), training=True).loss
trainable = variables[0]
for batch in dataset.repeat(30):
with tf.GradientTape() as tape:
weights = tf.nest.map_structure(lambda v: v.read_value(),
variables)
tape.watch(weights[0])
batch_output = functional_model.forward_pass(
weights, batch, training=True)
gradients = tape.gradient(batch_output.loss, weights[0])
optimizer.apply_gradients(zip(gradients, trainable))
loss = batch_output.loss
return initial_loss, loss
initial_loss, final_loss = train()
with tf.compat.v1.Session(graph=test_graph) as sess:
sess.run(tf.compat.v1.initializers.variables(captured_variables))
initial_loss, final_loss = sess.run([initial_loss, final_loss])
# Expect some amount of convergence after a few epochs of the dataset.
self.assertGreater(initial_loss, 2.0)
self.assertLess(final_loss, 0.2)
def test_tff_model_from_functional_same_result(self):
dataset = create_test_dataset()
input_spec = dataset.element_spec
......@@ -276,6 +362,65 @@ class FunctionalTest(tf.test.TestCase):
# mae = (1.0+1.0)/(2.0+6.0) = 0.25
collections.OrderedDict(loss=0.5, mse=0.75, mae=0.25))
def test_keras_model_with_non_trainable_variables_fails(self):
inputs = tf.keras.layers.Input(shape=[1])
d = tf.keras.layers.Dense(1)
d.trainable = False
outputs = d(inputs)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
with self.assertRaisesRegex(functional.KerasFunctionalModelError,
'non-trainable variables'):
functional.functional_model_from_keras(
keras_model,
tf.keras.losses.MeanSquaredError(),
input_spec=(tf.TensorSpec(shape=[None, 1]),
tf.TensorSpec(shape=[None, 1])))
def test_keras_model_with_batch_normalization_fails(self):
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=[10]),
tf.keras.layers.BatchNormalization(),
])
with self.assertRaisesRegex(functional.KerasFunctionalModelError,
'batch normalization'):
functional.functional_model_from_keras(
model,
tf.keras.losses.MeanSquaredError(),
input_spec=(tf.TensorSpec(shape=[None, 10]),
tf.TensorSpec(shape=[None, 1])))
def test_keras_model_with_shared_variables_fails(self):
class SharedLayer(tf.keras.layers.Layer):
def __init__(self, dense_layer: tf.keras.layers.Dense, **kwargs):
super().__init__()
self._dense_layer = dense_layer
self.kernel = dense_layer.kernel
self.bias = dense_layer.bias
def call(self, inputs):
return inputs
def get_config(self):
config = super().get_config()
config['dense_layer'] = self._dense_layer
return config
inputs = tf.keras.layers.Input(shape=[1])
layer1 = tf.keras.layers.Dense(1)
y = layer1(inputs)
layer2 = SharedLayer(layer1)
outputs = layer2(y)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
with self.assertRaisesRegex(functional.KerasFunctionalModelError,
'sharing variables across layers'):
functional.functional_model_from_keras(
keras_model,
tf.keras.losses.MeanSquaredError(),
input_spec=(tf.TensorSpec(shape=[None, 1]),
tf.TensorSpec(shape=[None, 1])))
if __name__ == '__main__':
execution_contexts.set_local_python_execution_context()
......
......@@ -348,8 +348,9 @@ def save_functional_model(functional_model: functional.FunctionalModel,
# structure of tensors with the initial weights. This way we can add it to the
# tf.SavedModel and call it to create initial weights after deserialization.
create_initial_weights = lambda: functional_model.initial_weights
concrete_structured_fn = tf.function(
create_initial_weights).get_concrete_function()
with tf.Graph().as_default():
concrete_structured_fn = tf.function(
create_initial_weights).get_concrete_function()
model_weights_tensor_specs = tf.nest.map_structure(
tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
initial_weights_result_type_spec = type_serialization.serialize_type(
......@@ -360,8 +361,9 @@ def save_functional_model(functional_model: functional.FunctionalModel,
def flat_initial_weights():
return tf.nest.flatten(create_initial_weights())
m.create_initial_weights = tf.function(
flat_initial_weights).get_concrete_function()
with tf.Graph().as_default():
m.create_initial_weights = tf.function(
flat_initial_weights).get_concrete_function()
# Serialize forward pass concretely, once for training and once for
# non-training.
......@@ -384,11 +386,12 @@ def save_functional_model(functional_model: functional.FunctionalModel,
# during function conretization. The resulting concrete function only has
# parameters for `model_weights` and `batch_input`, which are
# `tf.TensorSpec` structures here.
concrete_structured_fn = functional_model.forward_pass.get_concrete_function(
model_weights_tensor_specs,
functional_model.input_spec,
# Note: training does not appear in the resulting concrete function.
training=training)
with tf.Graph().as_default():
concrete_structured_fn = functional_model.forward_pass.get_concrete_function(
model_weights_tensor_specs,
functional_model.input_spec,
# Note: training does not appear in the resulting concrete function.
training=training)
output_tensor_spec_structure = tf.nest.map_structure(
tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
result_type_spec = type_serialization.serialize_type(
......@@ -399,11 +402,12 @@ def save_functional_model(functional_model: functional.FunctionalModel,
return tf.nest.flatten(
functional_model.forward_pass(model_weights, batch_input, training))
flat_concrete_fn = flat_forward_pass.get_concrete_function(
model_weights_tensor_specs,
functional_model.input_spec,
# Note: training does not appear in the resulting concrete function.
training=training)
with tf.Graph().as_default():
flat_concrete_fn = flat_forward_pass.get_concrete_function(
model_weights_tensor_specs,
functional_model.input_spec,
# Note: training does not appear in the resulting concrete function.
training=training)
return flat_concrete_fn, result_type_spec
fw_pass_training, fw_pass_training_type_spec = make_concrete_flat_forward_pass(
......@@ -466,15 +470,17 @@ def save_functional_model(functional_model: functional.FunctionalModel,
training=training)
return flat_concrete_fn, result_type_spec
predict_training, predict_training_type_spec = make_concrete_flat_predict_on_batch(
training=True)
with tf.Graph().as_default():
predict_training, predict_training_type_spec = make_concrete_flat_predict_on_batch(
training=True)
m.predict_on_batch_training = predict_training
m.predict_on_batch_training_type_spec = tf.Variable(
predict_training_type_spec.SerializeToString(deterministic=True),
trainable=False)
predict_inference, predict_inference_type_spec = make_concrete_flat_predict_on_batch(
training=False)
with tf.Graph().as_default():
predict_inference, predict_inference_type_spec = make_concrete_flat_predict_on_batch(
training=False)
m.predict_on_batch_inference = predict_inference
m.predict_on_batch_inference_type_spec = tf.Variable(
predict_inference_type_spec.SerializeToString(deterministic=True),
......
......@@ -375,7 +375,7 @@ class SerializationTest(test_case.TestCase, parameterized.TestCase):
def initial_weights():
"""Returns lists of trainable variables and non-trainable variables."""
trainable_variables = (np.asarray([[0.0, 0.0, 0.0]], dtype=np.float32),
trainable_variables = (np.asarray([[0.0], [0.0], [0.0]], dtype=np.float32),
np.asarray([0.0], dtype=np.float32))
non_trainable_variables = ()
return (trainable_variables, non_trainable_variables)
......@@ -385,7 +385,7 @@ def initial_weights():
def predict_on_batch(model_weights, x, training):
del training # Unused
trainable = model_weights[0]
return tf.matmul(x, trainable[0], transpose_b=True) + trainable[1]
return tf.matmul(x, trainable[0]) + trainable[1]
@tf.function
......@@ -415,38 +415,73 @@ def get_dataset():
return preprocess(tf.data.Dataset.range(15).enumerate())
class FunctionalModelTest(tf.test.TestCase):
@tf.function
def get_example_batch(dataset):
return next(iter(dataset))
def create_test_functional_model(input_spec):
return functional.FunctionalModel(initial_weights(), forward_pass,
predict_on_batch, input_spec)
def test_functional_predict_on_batch(self):
def create_test_keras_functional_model(input_spec):
# We must create the functional model that wraps a keras model in a graph
# context (see IMPORTANT note in `functional_model_from_keras`), otherwise
# we'll get non-model Variables.
keras_model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=[3]),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros')
])
return functional.functional_model_from_keras(
keras_model,
loss_fn=tf.keras.losses.MeanSquaredError(),
input_spec=input_spec)
class FunctionalModelTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('tf_function', create_test_functional_model),
('keras_model', create_test_keras_functional_model))
def test_functional_predict_on_batch(self, model_fn):
dataset = get_dataset()
example_batch = next(iter(dataset))
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
model_weights = functional_model.initial_weights
functional_model = model_fn(input_spec=dataset.element_spec)
# The wrapped keras model can only be used inside a `tff.tf_computation`.
@computations.tf_computation
def _predict_on_batch(dataset, model_weights):
example_batch = get_example_batch(dataset)
return functional_model.predict_on_batch(model_weights, example_batch[0])
self.assertAllClose(
functional_model.predict_on_batch(model_weights, example_batch[0]),
_predict_on_batch(dataset, functional_model.initial_weights),
[[0.]] * 5)
def test_construct_tff_model_from_functional_predict_on_batch(self):
@parameterized.named_parameters(
('tf_function', create_test_functional_model),
('keras_model', create_test_keras_functional_model))
def test_construct_tff_model_from_functional_predict_on_batch(