Commit 078120c0 authored by Nicole Mitchell's avatar Nicole Mitchell Committed by tensorflow-copybara
Browse files

Memory reduction measures for entropy compression.

Reduces the ops in the graph by
1) casting stochastic discretization `step_size` to `value_type`, rather than casting `value` to float32.
2) applying run length encoding layer-wise to each tensor, rather than concatenating structures to a single tensor for encoding.

PiperOrigin-RevId: 491695106
parent 0264e58d
......@@ -275,8 +275,7 @@ py_library(
srcs = ["elias_gamma_encode.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/aggregators:concat",
"//tensorflow_federated/python/aggregators:factory",
":factory",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/federated_context:federated_computation",
......@@ -284,6 +283,7 @@ py_library(
"//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/core/templates:measured_process",
],
......@@ -297,7 +297,7 @@ py_test(
srcs_version = "PY3",
deps = [
":elias_gamma_encode",
"//tensorflow_federated/python/aggregators:mean",
":mean",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:type_test_utils",
......
......@@ -19,7 +19,6 @@ from typing import Optional
import tensorflow as tf
import tensorflow_compression as tfc
from tensorflow_federated.python.aggregators import concat
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
......@@ -28,16 +27,15 @@ from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
@tensorflow_computation.tf_computation
def _get_bitrate(value, num_elements):
"""Return size (in bits) of an encoded value."""
bitstring_length = 8. * tf.cast(
tf.strings.length(value, unit="BYTE"), dtype=tf.float64)
return tf.math.divide_no_nan(bitstring_length, num_elements)
def _get_bits(value):
"""Return size (in bits) of an encoded tensor."""
return 8. * tf.cast(tf.strings.length(value, unit="BYTE"), dtype=tf.float64)
def _is_int32_or_structure_of_int32s(type_spec: computation_types.Type) -> bool:
......@@ -116,28 +114,36 @@ class EliasGammaEncodedSumFactory(factory.UnweightedAggregationFactory):
raise ValueError("Expect value_type to be an int32 tensor or a structure "
"containing only other structures of int32 tensors, "
f"found {value_type}.")
concat_fn, unconcat_fn = concat.create_concat_fns(value_type)
concat_value_type = concat_fn.type_signature.result
if self._bitrate_mean_factory is not None:
bitrate_mean_process = self._bitrate_mean_factory.create(
computation_types.to_type(tf.float64))
@tensorflow_computation.tf_computation(value_type)
def encode(value):
return tf.nest.map_structure(
lambda x: tfc.run_length_gamma_encode(data=x), value)
def sum_encoded_value(value):
@tensorflow_computation.tf_computation
def get_accumulator():
return tf.zeros(shape=concat_value_type.shape, dtype=tf.int32)
return type_conversions.structure_from_tensor_type_tree(
lambda x: tf.zeros(shape=x.shape, dtype=tf.int32), value_type)
@tensorflow_computation.tf_computation
def decode_accumulate_values(accumulator, encoded_value):
decoded_value = tfc.run_length_gamma_decode(
code=encoded_value, shape=concat_value_type.shape)
return accumulator + decoded_value
shapes = type_conversions.structure_from_tensor_type_tree(
lambda x: x.shape, value_type)
return tf.nest.map_structure(
lambda a, x, y: a + tfc.run_length_gamma_decode(code=x, shape=y),
accumulator, encoded_value, shapes)
@tensorflow_computation.tf_computation
def merge_decoded_values(decoded_value_1, decoded_value_2):
return decoded_value_1 + decoded_value_2
return tf.nest.map_structure(
tensorflow_computation.tf_computation(lambda x, y: x + y),
decoded_value_1, decoded_value_2)
@tensorflow_computation.tf_computation
def report_decoded_summation(summed_decoded_values):
......@@ -158,27 +164,35 @@ class EliasGammaEncodedSumFactory(factory.UnweightedAggregationFactory):
init_fn.type_signature.result, computation_types.at_clients(value_type))
def next_fn(state, value):
measurements = ()
concat_value = intrinsics.federated_map(concat_fn, value)
encoded_value = intrinsics.federated_map(
tensorflow_computation.tf_computation(
lambda x: tfc.run_length_gamma_encode(data=x)), concat_value)
encoded_value = intrinsics.federated_map(encode, value)
if self._bitrate_mean_factory is not None:
@tensorflow_computation.tf_computation
def get_num_elements():
return tf.constant(concat_value_type.shape.num_elements(), tf.float64)
num_elements = type_conversions.structure_from_tensor_type_tree(
lambda x: tf.constant(x.shape.num_elements(), tf.float64),
value_type)
return tf.math.add_n(tf.nest.flatten(num_elements))
num_elements = intrinsics.federated_eval(get_num_elements,
placements.CLIENTS)
bitrates = intrinsics.federated_map(_get_bitrate,
(encoded_value, num_elements))
@tensorflow_computation.tf_computation
def struct_get_bits(x):
return tf.math.add_n([_get_bits(t) for t in tf.nest.flatten(x)])
total_bits = intrinsics.federated_map(struct_get_bits, encoded_value)
bitrates = intrinsics.federated_map(
tensorflow_computation.tf_computation(
lambda x, y: tf.math.divide_no_nan(x=x, y=y, name="divide")),
(total_bits, num_elements))
avg_bitrate = bitrate_mean_process.next(
bitrate_mean_process.initialize(), bitrates).result
measurements = intrinsics.federated_zip(
collections.OrderedDict(elias_gamma_code_avg_bitrate=avg_bitrate))
decoded_value = sum_encoded_value(encoded_value)
unconcat_value = intrinsics.federated_map(unconcat_fn, decoded_value)
return measured_process.MeasuredProcessOutput(
state=state, result=unconcat_value, measurements=measurements)
state=state, result=decoded_value, measurements=measurements)
return aggregation_process.AggregationProcess(init_fn, next_fn)
......@@ -63,9 +63,15 @@ _test_expected_result_struct_int32_tensors = collections.OrderedDict({
'layer2': [-6, 2, 0, 0],
})
_test_avg_bitrate_int32_tensor_rank_1 = 16. / 4.
_test_avg_bitrate_int32_tensor_rank_2 = 32. / 8.
_test_avg_bitrate_struct_int32_tensors = 32. / 8.
# Avg num bits accross clients to represent each element
_test_avg_bitrate_int32_tensor_rank_1 = 16.0 / 4.0
_test_avg_bitrate_int32_tensor_rank_2 = 32.0 / 8.0
_test_avg_bitrate_struct_int32_tensors = 32.0 / 8.0
# Avg num bits across clients to represent each tensor
_test_avg_bitstring_length_int32_tensor_rank_1 = 16.0
_test_avg_bitstring_length_int32_tensor_rank_2 = 32.0
class EncodeUtilTest(tf.test.TestCase, parameterized.TestCase):
......@@ -168,16 +174,13 @@ class EncodeExecutionTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('int32_tensor_rank_1', _test_client_values_int32_tensor_rank_1,
_test_avg_bitrate_int32_tensor_rank_1),
_test_avg_bitstring_length_int32_tensor_rank_1),
('int32_tensor_rank_2', _test_client_values_int32_tensor_rank_2,
_test_avg_bitrate_int32_tensor_rank_2))
def test_bitstring_impl(self, client_values, avg_bitrate):
num_elements = tf.size(client_values[0], out_type=tf.float64)
_test_avg_bitstring_length_int32_tensor_rank_2))
def test_bitstring_impl(self, client_values, avg_total_bits):
bitstrings = [tfc.run_length_gamma_encode(x) for x in client_values]
bitrates = [
elias_gamma_encode._get_bitrate(x, num_elements) for x in bitstrings
]
self.assertEqual(np.mean(bitrates), avg_bitrate)
total_bits = [elias_gamma_encode._get_bits(x) for x in bitstrings]
self.assertEqual(np.mean(total_bits), avg_total_bits)
if __name__ == '__main__':
......
......@@ -187,9 +187,9 @@ def _discretize_struct(struct, step_size):
seed = tf.cast(
tf.stack([tf.timestamp() * 1e6,
tf.timestamp() * 1e6]), dtype=tf.int64)
scaled_x = tf.divide(tf.cast(x, tf.float32), step_size)
prob_x = scaled_x - tf.cast(tf.floor(scaled_x), tf.float32)
random_x = tf.random.stateless_uniform(x.shape, seed=seed, dtype=tf.float32)
scaled_x = tf.divide(x, tf.cast(step_size, x.dtype))
prob_x = scaled_x - tf.floor(scaled_x)
random_x = tf.random.stateless_uniform(x.shape, seed=seed, dtype=x.dtype)
discretized_x = tf.where(
tf.less_equal(random_x, prob_x), tf.math.ceil(scaled_x),
tf.math.floor(scaled_x))
......@@ -202,7 +202,7 @@ def _undiscretize_struct(struct, step_size, tf_dtype_struct):
"""Unscales the discretized structure and casts back to original dtypes."""
def undiscretize_tensor(x, original_dtype):
unscaled_x = tf.cast(x, tf.float32) * step_size
return tf.cast(unscaled_x, original_dtype)
unscaled_x = tf.cast(x, original_dtype) * tf.cast(step_size, original_dtype)
return unscaled_x
return tf.nest.map_structure(undiscretize_tensor, struct, tf_dtype_struct)
......@@ -35,16 +35,16 @@ _test_struct_type_int = [tf.int32, (tf.int32, (2,)), (tf.int32, (3, 3))]
_test_struct_type_float = [tf.float32, (tf.float32, (2,)), (tf.float32, (3, 3))]
_test_nested_struct_type_float = collections.OrderedDict(
a=[tf.float32, [(tf.float32, (2, 2, 1))]], b=(tf.float32, (3, 3)))
a=[tf.float16, [(tf.float32, (2, 2, 1))]], b=(tf.float16, (3, 3)))
def _make_test_nested_struct_value(value):
return collections.OrderedDict(
a=[
tf.constant(value, dtype=tf.float32),
tf.constant(value, dtype=tf.float16),
[tf.constant(value, dtype=tf.float32, shape=[2, 2, 1])]
],
b=tf.constant(value, dtype=tf.float32, shape=(3, 3)))
b=tf.constant(value, dtype=tf.float16, shape=(3, 3)))
def _named_test_cases_product(*args):
......@@ -141,14 +141,14 @@ class StochasticDiscretizationExecutionTest(tf.test.TestCase,
def test_discretize_impl(self, value_type, client_values, expected_sum):
factory = stochastic_discretization.StochasticDiscretizationFactory(
inner_agg_factory=_measurement_aggregator,
step_size=0.1,
step_size=0.125,
distortion_aggregation_factory=mean.UnweightedMeanFactory())
value_type = computation_types.to_type(value_type)
process = factory.create(value_type)
state = process.initialize()
expected_result = expected_sum
expected_quantized_result = tf.nest.map_structure(lambda x: x * 10,
expected_quantized_result = tf.nest.map_structure(lambda x: x * 8,
expected_sum)
expected_measurements = collections.OrderedDict(
stochastic_discretization=expected_quantized_result, distortion=0.)
......@@ -160,19 +160,21 @@ class StochasticDiscretizationExecutionTest(tf.test.TestCase,
result = output.result
self.assertAllClose(result, expected_result)
@parameterized.named_parameters(('int32', tf.int32), ('int64', tf.int64),
('float64', tf.float64))
@parameterized.named_parameters(
('float16', tf.float16), ('float32', tf.float32), ('float64', tf.float64))
def test_output_dtype(self, dtype):
"""Checks the tensor type gets casted during preprocessing."""
x = tf.range(8, dtype=dtype)
x = tf.range(8)
x = tf.cast(x, dtype=dtype)
encoded_x = stochastic_discretization._discretize_struct(x, step_size=0.1)
self.assertEqual(encoded_x.dtype, stochastic_discretization.OUTPUT_TF_TYPE)
@parameterized.named_parameters(('int32', tf.int32), ('int64', tf.int64),
('float64', tf.float64))
@parameterized.named_parameters(
('float16', tf.float16), ('float32', tf.float32), ('float64', tf.float64))
def test_revert_to_input_dtype(self, dtype):
"""Checks that postprocessing restores the original dtype."""
x = tf.range(8, dtype=dtype)
x = tf.range(8)
x = tf.cast(x, dtype=dtype)
encoded_x = stochastic_discretization._discretize_struct(x, step_size=1)
decoded_x = stochastic_discretization._undiscretize_struct(
encoded_x, step_size=1, tf_dtype_struct=dtype)
......@@ -232,15 +234,15 @@ class ScalingTest(tf.test.TestCase, parameterized.TestCase):
('step_size_3', 10**-5))
def test_scaling(self, step_size):
# Integers to prevent rounding.
x = tf.random.stateless_uniform([100], (1, 1), -100, 100, dtype=tf.int32)
discretized_x = stochastic_discretization._discretize_struct(
x, tf.cast(step_size, tf.float32))
x = tf.cast(
tf.random.stateless_uniform([100], (1, 1), -100, 100, dtype=tf.int32),
tf.float32)
discretized_x = stochastic_discretization._discretize_struct(x, step_size)
reverted_x = stochastic_discretization._undiscretize_struct(
discretized_x, step_size, tf_dtype_struct=tf.int32)
discretized_x, step_size, tf_dtype_struct=tf.float32)
x, discretized_x, reverted_x = self.evaluate([x, discretized_x, reverted_x])
self.assertAllEqual(
tf.round(tf.divide(tf.cast(x, tf.float32), step_size)),
discretized_x) # Scaling up.
self.assertAllEqual(tf.round(tf.divide(x, step_size)),
discretized_x) # Scaling up.
self.assertAllEqual(x, reverted_x) # Scaling down.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment