Commit 86084d98 authored by Isha Arkatkar's avatar Isha Arkatkar Committed by tensorflow-copybara
Browse files

Introduce a wrapper C++ API for dtensor device to be used by Tensorflow Federated.

PiperOrigin-RevId: 491812820
parent 852a88dc
......@@ -198,6 +198,27 @@ cc_test(
],
)
cc_library(
name = "dtensor_api",
srcs = ["dtensor_api.cc"],
hdrs = ["dtensor_api.h"],
deps = [
"@org_tensorflow//tensorflow/c:c_api_experimental",
"@org_tensorflow//tensorflow/c:tf_datatype",
"@org_tensorflow//tensorflow/c:tf_status_headers",
"@org_tensorflow//tensorflow/c:tf_status_helper",
"@org_tensorflow//tensorflow/c/eager:c_api",
"@org_tensorflow//tensorflow/core:core_cpu_base",
"@org_tensorflow//tensorflow/core/common_runtime:core",
"@org_tensorflow//tensorflow/core/platform:statusor",
"@org_tensorflow//tensorflow/dtensor/cc:dstatus",
"@org_tensorflow//tensorflow/dtensor/cc:dtensor_device_cc",
"@org_tensorflow//tensorflow/dtensor/cc:dtensor_device_util",
"@org_tensorflow//tensorflow/dtensor/cc:tensor_layout",
"@org_tensorflow//tensorflow/dtensor/proto:layout_proto_cc",
],
)
cc_library(
name = "eager_computation",
srcs = ["eager_computation.cc"],
......
/* Copyright 2022, The TensorFlow Federated Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License
==============================================================================*/
#include "tensorflow_federated/cc/core/impl/executors/dtensor_api.h"
#include <memory>
#include <optional>
#include <string>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/dtensor/cc/dstatus.h"
#include "tensorflow/dtensor/cc/dtensor_device.h"
#include "tensorflow/dtensor/cc/dtensor_device_util.h"
#include "tensorflow/dtensor/cc/tensor_layout.h"
namespace tensorflow_federated {
namespace dtensor {
extern "C" {
tensorflow::StatusOr<void*> RegisterDTensorDevice(
TFE_Context* context, tensorflow::dtensor::MeshProto mesh_proto,
const std::string& dtensor_device_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_CustomDevice device;
void* device_info;
tensorflow::dtensor::AllocateDTensorDevice(
/*device_name=*/dtensor_device_name, &device, &device_info);
TF_ASSIGN_OR_RETURN(tensorflow::dtensor::Mesh mesh,
tensorflow::dtensor::Mesh::ParseFromProto(mesh_proto));
std::string mesh_string = mesh.ToString();
TFE_RegisterCustomDevice(context, device, dtensor_device_name.c_str(),
device_info, status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
tensorflow::dtensor::AddMesh(mesh_string, device_info, /*is_async=*/false,
/*is_host_mesh=*/false,
/*in_flight_nodes_limit=*/0, status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
return device_info;
}
tensorflow::StatusOr<bool> IsTensorHandleOnDevice(
TFE_Context* context, TFE_TensorHandle* tensor_handle,
const std::string& device_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* tensor_device =
TFE_TensorHandleDeviceName(tensor_handle, status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
if (tensor_device == device_name) {
return true;
}
return false;
}
tensorflow::StatusOr<TFE_TensorHandle*> TensorToDTensor(
TFE_Context* context, TFE_TensorHandle* handle,
const tensorflow::dtensor::LayoutProto& layout,
const std::string& device_name) {
TF_ASSIGN_OR_RETURN(tensorflow::dtensor::Layout layout_object,
tensorflow::dtensor::Layout::FromProto(layout));
tensorflow::dtensor::Layout replicated_layout;
bool relayout_needed = false;
if (layout_object.IsFullyReplicated()) {
replicated_layout = layout_object;
} else {
replicated_layout = tensorflow::dtensor::Layout::ReplicatedOnMesh(
layout_object.mesh(), layout_object.rank());
relayout_needed = true;
}
TF_ASSIGN_OR_RETURN(
TFE_TensorHandle * replicated_result,
CopyToMesh(context, handle, replicated_layout, device_name));
if (relayout_needed) {
return Relayout(context, replicated_result, layout_object, device_name);
} else {
return replicated_result;
}
}
tensorflow::StatusOr<TFE_TensorHandle*> DTensorToTensor(
TFE_Context* context, TFE_TensorHandle* dtensor_handle,
const std::string& device_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
tensorflow::dtensor::TensorWithLayout* t =
reinterpret_cast<tensorflow::dtensor::TensorWithLayout*>(
TFE_TensorHandleDevicePointer(dtensor_handle, status.get()));
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
if (t->layout().IsFullyReplicated()) {
return TFE_TensorHandleCopySharingTensor(t->get_tensor(0), status.get());
}
auto replicated_layout = tensorflow::dtensor::Layout::ReplicatedOnMesh(
t->layout().mesh(), t->layout().rank());
TF_ASSIGN_OR_RETURN(
TFE_TensorHandle * result,
Relayout(context, dtensor_handle, replicated_layout, device_name));
tensorflow::dtensor::TensorWithLayout* t_replicated =
reinterpret_cast<tensorflow::dtensor::TensorWithLayout*>(
TFE_TensorHandleDevicePointer(result, status.get()));
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
auto tensor = TFE_TensorHandleCopySharingTensor(t_replicated->get_tensor(0),
status.get());
TFE_DeleteTensorHandle(result);
return tensor;
}
tensorflow::StatusOr<TFE_TensorHandle*> CopyToMeshWithProto(
TFE_Context* context, TFE_TensorHandle* tensor_handle,
const tensorflow::dtensor::LayoutProto& layout,
const std::string& device_name) {
TF_ASSIGN_OR_RETURN(tensorflow::dtensor::Layout replicated_layout,
tensorflow::dtensor::Layout::FromProto(layout));
return CopyToMesh(context, tensor_handle, replicated_layout, device_name);
}
tensorflow::StatusOr<TFE_TensorHandle*> CopyToMesh(
TFE_Context* context, TFE_TensorHandle* tensor_handle,
const tensorflow::dtensor::Layout& layout, const std::string& device_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "CopyToMesh", status.get()), TFE_DeleteOp);
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
TFE_OpSetDevice(op.get(), device_name.c_str(), status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
std::string serialized_layout = layout.ToString();
TFE_OpSetAttrString(op.get(), "layout", serialized_layout.data(),
serialized_layout.length());
TFE_OpAddInput(op.get(), tensor_handle, status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
int num_results = 1;
TFE_TensorHandle* replicated_result;
TFE_Execute(op.get(), &replicated_result, &num_results, status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
return replicated_result;
}
tensorflow::StatusOr<TFE_TensorHandle*> RelayoutWithProto(
TFE_Context* context, TFE_TensorHandle* handle,
const tensorflow::dtensor::LayoutProto& layout,
const std::string& device_name) {
TF_ASSIGN_OR_RETURN(tensorflow::dtensor::Layout layout_object,
tensorflow::dtensor::Layout::FromProto(layout));
return Relayout(context, handle, layout_object, device_name);
}
tensorflow::StatusOr<TFE_TensorHandle*> Relayout(
TFE_Context* context, TFE_TensorHandle* handle,
const tensorflow::dtensor::Layout& layout, const std::string& device_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> relayout(
TFE_NewOp(context, "Relayout", status.get()), TFE_DeleteOp);
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
TFE_OpSetDevice(relayout.get(), device_name.c_str(), status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
std::string serialized_layout = layout.ToString();
TFE_OpSetAttrString(relayout.get(), "layout", serialized_layout.data(),
serialized_layout.length());
TFE_OpAddInput(relayout.get(), handle, status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
int num_results = 1;
TFE_TensorHandle* result;
TFE_Execute(relayout.get(), &result, &num_results, status.get());
if (TF_GetCode(status.get()) != TF_OK)
return tensorflow::StatusFromTF_Status(status.get());
return result;
}
}
} // namespace dtensor
} // namespace tensorflow_federated
/* Copyright 2022, The TensorFlow Federated Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_
#define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_
#include <string>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/dtensor/cc/tensor_layout.h"
#include "tensorflow/dtensor/proto/layout.pb.h"
namespace tensorflow_federated {
namespace dtensor {
extern "C" {
// Registers a DTensor device with provided mesh.
// Returns a DeviceInfo object which can be used to add mesh
tensorflow::StatusOr<void*> RegisterDTensorDevice(
TFE_Context* context, tensorflow::dtensor::MeshProto mesh_proto,
const std::string& dtensor_device_name);
// Returns true, if given tensor_handle points to a DTensor on provided device
// name.
tensorflow::StatusOr<bool> IsTensorHandleOnDevice(
TFE_Context* context, TFE_TensorHandle* tensor_handle,
const std::string& device_name);
// Converts a Tensor to DTensor by sharding or replicating the input tensor
// according to specified layout.
tensorflow::StatusOr<TFE_TensorHandle*> TensorToDTensor(
TFE_Context* context, TFE_TensorHandle* tensor_handle,
const tensorflow::dtensor::LayoutProto& layout,
const std::string& device_name);
// Converts input DTensor to Tensor, by removing the sharding and
// returns the global tensor value handle.
tensorflow::StatusOr<TFE_TensorHandle*> DTensorToTensor(
TFE_Context* context, TFE_TensorHandle* dtensor_handle,
const std::string& device_name);
// Copies a Tensor onto mesh with replicated layout and returns DTensor.
// CopyToMesh only supports replicated layout.
tensorflow::StatusOr<TFE_TensorHandle*> CopyToMeshWithProto(
TFE_Context* context, TFE_TensorHandle* tensor_handle,
const tensorflow::dtensor::LayoutProto& layout,
const std::string& device_name);
// Same as CopyToMesh above, except accepts Layout object instead of Layout
// Proto.
tensorflow::StatusOr<TFE_TensorHandle*> CopyToMesh(
TFE_Context* context, TFE_TensorHandle* tensor_handle,
const tensorflow::dtensor::Layout& layout, const std::string& device_name);
// Changes the layout of input DTensor to provided layout and returns resulting
// DTensor handle.
tensorflow::StatusOr<TFE_TensorHandle*> RelayoutWithProto(
TFE_Context* context, TFE_TensorHandle* handle,
const tensorflow::dtensor::LayoutProto& layout,
const std::string& device_name);
// Same as Relayout above, except accepts Layout object instead of Layout Proto.
tensorflow::StatusOr<TFE_TensorHandle*> Relayout(
TFE_Context* context, TFE_TensorHandle* handle,
const tensorflow::dtensor::Layout& layout, const std::string& device_name);
} /* end extern "C" */
} // namespace dtensor
} // namespace tensorflow_federated
#endif // THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment