Skip to content

Commit

Permalink
Add Frames constructor for rotation matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Jan 18, 2024
1 parent 800cd5d commit 071731c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Backboner"
uuid = "9ac9c2a2-1cfe-46d3-b3fd-6fa470ea56a7"
authors = ["Anton Oresten"]
version = "0.7.0"
version = "0.7.1"

[deps]
BioStructures = "de9282ab-8554-53be-b2d6-f6c222edabfc"
Expand Down
18 changes: 13 additions & 5 deletions src/frames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@ function Frames{T}(rotations::AbstractMatrix{<:Real}, locations::AbstractMatrix{
Frames{T}(rotations_T, locations_T)
end

function Frames(rotations::AbstractMatrix{<:Real}, locations::AbstractMatrix{<:Real})
function Frames{T}(rotmats::AbstractArray{<:Real, 3}, locations::AbstractMatrix{<:Real}) where T <: Real
rotations_T = Matrix{T}(undef, 4, size(rotmats, 3))
for i in axes(rotmats, 3)
rotations_T[:, i] = params(QuatRotation(rotmats[:, :, i]))
end
locations_T = convert(Matrix{T}, locations)
Frames{T}(rotations_T, locations_T)
end

function Frames(rotations::AbstractArray{<:Real}, locations::AbstractMatrix{<:Real})
T = promote_type(eltype(rotations), eltype(locations))
Frames{T}(rotations, locations)
end
Expand All @@ -37,7 +46,6 @@ Base.length(frames::Frames) = size(frames.rotations, 2)
Base.size(frames::Frames) = Tuple(length(frames))
Base.getindex(frames::Frames, i::Integer) = QuatRotation(frames.rotations[:, i]), frames.locations[:, i]

# conversion to quatrotation is necessary, as multiple quaternions could represent the same rotation, and quatrotation accounts for this
Base.:(==)(frames1::Frames, frames2::Frames) = all(r1 == r2 && l1 == l2 for ((r1, l1), (r2, l2)) in zip(frames1, frames2))
Base.:()(frames1::Frames, frames2::Frames) = all(isapprox(r1, r2; atol=1e-10) && isapprox(l1, l2; atol=1e-10) for ((r1, l1), (r2, l2)) in zip(frames1, frames2))

Expand All @@ -63,15 +71,15 @@ function Frames(backbone::Backbone{T}, ideal_coords::AbstractMatrix{<:Real}) whe
num_frame_points = size(ideal_coords, 2)
L, r = divrem(length(backbone), num_frame_points)
iszero(r) || throw(ArgumentError("Backbone length must be a multiple of the number of points in a frame ($num_frame_points)"))
rotations = Matrix{T}(undef, 4, L)
rotmats = Array{T, 3}(undef, 3, 3, L)
locations = Matrix{T}(undef, 3, L)
all_raw_coords = reshape(backbone.coords, 3, num_frame_points, L)
for (i, raw_coords) in enumerate(eachslice(all_raw_coords, dims=3))
rotmat, _, raw_centroid = kabsch_algorithm(ideal_coords, raw_coords)
rotations[:, i] = params(QuatRotation(rotmat))
rotmats[:, :, i] = rotmat
locations[:, i] = raw_centroid
end
return Frames(rotations, locations)
return Frames(rotmats, locations)
end

function Backbone(frames::Frames{T}, ideal_coords::AbstractMatrix{<:Real}) where T <: Real
Expand Down
8 changes: 8 additions & 0 deletions test/frames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,12 @@ import Rotations: QuatRotation, params
@test Frames(backbone, standard_coords) != frames # due to numerical error
@test Frames(backbone, standard_coords) frames

@testset "constructor with rotmats" begin
rotations = frames.rotations
rotmats = stack(collect(QuatRotation(rot)) for rot in eachcol(rotations))
locations = frames.locations
frames2 = Frames(rotmats, locations)
@test frames2 == frames
end

end

0 comments on commit 071731c

Please sign in to comment.