Skip to content
GitLab
菜单
项目
群组
代码片段
/
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
92104850
提交
92104850
编辑于
8月 31, 2021
作者:
Zachary Garrett
提交者:
tensorflow-copybara
8月 31, 2021
浏览文件
Require Functional Model functions to be tf.function decorated.
PiperOrigin-RevId: 394089953
上级
92105eba
变更
3
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/learning/models/BUILD
浏览文件 @
92104850
...
...
@@ -18,6 +18,7 @@ py_library(
srcs
=
[
"functional.py"
],
srcs_version
=
"PY3"
,
deps
=
[
"//tensorflow_federated/python/common_libs:py_typecheck"
,
"//tensorflow_federated/python/core/api:computation_base"
,
"//tensorflow_federated/python/core/api:computations"
,
"//tensorflow_federated/python/core/impl/federated_context:intrinsics"
,
...
...
tensorflow_federated/python/learning/models/functional.py
浏览文件 @
92104850
...
...
@@ -27,6 +27,7 @@ from typing import Any, Callable, Mapping, Sequence, Tuple, Union
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow_federated.python.common_libs
import
py_typecheck
from
tensorflow_federated.python.core.api
import
computation_base
from
tensorflow_federated.python.core.api
import
computations
from
tensorflow_federated.python.core.impl.federated_context
import
intrinsics
...
...
@@ -39,6 +40,14 @@ WeightStruct = Union[Sequence[Weight], Mapping[str, Weight]]
ModelWeights
=
Tuple
[
WeightStruct
,
WeightStruct
]
class
CallableMustBeTFFunctionError
(
TypeError
):
"""Error raised when a callable is not decorated as a tf.function."""
class
ValueMustNotBeTFError
(
TypeError
):
"""Error raised a value must not be a `tf.Tensor` or `tf.Variable`."""
class
FunctionalModel
():
"""A model that parameterizes forward pass by model weights."""
...
...
@@ -106,9 +115,17 @@ class FunctionalModel():
corresponds to batched labels for those inputs.
"""
def
check_tf_function_decorated
(
fn
,
arg_name
):
if
not
hasattr
(
fn
,
'get_concrete_function'
):
type_string
=
py_typecheck
.
type_string
(
type
(
fn
))
raise
CallableMustBeTFFunctionError
(
f
'
{
arg_name
}
does not have a `get_concrete_function` attribute '
'meaning it is not a callable decorated with `tf.function`. '
f
'Got a
{
type_string
}
with value
{
fn
!r}
.'
)
def
check_non_tf_value
(
value
):
if
tf
.
is_tensor
(
value
)
or
isinstance
(
value
,
tf
.
Variable
):
raise
Type
Error
(
raise
ValueMustNotBeTF
Error
(
'initial_weights may not contain TensorFlow values '
f
'(tf.Tensor or tf.Variable). Got:
{
type
(
value
)
!r}
. Try '
'converting to a np.ndarray by using the `.numpy()` '
...
...
@@ -117,7 +134,9 @@ class FunctionalModel():
tf
.
nest
.
map_structure
(
check_non_tf_value
,
initial_weights
)
self
.
_initial_weights
=
initial_weights
check_tf_function_decorated
(
forward_pass_fn
,
'forward_pass_fn'
)
self
.
_forward_pass_fn
=
forward_pass_fn
check_tf_function_decorated
(
predict_on_batch_fn
,
'predict_on_batch_fn'
)
self
.
_predict_on_batch_fn
=
predict_on_batch_fn
self
.
_input_spec
=
input_spec
...
...
@@ -125,6 +144,7 @@ class FunctionalModel():
def
initial_weights
(
self
)
->
ModelWeights
:
return
self
.
_initial_weights
@
tf
.
function
def
forward_pass
(
self
,
model_weights
:
ModelWeights
,
batch_input
:
Any
,
...
...
@@ -132,6 +152,7 @@ class FunctionalModel():
"""Runs the forward pass and returns results."""
return
self
.
_forward_pass_fn
(
model_weights
,
batch_input
,
training
)
@
tf
.
function
def
predict_on_batch
(
self
,
model_weights
:
ModelWeights
,
x
:
Any
,
...
...
tensorflow_federated/python/learning/models/functional_test.py
浏览文件 @
92104850
...
...
@@ -28,6 +28,7 @@ def initial_weights():
return
(
trainable_variables
,
non_trainable_variables
)
@
tf
.
function
def
predict_on_batch
(
model_weights
,
x
,
training
):
"""Test predict_on_batch implementing linear regression."""
trainable
=
model_weights
[
0
]
...
...
@@ -40,6 +41,7 @@ def predict_on_batch(model_weights, x, training):
return
tf
.
matmul
(
x
,
w
,
transpose_b
=
True
)
@
tf
.
function
def
forward_pass
(
model_weights
,
batch_input
,
training
):
"""Test forward_pass implementing linear regression on MSE."""
x
,
y
=
batch_input
...
...
@@ -75,13 +77,29 @@ class FunctionalTest(tf.test.TestCase):
def
test_fail_construction_on_tf_value
(
self
):
dataset
=
create_test_dataset
()
input_spec
=
dataset
.
element_spec
with
self
.
assertRaisesRegex
(
TypeError
,
'initial_weights may not contain'
):
with
self
.
assertRaisesRegex
(
functional
.
ValueMustNotBeTFError
,
'initial_weights may not contain'
):
functional
.
FunctionalModel
((
tf
.
constant
(
1.0
),
()),
forward_pass
,
predict_on_batch
,
input_spec
)
with
self
.
assertRaisesRegex
(
TypeError
,
'initial_weights may not contain'
):
with
self
.
assertRaisesRegex
(
functional
.
ValueMustNotBeTFError
,
'initial_weights may not contain'
):
functional
.
FunctionalModel
((
tf
.
Variable
(
1.0
),
()),
forward_pass
,
predict_on_batch
,
input_spec
)
def
test_fail_non_tf_function
(
self
):
dataset
=
create_test_dataset
()
input_spec
=
dataset
.
element_spec
with
self
.
assertRaisesRegex
(
functional
.
CallableMustBeTFFunctionError
,
'forward_pass_fn does not have a `get_concrete_function`'
):
functional
.
FunctionalModel
((),
forward_pass
.
python_function
,
predict_on_batch
,
input_spec
)
with
self
.
assertRaisesRegex
(
functional
.
CallableMustBeTFFunctionError
,
'predict_on_batch_fn does not have a `get_concrete_function`'
):
functional
.
FunctionalModel
((),
forward_pass
,
predict_on_batch
.
python_function
,
input_spec
)
def
test_predict_on_batch
(
self
):
dataset
=
create_test_dataset
()
example_batch
=
next
(
iter
(
dataset
))
...
...
编辑
预览
支持
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录