提交 ed5c3a9a 编辑于 作者: Zheng Xu's avatar Zheng Xu 提交者: tensorflow-copybara

Remove the final softmax layer from the CNN model in simple_fedavg.

PiperOrigin-RevId: 344259493
上级 8d3641ed
......@@ -81,9 +81,6 @@ def get_emnist_dataset():
def create_original_fedavg_cnn_model(only_digits=True):
"""The CNN model used in https://arxiv.org/abs/1602.05629.
This function is duplicated from research/optimization/emnist/models.py to
make this example completely stand-alone.
only_digits: If True, uses a final layer with 10 outputs, for use with the
digits only EMNIST dataset. If False, uses 62 outputs for the larger
......@@ -115,7 +112,6 @@ def create_original_fedavg_cnn_model(only_digits=True):
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
return model
......@@ -138,7 +134,7 @@ def main(argv):
def tff_model_fn():
"""Constructs a fully initialized model for use in federated averaging."""
keras_model = create_original_fedavg_cnn_model(only_digits=True)
loss = tf.keras.losses.SparseCategoricalCrossentropy()
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
return simple_fedavg_tf.KerasModelWrapper(keras_model,
test_data.element_spec, loss)
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册