Skip to content
Snippets Groups Projects
Commit 1b69104a authored by Zachary Garrett's avatar Zachary Garrett
Browse files

Used tf.Variable.assign() instead of tf.assign() to be TF 2.0 compliant.

PiperOrigin-RevId: 272689129
parent 418e6997
No related branches found
No related tags found
No related merge requests found
......@@ -105,19 +105,21 @@ def server_update(model, server_optimizer, server_optimizer_vars, server_state,
Returns:
An updated `ServerState`.
"""
tf.nest.map_structure(tf.assign, (_get_weights(model), server_optimizer_vars),
model_weights = _get_weights(model)
# Initialize the model with the current state.
tf.nest.map_structure(lambda a, b: a.assign(b),
(model_weights, server_optimizer_vars),
(server_state.model, server_state.optimizer_state))
# Apply the update to the model.
grads_and_vars = tf.nest.map_structure(
lambda x, v: (-1.0 * x, v), tf.nest.flatten(weights_delta),
tf.nest.flatten(_get_weights(model).trainable))
tf.nest.flatten(model_weights.trainable))
server_optimizer.apply_gradients(grads_and_vars, name='server_update')
# Create a new state based on the updated model.
return tff.utils.update_state(
server_state,
model=_get_weights(model),
optimizer_state=server_optimizer_vars)
server_state, model=model_weights, optimizer_state=server_optimizer_vars)
@tf.function
......@@ -132,7 +134,9 @@ def client_update(model, dataset, initial_weights):
Returns:
A 'ClientOutput`.
"""
tf.nest.map_structure(tf.assign, _get_weights(model), initial_weights)
model_weights = _get_weights(model)
tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
initial_weights)
@tf.function
def reduce_fn(num_examples_sum, batch):
......@@ -143,8 +147,8 @@ def client_update(model, dataset, initial_weights):
num_examples_sum = dataset.reduce(
initial_state=tf.constant(0), reduce_func=reduce_fn)
weights_delta = tf.nest.map_structure(tf.subtract,
_get_weights(model).trainable,
weights_delta = tf.nest.map_structure(lambda a, b: a - b,
model_weights.trainable,
initial_weights.trainable)
aggregated_outputs = model.report_local_outputs()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment