Commit dcbf78d0 authored by Michael Reneer's avatar Michael Reneer Committed by tensorflow-copybara
Browse files

Cleanup usage of various `flatten` functions.

In the tff.program library, there is usage of various `flatten` functions:

* `tf.nest.flatten`
* `tree.flatten`
* `structure_utils.flatten`, which is a wrapper around `tree.flatten_with_path`

This change does two things:

1. Consolidates usage of various `flatten` functions to the `tree` package, `tree` is used over `tf.nest` because it has more functionality.
2. Renames `structure_utils.flatten` to `structure_utils.flatten_with_name` to disambiguate this functionality.

PiperOrigin-RevId: 410910920
parent 67425029
......@@ -26,6 +26,7 @@ from typing import Any, List, Optional, Union
from absl import logging
import tensorflow as tf
import tree
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.program import file_utils
......@@ -187,7 +188,7 @@ class FileProgramStateManager(program_state_manager.ProgramStateManager):
module = tf.saved_model.load(path)
flattened_value = module()
try:
program_state = tf.nest.pack_sequence_as(self._structure, flattened_value)
program_state = tree.unflatten_as(self._structure, flattened_value)
except ValueError as e:
raise FileProgramStateManagerStructureError(
f'The structure of type {type(self._structure)}:\n'
......@@ -237,7 +238,7 @@ class FileProgramStateManager(program_state_manager.ProgramStateManager):
raise program_state_manager.ProgramStateManagerStateAlreadyExistsError(
f'Program state already exists for version: {version}')
materialized_value = value_reference.materialize_value(program_state)
flattened_value = tf.nest.flatten(materialized_value)
flattened_value = tree.flatten(materialized_value)
module = file_utils.ValueModule(flattened_value)
file_utils.write_saved_model(module, path)
self._remove_old_program_state()
......@@ -30,6 +30,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Tuple, Sequence, Union
import numpy as np
import tensorflow as tf
import tree
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.program import file_utils
......@@ -221,7 +222,7 @@ class CSVFileReleaseManager(release_manager.ReleaseManager):
else:
self._remove_values_after(key - 1)
materialized_value = value_reference.materialize_value(value)
flattened_value = structure_utils.flatten(materialized_value)
flattened_value = structure_utils.flatten_with_name(materialized_value)
normalized_value = collections.OrderedDict()
for x, y in flattened_value.items():
......@@ -304,6 +305,6 @@ class SavedModelFileReleaseManager(release_manager.ReleaseManager):
py_typecheck.check_type(key, int)
path = self._get_path_for_key(key)
materialized_value = value_reference.materialize_value(value)
flattened_value = tf.nest.flatten(materialized_value)
flattened_value = tree.flatten(materialized_value)
module = file_utils.ValueModule(flattened_value)
file_utils.write_saved_model(module, path, overwrite=True)
......@@ -19,15 +19,15 @@ from typing import Any, OrderedDict
import tree
def flatten(structure: Any) -> OrderedDict[str, Any]:
"""Creates a flattened representation of the given `structure`.
def flatten_with_name(structure: Any) -> OrderedDict[str, Any]:
"""Creates a flattened representation of the given `structure` with names.
Args:
structure: A possibly nested structure.
Returns:
A `collections.OrderedDict` representing the flattened version of the given
`structure`, where the keys are string uniquely identifying the position of
`structure`, where the keys are names uniquely identifying the position of
the values in the structure of the given `structure`.
"""
flattened = tree.flatten_with_path(structure)
......
......@@ -21,7 +21,7 @@ from tensorflow_federated.python.program import structure_utils
from tensorflow_federated.python.program import test_utils
class FlattenTest(parameterized.TestCase, tf.test.TestCase):
class FlattenWithNameTest(parameterized.TestCase, tf.test.TestCase):
# pyformat: disable
@parameterized.named_parameters(
......@@ -64,7 +64,7 @@ class FlattenTest(parameterized.TestCase, tf.test.TestCase):
)
# pyformat: enable
def test_returns_result(self, structure, expected_result):
actual_result = structure_utils.flatten(structure)
actual_result = structure_utils.flatten_with_name(structure)
# The results are zipped and each item is compared because `assertAllEqual`
# fails to compare nonscalar numpy arrays correctly when they are nested in
......
......@@ -77,7 +77,7 @@ class TensorboardReleaseManager(release_manager.ReleaseManager):
"""
py_typecheck.check_type(key, int)
materialized_value = value_reference.materialize_value(value)
flattened_value = structure_utils.flatten(materialized_value)
flattened_value = structure_utils.flatten_with_name(materialized_value)
with self._summary_writer.as_default():
for name, value in flattened_value.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment