提交 45e358da 编辑于 作者: Shanshan Wu's avatar Shanshan Wu 提交者: tensorflow-copybara
浏览文件

Change `max_num_samples` to `max_num_clients` in...

Change `max_num_samples` to `max_num_clients` in `tff.learning.build_personalization_eval`, and update the docstring. No change in the API or functionality.

PiperOrigin-RevId: 343883721
上级 cb8e4f80
......@@ -173,7 +173,7 @@ def main(argv):
model_fn=model_fn,
personalize_fn_dict=personalize_fn_dict,
baseline_evaluate_fn=evaluate_fn,
max_num_samples=100) # Metrics from at most 100 clients will be returned.
max_num_clients=100) # Metrics from at most 100 clients will be returned.
# Train a global model using the standard FedAvg algorithm.
num_total_rounds = 5
......@@ -201,7 +201,7 @@ def main(argv):
# maps keys (strategy names) in `personalize_fn_dict` to the evaluation
# metrics of the corresponding personalization strategies.
#
# Only metrics from at most `max_num_samples` participating clients are
# Only metrics from at most `max_num_clients` participating clients are
# collected (clients are sampled without replacement). Each metric is
# mapped to a list of scalars (each scalar comes from one client). Metric
# values at the same position, e.g., `metric_1[i]`, `metric_2[i]`, ...,
......
......@@ -30,13 +30,13 @@ from tensorflow_federated.python.learning import model_utils
def build_personalization_eval(model_fn,
personalize_fn_dict,
baseline_evaluate_fn,
max_num_samples=100,
max_num_clients=100,
context_tff_type=None):
"""Builds the TFF computation for evaluating personalization strategies.
The returned TFF computation broadcasts model weights from `tff.SERVER` to
`tff.CLIENTS`. Each client evaluates the personalization strategies given in
`personalize_fn_dict`. Evaluation metrics from at most `max_num_samples`
`personalize_fn_dict`. Evaluation metrics from at most `max_num_clients`
participating clients are collected to the server.
NOTE: The functions in `personalize_fn_dict` and `baseline_evaluate_fn` are
......@@ -67,10 +67,12 @@ def build_personalization_eval(model_fn,
`OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This
function is *only* used to compute the baseline metrics of the initial
model.
max_num_samples: A positive `int` specifying the maximum number of metric
samples to collect in a round. Each sample contains the personalization
metrics from a single client. If the number of participating clients in a
round is smaller than this value, all clients' metrics are collected.
max_num_clients: A positive `int` specifying the maximum number of clients
to collect metrics in a round (default is 100). The clients are sampled
without replacement. For each sampled client, all the personalization
metrics from this client will be collected. If the number of participating
clients in a round is smaller than this value, then metrics from all
clients will be collected.
context_tff_type: A `tff.Type` of the optional context object used by the
personalization strategies defined in `personalization_fn_dict`. We use a
context object to hold any extra information (in addition to the training
......@@ -93,7 +95,7 @@ def build_personalization_eval(model_fn,
(computed by `baseline_evaluate_fn`), and maps keys (strategy names) in
`personalize_fn_dict` to the evaluation metrics of the corresponding
personalization strategies.
* Note: only metrics from at most `max_num_samples` participating clients
* Note: only metrics from at most `max_num_clients` participating clients
(sampled without replacement) are collected to the SERVER. All collected
metrics are stored in a single `OrderedDict` (`personalization_metrics`
shown above), where each metric is mapped to a list of scalars (each
......@@ -103,7 +105,7 @@ def build_personalization_eval(model_fn,
Raises:
TypeError: If arguments are of the wrong types.
ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
ValueError: If `max_num_samples` is not positive.
ValueError: If `max_num_clients` is not positive.
"""
# Obtain the types by constructing the model first.
# TODO(b/124477628): Replace it with other ways of handling metadata.
......@@ -153,9 +155,9 @@ def build_personalization_eval(model_fn,
final_metrics.update(p13n_metrics)
return final_metrics
py_typecheck.check_type(max_num_samples, int)
if max_num_samples <= 0:
raise ValueError('max_num_samples must be a positive integer.')
py_typecheck.check_type(max_num_clients, int)
if max_num_clients <= 0:
raise ValueError('max_num_clients must be a positive integer.')
@computations.federated_computation(
computation_types.FederatedType(model_weights_type, placements.SERVER),
......@@ -170,7 +172,7 @@ def build_personalization_eval(model_fn,
# make sure that it is proper to collect those metrics from clients.
# TODO(b/147889283): Add a link to the TFF doc once it exists.
results = federated_aggregations.federated_sample(client_final_metrics,
max_num_samples)
max_num_clients)
return results
return personalization_eval
......
......@@ -399,16 +399,16 @@ class PersonalizationEvalTest(test_case.TestCase):
p13n_fn_dict = _create_p13n_fn_dict(learning_rate=1.0)
with self.assertRaises(TypeError):
# `max_num_samples` should be an `int`.
bad_num_samples = 1.0
# `max_num_clients` should be an `int`.
bad_num_clients = 1.0
p13n_eval.build_personalization_eval(
model_fn, p13n_fn_dict, _evaluate_fn, max_num_samples=bad_num_samples)
model_fn, p13n_fn_dict, _evaluate_fn, max_num_clients=bad_num_clients)
with self.assertRaises(ValueError):
# `max_num_samples` should be a positive `int`.
bad_num_samples = 0
# `max_num_clients` should be a positive `int`.
bad_num_clients = 0
p13n_eval.build_personalization_eval(
model_fn, p13n_fn_dict, _evaluate_fn, max_num_samples=bad_num_samples)
model_fn, p13n_fn_dict, _evaluate_fn, max_num_clients=bad_num_clients)
def test_success_with_small_sample_size(self):
......@@ -419,7 +419,7 @@ class PersonalizationEvalTest(test_case.TestCase):
p13n_fn_dict = _create_p13n_fn_dict(learning_rate=1.0)
federated_p13n_eval = p13n_eval.build_personalization_eval(
model_fn, p13n_fn_dict, _evaluate_fn, max_num_samples=1)
model_fn, p13n_fn_dict, _evaluate_fn, max_num_clients=1)
# Perform p13n eval on two clients.
results = federated_p13n_eval(zero_model_weights, [
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册