提交 92104850 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Require Functional Model functions to be tf.function decorated.

PiperOrigin-RevId: 394089953
上级 92105eba
......@@ -18,6 +18,7 @@ py_library(
srcs = ["functional.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
......
......@@ -27,6 +27,7 @@ from typing import Any, Callable, Mapping, Sequence, Tuple, Union
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl.federated_context import intrinsics
......@@ -39,6 +40,14 @@ WeightStruct = Union[Sequence[Weight], Mapping[str, Weight]]
ModelWeights = Tuple[WeightStruct, WeightStruct]
class CallableMustBeTFFunctionError(TypeError):
"""Error raised when a callable is not decorated as a tf.function."""
class ValueMustNotBeTFError(TypeError):
"""Error raised a value must not be a `tf.Tensor` or `tf.Variable`."""
class FunctionalModel():
"""A model that parameterizes forward pass by model weights."""
......@@ -106,9 +115,17 @@ class FunctionalModel():
corresponds to batched labels for those inputs.
"""
def check_tf_function_decorated(fn, arg_name):
if not hasattr(fn, 'get_concrete_function'):
type_string = py_typecheck.type_string(type(fn))
raise CallableMustBeTFFunctionError(
f'{arg_name} does not have a `get_concrete_function` attribute '
'meaning it is not a callable decorated with `tf.function`. '
f'Got a {type_string} with value {fn!r}.')
def check_non_tf_value(value):
if tf.is_tensor(value) or isinstance(value, tf.Variable):
raise TypeError(
raise ValueMustNotBeTFError(
'initial_weights may not contain TensorFlow values '
f'(tf.Tensor or tf.Variable). Got: {type(value)!r}. Try '
'converting to a np.ndarray by using the `.numpy()` '
......@@ -117,7 +134,9 @@ class FunctionalModel():
tf.nest.map_structure(check_non_tf_value, initial_weights)
self._initial_weights = initial_weights
check_tf_function_decorated(forward_pass_fn, 'forward_pass_fn')
self._forward_pass_fn = forward_pass_fn
check_tf_function_decorated(predict_on_batch_fn, 'predict_on_batch_fn')
self._predict_on_batch_fn = predict_on_batch_fn
self._input_spec = input_spec
......@@ -125,6 +144,7 @@ class FunctionalModel():
def initial_weights(self) -> ModelWeights:
return self._initial_weights
@tf.function
def forward_pass(self,
model_weights: ModelWeights,
batch_input: Any,
......@@ -132,6 +152,7 @@ class FunctionalModel():
"""Runs the forward pass and returns results."""
return self._forward_pass_fn(model_weights, batch_input, training)
@tf.function
def predict_on_batch(self,
model_weights: ModelWeights,
x: Any,
......
......@@ -28,6 +28,7 @@ def initial_weights():
return (trainable_variables, non_trainable_variables)
@tf.function
def predict_on_batch(model_weights, x, training):
"""Test predict_on_batch implementing linear regression."""
trainable = model_weights[0]
......@@ -40,6 +41,7 @@ def predict_on_batch(model_weights, x, training):
return tf.matmul(x, w, transpose_b=True)
@tf.function
def forward_pass(model_weights, batch_input, training):
"""Test forward_pass implementing linear regression on MSE."""
x, y = batch_input
......@@ -75,13 +77,29 @@ class FunctionalTest(tf.test.TestCase):
def test_fail_construction_on_tf_value(self):
dataset = create_test_dataset()
input_spec = dataset.element_spec
with self.assertRaisesRegex(TypeError, 'initial_weights may not contain'):
with self.assertRaisesRegex(functional.ValueMustNotBeTFError,
'initial_weights may not contain'):
functional.FunctionalModel((tf.constant(1.0), ()), forward_pass,
predict_on_batch, input_spec)
with self.assertRaisesRegex(TypeError, 'initial_weights may not contain'):
with self.assertRaisesRegex(functional.ValueMustNotBeTFError,
'initial_weights may not contain'):
functional.FunctionalModel((tf.Variable(1.0), ()), forward_pass,
predict_on_batch, input_spec)
def test_fail_non_tf_function(self):
dataset = create_test_dataset()
input_spec = dataset.element_spec
with self.assertRaisesRegex(
functional.CallableMustBeTFFunctionError,
'forward_pass_fn does not have a `get_concrete_function`'):
functional.FunctionalModel((), forward_pass.python_function,
predict_on_batch, input_spec)
with self.assertRaisesRegex(
functional.CallableMustBeTFFunctionError,
'predict_on_batch_fn does not have a `get_concrete_function`'):
functional.FunctionalModel((), forward_pass,
predict_on_batch.python_function, input_spec)
def test_predict_on_batch(self):
dataset = create_test_dataset()
example_batch = next(iter(dataset))
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册