Skip to content
Snippets Groups Projects
Commit 651def46 authored by tensorflow-copybara's avatar tensorflow-copybara
Browse files

Merge pull request #877 from tf-encrypted:resolving-typo

PiperOrigin-RevId: 320014281
parents ebb018ac 02b8ecf6
Branches
Tags
No related merge requests found
Showing
with 57 additions and 57 deletions
......@@ -64,7 +64,7 @@ from tensorflow_federated.python.core.impl.executors.executor_stacks import sizi
from tensorflow_federated.python.core.impl.executors.executor_stacks import worker_pool_executor_factory
from tensorflow_federated.python.core.impl.executors.executor_value_base import ExecutorValue
from tensorflow_federated.python.core.impl.executors.federated_composing_strategy import FederatedComposingStrategy
from tensorflow_federated.python.core.impl.executors.federated_resolving_strategy import FederatedResovlingStrategy
from tensorflow_federated.python.core.impl.executors.federated_resolving_strategy import FederatedResolvingStrategy
from tensorflow_federated.python.core.impl.executors.federating_executor import FederatingExecutor
from tensorflow_federated.python.core.impl.executors.federating_executor import FederatingStrategy
from tensorflow_federated.python.core.impl.executors.reference_resolving_executor import ReferenceResolvingExecutor
......
......@@ -185,7 +185,7 @@ class FederatingExecutorFactory(executor_factory.ExecutorFactory):
]
self._sizing_executors.extend(client_stacks)
federating_strategy_factory = federated_resolving_strategy.FederatedResovlingStrategy.factory(
federating_strategy_factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
{
placement_literals.CLIENTS: [
client_stacks[k % len(client_stacks)]
......
......@@ -34,7 +34,7 @@ def create_test_federated_stack(
executor = eager_tf_executor.EagerTFExecutor()
return reference_resolving_executor.ReferenceResolvingExecutor(executor)
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.SERVER:
create_bottom_stack(),
placement_literals.CLIENTS: [
......@@ -54,7 +54,7 @@ def create_test_aggregated_stack(
return reference_resolving_executor.ReferenceResolvingExecutor(executor)
def create_worker_stack():
factroy = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factroy = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.SERVER:
create_bottom_stack(),
placement_literals.CLIENTS: [
......
......@@ -42,7 +42,7 @@ def _create_bottom_stack():
def _create_worker_stack():
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.SERVER: _create_bottom_stack(),
placement_literals.CLIENTS: [_create_bottom_stack() for _ in range(2)],
})
......
......@@ -29,7 +29,7 @@
| executors |
+-----------+
A `FederatedResovlingStrategy`:
A `FederatedResolvingStrategy`:
* Implements the logic for resolving federated types and intrinsics, while
delegating unplaced computations to the target executor(s) associated with
......@@ -61,11 +61,11 @@ from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.impl.types import type_factory
class FederatedResovlingStrategyValue(executor_value_base.ExecutorValue):
class FederatedResolvingStrategyValue(executor_value_base.ExecutorValue):
"""A value embedded in a `FederatedExecutor`."""
def __init__(self, value, type_signature):
"""Creates a `FederatedResovlingStrategyValue` embedding the given `value`.
"""Creates a `FederatedResolvingStrategyValue` embedding the given `value`.
Args:
value: An object to embed in the executor, one of the supported types
......@@ -98,7 +98,7 @@ class FederatedResovlingStrategyValue(executor_value_base.ExecutorValue):
return await self._value.compute()
elif isinstance(self._value, anonymous_tuple.AnonymousTuple):
results = await asyncio.gather(*[
FederatedResovlingStrategyValue(v, t).compute()
FederatedResolvingStrategyValue(v, t).compute()
for v, t in zip(self._value, self._type_signature)
])
element_types = anonymous_tuple.iter_elements(self._type_signature)
......@@ -120,7 +120,7 @@ class FederatedResovlingStrategyValue(executor_value_base.ExecutorValue):
py_typecheck.type_string(type(self._value))))
class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
class FederatedResolvingStrategy(federating_executor.FederatingStrategy):
"""A strategy for resolving federated types and intrinsics.
This strategy implements the `federating_executor.FederatingStrategy`
......@@ -143,7 +143,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
def __init__(self, executor: federating_executor.FederatingExecutor,
target_executors: Dict[str, executor_base.Executor]):
"""Creates a `FederatedResovlingStrategy`.
"""Creates a `FederatedResolvingStrategy`.
Args:
executor: A `federating_executor.FederatingExecutor` to use to handle
......@@ -160,7 +160,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
`executor_base.Executor` or a list of `executor_base.Executor`s.
ValueError: If `target_executors` contains a
`placement_literals.PlacementLiteral` key that is not a kind supported
by the `FederatedResovlingStrategy`.
by the `FederatedResolvingStrategy`.
"""
super().__init__(executor)
py_typecheck.check_type(target_executors, dict)
......@@ -259,11 +259,11 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
):
self._check_strategy_compatible_with_placement(
type_signature.result.placement)
return FederatedResovlingStrategyValue(value, type_signature)
return FederatedResolvingStrategyValue(value, type_signature)
async def compute_federated_value(
self, value: Any, type_signature: computation_types.Type
) -> FederatedResovlingStrategyValue:
) -> FederatedResolvingStrategyValue:
self._check_strategy_compatible_with_placement(type_signature.placement)
children = self._target_executors[type_signature.placement]
self._check_value_compatible_with_placement(value, type_signature.placement,
......@@ -274,7 +274,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
c.create_value(v, type_signature.member)
for v, c in zip(value, children)
])
return FederatedResovlingStrategyValue(result, type_signature)
return FederatedResolvingStrategyValue(result, type_signature)
@tracing.trace
async def _eval(self, arg, placement, all_equal):
......@@ -291,7 +291,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
return await child.create_call(await child.create_value(fn, fn_type))
results = await asyncio.gather(*[call(child) for child in children])
return FederatedResovlingStrategyValue(
return FederatedResolvingStrategyValue(
results,
computation_types.FederatedType(
fn_type.result, placement, all_equal=all_equal))
......@@ -321,7 +321,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
results = await asyncio.gather(*[
c.create_call(f, v) for c, (f, v) in zip(children, list(zip(fns, val)))
])
return FederatedResovlingStrategyValue(
return FederatedResolvingStrategyValue(
results,
computation_types.FederatedType(
fn_type.result, val_type.placement, all_equal=all_equal))
......@@ -345,7 +345,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
anonymous_tuple.AnonymousTuple([(k, v[idx]) for k, v in elements]))
new_vals = await asyncio.gather(
*[c.create_tuple(x) for c, x in zip(children, new_vals)])
return FederatedResovlingStrategyValue(
return FederatedResolvingStrategyValue(
new_vals,
computation_types.FederatedType(
computation_types.NamedTupleType((
......@@ -357,7 +357,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
@tracing.trace
async def compute_federated_aggregate(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
val_type, zero_type, accumulate_type, _, report_type = (
executor_utils.parse_federated_aggregate_argument_types(
arg.type_signature))
......@@ -377,7 +377,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
zero = arg.internal_representation[1]
accumulate = arg.internal_representation[2]
pre_report = await self.compute_federated_reduce(
FederatedResovlingStrategyValue(
FederatedResolvingStrategyValue(
anonymous_tuple.AnonymousTuple([(None, val), (None, zero),
(None, accumulate)]),
computation_types.NamedTupleType(
......@@ -389,7 +389,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
report = arg.internal_representation[4]
return await self.compute_federated_apply(
FederatedResovlingStrategyValue(
FederatedResolvingStrategyValue(
anonymous_tuple.AnonymousTuple([
(None, report), (None, pre_report.internal_representation)
]),
......@@ -399,13 +399,13 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
@tracing.trace
async def compute_federated_apply(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await self._map(arg)
@tracing.trace
async def compute_federated_broadcast(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
py_typecheck.check_type(arg.internal_representation, list)
if len(arg.internal_representation) != 1:
raise ValueError(
......@@ -417,7 +417,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
@tracing.trace
async def compute_federated_collect(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
type_analysis.check_federated_type(
arg.type_signature, placement=placement_literals.CLIENTS)
......@@ -428,7 +428,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
collected_items = await child.create_value(
await asyncio.gather(*[v.compute() for v in val]),
computation_types.SequenceType(member_type))
return FederatedResovlingStrategyValue(
return FederatedResolvingStrategyValue(
[collected_items],
computation_types.FederatedType(
computation_types.SequenceType(member_type),
......@@ -438,31 +438,31 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
@tracing.trace
async def compute_federated_eval_at_clients(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await self._eval(arg, placement_literals.CLIENTS, False)
@tracing.trace
async def compute_federated_eval_at_server(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await self._eval(arg, placement_literals.SERVER, True)
@tracing.trace
async def compute_federated_map(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await self._map(arg, all_equal=False)
@tracing.trace
async def compute_federated_map_all_equal(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await self._map(arg, all_equal=True)
@tracing.trace
async def compute_federated_mean(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
arg_sum = await self.compute_federated_sum(arg)
member_type = arg_sum.type_signature.member
count = float(len(arg.internal_representation))
......@@ -479,12 +479,12 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
arg_sum.internal_representation[0]),
(None, factor)]))
result = await child.create_call(multiply, multiply_arg)
return FederatedResovlingStrategyValue([result], arg_sum.type_signature)
return FederatedResolvingStrategyValue([result], arg_sum.type_signature)
@tracing.trace
async def compute_federated_reduce(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
self._check_arg_is_anonymous_tuple(arg)
if len(arg.internal_representation) != 3:
raise ValueError(
......@@ -517,7 +517,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
result = await child.create_call(
op, await child.create_tuple(
anonymous_tuple.AnonymousTuple([(None, result), (None, item)])))
return FederatedResovlingStrategyValue([result],
return FederatedResolvingStrategyValue([result],
computation_types.FederatedType(
result.type_signature,
placement_literals.SERVER,
......@@ -526,13 +526,13 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
@tracing.trace
async def compute_federated_secure_sum(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
raise NotImplementedError('The secure sum intrinsic is not implemented.')
@tracing.trace
async def compute_federated_sum(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
zero, plus = await asyncio.gather(
executor_utils.embed_tf_scalar_constant(self._executor,
......@@ -541,7 +541,7 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
arg.type_signature.member,
tf.add))
return await self.compute_federated_reduce(
FederatedResovlingStrategyValue(
FederatedResolvingStrategyValue(
anonymous_tuple.AnonymousTuple([
(None, arg.internal_representation),
(None, zero.internal_representation),
......@@ -554,32 +554,32 @@ class FederatedResovlingStrategy(federating_executor.FederatingStrategy):
@tracing.trace
async def compute_federated_value_at_clients(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await executor_utils.compute_intrinsic_federated_value(
self._executor, arg, placement_literals.CLIENTS)
@tracing.trace
async def compute_federated_value_at_server(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await executor_utils.compute_intrinsic_federated_value(
self._executor, arg, placement_literals.SERVER)
@tracing.trace
async def compute_federated_weighted_mean(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await executor_utils.compute_intrinsic_federated_weighted_mean(
self._executor, arg)
@tracing.trace
async def compute_federated_zip_at_clients(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await self._zip(arg, placement_literals.CLIENTS, all_equal=False)
@tracing.trace
async def compute_federated_zip_at_server(
self,
arg: FederatedResovlingStrategyValue) -> FederatedResovlingStrategyValue:
arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
return await self._zip(arg, placement_literals.SERVER, all_equal=True)
......@@ -24,13 +24,13 @@ from tensorflow_federated.python.core.impl.executors import federated_resolving_
from tensorflow_federated.python.core.impl.types import type_factory
class FederatedResovlingStrategyValueComputeTest(
class FederatedResolvingStrategyValueComputeTest(
executor_test_utils.AsyncTestCase):
def test_returns_value_with_embedded_value(self):
value = eager_tf_executor.EagerValue(10.0, None, tf.float32)
type_signature = computation_types.TensorType(tf.float32)
value = federated_resolving_strategy.FederatedResovlingStrategyValue(
value = federated_resolving_strategy.FederatedResolvingStrategyValue(
value, type_signature)
result = self.run_sync(value.compute())
......@@ -44,7 +44,7 @@ class FederatedResovlingStrategyValueComputeTest(
eager_tf_executor.EagerValue(12.0, None, tf.float32),
]
type_signature = type_factory.at_clients(tf.float32)
value = federated_resolving_strategy.FederatedResovlingStrategyValue(
value = federated_resolving_strategy.FederatedResolvingStrategyValue(
value, type_signature)
result = self.run_sync(value.compute())
......@@ -54,7 +54,7 @@ class FederatedResovlingStrategyValueComputeTest(
def test_returns_value_with_federated_type_at_clients_all_equal(self):
value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)]
type_signature = type_factory.at_clients(tf.float32, all_equal=True)
value = federated_resolving_strategy.FederatedResovlingStrategyValue(
value = federated_resolving_strategy.FederatedResolvingStrategyValue(
value, type_signature)
result = self.run_sync(value.compute())
......@@ -64,7 +64,7 @@ class FederatedResovlingStrategyValueComputeTest(
def test_returns_value_with_federated_type_at_server(self):
value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)]
type_signature = type_factory.at_server(tf.float32)
value = federated_resolving_strategy.FederatedResovlingStrategyValue(
value = federated_resolving_strategy.FederatedResolvingStrategyValue(
value, type_signature)
result = self.run_sync(value.compute())
......@@ -78,7 +78,7 @@ class FederatedResovlingStrategyValueComputeTest(
value = anonymous_tuple.AnonymousTuple((n, element) for n in names)
type_signature = computation_types.NamedTupleType(
(n, element_type) for n in names)
value = federated_resolving_strategy.FederatedResovlingStrategyValue(
value = federated_resolving_strategy.FederatedResolvingStrategyValue(
value, type_signature)
result = self.run_sync(value.compute())
......@@ -89,7 +89,7 @@ class FederatedResovlingStrategyValueComputeTest(
def test_raises_type_error_with_unembedded_federated_type(self):
value = [10.0, 11.0, 12.0]
type_signature = type_factory.at_clients(tf.float32)
value = federated_resolving_strategy.FederatedResovlingStrategyValue(
value = federated_resolving_strategy.FederatedResolvingStrategyValue(
value, type_signature)
with self.assertRaises(TypeError):
......@@ -98,7 +98,7 @@ class FederatedResovlingStrategyValueComputeTest(
def test_raises_runtime_error_with_unsupported_value_or_type(self):
value = 10.0
type_signature = computation_types.TensorType(tf.float32)
value = federated_resolving_strategy.FederatedResovlingStrategyValue(
value = federated_resolving_strategy.FederatedResolvingStrategyValue(
value, type_signature)
with self.assertRaises(RuntimeError):
......
......@@ -46,7 +46,7 @@ def create_test_executor(
executor = eager_tf_executor.EagerTFExecutor()
return reference_resolving_executor.ReferenceResolvingExecutor(executor)
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.SERVER:
create_bottom_stack(),
placement_literals.CLIENTS: [
......@@ -98,7 +98,7 @@ def get_named_parameters_for_supported_intrinsics() -> List[Tuple[str, Any]]:
class FederatingExecutorInitTest(executor_test_utils.AsyncTestCase):
def test_raises_type_error_with_no_target_executor_unplaced(self):
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.SERVER: eager_tf_executor.EagerTFExecutor(),
placement_literals.CLIENTS: eager_tf_executor.EagerTFExecutor(),
})
......@@ -303,7 +303,7 @@ class FederatingExecutorCreateValueTest(executor_test_utils.AsyncTestCase,
# pyformat: enable
def test_raises_value_error_with_no_target_executor_clients(
self, value, type_signature):
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.SERVER: eager_tf_executor.EagerTFExecutor(),
})
executor = federating_executor.FederatingExecutor(
......@@ -340,7 +340,7 @@ class FederatingExecutorCreateValueTest(executor_test_utils.AsyncTestCase,
# pyformat: enable
def test_raises_value_error_with_no_target_executor_server(
self, value, type_signature):
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.CLIENTS: eager_tf_executor.EagerTFExecutor(),
})
executor = federating_executor.FederatingExecutor(
......
......@@ -247,7 +247,7 @@ class ReferenceResolvingExecutorTest(absltest.TestCase):
def test_with_federated_map(self):
eager_ex = eager_tf_executor.EagerTFExecutor()
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory(
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
{placement_literals.SERVER: eager_ex})
federated_ex = federating_executor.FederatingExecutor(factory, eager_ex)
ex = reference_resolving_executor.ReferenceResolvingExecutor(federated_ex)
......@@ -270,7 +270,7 @@ class ReferenceResolvingExecutorTest(absltest.TestCase):
def test_with_federated_map_and_broadcast(self):
eager_ex = eager_tf_executor.EagerTFExecutor()
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.SERVER: eager_ex,
placement_literals.CLIENTS: [eager_ex for _ in range(3)]
})
......@@ -296,7 +296,7 @@ class ReferenceResolvingExecutorTest(absltest.TestCase):
def test_raises_with_closure(self):
eager_ex = eager_tf_executor.EagerTFExecutor()
factory = federated_resolving_strategy.FederatedResovlingStrategy.factory({
factory = federated_resolving_strategy.FederatedResolvingStrategy.factory({
placement_literals.SERVER: eager_ex,
})
federated_ex = federating_executor.FederatingExecutor(factory, eager_ex)
......
......@@ -82,7 +82,7 @@ def make_remote_executor(inferred_cardinalities):
worker_stack = create_worker_stack(remote_ex)
client_ex.append(worker_stack)
federating_strategy_factory = tff.framework.FederatedResovlingStrategy.factory(
federating_strategy_factory = tff.framework.FederatedResolvingStrategy.factory(
{
tff.SERVER: create_worker_stack(tff.framework.EagerTFExecutor()),
tff.CLIENTS: client_ex,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment