Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split gather #502

Merged
merged 13 commits into from
Aug 2, 2023
269 changes: 171 additions & 98 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,40 @@
from openfecli import OFECommandPlugin
import pathlib

from typing import Tuple
dwhswenson marked this conversation as resolved.
Show resolved Hide resolved

def _get_column(val):
import numpy as np
if val == 0:
return 0

log10 = np.log10(val)

if log10 >= 0.0:
col = np.floor(log10 + 1)
else:
col = np.floor(log10)
return int(col)

dwhswenson marked this conversation as resolved.
Show resolved Hide resolved

def format_estimate_uncertainty(
est: float,
unc: float,
unc_prec: int = 1,
) -> Tuple[str, str]:
import numpy as np
# get the last column needed for uncertainty
unc_col = _get_column(unc) - (unc_prec - 1)

if unc_col < 0:
est_str = f"{est:.{-unc_col}f}"
unc_str = f"{unc:.{-unc_col}f}"
else:
est_str = f"{np.round(est, -unc_col + 1)}"
unc_str = f"{np.round(unc, -unc_col + 1)}"

return est_str, unc_str


def is_results_json(f):
# sanity check on files before we try and deserialize
Expand Down Expand Up @@ -51,6 +85,117 @@ def legacy_get_type(res_fn):
return 'complex'


def _get_ddgs(legs):
import numpy as np
DDGs = []
for ligpair, vals in sorted(legs.items()):
DDGbind = None
DDGhyd = None
bind_unc = None
hyd_unc = None

if 'complex' in vals and 'solvent' in vals:
DG1_mag, DG1_unc = vals['complex']
DG2_mag, DG2_unc = vals['solvent']
if not ((DG1_mag is None) or (DG2_mag is None)):
# DDG(2,1)bind = DG(1->2)complex - DG(1->2)solvent
DDGbind = (DG1_mag - DG2_mag).m
bind_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))
elif 'solvent' in vals and 'vacuum' in vals:
DG1_mag, DG1_unc = vals['solvent']
DG2_mag, DG2_unc = vals['vacuum']
if not ((DG1_mag is None) or (DG2_mag is None)):
DDGhyd = (DG1_mag - DG2_mag).m
hyd_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))
else: # -no-cov-
raise RuntimeError(f"Unknown DDG type for {vals}")

DDGs.append((*ligpair, DDGbind, bind_unc, DDGhyd, hyd_unc))

return DDGs

dwhswenson marked this conversation as resolved.
Show resolved Hide resolved

def _write_ddg(legs, writer):
DDGs = _get_ddgs(legs)
writer.writerow(["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)",
"uncertainty (kcal/mol)"])
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
name = f"{ligB}, {ligA}"
if DDGbind is not None:
DDGbind, bind_unc = format_estimate_uncertainty(DDGbind,
bind_unc)
writer.writerow([ligA, ligB, DDGbind, bind_unc])
if DDGhyd is not None:
DDGhyd, hyd_unc = format_estimate_uncertainty(DDGbind,
bind_unc)
writer.writerow([ligA, ligB, DDGhyd, hyd_unc])

dwhswenson marked this conversation as resolved.
Show resolved Hide resolved

def _write_raw_dg(legs, writer):
writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the optimal case I would like this to write out all the individual replicas. Any thoughts on the necessary changes to results to ensure this? (I suspect it's a method we'd have to add at the gufe level)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be protocol-specific. Our default protocol has, in its outputs dict, outputs["unit_estimate"], which we could extract. But we make no promise that every unit will have this. Indeed, it doesn't make sense to do so: if you used a separate parametrization unit, there would be no meaning to that unit having a "unit_estimate" result.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be something we need to add to the abstract class - a method for getting a breakdown by repeat.

"uncertainty (kcal/mol)"])
for ligpair, vals in sorted(legs.items()):
name = ', '.join(ligpair)
for simtype, (m, u) in sorted(vals.items()):
if m is None:
m, u = 'NaN', 'NaN'
else:
m, u = format_estimate_uncertainty(m.m, u.m)
writer.writerow([simtype, *ligpair, m, u])

