Skip to content
Snippets Groups Projects
Commit db837e9c authored by Zachary Garrett's avatar Zachary Garrett Committed by tensorflow-copybara
Browse files

Use python containers instead of `tff.NamedTupleType` for type specifications.

This will preserve the returned container value, making it usable with utilities such as `tf.nest`.

PiperOrigin-RevId: 322236441
parent 8094cf63
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data loader for Stackoverflow."""
import collections
from typing import List
from absl import logging
......@@ -134,19 +136,16 @@ def create_train_dataset_preprocess_fn(vocab: List[str],
else:
shuffle_buffer_size = max_training_elements_per_user
feature_dtypes = [
('creation_date', tf.string),
('title', tf.string),
('score', tf.int64),
('tags', tf.string),
('tokens', tf.string),
('type', tf.string),
]
@tff.tf_computation(
tff.SequenceType(
tff.NamedTupleType([(name, tff.TensorType(dtype=dtype, shape=()))
for name, dtype in feature_dtypes])))
feature_dtypes = collections.OrderedDict(
creation_date=tf.string,
title=tf.string,
score=tf.int64,
tags=tf.string,
tokens=tf.string,
type=tf.string,
)
@tff.tf_computation(tff.SequenceType(feature_dtypes))
def preprocess_train(dataset):
to_ids = build_to_ids_fn(vocab, max_seq_len)
dataset = dataset.take(max_training_elements_per_user)
......
......@@ -20,6 +20,16 @@ from tensorflow_federated.python.common_libs import test
from tensorflow_federated.python.research.utils.datasets import stackoverflow_dataset
TEST_DATA = collections.OrderedDict(
creation_date=(['unused date']),
title=(['unused title']),
score=([tf.constant(0, dtype=tf.int64)]),
tags=(['unused test tag']),
tokens=(['one must imagine']),
type=(['unused type']),
)
def _compute_length_of_dataset(ds):
return ds.reduce(0, lambda x, _: x + 1)
......@@ -103,10 +113,7 @@ class BatchAndSplitTest(tf.test.TestCase):
class DatasetPreprocessFnTest(tf.test.TestCase):
def test_train_preprocess_fn_return_dataset_element_spec(self):
token = collections.OrderedDict(tokens=([
'one must imagine',
]))
ds = tf.data.Dataset.from_tensor_slices(token)
ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
train_preprocess_fn = stackoverflow_dataset.create_train_dataset_preprocess_fn(
client_batch_size=32,
client_epochs_per_round=1,
......@@ -119,10 +126,7 @@ class DatasetPreprocessFnTest(tf.test.TestCase):
tf.TensorSpec(shape=[None, 10], dtype=tf.int64)))
def test_test_preprocess_fn_return_dataset_element_spec(self):
token = collections.OrderedDict(tokens=([
'one must imagine',
]))
ds = tf.data.Dataset.from_tensor_slices(token)
ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
test_preprocess_fn = stackoverflow_dataset.create_test_dataset_preprocess_fn(
max_seq_len=10, vocab=['one', 'must'])
test_preprocessed_ds = test_preprocess_fn(ds)
......@@ -131,10 +135,7 @@ class DatasetPreprocessFnTest(tf.test.TestCase):
tf.TensorSpec(shape=[None, 10], dtype=tf.int64)))
def test_train_preprocess_fn_returns_correct_sequence(self):
token = collections.OrderedDict(tokens=([
'one must imagine',
]))
ds = tf.data.Dataset.from_tensor_slices(token)
ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
train_preprocess_fn = stackoverflow_dataset.create_train_dataset_preprocess_fn(
client_batch_size=32,
client_epochs_per_round=1,
......@@ -148,10 +149,7 @@ class DatasetPreprocessFnTest(tf.test.TestCase):
self.evaluate(element[0]), np.array([[4, 1, 2, 3, 5, 0]]))
def test_test_preprocess_fn_returns_correct_sequence(self):
token = collections.OrderedDict(tokens=([
'one must imagine',
]))
ds = tf.data.Dataset.from_tensor_slices(token)
ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
test_preprocess_fn = stackoverflow_dataset.create_test_dataset_preprocess_fn(
max_seq_len=6, vocab=['one', 'must'])
test_preprocessed_ds = test_preprocess_fn(ds)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment