Skip to content

Commit

Permalink
Add quaternion support
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Jul 22, 2024
1 parent a1d81bf commit 5225831
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ ZygoteIdealizationExt = ["Zygote"]

[compat]
BioStructures = "4"
NNlib = "0.9"
LinearAlgebra = "1"
NNlib = "0.9"
PrecompileTools = "1"
Rotations = "1"
StaticArrays = "1"
Expand Down
79 changes: 77 additions & 2 deletions src/frames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using NNlib

centroid(A::AbstractArray{<:Real}; dims=2) = sum(A; dims) ./ size(A, 2)

# TODO: batched version? possible? batched svd?
function kabsch_algorithm(P::AbstractMatrix{T}, Q::AbstractMatrix{T}) where T <: Real
size(P) == size(Q) || throw(ArgumentError("P and Q must have the same size"))
P_centroid = centroid(P)
Expand Down Expand Up @@ -32,9 +33,9 @@ struct Frames{T<:Real,A<:AbstractArray{T,3},B<:AbstractArray{T,2}}
end
end

Frames(rotations::A,translations::B) where {T<:Real,A<:AbstractArray{T,3},B<:AbstractMatrix{T}} = Frames{T,A,B}(rotations, translations)
Frames(rotations::A, translations::B) where {T<:Real,A<:AbstractArray{T,3},B<:AbstractArray{T,2}} = Frames{T,A,B}(rotations, translations)

function Frames(rotations::AbstractArray{<:Real,3}, translations::AbstractArray{<:Real})
function Frames(rotations::AbstractArray{<:Real}, translations::AbstractArray{<:Real})
T = promote_type(eltype(rotations), eltype(translations))
return Frames(T.(rotations), T.(translations))
end
Expand All @@ -56,3 +57,77 @@ end
(frames::Frames{T})(coords::AbstractMatrix{<:Real}) where T<:Real = frames(T.(coords))

Backbone(frames::Frames, ideal_coords::AbstractMatrix{<:Real}) = Backbone(frames(ideal_coords))

### Quaternion support

# takes a batch of unit quaternions in a 4xN matrix and returns a batch of rotation matrices in a 3x3xN array
function quaternions_to_rotation_matrices(q::AbstractArray{<:Real,2})
size(q, 1) == 4 || throw(ArgumentError("Quaternion batch must have shape 4xN"))

sx = 2q[1, :] .* q[2, :]
sy = 2q[1, :] .* q[3, :]
sz = 2q[1, :] .* q[4, :]
xx = 2q[2, :] .^ 2
xy = 2q[2, :] .* q[3, :]
xz = 2q[2, :] .* q[4, :]
yy = 2q[3, :] .^ 2
yz = 2q[3, :] .* q[4, :]
zz = 2q[4, :] .^ 2

r1 = 1 .- (yy + zz)
r2 = xy - sz
r3 = xz + sy
r4 = xy + sz
r5 = 1 .- (xx + zz)
r6 = yz - sx
r7 = xz - sy
r8 = yz + sx
r9 = 1 .- (xx + yy)

return reshape(vcat(r1', r4', r7', r2', r5', r8', r3', r6', r9'), 3, 3, :)
end

Frames(rotations::AbstractArray{T,2}, translations::AbstractArray{T,2}) = Frames(quaternions_to_rotation_matrices(rotations), translations)

# takes a batch of rotation matrices in a 3x3xN array and returns a batch of unit quaternions in a 4xN matrix
function rotation_matrices_to_quaternions(R::AbstractArray{<:Real,3})
size(q)[1:2] == (3,3) || throw(ArgumentError("Rotation matrix batch must have shape 3x3xN"))
# 1x1xN
r11, r12, r13 = R[1:1, 1:1, :], R[1:1, 2:2, :], R[1:1, 3:3, :]
r21, r22, r23 = R[2:2, 1:1, :], R[2:2, 2:2, :], R[2:2, 3:3, :]
r31, r32, r33 = R[3:3, 1:1, :], R[3:3, 2:2, :], R[3:3, 3:3, :]

# 4x1xN
q0 = [1 .+ r11 + r22 + r33
r32 - r23
r13 - r31
r21 - r12]
q1 = [r32 - r23
1 .+ r11 - r22 - r33
r12 + r21
r13 + r31]
q2 = [r13 - r31
r12 + r21
1 .- r11 + r22 - r33
r23 + r32]
q3 = [r21 - r12
r13 + r31
r23 + r32
1 .- r11 - r22 + r33]

# 4x4xN
Q = hcat(q0, q1, q2, q3)

# 1x4xN, norm of each quaternion
norms = sqrt.(sum(abs2, Q, dims=1))

exp_norms = exp.(norms)
# 1x4xN, norm weights
weights = exp_norms ./ sum(exp_norms, dims=3)

# batched matmul, 4x1xN
q = Q reshape(weights, 4, 1, :)
q_normalized = q ./ sqrt.(sum(abs2, q, dims=1))

return reshape(q_normalized, 4, :)
end
57 changes: 35 additions & 22 deletions test/frames.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,42 @@
@testset "frames.jl" begin

frames = Frames(
[
[0 0 1; 1 0 0; 0 1 0];;;
[0 1 0; 0 0 1; 1 0 0];;;
[1 0 0; 0 1 0; 0 0 1];;;
],
[
0 10 100;
0 0 0;
0 0 0;
]
)
@testset "Frames" begin
frames = Frames(
[
[0 0 1; 1 0 0; 0 1 0];;;
[0 1 0; 0 0 1; 1 0 0];;;
[1 0 0; 0 1 0; 0 0 1];;;
],
[
0 10 100;
0 0 0;
0 0 0;
]
)

standard_coords = [3 1 -4; 1 -1 0; 0 0 0]
backbone = Backbone(frames, standard_coords)

standard_coords = [3 1 -4; 1 -1 0; 0 0 0]
backbone = Backbone(frames, standard_coords)
@test backbone.coords == [
0.0 0.0 0.0 11.0 9.0 10.0 103.0 101.0 96.0
3.0 1.0 -4.0 0.0 0.0 0.0 1.0 -1.0 0.0
1.0 -1.0 0.0 3.0 1.0 -4.0 0.0 0.0 0.0
]

@test backbone.coords == [
0.0 0.0 0.0 11.0 9.0 10.0 103.0 101.0 96.0
3.0 1.0 -4.0 0.0 0.0 0.0 1.0 -1.0 0.0
1.0 -1.0 0.0 3.0 1.0 -4.0 0.0 0.0 0.0
]
new_frames = Frames(backbone, standard_coords)
@test isapprox(frames.rotations, new_frames.rotations; atol=1e-10)
@test isapprox(frames.translations, new_frames.translations; atol=1e-10)
end

new_frames = Frames(backbone, standard_coords)
@test isapprox(frames.rotations, new_frames.rotations; atol=1e-10)
@test isapprox(frames.translations, new_frames.translations; atol=1e-10)
@testset "Quaternion conversion" begin
N = 10
_Q = randn(4, N)
Q = _Q ./ sqrt.(sum(abs2, _Q, dims=1))
R = quaternions_to_rotation_matrices(Q)
Q2 = rotation_matrices_to_quaternions(R)
R2 = quaternions_to_rotation_matrices(Q2)
@test R R2
@test all((Q .≈ Q2) .| (Q .≈ -Q2))
end

end

0 comments on commit 5225831

Please sign in to comment.