From f0ef58e108c09314b8f4f9f48a83b027e4f27e2f Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Tue, 16 Jul 2024 09:22:56 +0200 Subject: [PATCH 1/3] Bug fixes & refactoring --- docs/src/index.md | 28 ------------------------- src/MeshGraphNets.jl | 45 +++++++++++++++++---------------------- src/dataset.jl | 11 ++++++++-- src/strategies.jl | 50 ++++---------------------------------------- 4 files changed, 32 insertions(+), 102 deletions(-) delete mode 100644 docs/src/index.md diff --git a/docs/src/index.md b/docs/src/index.md deleted file mode 100644 index ad1e660..0000000 --- a/docs/src/index.md +++ /dev/null @@ -1,28 +0,0 @@ -![MeshGraphNets.jl Logo](https://github.com/una-auxme/MeshGraphNets.jl/blob/main/logo/meshgraphnetsjl_logo.png?raw=true "MeshGraphNets.jl Logo") - -# MeshGraphNets.jl - -[![Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://una-auxme.github.io/MeshGraphNets.jl/stable) - -[*MeshGraphNets.jl*](https://github.com/una-auxme/MeshGraphNets.jl) is a software package for the Julia programming language that provides an implementation of the [MeshGraphNets](https://arxiv.org/abs/2010.03409) framework by [Google DeepMind](https://deepmind.google/) for simulating mesh-based physical systems via graph neural networks: - -> Tobias Pfaff, Meire Fortunato, Alvaro Sanchez-Gonzalez, and Peter W. Battaglia. 2021. **Learning Mesh-Based Simulation with Graph Networks**. In International Conference on Learning Representations. - -## Overwiev - -[*MeshGraphNets.jl*](https://github.com/una-auxme/MeshGraphNets.jl) is designed to be part of the [SciML](https://sciml.ai/) ecosystem. The original framework was remodeled into a NeuralODE so that solvers from the [DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) can be used to evaluate the system. - -## How to use MeshGraphNets.jl - -Examples from the original paper are implemented in the [examples folder](https://github.com/una-auxme/MeshGraphNets.jl/tree/main/examples). You can also refer to the [documentation](https://una-auxme.github.io/MeshGraphNets.jl/stable) if you want to model your own system. - -## Currently supported - -- Customizable input & output quantities -- 1D & 3D meshes (2D meshes coming soon) -- Different strategies for training (see [here](https://una-auxme.github.io/MeshGraphNets.jl/stable)) -- Evaluation of system with [DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) solvers - -## Citation - -Coming soon! \ No newline at end of file diff --git a/src/MeshGraphNets.jl b/src/MeshGraphNets.jl index b7a9e30..688e47d 100644 --- a/src/MeshGraphNets.jl +++ b/src/MeshGraphNets.jl @@ -262,19 +262,15 @@ Initializes the network and performs the training loop. - `cp_path`: Path where checkpoints are saved. - `args`: Keyword arguments for configuring the training. """ -function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_train, df_valid, device::Function, cp_path, args::Args) +function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_train, df_valid, device, cp_path, args::Args) checkpoint = length(df_train.step) > 0 ? last(df_train.step) : 0 step = checkpoint cp_progress = 0 min_validation_loss = length(df_valid.loss) > 0 ? last(df_valid.loss) : Inf32 last_validation_loss = min_validation_loss - if isnothing(args.wandb_logger) - pr = Progress(args.epochs*args.steps; desc = "Training progress: ", dt=1.0, barlen=50, start=checkpoint, showspeed=true) - update!(pr) - else - pr = nothing - end + pr = Progress(args.epochs*args.steps; desc = "Training progress: ", dt=1.0, barlen=50, start=checkpoint, showspeed=true) + update!(pr) local tmp_loss = 0.0f0 local avg_loss = 0.0f0 @@ -306,15 +302,12 @@ function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_tr opt_state, ps = Optimisers.update(opt_state, mgn.ps, gs[i]) mgn.ps = ps end - if isnothing(args.wandb_logger) - next!(pr, showvalues=[(:train_step,"$(step + datapoint)/$(args.epochs*args.steps)"), (:train_loss, sum(losses)), (:checkpoint, length(df_train.step) > 0 ? last(df_train.step) : 0), (:data_interval, delta == 1 ? "1:end" : 1:delta), (:min_validation_loss, min_validation_loss), (:last_validation_loss, last_validation_loss)]) - else - Wandb.log(args.wandb_logger, Dict("train_loss" => sum(losses))) + next!(pr, showvalues=[(:train_step,"$(step + datapoint)/$(args.epochs*args.steps)"), (:train_loss, sum(losses)), (:checkpoint, length(df_train.step) > 0 ? last(df_train.step) : 0), (:data_interval, delta == 1 ? "1:end" : 1:delta), (:min_validation_loss, min_validation_loss), (:last_validation_loss, last_validation_loss)]) + if !isnothing(args.wandb_logger) + Wandb.log(args.wandb_logger, Dict("train_loss" => l)) end else - if isnothing(args.wandb_logger) - next!(pr, showvalues=[(:step,"$(step + datapoint)/$(args.epochs*args.steps)"), (:loss,"acc norm stats..."), (:checkpoint, 0)]) - end + next!(pr, showvalues=[(:step,"$(step + datapoint)/$(args.epochs*args.steps)"), (:loss,"acc norm stats..."), (:checkpoint, 0)]) end end @@ -329,20 +322,21 @@ function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_tr valid_error = 0.0f0 gt = nothing prediction = nothing - pr_valid = Progress(meta["n_trajectories_valid"]; desc = "Validation progress: ", barlen = 50) + pr_valid = Progress(dataset.meta["n_trajectories_valid"]; desc = "Validation progress: ", barlen = 50) - for i in 1:meta["n_trajectories_valid"] + for i in 1:dataset.meta["n_trajectories_valid"] data_valid, meta_valid = next_trajectory!(dataset, device; types_noisy = args.types_noisy, is_training = false) - node_type_valid, senders_valid, receivers_valid, edge_features_valid = create_base_graph(data_valid, meta["features"]["node_type"]["data_max"], meta["features"]["node_type"]["data_min"], device) + mask = Int32.(findall(x -> x in args.types_updated, data_valid["node_type"][1, :, 1])) |> device + node_type_valid, senders_valid, receivers_valid, edge_features_valid = create_base_graph(data_valid, meta_valid["features"]["node_type"]["data_max"], meta_valid["features"]["node_type"]["data_min"], device) val_mask_valid = Float32.(map(x -> x in args.types_updated, data_valid["node_type"][:, :, 1])) val_mask_valid = repeat(val_mask_valid, sum(size(data_valid[field], 1) for field in meta_valid["target_features"]), 1) |> device - inflow_mask_valid = repeat(data["node_type"][:, :, 1] .== 1, sum(size(data[field], 1) for field in meta["target_features"]), 1) |> device + inflow_mask_valid = repeat(data_valid["node_type"][:, :, 1] .== 1, sum(size(data_valid[field], 1) for field in meta_valid["target_features"]), 1) |> device ve, g, p = validation_step(args.training_strategy, ( mgn, data_valid, meta_valid, delta, args.solver_valid, args.solver_valid_dt, fields, node_type_valid, - edge_features_valid, senders_valid, receivers_valid, mask, val_mask_valid, inflow_mask_valid, data + edge_features_valid, senders_valid, receivers_valid, mask, val_mask_valid, inflow_mask_valid, data_valid )) valid_error += ve @@ -351,19 +345,19 @@ function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_tr prediction = p end - next!(pr_valid, showvalues = [(:trajectory, "$i/$(meta["n_trajectories_valid"])"), (:valid_loss, "$((valid_error + ve) / i)")]) + next!(pr_valid, showvalues = [(:trajectory, "$i/$(meta_valid["n_trajectories_valid"])"), (:valid_loss, "$((valid_error + ve) / i)")]) end if !isnothing(args.wandb_logger) - Wandb.log(args.wandb_logger, Dict("validation_loss" => valid_error / meta["n_trajectories_valid"])) + Wandb.log(args.wandb_logger, Dict("validation_loss" => valid_error / dataset.meta["n_trajectories_valid"])) end - if valid_error / meta["n_trajectories_valid"] < min_validation_loss - save!(mgn, opt_state, df_train, df_valid, step, valid_error / meta["n_trajectories_valid"], joinpath(cp_path, "valid"); is_training = false) - min_validation_loss = valid_error / meta["n_trajectories_valid"] + if valid_error / dataset.meta["n_trajectories_valid"] < min_validation_loss + save!(mgn, opt_state, df_train, df_valid, step, valid_error / dataset.meta["n_trajectories_valid"], joinpath(cp_path, "valid"); is_training = false) + min_validation_loss = valid_error / dataset.meta["n_trajectories_valid"] cp_progress = args.checkpoint end - last_validation_loss = valid_error / meta["n_trajectories_valid"] + last_validation_loss = valid_error / dataset.meta["n_trajectories_valid"] end if cp_progress >= args.checkpoint @@ -499,7 +493,6 @@ function eval_network!(solver, mgn::GraphNetwork, dataset::Dataset, device::Func error = mean((prediction - vcat([data[field][:, :, 1:length(saves)] for field in meta["target_features"]]...)) .^ 2; dims = 2) timesteps[(ti, "timesteps")] = sol_t - clear_log(1) @info "Rollout trajectory $ti completed!" println("MSE of state prediction:") diff --git a/src/dataset.jl b/src/dataset.jl index 4b37aeb..444126e 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -445,8 +445,9 @@ function preprocess!(data, noise_fields, noise_stddevs, types_noisy, ts, device) if key == "edges" || length(data[key]) == 1 || size(data[key])[end] == 1 continue end - if typeof(ts) == Collocation && ts.window_size <= 0 - data[key] = data[key][:, :, shuffle(rng, 1:end)] + if typeof(ts) <: CollocationStrategy && ts.random + + data[key] = data[key][repeat([:], ndims(data[key])-1)..., shuffle(rng, ts.window_size == 0 ? collect(1:end) : collect(1:ts.window_size))] end seed!(rng, seed) end @@ -570,6 +571,12 @@ function prepare_trajectory!(data, meta, device::Function; types_noisy, noise_st if !isnothing(ts) && (typeof(ts) <: CollocationStrategy) add_targets!(data, meta["target_features"], device) preprocess!(data, meta["target_features"], noise_stddevs, types_noisy, ts, device) + for field in meta["feature_names"] + if field == "mesh_pos" || field == "node_type" || field == "cells" || field in meta["target_features"] + continue + end + data[field] = device(data[field]) + end else for field in meta["feature_names"] if field == "mesh_pos" || field == "node_type" || field == "cells" diff --git a/src/strategies.jl b/src/strategies.jl index 306c39b..b211e80 100644 --- a/src/strategies.jl +++ b/src/strategies.jl @@ -303,7 +303,7 @@ struct MultipleShooting <: SolverStrategy solargs end -function MultipleShooting(tstart::Float32, dt::Float32, tstop::Float32, solver::OrdinaryDiffEqAlgorithm, interval_size, continuity_term = 100; sense::AbstractSensitivityAlgorithm = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...) +function MultipleShooting(tstart::Float32, dt::Float32, tstop::Float32, solver::OrdinaryDiffEqAlgorithm; sense::AbstractSensitivityAlgorithm = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), interval_size, continuity_term = 100, solargs...) MultipleShooting(tstart, dt, tstop, solver, sense, interval_size, continuity_term, solargs) end @@ -407,50 +407,8 @@ Useful for initial training of the system since it it faster than training with """ struct Collocation <: CollocationStrategy window_size::Integer + random::Bool end - -function Collocation(; window_size::Integer = 0) - Collocation(window_size) -end - - - -""" - RandomCollocation(; window_size = 0) - -Similar to Collocation, but timesteps are sampled randomly from the trajectory instead of sequential. - -## Keyword Arguments -- `window_size = 0`: Number of steps from each trajectory (starting at the beginning) that are used for training. If the number is zero then the whole trajectory is used. -""" -struct RandomCollocation <: CollocationStrategy - window_size::Integer -end - -function RandomCollocation(; window_size::Integer = 0) - RandomCollocation(window_size) -end - -function prepare_training(strategy::RandomCollocation) - samples = shuffle(1:strategy.window_size) - - return (samples,) -end - -function init_train_step(::RandomCollocation, t::Tuple, ta::Tuple) - - mgn, data, meta, fields, target_fields, node_type, edge_features, senders, receivers, datapoint, mask, _ = t - - sample = ta[1][datapoint] - - cur_quantities = vcat([data[field][:, :, sample] for field in target_fields]...) - target_quantities = vcat([data["target|" * field][:, :, sample] for field in target_fields]...) - if typeof(meta["dt"]) <: AbstractArray - target_quantities_change = mgn.o_norm((target_quantities - cur_quantities) / (meta["dt"][sample + 1] - meta["dt"][sample])) - else - target_quantities_change = mgn.o_norm((target_quantities - cur_quantities) / Float32(meta["dt"])) - end - graph = build_graph(mgn, data, fields, sample, node_type, edge_features, senders, receivers) - - return (mgn, graph, target_quantities_change, mask) +function Collocation(;window_size::Integer = 0, random = true) + Collocation(window_size, random) end From c53d812fa767c0db2b03121d9bfba8fd3cfe3ea4 Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Tue, 16 Jul 2024 09:53:58 +0200 Subject: [PATCH 2/3] Added new arguments & refactoring --- src/MeshGraphNets.jl | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/MeshGraphNets.jl b/src/MeshGraphNets.jl index 688e47d..d831989 100644 --- a/src/MeshGraphNets.jl +++ b/src/MeshGraphNets.jl @@ -41,17 +41,19 @@ export train_network, eval_network, der_minmax, data_meanstd steps::Integer = 10e6 checkpoint::Integer = 10000 norm_steps::Integer = 1000 + max_norm_steps::Integer = 10f6 types_updated::Vector{Integer} = [0, 5] types_noisy::Vector{Integer} = [0] training_strategy::TrainingStrategy = Collocation() use_cuda::Bool = true - gpu_idx::Integer = CUDA.deviceid() + gpu_device::Integer = CUDA.device() cell_idxs::Vector{Integer} = [0] num_rollouts::Integer = 10 use_valid::Bool = true solver_valid::OrdinaryDiffEqAlgorithm = Tsit5() solver_valid_dt::Union{Nothing, Float32} = nothing wandb_logger::Union{Nothing, Wandb.WandbLogger} = nothing + reset_valid::Bool = false end """ @@ -69,7 +71,7 @@ Initializes the normalisers based on the given dataset and its metadata. - Dictionary of each node feature and its normaliser as key-value pair. - Dictionary of each output feature and its normaliser as key-value pair. """ -function calc_norms(dataset, norm_steps, device) +function calc_norms(dataset, device) quantities = 0 n_norms = Dict{String, Union{NormaliserOffline, NormaliserOnline}}() o_norms = Dict{String, Union{NormaliserOffline, NormaliserOnline}}() @@ -122,7 +124,7 @@ function calc_norms(dataset, norm_steps, device) if haskey(dataset.meta["features"][feature], "output_min") && haskey(dataset.meta["features"][feature], "output_max") o_norms[feature] = NormaliserOfflineMinMax(Float32(dataset.meta["features"][feature]["output_min"]), Float32(dataset.meta["features"][feature]["output_max"]), Float32(dataset.meta["features"][feature]["target_min"]), Float32(dataset.meta["features"][feature]["target_max"])) else - o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps)) + o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps)) end end else @@ -131,7 +133,7 @@ function calc_norms(dataset, norm_steps, device) if haskey(dataset.meta["features"][feature], "output_min") && haskey(dataset.meta["features"][feature], "output_max") o_norms[feature] = NormaliserOfflineMinMax(Float32(dataset.meta["features"][feature]["output_min"]), Float32(dataset.meta["features"][feature]["output_max"])) else - o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps)) + o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps)) end end end @@ -141,13 +143,13 @@ function calc_norms(dataset, norm_steps, device) if haskey(dataset.meta["features"][feature], "output_min") && haskey(dataset.meta["features"][feature], "output_max") o_norms[feature] = NormaliserOfflineMeanStd(Float32(dataset.meta["features"][feature]["output_mean"]), Float32(dataset.meta["features"][feature]["output_std"])) else - o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps)) + o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps)) end end else - n_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps)) + n_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps)) if feature in dataset.meta["target_features"] - o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(norm_steps)) + o_norms[feature] = NormaliserOnline(dataset.meta["features"][feature]["dim"], device; max_acc = Float32(args.max_norm_steps)) end end end @@ -177,15 +179,17 @@ Starts the training process with the given configuration. - `steps = 10e6`: Number of training steps. - `checkpoint = 10000`: Number of steps after which checkpoints are created. - `norm_steps = 1000`: Number of steps before training (accumulate normalization stats). +- `max_norm_steps = 10f6`: Number of steps after which no more normalization stats are collected. - `types_updated = [0, 5]`: Array containing node types which are updated after each step. - `types_noisy = [0]`: Array containing node types which noise is added to. - `training_strategy = Collocation()`: Methods used for training. See [documentation](https://una-auxme.github.io/MeshGraphNets.jl/dev/strategies/). - `use_cuda = true`: Whether a GPU is used for training or not (if available). Currently only CUDA GPUs are supported. -- `gpu_idx = CUDA.deviceid()`: Index of GPU. See *nvidia-smi* for reference. +- `gpu_device = CUDA.device()`: Current CUDA device (aka GPU). See *nvidia-smi* for reference. - `cell_idxs = [0]`: Indices of cells that are plotted during validation (if enabled). - `solver_valid = Tsit5()`: Which solver should be used for validation during training. - `solver_valid_dt = nothing`: If set, the solver for validation will use fixed timesteps. - `wandb_logger` = nothing: If set, a [Wandb](https://github.com/avik-pal/Wandb.jl) WandbLogger will be used for logging the training. +- `reset_valid = false`: If set, the previous minimal validation loss will be overwritten. ## Training Strategies - `Collocation` @@ -197,13 +201,14 @@ See [CylinderFlow Example](https://una-auxme.github.io/MeshGraphNets.jl/dev/cyli ## Returns - Trained network as a [`GraphNetwork`](@ref) struct. +- Minimum of validation loss (for hyperparameter tuning). """ function train_network(noise_stddevs, opt, ds_path, cp_path; kws...) args = Args(;kws...) if CUDA.functional() && args.use_cuda @info "Training on CUDA GPU..." - CUDA.device!(args.gpu_idx) + CUDA.device!(args.gpu_device) CUDA.allowscalar(false) device = gpu_device() else @@ -220,7 +225,7 @@ function train_network(noise_stddevs, opt, ds_path, cp_path; kws...) println("Building model...") - quantities, e_norms, n_norms, o_norms = calc_norms(dataset, args.norm_steps, device) + quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device) dims = dataset.meta["dims"] outputs = 0 @@ -261,12 +266,19 @@ Initializes the network and performs the training loop. - `device`: Device where the normaliser should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)). - `cp_path`: Path where checkpoints are saved. - `args`: Keyword arguments for configuring the training. + +## Returns +- Minimum of validation loss (for hyperparameter tuning). """ function train_mgn!(mgn::GraphNetwork, opt_state, dataset::Dataset, noise, df_train, df_valid, device, cp_path, args::Args) checkpoint = length(df_train.step) > 0 ? last(df_train.step) : 0 step = checkpoint cp_progress = 0 - min_validation_loss = length(df_valid.loss) > 0 ? last(df_valid.loss) : Inf32 + if args.reset_valid + min_validation_loss = Inf32 + else + min_validation_loss = length(df_valid.loss) > 0 ? last(df_valid.loss) : Inf32 + end last_validation_loss = min_validation_loss pr = Progress(args.epochs*args.steps; desc = "Training progress: ", dt=1.0, barlen=50, start=checkpoint, showspeed=true) @@ -393,7 +405,7 @@ Starts the evaluation process with the given configuration. - `hidden_layers = 2`: Number of hidden layers inside MLPs. - `types_updated = [0, 5]`: Array containing node types which are updated after each step. - `use_cuda = true`: Whether a GPU is used for training or not (if available). Currently only CUDA GPUs are supported. -- `gpu_idx = CUDA.deviceid()`: Index of GPU. See *nvidia-smi* for reference. +- `gpu_device = CUDA.device()`: Current CUDA device (aka GPU). See *nvidia-smi* for reference. - `num_rollouts = 10`: Number of trajectories that are simulated (from the test dataset). - `use_valid = true`: Whether the last checkpoint with the minimal validation loss should be used. """ @@ -402,7 +414,7 @@ function eval_network(ds_path, cp_path::String, out_path::String, solver = nothi if CUDA.functional() && args.use_cuda @info "Evaluating on CUDA GPU..." - CUDA.device!(args.gpu_idx) + CUDA.device!(args.gpu_device) CUDA.allowscalar(false) device = gpu_device() else @@ -419,7 +431,7 @@ function eval_network(ds_path, cp_path::String, out_path::String, solver = nothi println("Building model...") - quantities, e_norms, n_norms, o_norms = calc_norms(dataset, args.norm_steps, device) + quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device) dims = dataset.meta["dims"] outputs = 0 From 7f6897c3a25edf8616d0fb881864041422e47137 Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Tue, 16 Jul 2024 10:16:41 +0200 Subject: [PATCH 3/3] Removed RandomCollocation strategy --- docs/src/strategies.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/strategies.md b/docs/src/strategies.md index 27ce700..851cc54 100644 --- a/docs/src/strategies.md +++ b/docs/src/strategies.md @@ -2,7 +2,6 @@ ```@docs Collocation -RandomCollocation SingleShooting MultipleShooting ``` \ No newline at end of file