Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
提交
90feb1a0
提交
90feb1a0
编辑于
11月 17, 2021
作者:
Yu Xiao
提交者:
tensorflow-copybara
11月 17, 2021
浏览文件
Automated rollback of commit
2cec0ed7
PiperOrigin-RevId: 410631896
上级
238bf340
变更
5
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/python/learning/keras_utils.py
浏览文件 @
90feb1a0
...
...
@@ -242,21 +242,8 @@ def federated_aggregate_keras_metric(
# If type(metric) is subclass of another tf.keras.metric arguments passed
# to __init__ must include arguments expected by the superclass and
# specified in superclass get_config().
# TODO(b/197746608): finds a safer way of reconstructing the metric,
# default argument values in Metric constructors can cause problems here.
keras_metric
=
None
try
:
# This is some trickery to reconstruct a metric object in the current
# scope, so that the `tf.Variable`s get created when we desire.
keras_metric
=
type
(
metric
).
from_config
(
metric
.
get_config
())
except
TypeError
as
e
:
# Re-raise the error with a more helpful message, but the previous stack
# trace.
raise
TypeError
(
'Caught exception trying to call `{t}.from_config()` with '
'config {c}. Confirm that {t}.__init__() has an argument for '
'each member of the config.
\n
Exception: {e}'
.
format
(
t
=
type
(
metric
),
c
=
metric
.
get_config
(),
e
=
e
))
finalizer
.
check_keras_metric_config_constructable
(
metric
)
keras_metric
=
type
(
metric
).
from_config
(
metric
.
get_config
())
assignments
=
[]
for
v
,
a
in
zip
(
keras_metric
.
variables
,
values
):
...
...
tensorflow_federated/python/learning/keras_utils_test.py
浏览文件 @
90feb1a0
...
...
@@ -1020,6 +1020,57 @@ class KerasUtilsTest(test_case.TestCase, parameterized.TestCase):
metrics
=
[
NumBatchesCounter
(),
NumExamplesCounter
()])
def
test_custom_keras_metric_with_extra_init_args_raises
(
self
):
class
CustomCounter
(
tf
.
keras
.
metrics
.
Sum
):
"""A custom `tf.keras.metrics.Metric` with extra args in `__init__`."""
def
__init__
(
self
,
name
=
'new_counter'
,
arg1
=
0
,
dtype
=
tf
.
int64
):
super
().
__init__
(
name
,
dtype
)
self
.
_arg1
=
arg1
def
update_state
(
self
,
y_true
,
y_pred
,
sample_weight
=
None
):
return
super
().
update_state
(
1
,
sample_weight
)
feature_dims
=
3
keras_model
=
model_examples
.
build_linear_regression_keras_functional_model
(
feature_dims
)
with
self
.
assertRaisesRegex
(
TypeError
,
'extra arguments'
):
keras_utils
.
from_keras_model
(
keras_model
=
keras_model
,
input_spec
=
_create_whimsy_types
(
feature_dims
),
loss
=
tf
.
keras
.
losses
.
MeanSquaredError
(),
metrics
=
[
CustomCounter
(
arg1
=
1
)])
def
test_custom_keras_metric_no_extra_init_args_builds
(
self
):
class
CustomCounter
(
tf
.
keras
.
metrics
.
Sum
):
"""A custom `tf.keras.metrics.Metric` without extra args in `__init__`."""
def
__init__
(
self
,
name
=
'new_counter'
,
arg1
=
0
,
dtype
=
tf
.
int64
):
super
().
__init__
(
name
,
dtype
)
self
.
_arg1
=
arg1
def
update_state
(
self
,
y_true
,
y_pred
,
sample_weight
=
None
):
return
super
().
update_state
(
1
,
sample_weight
)
def
get_config
(
self
):
config
=
super
().
get_config
()
config
[
'arg1'
]
=
self
.
_arg1
return
config
feature_dims
=
3
keras_model
=
model_examples
.
build_linear_regression_keras_functional_model
(
feature_dims
)
tff_model
=
keras_utils
.
from_keras_model
(
keras_model
=
keras_model
,
input_spec
=
_create_whimsy_types
(
feature_dims
),
loss
=
tf
.
keras
.
losses
.
MeanSquaredError
(),
metrics
=
[
CustomCounter
(
arg1
=
1
)])
self
.
assertIsInstance
(
tff_model
,
model_lib
.
Model
)
if
__name__
==
'__main__'
:
execution_contexts
.
set_local_python_execution_context
()
...
...
tensorflow_federated/python/learning/metrics/finalizer.py
浏览文件 @
90feb1a0
...
...
@@ -13,6 +13,8 @@
# limitations under the License.
"""Helper functions for creating metric finalizers."""
import
inspect
from
typing
import
Any
,
Callable
,
List
,
Union
import
tensorflow
as
tf
...
...
@@ -56,6 +58,7 @@ def create_keras_metric_finalizer(
# use `keras_metric.result()`.
with
tf
.
init_scope
():
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
check_keras_metric_config_constructable
(
metric
)
keras_metric
=
type
(
metric
).
from_config
(
metric
.
get_config
())
elif
callable
(
metric
):
keras_metric
=
metric
()
...
...
@@ -89,3 +92,50 @@ def create_keras_metric_finalizer(
return
keras_metric
.
result
()
return
finalizer
def
check_keras_metric_config_constructable
(
metric
:
tf
.
keras
.
metrics
.
Metric
)
->
tf
.
keras
.
metrics
.
Metric
:
"""Checks that a Keras metric is constructable from the `get_config()` method.
Args:
metric: A single `tf.keras.metrics.Metric`.
Returns:
The metric.
Raises:
TypeError: If the metric is not an instance of `tf.keras.metrics.Metric`, if
the metric is not constructable from the `get_config()` method.
"""
if
not
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
raise
TypeError
(
f
'Metric
{
type
(
metric
)
}
is not a `tf.keras.metrics.Metric` '
'to be constructable from the `get_config()` method.'
)
metric_type_str
=
type
(
metric
).
__name__
if
hasattr
(
tf
.
keras
.
metrics
,
metric_type_str
):
return
metric
init_args
=
inspect
.
getfullargspec
(
metric
.
__init__
).
args
init_args
.
remove
(
'self'
)
get_config_args
=
metric
.
get_config
().
keys
()
extra_args
=
[
arg
for
arg
in
init_args
if
arg
not
in
get_config_args
]
if
extra_args
:
# TODO(b/197746608): Updates the error message to redirect users to use
# metric constructors instead of constructed metrics when we support both
# cases in `from_keras_model`.
raise
TypeError
(
f
'Metric
{
metric_type_str
}
is not constructable from '
'the `get_config()` method, because `__init__` takes extra '
'arguments that are not included in the `get_config()`: '
f
'
{
extra_args
}
. Override or update the `get_config()` in '
'the metric class to include these extra arguments.
\n
'
'Example:
\n
'
'class CustomMetric(tf.keras.metrics.Metric):
\n
'
' def __init__(self, arg1):
\n
'
' self._arg1 = arg1
\n\n
'
' def get_config(self)
\n
'
' config = super().get_config()
\n
'
' config[
\'
arg1
\'
] = self._arg1
\n
'
' return config'
)
return
metric
tensorflow_federated/python/learning/metrics/finalizer_test.py
浏览文件 @
90feb1a0
...
...
@@ -82,6 +82,17 @@ class CustomSumMetric(tf.keras.metrics.Sum):
return
config
class
CustomCounter
(
tf
.
keras
.
metrics
.
Sum
):
"""A custom `tf.keras.metrics.Metric` with extra arguments in `__init__`."""
def
__init__
(
self
,
name
=
'new_metric'
,
arg1
=
0
,
dtype
=
tf
.
int64
):
super
().
__init__
(
name
,
dtype
)
self
.
_arg1
=
arg1
def
update_state
(
self
,
y_true
,
y_pred
,
sample_weight
=
None
):
return
super
().
update_state
(
1
,
sample_weight
)
class
FinalizerTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
named_parameters
(
...
...
@@ -115,8 +126,10 @@ class FinalizerTest(parameterized.TestCase, tf.test.TestCase):
@
parameterized
.
named_parameters
(
(
'tensor'
,
tf
.
constant
(
1.0
),
'found a non-callable'
),
(
'loss_constructor'
,
tf
.
keras
.
losses
.
MeanSquaredError
,
'found a callable'
)
)
(
'loss_constructor'
,
tf
.
keras
.
losses
.
MeanSquaredError
,
'found a callable'
),
# go/pyformat-break
(
'custom_metric_with_extra_init_args'
,
CustomCounter
(
arg1
=
1
),
'extra arguments'
))
def
test_create_keras_metric_finalizer_fails_with_invalid_input
(
self
,
invalid_metric
,
error_message
):
unused_type
=
[
tf
.
TensorSpec
(
shape
=
[],
dtype
=
tf
.
float32
)]
...
...
tensorflow_federated/python/simulation/baselines/keras_metrics.py
浏览文件 @
90feb1a0
...
...
@@ -40,6 +40,11 @@ class NumTokensCounter(tf.keras.metrics.Sum):
sample_weight
=
tf
.
reshape
(
sample_weight
,
[
-
1
])
super
().
update_state
(
sample_weight
)
def
get_config
(
self
):
config
=
super
().
get_config
()
config
[
'masked_tokens'
]
=
tuple
(
self
.
_masked_tokens
)
return
config
class
MaskedCategoricalAccuracy
(
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
):
"""An accuracy metric that masks some tokens."""
...
...
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录