提交 09b14dc9 编辑于 作者: Taylor Cramer's avatar Taylor Cramer 提交者: tensorflow-copybara
浏览文件

Clarify error message following failed value serialization

Previously, we relied on numpy to provide helpful error messages in the case where the provided value did not match the required dtype. However, this resulted in somewhat unfriendly error messages such as the following (in the case where an int64 tensor was expected but a tf.data.Dataset was provided):

TypeError: Cannot cast scalar from dtype('O') to dtype('int64') according to the rule 'same_kind'

The additional context provided by this error message (the value being serialized, the Python type of the value being serialized, and the desired TFF TensorType) should help users find and correct errors more easily.

PiperOrigin-RevId: 411833481
上级 0fd36cab
......@@ -86,6 +86,7 @@ def _serialize_tensor_value(
TypeError: If the arguments are of the wrong types.
ValueError: If the value is malformed.
"""
original_value = value
if tf.is_tensor(value):
if isinstance(value, tf.Variable):
value = value.read_value()
......@@ -107,7 +108,13 @@ def _serialize_tensor_value(
raise TypeError(f'Cannot serialize tensor with shape {value.shape} to '
f'shape {type_spec.shape}.')
if value.dtype != type_spec.dtype.as_numpy_dtype:
value = value.astype(type_spec.dtype.as_numpy_dtype, casting='same_kind')
try:
value = value.astype(type_spec.dtype.as_numpy_dtype, casting='same_kind')
except TypeError as te:
value_type_string = py_typecheck.type_string(type(original_value))
raise TypeError(
f'Failed to serialize value of Python type {value_type_string} to '
f'a tensor of type {type_spec}.\nValue: {original_value}') from te
return serialization_bindings.serialize_tensor_value(value), type_spec
......
......@@ -28,7 +28,6 @@ from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_serialization
# Convenience aliases.
TensorType = computation_types.TensorType
......@@ -110,20 +109,20 @@ class ValueSerializationtest(test_case.TestCase, parameterized.TestCase):
self.assert_types_identical(type_spec, TensorType(tf.float32))
self.assertAllEqual(x, y)
def test_serialize_raises_on_incompatible_dtype_float_to_int(self):
x = tf.constant(10.0)
with self.assertRaisesRegex(TypeError, 'Failed to serialize value'):
value_serialization.serialize_value(x, TensorType(tf.int32))
def test_serialize_deserialize_tensor_value_with_different_dtype(self):
with self.subTest('float2int'):
x = tf.constant(10.0)
with self.assertRaisesRegex(TypeError, 'Cannot cast scalar'):
value_serialization.serialize_value(x, TensorType(tf.int32))
with self.subTest('int2float'):
x = tf.constant(10)
value_proto, value_type = value_serialization.serialize_value(
x, TensorType(tf.float32))
self.assertIsInstance(value_proto, executor_pb2.Value)
self.assert_types_identical(value_type, TensorType(tf.float32))
y, type_spec = value_serialization.deserialize_value(value_proto)
self.assert_types_identical(type_spec, TensorType(tf.float32))
self.assertEqual(y, 10.0)
x = tf.constant(10)
value_proto, value_type = value_serialization.serialize_value(
x, TensorType(tf.float32))
self.assertIsInstance(value_proto, executor_pb2.Value)
self.assert_types_identical(value_type, TensorType(tf.float32))
y, type_spec = value_serialization.deserialize_value(value_proto)
self.assert_types_identical(type_spec, TensorType(tf.float32))
self.assertEqual(y, 10.0)
def test_serialize_deserialize_tensor_value_with_nontrivial_shape(self):
x = tf.constant([10, 20, 30])
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册