Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
wgzintel committed Jun 4, 2024
2 parents ca46673 + 096d94b commit d2350c0
Show file tree
Hide file tree
Showing 42 changed files with 2,508 additions and 393 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/test_inc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
python -m pip install --upgrade pip
pip install cmake
pip install py-cpuinfo
pip install torch==2.2 torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch==2.3.0 torchaudio==2.3.0 torchvision==0.18 --index-url https://download.pytorch.org/whl/cpu
pip install .[neural-compressor,diffusers,tests]
pip install intel-extension-for-transformers
pip install peft
Expand All @@ -43,7 +43,6 @@ jobs:
- name: Test IPEX
run: |
pip uninstall -y intel-extension-for-transformers
pip install torch==2.1.0 torchaudio==2.1.0 torchvision==0.16 --extra-index-url https://download.pytorch.org/whl/cpu
pip install intel-extension-for-pytorch==2.1.100
pip install intel-extension-for-pytorch==2.3.0
pytest tests/neural_compressor/test_ipex.py
2 changes: 1 addition & 1 deletion .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==2.2 torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests]
- name: Test with Pytest
run: |
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,8 @@ Do not forget to install requirements for every example:
cd <example-folder>
pip install -r requirements.txt
```


## Gaudi

To train your model on [Intel Gaudi AI Accelerators (HPU)](https://docs.habana.ai/en/latest/index.html), check out [Optimum Habana](https://github.com/huggingface/optimum-habana) which provides a set of tools enabling easy model loading, training and inference on single- and multi-HPU settings for different downstream tasks. After training your model, feel free to submit it to the Intel [leaderboard](https://huggingface.co/spaces/Intel/powered_by_intel_llm_leaderboard) which is designed to evaluate, score, and rank open-source LLMs that have been pre-trained or fine-tuned on Intel Hardwares. Models submitted to the leaderboard will be evaluated on the Intel Developer Cloud. The evaluation platform consists of Gaudi Accelerators and Xeon CPUs running benchmarks from the Eleuther AI Language Model Evaluation Harness.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
"ignored_scopes": [
"{re}.*__add___[0-1]",
"{re}.*layer_norm_0",
"{re}.*matmul_1",
"{re}.*__truediv__*"
]
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
"ignored_scopes": [
"{re}.*__add___[0-1]",
"{re}.*layer_norm_0",
"{re}.*matmul_1",
"{re}.*__truediv__*"
]
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
"ignored_scopes": [
"{re}.*__add___[0-1]",
"{re}.*layer_norm_0",
"{re}.*matmul_1",
"{re}.*__truediv__*"
]
}
]
2 changes: 1 addition & 1 deletion notebooks/openvino/quantized_generation_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"metadata": {},
"outputs": [],
"source": [
"# ! pip install optimum[openvino,nncf] torch"
"# ! pip install optimum[openvino,nncf] torch==2.2.2"
]
},
{
Expand Down
14 changes: 10 additions & 4 deletions notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
"import transformers\n",
"from pathlib import Path\n",
"from openvino.runtime import Core\n",
"from optimum.intel import OVStableDiffusionPipeline, OVWeightQuantizationConfig\n",
"from optimum.intel import OVConfig, OVQuantizer, OVStableDiffusionPipeline, OVWeightQuantizationConfig\n",
"from optimum.intel.openvino.configuration import OVQuantizationMethod\n",
"\n",
"transformers.logging.set_verbosity_error()\n",
"datasets.logging.set_verbosity_error()"
Expand Down Expand Up @@ -198,9 +199,14 @@
},
"outputs": [],
"source": [
"quantization_config = OVWeightQuantizationConfig(bits=8, dataset=calibration_dataset, num_samples=NUM_SAMPLES)\n",
"int8_pipe = OVStableDiffusionPipeline.from_pretrained(model_id=MODEL_ID, export=True, quantization_config=quantization_config)\n",
"int8_pipe.save_pretrained(int8_model_path)"
"int8_pipe = OVStableDiffusionPipeline.from_pretrained(model_id=MODEL_ID, export=True)\n",
"quantization_config = OVWeightQuantizationConfig(bits=8, num_samples=NUM_SAMPLES, quant_method=OVQuantizationMethod.HYBRID)\n",
"quantizer = OVQuantizer(int8_pipe)\n",
"quantizer.quantize(\n",
" ov_config=OVConfig(quantization_config=quantization_config),\n",
" calibration_dataset=calibration_dataset,\n",
" save_directory=int8_model_path\n",
")"
]
},
{
Expand Down
31 changes: 28 additions & 3 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ def parse_args_openvino(parser: "ArgumentParser"):
"or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models."
),
)
optional_group.add_argument(
"--all-layers",
action="store_true",
default=None,
help=(
"Whether embeddings and last MatMul layers should be compressed to INT4. If not provided an weight "
"compression is applied, they are compressed to INT8."
),
)
optional_group.add_argument(
"--disable-stateful",
action="store_true",
Expand Down Expand Up @@ -198,6 +207,7 @@ def run(self):
and self.args.ratio is None
and self.args.group_size is None
and self.args.sym is None
and self.args.all_layers is None
and self.args.model in _DEFAULT_4BIT_CONFIGS
):
quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model]
Expand All @@ -207,6 +217,7 @@ def run(self):
"ratio": 1 if is_int8 else (self.args.ratio or 0.8),
"sym": self.args.sym or False,
"group_size": -1 if is_int8 else self.args.group_size,
"all_layers": None if is_int8 else self.args.all_layers,
}

if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
Expand All @@ -226,6 +237,9 @@ def run(self):
)
library_name = "transformers"

if self.args.convert_tokenizer:
logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.")

if (
library_name == "diffusers"
and ov_config
Expand Down Expand Up @@ -261,10 +275,21 @@ def run(self):
)
model.save_pretrained(self.args.output)

else:
if self.args.convert_tokenizer:
logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.")
if self.args.disable_convert_tokenizer:
return

# avoid import when using other exporters (IPEX, INC)
from ...exporters.openvino.convert import export_tokenizer

output = Path(self.args.output)
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
export_tokenizer(tokenizer, output / "tokenizer")

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
export_tokenizer(tokenizer_2, output / "tokenizer_2")
else:
# TODO : add input shapes
main_export(
model_name_or_path=self.args.model,
Expand Down
14 changes: 11 additions & 3 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from optimum.exporters import TasksManager
from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
from optimum.exporters.openvino.convert import export_from_model, export_tokenizer
from optimum.exporters.openvino.convert import export_from_model
from optimum.intel.utils.import_utils import is_openvino_tokenizers_available, is_transformers_version
from optimum.utils.save_utils import maybe_load_preprocessors

Expand Down Expand Up @@ -219,6 +219,10 @@ def main_export(
model_type = config.model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
if custom_export_configs is None:
raise ValueError(
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum-intel/issues if you would like the model type {model_type} to be supported natively in the OpenVINO export."
)
elif task not in TasksManager.get_supported_tasks_for_model_type(
model_type, exporter="openvino", library_name=library_name
):
Expand All @@ -232,6 +236,7 @@ def main_export(
raise ValueError(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
loading_kwargs["attn_implementation"] = "eager"
# there are some difference between remote and in library representation of past key values for some models,
Expand Down Expand Up @@ -355,6 +360,9 @@ class StoreAttr(object):
**kwargs_shapes,
)

# hide openvino import when using other exporters
from optimum.exporters.openvino.convert import export_tokenizer

if convert_tokenizer and is_openvino_tokenizers_available():
if library_name != "diffusers":
tokenizer = next(
Expand All @@ -373,11 +381,11 @@ class StoreAttr(object):
else:
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
export_tokenizer(tokenizer, output)
export_tokenizer(tokenizer, output / "tokenizer")

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
export_tokenizer(tokenizer_2, output, suffix="_2")
export_tokenizer(tokenizer_2, output / "tokenizer_2")
elif convert_tokenizer and not is_openvino_tokenizers_available():
logger.warning("Tokenizer won't be converted.")

Expand Down
20 changes: 13 additions & 7 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def export_from_model(
# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_export_configs is None:
raise ValueError(
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export."
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum-intel/issues if you would like the model type {model_type} to be supported natively in the OpenVINO export."
)

if task.startswith("text-generation") and model.config.is_encoder_decoder:
Expand Down Expand Up @@ -614,7 +614,12 @@ def export_from_model(
model.config.save_pretrained(output)
generation_config = getattr(model, "generation_config", None)
if generation_config is not None:
generation_config.save_pretrained(output)
try:
generation_config.save_pretrained(output)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)

model_name_or_path = model.config._name_or_path
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)
Expand Down Expand Up @@ -667,20 +672,21 @@ def export_tokenizer(
output: Union[str, Path],
suffix: Optional[str] = "",
):
from optimum.intel.openvino import OV_DETOKENIZER_NAME, OV_TOKENIZER_NAME # avoid circular imports
# avoid circular imports
from optimum.intel.openvino import OV_DETOKENIZER_NAME, OV_TOKENIZER_NAME
from optimum.intel.openvino.utils import maybe_convert_tokenizer_to_fast

try:
from openvino_tokenizers import convert_tokenizer
except ModuleNotFoundError:
# avoid this message before tokenizers are part of the openvino dependencies
# logger.info(
# "Run `pip install openvino-tokenizers[transformers]` to get OpenVINO tokenizer/detokenizer models."
# )
return

if not isinstance(output, Path):
output = Path(output)

if output.exists():
tokenizer = maybe_convert_tokenizer_to_fast(tokenizer, output)

try:
converted = convert_tokenizer(tokenizer, with_detokenizer=True)
except NotImplementedError:
Expand Down
Loading

0 comments on commit d2350c0

Please sign in to comment.