diff --git a/ggt/modules/autocrop.py b/ggt/modules/autocrop.py index ae5f3857bc..0bd7fd1a57 100644 --- a/ggt/modules/autocrop.py +++ b/ggt/modules/autocrop.py @@ -90,7 +90,8 @@ def main( model = model.to(device) # Load the model from a saved state if provided - model.load_state_dict(torch.load(model_path)) + model.load_state_dict(torch.load(model_path, + map_location=torch.device(device))) # Collect all images, then iterate images = glob.glob(str(Path(image_dir) / "*.fits"))