Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic_embedding table 插入数据报错 #463

Open
lwmonster opened this issue Sep 15, 2024 · 2 comments
Open

dynamic_embedding table 插入数据报错 #463

lwmonster opened this issue Sep 15, 2024 · 2 comments

Comments

@lwmonster
Copy link

lwmonster commented Sep 15, 2024

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 22.04):
  • TensorFlow version and how it was installed (source or binary): tensorflow2.8.3 用 pip 安装的
  • TensorFlow-Recommenders-Addons version and how it was installed (source or binary): tfra 0.5.0
  • Python version: Python3.9.19
  • Is GPU used? (yes/no): NO

Describe the bug

在使用 Dynamic Embedding 的时候,创建了 Variable, 这个 Variable 里会有一些 hashtable 用于存储稀疏特征的 embedding。
现在我想用外部数据更新 hashtable 中的 embedding 向量,但是 执行 table.insert 的时候会报错:

ValueError: Operation name: "hashTableGroup_0_mht_1of1_lookup_table_insert/TFRA>CuckooHashTableInsert"
op: "TFRA>CuckooHashTableInsert"
input: "group_0/hashTableGroup_0/hashTableGroup_0_mht_1of1"
input: "hashTableGroup_0_mht_1of1_lookup_table_insert/keys"
input: "hashTableGroup_0_mht_1of1_lookup_table_insert/values"
device: "/job:ps/replica:0/task:0"
attr {
  key: "Tin"
  value {
    type: DT_INT64
  }
}
attr {
  key: "Tout"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "_class"
  value {
    list {
      s: "loc:@group_0/hashTableGroup_0/hashTableGroup_0_mht_1of1"
    }
  }
}
 is not an element of this graph.

Code to reproduce the issue

我的代码如下:

    
def _save_merged_ckpt(self,
                          cluster,
                          server,
                          model_config,
                          valid_data_dirs,
                          cluster_spec,
                          task_type,
                          task_index,
                          merged_variables,
                          merged_hashtable,
                          target_ckpt_dir):

        if task_index != 0:
            print('Worker {} is not chief, skip save merged ckpt'.format(task_index))
            return

        if tf.io.gfile.exists(target_ckpt_dir):
            tf.io.gfile.rmtree(target_ckpt_dir)
        tf.io.gfile.makedirs(target_ckpt_dir)

        #首先重置模型图
        tf.reset_default_graph()
        with tf.Graph().as_default():
            eval_ds = DatasetManager.create_dataset(model_config.dataset_type,
                                                    model_config,
                                                    valid_data_dirs,
                                                    training=False,
                                                    is_eval=True)
            inputs = eval_ds.build(batch_size=1,
                                num_epochs=1,
                                num_worker=model_config.num_shard_of_export,
                                worker_idx=task_index)

            with tf.device(
                tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % task_index,
                                            cluster=cluster)):
                # dataset
                inference_iterator = inputs.make_one_shot_iterator()
                sample = inference_iterator.get_next()
                self.model_fn(sample, eval_ds, training=True)
                
                global_step_assign = tf.assign_add(self._global_step, 1)
                # 获取模型所有参数的值
                all_vars = tf.trainable_variables()
                # 更新dense参数
                assign_ops = []
                for var in all_vars:
                    print("{}'s shape is {}".format(var.name, var.shape))
                    if var.name not in merged_variables:
                        print('var {} not in merged_variables, SKIP it'.format(var.name))
                        continue
                    assign_ops.append(tf.assign(var, merged_variables[var.name]))

                # 更新hashtable参数
                for table in self._hash_tables:
                    if table.name not in merged_hashtable:
                        print('table {} not in merged_hashtable, SKIP it'.format(table.name))
                        continue
                    keys, values = list(merged_hashtable[table.name].keys()), list(merged_hashtable[table.name].values())
                    values = np.array(values)
                    #values_tensor = tf.convert_to_tensor(values, dtype=tf.float32)
                    assign_ops.append(table.insert(keys, values))

            parallelism_conf = self.get_parallelism_conf(task_type, task_index)
            scaffold = self.get_scaffold()
            start = time.time()



            saver = tf.train.Saver()
            with tf.train.MonitoredTrainingSession(master=server.target,
                                                is_chief=(task_index == 0),
                                                scaffold=scaffold,
                                                checkpoint_dir=target_ckpt_dir,
                                                hooks=[],
                                                save_checkpoint_secs=None,
                                                save_checkpoint_steps=None,
                                                save_summaries_steps=None,
                                                save_summaries_secs=None,
                                                config=parallelism_conf) as sess:
                print('Worker {} restore successful....'.format(task_index))

                _, step = sess.run([assign_ops, global_step_assign])
                print('Assign merged_variables to checkpoint done')

我的assign_ops 是在图里的呀,为什么会报CuckooHashTableInsert 这个 OP 不在图里呢?
哪位大佬帮忙指点一下呀
PS: 这里的 self._hash_tables 是保存的所有 dynamic_embedding_variable 变量的 内部的 tables

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@lwmonster
Copy link
Author

lwmonster commented Sep 15, 2024

@MoFHeka @rhdong @jq
HELP Please ~

@lwmonster
Copy link
Author

问题解决了,但还是不知道为啥...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant