Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
621f8edd
提交
621f8edd
编辑于
11月 23, 2021
作者:
Yu Xiao
提交者:
tensorflow-copybara
11月 23, 2021
浏览文件
Make `from_keras_model` able to take metric constructors.
PiperOrigin-RevId: 411842768
上级
0f2d1f21
变更
3
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/learning/keras_utils.py
浏览文件 @
621f8edd
...
...
@@ -14,7 +14,7 @@
"""Utility methods for working with Keras in TensorFlow Federated."""
import
collections
from
typing
import
List
,
Optional
,
OrderedDict
,
Sequence
,
Union
from
typing
import
Callable
,
List
,
Optional
,
OrderedDict
,
Sequence
,
Union
import
warnings
import
tensorflow
as
tf
...
...
@@ -33,12 +33,16 @@ from tensorflow_federated.python.learning.metrics import finalizer
Loss
=
Union
[
tf
.
keras
.
losses
.
Loss
,
List
[
tf
.
keras
.
losses
.
Loss
]]
# TODO(b/197746608): Remove the code path that takes in constructed Keras
# metrics, because reconstructing metrics via `from_config` can cause problems.
def
from_keras_model
(
keras_model
:
tf
.
keras
.
Model
,
loss
:
Loss
,
input_spec
,
loss_weights
:
Optional
[
List
[
float
]]
=
None
,
metrics
:
Optional
[
List
[
tf
.
keras
.
metrics
.
Metric
]]
=
None
)
->
model_lib
.
Model
:
metrics
:
Optional
[
Union
[
List
[
tf
.
keras
.
metrics
.
Metric
],
List
[
Callable
[[],
tf
.
keras
.
metrics
.
Metric
]]]]
=
None
)
->
model_lib
.
Model
:
"""Builds a `tff.learning.Model` from a `tf.keras.Model`.
The `tff.learning.Model` returned by this function uses `keras_model` for
...
...
@@ -83,7 +87,8 @@ def from_keras_model(
loss_weights: (Optional) A list of Python floats used to weight the loss
contribution of each model output (when providing a list of losses for the
`loss` argument).
metrics: (Optional) a list of `tf.keras.metrics.Metric` objects.
metrics: (Optional) a list of `tf.keras.metrics.Metric` objects or a list of
no-arg callables that each constructs a `tf.keras.metrics.Metric`.
Returns:
A `tff.learning.Model` object.
...
...
@@ -171,8 +176,6 @@ def from_keras_model(
metrics
=
[]
else
:
py_typecheck
.
check_type
(
metrics
,
list
)
for
metric
in
metrics
:
py_typecheck
.
check_type
(
metric
,
tf
.
keras
.
metrics
.
Metric
)
for
layer
in
keras_model
.
layers
:
if
isinstance
(
layer
,
tf
.
keras
.
layers
.
BatchNormalization
):
...
...
@@ -191,8 +194,10 @@ def from_keras_model(
def
federated_aggregate_keras_metric
(
metrics
:
Union
[
tf
.
keras
.
metrics
.
Metric
,
Sequence
[
tf
.
keras
.
metrics
.
Metric
]],
federated_values
):
metrics
:
Union
[
tf
.
keras
.
metrics
.
Metric
,
Sequence
[
tf
.
keras
.
metrics
.
Metric
],
Callable
[[],
tf
.
keras
.
metrics
.
Metric
],
Sequence
[
Callable
[[],
tf
.
keras
.
metrics
.
Metric
]]],
federated_values
):
"""Aggregates variables a keras metric placed at CLIENTS to SERVER.
Args:
...
...
@@ -232,8 +237,11 @@ def federated_aggregate_keras_metric(
def
report
(
accumulators
):
"""Insert `accumulators` back into the keras metric to obtain result."""
def
finalize_metric
(
metric
:
tf
.
keras
.
metrics
.
Metric
,
values
):
# Note: the following call requires that `type(metric)` have a no argument
def
finalize_metric
(
metric
:
Union
[
tf
.
keras
.
metrics
.
Metric
,
Callable
[[],
tf
.
keras
.
metrics
.
Metric
]],
values
):
# Note: if the input metric is an instance of `tf.keras.metrics.Metric`,
# the following call requires that `type(metric)` have a no argument
# __init__ method, which will restrict the types of metrics that can be
# used. This is somewhat limiting, but the pattern to use default
# arguments and export the values in `get_config()` (see
...
...
@@ -242,8 +250,7 @@ def federated_aggregate_keras_metric(
# If type(metric) is subclass of another tf.keras.metric arguments passed
# to __init__ must include arguments expected by the superclass and
# specified in superclass get_config().
finalizer
.
check_keras_metric_config_constructable
(
metric
)
keras_metric
=
type
(
metric
).
from_config
(
metric
.
get_config
())
keras_metric
=
finalizer
.
create_keras_metric
(
metric
)
assignments
=
[]
for
v
,
a
in
zip
(
keras_metric
.
variables
,
values
):
...
...
@@ -270,12 +277,45 @@ class _KerasModel(model_lib.Model):
def
__init__
(
self
,
keras_model
:
tf
.
keras
.
Model
,
input_spec
,
loss_fns
:
List
[
tf
.
keras
.
losses
.
Loss
],
loss_weights
:
List
[
float
],
metrics
:
List
[
tf
.
keras
.
metrics
.
Metric
]):
metrics
:
Union
[
List
[
tf
.
keras
.
metrics
.
Metric
],
List
[
Callable
[[],
tf
.
keras
.
metrics
.
Metric
]]]):
self
.
_keras_model
=
keras_model
self
.
_input_spec
=
input_spec
self
.
_loss_fns
=
loss_fns
self
.
_loss_weights
=
loss_weights
self
.
_metrics
=
metrics
self
.
_metrics
=
[]
self
.
_metric_constructors
=
[]
if
metrics
:
has_keras_metric
=
False
has_keras_metric_constructor
=
False
for
metric
in
metrics
:
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
self
.
_metrics
.
append
(
metric
)
has_keras_metric
=
True
elif
callable
(
metric
):
constructed_metric
=
metric
()
if
not
isinstance
(
constructed_metric
,
tf
.
keras
.
metrics
.
Metric
):
raise
TypeError
(
f
'Metric constructor
{
metric
}
is not a no-arg callable that '
'creates a `tf.keras.metrics.Metric`.'
)
self
.
_metric_constructors
.
append
(
metric
)
self
.
_metrics
.
append
(
constructed_metric
)
has_keras_metric_constructor
=
True
else
:
raise
TypeError
(
'Expected the input metric to be either a '
'`tf.keras.metrics.Metric` or a no-arg callable that constructs '
'a `tf.keras.metrics.Metric`, found a non-callable '
f
'
{
py_typecheck
.
type_string
(
type
(
metric
))
}
.'
)
if
has_keras_metric
and
has_keras_metric_constructor
:
raise
TypeError
(
'Expected the input `metrics` to be either a list of '
'`tf.keras.metrics.Metric` objects or a list of no-arg callables '
'that each constructs a `tf.keras.metrics.Metric`, '
f
'found both types in the `metrics`:
{
metrics
}
.'
)
# This is defined here so that it closes over the `loss_fn`.
class
_WeightedMeanLossMetric
(
tf
.
keras
.
metrics
.
Mean
):
...
...
@@ -302,7 +342,9 @@ class _KerasModel(model_lib.Model):
return
super
().
update_state
(
batch_loss
,
batch_size
)
self
.
_loss_metric
=
_WeightedMeanLossMetric
()
self
.
_metrics
.
append
(
_WeightedMeanLossMetric
())
if
not
metrics
or
self
.
_metric_constructors
:
self
.
_metric_constructors
.
append
(
_WeightedMeanLossMetric
)
metric_variable_type_dict
=
tf
.
nest
.
map_structure
(
tf
.
TensorSpec
.
from_tensor
,
self
.
report_local_outputs
())
...
...
@@ -310,6 +352,9 @@ class _KerasModel(model_lib.Model):
metric_variable_type_dict
,
placements
.
CLIENTS
)
def
federated_output
(
local_outputs
):
if
self
.
_metric_constructors
:
return
federated_aggregate_keras_metric
(
self
.
_metric_constructors
,
local_outputs
)
return
federated_aggregate_keras_metric
(
self
.
get_metrics
(),
local_outputs
)
self
.
_federated_output_computation
=
computations
.
federated_computation
(
...
...
@@ -331,7 +376,7 @@ class _KerasModel(model_lib.Model):
return
local_variables
def
get_metrics
(
self
):
return
self
.
_metrics
+
[
self
.
_loss_metric
]
return
self
.
_metrics
@
property
def
input_spec
(
self
):
...
...
tensorflow_federated/python/learning/keras_utils_test.py
浏览文件 @
621f8edd
...
...
@@ -1070,6 +1070,128 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
self
.
assertIsInstance
(
tff_model
,
model_lib
.
Model
)
@
parameterized
.
named_parameters
(
# Test cases for the cartesian product of all parameter values.
*
_create_tff_model_from_keras_model_tuples
())
def
test_keras_model_with_metric_constructors
(
self
,
feature_dims
,
model_fn
):
keras_model
=
model_fn
(
feature_dims
)
tff_model
=
keras_utils
.
from_keras_model
(
keras_model
=
keras_model
,
input_spec
=
_create_whimsy_types
(
feature_dims
),
loss
=
tf
.
keras
.
losses
.
MeanSquaredError
(),
metrics
=
[
NumBatchesCounter
,
NumExamplesCounter
])
self
.
assertIsInstance
(
tff_model
,
model_lib
.
Model
)
# Metrics should be zero, though the model wrapper internally executes the
# forward pass once.
self
.
assertSequenceEqual
(
tff_model
.
local_variables
,
[
0
,
0
,
0.0
,
0.0
])
batch
=
collections
.
OrderedDict
(
x
=
np
.
stack
([
np
.
zeros
(
feature_dims
,
np
.
float32
),
np
.
ones
(
feature_dims
,
np
.
float32
)
]),
y
=
[[
0.0
],
[
1.0
]])
# from_model() was called without an optimizer which creates a tff.Model.
# There is no train_on_batch() method available in tff.Model.
with
self
.
assertRaisesRegex
(
AttributeError
,
'no attribute
\'
train_on_batch
\'
'
):
tff_model
.
train_on_batch
(
batch
)
output
=
tff_model
.
forward_pass
(
batch
)
# Since the model initializes all weights and biases to zero, we expect
# all predictions to be zero:
# 0*x1 + 0*x2 + ... + 0 = 0
self
.
assertAllEqual
(
output
.
predictions
,
[[
0.0
],
[
0.0
]])
# For the single batch:
#
# Example | Prediction | Label | Residual | Loss
# --------+------------+-------+----------+ -----
# 1 | 0.0 | 0.0 | 0.0 | 0.0
# 2 | 0.0 | 1.0 | 1.0 | 1.0
#
# Note that though regularization might be applied, this has no effect on
# the loss since all weights are 0.
# Total loss: 1.0
# Batch average loss: 0.5
self
.
assertEqual
(
output
.
loss
,
0.5
)
self
.
assertAllEqual
(
tff_model
.
report_local_outputs
(),
tff_model
.
report_local_unfinalized_metrics
())
metrics
=
tff_model
.
report_local_unfinalized_metrics
()
self
.
assertEqual
(
metrics
[
'num_batches'
],
[
1
])
self
.
assertEqual
(
metrics
[
'num_examples'
],
[
2
])
self
.
assertGreater
(
metrics
[
'loss'
][
0
],
0
)
self
.
assertEqual
(
metrics
[
'loss'
][
1
],
2
)
@
parameterized
.
named_parameters
(
# Test cases for the cartesian product of all parameter values.
*
_create_tff_model_from_keras_model_tuples
())
def
test_keras_model_without_input_metrics
(
self
,
feature_dims
,
model_fn
):
keras_model
=
model_fn
(
feature_dims
)
tff_model
=
keras_utils
.
from_keras_model
(
keras_model
=
keras_model
,
input_spec
=
_create_whimsy_types
(
feature_dims
),
loss
=
tf
.
keras
.
losses
.
MeanSquaredError
())
self
.
assertIsInstance
(
tff_model
,
model_lib
.
Model
)
# Metrics should be zero, though the model wrapper internally executes the
# forward pass once.
self
.
assertSequenceEqual
(
tff_model
.
local_variables
,
[
0
,
0
])
batch
=
collections
.
OrderedDict
(
x
=
np
.
stack
([
np
.
zeros
(
feature_dims
,
np
.
float32
),
np
.
ones
(
feature_dims
,
np
.
float32
)
]),
y
=
[[
0.0
],
[
1.0
]])
# from_model() was called without an optimizer which creates a tff.Model.
# There is no train_on_batch() method available in tff.Model.
with
self
.
assertRaisesRegex
(
AttributeError
,
'no attribute
\'
train_on_batch
\'
'
):
tff_model
.
train_on_batch
(
batch
)
output
=
tff_model
.
forward_pass
(
batch
)
# Since the model initializes all weights and biases to zero, we expect
# all predictions to be zero:
# 0*x1 + 0*x2 + ... + 0 = 0
self
.
assertAllEqual
(
output
.
predictions
,
[[
0.0
],
[
0.0
]])
# For the single batch:
#
# Example | Prediction | Label | Residual | Loss
# --------+------------+-------+----------+ -----
# 1 | 0.0 | 0.0 | 0.0 | 0.0
# 2 | 0.0 | 1.0 | 1.0 | 1.0
#
# Note that though regularization might be applied, this has no effect on
# the loss since all weights are 0.
# Total loss: 1.0
# Batch average loss: 0.5
self
.
assertEqual
(
output
.
loss
,
0.5
)
self
.
assertAllEqual
(
tff_model
.
report_local_outputs
(),
tff_model
.
report_local_unfinalized_metrics
())
metrics
=
tff_model
.
report_local_unfinalized_metrics
()
self
.
assertGreater
(
metrics
[
'loss'
][
0
],
0
)
self
.
assertEqual
(
metrics
[
'loss'
][
1
],
2
)
@
parameterized
.
named_parameters
(
(
'both_metrics_and_constructors'
,
[
NumExamplesCounter
,
NumBatchesCounter
()],
'found both types'
),
(
'non_callable'
,
[
tf
.
constant
(
1.0
)],
'found a non-callable'
),
(
'non_keras_metric_constructor'
,
[
tf
.
keras
.
losses
.
MeanSquaredError
],
'not a no-arg callable'
))
def
test_keras_model_provided_invalid_metrics_raises
(
self
,
metrics
,
error_message
):
feature_dims
=
3
keras_model
=
model_examples
.
build_linear_regression_keras_functional_model
(
feature_dims
)
with
self
.
assertRaisesRegex
(
TypeError
,
error_message
):
keras_utils
.
from_keras_model
(
keras_model
=
keras_model
,
input_spec
=
_create_whimsy_types
(
feature_dims
),
loss
=
tf
.
keras
.
losses
.
MeanSquaredError
(),
metrics
=
metrics
)
if
__name__
==
'__main__'
:
execution_contexts
.
set_local_python_execution_context
()
...
...
tensorflow_federated/python/learning/metrics/finalizer.py
浏览文件 @
621f8edd
...
...
@@ -57,22 +57,7 @@ def create_keras_metric_finalizer(
# we need the `tf.Variable`s to be created in the current scope in order to
# use `keras_metric.result()`.
with
tf
.
init_scope
():
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
check_keras_metric_config_constructable
(
metric
)
keras_metric
=
type
(
metric
).
from_config
(
metric
.
get_config
())
elif
callable
(
metric
):
keras_metric
=
metric
()
if
not
isinstance
(
keras_metric
,
tf
.
keras
.
metrics
.
Metric
):
raise
TypeError
(
'Expected input `metric` to be either a `tf.keras.metrics.Metric`'
' or a no-arg callable that creates a `tf.keras.metrics.Metric`, '
'found a callable that returns a '
f
'
{
py_typecheck
.
type_string
(
type
(
keras_metric
))
}
.'
)
else
:
raise
TypeError
(
'Expected input `metric` to be either a `tf.keras.metrics.Metric` '
'or a no-arg callable that constructs a `tf.keras.metrics.Metric`, '
f
'found a non-callable
{
py_typecheck
.
type_string
(
type
(
metric
))
}
.'
)
keras_metric
=
create_keras_metric
(
metric
)
py_typecheck
.
check_type
(
unfinalized_metric_values
,
list
)
if
len
(
keras_metric
.
variables
)
!=
len
(
unfinalized_metric_values
):
raise
ValueError
(
...
...
@@ -94,16 +79,12 @@ def create_keras_metric_finalizer(
return
finalizer
def
check_keras_metric_config_constructable
(
metric
:
tf
.
keras
.
metrics
.
Metric
)
->
tf
.
keras
.
metrics
.
Metric
:
def
_check_keras_metric_config_constructable
(
metric
:
tf
.
keras
.
metrics
.
Metric
):
"""Checks that a Keras metric is constructable from the `get_config()` method.
Args:
metric: A single `tf.keras.metrics.Metric`.
Returns:
The metric.
Raises:
TypeError: If the metric is not an instance of `tf.keras.metrics.Metric`, if
the metric is not constructable from the `get_config()` method.
...
...
@@ -114,28 +95,64 @@ def check_keras_metric_config_constructable(
metric_type_str
=
type
(
metric
).
__name__
if
hasattr
(
tf
.
keras
.
metrics
,
metric_type_str
):
return
metric
init_args
=
inspect
.
getfullargspec
(
metric
.
__init__
).
args
init_args
.
remove
(
'self'
)
get_config_args
=
metric
.
get_config
().
keys
()
extra_args
=
[
arg
for
arg
in
init_args
if
arg
not
in
get_config_args
]
if
extra_args
:
# TODO(b/197746608): Updates the error message to redirect users to use
# metric constructors instead of constructed metrics when we support both
# cases in `from_keras_model`.
raise
TypeError
(
f
'Metric
{
metric_type_str
}
is not constructable from '
'the `get_config()` method, because `__init__` takes extra '
'arguments that are not included in the `get_config()`: '
f
'
{
extra_args
}
. Override or update the `get_config()` in '
'the metric class to include these extra arguments.
\n
'
'Example:
\n
'
'class CustomMetric(tf.keras.metrics.Metric):
\n
'
' def __init__(self, arg1):
\n
'
' self._arg1 = arg1
\n\n
'
' def get_config(self)
\n
'
' config = super().get_config()
\n
'
' config[
\'
arg1
\'
] = self._arg1
\n
'
' return config'
)
return
metric
if
not
hasattr
(
tf
.
keras
.
metrics
,
metric_type_str
):
init_args
=
inspect
.
getfullargspec
(
metric
.
__init__
).
args
init_args
.
remove
(
'self'
)
get_config_args
=
metric
.
get_config
().
keys
()
extra_args
=
[
arg
for
arg
in
init_args
if
arg
not
in
get_config_args
]
if
extra_args
:
# TODO(b/197746608): Remove the suggestion of updating `get_config` if
# that code path is removed.
raise
TypeError
(
f
'Metric
{
metric_type_str
}
is not constructable from the '
'`get_config()` method, because `__init__` takes extra arguments '
f
'that are not included in the `get_config()`:
{
extra_args
}
. '
'Pass the metric constructor instead, or update the `get_config()` '
'in the metric class to include these extra arguments.
\n
'
'Example:
\n
'
'class CustomMetric(tf.keras.metrics.Metric):
\n
'
' def __init__(self, arg1):
\n
'
' self._arg1 = arg1
\n\n
'
' def get_config(self)
\n
'
' config = super().get_config()
\n
'
' config[
\'
arg1
\'
] = self._arg1
\n
'
' return config'
)
def
create_keras_metric
(
metric
:
Union
[
tf
.
keras
.
metrics
.
Metric
,
Callable
[[],
tf
.
keras
.
metrics
.
Metric
]]
)
->
tf
.
keras
.
metrics
.
Metric
:
"""Create a `tf.keras.metrics.Metric` from a `tf.keras.metrics.Metric`.
So the `tf.Variable`s in the metric can get created in the right scope in TFF.
Args:
metric: A single `tf.keras.metrics.Metric` or a no-arg callable that creates
a `tf.keras.metrics.Metric`.
Returns:
A `tf.keras.metrics.Metric` object.
Raises:
TypeError: If input metric is neither a `tf.keras.metrics.Metric` or a
no-arg callable that creates a `tf.keras.metrics.Metric`.
"""
keras_metric
=
None
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
_check_keras_metric_config_constructable
(
metric
)
keras_metric
=
type
(
metric
).
from_config
(
metric
.
get_config
())
elif
callable
(
metric
):
keras_metric
=
metric
()
if
not
isinstance
(
keras_metric
,
tf
.
keras
.
metrics
.
Metric
):
raise
TypeError
(
'Expected input `metric` to be either a `tf.keras.metrics.Metric` '
'or a no-arg callable that creates a `tf.keras.metrics.Metric`, '
'found a callable that returns a '
f
'
{
py_typecheck
.
type_string
(
type
(
keras_metric
))
}
.'
)
else
:
raise
TypeError
(
'Expected input `metric` to be either a `tf.keras.metrics.Metric` '
'or a no-arg callable that constructs a `tf.keras.metrics.Metric`, '
f
'found a non-callable
{
py_typecheck
.
type_string
(
type
(
metric
))
}
.'
)
return
keras_metric
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录