Skip to content
Snippets Groups Projects
Commit 025b2cba authored by Zachary Charles's avatar Zachary Charles Committed by Zachary Garrett
Browse files

Fix flag prefix issues in federated trainer due to abbreviation of "stackoverflow".

PiperOrigin-RevId: 332259391
parent f99b51e4
No related branches found
No related tags found
No related merge requests found
......@@ -148,6 +148,14 @@ TASK_FLAGS = collections.OrderedDict(
stackoverflow_nwp=so_nwp_flags,
stackoverflow_lr=so_lr_flags)
TASK_FLAG_PREFIXES = collections.OrderedDict(
cifar100='cifar100',
emnist_cr='emnist_cr',
emnist_ae='emnist_ae',
shakespeare='shakespeare',
stackoverflow_nwp='so_nwp',
stackoverflow_lr='so_lr')
def _get_hparam_flags():
"""Returns an ordered dictionary of pertinent hyperparameter flags."""
......@@ -179,18 +187,17 @@ def _get_task_args():
An ordered dictionary of (arg_name, arg_value) pairs.
"""
task_name = FLAGS.task
task_args = collections.OrderedDict()
if task_name in TASK_FLAGS:
task_flag_list = TASK_FLAGS[task_name]
task_flag_dict = utils_impl.lookup_flag_values(task_flag_list)
for key in task_flag_dict:
if key.startswith(task_name):
value = task_flag_dict.pop(key)
key = key[len(task_name):].lstrip('_-')
task_flag_dict[key] = value
return task_flag_dict
else:
return collections.OrderedDict()
task_flag_prefix = TASK_FLAG_PREFIXES[task_name]
for (key, value) in task_flag_dict.items():
if key.startswith(task_flag_prefix):
key = key[len(task_flag_prefix):].lstrip('_-')
task_args[key] = value
return task_args
def main(argv):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment