Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug Report #21

Open
Patrick-Ni opened this issue Aug 9, 2024 · 3 comments
Open

Bug Report #21

Patrick-Ni opened this issue Aug 9, 2024 · 3 comments

Comments

@Patrick-Ni
Copy link

@Zefan-Cai 你好:
非常精彩的工作,以及感谢你开源了代码!
昨天我在尝试复现你的工作的时候,发现了一个问题:在pyramidkv.llama_model.py的第964行:

  if past_key_value is not None:
      if self.layer_idx is None:
          raise ValueError(
              f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
              "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
              "with a layer index."
          )
      if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
          if self.kv_seq_len != 0:
              kv_seq_len += self.kv_seq_len
          else:
              kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
      else:
          kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

这段代码会在self.kv_seq_len不为0的时候,给当前的kv_seq_len加上过去的所有的kv_seq_len。然而,在你的第985行开始:

  if past_key_value is not None:
      cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
      # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
      # print('kv_seq_len:', kv_seq_len)
      # print('key_states.shape:', key_states.shape)
      if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
          self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len
          key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
          past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
      else:
          self.kv_seq_len += q_len
          key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

每次都会给self.kv_seql_len加上q_len。这在单个样例中是没有问题的,但是如果我当前生成完毕,进行下一次生成的时候,你会发现self.kv_seq_len并没有被置为0。这时候,从第二个生成样本开始,key_states.shape[-2]就永远不会等于kv_seq_len,也就永远不会执行self.kv_cluster.update_kv这个操作,也就是说,从第二个样本开始,我们设置的KV Reduction Method就失效了。
比如:第一个样本prompt token的长度是106,最终生成了279,那self.kv_seq_len=279。第二个样本就首先得加上279,这样永远无法执行kv cache压缩策略。
我尝试在别的地方找有关self.kv_seq_len的定义、说明等,但均无法找到。
我的理解是否有问题?希望能得到您的回复!

@Zefan-Cai
Copy link
Owner

transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_llama

这一步会替换transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation,替换后的prepare_inputs_for_generation会在处理每一个新样本时重置self.kv_seql_len,就不会有这个问题了。

@Patrick-Ni
Copy link
Author

@Zefan-Cai 你好!
我注意到这个函数了,我发现他确实成功替换了,但是在第二个样本的时候,他似乎没有成功重置self.kv_seq_len?我在prepare_inputs_for_generation_llama 打印了这么一个语句:

def prepare_inputs_for_generation_llama(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
    print("clear kv seq len!!!")
    print(past_key_values is None)
    if past_key_values is None:
        for layer in self.model.layers:
            layer.self_attn.kv_seq_len = 0

结果为:
image
也就是说,他从第一个样本开始,past_key_values就不是None?我不知道我这个观察是否是对的,您可以打印一下这个值吗?
期待您的回复!

@Zefan-Cai
Copy link
Owner

Zefan-Cai commented Aug 16, 2024

很抱歉回复晚了,您可以参考#15 ,和你的问题很像。您可以把transformers版本换到4.37再尝试吗?如果还有一样的问题,我会打印一下这个值试试。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants