Skip to content

Commit

Permalink
Enable SNIP on multiple cards using DeepSpeed ZeRO-3 (#1492)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yiliu30 and pre-commit-ci[bot] authored Jan 12, 2024
1 parent 061884d commit 49ab28d
Show file tree
Hide file tree
Showing 17 changed files with 447 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Step-by-Step
============

# single GPU
# Single GPU

```
export CUDA_VISIBLE_DEVICES=0
Expand All @@ -15,10 +15,11 @@ bash run.sh \
--pruning_frequency=1000
```

# multi GPU
# Multi GPU

we use `accelerate` and `deepspeed ZeRO Stage-2` to conduct weight magnitude pruning
We use `accelerate` and `deepspeed ZeRO` to conduct weight magnitude, snip pruning. Below are two usage examples: 1) magnitude pruning with ZeRO Stage-2, and 2) snip pruning with ZeRO Stage-3.

## Magnitude pruning with ZeRO Stage-2
### Accelerate DeepSpeed Plugin

On your machine(s) just run:
Expand Down Expand Up @@ -105,3 +106,82 @@ bash run_ds.sh \
--pruning_pattern=4x1 \
--pruning_frequency=1000
```


## SNIP pruning with ZeRO Stage-3

To specify the accelerate use DeepSpeed ZeRO Stage-3. On your machine(s) just run:
``` shell
accelerate config

compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: config/zero_stage3_config.json
zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false
```
with the contents of `config/zero_stage3_config.json` being:

```
{
"train_batch_size": 64,
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 4,
"fp16": {
"enabled": true,
"min_loss_scale": 1,
"opt_level": "O2"
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0.0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto",
"warmup_type": "cosine"
}
}
}
```

### Pruning
> Note: As the ZeRO Stage-3 partitions all three model states(optimizer states, gradients, and parameters), please specify the `pruning_scope` as `local`. Choosing `global` requires gathering all parameters to update the mask, which compromises the benefits of ZeRO Stage-3.

```
# 2 gpu cards example
export CUDA_VISIBLE_DEVICES=0,1 USE_DEEPSPEED=1
bash run_ds_z3.sh \
--model_name_or_path=facebook/opt-125m \
--dataset_name=NeelNanda/pile-10k \
--block_size=128 \
--output_dir=./test-clm \
--pruning_type=snip_momentum \
--pruning_scope=local \
--pruning_pattern=4x1 \
--pruning_frequency=1000
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"train_batch_size": 64,
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 4,
"fp16": {
"enabled": true,
"min_loss_scale": 1,
"opt_level": "O2"
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0.0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto",
"warmup_type": "cosine"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ def parse_args():
help="pruning criteria to use.",
choices=["magnitude", "snip", "snip_momentum"],
)
parser.add_argument(
"--pruning_scope",
type=str,
default="global",
help="determine layers' scores should be gather together to sort.",
choices=["local", "global"],
)
parser.add_argument(
"--warm_epochs",
type=int,
Expand Down Expand Up @@ -688,7 +695,7 @@ def group_texts(examples):
pruning_configs=[
{
"pruning_type": args.pruning_type,
"pruning_scope": "global",
"pruning_scope": args.pruning_scope,
"sparsity_decay_type": "exp",
"excluded_op_names": ["pooler"],
"pruning_op_types": ["Linear"],
Expand Down Expand Up @@ -800,7 +807,8 @@ def group_texts(examples):

if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
# fetch the ds model from inc model
unwrapped_model = accelerator.unwrap_model(model.model)
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/bin/bash
set -x

function main {

init_params "$@"
run_pruning

}

# init params
function init_params {
dataset_name="NeelNanda/pile-10k"
model_name_or_path="facebook/opt-125m"
output_dir="./test-clm"
per_device_train_batch_size=8
block_size=128
gradient_accumulation_steps=4
num_train_epochs=3
target_sparsity=0.8
pruning_type="snip_momentum"
pruning_scope="local"
pruning_pattern="4x1"
pruning_frequency=1000
for var in "$@"
do
case $var in
--dataset_name=*)
dataset_name=$(echo $var |cut -f2 -d=)
;;
--model_name_or_path=*)
model_name_or_path=$(echo $var |cut -f2 -d=)
;;
--output_dir=*)
output_dir=$(echo $var |cut -f2 -d=)
;;
--per_device_train_batch_size=*)
per_device_train_batch_size=$(echo $var |cut -f2 -d=)
;;
--block_size=*)
block_size=$(echo $var |cut -f2 -d=)
;;
--gradient_accumulation_steps=*)
gradient_accumulation_steps=$(echo $var |cut -f2 -d=)
;;
--num_train_epochs=*)
num_train_epochs=$(echo $var |cut -f2 -d=)
;;
--target_sparsity=*)
target_sparsity=$(echo $var |cut -f2 -d=)
;;
--pruning_type=*)
pruning_type=$(echo $var |cut -f2 -d=)
;;
--pruning_scope=*)
pruning_scope=$(echo $var |cut -f2 -d=)
;;
--pruning_pattern=*)
pruning_pattern=$(echo $var |cut -f2 -d=)
;;
--pruning_frequency=*)
pruning_frequency=$(echo $var |cut -f2 -d=)
;;
*)
echo "Error: No such parameter: ${var}"
exit 1
;;
esac
done

}

