Skip to content

Commit

Permalink
Climate projection task (#26)
Browse files Browse the repository at this point in the history
* add climate projection code

* add readme for climate projection

* minor updates
  • Loading branch information
tung-nd authored Jul 10, 2023
1 parent efd6de4 commit 5533b8c
Show file tree
Hide file tree
Showing 10 changed files with 977 additions and 28 deletions.
121 changes: 121 additions & 0 deletions configs/climate_projection.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
seed_everything: 42

# ---------------------------- TRAINER -------------------------------------------
trainer:
default_root_dir: ${oc.env:AMLT_OUTPUT_DIR,/home/tungnd/ClimaX/exps/climate_projection_climax}

precision: 16

gpus: null
num_nodes: 1
accelerator: gpu
strategy: ddp

min_epochs: 1
max_epochs: 50
enable_progress_bar: true

sync_batchnorm: True
enable_checkpointing: True
resume_from_checkpoint: null

# debugging
fast_dev_run: false

logger:
class_path: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${trainer.default_root_dir}/logs
name: null
version: null
log_graph: False
default_hp_metric: True
prefix: ""

callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: "step"

- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: "${trainer.default_root_dir}/checkpoints/"
monitor: "val/w_mse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
verbose: False
filename: "epoch_{epoch:03d}"
auto_insert_metric_name: False

- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: "val/w_mse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
patience: 5 # how many validation epochs of not improving until training stops
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement

- class_path: pytorch_lightning.callbacks.RichModelSummary
init_args:
max_depth: -1

- class_path: pytorch_lightning.callbacks.RichProgressBar

# ---------------------------- MODEL -------------------------------------------
model:
lr: 5e-4
beta_1: 0.9
beta_2: 0.999
weight_decay: 1e-5
warmup_epochs: 60
max_epochs: 600
warmup_start_lr: 1e-8
eta_min: 1e-8
pretrained_path: "https://huggingface.co/tungnd/climax/resolve/main/5.625deg.ckpt"

net:
class_path: climax.climate_projection.arch.ClimaXClimateBench
init_args:
default_vars: [
'CO2',
'SO2',
'CH4',
'BC'
]
out_vars: "tas" # diurnal_temperature_range, tas, pr, pr90
img_size: [32, 64]
time_history: 10
patch_size: 2
embed_dim: 1024
depth: 8
num_heads: 16
mlp_ratio: 4
drop_path: 0.1
drop_rate: 0.1
parallel_patch_embed: False
freeze_encoder: True

# ---------------------------- DATA -------------------------------------------
data:
root_dir: /home/data/datasets/climate-learn/climatebench/5.625deg/
history: 10
list_train_simu: [
'ssp126',
'ssp370',
'ssp585',
'historical',
'hist-GHG',
'hist-aer'
]
list_test_simu: ['ssp245']
variables: [
'CO2',
'SO2',
'CH4',
'BC'
]
out_variables: 'tas'
train_ratio: 0.9
batch_size: 1
num_workers: 1
pin_memory: False
36 changes: 36 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,42 @@ python src/climax/regional_forecast/train.py --config configs/regional_forecast_
```
To train ClimaX from scratch, set `--model.pretrained_path=""`.

## Climate Projection

### Data Preparation

First, download [ClimateBench](https://doi.org/10.5281/zenodo.5196512) data. ClimaX can work with either the original ClimateBench data or the regridded version. In the experiment in the paper, we regridded to ClimateBench data to 5.625 degree. To do that, run
```bash
python src/data_preprocessing/regrid_climatebench.py /mnt/data/climatebench/train_val \
--save_path /mnt/data/climatebench/5.625deg/train_val --ddeg_out 5.625
```
and
```bash
python src/data_preprocessing/regrid_climatebench.py /mnt/data/climatebench/test \
--save_path /mnt/data/climatebench/5.625deg/test --ddeg_out 5.625
```

### Training

To finetune ClimaX for climate projection, use
```
python src/climax/climate_projection/train.py --config <path/to/config>
```
For example, to finetune ClimaX on 8 GPUs use
```bash
python python src/climax/climate_projection/train.py --config configs/climate_projection.yaml \
--trainer.strategy=ddp --trainer.devices=8 \
--trainer.max_epochs=50 \
--data.root_dir=/mnt/data/climatebench/5.625deg \
--data.out_variables="tas" \
--data.batch_size=16 \
--model.pretrained_path='https://huggingface.co/tungnd/climax/resolve/main/5.625deg.ckpt' \
--model.out_vars="tas" \
--model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \
--model.weight_decay=1e-5
```
To train ClimaX from scratch, set `--model.pretrained_path=""`.

## Visualization

Coming soon
Empty file.
144 changes: 144 additions & 0 deletions src/climax/climate_projection/arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import numpy as np
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
import torch
import torch.nn as nn
from climax.arch import ClimaX
from climax.utils.pos_embed import get_1d_sincos_pos_embed_from_grid


class ClimaXClimateBench(ClimaX):
def __init__(
self,
default_vars,
out_vars,
img_size=[32, 64],
time_history=1,
patch_size=2,
embed_dim=1024,
depth=8,
decoder_depth=2,
num_heads=16,
mlp_ratio=4.0,
drop_path=0.1,
drop_rate=0.1,
parallel_patch_embed=False,
freeze_encoder=False,
):
assert out_vars is not None

super().__init__(
default_vars,
img_size,
patch_size,
embed_dim,
depth,
decoder_depth,
num_heads,
mlp_ratio,
drop_path,
drop_rate,
parallel_patch_embed
)

self.out_vars = out_vars
self.time_history = time_history
self.freeze_encoder = freeze_encoder

# used to aggregate multiple timesteps in the input
self.time_pos_embed = nn.Parameter(torch.zeros(1, time_history, embed_dim), requires_grad=True)
self.time_agg = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.time_query = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)

# initialize time embedding
time_pos_embed = get_1d_sincos_pos_embed_from_grid(self.time_pos_embed.shape[-1], np.arange(self.time_history))
self.time_pos_embed.data.copy_(torch.from_numpy(time_pos_embed).float().unsqueeze(0))

# overwrite ClimaX
# use a linear prediction head for this task
self.head = nn.Linear(embed_dim, img_size[0]*img_size[1])

if freeze_encoder:
for name, p in self.blocks.named_parameters():
name = name.lower()
# we do not freeze the norm layers, as suggested by https://arxiv.org/abs/2103.05247
if 'norm' in name:
continue
else:
p.requires_grad_(False)

def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables):
# x: `[B, T, V, H, W]` shape.

if isinstance(variables, list):
variables = tuple(variables)

b, t, _, _, _ = x.shape
x = x.flatten(0, 1) # BxT, V, H, W

# tokenize each variable separately
embeds = []
var_ids = self.get_var_ids(variables, x.device)

if self.parallel_patch_embed:
x = self.token_embeds(x, var_ids) # BxT, V, L, D
else:
for i in range(len(var_ids)):
id = var_ids[i]
embeds.append(self.token_embeds[id](x[:, i : i + 1]))
x = torch.stack(embeds, dim=1) # BxT, V, L, D

# add variable embedding
var_embed = self.get_var_emb(self.var_embed, variables)
x = x + var_embed.unsqueeze(2) # BxT, V, L, D

# variable aggregation
x = self.aggregate_variables(x) # BxT, L, D

# add pos embedding
x = x + self.pos_embed

# add time embedding
# time emb: 1, T, D
x = x.unflatten(0, sizes=(b, t)) # B, T, L, D
x = x + self.time_pos_embed.unsqueeze(2)

# add lead time embedding
lead_time_emb = self.lead_time_embed(lead_times.unsqueeze(-1)) # B, D
lead_time_emb = lead_time_emb.unsqueeze(1).unsqueeze(2)
x = x + lead_time_emb # B, T, L, D

x = x.flatten(0, 1) # BxT, L, D

x = self.pos_drop(x)

# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x) # BxT, L, D
x = x.unflatten(0, sizes=(b, t)) # B, T, L, D

# global average pooling, also used in CNN-LSTM baseline in ClimateBench
x = x.mean(-2) # B, T, D
time_query = self.time_query.repeat_interleave(x.shape[0], dim=0)
x, _ = self.time_agg(time_query, x, x) # B, 1, D

return x

def forward(self, x, y, lead_times, variables, out_variables, metric, lat):
x = self.forward_encoder(x, lead_times, variables) # B, 1, D
preds = self.head(x)
preds = preds.reshape(-1, 1, self.img_size[0], self.img_size[1]) # B, 1, H, W
if metric is None:
loss = None
else:
loss = [m(preds, y, out_variables, lat) for m in metric]
return loss, preds
Loading

0 comments on commit 5533b8c

Please sign in to comment.