From 49ab28d362b03338194160e5ef67a3b8c7967a86 Mon Sep 17 00:00:00 2001 From: Yi30 <106061964+yiliu30@users.noreply.github.com> Date: Fri, 12 Jan 2024 09:31:07 +0800 Subject: [PATCH] Enable SNIP on multiple cards using DeepSpeed ZeRO-3 (#1492) Signed-off-by: yiliu30 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../{magnitude => multi_cards}/README.md | 86 ++++++++++++++++- .../config/zero_stage2_config.json | 0 .../config/zero_stage3_config.json | 34 +++++++ .../requirements.txt | 0 .../pruning/{magnitude => multi_cards}/run.sh | 0 .../run_clm_no_trainer.py | 0 .../run_clm_no_trainer_deepspeed.py | 12 ++- .../{magnitude => multi_cards}/run_ds.sh | 0 .../pruning/multi_cards/run_ds_z3.sh | 94 +++++++++++++++++++ .../compression/pruner/criteria.py | 31 ++++-- .../compression/pruner/patterns/base.py | 24 +++-- .../compression/pruner/patterns/ninm.py | 53 ++++++++--- .../compression/pruner/patterns/nxm.py | 67 +++++++++---- .../compression/pruner/pruners/base.py | 10 +- .../compression/pruner/pruners/basic.py | 6 +- .../pruner/pruners/pattern_lock.py | 4 + neural_compressor/compression/pruner/utils.py | 85 +++++++++++++++++ 17 files changed, 447 insertions(+), 59 deletions(-) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{magnitude => multi_cards}/README.md (51%) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{magnitude => multi_cards}/config/zero_stage2_config.json (100%) create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/config/zero_stage3_config.json rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{magnitude => multi_cards}/requirements.txt (100%) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{magnitude => multi_cards}/run.sh (100%) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{magnitude => multi_cards}/run_clm_no_trainer.py (100%) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{magnitude => multi_cards}/run_clm_no_trainer_deepspeed.py (98%) rename examples/pytorch/nlp/huggingface_models/language-modeling/pruning/{magnitude => multi_cards}/run_ds.sh (100%) create mode 100644 examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_ds_z3.sh diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/README.md similarity index 51% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/README.md index b0d049f9547..1d41f53b17d 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/README.md @@ -1,7 +1,7 @@ Step-by-Step ============ -# single GPU +# Single GPU ``` export CUDA_VISIBLE_DEVICES=0 @@ -15,10 +15,11 @@ bash run.sh \ --pruning_frequency=1000 ``` -# multi GPU +# Multi GPU -we use `accelerate` and `deepspeed ZeRO Stage-2` to conduct weight magnitude pruning +We use `accelerate` and `deepspeed ZeRO` to conduct weight magnitude, snip pruning. Below are two usage examples: 1) magnitude pruning with ZeRO Stage-2, and 2) snip pruning with ZeRO Stage-3. +## Magnitude pruning with ZeRO Stage-2 ### Accelerate DeepSpeed Plugin On your machine(s) just run: @@ -105,3 +106,82 @@ bash run_ds.sh \ --pruning_pattern=4x1 \ --pruning_frequency=1000 ``` + + +## SNIP pruning with ZeRO Stage-3 + +To specify the accelerate use DeepSpeed ZeRO Stage-3. On your machine(s) just run: +``` shell +accelerate config + +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_config_file: config/zero_stage3_config.json + zero3_init_flag: true +distributed_type: DEEPSPEED +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 2 +use_cpu: false +``` +with the contents of `config/zero_stage3_config.json` being: + +``` +{ + "train_batch_size": 64, + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 4, + "fp16": { + "enabled": true, + "min_loss_scale": 1, + "opt_level": "O2" + }, + "zero_optimization": { + "stage": 3, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": 0.0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto", + "warmup_type": "cosine" + } + } +} +``` + +### Pruning +> Note: As the ZeRO Stage-3 partitions all three model states(optimizer states, gradients, and parameters), please specify the `pruning_scope` as `local`. Choosing `global` requires gathering all parameters to update the mask, which compromises the benefits of ZeRO Stage-3. + + +``` +# 2 gpu cards example +export CUDA_VISIBLE_DEVICES=0,1 USE_DEEPSPEED=1 +bash run_ds_z3.sh \ + --model_name_or_path=facebook/opt-125m \ + --dataset_name=NeelNanda/pile-10k \ + --block_size=128 \ + --output_dir=./test-clm \ + --pruning_type=snip_momentum \ + --pruning_scope=local \ + --pruning_pattern=4x1 \ + --pruning_frequency=1000 +``` diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/config/zero_stage2_config.json b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/config/zero_stage2_config.json similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/config/zero_stage2_config.json rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/config/zero_stage2_config.json diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/config/zero_stage3_config.json b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/config/zero_stage3_config.json new file mode 100644 index 00000000000..c81ffa45c4e --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/config/zero_stage3_config.json @@ -0,0 +1,34 @@ +{ + "train_batch_size": 64, + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 4, + "fp16": { + "enabled": true, + "min_loss_scale": 1, + "opt_level": "O2" + }, + "zero_optimization": { + "stage": 3, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": 0.0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto", + "warmup_type": "cosine" + } + } +} diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/requirements.txt b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/requirements.txt similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/requirements.txt rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/requirements.txt diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run.sh similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run.sh rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run.sh diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_clm_no_trainer.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer.py rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_clm_no_trainer.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_clm_no_trainer_deepspeed.py similarity index 98% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_clm_no_trainer_deepspeed.py index c8c86fdd971..24496e38f12 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_clm_no_trainer_deepspeed.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_clm_no_trainer_deepspeed.py @@ -274,6 +274,13 @@ def parse_args(): help="pruning criteria to use.", choices=["magnitude", "snip", "snip_momentum"], ) + parser.add_argument( + "--pruning_scope", + type=str, + default="global", + help="determine layers' scores should be gather together to sort.", + choices=["local", "global"], + ) parser.add_argument( "--warm_epochs", type=int, @@ -688,7 +695,7 @@ def group_texts(examples): pruning_configs=[ { "pruning_type": args.pruning_type, - "pruning_scope": "global", + "pruning_scope": args.pruning_scope, "sparsity_decay_type": "exp", "excluded_op_names": ["pooler"], "pruning_op_types": ["Linear"], @@ -800,7 +807,8 @@ def group_texts(examples): if args.output_dir is not None: accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) + # fetch the ds model from inc model + unwrapped_model = accelerator.unwrap_model(model.model) unwrapped_model.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save ) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_ds.sh similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/run_ds.sh rename to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_ds.sh diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_ds_z3.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_ds_z3.sh new file mode 100644 index 00000000000..12599ca7c01 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/run_ds_z3.sh @@ -0,0 +1,94 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_pruning + +} + +# init params +function init_params { + dataset_name="NeelNanda/pile-10k" + model_name_or_path="facebook/opt-125m" + output_dir="./test-clm" + per_device_train_batch_size=8 + block_size=128 + gradient_accumulation_steps=4 + num_train_epochs=3 + target_sparsity=0.8 + pruning_type="snip_momentum" + pruning_scope="local" + pruning_pattern="4x1" + pruning_frequency=1000 + for var in "$@" + do + case $var in + --dataset_name=*) + dataset_name=$(echo $var |cut -f2 -d=) + ;; + --model_name_or_path=*) + model_name_or_path=$(echo $var |cut -f2 -d=) + ;; + --output_dir=*) + output_dir=$(echo $var |cut -f2 -d=) + ;; + --per_device_train_batch_size=*) + per_device_train_batch_size=$(echo $var |cut -f2 -d=) + ;; + --block_size=*) + block_size=$(echo $var |cut -f2 -d=) + ;; + --gradient_accumulation_steps=*) + gradient_accumulation_steps=$(echo $var |cut -f2 -d=) + ;; + --num_train_epochs=*) + num_train_epochs=$(echo $var |cut -f2 -d=) + ;; + --target_sparsity=*) + target_sparsity=$(echo $var |cut -f2 -d=) + ;; + --pruning_type=*) + pruning_type=$(echo $var |cut -f2 -d=) + ;; + --pruning_scope=*) + pruning_scope=$(echo $var |cut -f2 -d=) + ;; + --pruning_pattern=*) + pruning_pattern=$(echo $var |cut -f2 -d=) + ;; + --pruning_frequency=*) + pruning_frequency=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + +# run_tuning +function run_pruning { + accelerate launch --deepspeed_config_file config/ds_config.json --mixed_precision fp16 \ + run_clm_no_trainer_deepspeed.py \ + --dataset_name $dataset_name \ + --model_name_or_path $model_name_or_path \ + --block_size $block_size \ + --per_device_train_batch_size $per_device_train_batch_size \ + --gradient_accumulation_steps $gradient_accumulation_steps \ + --output_dir $output_dir \ + --do_prune \ + --num_train_epochs $num_train_epochs \ + --target_sparsity $target_sparsity \ + --pruning_type $pruning_type \ + --pruning_scope $pruning_scope \ + --pruning_pattern $pruning_pattern \ + --pruning_frequency $pruning_frequency + +} + +main "$@" + diff --git a/neural_compressor/compression/pruner/criteria.py b/neural_compressor/compression/pruner/criteria.py index 5af9ed65a32..ef9dd8ac14c 100644 --- a/neural_compressor/compression/pruner/criteria.py +++ b/neural_compressor/compression/pruner/criteria.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import torch +from .utils import safe_get_data, safe_get_grad, safe_get_shape, torch CRITERIA = {} @@ -96,7 +96,8 @@ def on_step_begin(self): """Calculate and store the pruning scores based on a magnitude criterion.""" with torch.no_grad(): for key in self.modules.keys(): - p = self.modules[key].weight.data + param = self.modules[key].weight + p = safe_get_data(param) if hasattr(self.pattern, "reduce_score"): self.scores[key] = self.pattern.reduce_score(torch.abs(p), key) else: @@ -161,12 +162,15 @@ def on_before_optimizer_step(self): """Calculate and store the pruning scores based on snip criterion.""" with torch.no_grad(): for key in self.modules.keys(): - p = self.modules[key].weight + # p = self.modules[key].weight + param = self.modules[key].weight + data = safe_get_data(param) + grad = safe_get_grad(param) # self.scores[key] = torch.abs(p * p.grad) if hasattr(self.pattern, "reduce_score"): - self.scores[key] = self.pattern.reduce_score(torch.abs(p * p.grad), key) + self.scores[key] = self.pattern.reduce_score(torch.abs(data * grad), key) else: - self.scores[key] = torch.abs(p * p.grad) + self.scores[key] = torch.abs(data * grad) @register_criterion("snip_momentum") @@ -191,15 +195,19 @@ def __init__(self, modules, config, pattern): super(SnipMomentumCriterion, self).__init__(modules, config, pattern) assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion" for key in modules.keys(): - p = modules[key].weight + param = modules[key].weight + # p = modules[key].weight + param_shape = safe_get_shape(param) dtype = torch.float32 if self.low_memory_usage: - dtype = torch.bfloat16 if p.device.type == "cpu" else torch.float16 + dtype = torch.bfloat16 if param.device.type == "cpu" else torch.float16 # self.scores[key] = torch.zeros(p.shape, dtype=dtype).to(p.device) if hasattr(self.pattern, "reduce_score"): - self.scores[key] = self.pattern.reduce_score(torch.zeros(p.shape, dtype=dtype).to(p.device), key) + self.scores[key] = self.pattern.reduce_score( + torch.zeros(param_shape, dtype=dtype).to(param.device), key + ) else: - self.scores[key] = torch.zeros(p.shape, dtype=dtype).to(p.device) + self.scores[key] = torch.zeros(param_shape, dtype=dtype).to(param.device) self.alpha = 0.9 self.beta = 1.0 @@ -209,8 +217,11 @@ def on_before_optimizer_step(self): with torch.no_grad(): for key in self.modules.keys(): p = self.modules[key].weight + param = self.modules[key].weight + data = safe_get_data(param) + grad = safe_get_grad(param) self.scores[key] *= self.alpha - tmp = torch.abs(p * p.grad) + tmp = torch.abs(data * grad) if hasattr(self.pattern, "reduce_score"): tmp = self.pattern.reduce_score(tmp, key, force=True) if self.low_memory_usage: diff --git a/neural_compressor/compression/pruner/patterns/base.py b/neural_compressor/compression/pruner/patterns/base.py index 6bc1a325572..723db325dc6 100644 --- a/neural_compressor/compression/pruner/patterns/base.py +++ b/neural_compressor/compression/pruner/patterns/base.py @@ -20,7 +20,7 @@ import numpy as np -from ..utils import tf, torch +from ..utils import safe_get_data, safe_get_grad, safe_get_shape, tf, torch PATTERNS = {} @@ -75,12 +75,18 @@ def _reshape_2dims_to_orig(data, orig_shape): Returns: Reshaped data. """ - if len(orig_shape) == 4: + if len(orig_shape) == 2: + return data + elif len(orig_shape) == 4: data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]) data = data.permute(0, 3, 1, 2) - if len(orig_shape) == 3: + elif len(orig_shape) == 3: data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[1]) data = data.permute(0, 2, 1) + elif len(orig_shape) == 1: + data = data.reshape(orig_shape) + else: + raise NotImplementedError(f"not support {data.shape}") return data # some util functions which can be used. @@ -601,12 +607,16 @@ def get_pattern_lock_masks(self, modules): """ pattern_lock_masks = {} for key in modules.keys(): - weight = modules[key].weight - shape = weight.shape + # weight = modules[key].weight + # shape = weight.shape + param = modules[key].weight + data = safe_get_data(param) + shape = safe_get_shape(param) mask = torch.ones(shape) - mask[weight == 0] = 0.0 + # mask[weight == 0] = 0.0 + mask[data == 0] = 0.0 mask = mask.bool() - pattern_lock_masks[key] = mask.to(weight.device) + pattern_lock_masks[key] = mask.to(param.device) return pattern_lock_masks diff --git a/neural_compressor/compression/pruner/patterns/ninm.py b/neural_compressor/compression/pruner/patterns/ninm.py index d02508cbbcd..53264b47fd5 100644 --- a/neural_compressor/compression/pruner/patterns/ninm.py +++ b/neural_compressor/compression/pruner/patterns/ninm.py @@ -15,7 +15,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import logger, nn, tf, torch + +from ..utils import logger, nn, safe_get_data, safe_get_grad, safe_get_shape, tf, torch from .base import ProgressivePatternUtils, PytorchBasePattern, SparsityInfo, register_pattern @@ -145,12 +146,22 @@ def _reshape_orig_to_2dims(self, data): Returns: Reshaped data. """ - if len(data.shape) == 4: # TODO: need to verify whether it's ok for transposed conv + from ..utils import FLATTEN_DIM2 + + if len(data.shape) == 2: + return data + elif len(data.shape) == 4: # TODO: need to verify whether it's ok for transposed conv data = data.permute(0, 2, 3, 1) # cout,k,k,cin data = data.reshape(data.shape[0], -1) - if len(data.shape) == 3: + elif len(data.shape) == 3: data = data.permute(0, 2, 1) # cout,k,cin data = data.reshape(data.shape[0], -1) + elif len(data.shape) == 1: + data = data.reshape(-1, FLATTEN_DIM2) + else: + raise NotImplementedError( + f"Currently only support reshape data with 1,3,4-dims, but got shape {data.shape}" + ) return data def _reshape_2dims_to_orig(self, data, orig_shape): @@ -162,12 +173,21 @@ def _reshape_2dims_to_orig(self, data, orig_shape): Returns: Reshaped data. """ - if len(orig_shape) == 4: + if len(orig_shape) == 2: + return data + + elif len(orig_shape) == 4: data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]) data = data.permute(0, 3, 1, 2) - if len(orig_shape) == 3: + elif len(orig_shape) == 3: data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[1]) data = data.permute(0, 2, 1) + elif len(orig_shape) == 1: + data = data.reshape(orig_shape) + else: + raise NotImplementedError( + f"Currently only support reshape data with 1,3,4-dims, but got shape {data.shape}" + ) return data def reshape_orig_to_pattern(self, data, key): @@ -342,11 +362,13 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, keep_ex for key in masks.keys(): if key in self.invalid_layers: continue - orig_shape = self.modules[key].weight.shape - if len(orig_shape) == 4 or len(orig_shape) == 3: # need to permute - mask = masks[key] - mask = self._reshape_2dims_to_orig(mask, orig_shape) - masks[key] = mask + # orig_shape = self.modules[key].weight.shape + param = self.modules[key].weight + orig_shape = safe_get_shape(param) + # if len(orig_shape) == 4 or len(orig_shape) == 3 or: # need to permute + mask = masks[key] + mask = self._reshape_2dims_to_orig(mask, orig_shape) + masks[key] = mask layer_ratio = torch.sum(masks[key] == 0.0).data.item() / masks[key].numel() logger.info(f"layer {key} sparsity_ratio is {layer_ratio}") return masks @@ -362,13 +384,16 @@ def get_pattern_lock_masks(self, modules): """ pattern_lock_masks = {} for key in modules.keys(): - weight = modules[key].weight - orig_shape = weight.shape + # weight = modules[key].weight + param = modules[key].weight + # orig_shape = weight.shape + orig_shape = safe_get_shape(param) + data = safe_get_data(param) if key in self.invalid_layers: - mask = torch.ones(orig_shape, device=weight.device) + mask = torch.ones(orig_shape, device=param.device) pattern_lock_masks[key] = mask.bool() continue - reduced_mask = self.get_reduced_masks_from_data(weight, key) + reduced_mask = self.get_reduced_masks_from_data(data, key) mask = self.reshape_reduced_to_orig(reduced_mask, key, orig_shape) pattern_lock_masks[key] = mask return pattern_lock_masks diff --git a/neural_compressor/compression/pruner/patterns/nxm.py b/neural_compressor/compression/pruner/patterns/nxm.py index d9812c366df..560c6c05f8c 100644 --- a/neural_compressor/compression/pruner/patterns/nxm.py +++ b/neural_compressor/compression/pruner/patterns/nxm.py @@ -17,7 +17,7 @@ # limitations under the License. import numpy as np -from ..utils import logger, nn, tf, torch +from ..utils import logger, nn, safe_get_data, safe_get_grad, safe_get_shape, tf, torch from .base import KerasBasePattern, ProgressivePatternUtils, PytorchBasePattern, SparsityInfo, register_pattern @@ -71,7 +71,9 @@ def get_block_size_dict(self): if not (self.N == "channel" or self.M == "channel"): continue if isinstance(datas[key], torch.nn.Module): - shape = datas[key].weight.shape + param = datas[key].weight + shape = safe_get_shape(param) + # shape = datas[key].weight.shape else: shape = datas[key].shape if self.N == "channel": # support "channelxM" format @@ -86,7 +88,8 @@ def check_layer_validity(self): block_sizes = self.block_size datas = self.modules for key in datas.keys(): - data = datas[key].weight + param = datas[key].weight + data = safe_get_data(param) data = self._reshape_orig_to_2dims(data) shape = data.shape block_size = block_sizes[key] @@ -153,18 +156,29 @@ def _reshape_orig_to_2dims(self, data): Reshaped data. """ # TODO: need to verify whether it's ok for transposed conv - if len(data.shape) == 4: + from ..utils import FLATTEN_DIM2 + + if len(data.shape) == 2: + return data + elif len(data.shape) == 4: if isinstance(data, np.ndarray): data = np.transpose(data, (0, 2, 3, 1)) else: data = data.permute(0, 2, 3, 1) # cout,k,k,cin data = data.reshape(data.shape[0], -1) - if len(data.shape) == 3: + elif len(data.shape) == 3: if isinstance(data, np.ndarray): data = np.transpose(data, (0, 2, 1)) else: data = data.permute(0, 2, 1) # cout,k,cin data = data.reshape(data.shape[0], -1) + # TODO(Yi) support handle 1-dim (flatten param from DeepSpeed or FSDP) + elif len(data.shape) == 1: # pragma: no cover + data = data.reshape(-1, FLATTEN_DIM2) + else: + raise NotImplementedError( + f"Currently only support reshape data with 1,3,4-dims, but got shape {data.shape}" + ) return data def _reshape_2dims_to_orig(self, data, orig_shape): @@ -177,18 +191,26 @@ def _reshape_2dims_to_orig(self, data, orig_shape): Returns: Reshaped data. """ - if len(orig_shape) == 4: + if len(orig_shape) == 2: + return data + elif len(orig_shape) == 4: data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]) if isinstance(data, np.ndarray): # pragma: no cover data = np.transpose(data, (0, 3, 1, 2)) else: data = data.permute(0, 3, 1, 2) - if len(orig_shape) == 3: + elif len(orig_shape) == 3: data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[1]) if isinstance(data, np.ndarray): # pragma: no cover data = np.transpose(data, (0, 2, 1)) else: data = data.permute(0, 2, 1) + elif len(orig_shape) == 1: + data = data.reshape(-1) + else: + raise NotImplementedError( + f"Currently only support reshape data with 1,3,4-dims, but got shape {data.shape}" + ) return data def reshape_orig_to_pattern(self, data, key): @@ -359,12 +381,14 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, keep_ex for key in masks.keys(): if key in self.invalid_layers: continue - orig_shape = self.modules[key].weight.shape - if len(orig_shape) == 4 or len(orig_shape) == 3: # need to permute - mask = masks[key] - # orig_shape = scores[key].shape - mask = self._reshape_2dims_to_orig(mask, orig_shape) - masks[key] = mask + param = self.modules[key].weight + orig_shape = safe_get_shape(param) + # orig_shape = self.modules[key].weight.shape + # if len(orig_shape) == 4 or len(orig_shape) == 3 : # need to permute + mask = masks[key] + # orig_shape = scores[key].shape + mask = self._reshape_2dims_to_orig(mask, orig_shape) + masks[key] = mask layer_ratio = torch.sum(masks[key] == 0.0).data.item() / masks[key].numel() logger.info(f"{key} sparsity is {layer_ratio}") return masks @@ -380,13 +404,15 @@ def get_pattern_lock_masks(self, modules): """ pattern_lock_masks = {} for key in modules.keys(): - weight = modules[key].weight - ori_shape = weight.shape + param = modules[key].weight + data = safe_get_data(param) + ori_shape = safe_get_shape(param) + # ori_shape = weight.shape if key in self.invalid_layers: - mask = torch.ones(weight.shape, device=weight.device) + mask = torch.ones(ori_shape, device=param.device) pattern_lock_masks[key] = mask continue - reduced_mask = self.get_reduced_masks_from_data(weight, key) + reduced_mask = self.get_reduced_masks_from_data(data, key) mask = self.reshape_reduced_to_orig(reduced_mask, key, ori_shape) pattern_lock_masks[key] = mask @@ -424,7 +450,8 @@ def mask_block_weights(self, masks): continue module = self.modules[key] block_size = self.block_size[key] - org_shape = module.weight.shape + # org_shape = module.weight.shape + org_shape = safe_get_shape(module.weight) mask = ( masks[key] .data.repeat_interleave(block_size[0], dim=0) @@ -531,7 +558,9 @@ def fasterprune(self, gpt, blocksize=128, percdamp=0.01): if isinstance(module, transformers.Conv1D): W = W.t() - module.weight.data = W.reshape(module.weight.shape).to(dtype=module.weight.data.dtype) + param = module.weight + param_shape = safe_get_shape(param) + module.weight.data = W.reshape(param_shape).to(dtype=module.weight.data.dtype) if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/neural_compressor/compression/pruner/pruners/base.py b/neural_compressor/compression/pruner/pruners/base.py index bbdfcf344a9..47842389a16 100644 --- a/neural_compressor/compression/pruner/pruners/base.py +++ b/neural_compressor/compression/pruner/pruners/base.py @@ -18,7 +18,7 @@ import numpy as np -from ..utils import F, tf, torch +from ..utils import F, safe_get_data, safe_get_grad, safe_get_shape, safe_set_data, tf, torch PRUNERS = {} @@ -205,7 +205,8 @@ def __init__(self, config, modules): for key in self.modules.keys(): module = self.modules[key] # TODO: support bias or others - self.masks[key] = torch.ones(module.weight.shape).to(module.weight.device).bool() + param_shape = safe_get_shape(module.weight) + self.masks[key] = torch.ones(param_shape).to(module.weight.device).bool() self._init() def mask_weights(self): @@ -216,7 +217,10 @@ def mask_weights(self): with torch.no_grad(): for key in self.modules.keys(): module = self.modules[key] - module.weight.data = module.weight.data * self.masks[key] + param = module.weight + param_data = safe_get_data(param) + new_val = param_data * self.masks[key] + safe_set_data(new_val=new_val, param=param) class KerasBasePruner(BasePruner): diff --git a/neural_compressor/compression/pruner/pruners/basic.py b/neural_compressor/compression/pruner/pruners/basic.py index 003ae6cce21..da84ef35f5e 100644 --- a/neural_compressor/compression/pruner/pruners/basic.py +++ b/neural_compressor/compression/pruner/pruners/basic.py @@ -16,14 +16,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +# from ..utils import logger +from neural_compressor.utils.logger import Logger + from ..criteria import get_criterion from ..patterns import get_pattern from ..regs import get_reg from ..schedulers import get_scheduler from ..tf_criteria import get_tf_criterion -from ..utils import logger from .base import KerasBasePruner, PytorchBasePruner, register_pruner +logger = Logger().get_logger() + @register_pruner("pt_basic") class PytorchBasicPruner(PytorchBasePruner): diff --git a/neural_compressor/compression/pruner/pruners/pattern_lock.py b/neural_compressor/compression/pruner/pruners/pattern_lock.py index a2786b871c4..d30fc1cf35c 100644 --- a/neural_compressor/compression/pruner/pruners/pattern_lock.py +++ b/neural_compressor/compression/pruner/pruners/pattern_lock.py @@ -16,9 +16,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from neural_compressor.utils.logger import Logger + from ..patterns import get_pattern from .base import KerasBasePruner, PytorchBasePruner, register_pruner +logger = Logger().get_logger() + @register_pruner("pt_pattern_lock") class PytorchPatternLockPruner(PytorchBasePruner): diff --git a/neural_compressor/compression/pruner/utils.py b/neural_compressor/compression/pruner/utils.py index d31bdb231b6..6f6aa985f21 100644 --- a/neural_compressor/compression/pruner/utils.py +++ b/neural_compressor/compression/pruner/utils.py @@ -745,3 +745,88 @@ def forward(_, hidden_states, *positional_args, **kwargs): inputs.append(batch) return inputs, positional_inputs, other_input_infos + + +######################################################## +## Utility for integrate DeepSpeed +######################################################## +import os + +USE_DEEPSPEED = False +FLATTEN_DIM2 = 8 + + +def is_deepspeed_available(): # pragma: no cover + import importlib + import importlib.metadata as importlib_metadata + + package_exists = importlib.util.find_spec("deepspeed") is not None + + # Check we're not importing a "deepspeed" directory somewhere but the actual library by trying to grab the version + # AND checking it has an author field in the metadata that is HuggingFace. + if package_exists: + try: + _ = importlib_metadata.metadata("deepspeed") + return True + except importlib_metadata.PackageNotFoundError: + return False + + +from packaging.version import Version + + +def get_deepspeed_version(): # pragma: no cover + try: + import deepspeed # pylint: disable=E0401 + + deepspeed_version = deepspeed.__version__.split("+")[0] + except ValueError as e: # pragma: no cover + assert False, "Got an unknown version of torch: {}".format(e) + version = Version(deepspeed_version) + return version + + +def check_deepspeed_version(): # pragma: no cover + version = get_deepspeed_version() + assert version >= Version("0.12.4"), f"The minimum version requirement of deepspeed is 0.12.4, but got {version}." + + +USE_DEEPSPEED = os.environ.get("USE_DEEPSPEED", False) +if USE_DEEPSPEED: # pragma: no cover + assert is_deepspeed_available(), "Deepspeed is required: `pip install deepspeed>0.12.4" + check_deepspeed_version() + + +def safe_get_shape(param): # pragma: no cover + if USE_DEEPSPEED: + # param.ds_tensor is the partitioned tensor + return param.ds_tensor.shape + else: + return param.shape + + +def safe_get_data(param): # pragma: no cover + if USE_DEEPSPEED: + from deepspeed.utils import safe_get_local_fp32_param # pylint: disable=E0401 + + return safe_get_local_fp32_param(param) + else: + return param.data + + +def safe_get_grad(param): # pragma: no cover + if USE_DEEPSPEED: + from deepspeed.utils import safe_get_local_grad # pylint: disable=E0401 + + return safe_get_local_grad(param) + else: + return param.grad + + +def safe_set_data(param, new_val): # pragma: no cover + if USE_DEEPSPEED: + from deepspeed.utils import safe_set_local_fp32_param # pylint: disable=E0401 + + safe_set_local_fp32_param(new_val, param) + else: + param.data = new_val