Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add devtools for json results generation + replace old ones #691

Merged
merged 3 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added devtools/data/AHFEProtocol_json_results.gz
Binary file not shown.
Binary file added devtools/data/MDProtocol_json_results.gz
Binary file not shown.
Binary file added devtools/data/RHFEProtocol_json_results.gz
Binary file not shown.
124 changes: 124 additions & 0 deletions devtools/data/gen-serialized-results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import gzip
import json
import logging
import pathlib
import tempfile
from openff.toolkit import Molecule
from openff.units import unit
from kartograf.atom_aligner import align_mol_shape
from kartograf import KartografAtomMapper
import gufe
from gufe.tokenization import JSON_HANDLER
import openfe
from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocol
from openfe.protocols.openmm_afe import AbsoluteSolvationProtocol
from openfe.protocols.openmm_rfe import RelativeHybridTopologyProtocol


logger = logging.getLogger(__name__)

LIGA = "[H]C([H])([H])C([H])([H])C(=O)C([H])([H])C([H])([H])[H]"
LIGB = "[H]C([H])([H])C(=O)C([H])([H])C([H])([H])C([H])([H])[H]"


def get_molecule(smi, name):
m = Molecule.from_smiles(smi)
m.generate_conformers()
m.assign_partial_charges(partial_charge_method="am1bcc")
return openfe.SmallMoleculeComponent.from_openff(m, name=name)


def execute_and_serialize(dag, protocol, simname):
logger.info(f"running {simname}")
with tempfile.TemporaryDirectory() as tmpdir:
workdir = pathlib.Path(tmpdir)
dagres = gufe.protocols.execute_DAG(
dag,
shared_basedir=workdir,
scratch_basedir=workdir,
keep_shared=False,
n_retries=3
)
protres = protocol.gather([dagres])

outdict = {
"estimate": protres.get_estimate(),
"uncertainty": protres.get_uncertainty(),
"protocol_result": protres.to_dict(),
"unit_results": {
unit.key: unit.to_keyed_dict()
for unit in dagres.protocol_unit_results
}
}

with gzip.open(f"{simname}_json_results.gz", 'wt') as zipfile:
json.dump(outdict, zipfile, cls=JSON_HANDLER.encoder)


def generate_md_json(smc):
settings = PlainMDProtocol.default_settings()
settings.simulation_settings.equilibration_length_nvt = 0.01 * unit.nanosecond
settings.simulation_settings.equilibration_length = 0.01 * unit.nanosecond
settings.simulation_settings.production_length = 0.01 * unit.nanosecond
settings.system_settings.nonbonded_method = "nocutoff"
protocol = PlainMDProtocol(settings=settings)
system = openfe.ChemicalSystem({"ligand": smc})
dag = protocol.create(stateA=system, stateB=system, mapping=None)

execute_and_serialize(dag, protocol, "MDProtocol")


def generate_ahfe_json(smc):
settings = AbsoluteSolvationProtocol.default_settings()
settings.solvent_simulation_settings.equilibration_length = 10 * unit.picosecond
settings.solvent_simulation_settings.production_length = 500 * unit.picosecond
settings.vacuum_simulation_settings.equilibration_length = 10 * unit.picosecond
settings.vacuum_simulation_settings.production_length = 1000 * unit.picosecond
settings.alchemical_settings.lambda_elec_windows = 5
settings.alchemical_settings.lambda_vdw_windows = 9
settings.alchemsampler_settings.n_repeats = 3
settings.alchemsampler_settings.n_replicas = 14
settings.alchemsampler_settings.online_analysis_target_error = 0.2 * unit.boltzmann_constant * unit.kelvin
settings.vacuum_engine_settings.compute_platform = 'CPU'
settings.solvent_engine_settings.compute_platform = 'CUDA'

protocol = AbsoluteSolvationProtocol(settings=settings)
sysA = openfe.ChemicalSystem(
{"ligand": smc, "solvent": openfe.SolventComponent()}
)
sysB = openfe.ChemicalSystem(
{"solvent": openfe.SolventComponent()}
)

dag = protocol.create(stateA=sysA, stateB=sysB, mapping=None)

execute_and_serialize(dag, protocol, "AHFEProtocol")


def generate_rfe_json(smcA, smcB):
settings = RelativeHybridTopologyProtocol.default_settings()
settings.simulation_settings.equilibration_length = 10 * unit.picosecond
settings.simulation_settings.production_length = 250 * unit.picosecond
settings.system_settings.nonbonded_method = "nocutoff"
protocol = RelativeHybridTopologyProtocol(settings=settings)

a_smcB = align_mol_shape(smcB, ref_mol=smcA)
mapper = KartografAtomMapper(atom_map_hydrogens=True)
mapping = next(mapper.suggest_mappings(smcA, a_smcB))