dwhswenson marked this conversation as resolved.
Show resolved Hide resolved

def _write_dg_mle(legs, writer):
import networkx as nx
import numpy as np
from cinnabar.stats import mle
DDGs = _get_ddgs(legs)
MLEs = []
# 4b) perform MLE
g = nx.DiGraph()
nm_to_idx = {}
DDGbind_count = 0
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
if DDGbind is None:
continue
DDGbind_count += 1

# tl;dr this is super paranoid, but safer for now:
# cinnabar seems to rely on the ordering of values within the graph
# to correspond to the matrix that comes out from mle()
# internally they also convert the ligand names to ints, which I think
# has a side effect of giving the graph nodes a predictable order.
# fwiw this code didn't affect ordering locally
try:
idA = nm_to_idx[ligA]
except KeyError:
idA = len(nm_to_idx)
nm_to_idx[ligA] = idA
try:
idB = nm_to_idx[ligB]
except KeyError:
idB = len(nm_to_idx)
nm_to_idx[ligB] = idB

g.add_edge(
idA, idB, calc_DDG=DDGbind, calc_dDDG=bind_unc,
)
if DDGbind_count > 2:
idx_to_nm = {v: k for k, v in nm_to_idx.items()}

f_i, df_i = mle(g, factor='calc_DDG')
df_i = np.diagonal(df_i) ** 0.5

for node, f, df in zip(g.nodes, f_i, df_i):
ligname = idx_to_nm[node]
MLEs.append((ligname, f, df))

writer.writerow(["ligand", "DG(MLE) (kcal/mol)",
"uncertainty (kcal/mol)"])
for ligA, DG, unc_DG in MLEs:
DG, unc_DG = format_estimate_uncertainty(DG, unc_DG)
writer.writerow([ligA, DG, unc_DG])


@click.command(
'gather',
short_help="Gather result jsons for network of RFE results into a TSV file"
Expand All @@ -59,10 +204,20 @@ def legacy_get_type(res_fn):
type=click.Path(dir_okay=True, file_okay=False,
path_type=pathlib.Path),
required=True)
@click.option(
'--report',
type=click.Choice(['dg', 'ddg', 'legs'], case_sensitive=False),
default="dg", show_default=True,
help=(
"What data to report. 'dg' gives maximum-likelihood estimate of "
"asbolute deltaG, 'ddg' gives delta-delta-G, and 'leg' gives the "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"asbolute deltaG, 'ddg' gives delta-delta-G, and 'leg' gives the "
"asbolute deltaG, 'ddg' gives delta-delta-G, and 'legs' gives the "

ddg_legs? environments? transformations? raw?

@hannahbaumann @RiesBen please make a choice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote for dg_raw, though it's not a strong preference.

"raw result of the deltaG for a leg."
)
)
@click.option('output', '-o',
type=click.File(mode='w'),
default='-')
def gather(rootdir, output):
def gather(rootdir, output, report):
"""Gather simulation result jsons of relative calculations to a tsv file

Will walk ROOTDIR recursively and find all results files ending in .json
Expand Down Expand Up @@ -92,14 +247,7 @@ def gather(rootdir, output):
"""
from collections import defaultdict
import glob
import networkx as nx
import numpy as np
from cinnabar.stats import mle

def dp2(v: float) -> str:
# turns 0.0012345 -> '0.0012', round() would get this wrong
return np.format_float_positional(v, precision=2, trim='0',
fractional=False)
import csv

# 1) find all possible jsons
json_fns = glob.glob(str(rootdir) + '/**/*json', recursive=True)
Expand Down Expand Up @@ -129,103 +277,28 @@ def dp2(v: float) -> str:

legs[names][simtype] = result['estimate'], result['uncertainty']

# 4a for each ligand pair, resolve legs
DDGs = []
for ligpair, vals in sorted(legs.items()):
DDGbind = None
DDGhyd = None
bind_unc = None
hyd_unc = None

if 'complex' in vals and 'solvent' in vals:
DG1_mag, DG1_unc = vals['complex']
DG2_mag, DG2_unc = vals['solvent']
if not ((DG1_mag is None) or (DG2_mag is None)):
# DDG(2,1)bind = DG(1->2)complex - DG(1->2)solvent
DDGbind = (DG1_mag - DG2_mag).m
bind_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))
if 'solvent' in vals and 'vacuum' in vals:
DG1_mag, DG1_unc = vals['solvent']
DG2_mag, DG2_unc = vals['vacuum']
if not ((DG1_mag is None) or (DG2_mag is None)):
DDGhyd = (DG1_mag - DG2_mag).m
hyd_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))

DDGs.append((*ligpair, DDGbind, bind_unc, DDGhyd, hyd_unc))

MLEs = []
# 4b) perform MLE
g = nx.DiGraph()
nm_to_idx = {}
DDGbind_count = 0
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
if DDGbind is None:
continue
DDGbind_count += 1

# tl;dr this is super paranoid, but safer for now:
# cinnabar seems to rely on the ordering of values within the graph
# to correspond to the matrix that comes out from mle()
# internally they also convert the ligand names to ints, which I think
# has a side effect of giving the graph nodes a predictable order.
# fwiw this code didn't affect ordering locally
try:
idA = nm_to_idx[ligA]
except KeyError:
idA = len(nm_to_idx)
nm_to_idx[ligA] = idA
try:
idB = nm_to_idx[ligB]
except KeyError:
idB = len(nm_to_idx)
nm_to_idx[ligB] = idB

g.add_edge(
idA, idB, calc_DDG=DDGbind, calc_dDDG=bind_unc,
)
if DDGbind_count > 2:
idx_to_nm = {v: k for k, v in nm_to_idx.items()}

f_i, df_i = mle(g, factor='calc_DDG')
df_i = np.diagonal(df_i) ** 0.5
writer = csv.writer(
output,
delimiter="\t",
lineterminator="\n", # to exactly reproduce previous, prefer "\r\n"
)

for node, f, df in zip(g.nodes, f_i, df_i):
ligname = idx_to_nm[node]
MLEs.append((ligname, f, df))

output.write('measurement\ttype\tligand_i\tligand_j\testimate (kcal/mol)'
'\tuncertainty (kcal/mol)\n')
# 5a) write out MLE values
for ligA, DG, unc_DG in MLEs:
DG, unc_DG = dp2(DG), dp2(unc_DG)
output.write(f'DGbind({ligA})\tDG(MLE)\tZero\t{ligA}\t{DG}\t{unc_DG}\n')

# 5b) write out DDG values
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
name = f"{ligB}, {ligA}"
if DDGbind is not None:
DDGbind, bind_unc = dp2(DDGbind), dp2(bind_unc)
output.write(f'DDGbind({name})\tRBFE\t{ligA}\t{ligB}'
f'\t{DDGbind}\t{bind_unc}\n')
if DDGhyd is not None:
DDGhyd, hyd_unc = dp2(DDGhyd), dp2(hyd_unc)
output.write(f'DDGhyd({name})\tRHFE\t{ligA}\t{ligB}\t'
f'{DDGhyd}\t{hyd_unc}\n')

# 5c) write out each leg
for ligpair, vals in sorted(legs.items()):
name = ', '.join(ligpair)
for simtype, (m, u) in sorted(vals.items()):
if m is None:
m, u = 'NaN', 'NaN'
else:
m, u = dp2(m.m), dp2(u.m)
output.write(f'DG{simtype}({name})\t{simtype}\t{ligpair[0]}\t'
f'{ligpair[1]}\t{m}\t{u}\n')
writing_func = {
'dg': _write_dg_mle,
'ddg': _write_ddg,
'legs': _write_raw_dg,
}[report.lower()]
writing_func(legs, writer)


PLUGIN = OFECommandPlugin(
command=gather,
section='Quickrun Executor',
requires_ofe=(0, 6),
)

if __name__ == "__main__":
gather()