Skip to content

Commit

Permalink
allow loading checkpoint from url
Browse files Browse the repository at this point in the history
  • Loading branch information
tung-nd committed Feb 15, 2023
1 parent d6dec5f commit 99d0b4d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/climax/global_forecast/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def __init__(
self.load_pretrained_weights(pretrained_path)

def load_pretrained_weights(self, pretrained_path):
checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu"))

if pretrained_path.startswith('http'):
checkpoint = torch.hub.load_state_dict_from_url(pretrained_path)
else:
checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu"))
print("Loading pre-trained checkpoint from: %s" % pretrained_path)
checkpoint_model = checkpoint["state_dict"]
# interpolate positional embedding
Expand Down
5 changes: 4 additions & 1 deletion src/climax/regional_forecast/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def __init__(
self.load_pretrained_weights(pretrained_path)

def load_pretrained_weights(self, pretrained_path):
checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu"))
if pretrained_path.startswith('http'):
checkpoint = torch.hub.load_state_dict_from_url(pretrained_path)
else:
checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu"))

print("Loading pre-trained checkpoint from: %s" % pretrained_path)
checkpoint_model = checkpoint["state_dict"]
Expand Down

0 comments on commit 99d0b4d

Please sign in to comment.