Commit 61e5a882 authored by Keith Rush's avatar Keith Rush
Browse files

Cleaning up HDF5 client data impl.

parent 181cd6d4
......@@ -29,12 +29,13 @@ class HDF5ClientData(client_data.ClientData):
"""A `tff.simulation.ClientData` backed by an HDF5 file.
This class expects that the HDF5 file has a top-level group `examples` which
contains further subgroups, one per user, named by the user ID.
contains further subgroups, one per user, named by the user ID. Further, the
users must have identical keys.
The `tf.data.Dataset` returned by
`HDF5ClientData.create_tf_dataset_for_client(client_id)` yields tuples from
zipping all datasets that were found at `/data/client_id` group, in a similar
fashion to `tf.data.Dataset.from_tensor_slices()`.
`HDF5ClientData.create_tf_dataset_for_client(client_id)` yields ordered dicts
from zipping all datasets that were found at `/data/client_id` group, in a
similar fashion to `tf.data.Dataset.from_tensor_slices()`.
"""
_EXAMPLES_GROUP = "examples"
......@@ -55,8 +56,10 @@ class HDF5ClientData(client_data.ClientData):
# Get the types and shapes from the first client. We do it once during
# initialization so we can get both properties in one go.
tf_dataset = self._create_python_dataset(self._client_ids[0])
self._element_type_structure = tf_dataset.element_spec
example_tf_dataset = tf.data.Dataset.from_tensor_slices(
collections.OrderedDict((name, ds[()]) for name, ds in sorted(
self._h5_file[HDF5ClientData._EXAMPLES_GROUP][self._client_ids[0]].items())))
self._element_type_structure = example_tf_dataset.element_spec
@computations.tf_computation(tf.string)
def dataset_computation(client_id):
......@@ -70,11 +73,6 @@ class HDF5ClientData(client_data.ClientData):
self._dataset_computation = dataset_computation
def _create_python_dataset(self, client_id):
return tf.data.Dataset.from_tensor_slices(
collections.OrderedDict((name, ds[()]) for name, ds in sorted(
self._h5_file[HDF5ClientData._EXAMPLES_GROUP][client_id].items())))
@property
def client_ids(self):
return self._client_ids
......
......@@ -20,6 +20,7 @@ import h5py
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.simulation import hdf5_client_data
TEST_DATA = {
......@@ -127,4 +128,5 @@ class HDF5ClientDataTest(tf.test.TestCase, absltest.TestCase):
if __name__ == '__main__':
execution_contexts.set_local_execution_context()
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment