diff --git a/db/python/layers/sample.py b/db/python/layers/sample.py index dca54c8d..11dea6d1 100644 --- a/db/python/layers/sample.py +++ b/db/python/layers/sample.py @@ -8,12 +8,14 @@ from db.python.layers.sequencing_group import SequencingGroupLayer from db.python.tables.sample import SampleFilter, SampleTable from db.python.utils import NoOpAenter, NotFoundError +from models.models.assay import AssayUpsertInternal from models.models.project import ( FullWriteAccessRoles, ProjectId, ReadAccessRoles, ) from models.models.sample import SampleInternal, SampleUpsertInternal +from models.models.sequencing_group import SequencingGroupUpsertInternal from models.utils.sample_id_format import sample_id_format_list @@ -223,10 +225,14 @@ async def upsert_sample( with_function = ( self.connection.connection.transaction if open_transaction else NoOpAenter ) - # safely ignore nested samples here async with with_function(): + alayer = AssayLayer(self.connection) + for r in self.unwrap_nested_samples([sample]): s = r.sample + sample_parent_id = getattr(r.parent, 'id', sample_parent_id) + sample_root_id = getattr(r.root, 'id', sample_root_id) + if not s.id: s.id = await self.st.insert_sample( external_ids=s.external_ids, @@ -235,41 +241,49 @@ async def upsert_sample( meta=s.meta, participant_id=s.participant_id, project=project, - sample_parent_id=r.parent.id if r.parent else sample_parent_id, - sample_root_id=r.root.id if r.root else sample_root_id, + sample_parent_id=sample_parent_id, + sample_root_id=sample_root_id, ) else: - # Otherwise update await self.st.update_sample( - id_=s.id, # type: ignore + id_=s.id, external_ids=s.external_ids, meta=s.meta, participant_id=s.participant_id, type_=s.type, active=s.active, - sample_parent_id=r.parent.id if r.parent else sample_parent_id, - sample_root_id=r.root.id if r.root else sample_root_id, + sample_parent_id=sample_parent_id, + sample_root_id=sample_root_id, ) - if sample.sequencing_groups: + if process_sequencing_groups and sample.sequencing_groups: sglayer = SequencingGroupLayer(self.connection) - for seqg in sample.sequencing_groups: - seqg.sample_id = sample.id - - if process_sequencing_groups: - await sglayer.upsert_sequencing_groups(sample.sequencing_groups) + self.set_sample_ids(s.id, sample.sequencing_groups) + await sglayer.upsert_sequencing_groups(sample.sequencing_groups) - if sample.non_sequencing_assays: - alayer = AssayLayer(self.connection) - for assay in sample.non_sequencing_assays: - assay.sample_id = sample.id - if process_assays: - await alayer.upsert_assays( - sample.non_sequencing_assays, open_transaction=False - ) + if process_assays and sample.non_sequencing_assays: + self.set_sample_ids(s.id, sample.non_sequencing_assays) + await alayer.upsert_assays( + sample.non_sequencing_assays, open_transaction=False + ) return sample + def set_sample_ids( + self, + sample_id: int, + internal_objs: list[SequencingGroupUpsertInternal] | list[AssayUpsertInternal], + ): + """ + Set the sample id for upserting sequencing group or assay + These internal upsert models will be children of a SampleUpsertInternal + but may not have the correct sample_id set. This is to ensure that they do. + """ + for obj in internal_objs: + assert hasattr(obj, 'sample_id') + assert obj.sample_id is None or obj.sample_id == sample_id + obj.sample_id = sample_id + async def upsert_samples( self, samples: list[SampleUpsertInternal], @@ -277,6 +291,7 @@ async def upsert_samples( project: ProjectId = None, ) -> list[SampleUpsertInternal]: """Batch upsert a list of samples with sequences""" + alayer = AssayLayer(self.connection) seqglayer: SequencingGroupLayer = SequencingGroupLayer(self.connection) with_function = ( @@ -292,6 +307,8 @@ async def upsert_samples( async with with_function(): # Create or update samples + sequencing_groups: list[SequencingGroupUpsertInternal] = [] + assays: list[AssayUpsertInternal] = [] for sample in samples: await self.upsert_sample( sample, @@ -301,21 +318,13 @@ async def upsert_samples( open_transaction=False, ) - # Upsert all sequencing_groups (in turn relevant assays) - sequencing_groups = [ - seqg for sample in samples for seqg in (sample.sequencing_groups or []) - ] - if sequencing_groups: - await seqglayer.upsert_sequencing_groups(sequencing_groups) + # Collect all sequencing_groups and assays + sequencing_groups.extend(getattr(sample, 'sequencing_groups', [])) + assays.extend(getattr(sample, 'non_sequencing_assays', [])) - assays = [ - assay - for sample in samples - for assay in (sample.non_sequencing_assays or []) - ] - if assays: - alayer = AssayLayer(self.connection) - await alayer.upsert_assays(assays, open_transaction=False) + # Upsert all sequencing_groups (in turn relevant assays) + await seqglayer.upsert_sequencing_groups(sequencing_groups) + await alayer.upsert_assays(assays, open_transaction=False) return samples @@ -340,57 +349,38 @@ def unwrap_nested_samples( out the insert order, keeping reference to the root, and parent. Just keep a soft limit on the depth, as we don't want to go too deep. - - NB: Opting for a non-recursive approach here, as I'm a bit afraid of recursive - Python after a weird Hail Batch thing, and sounded like a nightmare to debug """ retval: list[SampleLayer.UnwrappedSample] = [] + seen_samples = set() + stack: list[ + tuple[ + SampleUpsertInternal | None, + SampleUpsertInternal | None, + SampleUpsertInternal | None, + int, + ] + ] = [(None, None, sample, 0) for sample in samples] + + while stack: + root, parent, sample, depth = stack.pop() + if depth > max_depth: + raise SampleLayer.SampleUnwrapMaxDepthError( + f'Exceeded max depth of {max_depth} for nested samples. ' + f'Parents: {parent}' + ) - seen_samples = {id(s) for s in samples} + if id(sample) in seen_samples: + raise ValueError(f'Sample sample was seen in the list ({sample})') + seen_samples.add(id(sample)) - rounds: list[ - list[ - tuple[ - SampleUpsertInternal | None, - SampleUpsertInternal | None, - list[SampleUpsertInternal], - ] - ] - ] = [[(None, None, samples)]] - - round_idx = 0 - while round_idx < len(rounds): - prev_round = rounds[round_idx] - new_round = [] - round_idx += 1 - for root, parent, nested_samples in prev_round: - for sample in nested_samples: - retval.append( - SampleLayer.UnwrappedSample( - root=root, parent=parent, sample=sample - ) - ) - if not sample.nested_samples: - continue - - # do the seen check - for s in sample.nested_samples: - if id(s) in seen_samples: - raise ValueError( - f'Sample sample was seen in the list ({s})' - ) - seen_samples.add(id(s)) - new_round.append((root or sample, sample, sample.nested_samples)) - - if new_round: - if round_idx >= max_depth: - parents = ', '.join(str(s) for _, s, _ in new_round) - raise SampleLayer.SampleUnwrapMaxDepthError( - f'Exceeded max depth of {max_depth} for nested samples. ' - f'Parents: {parents}' - ) - rounds.append(new_round) + retval.append( + SampleLayer.UnwrappedSample(root=root, parent=parent, sample=sample) + ) + + if sample.nested_samples: + for nested_sample in sample.nested_samples: + stack.append((root or sample, sample, nested_sample, depth + 1)) return retval