diff --git a/cameratokeyboard/model/train.py b/cameratokeyboard/model/train.py index dfb6e7f..369468e 100644 --- a/cameratokeyboard/model/train.py +++ b/cameratokeyboard/model/train.py @@ -23,7 +23,7 @@ def __init__(self, config: Config) -> None: def run(self): self._parition_data() - self._train() + return self._train() def _are_training_data_up_to_date(self): if not os.path.exists(self.dataset_path): @@ -64,3 +64,5 @@ def _train(self): model_path = os.path.join(results.save_dir, "weights", "best.pt") target_model_path = os.path.join("cameratokeyboard", "model.pt") shutil.copyfile(model_path, target_model_path) + + return results diff --git a/ci_train_and_upload.py b/ci_train_and_upload.py index 7698fba..14e2341 100644 --- a/ci_train_and_upload.py +++ b/ci_train_and_upload.py @@ -13,7 +13,6 @@ REGION = os.environ["AWS_REGION"] BUCKET_NAME = os.environ["AWS_BUCKET_NAME"] RAW_DATASET_PATH = "raw_dataset" -RUNS_DIR = os.path.join("runs", "detect") REMOTE_MODELS_DIR = "models" logger = get_logger() @@ -48,17 +47,11 @@ def train() -> str: logger.info("Training the model") config = Config() - config.processing_device = "cpu" trainer = Trainer(config) - trainer.run() + results = trainer.run() - def sort_key(p): - return os.path.getctime(os.path.join(RUNS_DIR, p)) - - last_created_dir = max(os.listdir(RUNS_DIR), key=sort_key) - - return os.path.join(RUNS_DIR, last_created_dir, "weights", "best.pt") + return os.path.join(results.save_dir, "weights", "best.pt") def upload_model(path: str, version: str):