From 66c6b6c39f3ab3438619edc233fc35d6f4c6f4dd Mon Sep 17 00:00:00 2001 From: madt2709 Date: Sat, 19 Oct 2024 22:09:21 -0700 Subject: [PATCH 1/2] Fix load config when using bools Signed-off-by: madt2709 --- tests/data/test_config.yaml | 1 + tests/test_utils.py | 2 ++ vllm/utils.py | 8 ++++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/data/test_config.yaml b/tests/data/test_config.yaml index 42f4f6f7bb992..a16857b5f2fbd 100644 --- a/tests/data/test_config.yaml +++ b/tests/data/test_config.yaml @@ -1,3 +1,4 @@ port: 12312 served_model_name: mymodel tensor_parallel_size: 2 +trust_remote_code: true diff --git a/tests/test_utils.py b/tests/test_utils.py index 0fed8e678fc76..6393303ad515f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -141,6 +141,7 @@ def parser_with_config(): parser.add_argument('--config', type=str) parser.add_argument('--port', type=int) parser.add_argument('--tensor-parallel-size', type=int) + parser.add_argument('--trust-remote-code', action='store_true') return parser @@ -214,6 +215,7 @@ def test_config_args(parser_with_config): args = parser_with_config.parse_args( ['serve', 'mymodel', '--config', './data/test_config.yaml']) assert args.tensor_parallel_size == 2 + assert args.trust_remote_code def test_config_file(parser_with_config): diff --git a/vllm/utils.py b/vllm/utils.py index fba9804289b94..16c7a05cd984d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1283,8 +1283,12 @@ def _load_config_file(file_path: str) -> List[str]: raise ex for key, value in config.items(): - processed_args.append('--' + key) - processed_args.append(str(value)) + if isinstance(value, bool): + if value: + processed_args.append('--' + key) + else: + processed_args.append('--' + key) + processed_args.append(str(value)) return processed_args From 0431d2d5ac5683e4d4d68eb6d3aaf62ef5a4b35a Mon Sep 17 00:00:00 2001 From: madt2709 Date: Sun, 20 Oct 2024 11:14:45 -0700 Subject: [PATCH 2/2] Handle StoreBoolean args Signed-off-by: madt2709 --- tests/data/test_config.yaml | 1 + tests/test_utils.py | 4 +++- vllm/engine/arg_utils.py | 14 +------------- vllm/utils.py | 29 ++++++++++++++++++++++------- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/tests/data/test_config.yaml b/tests/data/test_config.yaml index a16857b5f2fbd..5090e8f357bb8 100644 --- a/tests/data/test_config.yaml +++ b/tests/data/test_config.yaml @@ -2,3 +2,4 @@ port: 12312 served_model_name: mymodel tensor_parallel_size: 2 trust_remote_code: true +multi_step_stream_outputs: false diff --git a/tests/test_utils.py b/tests/test_utils.py index 6393303ad515f..a731b11eae81c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,7 @@ import pytest -from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs, +from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, get_open_port, merge_async_iterators, supports_kw) from .utils import error_on_warning @@ -142,6 +142,7 @@ def parser_with_config(): parser.add_argument('--port', type=int) parser.add_argument('--tensor-parallel-size', type=int) parser.add_argument('--trust-remote-code', action='store_true') + parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean) return parser @@ -216,6 +217,7 @@ def test_config_args(parser_with_config): ['serve', 'mymodel', '--config', './data/test_config.yaml']) assert args.tensor_parallel_size == 2 assert args.trust_remote_code + assert not args.multi_step_stream_outputs def test_config_file(parser_with_config): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c49f475b9ee61..38687809a31f6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -19,7 +19,7 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, StoreBoolean if TYPE_CHECKING: from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -1144,18 +1144,6 @@ def add_cli_args(parser: FlexibleArgumentParser, return parser -class StoreBoolean(argparse.Action): - - def __call__(self, parser, namespace, values, option_string=None): - if values.lower() == "true": - setattr(namespace, self.dest, True) - elif values.lower() == "false": - setattr(namespace, self.dest, False) - else: - raise ValueError(f"Invalid boolean value: {values}. " - "Expected 'true' or 'false'.") - - # These functions are used by sphinx to build the documentation def _engine_args_parser(): return EngineArgs.add_cli_args(FlexibleArgumentParser()) diff --git a/vllm/utils.py b/vllm/utils.py index 16c7a05cd984d..ac27b1ac4ab2d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1155,6 +1155,18 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: return wrapper +class StoreBoolean(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + if values.lower() == "true": + setattr(namespace, self.dest, True) + elif values.lower() == "false": + setattr(namespace, self.dest, False) + else: + raise ValueError(f"Invalid boolean value: {values}. " + "Expected 'true' or 'false'.") + + class FlexibleArgumentParser(argparse.ArgumentParser): """ArgumentParser that allows both underscore and dash in names.""" @@ -1163,7 +1175,7 @@ def parse_args(self, args=None, namespace=None): args = sys.argv[1:] if '--config' in args: - args = FlexibleArgumentParser._pull_args_from_config(args) + args = self._pull_args_from_config(args) # Convert underscores to dashes and vice versa in argument names processed_args = [] @@ -1181,8 +1193,7 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) - @staticmethod - def _pull_args_from_config(args: List[str]) -> List[str]: + def _pull_args_from_config(self, args: List[str]) -> List[str]: """Method to pull arguments specified in the config file into the command-line args variable. @@ -1226,7 +1237,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]: file_path = args[index + 1] - config_args = FlexibleArgumentParser._load_config_file(file_path) + config_args = self._load_config_file(file_path) # 0th index is for {serve,chat,complete} # followed by model_tag (only for serve) @@ -1247,8 +1258,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]: return args - @staticmethod - def _load_config_file(file_path: str) -> List[str]: + def _load_config_file(self, file_path: str) -> List[str]: """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml @@ -1282,8 +1292,13 @@ def _load_config_file(file_path: str) -> List[str]: Make sure path is correct", file_path) raise ex + store_boolean_arguments = [ + action.dest for action in self._actions + if isinstance(action, StoreBoolean) + ] + for key, value in config.items(): - if isinstance(value, bool): + if isinstance(value, bool) and key not in store_boolean_arguments: if value: processed_args.append('--' + key) else: