Commit 238bf340 authored by Zachary Charles's avatar Zachary Charles Committed by tensorflow-copybara
Browse files

Move learning-specific templates to their own library.

PiperOrigin-RevId: 410595976
parent 8a0c8866
......@@ -28,16 +28,11 @@ py_library(
visibility = ["//tensorflow_federated:__pkg__"],
deps = [
":client_weight_lib",
":client_works",
":composers",
":debug_measurements",
":distributors",
":federated_averaging",
":federated_evaluation",
":federated_sgd",
":finalizers",
":keras_utils",
":learning_process",
":model",
":model_update_aggregator",
":model_utils",
......@@ -49,6 +44,7 @@ py_library(
"//tensorflow_federated/python/learning/models",
"//tensorflow_federated/python/learning/optimizers",
"//tensorflow_federated/python/learning/reconstruction",
"//tensorflow_federated/python/learning/templates",
],
)
......@@ -58,99 +54,6 @@ py_library(
srcs_version = "PY3",
)
py_library(
name = "client_works",
srcs = ["client_works.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:optimizer",
],
)
py_test(
name = "client_works_test",
size = "small",
srcs = ["client_works_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":client_works",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model_examples",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_library(
name = "composers",
srcs = ["composers.py"],
srcs_version = "PY3",
deps = [
":client_works",
":distributors",
":finalizers",
"//tensorflow_federated/python/aggregators:mean",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/learning:learning_process",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_test(
name = "composers_test",
size = "small",
srcs = ["composers_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":client_works",
":composers",
":distributors",
":finalizers",
"//tensorflow_federated/python/aggregators:mean",
"//tensorflow_federated/python/aggregators:sum_factory",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:learning_process",
"//tensorflow_federated/python/learning:model_examples",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_library(
name = "debug_measurements",
srcs = ["debug_measurements.py"],
......@@ -178,42 +81,6 @@ py_test(
],
)
py_library(
name = "distributors",
srcs = ["distributors.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
],
)
py_test(
name = "distributors_test",
size = "small",
srcs = ["distributors_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":distributors",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
],
)
py_library(
name = "federated_averaging",
srcs = ["federated_averaging.py"],
......@@ -332,47 +199,6 @@ py_cpu_gpu_test(
],
)
py_library(
name = "finalizers",
srcs = ["finalizers.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:keras_optimizer",
"//tensorflow_federated/python/learning/optimizers:optimizer",
],
)
py_test(
name = "finalizers_test",
size = "small",
srcs = ["finalizers_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":finalizers",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_library(
name = "keras_utils",
srcs = ["keras_utils.py"],
......@@ -411,37 +237,6 @@ py_test(
],
)
py_library(
name = "learning_process",
srcs = ["learning_process.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:iterative_process",
],
)
py_test(
name = "learning_process_test",
size = "small",
srcs = ["learning_process_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":learning_process",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
],
)
py_library(
name = "model",
srcs = ["model.py"],
......
......@@ -19,27 +19,17 @@ from tensorflow_federated.python.learning import metrics
from tensorflow_federated.python.learning import models
from tensorflow_federated.python.learning import optimizers
from tensorflow_federated.python.learning import reconstruction
from tensorflow_federated.python.learning import templates
from tensorflow_federated.python.learning.client_weight_lib import ClientWeighting
from tensorflow_federated.python.learning.client_works import build_model_delta_client_work
from tensorflow_federated.python.learning.client_works import ClientResult
from tensorflow_federated.python.learning.client_works import ClientWorkProcess
from tensorflow_federated.python.learning.composers import compose_learning_process
from tensorflow_federated.python.learning.composers import LearningAlgorithmState
from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements
from tensorflow_federated.python.learning.distributors import build_broadcast_process
from tensorflow_federated.python.learning.distributors import DistributionProcess
from tensorflow_federated.python.learning.federated_averaging import build_federated_averaging_process
from tensorflow_federated.python.learning.federated_averaging import ClientFedAvg
from tensorflow_federated.python.learning.federated_evaluation import build_federated_evaluation
from tensorflow_federated.python.learning.federated_evaluation import build_local_evaluation
from tensorflow_federated.python.learning.federated_sgd import build_federated_sgd_process
from tensorflow_federated.python.learning.finalizers import build_apply_optimizer_finalizer
from tensorflow_federated.python.learning.finalizers import FinalizerProcess
from tensorflow_federated.python.learning.framework.optimizer_utils import state_with_new_model_weights
from tensorflow_federated.python.learning.keras_utils import federated_aggregate_keras_metric
from tensorflow_federated.python.learning.keras_utils import from_keras_model
from tensorflow_federated.python.learning.learning_process import LearningProcess
from tensorflow_federated.python.learning.learning_process import LearningProcessOutput
from tensorflow_federated.python.learning.model import BatchOutput
from tensorflow_federated.python.learning.model import Model
from tensorflow_federated.python.learning.model_update_aggregator import compression_aggregator
......
......@@ -35,14 +35,14 @@ py_library(
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:client_works",
"//tensorflow_federated/python/learning:composers",
"//tensorflow_federated/python/learning:distributors",
"//tensorflow_federated/python/learning:finalizers",
"//tensorflow_federated/python/learning:learning_process",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:optimizer",
"//tensorflow_federated/python/learning/templates:client_works",
"//tensorflow_federated/python/learning/templates:composers",
"//tensorflow_federated/python/learning/templates:distributors",
"//tensorflow_federated/python/learning/templates:finalizers",
"//tensorflow_federated/python/learning/templates:learning_process",
],
)
......@@ -75,15 +75,15 @@ py_library(
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:client_works",
"//tensorflow_federated/python/learning:composers",
"//tensorflow_federated/python/learning:distributors",
"//tensorflow_federated/python/learning:finalizers",
"//tensorflow_federated/python/learning:learning_process",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/optimizers:optimizer",
"//tensorflow_federated/python/learning/templates:client_works",
"//tensorflow_federated/python/learning/templates:composers",
"//tensorflow_federated/python/learning/templates:distributors",
"//tensorflow_federated/python/learning/templates:finalizers",
"//tensorflow_federated/python/learning/templates:learning_process",
"//tensorflow_federated/python/tensorflow_libs:tensor_utils",
],
)
......
......@@ -26,15 +26,15 @@ from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import client_works
from tensorflow_federated.python.learning import composers
from tensorflow_federated.python.learning import distributors
from tensorflow_federated.python.learning import finalizers
from tensorflow_federated.python.learning import learning_process
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.algorithms import example_weighted_federated_averaging
from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base
from tensorflow_federated.python.learning.templates import client_works
from tensorflow_federated.python.learning.templates import composers
from tensorflow_federated.python.learning.templates import distributors
from tensorflow_federated.python.learning.templates import finalizers
from tensorflow_federated.python.learning.templates import learning_process
TFFOrKerasOptimizer = Union[optimizer_base.Optimizer,
tf.keras.optimizers.Optimizer]
......
......@@ -38,15 +38,15 @@ from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import client_works
from tensorflow_federated.python.learning import composers
from tensorflow_federated.python.learning import distributors
from tensorflow_federated.python.learning import finalizers
from tensorflow_federated.python.learning import learning_process
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.framework import dataset_reduce
from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base
from tensorflow_federated.python.learning.templates import client_works
from tensorflow_federated.python.learning.templates import composers
from tensorflow_federated.python.learning.templates import distributors
from tensorflow_federated.python.learning.templates import finalizers
from tensorflow_federated.python.learning.templates import learning_process
from tensorflow_federated.python.tensorflow_libs import tensor_utils
......
......@@ -5,6 +5,7 @@ package(default_visibility = [
"//tensorflow_federated/python/learning:learning_visibility",
"//tensorflow_federated/python/learning/algorithms:algorithms_packages",
"//tensorflow_federated/python/learning/reconstruction:reconstruction_packages",
"//tensorflow_federated/python/learning/templates:templates_packages",
# TODO(b/151441025): This visibility is temporary and can be removed once
# the dependencies between `tff.learning` and `tff.learning.framework` are
......
load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(default_visibility = [
":templates_packages",
"//tensorflow_federated/python/learning:learning_visibility",
"//tensorflow_federated/python/learning/algorithms:algorithms_packages",
])
package_group(
name = "templates_packages",
packages = ["//tensorflow_federated/python/learning/templates/..."],
)
licenses(["notice"])
py_library(
name = "templates",
srcs = ["__init__.py"],
srcs_version = "PY3",
visibility = ["//tensorflow_federated/python/learning:__pkg__"],
deps = [
":client_works",
":composers",
":distributors",
":finalizers",
":learning_process",
],
)
py_library(
name = "client_works",
srcs = ["client_works.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:optimizer",
],
)
py_test(
name = "client_works_test",
size = "small",
srcs = ["client_works_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":client_works",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model_examples",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_library(
name = "composers",
srcs = ["composers.py"],
srcs_version = "PY3",
deps = [
":client_works",
":distributors",
":finalizers",
":learning_process",
"//tensorflow_federated/python/aggregators:mean",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_test(
name = "composers_test",
size = "small",
srcs = ["composers_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":client_works",
":composers",
":distributors",
":finalizers",
":learning_process",
"//tensorflow_federated/python/aggregators:mean",
"//tensorflow_federated/python/aggregators:sum_factory",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model_examples",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_library(
name = "distributors",
srcs = ["distributors.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
],
)
py_test(
name = "distributors_test",
size = "small",
srcs = ["distributors_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":distributors",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
],
)
py_library(
name = "finalizers",
srcs = ["finalizers.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:keras_optimizer",
"//tensorflow_federated/python/learning/optimizers:optimizer",
],
)
py_test(
name = "finalizers_test",
size = "small",
srcs = ["finalizers_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":finalizers",
"//tensorflow_federated/python/core/api:computations",
"//tensorflow_federated/python/core/api:test_case",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
py_library(
name = "learning_process",
srcs = ["learning_process.py"],
srcs_version = "PY3",
deps = [
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/core/api:computation_base",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:errors",
"//tensorflow_federated/python/core/templates:iterative_process",
],
)
py_test(
name = "learning_process_test",
size = "small",
srcs = ["learning_process_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [