Skip to content
Snippets Groups Projects
Commit 12ab02e8 authored by Scott Wegner's avatar Scott Wegner Committed by tensorflow-copybara
Browse files

Fix error-checking logic for `get_canonical_form_for_iterative_process()` for...

Fix error-checking logic for `get_canonical_form_for_iterative_process()` for `IterativeProcess`'s with no aggregation.

The logic was previously failing for this case, but with a harder to grok error message.

The previous logic had a slight bug in the branching logic, such that we would never hit the explicit `ValueError` in the final `else:` block.

The tests were also incorrectly passing because they hit a different validation check:

```
ValueError: Expected an AST containing an intrinsic with the uri: federated_secure_sum, found none.
```
PiperOrigin-RevId: 322385240
parent f3e1cc88
No related branches found
No related tags found
No related merge requests found
......@@ -940,6 +940,11 @@ def get_canonical_form_for_iterative_process(ip):
next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
if not (contains_federated_aggregate or contains_federated_secure_sum):
raise ValueError(
'Expected an `tff.templates.IterativeProcess` containing at least one '
'`federated_aggregate` or `federated_secure_sum`, found none.')
if contains_federated_aggregate and contains_federated_secure_sum:
before_aggregate, after_aggregate = (
transformations.force_align_and_split_by_intrinsics(
......@@ -947,18 +952,16 @@ def get_canonical_form_for_iterative_process(ip):
intrinsic_defs.FEDERATED_AGGREGATE.uri,
intrinsic_defs.FEDERATED_SECURE_SUM.uri,
]))
elif not contains_federated_aggregate:
elif contains_federated_secure_sum:
assert not contains_federated_aggregate
before_aggregate, after_aggregate = (
_create_before_and_after_aggregate_for_no_federated_aggregate(
after_broadcast))
elif not contains_federated_secure_sum:
else:
assert contains_federated_aggregate and not contains_federated_secure_sum
before_aggregate, after_aggregate = (
_create_before_and_after_aggregate_for_no_federated_secure_sum(
after_broadcast))
else:
raise ValueError(
'Expected an `tff.templates.IterativeProcess` containing at least one '
'`federated_aggregate` or `federated_secure_sum`, found none.')
type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast,
before_aggregate, after_aggregate)
......
......@@ -793,7 +793,10 @@ class GetCanonicalFormForIterativeProcessTest(CanonicalFormTestCase,
def test_raises_value_error_for_sum_example_with_no_aggregation(self):
ip = get_iterative_process_for_sum_example_with_no_aggregation()
with self.assertRaises(ValueError):
with self.assertRaisesRegex(
ValueError,
r'Expected .* containing at least one `federated_aggregate` or '
r'`federated_secure_sum`'):
canonical_form_utils.get_canonical_form_for_iterative_process(ip)
def test_returns_canonical_form_with_indirection_to_intrinsic(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment