提交 c638c139 编辑于 作者: Jakub Konecny's avatar Jakub Konecny 提交者: tensorflow-copybara
浏览文件

Improves tests for `federated_quantized_sum`.

This removes the context manager which temporarily replaced the reference `federated_secure_sum` with `federated_sum` while executors did not support this in tests. This has been addresses, thus no need for the workaround.

PiperOrigin-RevId: 341866125
上级 3aae728d
......@@ -164,13 +164,12 @@ py_test(
srcs_version = "PY3",
deps = [
":federated_aggregations",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:intrinsics",
"//tensorflow_federated/python/core/api:placements",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/backends/test:execution_contexts",
"//tensorflow_federated/python/core/test:static_assert",
],
)
......
......@@ -13,19 +13,17 @@
# limitations under the License.
import collections
import contextlib
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.api import intrinsics
from tensorflow_federated.python.core.api import placements
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.backends.test import execution_contexts
from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.core.utils import federated_aggregations
......@@ -1173,9 +1171,8 @@ def _build_test_sum_fn_py_bounds(value_type, lower_bound, upper_bound):
@computations.federated_computation(
computation_types.FederatedType(value_type, placements.CLIENTS))
def call_secure_sum(value):
with _hijack_federated_secure_sum():
summed_value = federated_aggregations.secure_quantized_sum(
value, lower_bound, upper_bound)
summed_value = federated_aggregations.secure_quantized_sum(
value, lower_bound, upper_bound)
return summed_value
return call_secure_sum
......@@ -1204,53 +1201,18 @@ def _build_test_sum_fn_tff_bounds(value_type, lower_bound_type,
computation_types.FederatedType(lower_bound_type, placements.SERVER),
computation_types.FederatedType(upper_bound_type, placements.SERVER))
def call_secure_sum(value, lower_bound, upper_bound):
with _hijack_federated_secure_sum():
summed_value = federated_aggregations.secure_quantized_sum(
value, lower_bound, upper_bound)
summed_value = federated_aggregations.secure_quantized_sum(
value, lower_bound, upper_bound)
return summed_value
return call_secure_sum
# TODO(b/162706090): Remove when simulated executors support
# `federated_secure_sum`.
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
def _hijack_federated_secure_sum():
"""Context manager which replaces `federated_secure_sum` with `federated_sum`.
The effect is that within the context, the use of
`intrinsics.federated_secure_sum` inside the `federated_aggregations` module
will be relaced by `intrinsics.federated_sum`.
"""
def fake_secure_sum(value, bitwidth):
# TODO(b/165856119): update parameter validation to reflect
# `federated_secure_sum` it it becomes possible to broadcast `bitwidth`.
value_type = value.type_signature.member
if value_type.is_struct():
bitwidth_struct = structure.from_container(bitwidth)
if not structure.is_same_structure(value_type, bitwidth_struct):
raise TypeError('value and bitwidth must have the same structure.\n'
'value: {v}\nbitwidth:{b}'.format(
v=value.type_signature.member,
b=bitwidth.type_signature))
return federated_aggregations.intrinsics.federated_sum(value)
real_secure_sum = getattr(federated_aggregations.intrinsics,
'federated_secure_sum')
setattr(federated_aggregations.intrinsics, 'federated_secure_sum',
fake_secure_sum)
yield
setattr(federated_aggregations.intrinsics, 'federated_secure_sum',
real_secure_sum)
def _np_val_fn(value, tf_dtype):
"""Converts `value` to numpy array of dtype corresponding to `tf_dtype`."""
return np.array(value, tf_dtype.as_numpy_dtype)
if __name__ == '__main__':
execution_contexts.set_local_execution_context()
execution_contexts.set_test_execution_context()
test_case.main()
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册