Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
KMSCAKKSCFKA AKFACAMADCAS
tensorflow-federated
Commits
86084d98
Commit
86084d98
authored
Nov 29, 2022
by
Isha Arkatkar
Committed by
tensorflow-copybara
Nov 29, 2022
Browse files
Introduce a wrapper C++ API for dtensor device to be used by Tensorflow Federated.
PiperOrigin-RevId: 491812820
parent
852a88dc
Changes
3
Hide whitespace changes
Inline
Side-by-side
tensorflow_federated/cc/core/impl/executors/BUILD
View file @
86084d98
...
...
@@ -198,6 +198,27 @@ cc_test(
],
)
cc_library
(
name
=
"dtensor_api"
,
srcs
=
[
"dtensor_api.cc"
],
hdrs
=
[
"dtensor_api.h"
],
deps
=
[
"@org_tensorflow//tensorflow/c:c_api_experimental"
,
"@org_tensorflow//tensorflow/c:tf_datatype"
,
"@org_tensorflow//tensorflow/c:tf_status_headers"
,
"@org_tensorflow//tensorflow/c:tf_status_helper"
,
"@org_tensorflow//tensorflow/c/eager:c_api"
,
"@org_tensorflow//tensorflow/core:core_cpu_base"
,
"@org_tensorflow//tensorflow/core/common_runtime:core"
,
"@org_tensorflow//tensorflow/core/platform:statusor"
,
"@org_tensorflow//tensorflow/dtensor/cc:dstatus"
,
"@org_tensorflow//tensorflow/dtensor/cc:dtensor_device_cc"
,
"@org_tensorflow//tensorflow/dtensor/cc:dtensor_device_util"
,
"@org_tensorflow//tensorflow/dtensor/cc:tensor_layout"
,
"@org_tensorflow//tensorflow/dtensor/proto:layout_proto_cc"
,
],
)
cc_library
(
name
=
"eager_computation"
,
srcs
=
[
"eager_computation.cc"
],
...
...
tensorflow_federated/cc/core/impl/executors/dtensor_api.cc
0 → 100644
View file @
86084d98
/* Copyright 2022, The TensorFlow Federated Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License
==============================================================================*/
#include
"tensorflow_federated/cc/core/impl/executors/dtensor_api.h"
#include
<memory>
#include
<optional>
#include
<string>
#include
"tensorflow/c/eager/c_api.h"
#include
"tensorflow/c/tf_status.h"
#include
"tensorflow/c/tf_status_helper.h"
#include
"tensorflow/core/platform/statusor.h"
#include
"tensorflow/dtensor/cc/dstatus.h"
#include
"tensorflow/dtensor/cc/dtensor_device.h"
#include
"tensorflow/dtensor/cc/dtensor_device_util.h"
#include
"tensorflow/dtensor/cc/tensor_layout.h"
namespace
tensorflow_federated
{
namespace
dtensor
{
extern
"C"
{
tensorflow
::
StatusOr
<
void
*>
RegisterDTensorDevice
(
TFE_Context
*
context
,
tensorflow
::
dtensor
::
MeshProto
mesh_proto
,
const
std
::
string
&
dtensor_device_name
)
{
std
::
unique_ptr
<
TF_Status
,
decltype
(
&
TF_DeleteStatus
)
>
status
(
TF_NewStatus
(),
TF_DeleteStatus
);
TFE_CustomDevice
device
;
void
*
device_info
;
tensorflow
::
dtensor
::
AllocateDTensorDevice
(
/*device_name=*/
dtensor_device_name
,
&
device
,
&
device_info
);
TF_ASSIGN_OR_RETURN
(
tensorflow
::
dtensor
::
Mesh
mesh
,
tensorflow
::
dtensor
::
Mesh
::
ParseFromProto
(
mesh_proto
));
std
::
string
mesh_string
=
mesh
.
ToString
();
TFE_RegisterCustomDevice
(
context
,
device
,
dtensor_device_name
.
c_str
(),
device_info
,
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
tensorflow
::
dtensor
::
AddMesh
(
mesh_string
,
device_info
,
/*is_async=*/
false
,
/*is_host_mesh=*/
false
,
/*in_flight_nodes_limit=*/
0
,
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
return
device_info
;
}
tensorflow
::
StatusOr
<
bool
>
IsTensorHandleOnDevice
(
TFE_Context
*
context
,
TFE_TensorHandle
*
tensor_handle
,
const
std
::
string
&
device_name
)
{
std
::
unique_ptr
<
TF_Status
,
decltype
(
&
TF_DeleteStatus
)
>
status
(
TF_NewStatus
(),
TF_DeleteStatus
);
const
char
*
tensor_device
=
TFE_TensorHandleDeviceName
(
tensor_handle
,
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
if
(
tensor_device
==
device_name
)
{
return
true
;
}
return
false
;
}
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
TensorToDTensor
(
TFE_Context
*
context
,
TFE_TensorHandle
*
handle
,
const
tensorflow
::
dtensor
::
LayoutProto
&
layout
,
const
std
::
string
&
device_name
)
{
TF_ASSIGN_OR_RETURN
(
tensorflow
::
dtensor
::
Layout
layout_object
,
tensorflow
::
dtensor
::
Layout
::
FromProto
(
layout
));
tensorflow
::
dtensor
::
Layout
replicated_layout
;
bool
relayout_needed
=
false
;
if
(
layout_object
.
IsFullyReplicated
())
{
replicated_layout
=
layout_object
;
}
else
{
replicated_layout
=
tensorflow
::
dtensor
::
Layout
::
ReplicatedOnMesh
(
layout_object
.
mesh
(),
layout_object
.
rank
());
relayout_needed
=
true
;
}
TF_ASSIGN_OR_RETURN
(
TFE_TensorHandle
*
replicated_result
,
CopyToMesh
(
context
,
handle
,
replicated_layout
,
device_name
));
if
(
relayout_needed
)
{
return
Relayout
(
context
,
replicated_result
,
layout_object
,
device_name
);
}
else
{
return
replicated_result
;
}
}
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
DTensorToTensor
(
TFE_Context
*
context
,
TFE_TensorHandle
*
dtensor_handle
,
const
std
::
string
&
device_name
)
{
std
::
unique_ptr
<
TF_Status
,
decltype
(
&
TF_DeleteStatus
)
>
status
(
TF_NewStatus
(),
TF_DeleteStatus
);
tensorflow
::
dtensor
::
TensorWithLayout
*
t
=
reinterpret_cast
<
tensorflow
::
dtensor
::
TensorWithLayout
*>
(
TFE_TensorHandleDevicePointer
(
dtensor_handle
,
status
.
get
()));
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
if
(
t
->
layout
().
IsFullyReplicated
())
{
return
TFE_TensorHandleCopySharingTensor
(
t
->
get_tensor
(
0
),
status
.
get
());
}
auto
replicated_layout
=
tensorflow
::
dtensor
::
Layout
::
ReplicatedOnMesh
(
t
->
layout
().
mesh
(),
t
->
layout
().
rank
());
TF_ASSIGN_OR_RETURN
(
TFE_TensorHandle
*
result
,
Relayout
(
context
,
dtensor_handle
,
replicated_layout
,
device_name
));
tensorflow
::
dtensor
::
TensorWithLayout
*
t_replicated
=
reinterpret_cast
<
tensorflow
::
dtensor
::
TensorWithLayout
*>
(
TFE_TensorHandleDevicePointer
(
result
,
status
.
get
()));
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
auto
tensor
=
TFE_TensorHandleCopySharingTensor
(
t_replicated
->
get_tensor
(
0
),
status
.
get
());
TFE_DeleteTensorHandle
(
result
);
return
tensor
;
}
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
CopyToMeshWithProto
(
TFE_Context
*
context
,
TFE_TensorHandle
*
tensor_handle
,
const
tensorflow
::
dtensor
::
LayoutProto
&
layout
,
const
std
::
string
&
device_name
)
{
TF_ASSIGN_OR_RETURN
(
tensorflow
::
dtensor
::
Layout
replicated_layout
,
tensorflow
::
dtensor
::
Layout
::
FromProto
(
layout
));
return
CopyToMesh
(
context
,
tensor_handle
,
replicated_layout
,
device_name
);
}
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
CopyToMesh
(
TFE_Context
*
context
,
TFE_TensorHandle
*
tensor_handle
,
const
tensorflow
::
dtensor
::
Layout
&
layout
,
const
std
::
string
&
device_name
)
{
std
::
unique_ptr
<
TF_Status
,
decltype
(
&
TF_DeleteStatus
)
>
status
(
TF_NewStatus
(),
TF_DeleteStatus
);
std
::
unique_ptr
<
TFE_Op
,
decltype
(
&
TFE_DeleteOp
)
>
op
(
TFE_NewOp
(
context
,
"CopyToMesh"
,
status
.
get
()),
TFE_DeleteOp
);
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
TFE_OpSetDevice
(
op
.
get
(),
device_name
.
c_str
(),
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
std
::
string
serialized_layout
=
layout
.
ToString
();
TFE_OpSetAttrString
(
op
.
get
(),
"layout"
,
serialized_layout
.
data
(),
serialized_layout
.
length
());
TFE_OpAddInput
(
op
.
get
(),
tensor_handle
,
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
int
num_results
=
1
;
TFE_TensorHandle
*
replicated_result
;
TFE_Execute
(
op
.
get
(),
&
replicated_result
,
&
num_results
,
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
return
replicated_result
;
}
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
RelayoutWithProto
(
TFE_Context
*
context
,
TFE_TensorHandle
*
handle
,
const
tensorflow
::
dtensor
::
LayoutProto
&
layout
,
const
std
::
string
&
device_name
)
{
TF_ASSIGN_OR_RETURN
(
tensorflow
::
dtensor
::
Layout
layout_object
,
tensorflow
::
dtensor
::
Layout
::
FromProto
(
layout
));
return
Relayout
(
context
,
handle
,
layout_object
,
device_name
);
}
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
Relayout
(
TFE_Context
*
context
,
TFE_TensorHandle
*
handle
,
const
tensorflow
::
dtensor
::
Layout
&
layout
,
const
std
::
string
&
device_name
)
{
std
::
unique_ptr
<
TF_Status
,
decltype
(
&
TF_DeleteStatus
)
>
status
(
TF_NewStatus
(),
TF_DeleteStatus
);
std
::
unique_ptr
<
TFE_Op
,
decltype
(
&
TFE_DeleteOp
)
>
relayout
(
TFE_NewOp
(
context
,
"Relayout"
,
status
.
get
()),
TFE_DeleteOp
);
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
TFE_OpSetDevice
(
relayout
.
get
(),
device_name
.
c_str
(),
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
std
::
string
serialized_layout
=
layout
.
ToString
();
TFE_OpSetAttrString
(
relayout
.
get
(),
"layout"
,
serialized_layout
.
data
(),
serialized_layout
.
length
());
TFE_OpAddInput
(
relayout
.
get
(),
handle
,
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
int
num_results
=
1
;
TFE_TensorHandle
*
result
;
TFE_Execute
(
relayout
.
get
(),
&
result
,
&
num_results
,
status
.
get
());
if
(
TF_GetCode
(
status
.
get
())
!=
TF_OK
)
return
tensorflow
::
StatusFromTF_Status
(
status
.
get
());
return
result
;
}
}
}
// namespace dtensor
}
// namespace tensorflow_federated
tensorflow_federated/cc/core/impl/executors/dtensor_api.h
0 → 100644
View file @
86084d98
/* Copyright 2022, The TensorFlow Federated Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_
#define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_
#include
<string>
#include
"tensorflow/c/eager/c_api.h"
#include
"tensorflow/dtensor/cc/tensor_layout.h"
#include
"tensorflow/dtensor/proto/layout.pb.h"
namespace
tensorflow_federated
{
namespace
dtensor
{
extern
"C"
{
// Registers a DTensor device with provided mesh.
// Returns a DeviceInfo object which can be used to add mesh
tensorflow
::
StatusOr
<
void
*>
RegisterDTensorDevice
(
TFE_Context
*
context
,
tensorflow
::
dtensor
::
MeshProto
mesh_proto
,
const
std
::
string
&
dtensor_device_name
);
// Returns true, if given tensor_handle points to a DTensor on provided device
// name.
tensorflow
::
StatusOr
<
bool
>
IsTensorHandleOnDevice
(
TFE_Context
*
context
,
TFE_TensorHandle
*
tensor_handle
,
const
std
::
string
&
device_name
);
// Converts a Tensor to DTensor by sharding or replicating the input tensor
// according to specified layout.
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
TensorToDTensor
(
TFE_Context
*
context
,
TFE_TensorHandle
*
tensor_handle
,
const
tensorflow
::
dtensor
::
LayoutProto
&
layout
,
const
std
::
string
&
device_name
);
// Converts input DTensor to Tensor, by removing the sharding and
// returns the global tensor value handle.
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
DTensorToTensor
(
TFE_Context
*
context
,
TFE_TensorHandle
*
dtensor_handle
,
const
std
::
string
&
device_name
);
// Copies a Tensor onto mesh with replicated layout and returns DTensor.
// CopyToMesh only supports replicated layout.
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
CopyToMeshWithProto
(
TFE_Context
*
context
,
TFE_TensorHandle
*
tensor_handle
,
const
tensorflow
::
dtensor
::
LayoutProto
&
layout
,
const
std
::
string
&
device_name
);
// Same as CopyToMesh above, except accepts Layout object instead of Layout
// Proto.
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
CopyToMesh
(
TFE_Context
*
context
,
TFE_TensorHandle
*
tensor_handle
,
const
tensorflow
::
dtensor
::
Layout
&
layout
,
const
std
::
string
&
device_name
);
// Changes the layout of input DTensor to provided layout and returns resulting
// DTensor handle.
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
RelayoutWithProto
(
TFE_Context
*
context
,
TFE_TensorHandle
*
handle
,
const
tensorflow
::
dtensor
::
LayoutProto
&
layout
,
const
std
::
string
&
device_name
);
// Same as Relayout above, except accepts Layout object instead of Layout Proto.
tensorflow
::
StatusOr
<
TFE_TensorHandle
*>
Relayout
(
TFE_Context
*
context
,
TFE_TensorHandle
*
handle
,
const
tensorflow
::
dtensor
::
Layout
&
layout
,
const
std
::
string
&
device_name
);
}
/* end extern "C" */
}
// namespace dtensor
}
// namespace tensorflow_federated
#endif // THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment