提交 46594ce6 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Add Functional Model serialization.

This also concretizes some generic methods used in previous model serialization
that appeared to interact poorly with tf.function tracing.

PiperOrigin-RevId: 394295569
上级 c7b20bf9
......@@ -44,6 +44,7 @@ py_library(
srcs = ["serialization.py"],
srcs_version = "PY3",
deps = [
":functional",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
......@@ -61,14 +62,17 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":functional",
":serialization",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/impl/types:type_serialization",
"//tensorflow_federated/python/learning:federated_averaging",
"//tensorflow_federated/python/learning:keras_utils",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_examples",
......
......@@ -16,4 +16,6 @@
from tensorflow_federated.python.learning.models.functional import FunctionalModel
from tensorflow_federated.python.learning.models.functional import model_from_functional
from tensorflow_federated.python.learning.models.serialization import load
from tensorflow_federated.python.learning.models.serialization import load_functional_model
from tensorflow_federated.python.learning.models.serialization import save
from tensorflow_federated.python.learning.models.serialization import save_functional_model
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for `tff.learning.Model` serialization."""
"""Module for model serialization."""
import collections
import functools
......@@ -26,6 +26,7 @@ from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.impl.types import type_serialization
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning.models import functional
class _LoadedSavedModel(model_lib.Model):
......@@ -281,3 +282,262 @@ def load(path: str) -> model_lib.Model:
raise ValueError('`path` must be a non-empty string, cannot deserialize '
'models without an output path.')
return _LoadedSavedModel(tf.saved_model.load(path))
def save_functional_model(functional_model: functional.FunctionalModel,
path: str):
"""Serializes a `FunctionalModel` as a `tf.SavedModel` to `path`.
Args:
functional_model: A `tff.learning.models.FunctionalModel`.
path: A `str` directory path to serialize the model to.
"""
m = tf.Module()
# Serialize the initial_weights values as a tf.function that creates a
# structure of tensors with the initial weights. This way we can add it to the
# tf.SavedModel and call it to create initial weights after deserialization.
create_initial_weights = lambda: functional_model.initial_weights
concrete_structured_fn = tf.function(
create_initial_weights).get_concrete_function()
model_weights_tensor_specs = tf.nest.map_structure(
tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
initial_weights_result_type_spec = type_serialization.serialize_type(
computation_types.to_type(model_weights_tensor_specs))
m.create_initial_weights_type_spec = tf.Variable(
initial_weights_result_type_spec.SerializeToString(deterministic=True))
def flat_initial_weights():
return tf.nest.flatten(create_initial_weights())
m.create_initial_weights = tf.function(
flat_initial_weights).get_concrete_function()
# Serialize forward pass concretely, once for training and once for
# non-training.
# TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove
# the need to for serializing two different function graphs.
def make_concrete_flat_forward_pass(training: bool):
"""Create a concrete forward_pass function that has flattened output.
Args:
training: A boolean indicating whether this is a call in a training loop,
or evaluation loop.
Returns:
A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol
buffer message documenting the the result structure returned by the
concrete function.
"""
# Save the un-flattened type spec for deserialization later.
# Note: `training` is a Python boolean, which gets "curried", in a sense,
# during function conretization. The resulting concrete function only has
# parameters for `model_weights` and `batch_input`, which are
# `tf.TensorSpec` structures here.
concrete_structured_fn = functional_model.forward_pass.get_concrete_function(
model_weights_tensor_specs,
functional_model.input_spec,
# Note: training does not appear in the resulting concrete function.
training=training)
output_tensor_spec_structure = tf.nest.map_structure(
tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
result_type_spec = type_serialization.serialize_type(
computation_types.to_type(output_tensor_spec_structure))
@tf.function
def flat_forward_pass(model_weights, batch_input, training):
return tf.nest.flatten(
functional_model.forward_pass(model_weights, batch_input, training))
flat_concrete_fn = flat_forward_pass.get_concrete_function(
model_weights_tensor_specs,
functional_model.input_spec,
# Note: training does not appear in the resulting concrete function.
training=training)
return flat_concrete_fn, result_type_spec
fw_pass_training, fw_pass_training_type_spec = make_concrete_flat_forward_pass(
training=True)
m.flat_forward_pass_training = fw_pass_training
m.forward_pass_training_type_spec = tf.Variable(
fw_pass_training_type_spec.SerializeToString(deterministic=True),
trainable=False)
fw_pass_inference, fw_pass_inference_type_spec = make_concrete_flat_forward_pass(
training=False)
m.flat_forward_pass_inference = fw_pass_inference
m.forward_pass_inference_type_spec = tf.Variable(
fw_pass_inference_type_spec.SerializeToString(deterministic=True),
trainable=False)
# Serialize predict_on_batch, once for training, once for non-training.
x_type = functional_model.input_spec[0]
# TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove
# the need to for serializing two different function graphs.
def make_concrete_flat_predict_on_batch(training: bool):
"""Create a concrete predict_on_batch function that has flattened output.
Args:
training: A boolean indicating whether this is a call in a training loop,
or evaluation loop.
Returns:
A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol
buffer message documenting the the result structure returned by the
concrete
function.
"""
# Save the un-flattened type spec for deserialization later.
# Note: `training` is a Python boolean, which gets "curried", in a sense,
# during function conretization. The resulting concrete function only has
# parameters for `model_weights` and `batch_input`, which are
# `tf.TensorSpec` structures here.
concrete_structured_fn = tf.function(
functional_model.predict_on_batch).get_concrete_function(
model_weights_tensor_specs,
x_type,
# Note: training does not appear in the resulting concrete function.
training=training)
output_tensor_spec_structure = tf.nest.map_structure(
tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
result_type_spec = type_serialization.serialize_type(
computation_types.to_type(output_tensor_spec_structure))
@tf.function
def flat_predict_on_batch(model_weights, x, training):
return tf.nest.flatten(
functional_model.predict_on_batch(model_weights, x, training))
flat_concrete_fn = tf.function(flat_predict_on_batch).get_concrete_function(
model_weights_tensor_specs,
x_type,
# Note: training does not appear in the resulting concrete function.
training=training)
return flat_concrete_fn, result_type_spec
predict_training, predict_training_type_spec = make_concrete_flat_predict_on_batch(
training=True)
m.predict_on_batch_training = predict_training
m.predict_on_batch_training_type_spec = tf.Variable(
predict_training_type_spec.SerializeToString(deterministic=True),
trainable=False)
predict_inference, predict_inference_type_spec = make_concrete_flat_predict_on_batch(
training=False)
m.predict_on_batch_inference = predict_inference
m.predict_on_batch_inference_type_spec = tf.Variable(
predict_inference_type_spec.SerializeToString(deterministic=True),
trainable=False)
# Serialize TFF values as string variables that contain the serialized
# protos from the computation or the type.
m.serialized_input_spec = tf.Variable(
type_serialization.serialize_type(
computation_types.to_type(
functional_model.input_spec)).SerializeToString(
deterministic=True),
trainable=False)
# Save everything
_save_tensorflow_module(m, path)
class _LoadedFunctionalModel(functional.FunctionalModel):
"""Creates a `FunctionalModel` from a loaded SavedModel."""
def __init__(self, loaded_module):
self._loaded_module = loaded_module
self._input_spec = tf.nest.map_structure(
lambda t: tf.TensorSpec(dtype=t.dtype, shape=t.shape),
_deserialize_type_spec(loaded_module.serialized_input_spec))
weights_nested_tensor_specs = _deserialize_type_spec(
loaded_module.create_initial_weights_type_spec, tuple)
self._initial_weights = tf.nest.pack_sequence_as(
weights_nested_tensor_specs,
# Convert EagerTensors to numpy arrays, necessary to avoid trying
# to capture EagerTensors in different graphs when doing:
# build_fedarated_averaging_process(
# ModelFromFunctional(_LoadedFunctionalModel)
[w.numpy() for w in loaded_module.create_initial_weights()])
def unflatten_forward_pass_fn(flat_forward_pass,
serialized_result_type_variable):
result_tensor_specs = _deserialize_type_spec(
serialized_result_type_variable, model_lib.BatchOutput)
def forward_pass(model_weights, batch_input):
result = flat_forward_pass(model_weights, batch_input)
return tf.nest.pack_sequence_as(result_tensor_specs, result)
return forward_pass
self._forward_pass_training = unflatten_forward_pass_fn(
loaded_module.flat_forward_pass_training,
loaded_module.forward_pass_training_type_spec)
self._forward_pass_inference = unflatten_forward_pass_fn(
loaded_module.flat_forward_pass_inference,
loaded_module.forward_pass_inference_type_spec)
def unflatten_predict_on_batch_fn(flat_predict_on_batch,
serialized_result_type_variable):
result_tensor_specs = _deserialize_type_spec(
serialized_result_type_variable, tuple)
def predict_on_batch(model_weights, x):
result = flat_predict_on_batch(model_weights=model_weights, x=x)
if tf.is_tensor(result):
return result
return tf.nest.pack_sequence_as(result_tensor_specs, result)
return predict_on_batch
self._predict_on_batch_training = unflatten_predict_on_batch_fn(
loaded_module.predict_on_batch_training,
loaded_module.predict_on_batch_training_type_spec)
self._predict_on_batch_inference = unflatten_predict_on_batch_fn(
loaded_module.predict_on_batch_inference,
loaded_module.predict_on_batch_inference_type_spec)
@property
def initial_weights(self):
return self._initial_weights
def forward_pass(self,
model_weights,
batch_input,
training=True) -> model_lib.BatchOutput:
"""Runs the forward pass and returns results."""
if training:
return self._forward_pass_training(
model_weights=model_weights, batch_input=batch_input)
else:
return self._forward_pass_inference(
model_weights=model_weights, batch_input=batch_input)
def predict_on_batch(self, model_weights, x, training=True):
"""Returns tensor(s) interpretable by the loss function."""
if training:
return self._predict_on_batch_training(model_weights=model_weights, x=x)
else:
return self._predict_on_batch_inference(model_weights=model_weights, x=x)
@property
def input_spec(self):
return self._input_spec
def load_functional_model(path: str) -> functional.FunctionalModel:
"""Deserializes a TensorFlow SavedModel at `path` to a functional model.
Args:
path: The `str` path pointing to a SavedModel.
Returns:
A `tff.learning.models.FunctionalModel`.
"""
py_typecheck.check_type(path, str)
if not path:
raise ValueError('`path` must be a non-empty string, cannot deserialize '
'models without an output path.')
return _LoadedFunctionalModel(tf.saved_model.load(path))
......@@ -23,14 +23,17 @@ import tensorflow as tf
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.impl.types import type_serialization
from tensorflow_federated.python.learning import federated_averaging
from tensorflow_federated.python.learning import keras_utils
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning import model_examples
from tensorflow_federated.python.learning.models import functional
from tensorflow_federated.python.learning.models import serialization
# Convenience aliases.
......@@ -316,5 +319,199 @@ class SerializationTest(test_case.TestCase, parameterized.TestCase):
self.assertNotEmpty(tflite_flatbuffer)
def initial_weights():
"""Returns lists of trainable variables and non-trainable variables."""
trainable_variables = (np.asarray([[0.0, 0.0, 0.0]], dtype=np.float32),
np.asarray([0.0], dtype=np.float32))
non_trainable_variables = ()
return (trainable_variables, non_trainable_variables)
@tf.function
def predict_on_batch(model_weights, x, training):
del training # Unused
trainable = model_weights[0]
return tf.matmul(x, trainable[0], transpose_b=True) + trainable[1]
@tf.function
def forward_pass(model_weights, batch_input, training):
predictions = predict_on_batch(model_weights, batch_input[0], training)
residuals = predictions - batch_input[1]
num_examples = tf.gather(tf.shape(predictions), 0)
total_loss = tf.reduce_sum(tf.pow(residuals, 2))
average_loss = total_loss / tf.cast(num_examples, tf.float32)
return model_lib.BatchOutput(
loss=average_loss, predictions=predictions, num_examples=num_examples)
def preprocess(ds):
def generate_example(i, t):
del t # Unused.
features = tf.random.stateless_uniform(shape=[3], seed=(0, i))
label = tf.expand_dims(
tf.reduce_sum(features * tf.constant([1.0, 2.0, 3.0])), axis=-1) + 5.0
return (features, label)
return ds.map(generate_example).batch(5, drop_remainder=True)
def get_dataset():
return preprocess(tf.data.Dataset.range(15).enumerate())
class FunctionalModelTest(tf.test.TestCase):
def test_functional_predict_on_batch(self):
dataset = get_dataset()
example_batch = next(iter(dataset))
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
model_weights = functional_model.initial_weights
self.assertAllClose(
functional_model.predict_on_batch(model_weights, example_batch[0]),
[[0.]] * 5)
def test_construct_tff_model_from_functional_predict_on_batch(self):
dataset = get_dataset()
example_batch = next(iter(dataset))
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
tff_model = functional.model_from_functional(functional_model)
self.assertAllClose(
tff_model.predict_on_batch(example_batch[0]), [[0.]] * 5)
def test_construct_tff_model_from_functional_and_train(self):
dataset = get_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
def model_fn():
return functional.model_from_functional(functional_model)
training_process = federated_averaging.build_federated_averaging_process(
model_fn, lambda: tf.keras.optimizers.SGD(learning_rate=0.1))
state = training_process.initialize()
for _ in range(2):
state, _ = training_process.next(state, [dataset])
def test_save_functional_model(self):
dataset = get_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
path = self.get_temp_dir()
serialization.save_functional_model(functional_model, path)
def test_save_and_load_functional_model(self):
dataset = get_dataset()
example_batch = next(iter(dataset))
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
path = self.get_temp_dir()
serialization.save_functional_model(functional_model, path)
loaded_model = serialization.load_functional_model(path)
model_weights = loaded_model.initial_weights
self.assertAllClose(
loaded_model.predict_on_batch(
model_weights=model_weights, x=example_batch[0]), [[0.]] * 5)
def test_initial_model_weights_before_after_save(self):
dataset = get_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
model_weights1 = functional_model.initial_weights
path = self.get_temp_dir()
serialization.save_functional_model(functional_model, path)
loaded_model = serialization.load_functional_model(path)
model_weights2 = loaded_model.initial_weights
self.assertAllClose(model_weights1, model_weights2)
def test_convert_to_tff_model(self):
dataset = get_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
tff_model = functional.model_from_functional(functional_model)
example_batch = next(iter(dataset))
self.assertAllClose(
tff_model.predict_on_batch(x=example_batch[0]), [[0.]] * 5)
def test_save_load_convert_to_tff_model(self):
dataset = get_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
path = self.get_temp_dir()
serialization.save_functional_model(functional_model, path)
loaded_model = serialization.load_functional_model(path)
tff_model = functional.model_from_functional(loaded_model)
example_batch = next(iter(dataset))
for training in [True, False]:
self.assertAllClose(
tff_model.predict_on_batch(x=example_batch[0], training=training),
[[0.]] * 5)
for training in [True, False]:
tf.nest.map_structure(
lambda x, y: self.assertAllClose(x, y, atol=1e-2, rtol=1e-2),
tff_model.forward_pass(batch_input=example_batch, training=training),
model_lib.BatchOutput(
loss=74.250, predictions=np.zeros(shape=[5, 1]), num_examples=5))
def test_save_load_convert_to_tff_model_and_train_to_convergence(self):
dataset = get_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
path = self.get_temp_dir()
serialization.save_functional_model(functional_model, path)
loaded_model = serialization.load_functional_model(path)
def model_fn():
return functional.model_from_functional(loaded_model)
training_process = federated_averaging.build_federated_averaging_process(
model_fn, lambda: tf.keras.optimizers.SGD(learning_rate=0.05))
state = training_process.initialize()
self.assertAllClose(state.model.trainable,
[np.zeros([1, 3]), np.zeros([1])])
num_rounds = 50
for _ in range(num_rounds):
state, _ = training_process.next(state, [dataset])
# Test that we came close to convergence.
self.assertAllClose(
state.model.trainable,
[np.asarray([[1.0, 2.0, 3.0]]),
np.asarray([5.0])],
atol=0.5)
if __name__ == '__main__':
# TODO(b/198454066): the EagerTFExecutor in the local executions stack shares
# a global function library with this test. `tf.function` tracing happens in
# this test, and we end up with two conflicting FunctionDefs in the global
# eager context, one after the TFF disable grappler transformation which is
# later added during the `tf.compat.v1.wrap_function` call during execution.
# This conflict does not occur in the C++ executor that does not use the eager
# context.
tf.config.optimizer.set_experimental_options({'disable_meta_optimizer': True})
execution_contexts.set_local_execution_context()
test_case.main()
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册