Skip to content

Commit

Permalink
[Op] Implement SliceSend/SliceRecv Op.
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <chenbangduo.cbd@alibaba-inc.com>
  • Loading branch information
JackMoriarty committed Nov 14, 2023
1 parent 89c7d63 commit eac8e8e
Show file tree
Hide file tree
Showing 11 changed files with 1,118 additions and 5 deletions.
2 changes: 2 additions & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,7 @@ tf_gen_op_libs(
"set_ops",
"script_ops",
"sendrecv_ops",
"slice_sendrecv_ops",
"sparse_ops",
"spectral_ops",
"state_ops",
Expand Down Expand Up @@ -1497,6 +1498,7 @@ cc_library(
":sdca_ops_op_lib",
":sendrecv_ops_op_lib",
":set_ops_op_lib",
":slice_sendrecv_ops_op_lib",
":sparse_ops_op_lib",
":star_run_graph_op_op_lib",
":summary_ops_op_lib",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/framework/rendezvous.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class Rendezvous : public core::RefCounted {
friend class SendOp;
friend class RecvOp;
friend class FuseRecvOp;
friend class SliceSendOp;
friend class SliceRecvOp;
friend class RefSendOp;
friend class RefRecvOp;
string buf_;
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
{"_Send", NC_SEND},
{"_HostSend", NC_HOST_SEND},
{"_RefSend", NC_REF_SEND},
{"_SliceSend", NC_SLICE_SEND},
{"_Recv", NC_RECV},
{"_HostRecv", NC_HOST_RECV},
{"_RefRecv", NC_REF_RECV},
{"_FuseRecv", NC_FUSE_RECV},
{"_HostFuseRecv", NC_HOST_FUSE_RECV},
{"_SliceRecv", NC_SLICE_RECV},
{"Const", NC_CONSTANT},
{"HostConst", NC_CONSTANT},
{"Variable", NC_VARIABLE},
Expand Down
12 changes: 10 additions & 2 deletions tensorflow/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,16 @@ class Node {
bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
bool IsSend() const { return class_ == NC_SEND ||
class_ == NC_HOST_SEND ||
class_ == NC_REF_SEND; }
class_ == NC_REF_SEND ||
class_ == NC_SLICE_SEND; }
bool IsSliceSend() const { return class_ == NC_SLICE_SEND; }
bool IsRecv() const { return class_ == NC_RECV ||
class_ == NC_HOST_RECV ||
class_ == NC_REF_RECV; }
class_ == NC_REF_RECV ||
class_ == NC_SLICE_RECV; }
bool IsFuseRecv() const { return class_ == NC_FUSE_RECV ||
class_ == NC_HOST_FUSE_RECV; }
bool IsSliceRecv() const {return class_ == NC_SLICE_RECV; }
bool IsConstant() const { return class_ == NC_CONSTANT; }
bool IsStage() const { return class_ == NC_TENSOR_BUFFER_PUT; }
bool IsUnstage() const { return class_ == NC_TENSOR_BUFFER_TAKE; }
Expand Down Expand Up @@ -334,11 +338,13 @@ class Node {
NC_SEND,
NC_HOST_SEND,
NC_REF_SEND,
NC_SLICE_SEND,
NC_RECV,
NC_HOST_RECV,
NC_REF_RECV,
NC_FUSE_RECV,
NC_HOST_FUSE_RECV,
NC_SLICE_RECV,
NC_CONSTANT,
NC_VARIABLE,
NC_KV_VAR_HANDLE,
Expand Down Expand Up @@ -844,7 +850,9 @@ inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
inline bool IsSend(const Node* node) { return node->IsSend(); }
inline bool IsSliceSend(const Node* node) { return node->IsSliceSend(); }
inline bool IsRecv(const Node* node) { return node->IsRecv(); }
inline bool IsSliceRecv(const Node* node) { return node->IsSliceRecv(); }
inline bool IsFuseRecv(const Node* node) { return node->IsFuseRecv(); }
inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/core/grappler/op_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ bool IsReciprocalGrad(const NodeDef& node) {
}

bool IsRecv(const NodeDef& node) {
return node.op() == "_Recv" || node.op() == "_HostRecv";
return node.op() == "_Recv" || node.op() == "_HostRecv" || IsSliceRecv(node);
}

bool IsFuseRecv(const NodeDef& node) {
Expand Down Expand Up @@ -502,7 +502,7 @@ bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }

bool IsSend(const NodeDef& node) {
return node.op() == "_Send" || node.op() == "_HostSend";
return node.op() == "_Send" || node.op() == "_HostSend" || IsSliceSend(node);
}

bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
Expand All @@ -517,6 +517,10 @@ bool IsSize(const NodeDef& node) { return node.op() == "Size"; }

bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }

bool IsSliceRecv(const NodeDef& node) { return node.op() == "_SliceRecv"; }

bool IsSliceSend(const NodeDef& node) { return node.op() == "_SliceSend"; }

bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }

bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ bool IsShuffle(const NodeDef& node);
bool IsSigmoidGrad(const NodeDef& node);
bool IsSize(const NodeDef& node);
bool IsSlice(const NodeDef& node);
bool IsSliceRecv(const NodeDef& node);
bool IsSliceSend(const NodeDef& node);
bool IsSnapshot(const NodeDef& node);
bool IsSoftmax(const NodeDef& node);
bool IsSoftplusGrad(const NodeDef& node);
Expand Down
27 changes: 26 additions & 1 deletion tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5423,8 +5423,9 @@ cc_library(
name = "required",
deps = [
":no_op",
":sendrecv_ops",
":fuserecv_ops",
":sendrecv_ops",
":slice_sendrecv_ops",
],
)

Expand All @@ -5445,6 +5446,12 @@ tf_kernel_library(
deps = REQUIRED_DEPS,
)

tf_kernel_library(
name = "slice_sendrecv_ops",
prefix = "slice_sendrecv_ops",
deps = REQUIRED_DEPS,
)

tf_kernel_library(
name = "group_embedding_ops",
hdrs = ["group_embedding/group_embedding_lookup_sparse_forward_base_ops.h"],
Expand Down Expand Up @@ -5509,6 +5516,24 @@ tf_cc_test(
],
)

tf_cc_test(
name = "slice_sendrecv_ops_test",
srcs = ["slice_sendrecv_ops_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
deps = [
":control_flow_ops",
":cwise_op",
":logging_ops",
":ops_testutil",
":ops_util",
":slice_sendrecv_ops",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_kernel_library(
name = "fuserecv_ops",
prefix = "fuserecv_ops",
Expand Down
Loading

0 comments on commit eac8e8e

Please sign in to comment.