diff --git a/src/operators.jl b/src/operators.jl index 77496d323..aa788553e 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -712,7 +712,7 @@ Determine how many elements of `x` are less than `i` for all `i` in `1:n`. """ function compute_shifts(n::Integer, x::AbstractArray) tmp = zeros(eltype(x), n) - tmp[x[2:end]] .= 1 + tmp[x] .= 1 return cumsum!(tmp, tmp) end @@ -743,29 +743,25 @@ julia> collect(edges(h)) Edge 3 => 4 ``` """ -function merge_vertices(g::AbstractGraph, vs) - labels = collect(1:nv(g)) +function merge_vertices(g::AbstractSimpleGraph, vs) # Use lowest value as new vertex id. - sort!(vs) - nvnew = nv(g) - length(unique(vs)) + 1 + vs = unique!(sort(vs)) + merged_vertex = popfirst!(vs) + + nvnew = nv(g) - length(vs) nvnew <= nv(g) || return g - (v0, vm) = extrema(vs) - v0 > 0 || throw(ArgumentError("invalid vertex ID: $v0 in list of vertices to be merged")) - vm <= nv(g) || throw(ArgumentError("vertex $vm not found in graph")) # TODO 0.7: change to DomainError? - labels[vs] .= v0 - shifts = compute_shifts(nv(g), vs[2:end]) - for v in vertices(g) - if labels[v] != v0 - labels[v] -= shifts[v] - end - end + merged_vertex > 0 || throw(ArgumentError("invalid vertex ID: $merged_vertex in list of vertices to be merged")) + vs[end] <= nv(g) || throw(ArgumentError("vertex $(vs[end]) not found in graph")) # TODO 0.7: change to DomainError? + + new_vertex_ids = collect(vertices(g)) .- compute_shifts(nv(g), vs) + new_vertex_ids[vs] .= merged_vertex #if v in vs then labels[v] == v0 else labels[v] == v newg = SimpleGraph(nvnew) for e in edges(g) u, w = src(e), dst(e) - if labels[u] != labels[w] #not a new self loop - add_edge!(newg, labels[u], labels[w]) + if new_vertex_ids[u] != new_vertex_ids[w] #not a new self loop + add_edge!(newg, new_vertex_ids[u], new_vertex_ids[w]) end end return newg @@ -812,35 +808,27 @@ julia> collect(edges(g)) ``` """ function merge_vertices!(g::Graph{T}, vs::Vector{U} where U <: Integer) where T - vs = sort!(unique(vs)) - merged_vertex = popfirst!(vs) + vs = unique!(sort(vs)) + (merged_vertex, vm) = extrema(vs) - x = zeros(Int, nv(g)) - x[vs] .= 1 - new_vertex_ids = collect(1:nv(g)) .- cumsum(x) + merged_vertex > 0 || throw(ArgumentError("invalid vertex ID: $merged_vertex in list of vertices to be merged")) + vm <= nv(g) || throw(ArgumentError("vertex $vm not found in graph")) # TODO 0.7: change to DomainError? + + new_vertex_ids = collect(vertices(g)) .- compute_shifts(nv(g), vs[2:end]) new_vertex_ids[vs] .= merged_vertex for i in vertices(g) # Adjust connections to merged vertices - if (i != merged_vertex) && !insorted(i, vs) + if new_vertex_ids[i] != merged_vertex nbrs_to_rewire = Set{T}() for j in outneighbors(g, i) - if insorted(j, vs) - push!(nbrs_to_rewire, merged_vertex) - else - push!(nbrs_to_rewire, new_vertex_ids[j]) - end + push!(nbrs_to_rewire, new_vertex_ids[j]) end - g.fadjlist[new_vertex_ids[i]] = sort(collect(nbrs_to_rewire)) - + g.fadjlist[new_vertex_ids[i]] = sort!(collect(nbrs_to_rewire)) # Collect connections to new merged vertex else nbrs_to_merge = Set{T}() - for element in filter(x -> !(insorted(x, vs)) && (x != merged_vertex), g.fadjlist[i]) - push!(nbrs_to_merge, new_vertex_ids[element]) - end - for j in vs, e in outneighbors(g, j) if new_vertex_ids[e] != merged_vertex push!(nbrs_to_merge, new_vertex_ids[e]) @@ -850,8 +838,9 @@ function merge_vertices!(g::Graph{T}, vs::Vector{U} where U <: Integer) where T end end + # Drop excess vertices - g.fadjlist = g.fadjlist[1:(end - length(vs))] + g.fadjlist = g.fadjlist[begin:(end - length(vs)+1)] # Correct edge counts g.ne = sum(degree(g, i) for i in vertices(g)) / 2 diff --git a/test/operators.jl b/test/operators.jl index a0809205f..39fa97963 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -103,6 +103,13 @@ @test neighbors(h2, 5) == [2] @test ne(h2) == 3 @test nv(h2) == 5 + + h3 = star_graph(5) + h3merged = merge_vertices(h3, [1,2]) + @test neighbors(h3merged, 1) == [2,3,4] + @test neighbors(h3merged, 2) == [1] + @test neighbors(h3merged, 3) == [1] + @test neighbors(h3merged, 4) == [1] end end