diff --git a/model-trainer-huggingface/src/test_utils.py b/model-trainer-huggingface/src/test_utils.py index 1b5941e..74c4f3f 100644 --- a/model-trainer-huggingface/src/test_utils.py +++ b/model-trainer-huggingface/src/test_utils.py @@ -11,13 +11,13 @@ def test_parse_training_args_int_float(): args = parse_training_args(params) assert args.num_train_epochs == 1.0 assert args.max_steps == 5 - assert args.output_dir == "/content/model/checkpoints" + assert args.output_dir == "/content/artifacts/checkpoints" params = {"num_train_epochs": "2.0", "max_steps": "5"} args = parse_training_args(params) assert args.num_train_epochs == 2.0 assert args.max_steps == 5 - assert args.output_dir == "/content/model/checkpoints" + assert args.output_dir == "/content/artifacts/checkpoints" def test_parse_training_args_bool(): diff --git a/model-trainer-huggingface/src/utils.py b/model-trainer-huggingface/src/utils.py index 2f599f3..ddbe983 100644 --- a/model-trainer-huggingface/src/utils.py +++ b/model-trainer-huggingface/src/utils.py @@ -14,7 +14,7 @@ def parse_training_args(params: typing.Mapping) -> TrainingArguments: learning_rate=2e-4, fp16=True, logging_steps=1, - output_dir="/content/model/checkpoints", + output_dir="/content/artifacts/checkpoints", optim="paged_adamw_32bit", )