Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add draft of "yank analyze cluster" CLI #1020

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 266 additions & 8 deletions Yank/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@
import os
import abc
import yaml
import mdtraj
import logging
import itertools

import numpy as np
import simtk.unit as units
import mdtraj as md

import simtk.unit as unit
import openmmtools as mmtools
from pymbar import timeseries
from pymbar import timeseries, MBAR
from msmbuilder.cluster import RegularSpatial

from . import multistate

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -253,8 +258,8 @@ def analyze_directory(source_directory, **analyzer_kwargs):
# Print energies
logger.info('')
logger.info('Free energy{:<13}: {:9.3f} +- {:.3f} kT ({:.3f} +- {:.3f} kcal/mol)'.format(
calculation_type, DeltaF, dDeltaF, DeltaF * kT / units.kilocalories_per_mole,
dDeltaF * kT / units.kilocalories_per_mole))
calculation_type, DeltaF, dDeltaF, DeltaF * kT / unit.kilocalories_per_mole,
dDeltaF * kT / unit.kilocalories_per_mole))
logger.info('')

for phase in phase_names:
Expand All @@ -265,8 +270,8 @@ def analyze_directory(source_directory, **analyzer_kwargs):
data[phase]['DeltaF_standard_state_correction']))
logger.info('')
logger.info('Enthalpy{:<16}: {:9.3f} +- {:.3f} kT ({:.3f} +- {:.3f} kcal/mol)'.format(
calculation_type, DeltaH, dDeltaH, DeltaH * kT / units.kilocalories_per_mole,
dDeltaH * kT / units.kilocalories_per_mole))
calculation_type, DeltaH, dDeltaH, DeltaH * kT / unit.kilocalories_per_mole,
dDeltaH * kT / unit.kilocalories_per_mole))


