Skip to content
Snippets Groups Projects
Commit e7199779 authored by Shanshan Wu's avatar Shanshan Wu Committed by tensorflow-copybara
Browse files

Fix a bug in public doc.

PiperOrigin-RevId: 277124116
parent b62640b8
No related branches found
No related tags found
No related merge requests found
......@@ -74,15 +74,15 @@ landing_page:
# Wrap a Keras model for use with TFF.
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, tf.nn.softmax, input_shape=(784,),
kernel_initializer='zeros')
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.SGD(0.1),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return tff.learning.from_compiled_keras_model(model, sample_batch)
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, tf.nn.softmax, input_shape=(784,),
kernel_initializer='zeros')
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.SGD(0.1),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return tff.learning.from_compiled_keras_model(model, sample_batch)
# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process(model_fn)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment