Skip to content

Commit

Permalink
Support auto device for GPTQ and RTN (#1622)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 authored Feb 26, 2024
1 parent 071ab31 commit 2a86aea
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 17 deletions.
5 changes: 3 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
import transformers
from tqdm import tqdm

from neural_compressor.torch.utils import fetch_module, logger, set_module
from neural_compressor.torch.utils import fetch_module, get_device, logger, set_module
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .modules import WeightOnlyLinear

Expand Down Expand Up @@ -255,7 +256,7 @@ def __init__(
self.check_layer_config()

# device
self.device = device
self.device = get_device(kwargs.pop("device", "auto"))
if str(self.model.device).startswith("cuda"):
self.device = self.model.device
self.is_ready = False
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/algorithms/weight_only/hqq/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import torch

from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .auto_accelerator import auto_detect_accelerator
from .bitpack import Packer
from .config import HQQModuleConfig, QTensorConfig, default_hqq_module_config, hqq_global_option
from .optimizer import optimize_weights_proximal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import torch

from neural_compressor.torch.utils import logger

from .auto_accelerator import auto_detect_accelerator
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator


# Proximal solver || W - dequantize(quantize(W))||_p^p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import torch

from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .auto_accelerator import auto_detect_accelerator
from .config import ConfigMappingType, default_hqq_module_config, hqq_global_option
from .core import HQQLinear

Expand Down
11 changes: 9 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

import torch

from neural_compressor.torch.utils import logger, set_module
from neural_compressor.torch.utils import get_device, logger, set_module
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .utility import quant_tensor, search_clip

Expand Down Expand Up @@ -73,7 +74,12 @@ def rtn_quantize(
Returns:
model: fake quantized torch module
"""
device = "cpu"
device = get_device(kwargs.pop("device", "auto"))

# Put model on device explicitly
# TODO: refine it later, Put module on device one by one instead of the whole model
model.to(device)

assert isinstance(model, torch.nn.Module), "only support torch module"
supported_layers = ["Linear"]
# initialize global configuration
Expand All @@ -94,6 +100,7 @@ def rtn_quantize(
dtype = weight_config[name].get("dtype", "int")
if dtype == "fp32":
continue
logger.debug("Apply RTN on module %s.", name)
bits = weight_config[name].get("bits", 4)
group_size = weight_config[name]["group_size"]
scheme = weight_config[name]["scheme"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,23 @@ def empty_cache(self):
return torch.cuda.empty_cache()


def auto_detect_accelerator() -> Auto_Accelerator:
# if runtime_accelerator.accelerator:
# return runtime_accelerator.accelerator
def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator:
# The environment variable `FORCE_DEVICE` has higher priority than the `device_name`.
# TODO: refine the docs and logic later
FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None)
if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None:
logger.warning("Force use %s accelerator.", FORCE_DEVICE)
return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)()
if device_name != "auto":
if accelerator_registry.get_accelerator_cls_by_name(device_name) is not None:
accelerator_cls = accelerator_registry.get_accelerator_cls_by_name(device_name)
logger.warning("Selected accelerator %s by device_name.", accelerator_cls.__name__)
return accelerator_cls()
else:
logger.warning("The device name %s is not supported, use auto detect instead.", device_name)
for accelerator_cls in accelerator_registry.get_sorted_accelerators():
if accelerator_cls.is_available():
logger.debug("Auto detect accelerator: %s.", accelerator_cls.__name__)
logger.warning("Auto detect accelerator: %s.", accelerator_cls.__name__)
accelerator = accelerator_cls()
return accelerator

Expand Down
8 changes: 8 additions & 0 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ def get_torch_version():
assert False, "Got an unknown version of torch: {}".format(e)
version = Version(torch_version)
return version


def get_device(device_name="auto"):
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

runtime_accelerator = auto_detect_accelerator(device_name)
device = runtime_accelerator.name()
return device
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
from transformers import AutoModelForCausalLM

from neural_compressor.torch.algorithms.weight_only.hqq.auto_accelerator import auto_detect_accelerator
from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
from neural_compressor.torch.algorithms.weight_only.hqq.utility import see_cuda_memory_usage
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator


def _common_cuda_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False, scale_quant_group_size=128):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import pytest
import torch

from neural_compressor.torch.algorithms.weight_only.hqq.auto_accelerator import (
accelerator_registry,
auto_detect_accelerator,
)
from neural_compressor.torch.utils.auto_accelerator import accelerator_registry, auto_detect_accelerator


class Test_CPU_Accelerator:
Expand Down
241 changes: 241 additions & 0 deletions test/3x/torch/quantization/weight_only/test_woq_on_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import random

import pytest
import torch
import transformers
from tqdm import tqdm

from neural_compressor.common.utils import logger
from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device
from neural_compressor.torch.quantization import GPTQConfig, get_default_rtn_config, quantize


class GPTQDataloaderPreprocessor:
def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128):
self.dataloader_original = dataloader_original
self.use_max_length = use_max_length
self.max_seq_length = max_seq_length
self.nsamples = nsamples
self.dataloader = []
self.is_ready = False

def get_prepared_dataloader(self):
if not self.is_ready:
self.prepare_dataloader()
return self.dataloader

def prepare_dataloader(self):
if self.use_max_length:
# (Recommend) only take sequence whose length exceeds self.max_seq_length,
# which preserves calibration's tokens are all valid
# This is GPTQ official dataloader implementation
self.obtain_first_n_samples_fulllength()
else:
# general selection, no padding, not GPTQ original implementation.
self.obtain_first_n_samples()
self.is_ready = True

def obtain_first_n_samples(self, seed=0):
"""Get first nsample data as the real calibration dataset."""
self.dataloader.clear()
random.seed(seed)
for batch in self.dataloader_original:
# process data, depends on its data type.
if len(self.dataloader) == self.nsamples:
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
break
# list, tuple
if isinstance(batch, list) or isinstance(batch, tuple):
if batch[0].shape[-1] > self.max_seq_length:
i = random.randint(0, batch[0].shape[-1] - self.max_seq_length - 1)
j = i + self.max_seq_length
batch_final = []
for item in batch:
if isinstance(item, torch.Tensor) and item.shape.__len__() == 2:
batch_final.append(item[:, i:j])
else:
batch_final.append(item)
else:
batch_final = batch[:]
# dict
elif isinstance(batch, dict):
try:
length = batch["input_ids"].shape[-1]
except:
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
continue
batch_final = {}
if length > self.max_seq_length:
i = random.randint(0, length - self.max_seq_length - 1)
j = i + self.max_seq_length
# may have to slice every sequence related data
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch_final[key] = batch[key][:, i:j] # slice on sequence length dim
else:
batch_final[key] = batch[key]
else:
batch_final = batch
# tensor
else:
if batch.shape[-1] > self.max_seq_length:
i = random.randint(0, batch.shape[-1] - self.max_seq_length - 1)
j = i + self.max_seq_length
batch_final = batch[:, i:j]
else:
batch_final = batch
self.dataloader.append(batch_final)

if len(self.dataloader) < self.nsamples:
logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.")

def obtain_first_n_samples_fulllength(self, seed=0):
self.dataloader.clear()
random.seed(seed)
unified_length = self.max_seq_length
for batch in self.dataloader_original:
if len(self.dataloader) == self.nsamples:
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
break
# list & tuple, gpt-j-6b mlperf, etc.
if isinstance(batch, list) or isinstance(batch, tuple):
if batch[0].shape[-1] == unified_length:
batch_final = batch[:]
elif batch[0].shape[-1] > unified_length:
i = random.randint(0, batch[0].shape[-1] - unified_length - 1)
j = i + unified_length
batch_final = []
for item in batch:
if isinstance(item, torch.Tensor) and item.shape.__len__() == 2:
batch_final.append(item[:, i:j])
else:
batch_final.append(item)
else:
# not match max length, not include in target dataset
continue
# dict
elif isinstance(batch, dict):
try:
length = batch["input_ids"].shape[-1]
except:
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
continue
batch_final = {}
if length == self.max_seq_length:
batch_final = batch
elif length > self.max_seq_length:
i = random.randint(0, length - self.max_seq_length - 1)
j = i + self.max_seq_length
# may have to slice every sequence related data
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position
else:
batch_final[key] = batch[key]
else:
# not match max length, not include in target dataset
continue
# tensor
else:
if batch.shape[-1] == unified_length:
batch_final = batch
elif batch.shape[-1] > unified_length:
i = random.randint(0, batch.shape[-1] - unified_length - 1)
j = i + unified_length
batch_final = batch[:, i:j]
else:
# not match max length, not include in target dataset
continue
self.dataloader.append(batch_final)
if len(self.dataloader) < self.nsamples: # pragma: no cover
logger.warning(
f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \
but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value."
)


class TestGPTQ:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
def test_GPTQ_fixed_length_quant(self):
class GPTQLLMDataLoader:
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(10):
yield torch.ones([1, 512], dtype=torch.long)

class GPTQLLMDataLoaderList:
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(10):
yield (torch.ones([1, 512], dtype=torch.long), torch.ones([1, 512], dtype=torch.long))

class GPTQLLMDataLoaderDict:
def __init__(self):
self.batch_size = 1

def __iter__(self):
for i in range(10):
yield {
"input_ids": torch.ones([1, 512], dtype=torch.long),
"attention_mask": torch.ones([1, 512], dtype=torch.long),
}

dataloader_list = GPTQLLMDataLoaderList()
dataloader_dict = GPTQLLMDataLoaderDict()

quant_config = GPTQConfig()
quant_config.set_local("lm_head", GPTQConfig(dtype="fp32"))

gptq_use_max_length = False
gptq_max_seq_length = 2048
dataloaderPreprocessor = GPTQDataloaderPreprocessor(
dataloader_original=dataloader_list,
use_max_length=gptq_use_max_length,
max_seq_length=gptq_max_seq_length,
)
dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader()

def run_fn_for_gptq(model, dataloader_for_calibration, *args):
for batch in tqdm(dataloader_for_calibration):
batch = move_input_to_device(batch, device=model.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
model(**batch)
else:
model(batch)
except ValueError:
pass
return

user_model = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
)

user_model = quantize(
model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration
)
model_device = str(user_model.device)
assert "cuda" in model_device, f"Model device is {model_device}"


class TestRTNQuant:

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
def test_rtn(self):
self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
)
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long)
model = self.tiny_gptj
# record label for comparison
self.label = model(self.example_inputs.to(model.device))[0]
# test_default_config
quant_config = get_default_rtn_config()
q_model = quantize(model, quant_config)
assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}"

0 comments on commit 2a86aea

Please sign in to comment.