Skip to content
Snippets Groups Projects
Commit c82f1bfa authored by Shanshan Wu's avatar Shanshan Wu Committed by tensorflow-copybara
Browse files

A problem shows up with using subclasses of `tff.learning.Model`: after...

A problem shows up with using subclasses of `tff.learning.Model`: after wrapping the model as an EnhancedModel, one cannot access the methods that are specifically defined by the subclass model.

This CL removes the EnhancedModel wrapper used when computing baseline metrics and training personalized models. This makes sure that users can access the full functionality of the model returned by `model_fn`.

PiperOrigin-RevId: 321578480
parent 2b531140
No related branches found
No related tags found
No related merge requests found
......@@ -201,11 +201,12 @@ def _remove_batch_dim(spec):
def _compute_baseline_metrics(model_fn, initial_model_weights, test_data,
baseline_evaluate_fn):
"""Evaluate the model with weights being the `initial_model_weights`."""
model = model_utils.enhance(model_fn())
model = model_fn()
model_weights = model_utils.ModelWeights.from_model(model)
@tf.function
def assign_and_compute():
tff.utils.assign(model.weights, initial_model_weights)
tff.utils.assign(model_weights, initial_model_weights)
py_typecheck.check_callable(baseline_evaluate_fn)
return baseline_evaluate_fn(model, test_data)
......@@ -215,7 +216,8 @@ def _compute_baseline_metrics(model_fn, initial_model_weights, test_data,
def _compute_p13n_metrics(model_fn, initial_model_weights, train_data,
test_data, personalize_fn_dict, context):
"""Train and evaluate the personalized models."""
model = model_utils.enhance(model_fn())
model = model_fn()
model_weights = model_utils.ModelWeights.from_model(model)
# Construct the `personalize_fn` (and the associated `tf.Variable`s) here.
# This ensures that the new variables are created in the graphs that TFF
# controls. This is the key reason why we need `personalize_fn_dict` to
......@@ -233,7 +235,7 @@ def _compute_p13n_metrics(model_fn, initial_model_weights, train_data,
def loop_and_compute():
p13n_metrics = collections.OrderedDict()
for name, personalize_fn in personalize_fns.items():
tff.utils.assign(model.weights, initial_model_weights)
tff.utils.assign(model_weights, initial_model_weights)
py_typecheck.check_callable(personalize_fn)
p13n_metrics[name] = personalize_fn(model, train_data, test_data, context)
return p13n_metrics
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment