提交 60192a19 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Adds failing test for 'aggregation-by-dataset-concatenation' pattern.

PiperOrigin-RevId: 414018974
上级 b19d3800
......@@ -74,6 +74,43 @@ class NoClientAggregationsTest(parameterized.TestCase):
fed_mean([])
class DatasetConcatAggregationTest(parameterized.TestCase):
@test_contexts.with_contexts
def test_executes_dataset_concat_aggregation(self):
self.skipTest('b/209050033')
tensor_spec = tf.TensorSpec(shape=[2], dtype=tf.float32)
@tff.tf_computation
def create_empty_ds():
empty_tensor = tf.zeros(
shape=[0] + tensor_spec.shape, dtype=tensor_spec.dtype)
return tf.data.Dataset.from_tensor_slices(empty_tensor)
@tff.tf_computation
def concat_datasets(ds1, ds2):
return ds1.concatenate(ds2)
@tff.tf_computation
def identity(ds):
return ds
@tff.federated_computation(
tff.type_at_clients(tff.SequenceType(tensor_spec)))
def do_a_federated_aggregate(client_ds):
return tff.federated_aggregate(
value=client_ds,
zero=create_empty_ds(),
accumulate=concat_datasets,
merge=concat_datasets,
report=identity)
input_data = tf.data.Dataset.from_tensor_slices([[0.1, 0.2]])
ds = do_a_federated_aggregate([input_data])
self.assertIsInstance(ds, tf.data.Dataset)
class TemperatureSensorExampleTest(parameterized.TestCase):
@test_contexts.with_contexts
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册