From dccfaf78cf92c0263e6de1d55706d4319b455615 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 26 Aug 2023 21:48:29 +0800 Subject: [PATCH 1/3] feat: lazy file option to use memmap for npy --- tf_pwa/config_loader/data.py | 18 ++++++++--- tf_pwa/data.py | 63 +++++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/tf_pwa/config_loader/data.py b/tf_pwa/config_loader/data.py index e8d26a37..06d43242 100644 --- a/tf_pwa/config_loader/data.py +++ b/tf_pwa/config_loader/data.py @@ -17,6 +17,7 @@ from tf_pwa.config_loader.decay_config import DecayConfig from tf_pwa.data import ( LazyCall, + LazyFile, data_index, data_shape, data_split, @@ -86,6 +87,10 @@ def __init__(self, dic, decay_struct, config=None): self.re_map[v] = k self.scale_list = self.dic.get("scale_list", ["bg"]) self.lazy_call = self.dic.get("lazy_call", False) + self.lazy_file = self.dic.get("lazy_file", False) + self.cp_trans = self.dic.get("cp_trans", True) + if self.lazy_file and self.cp_trans: + warnings.warn("use lazy_file with cp_trans") center_mass = self.dic.get("center_mass", False) r_boost = self.dic.get("r_boost", True) random_z = self.dic.get("random_z", True) @@ -186,7 +191,8 @@ def get_n_data(self): def load_p4(self, fnames): particles = self.get_dat_order() - p = load_dat_file(fnames, particles) + mmap_mode = "r" if self.lazy_file else None + p = load_dat_file(fnames, particles, mmap_mode=mmap_mode) return p def cal_angle(self, p4, **kwargs): @@ -195,14 +201,18 @@ def cal_angle(self, p4, **kwargs): charge = kwargs.get("charge_conjugation", None) p4 = self.process_cp_trans(p4, charge) if self.lazy_call: - data = LazyCall(self.preprocessor, {"p4": p4, "extra": kwargs}) + if self.lazy_file: + data = LazyCall( + self.preprocessor, LazyFile({"p4": p4, "extra": kwargs}) + ) + else: + data = LazyCall(self.preprocessor, {"p4": p4, "extra": kwargs}) else: data = self.preprocessor({"p4": p4, "extra": kwargs}) return data def process_cp_trans(self, p4, charges): - cp_trans = self.dic.get("cp_trans", True) - if cp_trans and charges is not None: + if self.cp_trans and charges is not None: if self.lazy_call: with tf.device("CPU"): p4 = {k: parity_trans(v, charges) for k, v in p4.items()} diff --git a/tf_pwa/data.py b/tf_pwa/data.py index 7bc5be46..0f697e00 100644 --- a/tf_pwa/data.py +++ b/tf_pwa/data.py @@ -94,7 +94,10 @@ def batch(self, batch, axis=0): def __iter__(self): assert self.batch_size is not None, "" - if isinstance(self.f, HeavyCall): + if ( + isinstance(self.f, HeavyCall) + and self.batch_size in self.cached_batch + ): for i, j in zip( self.cached_batch[self.batch_size], split_generator(self.extra, self.batch_size), @@ -104,13 +107,13 @@ def __iter__(self): for i, j in zip( self.x, split_generator(self.extra, self.batch_size) ): - yield {**i, **j} + yield {**self.f(i), **j} else: for i, j in zip( split_generator(self.x, self.batch_size), split_generator(self.extra, self.batch_size), ): - yield {**i, **j} + yield {**self.f(i), **j} def as_dataset(self, batch=65000): self.batch_size = batch @@ -127,12 +130,14 @@ def f(x): ret = self.f(x, *self.args, **self.kwargs) return ret - if isinstance(self.x, LazyCall): - real_x = self.x.eval() + if isinstance(self.x, LazyFile): + data = self.x.cached_batch[batch] else: - real_x = self.x - - data = tf.data.Dataset.from_tensor_slices(real_x) + if isinstance(self.x, LazyCall): + real_x = self.x.eval() + else: + real_x = self.x + data = tf.data.Dataset.from_tensor_slices(real_x).batch(batch) # data = data.batch(batch).cache().map(f) if self.cached_file is not None: from tf_pwa.utils import create_dir @@ -141,13 +146,13 @@ def f(x): cached_file += "_" + str(batch) create_dir(cached_file) - data = data.batch(batch).map(f) + data = data.map(f) if self.cached_file == "": data = data.cache() else: data = data.cache(cached_file) else: - data = data.batch(batch).cache().map(f) + data = data.cache().map(f) data = data.prefetch(tf.data.AUTOTUNE) self.cached_batch[batch] = data @@ -219,6 +224,34 @@ def __len__(self): return data_shape(x) +class LazyFile(LazyCall): + def __init__(self, x): + self.x = x + self.batch_size = None + self.f = None + self.extra = {} + self.cached_batch = {} + + def as_dataset(self, batch=65000): + if batch in self.cached_batch: + return self.cached_batch[batch] + + def gen(): + for i in data_split(self.x, batch_size=batch): + yield data_map(i, np.array) + + test_data = next(gen()) + from tf_pwa.experimental.wrap_function import _wrap_struct + + output_signature = _wrap_struct(test_data) + ret = tf.data.Dataset.from_generator( + gen, output_signature=output_signature + ) + self.batch_size = batch + self.cached_batch[batch] = ret + return self + + class EvalLazy: def __init__(self, f): self.f = f @@ -244,7 +277,13 @@ def set_random_seed(seed): def load_dat_file( - fnames, particles, dtype=None, split=None, order=None, _force_list=False + fnames, + particles, + dtype=None, + split=None, + order=None, + _force_list=False, + mmap_mode=None, ): """ Load ``*.dat`` file(s) of 4-momenta of the final particles. @@ -274,7 +313,7 @@ def load_dat_file( if fname.endswith(".npz"): data = np.load(fname)["arr_0"] elif fname.endswith(".npy"): - data = np.load(fname) + data = np.load(fname, mmap_mode=mmap_mode) else: data = np.loadtxt(fname, dtype=dtype) data = np.reshape(data, (-1, 4)) From 0c966f9e2d345d6ec8b5664c775af3693930d197 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sun, 27 Aug 2023 21:13:47 +0800 Subject: [PATCH 2/3] feat: lazy_file and lazy_prefetch options; move cp_trans to preprocessor --- tf_pwa/amp/preprocess.py | 9 ++++++++- tf_pwa/cal_angle.py | 2 ++ tf_pwa/config_loader/data.py | 20 ++++++-------------- tf_pwa/data.py | 15 +++++++++++---- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/tf_pwa/amp/preprocess.py b/tf_pwa/amp/preprocess.py index 7f3364cc..dcb95529 100644 --- a/tf_pwa/amp/preprocess.py +++ b/tf_pwa/amp/preprocess.py @@ -2,7 +2,11 @@ import tensorflow as tf -from tf_pwa.cal_angle import CalAngleData, cal_angle_from_momentum +from tf_pwa.cal_angle import ( + CalAngleData, + cal_angle_from_momentum, + parity_trans, +) from tf_pwa.config import create_config, get_config, regist_config, temp_config from tf_pwa.data import HeavyCall, data_strip @@ -50,6 +54,9 @@ def __init__( def __call__(self, x): p4 = x["p4"] + if self.kwargs.get("cp_trans", False): + charges = x.get("extra", {}).get("charge_conjugation", None) + p4 = {k: parity_trans(v, charges) for k, v in p4.items()} kwargs = {} for k in [ "center_mass", diff --git a/tf_pwa/cal_angle.py b/tf_pwa/cal_angle.py index b7113912..72f5586f 100644 --- a/tf_pwa/cal_angle.py +++ b/tf_pwa/cal_angle.py @@ -559,6 +559,8 @@ def add_relative_momentum(data: dict): def parity_trans(p, charges): + if charges is None: + return p charges = charges[: p.shape[0], None] return tf.where(charges > 0, p, LorentzVector.neg(p)) diff --git a/tf_pwa/config_loader/data.py b/tf_pwa/config_loader/data.py index 06d43242..a5ff2216 100644 --- a/tf_pwa/config_loader/data.py +++ b/tf_pwa/config_loader/data.py @@ -88,9 +88,7 @@ def __init__(self, dic, decay_struct, config=None): self.scale_list = self.dic.get("scale_list", ["bg"]) self.lazy_call = self.dic.get("lazy_call", False) self.lazy_file = self.dic.get("lazy_file", False) - self.cp_trans = self.dic.get("cp_trans", True) - if self.lazy_file and self.cp_trans: - warnings.warn("use lazy_file with cp_trans") + cp_trans = self.dic.get("cp_trans", True) center_mass = self.dic.get("center_mass", False) r_boost = self.dic.get("r_boost", True) random_z = self.dic.get("random_z", True) @@ -110,6 +108,7 @@ def __init__(self, dic, decay_struct, config=None): model=preprocessor_model, no_p4=no_p4, no_angle=no_angle, + cp_trans=cp_trans, ) def get_data_file(self, idx): @@ -182,7 +181,9 @@ def set_lazy_call(self, data, idx): if isinstance(data, LazyCall): name = idx cached_file = self.dic.get("cached_lazy_call", None) + prefetch = self.dic.get("lazy_prefetch", -1) data.set_cached_file(cached_file, name) + data.prefetch = prefetch def get_n_data(self): data = self.get_data("data") @@ -198,8 +199,8 @@ def load_p4(self, fnames): def cal_angle(self, p4, **kwargs): if isinstance(p4, (list, tuple)): p4 = {k: v for k, v in zip(self.get_dat_order(), p4)} - charge = kwargs.get("charge_conjugation", None) - p4 = self.process_cp_trans(p4, charge) + # charge = kwargs.get("charge_conjugation", None) + # p4 = self.process_cp_trans(p4, charge) if self.lazy_call: if self.lazy_file: data = LazyCall( @@ -211,15 +212,6 @@ def cal_angle(self, p4, **kwargs): data = self.preprocessor({"p4": p4, "extra": kwargs}) return data - def process_cp_trans(self, p4, charges): - if self.cp_trans and charges is not None: - if self.lazy_call: - with tf.device("CPU"): - p4 = {k: parity_trans(v, charges) for k, v in p4.items()} - else: - p4 = {k: parity_trans(v, charges) for k, v in p4.items()} - return p4 - def load_extra_var(self, n_data, **kwargs): extra_var = {} for k, v in self.extra_var.items(): diff --git a/tf_pwa/data.py b/tf_pwa/data.py index 0f697e00..64f913df 100644 --- a/tf_pwa/data.py +++ b/tf_pwa/data.py @@ -88,6 +88,7 @@ def __init__(self, f, x, *args, **kwargs): self.cached_batch = {} self.cached_file = None self.name = "" + self.prefetch = -1 def batch(self, batch, axis=0): return self.as_dataset(batch) @@ -107,13 +108,13 @@ def __iter__(self): for i, j in zip( self.x, split_generator(self.extra, self.batch_size) ): - yield {**self.f(i), **j} + yield {**self.f(i, *self.args, **self.kwargs), **j} else: for i, j in zip( split_generator(self.x, self.batch_size), split_generator(self.extra, self.batch_size), ): - yield {**self.f(i), **j} + yield {**self.f(i, *self.args, **self.kwargs), **j} def as_dataset(self, batch=65000): self.batch_size = batch @@ -152,8 +153,12 @@ def f(x): else: data = data.cache(cached_file) else: - data = data.cache().map(f) - data = data.prefetch(tf.data.AUTOTUNE) + data = data.map(f) + + if self.prefetch > 0: + data = data.prefetch(tf.prefetch) + elif self.prefetch < 0: + data = data.prefetch(tf.data.AUTOTUNE) self.cached_batch[batch] = data return self @@ -179,6 +184,7 @@ def merge(self, *other, axis=0): ret.name = self.name for i in other: ret.name += "_" + i.name + ret.prefetch = self.prefetch return ret def __setitem__(self, index, value): @@ -204,6 +210,7 @@ def copy(self): ret.extra = self.extra.copy() ret.cached_file = self.cached_file ret.name = self.name + ret.prefetch = self.prefetch return ret def eval(self): From 691ea2fe23586889f88e3480b356159bf6852d87 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sun, 27 Aug 2023 22:44:21 +0800 Subject: [PATCH 3/3] add tests and script for lazy_file --- tf_pwa/data.py | 24 ++++++++--- tf_pwa/tests/config_toy_npy.yml | 42 +++++++++++++++++++ tf_pwa/tests/test_full.py | 19 +++++++++ tutorials/examples/create_lazy_cached_data.py | 25 +++++++++++ 4 files changed, 105 insertions(+), 5 deletions(-) create mode 100644 tf_pwa/tests/config_toy_npy.yml create mode 100644 tutorials/examples/create_lazy_cached_data.py diff --git a/tf_pwa/data.py b/tf_pwa/data.py index 64f913df..a28dff57 100644 --- a/tf_pwa/data.py +++ b/tf_pwa/data.py @@ -169,6 +169,9 @@ def set_cached_file(self, cached_file, name): self.cached_file = cached_file self.name = name + def create_new(self, f, x, *args, **kwargs): + return LazyCall(f, x, *args, **kwargs) + def merge(self, *other, axis=0): all_x = [self.x] all_extra = [self.extra] @@ -176,7 +179,7 @@ def merge(self, *other, axis=0): all_x.append(i.x) all_extra.append(i.extra) new_extra = data_merge(*all_extra, axis=axis) - ret = LazyCall( + ret = self.create_new( self.f, data_merge(*all_x, axis=axis), *self.args, **self.kwargs ) ret.extra = new_extra @@ -206,7 +209,7 @@ def get_weight(self): return tf.ones(data_shape(self), dtype=get_config("dtype")) def copy(self): - ret = LazyCall(self.f, self.x, *self.args, **self.kwargs) + ret = self.create_new(self.f, self.x, *self.args, **self.kwargs) ret.extra = self.extra.copy() ret.cached_file = self.cached_file ret.name = self.name @@ -232,12 +235,17 @@ def __len__(self): class LazyFile(LazyCall): - def __init__(self, x): + def __init__(self, x, *args, **kwargs): self.x = x - self.batch_size = None - self.f = None + self.f = lambda x: x + self.args = args + self.kwargs = kwargs self.extra = {} + self.batch_size = None self.cached_batch = {} + self.cached_file = None + self.name = "" + self.prefetch = -1 def as_dataset(self, batch=65000): if batch in self.cached_batch: @@ -258,6 +266,12 @@ def gen(): self.cached_batch[batch] = ret return self + def create_new(self, f, x, *args, **kwargs): + return LazyFile(x) + + def eval(self): + return self.x + class EvalLazy: def __init__(self, f): diff --git a/tf_pwa/tests/config_toy_npy.yml b/tf_pwa/tests/config_toy_npy.yml new file mode 100644 index 00000000..e84091b2 --- /dev/null +++ b/tf_pwa/tests/config_toy_npy.yml @@ -0,0 +1,42 @@ +data: + dat_order: [B, C, D] + data: ["toy_data/data_npy.npy"] + bg: ["toy_data/bg_npy.npy"] + phsp: ["toy_data/PHSP_npy.npy"] + lazy_file: True + lazy_call: True + cached_lazy_call: toy_data/cached2/ + lazy_prefetch: 0 + bg_weight: 0.1 + +decay: + A: + - [R_BC, D] + - [R_BD, C] + - [R_CD, B] + R_BC: [B, C] + R_BD: [B, D] + R_CD: [C, D] + +particle: + $top: + A: { J: 1, P: -1, spins: [-1, 1], mass: 4.6 } + $finals: + B: { J: 1, P: -1, mass: 2.00698 } + C: { J: 1, P: -1, mass: 2.01028 } + D: { J: 0, P: -1, mass: 0.13957 } + R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, params: { mass_range: [4.0, 4.2] } } + R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3 } + R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 } + +constrains: + particle: null + decay: null + +plot: + config: + legend_outside: True + mass: + R_BC: { display: "$M_{BC}$" } + R_BD: { display: "$M_{BD}$" } + R_CD: { display: "$M_{CD}$" } diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index d4803cbd..68df04c9 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -64,6 +64,13 @@ def gen_toy(): np.savetxt("toy_data/phsp_eff_value.dat", np.ones((10000,))) +@pytest.fixture +def toy_npy(gen_toy): + for i in ["data", "bg", "PHSP"]: + data = np.loadtxt(f"toy_data/{i}.dat") + np.save(f"toy_data/{i}_npy.npy", data) + + @pytest.fixture def toy_config(gen_toy): config = ConfigLoader(f"{this_dir}/config_toy.yml") @@ -71,6 +78,13 @@ def toy_config(gen_toy): return config +@pytest.fixture +def toy_config_npy(toy_npy): + config = ConfigLoader(f"{this_dir}/config_toy_npy.yml") + config.set_params(f"{this_dir}/exp_params.json") + return config + + @pytest.fixture def toy_config_lazy(gen_toy): config = ConfigLoader(f"{this_dir}/config_lazycall.yml") @@ -368,3 +382,8 @@ def test_plot_2dpull(toy_config): a, b = toy_config.get_dalitz_boundary("R_BC", "R_CD") plt.plot(a, b, color="red") plt.savefig("adptive_2d.png") + + +def test_lazy_file(toy_config_npy): + fcn = toy_config_npy.get_fcn() + fcn.nll_grad() diff --git a/tutorials/examples/create_lazy_cached_data.py b/tutorials/examples/create_lazy_cached_data.py new file mode 100644 index 00000000..22bb08c0 --- /dev/null +++ b/tutorials/examples/create_lazy_cached_data.py @@ -0,0 +1,25 @@ +from tf_pwa.config_loader import ConfigLoader + +# import extra_amp + +batch_size = 100000 +config = ConfigLoader("config.yml") + + +def create_cached_file(d): + d.prefetch = 0 + d.batch(batch_size) + for i in d: + # iter for Dataset + pass + + +data, phsp, bg, *_ = config.get_all_data() + +for i in range(len(data)): + if bg is None or bg[i] is None: + data_bg_merged = data[i] + else: + data_bg_merged = data[i].merge(bg[i]) + create_cached_file(data_bg_merged) + create_cached_file(phsp[i])