From 499a502cf802e38aa2d009dfd66cbd36d9f21567 Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Tue, 16 Jul 2024 10:42:14 +0200 Subject: [PATCH 1/2] Reworked training strategies --- docs/src/cylinder_flow.md | 8 ++--- docs/src/strategies.md | 4 +-- examples/cylinder_flow/cylinder_flow.jl | 8 ++--- src/MeshGraphNets.jl | 13 ++++---- src/dataset.jl | 11 ++++--- src/strategies.jl | 40 ++++++++++++------------- 6 files changed, 41 insertions(+), 43 deletions(-) diff --git a/docs/src/cylinder_flow.md b/docs/src/cylinder_flow.md index e158663..13155e1 100644 --- a/docs/src/cylinder_flow.md +++ b/docs/src/cylinder_flow.md @@ -32,20 +32,20 @@ In order to train the system you can simply comment in/out the lines of code pro # Train network # ################# -# with Collocation +# with DerivativeTraining train_network( noise, opt, ds_path, chk_path; mps = message_steps, layer_size = layer_size, hidden_layers = hidden_layers, batchsize = batch, epochs = epo, steps = Int(ns), use_cuda = cuda, checkpoint = cp, norm_steps = 1000, types_updated = types_updated, - types_noisy = types_noisy, training_strategy = Collocation(), solver_valid = Euler(), solver_valid_dt = 0.01f0 + types_noisy = types_noisy, training_strategy = DerivativeTraining(), solver_valid = Euler(), solver_valid_dt = 0.01f0 ) -# with SingleShooting +# with SolverTraining train_network( noise, opt, ds_path, chk_path; mps = message_steps, layer_size = layer_size, hidden_layers = hidden_layers, batchsize = batch, epochs = epo, steps = Int(ns), use_cuda = cuda, checkpoint = 10, norm_steps = 1000, types_updated = types_updated, types_noisy = types_noisy, - training_strategy = SingleShooting(0.0f0, 0.01f0, 5.99f0, Euler(); adaptive = false, tstops = 0.0f0:0.01f0:5.99f0) + training_strategy = SolverTraining(0.0f0, 0.01f0, 5.99f0, Euler(); adaptive = false, tstops = 0.0f0:0.01f0:5.99f0) ) ``` diff --git a/docs/src/strategies.md b/docs/src/strategies.md index 851cc54..5ec161f 100644 --- a/docs/src/strategies.md +++ b/docs/src/strategies.md @@ -1,7 +1,7 @@ # Training Strategies ```@docs -Collocation -SingleShooting +DerivativeTraining +SolverTraining MultipleShooting ``` \ No newline at end of file diff --git a/examples/cylinder_flow/cylinder_flow.jl b/examples/cylinder_flow/cylinder_flow.jl index 9abdeb6..dbcae7b 100644 --- a/examples/cylinder_flow/cylinder_flow.jl +++ b/examples/cylinder_flow/cylinder_flow.jl @@ -50,20 +50,20 @@ eval_path = "data/CylinderFlow/eval" # Train network # ################# -# with Collocation +# with DerivativeTraining train_network( noise, opt, ds_path, chk_path; mps = message_steps, layer_size = layer_size, hidden_layers = hidden_layers, batchsize = batch, epochs = epo, steps = Int(ns), use_cuda = cuda, checkpoint = cp, norm_steps = 1000, types_updated = types_updated, - types_noisy = types_noisy, training_strategy = Collocation(), solver_valid = Euler(), solver_valid_dt = 0.01f0 + types_noisy = types_noisy, training_strategy = DerivativeTraining(), solver_valid = Euler(), solver_valid_dt = 0.01f0 ) -# with SingleShooting +# with SolverTraining train_network( noise, opt, ds_path, chk_path; mps = message_steps, layer_size = layer_size, hidden_layers = hidden_layers, batchsize = batch, epochs = epo, steps = Int(ns), use_cuda = cuda, checkpoint = 10, norm_steps = 1000, types_updated = types_updated, types_noisy = types_noisy, - training_strategy = SingleShooting(0.0f0, 0.01f0, 5.99f0, Euler(); adaptive = false, tstops = 0.0f0:0.01f0:5.99f0) + training_strategy = SolverTraining(0.0f0, 0.01f0, 5.99f0, Euler(); adaptive = false, tstops = 0.0f0:0.01f0:5.99f0) ) diff --git a/src/MeshGraphNets.jl b/src/MeshGraphNets.jl index 7c121d9..84788a4 100644 --- a/src/MeshGraphNets.jl +++ b/src/MeshGraphNets.jl @@ -28,7 +28,7 @@ include("graph.jl") include("solve.jl") include("dataset.jl") -export SingleShooting, MultipleShooting, RandomCollocation, Collocation +export SolverTraining, MultipleShooting, DerivativeTraining export train_network, eval_network, der_minmax, data_meanstd @@ -44,7 +44,7 @@ export train_network, eval_network, der_minmax, data_meanstd max_norm_steps::Integer = 10f6 types_updated::Vector{Integer} = [0, 5] types_noisy::Vector{Integer} = [0] - training_strategy::TrainingStrategy = Collocation() + training_strategy::TrainingStrategy = DerivativeTraining() use_cuda::Bool = true gpu_device::Integer = CUDA.device() cell_idxs::Vector{Integer} = [0] @@ -182,7 +182,7 @@ Starts the training process with the given configuration. - `max_norm_steps = 10f6`: Number of steps after which no more normalization stats are collected. - `types_updated = [0, 5]`: Array containing node types which are updated after each step. - `types_noisy = [0]`: Array containing node types which noise is added to. -- `training_strategy = Collocation()`: Methods used for training. See [documentation](https://una-auxme.github.io/MeshGraphNets.jl/dev/strategies/). +- `training_strategy = DerivativeTraining()`: Methods used for training. See [documentation](https://una-auxme.github.io/MeshGraphNets.jl/dev/strategies/). - `use_cuda = true`: Whether a GPU is used for training or not (if available). Currently only CUDA GPUs are supported. - `gpu_device = CUDA.device()`: Current CUDA device (aka GPU). See *nvidia-smi* for reference. - `cell_idxs = [0]`: Indices of cells that are plotted during validation (if enabled). @@ -192,9 +192,8 @@ Starts the training process with the given configuration. - `reset_valid = false`: If set, the previous minimal validation loss will be overwritten. ## Training Strategies -- `Collocation` -- `RandomCollocation` -- `SingleShooting` +- `DerivativeTraining` +- `SolverTraining` - `MultipleShooting` See [CylinderFlow Example](https://una-auxme.github.io/MeshGraphNets.jl/dev/cylinder_flow) for reference. @@ -520,7 +519,7 @@ function eval_network!(solver, mgn::GraphNetwork, dataset::Dataset, device::Func errors[(ti, "error")] = cpu_device()(error[:, 1, :]) end - eval_path = joinpath(out_path, isnothing(solver) ? "collocation" : lowercase("$(nameof(typeof(solver)))")) + eval_path = joinpath(out_path, isnothing(solver) ? "derivative_training" : lowercase("$(nameof(typeof(solver)))")) mkpath(eval_path) h5open(joinpath(eval_path, "trajectories.h5"), "w") do f for i in 1:maximum(getfield.(keys(traj_ops), 1)) diff --git a/src/dataset.jl b/src/dataset.jl index b59b4e2..e73d92e 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -423,7 +423,7 @@ end """ add_targets!(data, fields, device) -Shifts the datapoints beginning from second index back in order to use them as ground truth data (used for collocation strategies). +Shifts the datapoints beginning from second index back in order to use them as ground truth data (used for derivative based strategies). ## Arguments - `data`: Data from the dataset containing one trajectory. @@ -455,7 +455,7 @@ end """ preprocess!(data, noise_fields, noise_stddevs, types_noisy, ts, device) -Adds noise to the given features and shuffles the datapoints if a collocation strategy is used. +Adds noise to the given features and shuffles the datapoints if a derivative based strategy is used. ## Arguments - `data`: Data from the dataset containing one trajectory. @@ -485,8 +485,7 @@ function preprocess!(data, noise_fields, noise_stddevs, types_noisy, ts, device) if key == "edges" || length(data[key]) == 1 || size(data[key])[end] == 1 continue end - if typeof(ts) <: CollocationStrategy && ts.random - + if typeof(ts) <: DerivativeStrategy && ts.random data[key] = data[key][repeat([:], ndims(data[key])-1)..., shuffle(rng, ts.window_size == 0 ? collect(1:end) : collect(1:ts.window_size))] end seed!(rng, seed) @@ -591,7 +590,7 @@ end """ prepare_trajectory!(data, meta, device; types_noisy, noise_stddevs, ts) -Transfers the data to the given device and configures the data if a collocation strategy is used. +Transfers the data to the given device and configures the data if a derivative based strategy is used. ## Arguments - `data`: Data from the dataset containing one trajectory. @@ -608,7 +607,7 @@ Transfers the data to the given device and configures the data if a collocation - Metadata of the dataset. """ function prepare_trajectory!(data, meta, device::Function; types_noisy, noise_stddevs, ts) - if !isnothing(ts) && (typeof(ts) <: CollocationStrategy) + if !isnothing(ts) && (typeof(ts) <: DerivativeStrategy) add_targets!(data, meta["target_features"], device) preprocess!(data, meta["target_features"], noise_stddevs, types_noisy, ts, device) for field in meta["feature_names"] diff --git a/src/strategies.jl b/src/strategies.jl index b211e80..b58084d 100644 --- a/src/strategies.jl +++ b/src/strategies.jl @@ -35,7 +35,7 @@ Returns the delta between samples in the training data. ## Arguments - `strategy`: Used training strategy. -- Trajectory length (used for collocation strategies). +- Trajectory length (used for derivative based strategies). ## Returns - Delta between samples in the training data. @@ -213,7 +213,7 @@ end """ - SingleShooting(tstart, dt, tstop, solver; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()), solargs...) + SolverTraining(tstart, dt, tstop, solver; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()), solargs...) The default solver based training that is normally used for NeuralODEs. Simulates the system from `tstart` to `tstop` and calculates the loss based on the difference between the prediction and the ground truth at the timesteps `tstart:dt:tstop`. @@ -228,7 +228,7 @@ Simulates the system from `tstart` to `tstop` and calculates the loss based on t - `sense = InterpolatingAdjoint(autojacvec = ZygoteVJP())`: The sensitivity algorithm that is used for caluclating the sensitivities. - `solargs`: Keyword arguments that are passed on to the solver. """ -struct SingleShooting <: SolverStrategy +struct SolverTraining <: SolverStrategy tstart::Float32 dt::Float32 tstop::Float32 @@ -237,11 +237,11 @@ struct SingleShooting <: SolverStrategy solargs end -function SingleShooting(tstart::Float32, dt::Float32, tstop::Float32, solver::OrdinaryDiffEqAlgorithm; sense::AbstractSensitivityAlgorithm = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...) - SingleShooting(tstart, dt, tstop, solver, sense, solargs) +function SolverTraining(tstart::Float32, dt::Float32, tstop::Float32, solver::OrdinaryDiffEqAlgorithm; sense::AbstractSensitivityAlgorithm = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...) + SolverTraining(tstart, dt, tstop, solver, sense, solargs) end -function train_loss(strategy::SingleShooting, t::Tuple) +function train_loss(strategy::SolverTraining, t::Tuple) prob, ps, u0, callback_solve, gt, val_mask, n_norm, target_fields, target_dims = t @@ -277,8 +277,8 @@ end """ MultipleShooting(tstart, dt, tstop, solver, interval_size, continuity_term = 100; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...) -Similar to SingleShooting, but splits the trajectory into intervals that are solved independently and then combines them for loss calculation. -Useful if the network tends to get stuck in a local minimum if SingleShooting is used. +Similar to SolverTraining, but splits the trajectory into intervals that are solved independently and then combines them for loss calculation. +Useful if the network tends to get stuck in a local minimum if SolverTraining is used. ## Arguments - `tstart`: Start time of the simulation. @@ -356,17 +356,17 @@ function train_loss(strategy::MultipleShooting, t::Tuple) return loss end -######################################################################### -# Abstract type and functions for collocation based training strategies # -######################################################################### +######################################################################## +# Abstract type and functions for derivative based training strategies # +######################################################################## -abstract type CollocationStrategy <: TrainingStrategy end +abstract type DerivativeStrategy <: TrainingStrategy end -function get_delta(strategy::CollocationStrategy, trajectory_length::Integer) +function get_delta(strategy::DerivativeStrategy, trajectory_length::Integer) return strategy.window_size > 0 ? strategy.window_size : trajectory_length - 1 end -function init_train_step(::CollocationStrategy, t::Tuple, ::Tuple) +function init_train_step(::DerivativeStrategy, t::Tuple, ::Tuple) mgn, data, meta, fields, target_fields, node_type, edge_features, senders, receivers, datapoint, mask, _ = t @@ -380,14 +380,14 @@ function init_train_step(::CollocationStrategy, t::Tuple, ::Tuple) return (mgn, graph, target_quantities_change, mask) end -function train_step(::CollocationStrategy, t::Tuple) +function train_step(::DerivativeStrategy, t::Tuple) mgn, graph, target_quantities_change, mask = t return step!(mgn, graph, target_quantities_change, mask, mse_reduce) end -function validation_step(::CollocationStrategy, t::Tuple) +function validation_step(::DerivativeStrategy, t::Tuple) sim_interval = t[2]["dt"][1]:t[2]["dt"][2]-t[2]["dt"][1]:t[2]["dt"][t[4]] data_interval = 1:t[4] @@ -397,7 +397,7 @@ end """ - Collocation(; window_size = 0) + DerivativeTraining(; window_size = 0) Compares the prediction of the system with the derivative from the data (via finite differences). Useful for initial training of the system since it it faster than training with a solver. @@ -405,10 +405,10 @@ Useful for initial training of the system since it it faster than training with ## Keyword Arguments - `window_size = 0`: Number of steps from each trajectory (starting at the beginning) that are used for training. If the number is zero then the whole trajectory is used. """ -struct Collocation <: CollocationStrategy +struct DerivativeTraining <: DerivativeStrategy window_size::Integer random::Bool end -function Collocation(;window_size::Integer = 0, random = true) - Collocation(window_size, random) +function DerivativeTraining(;window_size::Integer = 0, random = true) + DerivativeTraining(window_size, random) end From ce1107979c5691f5fa5528b7d2a092379cba8365 Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Tue, 16 Jul 2024 15:29:24 +0200 Subject: [PATCH 2/2] Added version specific docs & fixed bug for CUDA device --- docs/make.jl | 2 +- src/MeshGraphNets.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 839f7c4..cd95485 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -45,4 +45,4 @@ function deployConfig() return GitHubActions(github_repository, github_event_name, github_ref) end -deploydocs(repo = "github.com/una-auxme/MeshGraphNets.jl.git", devbranch = "main", deploy_config = deployConfig()) \ No newline at end of file +deploydocs(repo = "github.com/una-auxme/MeshGraphNets.jl.git", devbranch = "main", versions = ["dev" => "dev", "v#.#"], deploy_config = deployConfig()) \ No newline at end of file diff --git a/src/MeshGraphNets.jl b/src/MeshGraphNets.jl index 84788a4..56cab82 100644 --- a/src/MeshGraphNets.jl +++ b/src/MeshGraphNets.jl @@ -46,7 +46,7 @@ export train_network, eval_network, der_minmax, data_meanstd types_noisy::Vector{Integer} = [0] training_strategy::TrainingStrategy = DerivativeTraining() use_cuda::Bool = true - gpu_device::Integer = CUDA.device() + gpu_device::CuDevice = CUDA.device() cell_idxs::Vector{Integer} = [0] num_rollouts::Integer = 10 use_valid::Bool = true