提交 8d2f989f 编辑于 作者: Zachary Garrett's avatar Zachary Garrett 提交者: tensorflow-copybara
浏览文件

Clean-up old unused code path, TFF is now released on TF 2.3.0.

PiperOrigin-RevId: 341461479
上级 9d2a339f
......@@ -41,7 +41,6 @@ py_library(
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/tensorflow_libs:version_check",
],
)
......@@ -51,10 +50,7 @@ py_test(
srcs = ["client_data_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":client_data",
"//tensorflow_federated/python/tensorflow_libs:version_check",
],
deps = [":client_data"],
)
py_library(
......
......@@ -24,7 +24,6 @@ import tensorflow as tf
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.api import computation_base
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.tensorflow_libs import version_check
class ClientData(object, metaclass=abc.ABCMeta):
......@@ -148,30 +147,9 @@ class ClientData(object, metaclass=abc.ABCMeta):
# Note: simply calling Dataset.concatenate() will result in too deep
# recursion depth.
# Note: Tests are via the simple concrete from_tensor_slices_client_data.
# TODO(b/154763092): remove this check and only use the newer path.
if version_check.is_tensorflow_version_newer('2.3.0', tf):
logging.info('Using newer tf.data.Dataset construction behavior.')
# This works in tf-nightly, but isn't in a released tensorflow
# version yet.
client_datasets = [d for d in self.datasets(seed=seed)]
nested_dataset = tf.data.Dataset.from_tensor_slices(client_datasets)
example_dataset = nested_dataset.flat_map(lambda x: x)
else:
logging.info('Old TensorFlow version detected; defaulting to slower '
'tf.data.Dataset construction.')
def _generator():
for dataset in self.datasets(seed=seed):
for example in dataset:
yield example
types = tf.nest.map_structure(lambda t: t.dtype,
self.element_type_structure)
shapes = tf.nest.map_structure(lambda t: t.shape,
self.element_type_structure)
example_dataset = tf.data.Dataset.from_generator(_generator, types,
shapes)
client_datasets = list(self.datasets(seed=seed))
nested_dataset = tf.data.Dataset.from_tensor_slices(client_datasets)
example_dataset = nested_dataset.flat_map(lambda x: x)
return example_dataset
def preprocess(
......
......@@ -16,7 +16,6 @@ from absl.testing import absltest
import tensorflow as tf
from tensorflow_federated.python.simulation import client_data as cd
from tensorflow_federated.python.tensorflow_libs import version_check
class ConcreteClientDataTest(tf.test.TestCase, absltest.TestCase):
......@@ -36,12 +35,7 @@ class ConcreteClientDataTest(tf.test.TestCase, absltest.TestCase):
tf.TensorSpec(shape=(), dtype=tf.int64))
def length(ds):
if version_check.is_tensorflow_version_newer('2.3.0', tf):
# ds.cardinality() only works for RangeDataset at HEAD,
# and is not in a released version of TensorFlow yet.
return ds.cardinality().numpy()
else:
return tf.data.experimental.cardinality(ds).numpy()
return tf.data.experimental.cardinality(ds).numpy()
for i in client_ids:
self.assertEqual(
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册