# run_tuning
function run_pruning {
accelerate launch --deepspeed_config_file config/ds_config.json --mixed_precision fp16 \
run_clm_no_trainer_deepspeed.py \
--dataset_name $dataset_name \
--model_name_or_path $model_name_or_path \
--block_size $block_size \
--per_device_train_batch_size $per_device_train_batch_size \
--gradient_accumulation_steps $gradient_accumulation_steps \
--output_dir $output_dir \
--do_prune \
--num_train_epochs $num_train_epochs \
--target_sparsity $target_sparsity \
--pruning_type $pruning_type \
--pruning_scope $pruning_scope \
--pruning_pattern $pruning_pattern \
--pruning_frequency $pruning_frequency

}

main "$@"

31 changes: 21 additions & 10 deletions neural_compressor/compression/pruner/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import torch
from .utils import safe_get_data, safe_get_grad, safe_get_shape, torch

CRITERIA = {}

Expand Down Expand Up @@ -96,7 +96,8 @@ def on_step_begin(self):
"""Calculate and store the pruning scores based on a magnitude criterion."""
with torch.no_grad():
for key in self.modules.keys():
p = self.modules[key].weight.data
param = self.modules[key].weight
p = safe_get_data(param)
if hasattr(self.pattern, "reduce_score"):
self.scores[key] = self.pattern.reduce_score(torch.abs(p), key)
else:
Expand Down Expand Up @@ -161,12 +162,15 @@ def on_before_optimizer_step(self):
"""Calculate and store the pruning scores based on snip criterion."""
with torch.no_grad():
for key in self.modules.keys():
p = self.modules[key].weight
# p = self.modules[key].weight
param = self.modules[key].weight
data = safe_get_data(param)
grad = safe_get_grad(param)
# self.scores[key] = torch.abs(p * p.grad)
if hasattr(self.pattern, "reduce_score"):
self.scores[key] = self.pattern.reduce_score(torch.abs(p * p.grad), key)
self.scores[key] = self.pattern.reduce_score(torch.abs(data * grad), key)
else:
self.scores[key] = torch.abs(p * p.grad)
self.scores[key] = torch.abs(data * grad)


@register_criterion("snip_momentum")
Expand All @@ -191,15 +195,19 @@ def __init__(self, modules, config, pattern):
super(SnipMomentumCriterion, self).__init__(modules, config, pattern)
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
for key in modules.keys():
p = modules[key].weight
param = modules[key].weight
# p = modules[key].weight
param_shape = safe_get_shape(param)
dtype = torch.float32
if self.low_memory_usage:
dtype = torch.bfloat16 if p.device.type == "cpu" else torch.float16
dtype = torch.bfloat16 if param.device.type == "cpu" else torch.float16
# self.scores[key] = torch.zeros(p.shape, dtype=dtype).to(p.device)
if hasattr(self.pattern, "reduce_score"):
self.scores[key] = self.pattern.reduce_score(torch.zeros(p.shape, dtype=dtype).to(p.device), key)
self.scores[key] = self.pattern.reduce_score(
torch.zeros(param_shape, dtype=dtype).to(param.device), key
)
else:
self.scores[key] = torch.zeros(p.shape, dtype=dtype).to(p.device)
self.scores[key] = torch.zeros(param_shape, dtype=dtype).to(param.device)

self.alpha = 0.9
self.beta = 1.0
Expand All @@ -209,8 +217,11 @@ def on_before_optimizer_step(self):
with torch.no_grad():
for key in self.modules.keys():
p = self.modules[key].weight
param = self.modules[key].weight
data = safe_get_data(param)
grad = safe_get_grad(param)
self.scores[key] *= self.alpha
tmp = torch.abs(p * p.grad)
tmp = torch.abs(data * grad)
if hasattr(self.pattern, "reduce_score"):
tmp = self.pattern.reduce_score(tmp, key, force=True)
if self.low_memory_usage:
Expand Down
24 changes: 17 additions & 7 deletions neural_compressor/compression/pruner/patterns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np

from ..utils import tf, torch
from ..utils import safe_get_data, safe_get_grad, safe_get_shape, tf, torch

PATTERNS = {}

Expand Down Expand Up @@ -75,12 +75,18 @@ def _reshape_2dims_to_orig(data, orig_shape):
Returns:
Reshaped data.
"""
if len(orig_shape) == 4:
if len(orig_shape) == 2:
return data
elif len(orig_shape) == 4:
data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1])
data = data.permute(0, 3, 1, 2)
if len(orig_shape) == 3:
elif len(orig_shape) == 3:
data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[1])
data = data.permute(0, 2, 1)
elif len(orig_shape) == 1:
data = data.reshape(orig_shape)
else:
raise NotImplementedError(f"not support {data.shape}")
return data

# some util functions which can be used.
Expand Down Expand Up @@ -601,12 +607,16 @@ def get_pattern_lock_masks(self, modules):
"""
pattern_lock_masks = {}
for key in modules.keys():
weight = modules[key].weight
shape = weight.shape
# weight = modules[key].weight
# shape = weight.shape
param = modules[key].weight
data = safe_get_data(param)
shape = safe_get_shape(param)
mask = torch.ones(shape)
mask[weight == 0] = 0.0
# mask[weight == 0] = 0.0
mask[data == 0] = 0.0
mask = mask.bool()
pattern_lock_masks[key] = mask.to(weight.device)
pattern_lock_masks[key] = mask.to(param.device)

return pattern_lock_masks

Expand Down
Loading

0 comments on commit 49ab28d

Please sign in to comment.