提交 c40e3083 编辑于 作者: Michael Reneer's avatar Michael Reneer 提交者: tensorflow-copybara
浏览文件

Fix PY 3.8 warning about using ABCs from `collections`.

Do not use or import the ABCs from `collections`, this is deprecated since Python 3.3, and in 3.9 it will stop working. These ABCs have moved to 'collections.abc' and should be used from this package.

PiperOrigin-RevId: 340678321
上级 ded6d2b9
......@@ -174,7 +174,7 @@ def is_name_value_pair(element, name_required=True, value_type=None):
Returns:
`True` if `element` is a named tuple element, otherwise `False`.
"""
if not isinstance(element, collections.Sequence) or len(element) != 2:
if not isinstance(element, collections.abc.Sequence) or len(element) != 2:
return False
if ((name_required or element[0] is not None) and
not isinstance(element[0], str)):
......
......@@ -71,7 +71,7 @@ class Struct(object):
TypeError: if the `elements` are not a list, or if any of the items on
the list is not a pair with a string at the first position.
"""
py_typecheck.check_type(elements, collections.Iterable)
py_typecheck.check_type(elements, collections.abc.Iterable)
values = []
names = []
name_to_index = {}
......
......@@ -489,7 +489,7 @@ class StructType(structure.Struct, Type, metaclass=_Intern):
@staticmethod
def _normalize_init_args(elements, convert=True):
py_typecheck.check_type(elements, collections.Iterable)
py_typecheck.check_type(elements, collections.abc.Iterable)
if convert:
if py_typecheck.is_named_tuple(elements):
elements = typing.cast(Any, elements)
......@@ -986,7 +986,7 @@ def to_type(spec) -> Type:
return StructWithPythonType(spec, type(spec))
elif py_typecheck.is_attrs(spec):
return _to_type_from_attrs(spec)
elif isinstance(spec, collections.Mapping):
elif isinstance(spec, collections.abc.Mapping):
# This is an unsupported mapping, likely a `dict`. StructType adds an
# ordering, which the original container did not have.
raise TypeError(
......
......@@ -387,8 +387,8 @@ class WrapParameterAsTupleTest(test_case.TestCase, parameterized.TestCase):
def assertSequenceEqual(self, a, b):
"""Assert two tff.SequenceType values are the same."""
if (isinstance(a, collections.Sequence) and
isinstance(b, collections.Sequence)):
if (isinstance(a, collections.abc.Sequence) and
isinstance(b, collections.abc.Sequence)):
sequence = zip(a, b)
elif isinstance(a, tf.data.Dataset) and isinstance(b, tf.data.Dataset):
sequence = tf.data.Dataset.zip(a, b)
......
......@@ -30,7 +30,7 @@ from tensorflow_federated.python.core.impl.executors import executor_base
from tensorflow_federated.python.core.impl.executors import executor_value_base
class HashableWrapper(collections.Hashable):
class HashableWrapper(collections.abc.Hashable):
"""A wrapper around non-hashable objects to be compared by identity."""
def __init__(self, target):
......@@ -80,17 +80,17 @@ def _get_hashable_key(value, type_spec):
elif isinstance(value, np.ndarray):
return ('<dtype={},shape={}>'.format(value.dtype,
value.shape), value.tobytes())
elif (isinstance(value, collections.Hashable) and
elif (isinstance(value, collections.abc.Hashable) and
not isinstance(value, (tf.Tensor, tf.Variable))):
# TODO(b/139200385): Currently Tensor and Variable returns True for
# `isinstance(value, collections.Hashable)` even when it's not hashable.
# Hence this workaround.
# `isinstance(value, collections.abc.Hashable)` even when it's not hashable.
# Hence this workaround.
return value
else:
return HashableWrapper(value)
class CachedValueIdentifier(collections.Hashable):
class CachedValueIdentifier(collections.abc.Hashable):
"""An identifier for a cached value."""
def __init__(self, identifier):
......@@ -132,7 +132,7 @@ class CachedValue(executor_value_base.ExecutorValue):
TypeError: If the arguments are of the wrong types.
"""
py_typecheck.check_type(identifier, CachedValueIdentifier)
py_typecheck.check_type(hashable_key, collections.Hashable)
py_typecheck.check_type(hashable_key, collections.abc.Hashable)
py_typecheck.check_type(type_spec, computation_types.Type)
if not asyncio.isfuture(target_future):
raise TypeError('Expected an asyncio future, got {}'.format(
......
......@@ -72,7 +72,7 @@ def infer_cardinalities(value, type_spec):
if type_spec.is_federated():
if type_spec.all_equal:
return {}
py_typecheck.check_type(value, collections.Sized)
py_typecheck.check_type(value, collections.abc.Sized)
return {type_spec.placement: len(value)}
elif type_spec.is_struct():
structure_value = structure.from_container(value, recursive=False)
......
......@@ -182,9 +182,9 @@ def _serialize_sequence_value(
.format(value_type, type_spec if type_spec is not None else 'unknown'))
# TFF must store the type spec here because TF will lose the ordering of the
# names for `tf.data.Dataset` that return elements of `collections.Mapping`
# type. This allows TFF to preserve and restore the key ordering upon
# deserialization.
# names for `tf.data.Dataset` that return elements of
# `collections.abc.Mapping` type. This allows TFF to preserve and restore the
# key ordering upon deserialization.
element_type = computation_types.to_type(value.element_spec)
return executor_pb2.Value(
sequence=executor_pb2.Value.Sequence(
......
......@@ -98,7 +98,7 @@ def get_tf_typespec_and_binding(parameter_type, arg_names, unpack=None):
pack_in_struct = False
arg_types, kwarg_types = [parameter_type], {}
py_typecheck.check_type(arg_names, collections.Iterable)
py_typecheck.check_type(arg_names, collections.abc.Iterable)
if len(arg_names) < len(arg_types):
raise ValueError(
'If provided, arg_names must be a list of at least {} strings to '
......@@ -357,7 +357,7 @@ def capture_result_from_graph(result, graph):
elif isinstance(result, structure.Struct):
return _get_bindings_for_elements(
structure.to_elements(result), graph, computation_types.StructType)
elif isinstance(result, collections.Mapping):
elif isinstance(result, collections.abc.Mapping):
if isinstance(result, collections.OrderedDict):
name_value_pairs = result.items()
else:
......
......@@ -33,7 +33,7 @@ def update_state(state, **kwargs):
"""
# TODO(b/129569441): Support Struct as well.
if not (py_typecheck.is_named_tuple(state) or py_typecheck.is_attrs(state) or
isinstance(state, collections.Mapping)):
isinstance(state, collections.abc.Mapping)):
raise TypeError('state must be a structure with named fields (e.g. '
'dict, attrs class, collections.namedtuple), '
'but found {}'.format(type(state)))
......@@ -50,6 +50,6 @@ def update_state(state, **kwargs):
'state does not contain a field named "{!s}"'.format(key))
d = state
d.update(kwargs)
if isinstance(state, collections.Mapping):
if isinstance(state, collections.abc.Mapping):
return d
return type(state)(**d)
......@@ -56,8 +56,8 @@ class KerasModelWrapper(object):
"""Forward pass of the model to get loss for a batch of data.
Args:
batch_input: A `collections.Mapping` with two keys, `x` for inputs and `y`
for labels.
batch_input: A `collections.abc.Mapping` with two keys, `x` for inputs and
`y` for labels.
training: Boolean scalar indicating training or inference mode.
Returns:
......
......@@ -46,8 +46,8 @@ class KerasModelWrapper(object):
"""Forward pass of the model to get loss for a batch of data.
Args:
batch_input: A `collections.Mapping` with two keys, `x` for inputs and `y`
for labels.
batch_input: A `collections.abc.Mapping` with two keys, `x` for inputs and
`y` for labels.
training: Boolean scalar indicating training or inference mode.
Returns:
......
......@@ -153,8 +153,8 @@ def state_with_new_model_weights(
raise TypeError('Element is not the same tensor type. old '
f'({old_value.dtype}, {old_value.shape}) != '
f'new ({new_value.dtype}, {new_value.shape})')
elif (isinstance(new_value, collections.Sequence) and
isinstance(old_value, collections.Sequence)):
elif (isinstance(new_value, collections.abc.Sequence) and
isinstance(old_value, collections.abc.Sequence)):
if len(old_value) != len(new_value):
raise TypeError('Model weights have different lengths: '
f'(old) {len(old_value)} != (new) {len(new_value)})\n'
......@@ -166,7 +166,7 @@ def state_with_new_model_weights(
'handled.\nOld weights structure: {old}\n'
'New weights structure: {new}\n'
'Must be one of (int, float, np.ndarray, tf.Tensor, '
'collections.Sequence)'.format(
'collections.abc.Sequence)'.format(
old=tf.nest.map_structure(type, old_value),
new=tf.nest.map_structure(type, new_value)))
......
......@@ -127,15 +127,15 @@ def from_keras_model(
else:
for type_elem in input_spec:
py_typecheck.check_type(type_elem, computation_types.TensorType)
if isinstance(input_spec, collections.Mapping):
if isinstance(input_spec, collections.abc.Mapping):
if 'x' not in input_spec:
raise ValueError(
'The `input_spec` is a collections.Mapping (e.g., a dict), so it '
'The `input_spec` is a collections.abc.Mapping (e.g., a dict), so it '
'must contain an entry with key `\'x\'`, representing the input(s) '
'to the Keras model.')
if 'y' not in input_spec:
raise ValueError(
'The `input_spec` is a collections.Mapping (e.g., a dict), so it '
'The `input_spec` is a collections.abc.Mapping (e.g., a dict), so it '
'must contain an entry with key `\'y\'`, representing the label(s) '
'to be used in the Keras loss(es).')
......@@ -312,7 +312,7 @@ class _KerasModel(model_lib.Model):
def _forward_pass(self, batch_input, training=True):
if hasattr(batch_input, '_asdict'):
batch_input = batch_input._asdict()
if isinstance(batch_input, collections.Mapping):
if isinstance(batch_input, collections.abc.Mapping):
inputs = batch_input.get('x')
else:
inputs = batch_input[0]
......@@ -321,7 +321,7 @@ class _KerasModel(model_lib.Model):
'Instead have keys {}'.format(list(batch_input.keys())))
predictions = self._keras_model(inputs, training=training)
if isinstance(batch_input, collections.Mapping):
if isinstance(batch_input, collections.abc.Mapping):
y_true = batch_input.get('y')
else:
y_true = batch_input[1]
......
......@@ -133,7 +133,7 @@ def enhance(model):
def _check_iterable_of_variables(variables):
py_typecheck.check_type(variables, collections.Iterable)
py_typecheck.check_type(variables, collections.abc.Iterable)
for v in variables:
py_typecheck.check_type(v, tf.Variable)
return variables
......
......@@ -350,7 +350,7 @@ class ConcreteClientData(ClientData):
will expose a `dataset_computation` attribute which can be used for
high-performance distributed simulations.
"""
py_typecheck.check_type(client_ids, collections.Iterable)
py_typecheck.check_type(client_ids, collections.abc.Iterable)
py_typecheck.check_callable(create_tf_dataset_for_client_fn)
if not client_ids:
......
......@@ -89,7 +89,7 @@ def concatenate_inputs_and_outputs(arg_list):
out_name_maps: Similar to `in_name_maps`.
"""
if not isinstance(arg_list, collections.Iterable):
if not isinstance(arg_list, collections.abc.Iterable):
raise TypeError('Please pass an iterable to '
'`concatenate_inputs_and_outputs`.')
(graph_def_list, init_op_names_list, in_names_list, out_names_list,
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册