Skip to content

Commit

Permalink
Support auto device for TEQ and AWQ (#1634)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 authored Feb 27, 2024
1 parent 7e1fa90 commit 5343009
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 14 deletions.
15 changes: 12 additions & 3 deletions neural_compressor/torch/algorithms/weight_only/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
# Copied from neural_compressor/adaptor/torch_utils/awq.py

import copy
from functools import partial

import torch

from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils import get_device, logger

from .modules import MulLinear
from .utility import (
Expand All @@ -33,6 +32,8 @@
set_module,
)

__all__ = ["awq_quantize"]


def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}):
"""Get absorbed layer per block.
Expand Down Expand Up @@ -122,10 +123,13 @@ def __init__(
use_full_range=False,
weight_config={},
):

self.example_inputs = example_inputs
self.model = model
if example_inputs is None:
assert dataloader is not None, "datalaoder or example_inputs is required."
self.example_inputs = get_example_input(dataloader)
self._move_model_and_data_to_device()
# Step 1: get hidden states and kwargs of first block.
self.total_block_args, self.total_block_kwargs = get_hidden_states(
model, dataloader=dataloader, n_samples=n_samples, calib_func=calib_func
Expand All @@ -139,7 +143,12 @@ def __init__(
self.scheme = scheme
self.use_full_range = use_full_range
self.weight_config = weight_config
self.model = model

def _move_model_and_data_to_device(self):
# Put the model and example_inputs into target device
device = get_device()
self.model.to(device)
self.example_inputs = self.example_inputs.to(device)

def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, return_int=False):
"""Execute AWQ quantization.
Expand Down
14 changes: 11 additions & 3 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import torch
import transformers

from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils import get_device, logger

from .modules import MulLinear, TEQLinearFakeQuant
from .utility import get_module, quant_tensor, set_module

__all__ = ["teq_quantize", "TEQuantizer"]


class TEQuantizer:
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input."""
Expand All @@ -38,16 +40,22 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
self.weight_config = weight_config
self.folding = folding
self.example_inputs = example_inputs
self.device, self.dtype = self._get_device()
self.device = self._get_device()
self.dtype = self._get_dtype()
self.model.eval()
self.trained_alphas = {}
self.absorb_to_layer = absorb_to_layer

def _get_device(self):
"""Get the model device
:return:Model device."""
device = get_device()
self.model.to(device)
return device

def _get_dtype(self):
for _, p in self.model.named_parameters():
return p.data.device, p.data.dtype
return p.data.dtype

def add_tuning_scale(self, sqrt_w_init=False):
"""The main entry of smooth quant
Expand Down
11 changes: 10 additions & 1 deletion neural_compressor/torch/utils/auto_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

# NOTICE: The design adapted from:
# https://github.com/microsoft/DeepSpeed/blob/master/accelerator/abstract_accelerator.py.
# TODO: move it into torch/utils


# To keep it simply, only add the APIs we need.
Expand Down Expand Up @@ -204,19 +203,27 @@ def empty_cache(self):


def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator:
# Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ...
# The `FORCE_DEVICE` is case insensitive.
# The environment variable `FORCE_DEVICE` has higher priority than the `device_name`.
# TODO: refine the docs and logic later
# 1. Get the device setting from environment variable `FORCE_DEVICE`.
FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None)
if FORCE_DEVICE:
FORCE_DEVICE = FORCE_DEVICE.lower()
# 2. If the `FORCE_DEVICE` is set and the accelerator is available, use it.
if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None:
logger.warning("Force use %s accelerator.", FORCE_DEVICE)
return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)()
# 3. If the `device_name` is set and the accelerator is available, use it.
if device_name != "auto":
if accelerator_registry.get_accelerator_cls_by_name(device_name) is not None:
accelerator_cls = accelerator_registry.get_accelerator_cls_by_name(device_name)
logger.warning("Selected accelerator %s by device_name.", accelerator_cls.__name__)
return accelerator_cls()
else:
logger.warning("The device name %s is not supported, use auto detect instead.", device_name)
# 4. Select the accelerator by priority.
for accelerator_cls in accelerator_registry.get_sorted_accelerators():
if accelerator_cls.is_available():
logger.warning("Auto detect accelerator: %s.", accelerator_cls.__name__)
Expand All @@ -227,4 +234,6 @@ def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator:
# Force use cpu accelerator even if cuda is available.
# FORCE_DEVICE = "cpu" python ...
# or
# FORCE_DEVICE = "CPU" python ...
# or
# CUDA_VISIBLE_DEVICES="" python ...
151 changes: 144 additions & 7 deletions test/3x/torch/quantization/weight_only/test_woq_on_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,22 @@

from neural_compressor.common.utils import logger
from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device
from neural_compressor.torch.quantization import GPTQConfig, get_default_rtn_config, quantize
from neural_compressor.torch.quantization import (
AWQConfig,
GPTQConfig,
get_default_awq_config,
get_default_rtn_config,
get_default_teq_config,
quantize,
)


def get_gpt_j():
tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
torchscript=True,
)
return tiny_gptj


class GPTQDataloaderPreprocessor:
Expand Down Expand Up @@ -213,9 +228,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
pass
return

user_model = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
)
user_model = get_gpt_j()

user_model = quantize(
model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration
Expand All @@ -228,9 +241,7 @@ class TestRTNQuant:

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
def test_rtn(self):
self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
)
self.tiny_gptj = get_gpt_j()
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long)
model = self.tiny_gptj
# record label for comparison
Expand All @@ -239,3 +250,129 @@ def test_rtn(self):
quant_config = get_default_rtn_config()
q_model = quantize(model, quant_config)
assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
class TestAWQOnCuda:

def test_awq(self):
self.lm_input = torch.ones([1, 10], dtype=torch.long)
self.gptj = get_gpt_j()
example_inputs = torch.ones([1, 10], dtype=torch.long)

def calib_func(model):
for i in range(2):
model(self.lm_input.to(model.device))

quant_config = get_default_awq_config()
logger.info("Test quantization with config", quant_config)
q_model = quantize(
model=self.gptj, quant_config=quant_config, example_inputs=self.lm_input, run_fn=calib_func, inplace=False
)
out2 = q_model(example_inputs.to(q_model.device))
assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}"
assert "cuda" in str(out2[0].device), f"Expect out2 device is cuda, got {out2.device}"


def generate_random_corpus(nsamples=32):
meta_data = []
for _ in range(nsamples):
inp = torch.ones([1, 512], dtype=torch.long)
tar = torch.ones([1, 512], dtype=torch.long)
meta_data.append((inp, tar))
return meta_data


def train(
model,
train_steps=1000,
lr=1e-3,
warmup_ratio=0.05,
gradient_accumulation_steps=1,
logging_steps=10,
betas=[0.9, 0.9],
weight_decay=0,
lr_scheduler_type="linear",
):
"""Train function."""
trained_alphas_list = [torch.ones([128], requires_grad=True)]
optimizer = torch.optim.Adam(trained_alphas_list, lr=lr, weight_decay=weight_decay, betas=betas)

lr_scheduler = transformers.get_scheduler( # pylint: disable=E1111
name=lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=int(train_steps * warmup_ratio) // gradient_accumulation_steps,
num_training_steps=train_steps // gradient_accumulation_steps,
)

logger.info("start training")
model.train()
global_steps = 0
dataloader = generate_random_corpus()
while global_steps <= train_steps:
for inputs in dataloader:
if isinstance(inputs, torch.Tensor):
input_id = inputs
elif isinstance(inputs, dict):
input_id = inputs["input_ids"]
else:
input_id = inputs[0]
output = model(input_id.to(model.device), labels=input_id.to(model.device))
loss = output[0] / gradient_accumulation_steps
loss.backward()
global_steps += 1

if global_steps % logging_steps == 0:
logger.info("steps: {}, loss: {}".format(global_steps, loss.detach().cpu().item()))

if global_steps % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

if global_steps >= train_steps: # pragma: no cover
break

logger.info("finish training")
model.eval()
return None


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
class TestTEQOnCuda:

def test_teq(self):
quant_config = {
"teq": {
"global": {
"dtype": "fp32",
},
"local": {
"transformer.h.0.mlp.fc_in": {
"dtype": "int",
"bits": 8,
"group_size": -1,
"use_sym": True,
"folding": True,
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
},
"transformer.h.0.mlp.fc_out": {
"dtype": "int",
"bits": 4,
"group_size": 32,
"use_sym": False,
"folding": True,
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
},
},
}
}
example_inputs = torch.ones([1, 512], dtype=torch.long)
test_input = torch.ones([1, 512], dtype=torch.long)
model = get_gpt_j()

qdq_model = quantize(model=model, quant_config=quant_config, run_fn=train, example_inputs=example_inputs)
assert isinstance(qdq_model, torch.nn.Module), "Expect qdq_model is a torch module"
out2 = qdq_model(test_input.to(qdq_model.device))
assert "cuda" in str(qdq_model.device), f"Expect qmodel device is cuda, got {qdq_model.device}"
assert "cuda" in str(out2[0].device), f"Expect out2 device is cuda, got {out2.device}"

0 comments on commit 5343009

Please sign in to comment.