Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qwen2 is not supported by QAT #1818

Open
Tracked by #1747
elfisworking opened this issue Oct 12, 2024 · 0 comments
Open
Tracked by #1747

qwen2 is not supported by QAT #1818

elfisworking opened this issue Oct 12, 2024 · 0 comments
Assignees
Labels
bug Something isn't working high-priority

Comments

@elfisworking
Copy link

elfisworking commented Oct 12, 2024

i try to use QAT to quantize qwen2 1.5B model
The error raise from function training.load_from_full_model_state_dict( model, model_state_dict, self._device, self._is_rank_zero, strict=True ) from recipes/qat_distributed
Then i find error caused by

# torchtune/torchtune/training/_distributed.py
def load_from_full_model_state_dict(
    model: "FSDPModule",  # noqa
    full_sd: Dict[str, Any],
    device: torch.device,
    is_rank_zero: bool,
    strict: bool = False,
    cpu_offload: bool = False,
):
    """
    Converting full state dict into a sharded state dict
    and loading it into FSDP model
    - 'full' means plain tensor
    - 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
    - `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and
       we need to broadcast from rank 0
    """
    meta_sharded_sd = model.state_dict()
    sharded_sd = {}
    for param_name, full_tensor in full_sd.items():
        sharded_meta_param = meta_sharded_sd.get(param_name)
        full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)

it shows us that shared_meta_param.dtype is None.
By adding printing function, i find meta_shared_sd doesn‘t have bias layer

sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(151936, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.sa_norm.scale
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536,), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.mlp.w2.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536, 8960), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.mlp.w1.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(8960, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.mlp.w3.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(8960, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.mlp_norm.scale
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536,), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.attn.k_proj.bias
### error raise

my yaml file is as follows:

# Tokenizer
tokenizer:
  _component_: torchtune.models.qwen2.qwen2_1_5b
  path: /QAT/Qwen2-1.5B/vocab.json
  merges_file: /QAT/Qwen2-1.5b/merges.txt
  max_seq_len: null

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  source: parquet
  data_files: /QAT/dataset/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
seed: 42
shuffle: True

# Model Arguments
model:
  _component_: torchtune.models.qwen2.qwen2_1_5b

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /QAT/Qwen2-1.5B/
  checkpoint_files: [
   model.safetensors
  ]
  recipe_checkpoint: null
  output_dir: /QAT/Qwen2-1.5B
  model_type: QWEN2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 8
epochs: 1

# QAT arguments
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 256

optimizer:
  _component_: torch.optim.AdamW
  lr: 2e-5
  fused: True
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}
output_dir: /QAT/Qwen2-1.5B/finetune-logs
log_every_n_steps: 1
log_peak_memory_stats: False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high-priority
Projects
None yet
Development

No branches or pull requests

4 participants