systemA = openfe.ChemicalSystem({'ligand': smcA})
systemB = openfe.ChemicalSystem({'ligand': a_smcB})

dag = protocol.create(
stateA=systemA, stateB=systemB, mapping={'ligands': mapping}
)

execute_and_serialize(dag, protocol, "RHFEProtocol")


if __name__ == "__main__":
molA = get_molecule(LIGA, "ligandA")
molB = get_molecule(LIGB, "ligandB")
generate_md_json(molA)
generate_ahfe_json(molA)
generate_rfe_json(molA, molB)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed openfe/tests/data/openmm_md/md_results.json.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
8 changes: 4 additions & 4 deletions openfe/tests/protocols/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ def toluene_many_solv_system(benzene_modifications):

@pytest.fixture
def rfe_transformation_json() -> str:
"""string of a RFE result of quickrun"""
"""string of a RFE results similar to quickrun"""
d = resources.files('openfe.tests.data.openmm_rfe')

with gzip.open((d / 'RFE-ProtocolUnitResult-0f3457edf947483aa03d0f4fe88bf566.json.gz').as_posix(), 'r') as f: # type: ignore
with gzip.open((d / 'RHFEProtocol_json_results.gz').as_posix(), 'r') as f: # type: ignore
return f.read().decode() # type: ignore


Expand All @@ -208,7 +208,7 @@ def afe_solv_transformation_json() -> str:
string of a Absolute Solvation result (CN in water) generated by quickrun
"""
d = resources.files('openfe.tests.data.openmm_afe')
fname = "CN_absolute_solvation_transformation.json.gz"
fname = "AHFEProtocol_json_results.gz"

with gzip.open((d / fname).as_posix(), 'r') as f: # type: ignore
return f.read().decode() # type: ignore
Expand All @@ -220,7 +220,7 @@ def md_json() -> str:
string of a MD result (TYK ligand lig_ejm_31 in water) generated by quickrun
"""
d = resources.files('openfe.tests.data.openmm_md')
fname = "md_results.json.gz"
fname = "MDProtocol_json_results.gz"

with gzip.open((d / fname).as_posix(), 'r') as f: # type: ignore
return f.read().decode() # type: ignore
10 changes: 5 additions & 5 deletions openfe/tests/protocols/test_openmm_afe_solvation_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,15 +600,15 @@ def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()

assert est
assert est.m == pytest.approx(-2.977553138764437)
assert est.m == pytest.approx(-2.7514342223922856)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)

def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()

assert est
assert est.m == pytest.approx(0.19617297299036018)
assert est.m == pytest.approx(0.1417058859527063)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)

Expand Down Expand Up @@ -649,7 +649,7 @@ def test_get_overlap_matrices(self, key, protocolresult):

ovp1 = ovp[key][0]
assert isinstance(ovp1['matrix'], np.ndarray)
assert ovp1['matrix'].shape == (15, 15)
assert ovp1['matrix'].shape == (14, 14)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_get_replica_transition_statistics(self, key, protocolresult):
Expand All @@ -661,8 +661,8 @@ def test_get_replica_transition_statistics(self, key, protocolresult):
rpx1 = rpx[key][0]
assert 'eigenvalues' in rpx1
assert 'matrix' in rpx1
assert rpx1['eigenvalues'].shape == (15,)
assert rpx1['matrix'].shape == (15, 15)
assert rpx1['eigenvalues'].shape == (14,)
assert rpx1['matrix'].shape == (14, 14)

@pytest.mark.parametrize('key', ['solvent', 'vacuum'])
def test_equilibration_iterations(self, key, protocolresult):
Expand Down
4 changes: 2 additions & 2 deletions openfe/tests/protocols/test_openmm_equil_rfe_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,15 +1414,15 @@ def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()

assert est
assert est.m == pytest.approx(3.5531577581450953)
assert est.m == pytest.approx(16.887389)
assert isinstance(est, unit.Quantity)
assert est.is_compatible_with(unit.kilojoule_per_mole)

def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()

assert est
assert est.m == pytest.approx(0.03431704941311493)
assert est.m == pytest.approx(0.12354885)
assert isinstance(est, unit.Quantity)
assert est.is_compatible_with(unit.kilojoule_per_mole)

Expand Down
2 changes: 1 addition & 1 deletion openfe/tests/protocols/test_solvation_afe_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_key_stable(self):

class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationProtocolResult
key = "AbsoluteSolvationProtocolResult-e7d74b8ccc009d071b8c6eb0420da4bf"
key = "AbsoluteSolvationProtocolResult-291fef7bbbad3ffda898be6c01a22f16"
repr = f"<{key}>"

@pytest.fixture()
Expand Down
Loading