提交 4bd1be79 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Introduce a Functional Model API for tff.learning.

This will allow for serializable Model that are also composable post
deserialization, as well as more Jax-friendly APIs to go along with
tensorflow_federated/experimental/python/learning/jax_components.py

PiperOrigin-RevId: 393835777
上级 eeaea774
......@@ -7,7 +7,35 @@ py_library(
srcs = ["__init__.py"],
srcs_version = "PY3",
visibility = ["//tensorflow_federated/python/learning:__pkg__"],
deps = [":serialization"],
deps = [
":functional",
":serialization",
],
)
py_library(
name = "functional",
srcs = ["functional.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/learning:model",
],
)
py_test(
name = "functional_test",
srcs = ["functional_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":functional",
"//tensorflow_federated/python/learning:model",
],
)
py_library(
......
......@@ -13,5 +13,7 @@
# limitations under the License.
"""Libraries for working with models in Federated Learning algorithms."""
from tensorflow_federated.python.learning.models.functional import FunctionalModel
from tensorflow_federated.python.learning.models.functional import model_from_functional
from tensorflow_federated.python.learning.models.serialization import load
from tensorflow_federated.python.learning.models.serialization import save
# Copyright 2021, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for creating functional implementations of a `tff.learning.Model`.
This version of the model parameterizes its `forward_pass` and
`predict_on_batch` methods by model weights, rather than storing them in the
model. This allows for greater flexibility in model portability.
To use with `tff.learning.build_federated_averaging_process` and other APIs that
construct learning processes expecting stateful models, wrap the functional
model with `tff.learning.models.model_from_functional`.
"""
from typing import Any, Callable, Mapping, Sequence, Tuple, Union
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.learning import model as model_lib
Weight = Union[np.ndarray, int, float]
WeightStruct = Union[Sequence[Weight], Mapping[str, Weight]]
ModelWeights = Tuple[WeightStruct, WeightStruct]
class FunctionalModel():
"""A model that parameterizes forward pass by model weights."""
def __init__(
self,
initial_weights: ModelWeights,
forward_pass_fn: Callable[[ModelWeights, Any, bool],
model_lib.BatchOutput],
predict_on_batch_fn: Callable[[ModelWeights, Any, bool], Any],
input_spec,
):
"""Initializes a `FunctionalModel`.
Example model implementing linear regression:
```
w, b = np.zeros(shape=[1,3]), np.zeros([1])
trainable_weights = (w, b)
non_trainable_weights = ()
initial_weights = (trainable_weights, non_trainable_weights)
@tf.function
def predict_on_batch(model_weights, x, training):
del training # Unused.
trainable, non_trainable = model_weights
w, b = trainable
return tf.matmul(x, w, transpose_b=True) + b
@tf.function
def forward_pass(model_weights, batch_input, training):
x, y = batch_input
predictions = predict_on_batch(model_weights, , training)
residuals = predictions - y
total_loss = tf.reduce_sum(tf.pow(residuals, 2.))
num_examples = tf.shape(predictions)[0]
average_loss = total_loss / tf.cast(num_examples, tf.float32)
return tff.learning.BatchOutput(
loss=average_loss, predictions=predictions, num_examples=num_examples)
model = FunctionalModel(
initial_weights, forward_pass, predict_on_batch,
(tf.TensorSpec(shape=[None, 3], dtype=tf.float32),
tf.TensorSpec(shape=[None, 1], dtype=tf.float32))
)
```
Args:
initial_weights: A 2-tuple `(trainable, non_trainable)` where the two
elements are sequences of weights. Weights must be values convertable to
`tf.Tensor` (e.g. `numpy.ndarray`, Python sequences, etc), but _not_
`tf.Tensor` values.
forward_pass_fn: A `tf.function` decorated callable that takes three
arguments, `model_weights` the same structure as `initial_weights`,
`batch_input` a nested structure of tensors matching `input_spec`, and
`training` a boolean determinig whether the call is during a training
pass (e.g. for Dropout, BatchNormalization, etc).
predict_on_batch_fn: A `tf.function` decorated callable that takes three
arguments, `model_weights` the same structure as `initial_weights`, `x`
the first element of `batch_input` (or `input_spec`), and `training` a
boolean determinig whether the call is during a training pass (e.g. for
Dropout, BatchNormalization, etc).
input_spec: A 2-tuple of `(x, y)` where each element is a nested structure
of `tf.TensorSpec` that defines the shape and dtypes of `batch_input` to
`forward_pass_fn`. `x` corresponds to batched model inputs and `y`
corresponds to batched labels for those inputs.
"""
def check_non_tf_value(value):
if tf.is_tensor(value) or isinstance(value, tf.Variable):
raise TypeError(
'initial_weights may not contain TensorFlow values '
f'(tf.Tensor or tf.Variable). Got: {type(value)!r}. Try '
'converting to a np.ndarray by using the `.numpy()` '
'attribute for tf.Tensor, or `.read_value().numpy()` '
'for tf.Variable.')
tf.nest.map_structure(check_non_tf_value, initial_weights)
self._initial_weights = initial_weights
self._forward_pass_fn = forward_pass_fn
self._predict_on_batch_fn = predict_on_batch_fn
self._input_spec = input_spec
@property
def initial_weights(self) -> ModelWeights:
return self._initial_weights
def forward_pass(self,
model_weights: ModelWeights,
batch_input: Any,
training: bool = True) -> model_lib.BatchOutput:
"""Runs the forward pass and returns results."""
return self._forward_pass_fn(model_weights, batch_input, training)
def predict_on_batch(self,
model_weights: ModelWeights,
x: Any,
training: bool = True):
"""Returns tensor(s) interpretable by the loss function."""
return self._predict_on_batch_fn(model_weights, x, training)
@property
def input_spec(self):
return self._input_spec
class _ModelFromFunctional(model_lib.Model):
"""A `tff.learning.Model` wrapping a `tff.learning.model.FunctionalModel`."""
def __init__(self, functional_model: FunctionalModel):
self._functional_model = functional_model
# Construct `tf.Variable` to optimize during the learning process.
trainable, non_trainable = functional_model.initial_weights
self._trainable_variables = tuple(tf.Variable(x) for x in trainable)
self._non_trainable_variables = tuple(
tf.Variable(x, trainable=False) for x in non_trainable)
self._model_weights = (self._trainable_variables,
self._non_trainable_variables)
@property
def trainable_variables(self) -> Tuple[tf.Variable, ...]:
return self._trainable_variables
@property
def non_trainable_variables(self) -> Tuple[tf.Variable, ...]:
return self._non_trainable_variables
@property
def local_variables(self) -> Tuple[tf.Variable, ...]:
return ()
@property
def input_spec(self):
return self._functional_model.input_spec
@tf.function
def forward_pass(self, batch_input, training=True):
return self._functional_model.forward_pass(
model_weights=tf.nest.map_structure(lambda v: v.read_value(),
self._model_weights),
batch_input=batch_input,
training=training)
@tf.function
def predict_on_batch(self, x, training=True):
return self._functional_model.predict_on_batch(
model_weights=tf.nest.map_structure(lambda v: v.read_value(),
self._model_weights),
x=x,
training=training)
@tf.function
def report_local_outputs(self):
return {}
@property
def federated_output_computation(self) -> computation_base.Computation:
@computations.federated_computation(computation_types.at_clients(()))
def aggregate(values):
del values # Unused.
return intrinsics.federated_value((), placements.SERVER)
return aggregate
def model_from_functional(functional_model: FunctionalModel) -> model_lib.Model:
"""Converts a `FunctionalModel` to a `tff.learning.Model`."""
return _ModelFromFunctional(functional_model)
# Copyright 2021, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for FunctionModel."""
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning.models import functional
def initial_weights():
"""Returns lists of trainable variables and non-trainable variables."""
trainable_variables = (np.asarray([[0.0, 0.0, 0.0]], dtype=np.float32),
np.asarray([0.0], dtype=np.float32))
non_trainable_variables = ()
return (trainable_variables, non_trainable_variables)
def predict_on_batch(model_weights, x, training):
"""Test predict_on_batch implementing linear regression."""
trainable = model_weights[0]
w, b = trainable
# For the sake of testing, only add the bias term when training so that
# we get different outputs.
if training:
return tf.matmul(x, w, transpose_b=True) + b
else:
return tf.matmul(x, w, transpose_b=True)
def forward_pass(model_weights, batch_input, training):
"""Test forward_pass implementing linear regression on MSE."""
x, y = batch_input
predictions = predict_on_batch(model_weights, x, training)
residuals = predictions - y
num_examples = tf.shape(predictions)[0]
total_loss = tf.reduce_sum(tf.pow(residuals, 2.))
average_loss = total_loss / tf.cast(num_examples, tf.float32)
return model_lib.BatchOutput(
loss=average_loss, predictions=predictions, num_examples=num_examples)
def create_test_dataset():
"""Create a test dataset."""
def preprocess(ds):
def generate_example(i, t):
del t # Unused.
features = tf.random.stateless_uniform(shape=[3], seed=(0, i))
label = tf.expand_dims(
tf.reduce_sum(features * tf.constant([1.0, 2.0, 3.0])), axis=-1) + 5.0
return (features, label)
return ds.map(generate_example).batch(5, drop_remainder=True)
num_examples = 25
return preprocess(tf.data.Dataset.range(num_examples).enumerate())
class FunctionalTest(tf.test.TestCase):
def test_fail_construction_on_tf_value(self):
dataset = create_test_dataset()
input_spec = dataset.element_spec
with self.assertRaisesRegex(TypeError, 'initial_weights may not contain'):
functional.FunctionalModel((tf.constant(1.0), ()), forward_pass,
predict_on_batch, input_spec)
with self.assertRaisesRegex(TypeError, 'initial_weights may not contain'):
functional.FunctionalModel((tf.Variable(1.0), ()), forward_pass,
predict_on_batch, input_spec)
def test_predict_on_batch(self):
dataset = create_test_dataset()
example_batch = next(iter(dataset))
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
self.assertAllClose(
functional_model.predict_on_batch(functional_model.initial_weights,
example_batch[0]), [[0.]] * 5)
def test_forward_pass(self):
dataset = create_test_dataset()
example_batch = next(iter(dataset))
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
self.assertAllClose(
functional_model.predict_on_batch(functional_model.initial_weights,
example_batch[0]), [[0.]] * 5)
def test_tff_model_from_functional_same_result(self):
dataset = create_test_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
tff_model = functional.model_from_functional(functional_model)
for training in [True, False]:
for batch in dataset:
self.assertAllClose(
tff_model.predict_on_batch(batch[0], training),
functional_model.predict_on_batch(functional_model.initial_weights,
batch[0], training))
tf.nest.map_structure(
self.assertAllClose, tff_model.forward_pass(batch, training),
functional_model.forward_pass(functional_model.initial_weights,
batch, training))
def test_functional_model_converges(self):
dataset = create_test_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)
variables = tf.nest.map_structure(tf.Variable,
functional_model.initial_weights)
trainable = variables[0]
loss = None
num_epochs = 50
for batch in dataset.repeat(num_epochs):
with tf.GradientTape() as tape:
batch_output = functional_model.forward_pass(
variables, batch, training=True)
gradients = tape.gradient(batch_output.loss, trainable)
optimizer.apply_gradients(zip(gradients, trainable))
loss = batch_output.loss
# Expect some amount of convergence after a few epochs of the dataset.
self.assertLess(loss, 0.1)
self.assertAllClose(trainable, ([[1.0, 2.0, 3.0]], [5.0]), atol=0.5)
def test_tff_model_from_functional_converges(self):
dataset = create_test_dataset()
input_spec = dataset.element_spec
functional_model = functional.FunctionalModel(initial_weights(),
forward_pass,
predict_on_batch, input_spec)
tff_model = functional.model_from_functional(functional_model)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)
loss = None
num_epochs = 50
for batch in dataset.repeat(num_epochs):
with tf.GradientTape() as tape:
batch_output = tff_model.forward_pass(batch, training=True)
gradients = tape.gradient(batch_output.loss,
tff_model.trainable_variables)
optimizer.apply_gradients(zip(gradients, tff_model.trainable_variables))
loss = batch_output.loss
# Expect some amount of convergence after a few epochs of the dataset.
self.assertLess(loss, 0.1)
self.assertAllClose(
tff_model.trainable_variables, ([[1.0, 2.0, 3.0]], [5.0]), atol=0.5)
if __name__ == '__main__':
tf.test.main()
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册