Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge external PR to dev branch #13

Merged
merged 13 commits into from
Jul 16, 2024
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ docs/build/
Manifest.toml

# Wandb files
.CondaPkg/
.CondaPkg/

# VSCode files
.vscode/
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Random = "1"
SciMLBase = "2.7.0 - 2"
Statistics = "1"
TFRecord = "0.4.2"
Wandb = "0.5"
Zygote = "0.6"
cuDNN = "1.3"
julia = "1.10"
Expand Down
6 changes: 3 additions & 3 deletions src/MeshGraphNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Initializes the normalisers based on the given dataset and its metadata.
- Dictionary of each node feature and its normaliser as key-value pair.
- Dictionary of each output feature and its normaliser as key-value pair.
"""
function calc_norms(dataset, device)
function calc_norms(dataset, device, args::Args)
quantities = 0
n_norms = Dict{String, Union{NormaliserOffline, NormaliserOnline}}()
o_norms = Dict{String, Union{NormaliserOffline, NormaliserOnline}}()
Expand Down Expand Up @@ -225,7 +225,7 @@ function train_network(noise_stddevs, opt, ds_path, cp_path; kws...)

println("Building model...")

quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device)
quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device, args)

dims = dataset.meta["dims"]
outputs = 0
Expand Down Expand Up @@ -431,7 +431,7 @@ function eval_network(ds_path, cp_path::String, out_path::String, solver = nothi

println("Building model...")

quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device)
quantities, e_norms, n_norms, o_norms = calc_norms(dataset, device, args)

dims = dataset.meta["dims"]
outputs = 0
Expand Down
48 changes: 44 additions & 4 deletions src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Distributions: Normal
import Random: MersenneTwister
import TFRecord: Example

import HDF5: h5open
import HDF5: h5open, Group, read_dataset
import JLD2: jldopen
import JSON: parse
import Random: seed!, make_seed, shuffle
Expand Down Expand Up @@ -301,10 +301,20 @@ function read_h5!(datafile, data_keys, meta, is_jld)
close(file)
end

if !haskey(meta, "dims")
throw(ErrorException("Edges for custom meshes without specifying domain dimensions is not supported yet"))
if haskey(meta, "custom_edges")
lock(l)
if is_jld
throw(ErrorException("Custom edge definition is not supported for JLD2 files."))
else
file = h5open(datafile, "r")

edges = read_edges(file[k], meta["custom_edges"], traj_dict["node_type"], haskey(meta, "no_edges_node_types") ? meta["no_edges_node_types"] : [], haskey(meta, "exclude_node_indices") ? meta["exclude_node_indices"] : [])
close(file)
end
unlock(l)
elseif haskey(meta, "dims") # this condition is basically useless, because if there would be no "dims", it would have failed earlier
edges = create_edges(dims, traj_dict["node_type"], haskey(meta, "no_edges_node_types") ? meta["no_edges_node_types"] : [])
end
edges = create_edges(dims, traj_dict["node_type"], haskey(meta, "no_edges_node_types") ? meta["no_edges_node_types"] : [])
traj_dict["edges"] = hcat(sort(edges)...)

put!(ch, traj_dict)
Expand Down Expand Up @@ -380,6 +390,36 @@ function create_edges(dims, node_type, no_edges_node_types)
return edges
end


"""
read_edges(traj::Group, node_type, no_edges_node_types::Vector{Int}, exclude_node_indices::Vector{Int})

Read edges from trajectory group.

## Arguments

- `traj`: HDF5 group containing this trajectory's data.
- `node_type`: Array of node types from the data file.
- `excluded_node_types`: Vector of node types that should not be connected with edges.
- `exclude_node_indices`: Vector of node indices that should not be connected with edges.

## Returns

- Vector of connected node pair indices (as vectors).
"""
function read_edges(traj::Group, edge_key, node_type, no_edges_node_types, exclude_node_indices)
@assert haskey(traj, edge_key) "Key '$(edge_key)' not found in trajectory group '$(HDF5.name(traj))'"
edges = read_dataset(traj, edge_key)
exclude_indices = findall(x -> x ∈ no_edges_node_types, node_type)
exclude_indices = vcat(exclude_indices, exclude_node_indices)
filter!(x -> x[1] ∉ exclude_indices && x[2] ∉ exclude_indices, edges)
edge_vec = Vector{Vector{Int32}}()
for edge in edges
push!(edge_vec, [edge[1], edge[2]])
end
return edge_vec
end

"""
add_targets!(data, fields, device)

Expand Down
Loading