Skip to content

Commit

Permalink
[fbsync] Fix quantized references (#8073)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D50789093

fbshipit-source-id: 17b6840f89063eeef1fb429ff04817f46e5919eb
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Nov 13, 2023
1 parent 235f632 commit 8b381f3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
14 changes: 7 additions & 7 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ For all post training quantized models, the settings are:
2. num_workers: 16
3. batch_size: 32
4. eval_batch_size: 128
5. backend: 'fbgemm'
5. qbackend: 'fbgemm'

```
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL'
python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' --model='$MODEL'
```
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d`, `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0`.

Expand All @@ -301,12 +301,12 @@ Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `re
Here are commands that we use to quantize the `shufflenet_v2_x1_5` and `shufflenet_v2_x2_0` models.
```
# For shufflenet_v2_x1_5
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \
python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' \
--model=shufflenet_v2_x1_5 --weights="ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1" \
--train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/
# For shufflenet_v2_x2_0
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \
python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' \
--model=shufflenet_v2_x2_0 --weights="ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1" \
--train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/
```
Expand All @@ -317,7 +317,7 @@ For Mobilenet-v2, the model was trained with quantization aware training, the se
1. num_workers: 16
2. batch_size: 32
3. eval_batch_size: 128
4. backend: 'qnnpack'
4. qbackend: 'qnnpack'
5. learning-rate: 0.0001
6. num_epochs: 90
7. num_observer_update_epochs:4
Expand All @@ -339,7 +339,7 @@ For Mobilenet-v3 Large, the model was trained with quantization aware training,
1. num_workers: 16
2. batch_size: 32
3. eval_batch_size: 128
4. backend: 'qnnpack'
4. qbackend: 'qnnpack'
5. learning-rate: 0.001
6. num_epochs: 90
7. num_observer_update_epochs:4
Expand All @@ -359,7 +359,7 @@ For post training quant, device is set to CPU. For training, the device is set t
### Command to evaluate quantized models using the pre-trained weights:

```
python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
python train_quantization.py --device='cpu' --test-only --qbackend='<qbackend>' --model='<model_name>'
```

For inception_v3 you need to pass the following extra parameters:
Expand Down
20 changes: 14 additions & 6 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def main(args):
raise RuntimeError("Post training quantization example should not be performed on distributed mode")

# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
torch.backends.quantized.engine = args.backend
if args.qbackend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.qbackend))
torch.backends.quantized.engine = args.qbackend

device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -55,7 +55,7 @@ def main(args):

if not (args.test_only or args.post_training_quantize):
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
Expand Down Expand Up @@ -89,7 +89,7 @@ def main(args):
)
model.eval()
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
Expand Down Expand Up @@ -161,7 +161,7 @@ def get_args_parser(add_help=True):

parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")

parser.add_argument(
Expand Down Expand Up @@ -257,9 +257,17 @@ def get_args_parser(add_help=True):
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")

return parser


if __name__ == "__main__":
args = get_args_parser().parse_args()
if args.backend in ("fbgemm", "qnnpack"):
raise ValueError(
"The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) "
"instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend."
)
main(args)

0 comments on commit 8b381f3

Please sign in to comment.