From faa99d32a033a21b1867a3f83135d88690c46acf Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Wed, 28 Aug 2024 17:44:34 +0800 Subject: [PATCH 1/2] add overlap output in feature_to_block --- dptb/data/interfaces/ham_to_feature.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/dptb/data/interfaces/ham_to_feature.py b/dptb/data/interfaces/ham_to_feature.py index cc39b982..1acbf7a5 100644 --- a/dptb/data/interfaces/ham_to_feature.py +++ b/dptb/data/interfaces/ham_to_feature.py @@ -319,19 +319,25 @@ def block_to_feature(data, idp, blocks=False, overlap_blocks=False, orthogonal=F # if overlap_blocks: # data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype()) -def feature_to_block(data, idp): +def feature_to_block(data, idp, mode="H"): idp.get_orbital_maps() idp.get_orbpair_maps() has_block = False - if data.get(_keys.NODE_FEATURES_KEY, None) is not None: - node_features = data[_keys.NODE_FEATURES_KEY] - edge_features = data[_keys.EDGE_FEATURES_KEY] - has_block = True - blocks = {} - - idp.get_orbital_maps() - idp.get_orbpair_maps() + if mode == "H": + if data.get(_keys.NODE_FEATURES_KEY, None) is not None: + node_features = data[_keys.NODE_FEATURES_KEY] + edge_features = data[_keys.EDGE_FEATURES_KEY] + has_block = True + blocks = {} + elif mode == "S": + if data.get(_keys.NODE_OVERLAP_KEY, None) is not None: + node_features = data[_keys.NODE_OVERLAP_KEY] + edge_features = data[_keys.EDGE_OVERLAP_KEY] + has_block = True + blocks = {} + else: + raise ValueError("Mode should be either 'H' or 'S'.") if has_block: # get node blocks from node_features From cc336e2fdf0b95675489d047bb4fbfbd9f4b0e36 Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Wed, 28 Aug 2024 19:54:46 +0800 Subject: [PATCH 2/2] update feature_to_block generally for H, S and D --- dptb/data/interfaces/ham_to_feature.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dptb/data/interfaces/ham_to_feature.py b/dptb/data/interfaces/ham_to_feature.py index 1acbf7a5..cfd61a7e 100644 --- a/dptb/data/interfaces/ham_to_feature.py +++ b/dptb/data/interfaces/ham_to_feature.py @@ -319,25 +319,25 @@ def block_to_feature(data, idp, blocks=False, overlap_blocks=False, orthogonal=F # if overlap_blocks: # data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype()) -def feature_to_block(data, idp, mode="H"): +def feature_to_block(data, idp, overlap: bool = False): idp.get_orbital_maps() idp.get_orbpair_maps() has_block = False - if mode == "H": + if not overlap: if data.get(_keys.NODE_FEATURES_KEY, None) is not None: node_features = data[_keys.NODE_FEATURES_KEY] edge_features = data[_keys.EDGE_FEATURES_KEY] has_block = True blocks = {} - elif mode == "S": + else: if data.get(_keys.NODE_OVERLAP_KEY, None) is not None: node_features = data[_keys.NODE_OVERLAP_KEY] edge_features = data[_keys.EDGE_OVERLAP_KEY] has_block = True blocks = {} - else: - raise ValueError("Mode should be either 'H' or 'S'.") + else: + raise KeyError("Overlap features not found in data.") if has_block: # get node blocks from node_features