Skip to content

Commit

Permalink
Update utils.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
anton083 committed Oct 29, 2023
1 parent 9f3e31e commit da793b6
Showing 1 changed file with 4 additions and 22 deletions.
26 changes: 4 additions & 22 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,9 @@ function _pad(x::T, arr::AbstractArray{T, N}, paddings::Vararg{Tuple{Int, Int},
return PaddedView(x, arr, Tuple(Base.OneTo.(new_size)), Tuple(offsets))
end

function _moveaxis(arr::AbstractArray, source::Union{Int, Vector{Int}}, destination::Union{Int, Vector{Int}})
function _moveaxis(arr::AbstractArray, src::Int, dest::Int)
ndim = ndims(arr)
source = source isa Int ? [source] : source
destination = destination isa Int ? [destination] : destination

if length(source) != length(destination)
throw(ArgumentError("Length of source and destination must match"))
end

source .= mod.(source .- 1, ndim) .+ 1
destination .= mod.(destination .- 1, ndim) .+ 1

if length(unique(source)) != length(source) || length(unique(destination)) != length(destination)
throw(ArgumentError("Repeated indices are not allowed"))
end

permute_dims = setdiff(1:ndim, source)

for (s, d) in zip(source, destination)
insert!(permute_dims, d, s)
end

return permutedims(arr, permute_dims)
src = (src - 1) % ndim + 1
dest = (dest - 1) % ndim + 1
return permutedims(arr, insert!(setdiff(1:ndim, [src]), dest, src))
end

0 comments on commit da793b6

Please sign in to comment.