Skip to content

Commit

Permalink
Merge branch 'main' of github.com:calico/baskerville into main
Browse files Browse the repository at this point in the history
  • Loading branch information
davek44 committed Jun 30, 2024
2 parents 89e6a33 + 9d51a09 commit 699b377
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/baskerville/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import numpy as np
import tensorflow as tf

import tempfile
from baskerville.helpers.gcs_utils import is_gcs_path, upload_folder_gcs
from baskerville import metrics


Expand Down Expand Up @@ -119,6 +120,15 @@ def __init__(
self.batch_size = self.train_data[0].batch_size
self.compiled = False

# if log_dir is in gcs then create a local temp dir
if is_gcs_path(self.log_dir):
folder_name = "/".join(self.log_dir.split("/")[3:])
self.log_dir = tempfile.mkdtemp() + "/" + folder_name
self.gcs_log_dir = log_dir
self.gcs = True
else:
self.gcs = False

# early stopping
self.patience = self.params.get("patience", 20)

Expand Down Expand Up @@ -498,6 +508,10 @@ def eval_step1_distr(xd, yd):
print(" - valid_r2: %.4f" % valid_r2[di].result().numpy(), end="")
early_stop_stat = valid_r[di].result().numpy()

# upload to gcs
if self.gcs:
upload_folder_gcs(train_log_dir, self.gcs_log_dir)
upload_folder_gcs(valid_log_dir, self.gcs_log_dir)
# checkpoint
managers[di].save()
model.save(
Expand Down Expand Up @@ -697,6 +711,11 @@ def eval_step_distr(xd, yd):
end="",
)

# upload to gcs
if self.gcs:
upload_folder_gcs(train_log_dir, self.gcs_log_dir)
upload_folder_gcs(valid_log_dir, self.gcs_log_dir)

# checkpoint
manager.save()
seqnn_model.save("%s/model_check.h5" % self.out_dir)
Expand Down

0 comments on commit 699b377

Please sign in to comment.