提交 a89c263f 编辑于 作者: Zheng Xu's avatar Zheng Xu 提交者: tensorflow-copybara
浏览文件

Add Multi-GPU configuration to simple_fedavg as an example. This is a...

Add Multi-GPU configuration to simple_fedavg as an example. This is a temporary public exposure before an official tutorial on multi-GPU configuration is released. Also update some default parameters to be more accelerators friendly.

PiperOrigin-RevId: 346644343
上级 d05c6ef9
......@@ -35,8 +35,8 @@ flags.DEFINE_integer('train_clients_per_round', 2,
'How many clients to sample per round.')
flags.DEFINE_integer('client_epochs_per_round', 1,
'Number of epochs in the client to take per round.')
flags.DEFINE_integer('batch_size', 20, 'Batch size used on the client.')
flags.DEFINE_integer('test_batch_size', 100, 'Minibatch size of test data.')
flags.DEFINE_integer('batch_size', 16, 'Batch size used on the client.')
flags.DEFINE_integer('test_batch_size', 128, 'Minibatch size of test data.')
# Optimizer configuration (this defines one or more flags per optimizer).
flags.DEFINE_float('server_learning_rate', 1.0, 'Server learning rate.')
......@@ -129,6 +129,17 @@ def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# If GPU is provided, TFF will by default use the first GPU like TF. The
# following lines will configure TFF to use multi-GPUs and distribute client
# computation on the GPUs. Note that we put server computatoin on CPU to avoid
# potential out of memory issue when a large number of clients is sampled per
# round. The client devices below can be an empty list when no GPU could be
# detected by TF.
client_devices = tf.config.list_logical_devices('GPU')
server_device = tf.config.list_logical_devices('CPU')[0]
tff.backends.native.set_local_execution_context(
server_tf_device=server_device, client_tf_devices=client_devices)
train_data, test_data = get_emnist_dataset()
def tff_model_fn():
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册