diff --git a/dptb/data/interfaces/ham_to_feature.py b/dptb/data/interfaces/ham_to_feature.py index cc39b982..cfd61a7e 100644 --- a/dptb/data/interfaces/ham_to_feature.py +++ b/dptb/data/interfaces/ham_to_feature.py @@ -319,19 +319,25 @@ def block_to_feature(data, idp, blocks=False, overlap_blocks=False, orthogonal=F # if overlap_blocks: # data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype()) -def feature_to_block(data, idp): +def feature_to_block(data, idp, overlap: bool = False): idp.get_orbital_maps() idp.get_orbpair_maps() has_block = False - if data.get(_keys.NODE_FEATURES_KEY, None) is not None: - node_features = data[_keys.NODE_FEATURES_KEY] - edge_features = data[_keys.EDGE_FEATURES_KEY] - has_block = True - blocks = {} - - idp.get_orbital_maps() - idp.get_orbpair_maps() + if not overlap: + if data.get(_keys.NODE_FEATURES_KEY, None) is not None: + node_features = data[_keys.NODE_FEATURES_KEY] + edge_features = data[_keys.EDGE_FEATURES_KEY] + has_block = True + blocks = {} + else: + if data.get(_keys.NODE_OVERLAP_KEY, None) is not None: + node_features = data[_keys.NODE_OVERLAP_KEY] + edge_features = data[_keys.EDGE_OVERLAP_KEY] + has_block = True + blocks = {} + else: + raise KeyError("Overlap features not found in data.") if has_block: # get node blocks from node_features