diff --git a/Project.toml b/Project.toml index 2d179d0..9b90f17 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MeshGraphNets" uuid = "f7b4726d-4195-44ab-b39c-37bbdadae004" authors = ["JT "] -version = "0.4.0" +version = "0.4.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/strategies.jl b/src/strategies.jl index 006bd1b..3e1a3cd 100644 --- a/src/strategies.jl +++ b/src/strategies.jl @@ -302,11 +302,11 @@ Useful if the network tends to get stuck in a local minimum if SolverTraining is - `dt`: Interval at which the simulation is saved. - `tstop`: Stop time of the simulation. - `solver`: Solver that is used for simulating the system. -- `interval_size`: Size of the intervals (i.e. number of datapoints in one interval). -- `continuity_term = 100`: Factor by which the error between points of concurrent intervals is multiplied. ## Keyword Arguments -- `sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true)`: +- `sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true)`: The sensitivity algorithm that is used for caluclating the sensitivities. +- `interval_size`: Size of the intervals (i.e. number of datapoints in one interval). +- `continuity_term = 100`: Factor by which the error between points of concurrent intervals is multiplied. - `solargs`: Keyword arguments that are passed on to the solver. """ struct MultipleShooting <: SolverStrategy @@ -436,6 +436,7 @@ Useful for initial training of the system since it it faster than training with ## Keyword Arguments - `window_size = 0`: Number of steps from each trajectory (starting at the beginning) that are used for training. If the number is zero then the whole trajectory is used. +- `random = true`: Whether the derivatives of the data should shuffled before the training. """ struct DerivativeTraining <: DerivativeStrategy window_size::Integer