Skip to content
Snippets Groups Projects
Commit 405ac3ac authored by Zachary Charles's avatar Zachary Charles Committed by tensorflow-copybara
Browse files

Remove use of anonymous tuple conversion calls in research/adaptive_lr_decay.

PiperOrigin-RevId: 322264433
parent 8bb2ffb4
No related branches found
No related tags found
No related merge requests found
......@@ -60,22 +60,6 @@ class ServerState(object):
client_lr_callback = attr.ib()
server_lr_callback = attr.ib()
@classmethod
def from_tff_result(cls, anon_tuple, from_anon_client_callback,
from_anon_server_callback):
"""Constructs a `ServerState` from any compatible anonymous tuple."""
model = tff.learning.ModelWeights(
trainable=tuple(anon_tuple.model.trainable),
non_trainable=tuple(anon_tuple.model.non_trainable))
return cls(
model=model,
optimizer_state=list(anon_tuple.optimizer_state),
client_lr_callback=from_anon_client_callback(
anon_tuple.client_lr_callback),
server_lr_callback=from_anon_server_callback(
anon_tuple.server_lr_callback))
@classmethod
def assign_weights_to_keras_model(cls, reference_weights, keras_model):
"""Assign the model weights to the weights of a `tf.keras.Model`.
......@@ -398,9 +382,7 @@ def build_fed_avg_process(model_fn,
tff_iterative_process = tff.templates.IterativeProcess(
initialize_fn=initialize_fn, next_fn=run_one_round)
return FedAvgDecayAdapter(tff_iterative_process,
client_lr_callback.from_anonymous_tuple,
server_lr_callback.from_anonymous_tuple)
return FedAvgDecayAdapter(tff_iterative_process)
class FedAvgDecayAdapter(adapters.IterativeProcessPythonAdapter):
......@@ -410,24 +392,14 @@ class FedAvgDecayAdapter(adapters.IterativeProcessPythonAdapter):
recording metrics.
"""
def __init__(self, iterative_process, from_anon_client_callback,
from_anon_server_callback):
def __init__(self, iterative_process):
self._iterative_process = iterative_process
self._from_anon_client_callback = from_anon_client_callback
self._from_anon_server_callback = from_anon_server_callback
def initialize(self):
initial_state = self._iterative_process.initialize()
return ServerState.from_tff_result(initial_state,
self._from_anon_client_callback,
self._from_anon_server_callback)
return self._iterative_process.initialize()
def next(self, state, data):
state, initial_metrics, metrics = self._iterative_process.next(state, data)
state = ServerState.from_tff_result(state, self._from_anon_client_callback,
self._from_anon_server_callback)
initial_metrics = initial_metrics._asdict(recursive=True)
metrics = metrics._asdict(recursive=True)
total_metrics = {
'before_training': initial_metrics,
'during_training': metrics
......
......@@ -77,7 +77,7 @@ class AdaptiveFedAvgTest(tf.test.TestCase):
state, outputs = iterative_process.next(state, client_datasets)
logging.info('Round %d: %s', round_num, outputs)
logging.info('Model: %s', state.model)
train_outputs.append(outputs)
train_outputs.append(outputs['train'])
return state, train_outputs
def test_comparable_to_fed_avg(self):
......@@ -114,7 +114,7 @@ class AdaptiveFedAvgTest(tf.test.TestCase):
for i in range(5):
self.assertAllClose(train_outputs[i]['during_training']['loss'],
reference_train_outputs[i].train['loss'], 1e-4)
reference_train_outputs[i]['loss'], 1e-4)
def test_fed_avg_without_decay_decreases_loss(self):
client_lr_callback = callbacks.create_reduce_lr_on_plateau(
......
......@@ -85,35 +85,6 @@ class ReduceLROnPlateau(object):
cooldown = attr.ib(default=None)
cooldown_counter = attr.ib(default=None)
@classmethod
def from_anonymous_tuple(cls, anon_tuple):
"""Creates a `ReduceLROnPlateau` instance from an anonymout tuple.
Used to convert objects such as anonymous tuples with matching named
attributes into `ReduceLROnPlateau` objects.
Args:
anon_tuple: An object with named attributes matching that of
`ReduceLROnPlateau`.
Returns:
A `ReduceLROnPlateau` instance.
"""
return cls(
learning_rate=anon_tuple.learning_rate,
monitor=anon_tuple.monitor,
decay_factor=anon_tuple.decay_factor,
minimize=anon_tuple.minimize,
best=anon_tuple.best,
min_delta=anon_tuple.min_delta,
min_lr=anon_tuple.min_lr,
window_size=anon_tuple.window_size,
metrics_window=list(anon_tuple.metrics_window),
patience=anon_tuple.patience,
wait=anon_tuple.wait,
cooldown=anon_tuple.cooldown,
cooldown_counter=anon_tuple.cooldown_counter)
def create_reduce_lr_on_plateau(**kwargs):
"""Initializes a callback in a way that automatically infers attributes."""
......
......@@ -105,11 +105,9 @@ class DecayIterativeProcessBuilderTest(tf.test.TestCase):
server_state_type = tff.FederatedType(
adaptive_fed_avg.ServerState(
model=tff.learning.ModelWeights(
trainable=[
tff.TensorType(tf.float32, [1, 1]),
tff.TensorType(tf.float32, [1])
],
non_trainable=[]),
trainable=(tff.TensorType(tf.float32, [1, 1]),
tff.TensorType(tf.float32, [1])),
non_trainable=()),
optimizer_state=[tf.int64],
client_lr_callback=lr_callback_type,
server_lr_callback=lr_callback_type), tff.SERVER)
......@@ -129,11 +127,12 @@ class DecayIterativeProcessBuilderTest(tf.test.TestCase):
mean_squared_error=tff.TensorType(tf.float32),
loss=tff.TensorType(tf.float32)), tff.SERVER)
self.assertEqual(
iterative_process._iterative_process.next.type_signature,
tff.FunctionType(
parameter=(server_state_type, dataset_type),
result=(server_state_type, metrics_type, metrics_type)))
actual_type = iterative_process._iterative_process.next.type_signature
expected_type = tff.FunctionType(
parameter=(server_state_type, dataset_type),
result=(server_state_type, metrics_type, metrics_type))
self.assertTrue(actual_type.is_equivalent_to(expected_type))
def test_iterative_process_decreases_loss(self):
iterative_process = decay_iterative_process_builder.from_flags(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment