提交 db28145d 编辑于 作者: Zachary Charles's avatar Zachary Charles 提交者: tensorflow-copybara
浏览文件

Add a call to `tf.convert_to_tensor` within...

Add a call to `tf.convert_to_tensor` within `tff.simulation.datasets.TestClientData` in order to allow passing strings to `.serializable_dataset_fn`.

This change also adds tests that will fail without this conversion to tensor.

PiperOrigin-RevId: 390674662
上级 c1a5ee8d
......@@ -192,7 +192,8 @@ class TestClientData(client_data.ClientData):
# Recover data relating to the given client_id from the hash table.
tensor_slices_list = [
tf.io.parse_tensor(table.lookup(client_id), out_type=dtype)
tf.io.parse_tensor(
table.lookup(tf.convert_to_tensor(client_id)), out_type=dtype)
for table, dtype in zip(hash_tables, self._dtypes)
]
......
......@@ -112,7 +112,7 @@ class TestClientDataTest(tf.test.TestCase, parameterized.TestCase):
from_tensor_slices_client_data.TestClientData(tensor_slices_dict)
self.assertSameStructure(tensor_slices_dict, copy_of_tensor_slices_dict)
def test_basic(self):
def test_client_data_constructs_with_correct_clients_and_types(self):
tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
client_data = from_tensor_slices_client_data.TestClientData(
tensor_slices_dict)
......@@ -120,6 +120,11 @@ class TestClientDataTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(client_data.element_type_structure,
tf.TensorSpec(shape=(), dtype=tf.int32))
def test_create_tf_dataset_for_client_constructs(self):
tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
client_data = from_tensor_slices_client_data.TestClientData(
tensor_slices_dict)
def as_list(dataset):
return [self.evaluate(x) for x in dataset]
......@@ -128,6 +133,17 @@ class TestClientDataTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(
as_list(client_data.create_tf_dataset_for_client('b')), [4, 5])
def test_serializable_dataset_fn_constructs(self):
tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
client_data = from_tensor_slices_client_data.TestClientData(
tensor_slices_dict)
def as_list(dataset):
return [self.evaluate(x) for x in dataset]
self.assertEqual(
as_list(client_data.serializable_dataset_fn('a')), [1, 2, 3])
def test_where_client_data_is_tensors(self):
client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
self.assertCountEqual(TEST_DATA.keys(), client_data.client_ids)
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册