提交 a8ea2945 编辑于 作者: Harry Zhang's avatar Harry Zhang 提交者: tensorflow-copybara
浏览文件

Alter input and output history to be called broadcast and aggregate history

PiperOrigin-RevId: 300429296
上级 f4fcf0e2
......@@ -128,13 +128,13 @@ class TensorFlowComputationsTest(parameterized.TestCase):
comp(temperatures, threshold)
# Each client receives a tf.float32 and uploads two tf.float32 values.
expected_input_bits = num_clients * 32
expected_output_bits = expected_input_bits * 2
expected_broadcast_bits = num_clients * 32
expected_aggregate_bits = expected_broadcast_bits * 2
expected = ({
(('CLIENTS', num_clients),): [[1, tf.float32]] * num_clients
}, {
(('CLIENTS', num_clients),): [[1, tf.float32]] * num_clients * 2
}, [expected_input_bits], [expected_output_bits])
}, [expected_broadcast_bits], [expected_aggregate_bits])
self.assertEqual(expected, factory.get_size_info())
......
......@@ -191,42 +191,43 @@ class SizingExecutorFactoryImpl(ExecutorFactoryImpl):
self
) -> Tuple[Dict[Any, sizing_executor.SizeAndDTypes], Dict[
Any, sizing_executor.SizeAndDTypes], List[int], List[int]]:
"""Returns information about the input and output of each SizingExecutor.
"""Returns information about the transferred data of each SizingExecutor.
Returns the history of inputs and outputs for each executor as well as the
number of aggregated bits that has been passed through.
Returns the history of broadcast and aggregation for each executor as well
as the number of aggregated bits that has been passed through.
Returns:
A tuple of
2D ragged list of 2-tuples which represents the input history.
2D ragged list of 2-tuples which represents the output history.
A list of shape [number_of_execs] representing the number of input bits
passed through each executor.
A list of shape [number_of_execs] representing the number of output bits
passed through each executor.
2D ragged list of 2-tuples which represents the broadcast history.
2D ragged list of 2-tuples which represents the aggregation history.
A list of shape [number_of_execs] representing the number of broadcasted
bits passed through each executor.
A list of shape [number_of_execs] representing the number of aggregated
bits passed through each executor.
"""
size_ex_dict = self._sizing_executors
def _extract_history(sizing_exs: List[sizing_executor.SizingExecutor]):
input_history, output_history = [], []
broadcast_history, aggregate_history = [], []
for ex in sizing_exs:
input_history.extend(ex.input_history)
output_history.extend(ex.output_history)
return input_history, output_history
broadcast_history.extend(ex.broadcast_history)
aggregate_history.extend(ex.aggregate_history)
return broadcast_history, aggregate_history
input_history, output_history = {}, {}
broadcast_history, aggregate_history = {}, {}
for key, size_exs in size_ex_dict.items():
current_input_history, current_output_history = _extract_history(size_exs)
input_history[key] = current_input_history
output_history[key] = current_output_history
current_broadcast_history, current_aggregate_history = _extract_history(
size_exs)
broadcast_history[key] = current_broadcast_history
aggregate_history[key] = current_aggregate_history
input_bits = [
self._calculate_bit_size(hist) for hist in input_history.values()
broadcast_bits = [
self._calculate_bit_size(hist) for hist in broadcast_history.values()
]
output_bits = [
self._calculate_bit_size(hist) for hist in output_history.values()
aggregate_bits = [
self._calculate_bit_size(hist) for hist in aggregate_history.values()
]
return input_history, output_history, input_bits, output_bits
return broadcast_history, aggregate_history, broadcast_bits, aggregate_bits
def _bits_per_element(self, dtype: tf.DType) -> int:
"""Returns the number of bits that a tensorflow DType uses per element."""
......@@ -239,8 +240,8 @@ class SizingExecutorFactoryImpl(ExecutorFactoryImpl):
def _calculate_bit_size(self, history: sizing_executor.SizeAndDTypes) -> int:
"""Takes a list of 2 element lists and calculates the number of bits represented.
The input list should follow the format of self.input_history or
self.output_history. That is, each 2 element list should be
The input list should follow the format of self.broadcast_history or
self.aggregate_history. That is, each 2 element list should be
[num_elements, dtype].
Args:
......
......@@ -92,21 +92,21 @@ class SizingExecutor(executor_base.Executor):
"""
py_typecheck.check_type(target, executor_base.Executor)
self._target = target
self._input_history = []
self._output_history = []
self._broadcast_history = []
self._aggregate_history = []
@property
def output_history(self) -> SizeAndDTypes:
return self._output_history
def aggregate_history(self) -> SizeAndDTypes:
return self._aggregate_history
@property
def input_history(self) -> SizeAndDTypes:
return self._input_history
def broadcast_history(self) -> SizeAndDTypes:
return self._broadcast_history
async def create_value(self, value, type_spec=None):
target_val = await self._target.create_value(value, type_spec)
wrapped_val = SizingExecutorValue(self, target_val)
self._input_history.extend(get_type_information(value, type_spec))
self._broadcast_history.extend(get_type_information(value, type_spec))
return wrapped_val
async def create_call(self, comp, arg=None):
......@@ -160,6 +160,6 @@ class SizingExecutorValue(executor_value_base.ExecutorValue):
async def compute(self):
result = await self._value.compute()
self._owner.output_history.extend(
self._owner.aggregate_history.extend(
get_type_information(result, self._value.type_signature))
return result
......@@ -52,8 +52,8 @@ class SizingExecutorTest(parameterized.TestCase):
return await v5.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history, [[10, tf.int32]])
self.assertCountEqual(ex.output_history, [[10, tf.int32]])
self.assertCountEqual(ex.broadcast_history, [[10, tf.int32]])
self.assertCountEqual(ex.aggregate_history, [[10, tf.int32]])
def test_string(self):
ex = sizing_executor.SizingExecutor(eager_tf_executor.EagerTFExecutor())
......@@ -66,8 +66,10 @@ class SizingExecutorTest(parameterized.TestCase):
return await v1.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history, [[total_string_length, tf.string]])
self.assertCountEqual(ex.output_history, [[total_string_length, tf.string]])
self.assertCountEqual(ex.broadcast_history,
[[total_string_length, tf.string]])
self.assertCountEqual(ex.aggregate_history,
[[total_string_length, tf.string]])
def test_different_input_output(self):
ex = sizing_executor.SizingExecutor(eager_tf_executor.EagerTFExecutor())
......@@ -86,8 +88,8 @@ class SizingExecutorTest(parameterized.TestCase):
return await v3.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history, [[10, tf.int32]])
self.assertCountEqual(ex.output_history, [[1, tf.int32]])
self.assertCountEqual(ex.broadcast_history, [[10, tf.int32]])
self.assertCountEqual(ex.aggregate_history, [[1, tf.int32]])
def test_multiple_inputs(self):
ex = sizing_executor.SizingExecutor(eager_tf_executor.EagerTFExecutor())
......@@ -111,8 +113,9 @@ class SizingExecutorTest(parameterized.TestCase):
return await v5.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history, [[10, tf.int32], [10, tf.float64]])
self.assertCountEqual(ex.output_history, [[10, tf.int64]])
self.assertCountEqual(ex.broadcast_history,
[[10, tf.int32], [10, tf.float64]])
self.assertCountEqual(ex.aggregate_history, [[10, tf.int64]])
def test_nested_tuple(self):
ex = sizing_executor.SizingExecutor(eager_tf_executor.EagerTFExecutor())
......@@ -136,7 +139,7 @@ class SizingExecutorTest(parameterized.TestCase):
return await v1.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history,
self.assertCountEqual(ex.broadcast_history,
[[4, tf.int32], [2, tf.bool], [6, tf.int64],
[4, tf.int32], [2, tf.bool], [6, tf.int64]])
......@@ -150,7 +153,7 @@ class SizingExecutorTest(parameterized.TestCase):
return await v1.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history, [])
self.assertCountEqual(ex.broadcast_history, [])
def test_ordered_dict(self):
a = computation_types.TensorType(tf.string, [4])
......@@ -167,7 +170,7 @@ class SizingExecutorTest(parameterized.TestCase):
return await v1.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history,
self.assertCountEqual(ex.broadcast_history,
[[total_string_length, tf.string], [6, tf.int64]])
def test_unnamed_tuple(self):
......@@ -179,8 +182,8 @@ class SizingExecutorTest(parameterized.TestCase):
return await v1.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history, [[1, tf.int32], [1, tf.int32]])
self.assertCountEqual(ex.output_history, [[1, tf.int32], [1, tf.int32]])
self.assertCountEqual(ex.broadcast_history, [[1, tf.int32], [1, tf.int32]])
self.assertCountEqual(ex.aggregate_history, [[1, tf.int32], [1, tf.int32]])
@parameterized.named_parameters(
{
......@@ -247,7 +250,7 @@ class SizingExecutorTest(parameterized.TestCase):
return await v1.compute()
asyncio.get_event_loop().run_until_complete(_make())
self.assertCountEqual(ex.input_history,
self.assertCountEqual(ex.broadcast_history,
[[4, tf.int32], [2, tf.bool], [6, tf.int64],
[4, tf.int32], [2, tf.bool], [6, tf.int64]])
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册