Skip to content
Snippets Groups Projects
Commit 8d55fb7e authored by Zachary Garrett's avatar Zachary Garrett
Browse files

Don't assign variables immediately after initialization when the initial_value is already set.

Update to TF 2.0 optimizers and use gradient tape.

PiperOrigin-RevId: 272705511
parent 1b69104a
No related branches found
No related tags found
No related merge requests found
......@@ -181,25 +181,27 @@ def get_mnist_training_example():
model_vars = model_nt(
weights=tf.Variable(initial_value=state.model.weights, name='weights'),
bias=tf.Variable(initial_value=state.model.bias, name='bias'))
init_model = tf.compat.v1.global_variables_initializer()
with tf.control_dependencies([tf.compat.v1.global_variables_initializer()]):
init_model = tf.group(
tf.assign(model_vars.weights, state.model.weights),
tf.assign(model_vars.bias, state.model.bias))
optimizer = tf.train.GradientDescentOptimizer(state.learning_rate)
optimizer = tf.keras.optimizers.SGD(state.learning_rate)
@tf.function
def reduce_fn(loop_state, batch):
pred_y = tf.nn.softmax(
tf.matmul(batch.x, model_vars.weights) + model_vars.bias)
loss = -tf.reduce_mean(
tf.reduce_sum(
tf.one_hot(batch.y, 10) * tf.log(pred_y), reduction_indices=[1]))
with tf.control_dependencies([optimizer.minimize(loss)]):
return loop_state_nt(
num_examples=loop_state.num_examples + 1,
total_loss=loop_state.total_loss + loss)
"""Compute a single gradient step on an given batch of examples."""
with tf.GradientTape() as tape:
pred_y = tf.nn.softmax(
tf.matmul(batch.x, model_vars.weights) + model_vars.bias)
loss = -tf.reduce_mean(
tf.reduce_sum(
tf.one_hot(batch.y, 10) * tf.log(pred_y),
reduction_indices=[1],
))
grads = tape.gradient(loss, model_vars)
optimizer.apply_gradients(
zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
return loop_state_nt(
num_examples=loop_state.num_examples + 1,
total_loss=loop_state.total_loss + loss)
with tf.control_dependencies([init_model]):
loop_state = data.reduce(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment