diff --git a/src/utils.jl b/src/utils.jl index 5c25df9..04ff522 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,27 +12,9 @@ function _pad(x::T, arr::AbstractArray{T, N}, paddings::Vararg{Tuple{Int, Int}, return PaddedView(x, arr, Tuple(Base.OneTo.(new_size)), Tuple(offsets)) end -function _moveaxis(arr::AbstractArray, source::Union{Int, Vector{Int}}, destination::Union{Int, Vector{Int}}) +function _moveaxis(arr::AbstractArray, src::Int, dest::Int) ndim = ndims(arr) - source = source isa Int ? [source] : source - destination = destination isa Int ? [destination] : destination - - if length(source) != length(destination) - throw(ArgumentError("Length of source and destination must match")) - end - - source .= mod.(source .- 1, ndim) .+ 1 - destination .= mod.(destination .- 1, ndim) .+ 1 - - if length(unique(source)) != length(source) || length(unique(destination)) != length(destination) - throw(ArgumentError("Repeated indices are not allowed")) - end - - permute_dims = setdiff(1:ndim, source) - - for (s, d) in zip(source, destination) - insert!(permute_dims, d, s) - end - - return permutedims(arr, permute_dims) + src = (src - 1) % ndim + 1 + dest = (dest - 1) % ndim + 1 + return permutedims(arr, insert!(setdiff(1:ndim, [src]), dest, src)) end \ No newline at end of file