Skip to content

Commit

Permalink
fix typing, fix 2D plotting issues (#896)
Browse files Browse the repository at this point in the history
* fix typing, fix 2D plotting issues

* Update plotting.py

* Ignore extra flatten calls

* Add smoke tests

* Add changelog

* ignore one more flatten call

---------

Co-authored-by: Mike Henry <11765982+mikemhenry@users.noreply.github.com>
  • Loading branch information
IAlibay and mikemhenry authored Jul 23, 2024
1 parent cc5ed8c commit 7615ccc
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 8 deletions.
25 changes: 25 additions & 0 deletions news/pr_896.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
**Added:**

* <news item>

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* 2D RMSD plotting now allows for fewer than 5 states (PR #896).
* 2D RMSD plotting no longer draws empty axes when
the number of states - 1 is not divisible by 4 (PR #896).

**Security:**

* <news item>
36 changes: 29 additions & 7 deletions openfe/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,40 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes:
)

# anotate axes
base_settings = {
base_settings: dict[str, Union[str, int]] = {
'size': 10, 'va': 'center', 'ha': 'center', 'color': 'k',
'family': 'sans-serif'
}
for i in range(num_states):
ax.annotate(
i, xy=(i + 0.5, 1), xytext=(i + 0.5, num_states + 0.5),
text=f"{i}",
xy=(i + 0.5, 1),
xytext=(i + 0.5, num_states + 0.5),
xycoords='data',
textcoords=None,
arrowprops=None,
annotation_clip=None,
**base_settings,
)
ax.annotate(
i, xy=(-0.5, num_states - (num_states - 0.5)),
text=f"{i}",
xy=(-0.5, num_states - (num_states - 0.5)),
xytext=(-0.5, num_states - (i + 0.5)),
xycoords='data',
textcoords=None,
arrowprops=None,
annotation_clip=None,
**base_settings,
)

ax.annotate(
r"$\lambda$", xy=(-0.5, num_states - (num_states - 0.5)),
r"$\lambda$",
xy=(-0.5, num_states - (num_states - 0.5)),
xytext=(-0.5, num_states + 0.5),
xycoords='data',
textcoords=None,
arrowprops=None,
annotation_clip=None,
**base_settings,
)

Expand Down Expand Up @@ -278,15 +294,21 @@ def plot_2D_rmsd(data: list[list[float]],
fig, axes = plt.subplots(nrows, 4)

for i, (arr, ax) in enumerate(
zip(twod_rmsd_arrs, chain.from_iterable(axes))):
zip(twod_rmsd_arrs, axes.flatten())): # type: ignore
ax.imshow(arr,
vmin=0, vmax=vmax,
cmap=plt.get_cmap('cividis'))
ax.axis('off') # turn off ticks/labels
ax.set_title(f'State {i}')

plt.colorbar(axes[0][0].images[0],
cax=axes[-1][-1],
# if we have any leftover plots then we turn them off
# except the last one!
overage = len(axes.flatten()) - len(twod_rmsd_arrs) # type: ignore
for i in range(overage, len(axes.flatten())-1): # type: ignore
axes.flatten()[i].set_axis_off() # type: ignore

plt.colorbar(axes.flatten()[0].images[0], # type: ignore
cax=axes.flatten()[-1], # type: ignore
label="RMSD scale (A)",
orientation="horizontal")

Expand Down
14 changes: 14 additions & 0 deletions openfe/tests/analysis/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pytest
from openfe.analysis.plotting import (
plot_lambda_transition_matrix,
plot_2D_rmsd,
)


Expand Down Expand Up @@ -150,3 +152,15 @@ def test_mbar_overlap_plot_high_warn(matrix):
def test_mbar_overlap_plot():
ax = plot_lambda_transition_matrix(MBAR_OVERLAP_NORMAL)
assert isinstance(ax, matplotlib.axes.Axes)


@pytest.mark.parametrize('num', [i for i in range(1, 30)])
def test_plot_2D_rmsd(num):
"""
Smoke test:
Loop through and test plotting fictitious 2D data
"""
points = num * (num-1) // 2
data = [[0.5 for x in range(points)] for i in range(num)]
fig = plot_2D_rmsd(data)
plt.close(fig)
2 changes: 1 addition & 1 deletion openfe/tests/protocols/test_openmm_equil_rfe_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,7 @@ def test_filenotfound_replica_states(self, protocolresult):
def test_get_charge_difference(mapping_name, result, request):
mapping = request.getfixturevalue(mapping_name)
if result != 0:
ion = 'Na\+' if result == -1 else 'Cl\-'
ion = r'Na\+' if result == -1 else r'Cl\-'
wmsg = (f"A charge difference of {result} is observed "
"between the end states. This will be addressed by "
f"transforming a water into a {ion} ion")
Expand Down

0 comments on commit 7615ccc

Please sign in to comment.