# ==========================================
Expand Down Expand Up @@ -479,7 +484,7 @@ def extract_trajectory(nc_path, nc_checkpoint_file=None, state_index=None, repli

# Create trajectory object
logger.info('Creating trajectory object...')
trajectory = mdtraj.Trajectory(positions, topology)
trajectory = md.Trajectory(positions, topology)
if is_periodic:
trajectory.unitcell_vectors = box_vectors

Expand All @@ -496,3 +501,256 @@ def extract_trajectory(nc_path, nc_checkpoint_file=None, state_index=None, repli
logger.warning('The molecules will not be imaged because the system is non-periodic.')

return trajectory

# ==============================================================================
# Cluster ligand conformations and estimate populations in fully-interacting state
# ==============================================================================

# TODO: This is a preliminary draft. This can be heavily refactored after generalizing the analysis code in the MultiStateAnalyzer

def cluster(reference_pdb_filename, netcdf_filename, output_prefix='cluster', nsnapshots_per_cluster=5,
receptor_dsl_selection='protein and name CA', ligand_dsl_selection='not protein and (mass > 1.5)',
fully_interacting_state=0, ligand_rmsd_cutoff=0.3, ligand_filter_cutoff=0.3,
cluster_filter_threshold=0.95):
"""
Cluster ligand conformations and estimate populations in fully-interacting state

Parameters
----------
reference_pdb_filename : str
The name of the PDB file for the solvated complex
netcdf_filename : str
The complex NetCDF file to read
output_prefix : str
String to prepend to cluster PDB files and populations written
nsnapshots_per_cluster : int, optional, default=5
The number of snapshots per state to write
receptor_dsl_selection : str, optional, default='protein and name CA'
MDTraj DSL to use for selecting receptor atoms for alignment
ligand_dsl_selection : str, optional, default='not protein and (mass > 1.5)'
MDTraj DSL to use for selectinf ligand atoms to cluster
fully_interacting_state : int, optional, default=0
0 specifies the fully-interacting state
1 species the first alchemical state
ligand_rmsd_cutoff : float, optional, default=0.3
RMSD cutoff to use for ligand custering (in nanometers)
ligand_filter_cutoff : float, optional, default=0.3
Snapshots where ligand atoms are greater than this cutoff from the receptor are filtered out
cluster_filter_threshold : float, optional, default=0.95
Only the most populous clusters that add up to more than this threshold in population are written.

The algorithm
-------------
* Compute per-snapshot weights (using MBAR) representing the relative weight of each snapshot in the fully interacting state
* Cluster the remaining snapshots
* Assign relative populations to the clusters
* Sort clusters by population, writing only most populous clusters
* Sample representative snapshots from the clusters proportional to their weights, writing out PDB files
* Write out cluster populations

"""
# mdtraj works in nanometers
ligand_rmsd_cutoff /= unit.nanometers
ligand_filter_cutoff /= unit.nanometers

topology = md.load(reference_pdb_filename)
solute_indices = topology.top.select('not water')
logger.info('There are {:d} non-water atoms'.format(len(solute_indices)))
topology = topology.atom_slice(solute_indices) # check that this is the same as w

from netCDF4 import Dataset
ncfile = Dataset(netcdf_filename, 'r')

# TODO: Extend this to handle more than one replica
replica_index = 0

# Extract energy trajectories
sampled_energy_matrix = np.transpose(np.array(ncfile.variables['energies'][:,replica_index:(replica_index+1),:], np.float32), axes=[1,2,0])
unsampled_energy_matrix = np.transpose(np.array(ncfile.variables['unsampled_energies'][:,replica_index:(replica_index+1),:], np.float32), axes=[1,2,0])

# Initialize the MBAR matrices in ln form.
n_replicas, n_sampled_states, n_iterations = sampled_energy_matrix.shape
_, n_unsampled_states, _ = unsampled_energy_matrix.shape
logger.info('n_replicas: {:d}'.format(n_replicas))
logger.info('n_sampled_states: {:d}'.format(n_sampled_states))
logger.info('n_iterations: {:d}'.format(n_iterations))

# Remove some frames
# TODO: Change this to instead extract the "good" portion of the trajectory that isn't corrupted
ntrim = 100
logger.info('Trimming {:d} frames from either end'.format(ntrim))
retained_snapshot_indices = list(range(ntrim, (n_iterations-ntrim)))
n_iterations = len(retained_snapshot_indices)
sampled_energy_matrix = sampled_energy_matrix[:,:,retained_snapshot_indices]
unsampled_energy_matrix = unsampled_energy_matrix[:,:,retained_snapshot_indices]

# Extract thermodynamic state indices
replicas_state_indices = np.transpose(np.array(ncfile.variables['states'][retained_snapshot_indices,replica_index:(replica_index+1)], np.int64), axes=[1,0])

# TODO: Pre-filter all states remote from fully-interacting state

# TODO: We could detect the equilibration time and discard data to equilibration here
#[t0, g, Neff_max] = timeseries.detectEquilibration(replicas_state_indices[replica_index,:], nskip=100)

# Compute snapshot weights with MBAR

#
# Note: This comes from multistateanalyzer.py L1445-1479.
# That section could be refactored to be more general to avoid code duplication
#

logger.info('Reformatting energies...')
n_total_states = n_sampled_states + n_unsampled_states
energy_matrix = np.zeros([n_total_states, n_iterations*n_replicas])
samples_per_state = np.zeros([n_total_states], dtype=int)
# Compute shift index for how many unsampled states there were.
# This assume that we set an equal number of unsampled states at the end points.
first_sampled_state = int(n_unsampled_states/2.0)
last_sampled_state = n_total_states - first_sampled_state
# Cast the sampled energy matrix from kln' to ln form.
energy_matrix[first_sampled_state:last_sampled_state, :] = multistate.MultiStateSamplerAnalyzer.reformat_energies_for_mbar(sampled_energy_matrix)
# Determine how many samples and which states they were drawn from.
unique_sampled_states, counts = np.unique(replicas_state_indices, return_counts=True)
# Assign those counts to the correct range of states.
samples_per_state[first_sampled_state:last_sampled_state][unique_sampled_states] = counts
# Add energies of unsampled states to the end points.
if n_unsampled_states > 0:
energy_matrix[[0, -1], :] = multistate.MultiStateSamplerAnalyzer.reformat_energies_for_mbar(unsampled_energy_matrix)

# TODO: Should we instead only run MBAR *after* we have already filtered out problematic snapshots?
logger.info('Estimating weights...')
mbar = MBAR(energy_matrix, samples_per_state)
# Extract weights
w_n = mbar.W_nk[:,fully_interacting_state]

# Extract unitcell lengths and angles
# TODO: Make this more general for non-rectilinear boxes
x = np.array(ncfile.variables['box_vectors'][retained_snapshot_indices,replica_index,:,:])
unitcell_lengths = x[:,[0,1,2],[0,1,2]]
unitcell_angles = 90.0 * np.ones(unitcell_lengths.shape, np.float32)

# Extract solute trajectory as MDTraj Trajectory
# NOTE: Only retained snapshots are extracted to speed things up
# NOTE: This will store the whole trajectory in memory
traj = md.Trajectory(ncfile.variables['positions'][retained_snapshot_indices,replica_index,solute_indices,:], topology.top, unitcell_lengths=unitcell_lengths, unitcell_angles=unitcell_angles)


# Remove counterions
# TODO: Is there a better way to eliminate everything but receptor and ligand?
logger.info(traj)
ion_dsl_selection = 'not (resname "Na+" or resname "Cl-")'
indices = traj.top.select(ion_dsl_selection)
traj = traj.atom_slice(indices)
logger.info(traj)

# Check snapshot weights are small
logger.info('Maximum weight from any given snapshot (SHOULD BE SMALL!): {:f}'.format(w_n.max()))
indices = np.argsort(-w_n)
MAX_SNAPSHOT_WEIGHT = 0.01
if w_n.max() > MAX_SNAPSHOT_WEIGHT:
filename = '%s-outlier.pdb' % (output_prefix)
logger.warning('WARNING: One snapshot is dominating the weights so clusters and populations will be unreliable')
logger.warning('Writing outlier to {}'.format(filename))
snapshot_index = w_n.argmax()
logger.warning('snaphot {:d} has weight {:f}'.format(snapshot_index, w_n.max()))
traj[snapshot_index].save(filename)

# Image molecules into periodic box, ensuring ligand is in closest image to receptor
traj.image_molecules(inplace=True)

# Compute minimum heavy atom distance from ligand to protein
residues = [residue for residue in traj.top.residues]
protein_residues = [residue.index for residue in traj.top.residues if residue.is_protein]
logger.info('There are {:d} protein residues'.format(len(protein_residues)))
ligand_residues = [residue.index for residue in traj.top.residues if not residue.is_protein]
logger.info('There are {:d} ligand residues'.format(len(ligand_residues)))

pairs = list(itertools.product(ligand_residues, protein_residues))
distances, contacts = md.compute_contacts(traj, contacts=pairs, scheme='closest-heavy', ignore_nonprotein=False)
min_distances = distances.min(1)
logger.info('Maximum ligand heavy atom distance from protein: {:f} nm'.format(min_distances.max()))
logger.info('Minimum ligand heavy atom distance from protein: {:f} nm'.format(min_distances.min()))

# Filter out snapshots where ligand is too far from the protein
filtered_snapshot_indices = np.where(min_distances <= ligand_filter_cutoff)[0]
logger.info('Retaining {:d} of {:d} snapshots where ligand heavy atoms are less than {:f} nm from protein'.format(len(filtered_snapshot_indices), len(traj), ligand_filter_cutoff))

# Filter out snapshots where ligand is too far from the protein
filtered_traj = traj[filtered_snapshot_indices]
filtered_w_n = np.array(w_n[filtered_snapshot_indices])
filtered_w_n /= filtered_w_n.sum() # renormalize

# Align receptor to first frame
atoms_to_align = filtered_traj.top.select(receptor_dsl_selection)
if (len(atoms_to_align) == 0):
raise Exception("Please check receptor_dsl_selection since no atoms were found in selection!")
logger.info('aligning on {:d} atoms from receptor_dsl_selection'.format(len(atoms_to_align)))
aligned_traj = filtered_traj.superpose(filtered_traj[0], frame=0, atom_indices=atoms_to_align)

# Extract ligand trajectory
ligand_atom_indices = aligned_traj.topology.select(ligand_dsl_selection)
ligand_trajectory = aligned_traj.atom_slice(ligand_atom_indices)
logger.info('{:d} atoms in ligand trajectory'.format(len(ligand_atom_indices)))

# Perform regular spatial clustering on ligand trajectory
nsnapshots, natoms, _ = ligand_trajectory.xyz.shape
x = np.array(ligand_trajectory.xyz).reshape([nsnapshots, natoms*3], order='C')
reg_space = RegularSpatial(d_min=3*natoms*ligand_rmsd_cutoff**2, metric='sqeuclidean').fit([x])
cluster_assignments = reg_space.fit_predict([x])[0]
nclusters = cluster_assignments.max() + 1
logger.info('There are {:d} clusters'.format(nclusters))

# Sort clusters by probability
cluster_probabilities = np.zeros([nclusters], np.float64)
for cluster_index in range(nclusters):
snapshot_indices = np.where(cluster_assignments == cluster_index)[0]
cluster_probabilities[cluster_index] = filtered_w_n[snapshot_indices].sum()
# Permute clusters
sorted_indices = np.argsort(-cluster_probabilities)
new_cluster_assignments = np.array(cluster_assignments)
for cluster_index in range(nclusters):
indices = np.where(cluster_assignments == sorted_indices[cluster_index])[0]
new_cluster_assignments[indices] = cluster_index
cluster_assignments = new_cluster_assignments
cluster_probabilities = cluster_probabilities[sorted_indices]

# Write cluster populations
for cluster_index in range(nclusters):
logger.info('Cluster {:5d} : {:12.8f}'.format(cluster_index, cluster_probabilities[cluster_index]))

cumsum = np.cumsum(cluster_probabilities)
cutoff_index = np.where(cumsum > cluster_filter_threshold)[0][0] # first index where weight is below threshold
nclusters = max(cutoff_index, 1)
logger.info('There are {:d} clusters after retaining only those where the cumulative weight exceeds {:f}'.format(nclusters, cluster_filter_threshold))

# Write reference protein conformation
receptor_atom_indices = aligned_traj.topology.select('protein')
filename = '%s-reference.pdb' % (output_prefix)
logger.info('Writing reference coordinates to {}'.format(filename))
aligned_traj[0].atom_slice(receptor_atom_indices).save(filename)

# Write aligned frames
for cluster_index in range(nclusters):
indices = np.where(cluster_assignments == cluster_index)[0]
# Remove indices with zero probability
indices = indices[filtered_w_n[indices] > 0.0]

nsnapshots = len(indices)
logger.info('Cluster {:5d} : pop {:12.8f} : {:8d} members'.format(cluster_index, cluster_probabilities[cluster_index], nsnapshots))

# Sample frames
filename = '%s-cluster%03d.pdb' % (output_prefix, cluster_index)
logger.info(' writing {}'.format(filename))
if nsnapshots <= nsnapshots_per_cluster:
aligned_traj[indices].save(filename)
else:
p = filtered_w_n[indices] / filtered_w_n[indices].sum()
sampled_indices = np.random.choice(indices, size=nsnapshots_per_cluster, p=p, replace=False)
aligned_traj[sampled_indices].save('%s-cluster%03d.pdb' % (output_prefix, cluster_index))
# Write cluster populations to a file
filename = '%s-populations.txt' % (output_prefix)
logger.info('Writing populations to {}'.format(filename))
outfile = open(filename, 'w')
for cluster_index in range(nclusters):
outfile.write('%05d %12.8f\n' % (cluster_index, cluster_probabilities[cluster_index]))
outfile.close()
Loading