From 796903ae2086d5611e97ca6f9e04cacf1ff002dd Mon Sep 17 00:00:00 2001 From: Jinsun Yoo Date: Sun, 1 Sep 2024 18:41:28 -0400 Subject: [PATCH] [et_generator, text_converter] Change 'comm_size' attr from uint64 to int64 --- src/converter/text_converter.py | 2 +- src/feeder/et_feeder_node.cpp | 2 +- src/generator/generator.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/converter/text_converter.py b/src/converter/text_converter.py index 1bad778f..f28e1230 100644 --- a/src/converter/text_converter.py +++ b/src/converter/text_converter.py @@ -104,7 +104,7 @@ def get_comm_type(self, comm_type: str) -> int: def get_comm_coll_node(self, layer_name: str, comm_type: str, comm_size: int) -> Any: node = self.get_node(f"COMM_COLL_NODE_{layer_name}_{comm_type}", COMM_COLL_NODE) node.attr.append(ChakraAttr(name="comm_type", int64_val=self.get_comm_type(comm_type))) - node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size)) + node.attr.append(ChakraAttr(name="comm_size", int64_val=comm_size)) return node def add_parent(self, child_node: Any, parent_node: Any) -> None: diff --git a/src/feeder/et_feeder_node.cpp b/src/feeder/et_feeder_node.cpp index 2d89b93c..e0427e41 100644 --- a/src/feeder/et_feeder_node.cpp +++ b/src/feeder/et_feeder_node.cpp @@ -25,7 +25,7 @@ ETFeederNode::ETFeederNode(std::shared_ptr node) { } else if (attr_name == "comm_priority") { this->comm_priority_ = static_cast(attr.int32_val()); } else if (attr_name == "comm_size") { - this->comm_size_ = attr.int64_val(); + this->comm_size_ = static_cast(attr.int64_val()); } else if (attr_name == "comm_src") { this->comm_src_ = static_cast(attr.int32_val()); } else if (attr_name == "comm_dst") { diff --git a/src/generator/generator.py b/src/generator/generator.py index 036c16f7..d1f547f0 100644 --- a/src/generator/generator.py +++ b/src/generator/generator.py @@ -191,7 +191,7 @@ def generate_comm_coll_node(num_npus: int, comm_size: int, comm_type: int, node_ node = get_node(node_name, COMM_COLL_NODE) node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False)) - node.attr.extend([get_comm_type_attr(comm_type), ChakraAttr(name="comm_size", uint64_val=comm_size)]) + node.attr.extend([get_comm_type_attr(comm_type), ChakraAttr(name="comm_size", int64_val=comm_size)]) encode_message(et, node)