Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change get_simsteps function to handle NVT equilibration in plain MD protocol #647

Merged
merged 7 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,13 @@ def _run_simulation(
# Get the relevant simulation steps
mc_steps = settings['integrator_settings'].n_steps.m

equil_steps, prod_steps = settings_validation.get_simsteps(
equil_length=settings['simulation_settings'].equilibration_length,
prod_length=settings['simulation_settings'].production_length,
equil_steps = settings_validation.get_simsteps(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, de-hardcoding re-usable methods is super useful!

sim_length=settings['simulation_settings'].equilibration_length,
timestep=settings['integrator_settings'].timestep,
mc_steps=mc_steps,
)
prod_steps = settings_validation.get_simsteps(
sim_length=settings['simulation_settings'].production_length,
timestep=settings['integrator_settings'].timestep,
mc_steps=mc_steps,
)
Expand Down
11 changes: 7 additions & 4 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,13 @@ def run(self, *, dry=False, verbose=True,
settings_validation.validate_timestep(
forcefield_settings.hydrogen_mass, timestep
)
equil_steps, prod_steps = settings_validation.get_simsteps(
equil_length=sim_settings.equilibration_length,
prod_length=sim_settings.production_length,
timestep=timestep, mc_steps=mc_steps
equil_steps = settings_validation.get_simsteps(
sim_length=sim_settings.equilibration_length,
timestep=timestep, mc_steps=mc_steps,
)
prod_steps = settings_validation.get_simsteps(
sim_length=sim_settings.production_length,
timestep=timestep, mc_steps=mc_steps,
)

solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA)
Expand Down
42 changes: 16 additions & 26 deletions openfe/protocols/openmm_utils/settings_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,37 @@ def validate_timestep(hmass: float, timestep: unit.Quantity):
raise ValueError(errmsg)


def get_simsteps(equil_length: unit.Quantity, prod_length: unit.Quantity,
timestep: unit.Quantity, mc_steps: int) -> Tuple[int, int]:
def get_simsteps(sim_length: unit.Quantity,
timestep: unit.Quantity, mc_steps: int) -> int:
"""
Gets and validates the number of equilibration and production steps.
Gets and validates the number of simulation steps.

Parameters
----------
equil_length : unit.Quantity
Simulation equilibration length.
prod_length : unit.Quantity
Simulation production length.
sim_length : unit.Quantity
Simulation length.
timestep : unit.Quantity
Integration timestep.
mc_steps : int
Number of integration timesteps between MCMC moves.

Returns
-------
equil_steps : int
The number of equilibration timesteps.
prod_steps : int
The number of production timesteps.
sim_steps : int
The number of simulation timesteps.
"""

equil_time = round(equil_length.to('attosecond').m)
prod_time = round(prod_length.to('attosecond').m)
sim_time = round(sim_length.to('attosecond').m)
ts = round(timestep.to('attosecond').m)

equil_steps, mod = divmod(equil_time, ts)
sim_steps, mod = divmod(sim_time, ts)
if mod != 0:
raise ValueError("Equilibration time not divisible by timestep")
prod_steps, mod = divmod(prod_time, ts)
if mod != 0:
raise ValueError("Production time not divisible by timestep")
raise ValueError("Simulation time not divisible by timestep")

for var in [("Equilibration", equil_steps, equil_time),
("Production", prod_steps, prod_time)]:
if (var[1] % mc_steps) != 0:
errmsg = (f"{var[0]} time {var[2]/1000000} ps should contain a "
"number of steps divisible by the number of integrator "
f"timesteps between MC moves {mc_steps}")
raise ValueError(errmsg)
if (sim_steps % mc_steps) != 0:
errmsg = (f"Simulation time {sim_time/1000000} ps should contain a "
"number of steps divisible by the number of integrator "
f"timesteps between MC moves {mc_steps}")
raise ValueError(errmsg)

return equil_steps, prod_steps
return sim_steps
49 changes: 18 additions & 31 deletions openfe/tests/protocols/test_openmmutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,30 @@ def test_validate_timestep():
settings_validation.validate_timestep(2.0, 4.0 * unit.femtoseconds)


@pytest.mark.parametrize('e,p,ts,mc,es,ps', [
[1 * unit.nanoseconds, 5 * unit.nanoseconds, 4 * unit.femtoseconds,
250, 250000, 1250000],
[1 * unit.picoseconds, 1 * unit.picoseconds, 2 * unit.femtoseconds,
250, 500, 500],
@pytest.mark.parametrize('s,ts,mc,es', [
[5 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 1250000],
[1 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 250000],
[1 * unit.picoseconds, 2 * unit.femtoseconds, 250, 500],
])
def test_get_simsteps(e, p, ts, mc, es, ps):
equil_steps, prod_steps = settings_validation.get_simsteps(e, p, ts, mc)
def test_get_simsteps(s, ts, mc, es):
sim_steps = settings_validation.get_simsteps(s, ts, mc)

assert equil_steps == es
assert prod_steps == ps
assert sim_steps == es


@pytest.mark.parametrize('nametype, timelengths', [
['Equilibration', [1.003 * unit.picoseconds, 1 * unit.picoseconds]],
['Production', [1 * unit.picoseconds, 1.003 * unit.picoseconds]],
])
def test_get_simsteps_indivisible_simtime(nametype, timelengths):
errmsg = f"{nametype} time not divisible by timestep"
def test_get_simsteps_indivisible_simtime():
errmsg = "Simulation time not divisible by timestep"
timelength = 1.003 * unit.picosecond
with pytest.raises(ValueError, match=errmsg):
settings_validation.get_simsteps(
timelengths[0],
timelengths[1],
2 * unit.femtoseconds,
100)
settings_validation.get_simsteps(timelength, 2 * unit.femtoseconds, 100)


@pytest.mark.parametrize('nametype, timelengths', [
['Equilibration', [1 * unit.picoseconds, 10 * unit.picoseconds]],
['Production', [10 * unit.picoseconds, 1 * unit.picoseconds]],
])
def test_mc_indivisible(nametype, timelengths):
errmsg = f"{nametype} time 1.0 ps should contain"
def test_mc_indivisible():
errmsg = "Simulation time 1.0 ps should contain"
timelength = 1 * unit.picoseconds
with pytest.raises(ValueError, match=errmsg):
settings_validation.get_simsteps(
timelengths[0], timelengths[1],
2 * unit.femtoseconds, 1000)
timelength, 2 * unit.femtoseconds, 1000)


def test_get_alchemical_components(benzene_modifications,
Expand All @@ -90,7 +77,7 @@ def test_get_alchemical_components(benzene_modifications,

def test_duplicate_chemical_components(benzene_modifications):
stateA = openfe.ChemicalSystem({'A': benzene_modifications['toluene'],
'B': benzene_modifications['toluene'],})
'B': benzene_modifications['toluene'], })
stateB = openfe.ChemicalSystem({'A': benzene_modifications['toluene']})

errmsg = "state A components B:"
Expand Down Expand Up @@ -139,7 +126,7 @@ def test_multiple_proteins(T4_protein_component):
def test_get_components_gas(benzene_modifications):

state = openfe.ChemicalSystem({'A': benzene_modifications['benzene'],
'B': benzene_modifications['toluene'],})
'B': benzene_modifications['toluene'], })

s, p, mols = system_validation.get_components(state)

Expand All @@ -152,7 +139,7 @@ def test_components_solvent(benzene_modifications):

state = openfe.ChemicalSystem({'S': openfe.SolventComponent(),
'A': benzene_modifications['benzene'],
'B': benzene_modifications['toluene'],})
'B': benzene_modifications['toluene'], })

s, p, mols = system_validation.get_components(state)

Expand Down