提交 8184a1c2 编辑于 作者: Keith Rush's avatar Keith Rush 提交者: tensorflow-copybara
浏览文件

Ensures parameter_type_hint is passed as a keyword-arg in value_impl.to_value.

Several sites in TFF are still passing in a context stack as the third argument to to_value; this changes ensures that this is no longer occurring.

PiperOrigin-RevId: 413784015
上级 4bbf4bc7
......@@ -65,7 +65,7 @@ def federated_computation_serializer(
result = yield value_impl.Value(
building_blocks.Reference(parameter_name, parameter_type))
annotated_result_type = type_conversions.infer_type(result)
result = value_impl.to_value(result, annotated_result_type, context_stack)
result = value_impl.to_value(result, annotated_result_type)
result_comp = result.comp
symbols_bound_in_context = context_stack.current.symbol_bindings
if symbols_bound_in_context:
......
......@@ -16,7 +16,7 @@
import abc
import collections
import itertools
from typing import Any, Union
from typing import Any, Optional, Union
import attr
import tensorflow as tf
......@@ -206,7 +206,7 @@ class Value(typed_object.TypedObject, metaclass=abc.ABCMeta):
arg = function_utils.pack_args(self.type_signature.parameter, args,
kwargs,
context_stack_impl.context_stack.current)
arg = to_value(arg, None, self).comp
arg = to_value(arg, None).comp
else:
arg = None
call = building_blocks.Call(self._comp, arg)
......@@ -296,9 +296,9 @@ def _dictlike_items_to_value(items, type_spec, container_type) -> Value:
def to_value(
arg: Any,
type_spec,
parameter_type_hint=None,
type_spec: Optional[computation_types.Type],
*,
parameter_type_hint=None,
zip_if_needed: bool = False,
) -> Value:
"""Converts the argument into an instance of the abstract class `tff.Value`.
......
......@@ -27,8 +27,6 @@ from tensorflow_federated.python.core.impl.federated_context import value_utils
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
_context_stack = context_stack_impl.context_stack
class ValueUtilsTest(parameterized.TestCase):
......@@ -59,7 +57,7 @@ class ValueUtilsTest(parameterized.TestCase):
@computations.federated_computation(
computation_types.FederatedType(tf.int32, placements.CLIENTS))
def _(x):
x = value_impl.to_value(x, None, _context_stack)
x = value_impl.to_value(x, None)
value_utils.ensure_federated_value(x, placements.CLIENTS)
return x
......@@ -68,7 +66,7 @@ class ValueUtilsTest(parameterized.TestCase):
@computations.federated_computation(
computation_types.FederatedType(tf.int32, placements.CLIENTS))
def _(x):
x = value_impl.to_value(x, None, _context_stack)
x = value_impl.to_value(x, None)
with self.assertRaises(TypeError):
value_utils.ensure_federated_value(x, placements.SERVER)
return x
......@@ -80,7 +78,7 @@ class ValueUtilsTest(parameterized.TestCase):
(computation_types.FederatedType(tf.int32, placements.CLIENTS),
computation_types.FederatedType(tf.int32, placements.CLIENTS))))
def _(x):
x = value_impl.to_value(x, None, _context_stack)
x = value_impl.to_value(x, None)
value_utils.ensure_federated_value(x)
return x
......@@ -91,7 +89,7 @@ class ValueUtilsTest(parameterized.TestCase):
(computation_types.FederatedType(tf.int32, placements.CLIENTS),
computation_types.FederatedType(tf.int32, placements.SERVER))))
def _(x):
x = value_impl.to_value(x, None, _context_stack)
x = value_impl.to_value(x, None)
with self.assertRaises(TypeError):
value_utils.ensure_federated_value(x)
return x
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册