diff --git a/.gitignore b/.gitignore index b83d2f5..d528b11 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,7 @@ docs/build/ Manifest.toml # Wandb files -.CondaPkg/ \ No newline at end of file +.CondaPkg/ + +# VSCode files +.vscode/ \ No newline at end of file diff --git a/Project.toml b/Project.toml index 51a451d..3df428d 100644 --- a/Project.toml +++ b/Project.toml @@ -47,6 +47,7 @@ Random = "1" SciMLBase = "2.7.0 - 2" Statistics = "1" TFRecord = "0.4.2" +Wandb = "0.5" Zygote = "0.6" cuDNN = "1.3" julia = "1.10" diff --git a/src/MeshGraphNets.jl b/src/MeshGraphNets.jl index d831989..7c121d9 100644 --- a/src/MeshGraphNets.jl +++ b/src/MeshGraphNets.jl @@ -71,7 +71,7 @@ Initializes the normalisers based on the given dataset and its metadata. - Dictionary of each node feature and its normaliser as key-value pair. - Dictionary of each output feature and its normaliser as key-value pair. """ -function calc_norms(dataset, device) +function calc_norms(dataset, device, args::Args) quantities = 0 n_norms = Dict{String, Union{NormaliserOffline, NormaliserOnline}}() o_norms = Dict{String, Union{NormaliserOffline, NormaliserOnline}}() @@ -225,7 +225,7 @@ function train_network(noise_stddevs, opt, ds_path, cp_path; kws...) println("Building model...") - quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device) + quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device, args) dims = dataset.meta["dims"] outputs = 0 @@ -431,7 +431,7 @@ function eval_network(ds_path, cp_path::String, out_path::String, solver = nothi println("Building model...") - quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device) + quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device, args) dims = dataset.meta["dims"] outputs = 0 diff --git a/src/dataset.jl b/src/dataset.jl index 444126e..b59b4e2 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -7,7 +7,7 @@ import Distributions: Normal import Random: MersenneTwister import TFRecord: Example -import HDF5: h5open +import HDF5: h5open, Group, read_dataset import JLD2: jldopen import JSON: parse import Random: seed!, make_seed, shuffle @@ -301,10 +301,20 @@ function read_h5!(datafile, data_keys, meta, is_jld) close(file) end - if !haskey(meta, "dims") - throw(ErrorException("Edges for custom meshes without specifying domain dimensions is not supported yet")) + if haskey(meta, "custom_edges") + lock(l) + if is_jld + throw(ErrorException("Custom edge definition is not supported for JLD2 files.")) + else + file = h5open(datafile, "r") + + edges = read_edges(file[k], meta["custom_edges"], traj_dict["node_type"], haskey(meta, "no_edges_node_types") ? meta["no_edges_node_types"] : [], haskey(meta, "exclude_node_indices") ? meta["exclude_node_indices"] : []) + close(file) + end + unlock(l) + elseif haskey(meta, "dims") # this condition is basically useless, because if there would be no "dims", it would have failed earlier + edges = create_edges(dims, traj_dict["node_type"], haskey(meta, "no_edges_node_types") ? meta["no_edges_node_types"] : []) end - edges = create_edges(dims, traj_dict["node_type"], haskey(meta, "no_edges_node_types") ? meta["no_edges_node_types"] : []) traj_dict["edges"] = hcat(sort(edges)...) put!(ch, traj_dict) @@ -380,6 +390,36 @@ function create_edges(dims, node_type, no_edges_node_types) return edges end + +""" + read_edges(traj::Group, node_type, no_edges_node_types::Vector{Int}, exclude_node_indices::Vector{Int}) + + Read edges from trajectory group. + + ## Arguments + + - `traj`: HDF5 group containing this trajectory's data. + - `node_type`: Array of node types from the data file. + - `excluded_node_types`: Vector of node types that should not be connected with edges. + - `exclude_node_indices`: Vector of node indices that should not be connected with edges. + + ## Returns + + - Vector of connected node pair indices (as vectors). +""" +function read_edges(traj::Group, edge_key, node_type, no_edges_node_types, exclude_node_indices) + @assert haskey(traj, edge_key) "Key '$(edge_key)' not found in trajectory group '$(HDF5.name(traj))'" + edges = read_dataset(traj, edge_key) + exclude_indices = findall(x -> x ∈ no_edges_node_types, node_type) + exclude_indices = vcat(exclude_indices, exclude_node_indices) + filter!(x -> x[1] ∉ exclude_indices && x[2] ∉ exclude_indices, edges) + edge_vec = Vector{Vector{Int32}}() + for edge in edges + push!(edge_vec, [edge[1], edge[2]]) + end + return edge_vec +end + """ add_targets!(data, fields, device)