Skip to content
Snippets Groups Projects
Commit 8893893a authored by Scott Wegner's avatar Scott Wegner Committed by tensorflow-copybara
Browse files

Create CanonicalForm.summary() to print a statically-known type information.

This follows a similar convetion to tf.kersas.Model.summary(): https://www.tensorflow.org/api_docs/python/tf/keras/Model#summary

PiperOrigin-RevId: 273609432
parent 623dad4f
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@
<meta itemprop="property" content="work"/>
<meta itemprop="property" content="zero"/>
<meta itemprop="property" content="__init__"/>
<meta itemprop="property" content="summary"/>
</div>
# tff.backends.mapreduce.CanonicalForm
......@@ -320,3 +321,21 @@ decorator/wrapper.
<h3 id="work"><code>work</code></h3>
<h3 id="zero"><code>zero</code></h3>
## Methods
<h3 id="summary"><code>summary</code></h3>
<a target="_blank" href="http://github.com/tensorflow/federated/tree/master/tensorflow_federated/python/core/backends/mapreduce/canonical_form.py">View
source</a>
```python
summary(print_fn=print)
```
Prints a string summary of the `CanonicalForm`.
#### Arguments:
* <b>`print_fn`</b>: Print function to use. It will be called on each line of
the summary in order to capture the string summary.
......@@ -22,6 +22,14 @@ py_library(
],
)
py_test(
name = "canonical_form_test",
srcs = ["canonical_form_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":test_utils"],
)
py_library(
name = "canonical_form_utils",
srcs = ["canonical_form_utils.py"],
......
......@@ -415,3 +415,25 @@ class CanonicalForm(object):
@property
def update(self):
return self._update
def summary(self, print_fn=print):
"""Prints a string summary of the `CanonicalForm`.
Arguments:
print_fn: Print function to use. It will be called on each line of the
summary in order to capture the string summary.
"""
computations = [
('initialize', self.initialize),
('prepare', self.prepare),
('work', self.work),
('zero', self.zero),
('accumulate', self.accumulate),
('merge', self.merge),
('report', self.report),
('update', self.initialize),
]
for name, comp in computations:
# Add sufficient padding to align first column; len('initialize') == 10
print_fn('{:<10}: {}'.format(
name, comp.type_signature.compact_representation()))
# Lint as: python2, python3
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap
from absl.testing import absltest
import tensorflow as tf
from tensorflow_federated.python.core.backends.mapreduce import test_utils
class CanonicalFormTest(absltest.TestCase):
def test_summary(self):
cf = test_utils.get_temperature_sensor_example()
class CapturePrint(object):
def __init__(self):
self.summary = ''
def __call__(self, msg):
self.summary += msg + '\n'
capture = CapturePrint()
cf.summary(print_fn=capture)
self.assertEqual(
capture.summary,
textwrap.dedent("""\
initialize: ( -> <num_rounds=int32>)
prepare : (<num_rounds=int32> -> <max_temperature=float32>)
work : (<float32*,<max_temperature=float32>> -> <<is_over=bool>,<num_readings=int32>>)
zero : ( -> <num_total=int32,num_over=int32>)
accumulate: (<<num_total=int32,num_over=int32>,<is_over=bool>> -> <num_total=int32,num_over=int32>)
merge : (<<num_total=int32,num_over=int32>,<num_total=int32,num_over=int32>> -> <num_total=int32,num_over=int32>)
report : (<num_total=int32,num_over=int32> -> <ratio_over_threshold=float32>)
update : ( -> <num_rounds=int32>)
"""))
if __name__ == '__main__':
tf.compat.v1.enable_v2_behavior()
absltest.main()
package(default_visibility = ["//visibility:private"])
licenses(["notice"])
filegroup(
name = "build_docs",
srcs = ["build_docs.py"],
tags = ["ignore_srcs"],
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment