Skip to content

Commit

Permalink
Merge pull request #105 from jiangyi15/lazy_file
Browse files Browse the repository at this point in the history
Lazy file
  • Loading branch information
jiangyi15 authored Aug 28, 2023
2 parents 5a7ef7e + 691ea2f commit bb5236f
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 30 deletions.
9 changes: 8 additions & 1 deletion tf_pwa/amp/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import tensorflow as tf

from tf_pwa.cal_angle import CalAngleData, cal_angle_from_momentum
from tf_pwa.cal_angle import (
CalAngleData,
cal_angle_from_momentum,
parity_trans,
)
from tf_pwa.config import create_config, get_config, regist_config, temp_config
from tf_pwa.data import HeavyCall, data_strip

Expand Down Expand Up @@ -50,6 +54,9 @@ def __init__(

def __call__(self, x):
p4 = x["p4"]
if self.kwargs.get("cp_trans", False):
charges = x.get("extra", {}).get("charge_conjugation", None)
p4 = {k: parity_trans(v, charges) for k, v in p4.items()}
kwargs = {}
for k in [
"center_mass",
Expand Down
2 changes: 2 additions & 0 deletions tf_pwa/cal_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ def add_relative_momentum(data: dict):


def parity_trans(p, charges):
if charges is None:
return p
charges = charges[: p.shape[0], None]
return tf.where(charges > 0, p, LorentzVector.neg(p))

Expand Down
30 changes: 16 additions & 14 deletions tf_pwa/config_loader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tf_pwa.config_loader.decay_config import DecayConfig
from tf_pwa.data import (
LazyCall,
LazyFile,
data_index,
data_shape,
data_split,
Expand Down Expand Up @@ -86,6 +87,8 @@ def __init__(self, dic, decay_struct, config=None):
self.re_map[v] = k
self.scale_list = self.dic.get("scale_list", ["bg"])
self.lazy_call = self.dic.get("lazy_call", False)
self.lazy_file = self.dic.get("lazy_file", False)
cp_trans = self.dic.get("cp_trans", True)
center_mass = self.dic.get("center_mass", False)
r_boost = self.dic.get("r_boost", True)
random_z = self.dic.get("random_z", True)
Expand All @@ -105,6 +108,7 @@ def __init__(self, dic, decay_struct, config=None):
model=preprocessor_model,
no_p4=no_p4,
no_angle=no_angle,
cp_trans=cp_trans,
)

def get_data_file(self, idx):
Expand Down Expand Up @@ -177,7 +181,9 @@ def set_lazy_call(self, data, idx):
if isinstance(data, LazyCall):
name = idx
cached_file = self.dic.get("cached_lazy_call", None)
prefetch = self.dic.get("lazy_prefetch", -1)
data.set_cached_file(cached_file, name)
data.prefetch = prefetch

def get_n_data(self):
data = self.get_data("data")
Expand All @@ -186,30 +192,26 @@ def get_n_data(self):

def load_p4(self, fnames):
particles = self.get_dat_order()
p = load_dat_file(fnames, particles)
mmap_mode = "r" if self.lazy_file else None
p = load_dat_file(fnames, particles, mmap_mode=mmap_mode)
return p

def cal_angle(self, p4, **kwargs):
if isinstance(p4, (list, tuple)):
p4 = {k: v for k, v in zip(self.get_dat_order(), p4)}
charge = kwargs.get("charge_conjugation", None)
p4 = self.process_cp_trans(p4, charge)
# charge = kwargs.get("charge_conjugation", None)
# p4 = self.process_cp_trans(p4, charge)
if self.lazy_call:
data = LazyCall(self.preprocessor, {"p4": p4, "extra": kwargs})
if self.lazy_file:
data = LazyCall(
self.preprocessor, LazyFile({"p4": p4, "extra": kwargs})
)
else:
data = LazyCall(self.preprocessor, {"p4": p4, "extra": kwargs})
else:
data = self.preprocessor({"p4": p4, "extra": kwargs})
return data

def process_cp_trans(self, p4, charges):
cp_trans = self.dic.get("cp_trans", True)
if cp_trans and charges is not None:
if self.lazy_call:
with tf.device("CPU"):
p4 = {k: parity_trans(v, charges) for k, v in p4.items()}
else:
p4 = {k: parity_trans(v, charges) for k, v in p4.items()}
return p4

def load_extra_var(self, n_data, **kwargs):
extra_var = {}
for k, v in self.extra_var.items():
Expand Down
90 changes: 75 additions & 15 deletions tf_pwa/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@ def __init__(self, f, x, *args, **kwargs):
self.cached_batch = {}
self.cached_file = None
self.name = ""
self.prefetch = -1

def batch(self, batch, axis=0):
return self.as_dataset(batch)

def __iter__(self):
assert self.batch_size is not None, ""
if isinstance(self.f, HeavyCall):
if (
isinstance(self.f, HeavyCall)
and self.batch_size in self.cached_batch
):
for i, j in zip(
self.cached_batch[self.batch_size],
split_generator(self.extra, self.batch_size),
Expand All @@ -104,13 +108,13 @@ def __iter__(self):
for i, j in zip(
self.x, split_generator(self.extra, self.batch_size)
):
yield {**i, **j}
yield {**self.f(i, *self.args, **self.kwargs), **j}
else:
for i, j in zip(
split_generator(self.x, self.batch_size),
split_generator(self.extra, self.batch_size),
):
yield {**i, **j}
yield {**self.f(i, *self.args, **self.kwargs), **j}

def as_dataset(self, batch=65000):
self.batch_size = batch
Expand All @@ -127,12 +131,14 @@ def f(x):
ret = self.f(x, *self.args, **self.kwargs)
return ret

if isinstance(self.x, LazyCall):
real_x = self.x.eval()
if isinstance(self.x, LazyFile):
data = self.x.cached_batch[batch]
else:
real_x = self.x

data = tf.data.Dataset.from_tensor_slices(real_x)
if isinstance(self.x, LazyCall):
real_x = self.x.eval()
else:
real_x = self.x
data = tf.data.Dataset.from_tensor_slices(real_x).batch(batch)
# data = data.batch(batch).cache().map(f)
if self.cached_file is not None:
from tf_pwa.utils import create_dir
Expand All @@ -141,14 +147,18 @@ def f(x):

cached_file += "_" + str(batch)
create_dir(cached_file)
data = data.batch(batch).map(f)
data = data.map(f)
if self.cached_file == "":
data = data.cache()
else:
data = data.cache(cached_file)
else:
data = data.batch(batch).cache().map(f)
data = data.prefetch(tf.data.AUTOTUNE)
data = data.map(f)

if self.prefetch > 0:
data = data.prefetch(tf.prefetch)
elif self.prefetch < 0:
data = data.prefetch(tf.data.AUTOTUNE)

self.cached_batch[batch] = data
return self
Expand All @@ -159,21 +169,25 @@ def set_cached_file(self, cached_file, name):
self.cached_file = cached_file
self.name = name

def create_new(self, f, x, *args, **kwargs):
return LazyCall(f, x, *args, **kwargs)

def merge(self, *other, axis=0):
all_x = [self.x]
all_extra = [self.extra]
for i in other:
all_x.append(i.x)
all_extra.append(i.extra)
new_extra = data_merge(*all_extra, axis=axis)
ret = LazyCall(
ret = self.create_new(
self.f, data_merge(*all_x, axis=axis), *self.args, **self.kwargs
)
ret.extra = new_extra
ret.cached_file = self.cached_file
ret.name = self.name
for i in other:
ret.name += "_" + i.name
ret.prefetch = self.prefetch
return ret

def __setitem__(self, index, value):
Expand All @@ -195,10 +209,11 @@ def get_weight(self):
return tf.ones(data_shape(self), dtype=get_config("dtype"))

def copy(self):
ret = LazyCall(self.f, self.x, *self.args, **self.kwargs)
ret = self.create_new(self.f, self.x, *self.args, **self.kwargs)
ret.extra = self.extra.copy()
ret.cached_file = self.cached_file
ret.name = self.name
ret.prefetch = self.prefetch
return ret

def eval(self):
Expand All @@ -219,6 +234,45 @@ def __len__(self):
return data_shape(x)


class LazyFile(LazyCall):
def __init__(self, x, *args, **kwargs):
self.x = x
self.f = lambda x: x
self.args = args
self.kwargs = kwargs
self.extra = {}
self.batch_size = None
self.cached_batch = {}
self.cached_file = None
self.name = ""
self.prefetch = -1

def as_dataset(self, batch=65000):
if batch in self.cached_batch:
return self.cached_batch[batch]

def gen():
for i in data_split(self.x, batch_size=batch):
yield data_map(i, np.array)

test_data = next(gen())
from tf_pwa.experimental.wrap_function import _wrap_struct

output_signature = _wrap_struct(test_data)
ret = tf.data.Dataset.from_generator(
gen, output_signature=output_signature
)
self.batch_size = batch
self.cached_batch[batch] = ret
return self

def create_new(self, f, x, *args, **kwargs):
return LazyFile(x)

def eval(self):
return self.x


class EvalLazy:
def __init__(self, f):
self.f = f
Expand All @@ -244,7 +298,13 @@ def set_random_seed(seed):


def load_dat_file(
fnames, particles, dtype=None, split=None, order=None, _force_list=False
fnames,
particles,
dtype=None,
split=None,
order=None,
_force_list=False,
mmap_mode=None,
):
"""
Load ``*.dat`` file(s) of 4-momenta of the final particles.
Expand Down Expand Up @@ -274,7 +334,7 @@ def load_dat_file(
if fname.endswith(".npz"):
data = np.load(fname)["arr_0"]
elif fname.endswith(".npy"):
data = np.load(fname)
data = np.load(fname, mmap_mode=mmap_mode)
else:
data = np.loadtxt(fname, dtype=dtype)
data = np.reshape(data, (-1, 4))
Expand Down
42 changes: 42 additions & 0 deletions tf_pwa/tests/config_toy_npy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
data:
dat_order: [B, C, D]
data: ["toy_data/data_npy.npy"]
bg: ["toy_data/bg_npy.npy"]
phsp: ["toy_data/PHSP_npy.npy"]
lazy_file: True
lazy_call: True
cached_lazy_call: toy_data/cached2/
lazy_prefetch: 0
bg_weight: 0.1

decay:
A:
- [R_BC, D]
- [R_BD, C]
- [R_CD, B]
R_BC: [B, C]
R_BD: [B, D]
R_CD: [C, D]

particle:
$top:
A: { J: 1, P: -1, spins: [-1, 1], mass: 4.6 }
$finals:
B: { J: 1, P: -1, mass: 2.00698 }
C: { J: 1, P: -1, mass: 2.01028 }
D: { J: 0, P: -1, mass: 0.13957 }
R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, params: { mass_range: [4.0, 4.2] } }
R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3 }
R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 }

constrains:
particle: null
decay: null

plot:
config:
legend_outside: True
mass:
R_BC: { display: "$M_{BC}$" }
R_BD: { display: "$M_{BD}$" }
R_CD: { display: "$M_{CD}$" }
19 changes: 19 additions & 0 deletions tf_pwa/tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,27 @@ def gen_toy():
np.savetxt("toy_data/phsp_eff_value.dat", np.ones((10000,)))


@pytest.fixture
def toy_npy(gen_toy):
for i in ["data", "bg", "PHSP"]:
data = np.loadtxt(f"toy_data/{i}.dat")
np.save(f"toy_data/{i}_npy.npy", data)


@pytest.fixture
def toy_config(gen_toy):
config = ConfigLoader(f"{this_dir}/config_toy.yml")
config.set_params(f"{this_dir}/exp_params.json")
return config


@pytest.fixture
def toy_config_npy(toy_npy):
config = ConfigLoader(f"{this_dir}/config_toy_npy.yml")
config.set_params(f"{this_dir}/exp_params.json")
return config


@pytest.fixture
def toy_config_lazy(gen_toy):
config = ConfigLoader(f"{this_dir}/config_lazycall.yml")
Expand Down Expand Up @@ -368,3 +382,8 @@ def test_plot_2dpull(toy_config):
a, b = toy_config.get_dalitz_boundary("R_BC", "R_CD")
plt.plot(a, b, color="red")
plt.savefig("adptive_2d.png")


def test_lazy_file(toy_config_npy):
fcn = toy_config_npy.get_fcn()
fcn.nll_grad()
Loading

0 comments on commit bb5236f

Please sign in to comment.