Skip to content
Snippets Groups Projects
Commit 90675172 authored by Michael Reneer's avatar Michael Reneer Committed by tensorflow-copybara
Browse files

Cleanup imports and visibility in the `research/utils` package.

PiperOrigin-RevId: 268795604
parent 5a93388b
Branches
Tags
No related merge requests found
package(default_visibility = ["//visibility:private"])
package(default_visibility = ["//tensorflow_federated/python/research"])
licenses(["notice"]) # Apache 2.0
......@@ -6,7 +6,6 @@ py_library(
name = "models",
srcs = ["models.py"],
srcs_version = "PY3",
visibility = ["//tensorflow_federated/python/research"],
)
py_binary(
......
package(default_visibility = ["//tensorflow_federated/python/research:__subpackages__"])
package(default_visibility = ["//tensorflow_federated/python/research"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "utils",
srcs = ["__init__.py"],
deps = [":utils_impl"],
)
py_library(
name = "checkpoint_utils",
srcs = ["checkpoint_utils.py"],
......
......@@ -12,30 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used in the research/ directory.
General utilities used by the other directories under `research/`, for things
like writing output, constructing grids of experiments, configuration via
command-line flags, etc.
These utilities are not part of the TFF pip package.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow_federated.python.research.utils.utils_impl import atomic_write_to_csv
from tensorflow_federated.python.research.utils.utils_impl import define_optimizer_flags
from tensorflow_federated.python.research.utils.utils_impl import get_optimizer_from_flags
from tensorflow_federated.python.research.utils.utils_impl import iter_grid
from tensorflow_federated.python.research.utils.utils_impl import record_new_flags
# Used by doc generation script.
_allowed_symbols = [
"iter_grid",
"atomic_write_to_csv",
"define_optimizer_flags",
"get_optimizer_from_flags",
"record_new_flags",
]
......@@ -23,7 +23,7 @@ import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow_federated.python.research.utils import utils_impl as utils
from tensorflow_federated.python.research.utils import utils_impl
FLAGS = flags.FLAGS
......@@ -32,15 +32,16 @@ FLAGS = flags.FLAGS
class UtilsTest(tf.test.TestCase):
def test_get_optimizer_from_flags(self):
utils.define_optimizer_flags('server', defaults=dict(learning_rate=1.25))
utils_impl.define_optimizer_flags(
'server', defaults=dict(learning_rate=1.25))
self.assertEqual(FLAGS.server_learning_rate, 1.25)
optimizer = utils.get_optimizer_from_flags('server')
optimizer = utils_impl.get_optimizer_from_flags('server')
self.assertEqual(optimizer.get_config()['learning_rate'], 1.25)
def test_define_optimizer_unused_default(self):
with self.assertRaisesRegex(ValueError, 'not consumed'):
# Use a different prefix to avoid declaring duplicate flags:
utils.define_optimizer_flags('client', defaults=dict(lr=1.25))
utils_impl.define_optimizer_flags('client', defaults=dict(lr=1.25))
def test_atomic_write(self):
# Ensure randomness for temp filenames.
......@@ -49,29 +50,30 @@ class UtilsTest(tf.test.TestCase):
for name in ['foo.csv', 'baz.csv.bz2']:
dataframe = pd.DataFrame(dict(a=[1, 2], b=[4.0, 5.0]))
output_file = os.path.join(absltest.get_default_test_tmpdir(), name)
utils.atomic_write_to_csv(dataframe, output_file)
utils_impl.atomic_write_to_csv(dataframe, output_file)
dataframe2 = pd.read_csv(output_file, index_col=0)
pd.testing.assert_frame_equal(dataframe, dataframe2)
# Overwriting
dataframe3 = pd.DataFrame(dict(a=[1, 2, 3], b=[4.0, 5.0, 6.0]))
utils.atomic_write_to_csv(dataframe3, output_file)
utils_impl.atomic_write_to_csv(dataframe3, output_file)
dataframe4 = pd.read_csv(output_file, index_col=0)
pd.testing.assert_frame_equal(dataframe3, dataframe4)
def test_iter_grid(self):
grid = dict(a=[], b=[])
self.assertCountEqual(list(utils.iter_grid(grid)), [])
self.assertCountEqual(list(utils_impl.iter_grid(grid)), [])
grid = dict(a=[1])
self.assertCountEqual(list(utils.iter_grid(grid)), [dict(a=1)])
self.assertCountEqual(list(utils_impl.iter_grid(grid)), [dict(a=1)])
grid = dict(a=[1, 2])
self.assertCountEqual(list(utils.iter_grid(grid)), [dict(a=1), dict(a=2)])
self.assertCountEqual(
list(utils_impl.iter_grid(grid)), [dict(a=1), dict(a=2)])
grid = dict(a=[1, 2], b='b', c=[3.0, 4.0])
self.assertCountEqual(
list(utils.iter_grid(grid)), [
list(utils_impl.iter_grid(grid)), [
dict(a=1, b='b', c=3.0),
dict(a=1, b='b', c=4.0),
dict(a=2, b='b', c=3.0),
......@@ -79,13 +81,13 @@ class UtilsTest(tf.test.TestCase):
])
def test_record_new_flags(self):
with utils.record_new_flags() as hparam_flags:
with utils_impl.record_new_flags() as hparam_flags:
flags.DEFINE_string('exp_name', 'name', 'Unique name for the experiment.')
flags.DEFINE_integer('random_seed', 0, 'Random seed for the experiment.')
self.assertCountEqual(hparam_flags, ['exp_name', 'random_seed'])
@mock.patch.object(utils, 'multiprocessing')
@mock.patch.object(utils_impl, 'multiprocessing')
def test_launch_experiment(self, mock_multiprocessing):
pool = mock_multiprocessing.Pool(processes=10)
......@@ -94,7 +96,7 @@ class UtilsTest(tf.test.TestCase):
collections.OrderedDict([('a_long', 1), ('b', 5.0)])
]
utils.launch_experiment(
utils_impl.launch_experiment(
'run_exp.py', grid_dict, '/tmp_dir', short_names={'a_long': 'a'})
expected = [
'python run_exp.py --a_long=1 --b=4.0 --root_output_dir=/tmp_dir '
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment