Skip to content

Commit

Permalink
Merge pull request #14 from una-auxme/dev
Browse files Browse the repository at this point in the history
Reworked training strategies
  • Loading branch information
JulianTrommer authored Jul 16, 2024
2 parents 25429eb + ce11079 commit 71eef20
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 45 deletions.
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ function deployConfig()
return GitHubActions(github_repository, github_event_name, github_ref)
end

deploydocs(repo = "github.com/una-auxme/MeshGraphNets.jl.git", devbranch = "main", deploy_config = deployConfig())
deploydocs(repo = "github.com/una-auxme/MeshGraphNets.jl.git", devbranch = "main", versions = ["dev" => "dev", "v#.#"], deploy_config = deployConfig())
8 changes: 4 additions & 4 deletions docs/src/cylinder_flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ In order to train the system you can simply comment in/out the lines of code pro
# Train network #
#################

# with Collocation
# with DerivativeTraining

train_network(
noise, opt, ds_path, chk_path; mps = message_steps, layer_size = layer_size, hidden_layers = hidden_layers, batchsize = batch,
epochs = epo, steps = Int(ns), use_cuda = cuda, checkpoint = cp, norm_steps = 1000, types_updated = types_updated,
types_noisy = types_noisy, training_strategy = Collocation(), solver_valid = Euler(), solver_valid_dt = 0.01f0
types_noisy = types_noisy, training_strategy = DerivativeTraining(), solver_valid = Euler(), solver_valid_dt = 0.01f0
)

# with SingleShooting
# with SolverTraining

train_network(
noise, opt, ds_path, chk_path; mps = message_steps, layer_size = layer_size, hidden_layers = hidden_layers, batchsize = batch, epochs = epo,
steps = Int(ns), use_cuda = cuda, checkpoint = 10, norm_steps = 1000, types_updated = types_updated, types_noisy = types_noisy,
training_strategy = SingleShooting(0.0f0, 0.01f0, 5.99f0, Euler(); adaptive = false, tstops = 0.0f0:0.01f0:5.99f0)
training_strategy = SolverTraining(0.0f0, 0.01f0, 5.99f0, Euler(); adaptive = false, tstops = 0.0f0:0.01f0:5.99f0)
)
```

Expand Down
4 changes: 2 additions & 2 deletions docs/src/strategies.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Training Strategies

```@docs
Collocation
SingleShooting
DerivativeTraining
SolverTraining
MultipleShooting
```
8 changes: 4 additions & 4 deletions examples/cylinder_flow/cylinder_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ eval_path = "data/CylinderFlow/eval"
# Train network #
#################

# with Collocation
# with DerivativeTraining

train_network(
noise, opt, ds_path, chk_path; mps = message_steps, layer_size = layer_size, hidden_layers = hidden_layers, batchsize = batch,
epochs = epo, steps = Int(ns), use_cuda = cuda, checkpoint = cp, norm_steps = 1000, types_updated = types_updated,
types_noisy = types_noisy, training_strategy = Collocation(), solver_valid = Euler(), solver_valid_dt = 0.01f0
types_noisy = types_noisy, training_strategy = DerivativeTraining(), solver_valid = Euler(), solver_valid_dt = 0.01f0
)

# with SingleShooting
# with SolverTraining

train_network(
noise, opt, ds_path, chk_path; mps = message_steps, layer_size = layer_size, hidden_layers = hidden_layers, batchsize = batch, epochs = epo,
steps = Int(ns), use_cuda = cuda, checkpoint = 10, norm_steps = 1000, types_updated = types_updated, types_noisy = types_noisy,
training_strategy = SingleShooting(0.0f0, 0.01f0, 5.99f0, Euler(); adaptive = false, tstops = 0.0f0:0.01f0:5.99f0)
training_strategy = SolverTraining(0.0f0, 0.01f0, 5.99f0, Euler(); adaptive = false, tstops = 0.0f0:0.01f0:5.99f0)
)


Expand Down
15 changes: 7 additions & 8 deletions src/MeshGraphNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ include("graph.jl")
include("solve.jl")
include("dataset.jl")

export SingleShooting, MultipleShooting, RandomCollocation, Collocation
export SolverTraining, MultipleShooting, DerivativeTraining

export train_network, eval_network, der_minmax, data_meanstd

Expand All @@ -44,9 +44,9 @@ export train_network, eval_network, der_minmax, data_meanstd
max_norm_steps::Integer = 10f6
types_updated::Vector{Integer} = [0, 5]
types_noisy::Vector{Integer} = [0]
training_strategy::TrainingStrategy = Collocation()
training_strategy::TrainingStrategy = DerivativeTraining()
use_cuda::Bool = true
gpu_device::Integer = CUDA.device()
gpu_device::CuDevice = CUDA.device()
cell_idxs::Vector{Integer} = [0]
num_rollouts::Integer = 10
use_valid::Bool = true
Expand Down Expand Up @@ -182,7 +182,7 @@ Starts the training process with the given configuration.
- `max_norm_steps = 10f6`: Number of steps after which no more normalization stats are collected.
- `types_updated = [0, 5]`: Array containing node types which are updated after each step.
- `types_noisy = [0]`: Array containing node types which noise is added to.
- `training_strategy = Collocation()`: Methods used for training. See [documentation](https://una-auxme.github.io/MeshGraphNets.jl/dev/strategies/).
- `training_strategy = DerivativeTraining()`: Methods used for training. See [documentation](https://una-auxme.github.io/MeshGraphNets.jl/dev/strategies/).
- `use_cuda = true`: Whether a GPU is used for training or not (if available). Currently only CUDA GPUs are supported.
- `gpu_device = CUDA.device()`: Current CUDA device (aka GPU). See *nvidia-smi* for reference.
- `cell_idxs = [0]`: Indices of cells that are plotted during validation (if enabled).
Expand All @@ -192,9 +192,8 @@ Starts the training process with the given configuration.
- `reset_valid = false`: If set, the previous minimal validation loss will be overwritten.
## Training Strategies
- `Collocation`
- `RandomCollocation`
- `SingleShooting`
- `DerivativeTraining`
- `SolverTraining`
- `MultipleShooting`
See [CylinderFlow Example](https://una-auxme.github.io/MeshGraphNets.jl/dev/cylinder_flow) for reference.
Expand Down Expand Up @@ -520,7 +519,7 @@ function eval_network!(solver, mgn::GraphNetwork, dataset::Dataset, device::Func
errors[(ti, "error")] = cpu_device()(error[:, 1, :])
end

eval_path = joinpath(out_path, isnothing(solver) ? "collocation" : lowercase("$(nameof(typeof(solver)))"))
eval_path = joinpath(out_path, isnothing(solver) ? "derivative_training" : lowercase("$(nameof(typeof(solver)))"))
mkpath(eval_path)
h5open(joinpath(eval_path, "trajectories.h5"), "w") do f
for i in 1:maximum(getfield.(keys(traj_ops), 1))
Expand Down
11 changes: 5 additions & 6 deletions src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ end
"""
add_targets!(data, fields, device)
Shifts the datapoints beginning from second index back in order to use them as ground truth data (used for collocation strategies).
Shifts the datapoints beginning from second index back in order to use them as ground truth data (used for derivative based strategies).
## Arguments
- `data`: Data from the dataset containing one trajectory.
Expand Down Expand Up @@ -455,7 +455,7 @@ end
"""
preprocess!(data, noise_fields, noise_stddevs, types_noisy, ts, device)
Adds noise to the given features and shuffles the datapoints if a collocation strategy is used.
Adds noise to the given features and shuffles the datapoints if a derivative based strategy is used.
## Arguments
- `data`: Data from the dataset containing one trajectory.
Expand Down Expand Up @@ -485,8 +485,7 @@ function preprocess!(data, noise_fields, noise_stddevs, types_noisy, ts, device)
if key == "edges" || length(data[key]) == 1 || size(data[key])[end] == 1
continue
end
if typeof(ts) <: CollocationStrategy && ts.random

if typeof(ts) <: DerivativeStrategy && ts.random
data[key] = data[key][repeat([:], ndims(data[key])-1)..., shuffle(rng, ts.window_size == 0 ? collect(1:end) : collect(1:ts.window_size))]
end
seed!(rng, seed)
Expand Down Expand Up @@ -591,7 +590,7 @@ end
"""
prepare_trajectory!(data, meta, device; types_noisy, noise_stddevs, ts)
Transfers the data to the given device and configures the data if a collocation strategy is used.
Transfers the data to the given device and configures the data if a derivative based strategy is used.
## Arguments
- `data`: Data from the dataset containing one trajectory.
Expand All @@ -608,7 +607,7 @@ Transfers the data to the given device and configures the data if a collocation
- Metadata of the dataset.
"""
function prepare_trajectory!(data, meta, device::Function; types_noisy, noise_stddevs, ts)
if !isnothing(ts) && (typeof(ts) <: CollocationStrategy)
if !isnothing(ts) && (typeof(ts) <: DerivativeStrategy)
add_targets!(data, meta["target_features"], device)
preprocess!(data, meta["target_features"], noise_stddevs, types_noisy, ts, device)
for field in meta["feature_names"]
Expand Down
40 changes: 20 additions & 20 deletions src/strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Returns the delta between samples in the training data.
## Arguments
- `strategy`: Used training strategy.
- Trajectory length (used for collocation strategies).
- Trajectory length (used for derivative based strategies).
## Returns
- Delta between samples in the training data.
Expand Down Expand Up @@ -213,7 +213,7 @@ end


"""
SingleShooting(tstart, dt, tstop, solver; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()), solargs...)
SolverTraining(tstart, dt, tstop, solver; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()), solargs...)
The default solver based training that is normally used for NeuralODEs.
Simulates the system from `tstart` to `tstop` and calculates the loss based on the difference between the prediction and the ground truth at the timesteps `tstart:dt:tstop`.
Expand All @@ -228,7 +228,7 @@ Simulates the system from `tstart` to `tstop` and calculates the loss based on t
- `sense = InterpolatingAdjoint(autojacvec = ZygoteVJP())`: The sensitivity algorithm that is used for caluclating the sensitivities.
- `solargs`: Keyword arguments that are passed on to the solver.
"""
struct SingleShooting <: SolverStrategy
struct SolverTraining <: SolverStrategy
tstart::Float32
dt::Float32
tstop::Float32
Expand All @@ -237,11 +237,11 @@ struct SingleShooting <: SolverStrategy
solargs
end

function SingleShooting(tstart::Float32, dt::Float32, tstop::Float32, solver::OrdinaryDiffEqAlgorithm; sense::AbstractSensitivityAlgorithm = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...)
SingleShooting(tstart, dt, tstop, solver, sense, solargs)
function SolverTraining(tstart::Float32, dt::Float32, tstop::Float32, solver::OrdinaryDiffEqAlgorithm; sense::AbstractSensitivityAlgorithm = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...)
SolverTraining(tstart, dt, tstop, solver, sense, solargs)
end

function train_loss(strategy::SingleShooting, t::Tuple)
function train_loss(strategy::SolverTraining, t::Tuple)

prob, ps, u0, callback_solve, gt, val_mask, n_norm, target_fields, target_dims = t

Expand Down Expand Up @@ -277,8 +277,8 @@ end
"""
MultipleShooting(tstart, dt, tstop, solver, interval_size, continuity_term = 100; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...)
Similar to SingleShooting, but splits the trajectory into intervals that are solved independently and then combines them for loss calculation.
Useful if the network tends to get stuck in a local minimum if SingleShooting is used.
Similar to SolverTraining, but splits the trajectory into intervals that are solved independently and then combines them for loss calculation.
Useful if the network tends to get stuck in a local minimum if SolverTraining is used.
## Arguments
- `tstart`: Start time of the simulation.
Expand Down Expand Up @@ -356,17 +356,17 @@ function train_loss(strategy::MultipleShooting, t::Tuple)
return loss
end

#########################################################################
# Abstract type and functions for collocation based training strategies #
#########################################################################
########################################################################
# Abstract type and functions for derivative based training strategies #
########################################################################

abstract type CollocationStrategy <: TrainingStrategy end
abstract type DerivativeStrategy <: TrainingStrategy end

function get_delta(strategy::CollocationStrategy, trajectory_length::Integer)
function get_delta(strategy::DerivativeStrategy, trajectory_length::Integer)
return strategy.window_size > 0 ? strategy.window_size : trajectory_length - 1
end

function init_train_step(::CollocationStrategy, t::Tuple, ::Tuple)
function init_train_step(::DerivativeStrategy, t::Tuple, ::Tuple)

mgn, data, meta, fields, target_fields, node_type, edge_features, senders, receivers, datapoint, mask, _ = t

Expand All @@ -380,14 +380,14 @@ function init_train_step(::CollocationStrategy, t::Tuple, ::Tuple)
return (mgn, graph, target_quantities_change, mask)
end

function train_step(::CollocationStrategy, t::Tuple)
function train_step(::DerivativeStrategy, t::Tuple)

mgn, graph, target_quantities_change, mask = t

return step!(mgn, graph, target_quantities_change, mask, mse_reduce)
end

function validation_step(::CollocationStrategy, t::Tuple)
function validation_step(::DerivativeStrategy, t::Tuple)
sim_interval = t[2]["dt"][1]:t[2]["dt"][2]-t[2]["dt"][1]:t[2]["dt"][t[4]]
data_interval = 1:t[4]

Expand All @@ -397,18 +397,18 @@ end


"""
Collocation(; window_size = 0)
DerivativeTraining(; window_size = 0)
Compares the prediction of the system with the derivative from the data (via finite differences).
Useful for initial training of the system since it it faster than training with a solver.
## Keyword Arguments
- `window_size = 0`: Number of steps from each trajectory (starting at the beginning) that are used for training. If the number is zero then the whole trajectory is used.
"""
struct Collocation <: CollocationStrategy
struct DerivativeTraining <: DerivativeStrategy
window_size::Integer
random::Bool
end
function Collocation(;window_size::Integer = 0, random = true)
Collocation(window_size, random)
function DerivativeTraining(;window_size::Integer = 0, random = true)
DerivativeTraining(window_size, random)
end

0 comments on commit 71eef20

Please sign in to comment.