Skip to content
Snippets Groups Projects
Commit 770a22e9 authored by Zachary Garrett's avatar Zachary Garrett Committed by tensorflow-copybara
Browse files

Update FLARS code to work with 4ed574ab.

Also while were here, update to some more modern TFF coding practices:
-  Use `tff.learning.framework.ModelWeights` instead of a custom type.
-  Delete `from_tff_result` conversion helper, it is no longer needed.
-  Remove unused test helper methods.
-  Replace numpy dependency with TF ops.

PiperOrigin-RevId: 322359152
parent 405ac3ac
No related branches found
No related tags found
No related merge requests found
......@@ -60,17 +60,10 @@ class ServerState(object):
model = attr.ib()
optimizer_state = attr.ib()
@classmethod
def from_tff_result(cls, anon_tuple):
return cls(
model=tff.learning.framework.ModelWeights.from_tff_result(
anon_tuple.model),
optimizer_state=list(anon_tuple.optimizer_state))
def _create_optimizer_vars(model, optimizer):
"""Generate variables for optimizer."""
model_weights = _get_weights(model)
model_weights = tff.learning.framework.ModelWeights.from_model(model)
delta = tf.nest.map_structure(tf.zeros_like, model_weights.trainable)
flat_trainable_weights = tf.nest.flatten(model_weights.trainable)
grads_and_vars = tf.nest.map_structure(
......@@ -83,14 +76,6 @@ def _create_optimizer_vars(model, optimizer):
return optimizer.variables()
def _get_weights(model):
model_weights = collections.namedtuple('ModelWeights',
'trainable non_trainable')
return model_weights(
trainable=tensor_utils.to_var_dict(model.trainable_variables),
non_trainable=tensor_utils.to_var_dict(model.non_trainable_variables))
@tf.function
def server_update(model, server_optimizer, server_optimizer_vars, server_state,
weights_delta, grads_norm):
......@@ -107,7 +92,7 @@ def server_update(model, server_optimizer, server_optimizer_vars, server_state,
Returns:
An updated `ServerState`.
"""
model_weights = _get_weights(model)
model_weights = tff.learning.framework.ModelWeights.from_model(model)
tf.nest.map_structure(lambda v, t: v.assign(t),
(model_weights, server_optimizer_vars),
(server_state.model, server_state.optimizer_state))
......@@ -140,7 +125,7 @@ def client_update(model, optimizer, dataset, initial_weights):
Returns:
A 'ClientOutput`.
"""
model_weights = _get_weights(model)
model_weights = tff.learning.framework.ModelWeights.from_model(model)
tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
initial_weights)
flat_trainable_weights = tuple(tf.nest.flatten(model_weights.trainable))
......@@ -221,7 +206,8 @@ def build_server_init_fn(model_fn, server_optimizer_fn):
# state.
server_optimizer_vars = _create_optimizer_vars(model, server_optimizer)
return ServerState(
model=_get_weights(model), optimizer_state=server_optimizer_vars)
model=tff.learning.framework.ModelWeights.from_model(model),
optimizer_state=server_optimizer_vars)
return server_init_tf
......@@ -373,13 +359,14 @@ def build_federated_averaging_process(
Returns:
A `tff.templates.IterativeProcess`.
"""
dummy_model_for_metadata = model_fn()
type_signature_grads_norm = tff.NamedTupleType([
with tf.Graph().as_default():
dummy_model_for_metadata = model_fn()
type_signature_grads_norm = tuple(
weight.dtype for weight in tf.nest.flatten(
_get_weights(dummy_model_for_metadata).trainable)
])
dummy_model_for_metadata.trainable_variables))
server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn)
server_state_type = server_init_tf.type_signature.result
server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
server_state_type,
......@@ -400,5 +387,5 @@ def build_federated_averaging_process(
return tff.templates.IterativeProcess(
initialize_fn=tff.federated_computation(
lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
next_fn=run_one_round_tff)
......@@ -15,7 +15,6 @@
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
......@@ -39,44 +38,10 @@ def _keras_model_fn():
loss=tf.keras.losses.SparseCategoricalCrossentropy())
def mnist_forward_pass(variables, batch):
inputs, label = batch
y = tf.nn.softmax(tf.matmul(inputs, variables.weights) + variables.bias)
predictions = tf.cast(tf.argmax(y, 1), tf.int32)
flat_labels = tf.reshape(label, [-1])
loss = -tf.reduce_mean(
tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(predictions, flat_labels), tf.float32))
num_examples = tf.cast(tf.size(batch['y']), tf.float32)
variables.num_examples.assign_add(num_examples)
variables.loss_sum.assign_add(loss * num_examples)
variables.accuracy_sum.assign_add(accuracy * num_examples)
return loss, predictions
def get_local_mnist_metrics(variables):
return collections.OrderedDict(
num_examples=variables.num_examples,
loss=variables.loss_sum / variables.num_examples,
accuracy=variables.accuracy_sum / variables.num_examples)
@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
return collections.OrderedDict(
num_examples=tff.federated_sum(metrics.num_examples),
loss=tff.federated_mean(metrics.loss, metrics.num_examples),
accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
def create_client_data():
emnist_batch = collections.OrderedDict(
label=[5], pixels=np.random.rand(28, 28))
label=[5],
pixels=tf.random.stateless_uniform(shape=(28, 28), seed=(7, 42)))
output_types = collections.OrderedDict(label=tf.int64, pixels=tf.float32)
output_shapes = collections.OrderedDict(
......@@ -116,8 +81,8 @@ class FlarsFedAvgTest(tf.test.TestCase):
@tf.function
def deterministic_batch():
return collections.OrderedDict(
x=np.ones([1, 784], dtype=np.float32),
y=np.ones([1, 1], dtype=np.int64))
x=tf.ones([1, 784], dtype=tf.float32),
y=tf.ones([1, 1], dtype=tf.int64))
batch = deterministic_batch()
federated_data = [[batch]]
......@@ -132,10 +97,10 @@ class FlarsFedAvgTest(tf.test.TestCase):
for _ in range(3):
keras_evaluate(server_state)
server_state, output = it_process.next(server_state, federated_data)
loss_list.append(output.loss)
loss_list.append(output['loss'])
keras_evaluate(server_state)
self.assertLess(np.mean(loss_list[1:]), loss_list[0])
self.assertLess(tf.reduce_mean(loss_list[1:]), loss_list[0])
def test_self_contained_example_keras_model(self):
client_data = create_client_data()
......@@ -148,7 +113,7 @@ class FlarsFedAvgTest(tf.test.TestCase):
for _ in range(2):
state, outputs = trainer.next(state, train_data)
# Track the loss.
losses.append(outputs.loss)
losses.append(outputs['loss'])
self.assertLess(losses[1], losses[0])
......@@ -164,7 +129,7 @@ def server_init(model, optimizer):
"""
optimizer_vars = flars_fedavg._create_optimizer_vars(model, optimizer)
return (flars_fedavg.ServerState(
model=flars_fedavg._get_weights(model),
model=tff.learning.framework.ModelWeights.from_model(model),
optimizer_state=optimizer_vars), optimizer_vars)
......@@ -178,17 +143,19 @@ class ServerTest(tf.test.TestCase):
grad_norm = [1.0, 1.0]
weights_delta = tf.nest.map_structure(
lambda t: tf.ones_like(t) * np.inf,
flars_fedavg._get_weights(model).trainable)
lambda t: tf.ones_like(t) * float('inf'),
flars_fedavg.tff.learning.framework.ModelWeights.from_model(
model).trainable)
old_model_vars = self.evaluate(state.model)
old_model_vars = state.model
for _ in range(2):
state = flars_fedavg.server_update(model, server_optimizer,
optimizer_vars, state, weights_delta,
grad_norm)
model_vars = self.evaluate(state.model)
self.assertAllClose(old_model_vars._asdict(), model_vars._asdict())
model_vars = state.model
# Assert the model hasn't changed.
self.assertAllClose(old_model_vars.trainable, model_vars.trainable)
self.assertAllClose(old_model_vars.non_trainable, model_vars.non_trainable)
class ClientTest(tf.test.TestCase):
......@@ -196,10 +163,9 @@ class ClientTest(tf.test.TestCase):
def test_self_contained_example(self):
client_data = create_client_data()
model = _keras_model_fn()
outputs = self.evaluate(
flars_fedavg.client_update(model, tf.keras.optimizers.SGD(0.1),
client_data(),
flars_fedavg._get_weights(model)))
outputs = flars_fedavg.client_update(
model, tf.keras.optimizers.SGD(0.1), client_data(),
flars_fedavg.tff.learning.framework.ModelWeights.from_model(model))
self.assertAllEqual(outputs.weights_delta_weight, 2)
# Expect a grad for each layer:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment