Skip to content

Commit

Permalink
Merge pull request #12 from una-auxme/dev
Browse files Browse the repository at this point in the history
Added new arguments, refactoring and bug fixes
  • Loading branch information
JulianTrommer authored Jul 16, 2024
2 parents c126cb3 + 7f6897c commit fc799b8
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 117 deletions.
28 changes: 0 additions & 28 deletions docs/src/index.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/src/strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

```@docs
Collocation
RandomCollocation
SingleShooting
MultipleShooting
```
85 changes: 45 additions & 40 deletions src/MeshGraphNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,19 @@ export train_network, eval_network, der_minmax, data_meanstd
steps::Integer = 10e6
checkpoint::Integer = 10000
norm_steps::Integer = 1000
max_norm_steps::Integer = 10f6
types_updated::Vector{Integer} = [0, 5]
types_noisy::Vector{Integer} = [0]
training_strategy::TrainingStrategy = Collocation()
use_cuda::Bool = true
gpu_idx::Integer = CUDA.deviceid()
gpu_device::Integer = CUDA.device()
cell_idxs::Vector{Integer} = [0]
num_rollouts::Integer = 10
use_valid::Bool = true
solver_valid::OrdinaryDiffEqAlgorithm = Tsit5()
solver_valid_dt::Union{Nothing, Float32} = nothing
wandb_logger::Union{Nothing, Wandb.WandbLogger} = nothing
reset_valid::Bool = false
end

"""
Expand All @@ -69,7 +71,7 @@ Initializes the normalisers based on the given dataset and its metadata.
- Dictionary of each node feature and its normaliser as key-value pair.
- Dictionary of each output feature and its normaliser as key-value pair.
"""
function calc_norms(dataset, norm_steps, device)
function calc_norms(dataset, device)
quantities = 0
n_norms = Dict{String, Union{NormaliserOffline, NormaliserOnline}}()
o_norms = Dict{String, Union{NormaliserOffline, NormaliserOnline}}()
Expand Down Expand Up @@ -122,7 +124,7 @@ function calc_norms(dataset, norm_steps, device)
if haskey(dataset.meta["features"][feature], "output_min") && haskey(dataset.meta["features"][feature], "output_max")
o_norms[feature] = NormaliserOfflineMinMax(Float32(dataset.meta["features"][feature]["output_min"]), Float32(dataset.meta["features"][feature]["output_max"]), Float32(dataset.meta["features"][feature]["target_min"]), Float32(dataset.meta["features"][feature]["target_max"]))
else
o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps))
o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps))
end
end
else
Expand All @@ -131,7 +133,7 @@ function calc_norms(dataset, norm_steps, device)
if haskey(dataset.meta["features"][feature], "output_min") && haskey(dataset.meta["features"][feature], "output_max")
o_norms[feature] = NormaliserOfflineMinMax(Float32(dataset.meta["features"][feature]["output_min"]), Float32(dataset.meta["features"][feature]["output_max"]))
else
o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps))
o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps))
end
end
end
Expand All @@ -141,13 +143,13 @@ function calc_norms(dataset, norm_steps, device)
if haskey(dataset.meta["features"][feature], "output_min") && haskey(dataset.meta["features"][feature], "output_max")
o_norms[feature] = NormaliserOfflineMeanStd(Float32(dataset.meta["features"][feature]["output_mean"]), Float32(dataset.meta["features"][feature]["output_std"]))
else
o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps))
o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps))
end
end
else
n_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps))
n_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps))
if feature in dataset.meta["target_features"]
o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps))
o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps))
end
end
end
Expand Down Expand Up @@ -177,15 +179,17 @@ Starts the training process with the given configuration.
- `steps = 10e6`: Number of training steps.
- `checkpoint = 10000`: Number of steps after which checkpoints are created.
- `norm_steps = 1000`: Number of steps before training (accumulate normalization stats).
- `max_norm_steps = 10f6`: Number of steps after which no more normalization stats are collected.
- `types_updated = [0, 5]`: Array containing node types which are updated after each step.
- `types_noisy = [0]`: Array containing node types which noise is added to.
- `training_strategy = Collocation()`: Methods used for training. See [documentation](https://una-auxme.github.io/MeshGraphNets.jl/dev/strategies/).
- `use_cuda = true`: Whether a GPU is used for training or not (if available). Currently only CUDA GPUs are supported.
- `gpu_idx = CUDA.deviceid()`: Index of GPU. See *nvidia-smi* for reference.
- `gpu_device = CUDA.device()`: Current CUDA device (aka GPU). See *nvidia-smi* for reference.
- `cell_idxs = [0]`: Indices of cells that are plotted during validation (if enabled).
- `solver_valid = Tsit5()`: Which solver should be used for validation during training.
- `solver_valid_dt = nothing`: If set, the solver for validation will use fixed timesteps.
- `wandb_logger` = nothing: If set, a [Wandb](https://github.com/avik-pal/Wandb.jl) WandbLogger will be used for logging the training.
- `reset_valid = false`: If set, the previous minimal validation loss will be overwritten.
## Training Strategies
- `Collocation`
Expand All @@ -197,13 +201,14 @@ See [CylinderFlow Example](https://una-auxme.github.io/MeshGraphNets.jl/dev/cyli
## Returns
- Trained network as a [`GraphNetwork`](@ref) struct.
- Minimum of validation loss (for hyperparameter tuning).
"""
function train_network(noise_stddevs, opt, ds_path, cp_path; kws...)
args = Args(;kws...)

if CUDA.functional() && args.use_cuda
@info "Training on CUDA GPU..."
CUDA.device!(args.gpu_idx)
CUDA.device!(args.gpu_device)
CUDA.allowscalar(false)
device = gpu_device()
else
Expand All @@ -220,7 +225,7 @@ function train_network(noise_stddevs, opt, ds_path, cp_path; kws...)

println("Building model...")

quantities, e_norms, n_norms, o_norms = calc_norms(dataset, args.norm_steps, device)
quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device)

dims = dataset.meta["dims"]
outputs = 0
Expand Down Expand Up @@ -261,20 +266,23 @@ Initializes the network and performs the training loop.
- `device`: Device where the normaliser should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)).
- `cp_path`: Path where checkpoints are saved.
- `args`: Keyword arguments for configuring the training.
## Returns
- Minimum of validation loss (for hyperparameter tuning).
"""
function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_train, df_valid, device::Function, cp_path, args::Args)
function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_train, df_valid, device, cp_path, args::Args)
checkpoint = length(df_train.step) > 0 ? last(df_train.step) : 0
step = checkpoint
cp_progress = 0
min_validation_loss = length(df_valid.loss) > 0 ? last(df_valid.loss) : Inf32
last_validation_loss = min_validation_loss

if isnothing(args.wandb_logger)
pr = Progress(args.epochs*args.steps; desc = "Training progress: ", dt=1.0, barlen=50, start=checkpoint, showspeed=true)
update!(pr)
if args.reset_valid
min_validation_loss = Inf32
else
pr = nothing
min_validation_loss = length(df_valid.loss) > 0 ? last(df_valid.loss) : Inf32
end
last_validation_loss = min_validation_loss

pr = Progress(args.epochs*args.steps; desc = "Training progress: ", dt=1.0, barlen=50, start=checkpoint, showspeed=true)
update!(pr)

local tmp_loss = 0.0f0
local avg_loss = 0.0f0
Expand Down Expand Up @@ -306,15 +314,12 @@ function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_tr
opt_state, ps = Optimisers.update(opt_state, mgn.ps, gs[i])
mgn.ps = ps
end
if isnothing(args.wandb_logger)
next!(pr, showvalues=[(:train_step,"$(step + datapoint)/$(args.epochs*args.steps)"), (:train_loss, sum(losses)), (:checkpoint, length(df_train.step) > 0 ? last(df_train.step) : 0), (:data_interval, delta == 1 ? "1:end" : 1:delta), (:min_validation_loss, min_validation_loss), (:last_validation_loss, last_validation_loss)])
else
Wandb.log(args.wandb_logger, Dict("train_loss" => sum(losses)))
next!(pr, showvalues=[(:train_step,"$(step + datapoint)/$(args.epochs*args.steps)"), (:train_loss, sum(losses)), (:checkpoint, length(df_train.step) > 0 ? last(df_train.step) : 0), (:data_interval, delta == 1 ? "1:end" : 1:delta), (:min_validation_loss, min_validation_loss), (:last_validation_loss, last_validation_loss)])
if !isnothing(args.wandb_logger)
Wandb.log(args.wandb_logger, Dict("train_loss" => l))
end
else
if isnothing(args.wandb_logger)
next!(pr, showvalues=[(:step,"$(step + datapoint)/$(args.epochs*args.steps)"), (:loss,"acc norm stats..."), (:checkpoint, 0)])
end
next!(pr, showvalues=[(:step,"$(step + datapoint)/$(args.epochs*args.steps)"), (:loss,"acc norm stats..."), (:checkpoint, 0)])
end
end

Expand All @@ -329,20 +334,21 @@ function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_tr
valid_error = 0.0f0
gt = nothing
prediction = nothing
pr_valid = Progress(meta["n_trajectories_valid"]; desc = "Validation progress: ", barlen = 50)
pr_valid = Progress(dataset.meta["n_trajectories_valid"]; desc = "Validation progress: ", barlen = 50)

for i in 1:meta["n_trajectories_valid"]
for i in 1:dataset.meta["n_trajectories_valid"]
data_valid, meta_valid = next_trajectory!(dataset, device; types_noisy = args.types_noisy, is_training = false)

node_type_valid, senders_valid, receivers_valid, edge_features_valid = create_base_graph(data_valid, meta["features"]["node_type"]["data_max"], meta["features"]["node_type"]["data_min"], device)
mask = Int32.(findall(x -> x in args.types_updated, data_valid["node_type"][1, :, 1])) |> device
node_type_valid, senders_valid, receivers_valid, edge_features_valid = create_base_graph(data_valid, meta_valid["features"]["node_type"]["data_max"], meta_valid["features"]["node_type"]["data_min"], device)
val_mask_valid = Float32.(map(x -> x in args.types_updated, data_valid["node_type"][:, :, 1]))
val_mask_valid = repeat(val_mask_valid, sum(size(data_valid[field], 1) for field in meta_valid["target_features"]), 1) |> device

inflow_mask_valid = repeat(data["node_type"][:, :, 1] .== 1, sum(size(data[field], 1) for field in meta["target_features"]), 1) |> device
inflow_mask_valid = repeat(data_valid["node_type"][:, :, 1] .== 1, sum(size(data_valid[field], 1) for field in meta_valid["target_features"]), 1) |> device

ve, g, p = validation_step(args.training_strategy, (
mgn, data_valid, meta_valid, delta, args.solver_valid, args.solver_valid_dt, fields, node_type_valid,
edge_features_valid, senders_valid, receivers_valid, mask, val_mask_valid, inflow_mask_valid, data
edge_features_valid, senders_valid, receivers_valid, mask, val_mask_valid, inflow_mask_valid, data_valid
))

valid_error += ve
Expand All @@ -351,19 +357,19 @@ function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_tr
prediction = p
end

next!(pr_valid, showvalues = [(:trajectory, "$i/$(meta["n_trajectories_valid"])"), (:valid_loss, "$((valid_error + ve) / i)")])
next!(pr_valid, showvalues = [(:trajectory, "$i/$(meta_valid["n_trajectories_valid"])"), (:valid_loss, "$((valid_error + ve) / i)")])
end

if !isnothing(args.wandb_logger)
Wandb.log(args.wandb_logger, Dict("validation_loss" => valid_error / meta["n_trajectories_valid"]))
Wandb.log(args.wandb_logger, Dict("validation_loss" => valid_error / dataset.meta["n_trajectories_valid"]))
end

if valid_error / meta["n_trajectories_valid"] < min_validation_loss
save!(mgn, opt_state, df_train, df_valid, step, valid_error / meta["n_trajectories_valid"], joinpath(cp_path, "valid"); is_training = false)
min_validation_loss = valid_error / meta["n_trajectories_valid"]
if valid_error / dataset.meta["n_trajectories_valid"] < min_validation_loss
save!(mgn, opt_state, df_train, df_valid, step, valid_error / dataset.meta["n_trajectories_valid"], joinpath(cp_path, "valid"); is_training = false)
min_validation_loss = valid_error / dataset.meta["n_trajectories_valid"]
cp_progress = args.checkpoint
end
last_validation_loss = valid_error / meta["n_trajectories_valid"]
last_validation_loss = valid_error / dataset.meta["n_trajectories_valid"]
end

if cp_progress >= args.checkpoint
Expand Down Expand Up @@ -399,7 +405,7 @@ Starts the evaluation process with the given configuration.
- `hidden_layers = 2`: Number of hidden layers inside MLPs.
- `types_updated = [0, 5]`: Array containing node types which are updated after each step.
- `use_cuda = true`: Whether a GPU is used for training or not (if available). Currently only CUDA GPUs are supported.
- `gpu_idx = CUDA.deviceid()`: Index of GPU. See *nvidia-smi* for reference.
- `gpu_device = CUDA.device()`: Current CUDA device (aka GPU). See *nvidia-smi* for reference.
- `num_rollouts = 10`: Number of trajectories that are simulated (from the test dataset).
- `use_valid = true`: Whether the last checkpoint with the minimal validation loss should be used.
"""
Expand All @@ -408,7 +414,7 @@ function eval_network(ds_path, cp_path::String, out_path::String, solver = nothi

if CUDA.functional() && args.use_cuda
@info "Evaluating on CUDA GPU..."
CUDA.device!(args.gpu_idx)
CUDA.device!(args.gpu_device)
CUDA.allowscalar(false)
device = gpu_device()
else
Expand All @@ -425,7 +431,7 @@ function eval_network(ds_path, cp_path::String, out_path::String, solver = nothi

println("Building model...")

quantities, e_norms, n_norms, o_norms = calc_norms(dataset, args.norm_steps, device)
quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device)

dims = dataset.meta["dims"]
outputs = 0
Expand Down Expand Up @@ -499,7 +505,6 @@ function eval_network!(solver, mgn::GraphNetwork, dataset::Dataset, device::Func
error = mean((prediction - vcat([data[field][:, :, 1:length(saves)] for field in meta["target_features"]]...)) .^ 2; dims = 2)
timesteps[(ti, "timesteps")] = sol_t

clear_log(1)
@info "Rollout trajectory $ti completed!"

println("MSE of state prediction:")
Expand Down
11 changes: 9 additions & 2 deletions src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,9 @@ function preprocess!(data, noise_fields, noise_stddevs, types_noisy, ts, device)
if key == "edges" || length(data[key]) == 1 || size(data[key])[end] == 1
continue
end
if typeof(ts) == Collocation && ts.window_size <= 0
data[key] = data[key][:, :, shuffle(rng, 1:end)]
if typeof(ts) <: CollocationStrategy && ts.random

data[key] = data[key][repeat([:], ndims(data[key])-1)..., shuffle(rng, ts.window_size == 0 ? collect(1:end) : collect(1:ts.window_size))]
end
seed!(rng, seed)
end
Expand Down Expand Up @@ -570,6 +571,12 @@ function prepare_trajectory!(data, meta, device::Function; types_noisy, noise_st
if !isnothing(ts) && (typeof(ts) <: CollocationStrategy)
add_targets!(data, meta["target_features"], device)
preprocess!(data, meta["target_features"], noise_stddevs, types_noisy, ts, device)
for field in meta["feature_names"]
if field == "mesh_pos" || field == "node_type" || field == "cells" || field in meta["target_features"]
continue
end
data[field] = device(data[field])
end
else
for field in meta["feature_names"]
if field == "mesh_pos" || field == "node_type" || field == "cells"
Expand Down
Loading

0 comments on commit fc799b8

Please sign in to comment.