Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework handle_trajectories to use replicas #539

Merged
merged 4 commits into from
Aug 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions openfe/utils/handle_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
from pathlib import Path
from openff.units import unit
from openfe import __version__
from typing import Optional

Check warning on line 8 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L8

Added line #L8 was not covered by tests


def _effective_replica(dataset: nc.Dataset, state_num: int,
frame_num: int) -> int:
"""
Helper method to extract the relevant replica number which
represents a given state number based on the frame number.
def _state_to_replica(dataset: nc.Dataset, state_num: int,

Check warning on line 11 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L11

Added line #L11 was not covered by tests
frame_num: int) -> int:
"""Convert a state index to replica index at a given frame

Parameters
----------
Expand All @@ -33,17 +32,18 @@
return np.where(state_distribution == state_num)[0][0]


def _state_positions_at_frame(dataset: nc.Dataset, state_num: int,
frame_num: int) -> unit.Quantity:
def _replica_positions_at_frame(dataset: nc.Dataset,

Check warning on line 35 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L35

Added line #L35 was not covered by tests
replica_index: int,
frame_num: int) -> unit.Quantity:
"""
Helper method to extract atom positions of a state at a given frame.

Parameters
----------
dataset : netCDF4.Dataset
Dataset containing the MultiState information.
state_num : int
State index to extract positions for.
replica_index : int
Replica index to extract positions for.
frame_num : int
Frame number to extract positions for.

Expand All @@ -52,8 +52,7 @@
unit.Quantity
n_atoms * 3 position Quantity array
"""
effective_replica = _effective_replica(dataset, state_num, frame_num)
pos = dataset.variables['positions'][frame_num][effective_replica].data
pos = dataset.variables['positions'][frame_num][replica_index].data

Check warning on line 55 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L55

Added line #L55 was not covered by tests
pos_units = dataset.variables['positions'].units
return pos * unit(pos_units)

Expand Down Expand Up @@ -120,7 +119,7 @@
return ncfile


def _get_unitcell(dataset: nc.Dataset, state_num: int, frame_num: int):
def _get_unitcell(dataset: nc.Dataset, replica_index: int, frame_num: int):

Check warning on line 122 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L122

Added line #L122 was not covered by tests
"""
Helper method to extract a unit cell from the stored
box vectors in a MultiState reporter generated NetCDF file
Expand All @@ -130,8 +129,8 @@
----------
dataset : netCDF4.Dataset
Dataset of MultiState reporter generated NetCDF file.
state_num : int
State for which to get the unit cell for.
replica_index : int
Replica for which to get the unit cell for.
frame_num : int
Frame for which to get the unit cell for.

Expand All @@ -140,8 +139,7 @@
Tuple[lx, ly, lz, alpha, beta, gamma]
Unit cell lengths and angles in angstroms and degrees.
"""
effective_replica = _effective_replica(dataset, state_num, frame_num)
vecs = dataset.variables['box_vectors'][frame_num][effective_replica].data
vecs = dataset.variables['box_vectors'][frame_num][replica_index].data

Check warning on line 142 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L142

Added line #L142 was not covered by tests
vecs_units = dataset.variables['box_vectors'].units
x, y, z = (vecs * unit(vecs_units)).to('angstrom').m
lx = np.linalg.norm(x)
Expand All @@ -158,28 +156,38 @@


def trajectory_from_multistate(input_file: Path, output_file: Path,
state_number: int) -> None:
state_number: Optional[int] = None,
replica_number: Optional[int] = None) -> None:
"""
Extract a state's trajectory (in an AMBER compliant format)
from a MultiState sampler generated NetCDF file.

Either a state or replica index must be supplied, but not both!

Parameters
----------
input_file : path.Pathlib
Path to the input MultiState sampler generated NetCDF file.
output_file : path.Pathlib
Path to the AMBER-style NetCDF trajectory to be written.
state_number : int
state_number : int, optional
Index of the state to write out to the trajectory.
replica_number : int, optional
Index of the replica to write out
"""
if not ((state_number is None) ^ (replica_number is None)):
raise ValueError("Supply either state or replica number, "

Check warning on line 179 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L178-L179

Added lines #L178 - L179 were not covered by tests
f"got state_number={state_number} "
f"and replica_number={replica_number}")

# Open MultiState NC file and get number of atoms and frames
multistate = nc.Dataset(input_file, 'r')
n_atoms = len(multistate.variables['positions'][0][0])
n_replicas = len(multistate.variables['positions'][0])
n_frames = len(multistate.variables['positions'])

# Sanity check
if state_number + 1 > n_replicas:
if state_number is not None and (state_number + 1 > n_replicas):

Check warning on line 190 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L190

Added line #L190 was not covered by tests
# Note this works for now, but when we have more states
# than replicas (e.g. SAMS) this won't really work
errmsg = "State does not exist"
Expand All @@ -190,13 +198,20 @@
output_file, n_atoms,
title=f"state {state_number} trajectory from {input_file}"
)


replica_id: int = -1
if replica_number is not None:
replica_id = replica_number

Check warning on line 204 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L202-L204

Added lines #L202 - L204 were not covered by tests

# Loopy de loop
for frame in range(n_frames):
traj.variables['coordinates'][frame] = _state_positions_at_frame(
multistate, state_number, frame
if state_number is not None:
replica_id = _state_to_replica(multistate, state_number, frame)

Check warning on line 209 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L208-L209

Added lines #L208 - L209 were not covered by tests

traj.variables['coordinates'][frame] = _replica_positions_at_frame(

Check warning on line 211 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L211

Added line #L211 was not covered by tests
multistate, replica_id, frame
).to('angstrom').m
unitcell = _get_unitcell(multistate, state_number, frame)
unitcell = _get_unitcell(multistate, replica_id, frame)

Check warning on line 214 in openfe/utils/handle_trajectories.py

View check run for this annotation

Codecov / codecov/patch

openfe/utils/handle_trajectories.py#L214

Added line #L214 was not covered by tests
traj.variables['cell_lengths'][frame] = unitcell[:3]
traj.variables['cell_angles'][frame] = unitcell[3:]

Expand Down
Loading