Skip to content
Snippets Groups Projects
Commit f3e1cc88 authored by Zachary Charles's avatar Zachary Charles Committed by tensorflow-copybara
Browse files

Remove anonymous tuple conversion in training_loop_test.

PiperOrigin-RevId: 322377616
parent 770a22e9
No related branches found
No related tags found
No related merge requests found
......@@ -31,15 +31,6 @@ _Batch = collections.namedtuple('Batch', ['x', 'y'])
FLAGS = flags.FLAGS
def _from_tff_result(state):
return tff.learning.framework.ServerState(
model=tff.learning.ModelWeights(
list(state.model.trainable), list(state.model.non_trainable)),
optimizer_state=list(state.optimizer_state),
delta_aggregate_state=[],
model_broadcast_state=[])
class BasicAdapter(adapters.IterativeProcessPythonAdapter):
"""Converts iterative process results from anonymous tuples."""
......@@ -47,13 +38,10 @@ class BasicAdapter(adapters.IterativeProcessPythonAdapter):
self._iterative_process = iterative_process
def initialize(self):
initial_state = self._iterative_process.initialize()
return _from_tff_result(initial_state)
return self._iterative_process.initialize()
def next(self, state, data):
state, metrics = self._iterative_process.next(state, data)
state = _from_tff_result(state)
metrics = metrics._asdict(recursive=True)
outputs = None
return adapters.IterationResult(state, metrics, outputs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment