Skip to content

Commit

Permalink
[Update] Deploy to TVM && TQT fake quantize.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangqi3 committed Nov 4, 2021
1 parent 371cc4d commit 3567c06
Show file tree
Hide file tree
Showing 15 changed files with 1,817 additions and 188 deletions.
215 changes: 69 additions & 146 deletions mqbench/adaround.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.fx import GraphModule, Node
from torch.quantization.observer import ObserverBase


from mqbench.observer import MinMaxObserver, ObserverBase
from mqbench.utils import deepcopy_graphmodule
from mqbench.utils.state import enable_quantization, disable_all
from mqbench.utils.logger import logger


_ADAROUND_SUPPORT_TYPE = (nn.Conv2d, nn.Linear, )
__all__ = ['adaround']
_ADAROUND_SUPPORT_TYPE = (nn.Conv2d, nn.Linear)


def lp_norm(prediction, target, p=2.0):
Expand All @@ -26,6 +31,7 @@ def lp_norm(prediction, target, p=2.0):
"""
return (prediction - target).abs().pow(p).sum(1).mean()


def _rectified_sigmoid(x, zeta, gamma):
"""Function to generate rounding mask.
Expand All @@ -39,60 +45,28 @@ def _rectified_sigmoid(x, zeta, gamma):
"""
return ((zeta - gamma) * torch.sigmoid(x) + gamma).clamp(0, 1)

def get_cali_samples(train_data_loader, num_samples, no_label=True):
"""Generate sub-dataset for calibration.
Args:
train_data_loader (torch.utils.data.DataLoader):
num_samples (int):
no_label (bool, optional): If the dataloader has no labels. Defaults to True.

