Skip to content
Snippets Groups Projects
Commit 130d2af4 authored by Zachary Charles's avatar Zachary Charles Committed by tensorflow-copybara
Browse files

Remove `from_tff_result` methods in optimization/

PiperOrigin-RevId: 322249517
parent e90385b1
No related branches found
No related tags found
No related merge requests found
......@@ -82,17 +82,6 @@ class ServerState(object):
# This is a float to avoid type incompatibility when calculating learning rate
# schedules.
@classmethod
def from_tff_result(cls, anon_tuple) -> 'ServerState':
"""Constructs a `ServerState` from any compatible anonymous tuple."""
model = ModelWeights(
trainable=tuple(anon_tuple.model.trainable),
non_trainable=tuple(anon_tuple.model.non_trainable))
return cls(
model=model,
optimizer_state=list(anon_tuple.optimizer_state),
round_num=anon_tuple.round_num)
@classmethod
def assign_weights_to_keras_model(cls, reference_model: ModelWeights,
keras_model: tf.keras.Model):
......@@ -280,8 +269,7 @@ class FederatedAveragingProcessAdapter(adapters.IterativeProcessPythonAdapter):
self._iterative_process = iterative_process
def initialize(self) -> ServerState:
initial_state = self._iterative_process.initialize()
return ServerState.from_tff_result(initial_state)
return self._iterative_process.initialize()
def next(
self,
......@@ -289,8 +277,6 @@ class FederatedAveragingProcessAdapter(adapters.IterativeProcessPythonAdapter):
data: Collection[tf.data.Dataset],
) -> adapters.IterationResult:
state, metrics = self._iterative_process.next(state, data)
state = ServerState.from_tff_result(state)
metrics = metrics._asdict(recursive=True)
outputs = None
return adapters.IterationResult(state, metrics, outputs)
......
......@@ -187,19 +187,6 @@ class ModelDeltaProcessTest(tf.test.TestCase):
train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5]['loss']
self.assertLess(train_gap_second_half, train_gap_first_half)
def test_conversion_from_tff_result(self):
federated_data = [[_batch_fn()]]
iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
_uncompiled_model_builder,
client_optimizer_fn=tf.keras.optimizers.SGD,
server_optimizer_fn=tf.keras.optimizers.SGD)
state, _, _ = self._run_rounds(iterproc_adapter, federated_data, 1)
converted_state = fed_avg_schedule.ServerState.from_tff_result(state)
self.assertIsInstance(converted_state, fed_avg_schedule.ServerState)
self.assertIsInstance(converted_state.model, fed_avg_schedule.ModelWeights)
def test_build_with_preprocess_function(self):
test_dataset = tf.data.Dataset.range(5)
client_datasets_type = tff.FederatedType(
......@@ -239,9 +226,8 @@ class ModelDeltaProcessTest(tf.test.TestCase):
expected_type = tff.FunctionType(
parameter=(server_state_type, client_datasets_type),
result=(server_state_type, metrics_type))
self.assertEqual(
iterproc.next.type_signature,
expected_type,
self.assertTrue(
iterproc.next.type_signature.is_equivalent_to(expected_type),
msg='{s}\n!={t}'.format(
s=iterproc.next.type_signature, t=expected_type))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment