diff --git a/src/vllm_tgis_adapter/tgis_utils/args.py b/src/vllm_tgis_adapter/tgis_utils/args.py index 7be4ce1..475e14a 100644 --- a/src/vllm_tgis_adapter/tgis_utils/args.py +++ b/src/vllm_tgis_adapter/tgis_utils/args.py @@ -15,6 +15,10 @@ def _to_env_var(arg_name: str) -> str: return arg_name.upper().replace("-", "_") +def _bool_from_string(val: str) -> bool: + return val.lower().strip() == "true" or val == "1" + + def _switch_action_default(action: argparse.Action) -> None: """Switch to using env var fallback for all args.""" env_val = os.environ.get(_to_env_var(action.dest)) @@ -22,8 +26,8 @@ def _switch_action_default(action: argparse.Action) -> None: return val: bool | str | int - if action.type is bool: - val = env_val.lower() == "true" or env_val == "1" + if action.type in [bool, _bool_from_string]: + val = _bool_from_string(env_val) elif action.type is int: val = int(env_val) else: @@ -111,9 +115,11 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: # map to tensor_parallel_size parser.add_argument("--num-shard", type=int) # TODO check boolean behaviour for env vars and defaults - parser.add_argument("--output-special-tokens", type=bool, default=False) parser.add_argument( - "--default-include-stop-seqs", type=bool, default=True + "--output-special-tokens", type=_bool_from_string, default=False + ) + parser.add_argument( + "--default-include-stop-seqs", type=_bool_from_string, default=True ) # TODO TBD parser.add_argument("--grpc-port", type=int, default=8033) @@ -135,9 +141,13 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--speculator-n-candidates", type=int) parser.add_argument("--speculator-max-batch-size", type=int) # allow re-enabling vllm native per-request logging - parser.add_argument("--enable-vllm-log-requests", type=bool, default=False) + parser.add_argument( + "--enable-vllm-log-requests", type=_bool_from_string, default=False + ) # set to true to disable producing prompt logprobs on all requests - parser.add_argument("--disable-prompt-logprobs", type=bool, default=False) + parser.add_argument( + "--disable-prompt-logprobs", type=_bool_from_string, default=False + ) # TODO check/add other args here