Returns:
torch.Tensor: Concatenated data matrix.
"""
cali_data_list = []
if no_label:
for batch_data in train_data_loader:
cali_data_list.append(batch_data["image"])
if len(cali_data_list) >= num_samples:
break
else:
for batch_data, _ in train_data_loader:
cali_data_list.append(batch_data)
if len(cali_data_list) >= num_samples:
break
return torch.cat(cali_data_list, dim=0)[:num_samples].cpu()

def adaround(model: GraphModule, train_data, n_samples: int = 128,
lr: float = 4e-3, batch_size: int = 128, max_iter: int = 8000,
weight: float = 0.01, beta: float = 20, gamma: float = -0.1, zeta: float = 1.1,
quant_min: int = -128, quant_max: int = 127, per_channel: bool = False):
def adaround(model: GraphModule, cali_data,
lr: float = 0.001, batch_size: int = 128, max_iter: int = 8000,
weight: float = 0.01, beta: float = 20, gamma: float = -0.1, zeta: float = 1.1):
"""Main function to run AdaRound on a given model.
Args:
model (GraphModule):
train_data (torch.utils.data.DataLoader):
n_samples (int, optional): Defaults to 128.
lr (float, optional): Defaults to 4e-3.
model (GraphModule): Model to adaround.
cali_data (torch.tensor): Stacked tensor.
lr (float, optional): Defaults to 0.001.
batch_size (int, optional): Defaults to 128.
max_iter (int, optional): Defaults to 8000.
weight (float, optional): Defaults to 0.01.
beta (float, optional): Defaults to 20.
gamma (float, optional): Defaults to -0.1.
zeta (float, optional): Defaults to 1.1.
quant_min (int, optional): Defaults to -128.
quant_max (int, optional): Defaults to 127.
per_channel (bool, optional): Defaults to False.
Returns:
GraphModule: Modified copy of the given model.
"""
model.cpu()
print("AdaRound: Quant-Range="
"[{}, {}], Per-Channel={}".format(quant_min, quant_max, per_channel))

# sample data from training data
cali_data = get_cali_samples(train_data, n_samples)
device = cali_data.device
model.to(device)

# apply rewritten deepcopy of GraphModule
quant_model = deepcopy_graphmodule(model)
Expand All @@ -103,50 +77,33 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
fp_observer_binding_dict = _insert_observer(model, "output")
quant_observer_binding_dict = _insert_observer(quant_model, "input")

print("Record Outputs (by CPU) ...")
logger.info("Record Outputs ...")
# apply data to record output
disable_all(model)
enable_quantization(quant_model)

saver = FpOutputSaver(model, observer_binding_dict=fp_observer_binding_dict,
input_data=cali_data)

# get layers for reconstruction
modules = dict(quant_model.named_modules())
quant_module_name_list = _get_quant_modules_by_topology(quant_model)

# TODO: more observer types / affine mode
if per_channel:
qscheme = torch.per_channel_symmetric
ch_axis = 0
else:
qscheme = torch.per_tensor_symmetric
ch_axis = -1

observer_type = MinMaxObserver.with_args(dtype=torch.qint8, quant_min=quant_min, quant_max=quant_max,
reduce_range=False, qscheme=qscheme, ch_axis=ch_axis)

scale_dict = _init_weight_scale(quant_model, quant_observer_binding_dict.keys(), observer_type)

# disable gradient for all parameters
for n, m in quant_model.named_modules():
if hasattr(m, "weight"):
m.weight.requires_grad = False
if hasattr(m, "bias") and getattr(m, "bias") is not None:
m.bias.requires_grad = False

quant_model.cuda()
cali_data = cali_data.cuda()
for p in quant_model.parameters():
p.requires_grad = False

# learn the rounding mask for each layer
for node_name in quant_module_name_list:
print("===> Train for Layer: {}".format(node_name))
logger.info("Adaround for Layer: {}".format(node_name))
# get input and output tensors
output_tensor = saver.get_result_by_name(node_name).cuda()
output_tensor = saver.get_result_by_name(node_name).to(device)
input_observer = modules[quant_observer_binding_dict[node_name].name]
cur_node = _get_node_by_name(quant_model, node_name)
if cur_node is not None:
module = modules[cur_node.target]
else:
raise RuntimeError("Node not found in graph.")
module.eval()

with _Recorder(input_observer):
with torch.no_grad():
Expand All @@ -158,12 +115,14 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
ada_reg_loss = AdaRoundReg(zeta=zeta, gamma=gamma, weight=weight,
temp_anneal=temp_anneal, h_func=_rectified_sigmoid)

scale, zero_point = scale_dict[node_name]
ada_quantizer = AdaRoundQuantizer(reg=ada_reg_loss, ch_axis=ch_axis,
scale=scale, zero_point=zero_point,
quant_min=quant_min, quant_max=quant_max)
weight_fake_quant = module.weight_fake_quant
ch_axis = weight_fake_quant.activation_post_process.ch_axis
scale, zero_point = weight_fake_quant.activation_post_process.calculate_qparams()
quant_min, quant_max = weight_fake_quant.activation_post_process._calculate_qmin_qmax()
ada_quantizer = AdaRoundQuantizer(reg=ada_reg_loss, scale=scale, zero_point=zero_point,
quant_min=quant_min, quant_max=quant_max, ch_axis=ch_axis)

ada_layer = AdaRoundLayer(module, ada_reg_loss, ada_quantizer).cuda()
ada_layer = AdaRoundLayer(module, ada_reg_loss, ada_quantizer).to(device)

alpha = learning_alpha(input_tensor, output_tensor,
ada_layer, ada_reg_loss, lr,
Expand All @@ -173,9 +132,26 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
module.weight.data = ada_quantizer(module.weight, alpha)
module.weight.requires_grad = False

_del_tensor_observer(quant_model, quant_observer_binding_dict)

return quant_model


def _del_tensor_observer(gm: GraphModule, observer_binding_dict):
modules = dict(gm.named_modules())
nodes = list(gm.graph.nodes)
# Quant model tensor observer insert in 'input' mode.
for node in observer_binding_dict.values():
delattr(gm, node.name)
for _node in list(node.users.keys()):
_node.args = node.args
for node in observer_binding_dict.values():
gm.graph.erase_node(node)

gm.recompile()
gm.graph.lint()


def _insert_observer(gm: GraphModule, insert_type="input"):
"""Insert observers to record the input and output of target layers.
Expand Down Expand Up @@ -260,7 +236,7 @@ class FpOutputSaver:
@torch.no_grad()
def __init__(self, fp_gm: GraphModule,
observer_binding_dict: Dict[str, Node],
save_loc="disk", root="./calibration",
save_loc="disk", root="./cali_data_cache",
input_data=None):
"""
Currently, there are two options provided to save floating point model
Expand All @@ -283,8 +259,8 @@ def __init__(self, fp_gm: GraphModule,
self._data = dict()

if self.save_loc == "disk" and not os.path.exists(self.data_root):
raise NotADirectoryError("The given path is not a folder."
"Ensure you give the correct path.")
logger.info('Save data on disk, create directory {}'.format(self.data_root))
os.mkdir(self.data_root)
saving_operation = self._disk_saving_operation \
if self.save_loc == "disk" else self._gpu_saving_operation

Expand Down Expand Up @@ -352,30 +328,6 @@ def _get_quant_modules_by_topology(gm: GraphModule):
module_name_list.append(node.name)
return module_name_list

def _init_weight_scale(gm: GraphModule, observed_module_list, observer_type: Callable):
"""Simulate the fake quant modules to calculate scales and zero-points.
Args:
gm (GraphModule):
observed_module_list (list):
observer_type (Callable):
Returns:
dict:
"""
scale_dict = dict()
modules = dict(gm.named_modules())

for name in observed_module_list:
node = _get_node_by_name(gm, name)
if node.op == "call_module":
observer = observer_type()
module = modules[node.target]
weight = module.weight
observer(weight)
scale, zero_point = observer.calculate_qparams()
scale_dict[name] = (scale.cuda().detach(), zero_point.cuda().detach())
return scale_dict

def _get_node_by_name(gm: GraphModule, node_name: str):
"""
Expand Down Expand Up @@ -446,8 +398,8 @@ def __call__(self, t):


class AdaRoundQuantizer:
def __init__(self, reg: AdaRoundReg, ch_axis: int,
scale, zero_point, quant_min=-128, quant_max=127,
def __init__(self, reg: AdaRoundReg, scale, zero_point,
quant_min=-128, quant_max=127, ch_axis=-1,
soft=True):
self.quant_min = quant_min
self.quant_max = quant_max
Expand All @@ -465,11 +417,6 @@ def __init__(self, reg: AdaRoundReg, ch_axis: int,
def __call__(self, w, alpha):
scale = self.scale
zero_point = self.zero_point
if self.ch_axis != -1:
new_shape = [1] * len(w.shape)
new_shape[self.ch_axis] = w.shape[self.ch_axis]
scale = self.scale.reshape(new_shape)
zero_point = self.zero_point.reshape(new_shape)

if self.soft_quantize:
w = (w / scale).floor() + self.h_func(alpha, self.zeta, self.gamma)
Expand All @@ -483,15 +430,6 @@ def __call__(self, w, alpha):
w = w * scale
return w

def __repr__(self):
scale = self.scale.item()
if self.ch_axis != -1:
scale = "per-channel scale of " + str(tuple(self.scale.shape))
repr_str = "AdaRoundQuantizer(quant_min={}, quant_max={}, scale={}, " \
"gamma={}, zeta={}, soft_quantize={})".format(self.quant_min, self.quant_max, scale,
self.gamma, self.zeta, self.soft_quantize)
return repr_str


class AdaRoundLayer(nn.Module):
def __init__(self, module: nn.Module,
Expand All @@ -506,16 +444,17 @@ def __init__(self, module: nn.Module,
if self.module.bias is not None:
self.module.bias.requires_grad = False

scale = self.quantizer.scale
if self.quantizer.ch_axis != -1:
new_shape = [1] * len(self.module.weight.shape)
new_shape[self.quantizer.ch_axis] = self.module.weight.shape[self.quantizer.ch_axis]
scale = self.quantizer.scale.reshape(new_shape)
self.quantizer.scale = self.quantizer.scale.reshape(new_shape)
self.quantizer.zero_point = self.quantizer.zero_point.reshape(new_shape)

# Init rest.
scale = self.quantizer.scale
rest = self.module.weight / scale - (self.module.weight / scale).floor()
rest = -torch.log((reg.zeta - reg.gamma) / (rest - reg.gamma) - 1)

self.alpha = torch.nn.Parameter(rest.cuda(), True)
self.alpha = torch.nn.Parameter(rest, True)

def forward(self, x):
weight = self.quantizer(self.module.weight, self.alpha)
Expand All @@ -529,6 +468,10 @@ def forward(self, x):
else:
raise RuntimeError("Unsupported module type.")

if isinstance(self.module, (torch.nn.intrinsic.qat.ConvReLU2d,
torch.nn.intrinsic.qat.LinearReLU)):
x = F.relu(x)

return x


Expand All @@ -541,7 +484,7 @@ def learning_alpha(in_tensor: torch.Tensor,
batch_size: int,
max_iter: int) -> torch.Tensor:

optimizer = torch.optim.Adam([ada_layer.alpha], lr=learning_rate)
optimizer = torch.optim.Adam([ada_layer.alpha])

for epoch in range(max_iter):
for idx in range(np.ceil(len(in_tensor) / batch_size).astype(int)):
Expand All @@ -560,33 +503,13 @@ def learning_alpha(in_tensor: torch.Tensor,
loss.backward()
optimizer.step()

if epoch % 200 == 0:
print("Epoch: {:<4} L2 Loss: {:>10.3f} Loss P: "
"{:>8.6f} Loss Reg: {:>5.3f} Beta: {:>3.3f}".format(epoch, loss, loss_p,
loss_reg, ada_reg.beta))
if epoch % 100 == 0:
logger.info("Epoch: {:<4} L2 Loss: {:>10.3f} Loss P: "
"{:>8.6f} Loss Reg: {:>5.3f} Beta: {:>3.3f}".format(epoch, loss, loss_p,
loss_reg, ada_reg.beta))
res = ada_reg.round_mask(ada_layer.alpha)
print("Loss: {:>5.3f} Ceil: {:>5} Floor: {:>5} Total: {:>5} Ratio: {:>.3f}".format(
logger.info("Loss: {:>5.3f} Ceil: {:>5} Floor: {:>5} Total: {:>5} Ratio: {:>.3f}".format(
loss,
res[res + 1e-4 >= 1.0].numel(), res[res <= 1e-4].numel(), torch.numel(res),
(res[res + 1e-4 >= 1.0].numel() + res[res <= 1e-4].numel()) / torch.numel(res)))
return ada_layer.alpha

@torch.no_grad()
def round_to_nearset_quant(m: nn.Module, scale, zero_point, quant_min, quant_max, ch_axis):
w = m.weight
if ch_axis != -1:
new_shape = [1] * len(w.shape)
new_shape[ch_axis] = w.shape[ch_axis]
scale = scale.reshape(new_shape)
zero_point = zero_point.reshape(new_shape)

w = (w / scale).round()
w += zero_point
w = w.clamp(quant_min, quant_max)
w -= zero_point
w = w * scale

return w

if __name__ == "__main__":
pass
return ada_layer.alpha
Loading

0 comments on commit 3567c06

Please sign in to comment.