From 34f0bc6b3f494ec46be7f4cb2af5058b1a3e4b70 Mon Sep 17 00:00:00 2001 From: Irfan Alibay Date: Wed, 22 Nov 2023 16:13:00 +0000 Subject: [PATCH] Use uuid for unit repeat ids for HFE protocol (#650) * use uuid for unit repeat ids for HFE protocol * Update test_openmm_afe_solvation_protocol.py --- openfe/protocols/openmm_afe/equil_solvation_afe_method.py | 5 +++-- openfe/tests/protocols/test_openmm_afe_solvation_protocol.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 086ec6d09..6de7daa1d 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -41,6 +41,7 @@ from openmmtools import multistate from typing import Dict, Optional, Union from typing import Any, Iterable +import uuid from gufe import ( settings, @@ -556,7 +557,7 @@ def _create( stateA=stateA, stateB=stateB, settings=self.settings, alchemical_components=alchem_comps, - generation=0, repeat_id=i, + generation=0, repeat_id=int(uuid.uuid4()), name=(f"Absolute Solvation, {alchname} solvent leg: " f"repeat {i} generation 0"), ) @@ -570,7 +571,7 @@ def _create( stateA=stateA, stateB=stateB, settings=self.settings, alchemical_components=alchem_comps, - generation=0, repeat_id=i, + generation=0, repeat_id=int(uuid.uuid4()), name=(f"Absolute Solvation, {alchname} vacuum leg: " f"repeat {i} generation 0"), ) diff --git a/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py b/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py index 5df0c0cfc..c6d02694e 100644 --- a/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py +++ b/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py @@ -552,8 +552,8 @@ def test_unit_tagging(benzene_solvation_dag, tmpdir): vac_repeats.add(ret.outputs['repeat_id']) else: solv_repeats.add(ret.outputs['repeat_id']) - assert vac_repeats == {0, 1, 2} - assert solv_repeats == {0, 1, 2} + # Repeat ids are random ints so just check their lengths + assert len(vac_repeats) == len(solv_repeats) == 3 def test_gather(benzene_solvation_dag, tmpdir):