Skip to content
Snippets Groups Projects
Unverified Commit 6947610c authored by Ronald Seoh's avatar Ronald Seoh
Browse files

update example code on TFF front page: dummy_batch was deprecated from from_keras_model()

parent e7ae539e
No related branches found
No related tags found
No related merge requests found
...@@ -69,10 +69,6 @@ landing_page: ...@@ -69,10 +69,6 @@ landing_page:
# Pick a subset of client devices to participate in training. # Pick a subset of client devices to participate in training.
train_data = [client_data(n) for n in range(3)] train_data = [client_data(n) for n in range(3)]
# Grab a single batch of data so that TFF knows what data looks like.
sample_batch = tf.nest.map_structure(
lambda x: x.numpy(), iter(train_data[0]).next())
# Wrap a Keras model for use with TFF. # Wrap a Keras model for use with TFF.
def model_fn(): def model_fn():
model = tf.keras.models.Sequential([ model = tf.keras.models.Sequential([
...@@ -81,7 +77,7 @@ landing_page: ...@@ -81,7 +77,7 @@ landing_page:
]) ])
return tff.learning.from_keras_model( return tff.learning.from_keras_model(
model, model,
dummy_batch=sample_batch, input_spec=train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(), loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment