Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy file #105

Merged
merged 3 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tf_pwa/amp/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import tensorflow as tf

from tf_pwa.cal_angle import CalAngleData, cal_angle_from_momentum
from tf_pwa.cal_angle import (
CalAngleData,
cal_angle_from_momentum,
parity_trans,
)
from tf_pwa.config import create_config, get_config, regist_config, temp_config
from tf_pwa.data import HeavyCall, data_strip

Expand Down Expand Up @@ -50,6 +54,9 @@ def __init__(

def __call__(self, x):
p4 = x["p4"]
if self.kwargs.get("cp_trans", False):
charges = x.get("extra", {}).get("charge_conjugation", None)
p4 = {k: parity_trans(v, charges) for k, v in p4.items()}
kwargs = {}
for k in [
"center_mass",
Expand Down
2 changes: 2 additions & 0 deletions tf_pwa/cal_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ def add_relative_momentum(data: dict):


def parity_trans(p, charges):
if charges is None:
return p
charges = charges[: p.shape[0], None]
return tf.where(charges > 0, p, LorentzVector.neg(p))

Expand Down
30 changes: 16 additions & 14 deletions tf_pwa/config_loader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tf_pwa.config_loader.decay_config import DecayConfig
from tf_pwa.data import (
LazyCall,
LazyFile,
data_index,
data_shape,
data_split,
Expand Down Expand Up @@ -86,6 +87,8 @@ def __init__(self, dic, decay_struct, config=None):
self.re_map[v] = k
self.scale_list = self.dic.get("scale_list", ["bg"])
self.lazy_call = self.dic.get("lazy_call", False)
self.lazy_file = self.dic.get("lazy_file", False)
cp_trans = self.dic.get("cp_trans", True)
center_mass = self.dic.get("center_mass", False)
r_boost = self.dic.get("r_boost", True)
random_z = self.dic.get("random_z", True)
Expand All @@ -105,6 +108,7 @@ def __init__(self, dic, decay_struct, config=None):
model=preprocessor_model,
no_p4=no_p4,
no_angle=no_angle,
cp_trans=cp_trans,
)

def get_data_file(self, idx):
Expand Down Expand Up @@ -177,7 +181,9 @@ def set_lazy_call(self, data, idx):
if isinstance(data, LazyCall):
name = idx
cached_file = self.dic.get("cached_lazy_call", None)
prefetch = self.dic.get("lazy_prefetch", -1)
data.set_cached_file(cached_file, name)
data.prefetch = prefetch

def get_n_data(self):
data = self.get_data("data")
Expand All @@ -186,30 +192,26 @@ def get_n_data(self):

def load_p4(self, fnames):
particles = self.get_dat_order()
p = load_dat_file(fnames, particles)
mmap_mode = "r" if self.lazy_file else None
p = load_dat_file(fnames, particles, mmap_mode=mmap_mode)
return p

def cal_angle(self, p4, **kwargs):
if isinstance(p4, (list, tuple)):
p4 = {k: v for k, v in zip(self.get_dat_order(), p4)}
charge = kwargs.get("charge_conjugation", None)
p4 = self.process_cp_trans(p4, charge)
# charge = kwargs.get("charge_conjugation", None)
# p4 = self.process_cp_trans(p4, charge)
if self.lazy_call:
data = LazyCall(self.preprocessor, {"p4": p4, "extra": kwargs})
if self.lazy_file:
data = LazyCall(
self.preprocessor, LazyFile({"p4": p4, "extra": kwargs})
)
else:
data = LazyCall(self.preprocessor, {"p4": p4, "extra": kwargs})
else:
data = self.preprocessor({"p4": p4, "extra": kwargs})
return data

def process_cp_trans(self, p4, charges):
cp_trans = self.dic.get("cp_trans", True)
if cp_trans and charges is not None:
if self.lazy_call:
with tf.device("CPU"):
p4 = {k: parity_trans(v, charges) for k, v in p4.items()}
else:
p4 = {k: parity_trans(v, charges) for k, v in p4.items()}
return p4

def load_extra_var(self, n_data, **kwargs):
extra_var = {}
for k, v in self.extra_var.items():
Expand Down
90 changes: 75 additions & 15 deletions tf_pwa/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@
self.cached_batch = {}
self.cached_file = None
self.name = ""
self.prefetch = -1

def batch(self, batch, axis=0):
return self.as_dataset(batch)

def __iter__(self):
assert self.batch_size is not None, ""
if isinstance(self.f, HeavyCall):
if (
isinstance(self.f, HeavyCall)
and self.batch_size in self.cached_batch
):
for i, j in zip(
self.cached_batch[self.batch_size],
split_generator(self.extra, self.batch_size),
Expand All @@ -104,13 +108,13 @@
for i, j in zip(
self.x, split_generator(self.extra, self.batch_size)
):
yield {**i, **j}
yield {**self.f(i, *self.args, **self.kwargs), **j}

Check warning on line 111 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L111

Added line #L111 was not covered by tests
else:
for i, j in zip(
split_generator(self.x, self.batch_size),
split_generator(self.extra, self.batch_size),
):
yield {**i, **j}
yield {**self.f(i, *self.args, **self.kwargs), **j}

Check warning on line 117 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L117

Added line #L117 was not covered by tests

def as_dataset(self, batch=65000):
self.batch_size = batch
Expand All @@ -127,12 +131,14 @@
ret = self.f(x, *self.args, **self.kwargs)
return ret

if isinstance(self.x, LazyCall):
real_x = self.x.eval()
if isinstance(self.x, LazyFile):
data = self.x.cached_batch[batch]
else:
real_x = self.x

data = tf.data.Dataset.from_tensor_slices(real_x)
if isinstance(self.x, LazyCall):
real_x = self.x.eval()

Check warning on line 138 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L138

Added line #L138 was not covered by tests
else:
real_x = self.x
data = tf.data.Dataset.from_tensor_slices(real_x).batch(batch)
# data = data.batch(batch).cache().map(f)
if self.cached_file is not None:
from tf_pwa.utils import create_dir
Expand All @@ -141,14 +147,18 @@

cached_file += "_" + str(batch)
create_dir(cached_file)
data = data.batch(batch).map(f)
data = data.map(f)
if self.cached_file == "":
data = data.cache()
else:
data = data.cache(cached_file)
else:
data = data.batch(batch).cache().map(f)
data = data.prefetch(tf.data.AUTOTUNE)
data = data.map(f)

if self.prefetch > 0:
data = data.prefetch(tf.prefetch)

Check warning on line 159 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L159

Added line #L159 was not covered by tests
elif self.prefetch < 0:
data = data.prefetch(tf.data.AUTOTUNE)

self.cached_batch[batch] = data
return self
Expand All @@ -159,21 +169,25 @@
self.cached_file = cached_file
self.name = name

def create_new(self, f, x, *args, **kwargs):
return LazyCall(f, x, *args, **kwargs)

def merge(self, *other, axis=0):
all_x = [self.x]
all_extra = [self.extra]
for i in other:
all_x.append(i.x)
all_extra.append(i.extra)
new_extra = data_merge(*all_extra, axis=axis)
ret = LazyCall(
ret = self.create_new(
self.f, data_merge(*all_x, axis=axis), *self.args, **self.kwargs
)
ret.extra = new_extra
ret.cached_file = self.cached_file
ret.name = self.name
for i in other:
ret.name += "_" + i.name
ret.prefetch = self.prefetch
return ret

def __setitem__(self, index, value):
Expand All @@ -195,10 +209,11 @@
return tf.ones(data_shape(self), dtype=get_config("dtype"))

def copy(self):
ret = LazyCall(self.f, self.x, *self.args, **self.kwargs)
ret = self.create_new(self.f, self.x, *self.args, **self.kwargs)
ret.extra = self.extra.copy()
ret.cached_file = self.cached_file
ret.name = self.name
ret.prefetch = self.prefetch
return ret

def eval(self):
Expand All @@ -219,6 +234,45 @@
return data_shape(x)


class LazyFile(LazyCall):
def __init__(self, x, *args, **kwargs):
self.x = x
self.f = lambda x: x
self.args = args
self.kwargs = kwargs
self.extra = {}
self.batch_size = None
self.cached_batch = {}
self.cached_file = None
self.name = ""
self.prefetch = -1

def as_dataset(self, batch=65000):
if batch in self.cached_batch:
return self.cached_batch[batch]

Check warning on line 252 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L252

Added line #L252 was not covered by tests

def gen():
for i in data_split(self.x, batch_size=batch):
yield data_map(i, np.array)

test_data = next(gen())
from tf_pwa.experimental.wrap_function import _wrap_struct

output_signature = _wrap_struct(test_data)
ret = tf.data.Dataset.from_generator(
gen, output_signature=output_signature
)
self.batch_size = batch
self.cached_batch[batch] = ret
return self

def create_new(self, f, x, *args, **kwargs):
return LazyFile(x)

def eval(self):
return self.x

Check warning on line 273 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L273

Added line #L273 was not covered by tests


class EvalLazy:
def __init__(self, f):
self.f = f
Expand All @@ -244,7 +298,13 @@


def load_dat_file(
fnames, particles, dtype=None, split=None, order=None, _force_list=False
fnames,
particles,
dtype=None,
split=None,
order=None,
_force_list=False,
mmap_mode=None,
):
"""
Load ``*.dat`` file(s) of 4-momenta of the final particles.
Expand Down Expand Up @@ -274,7 +334,7 @@
if fname.endswith(".npz"):
data = np.load(fname)["arr_0"]
elif fname.endswith(".npy"):
data = np.load(fname)
data = np.load(fname, mmap_mode=mmap_mode)
else:
data = np.loadtxt(fname, dtype=dtype)
data = np.reshape(data, (-1, 4))
Expand Down
42 changes: 42 additions & 0 deletions tf_pwa/tests/config_toy_npy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
data:
dat_order: [B, C, D]
data: ["toy_data/data_npy.npy"]
bg: ["toy_data/bg_npy.npy"]
phsp: ["toy_data/PHSP_npy.npy"]
lazy_file: True
lazy_call: True
cached_lazy_call: toy_data/cached2/
lazy_prefetch: 0
bg_weight: 0.1

decay:
A:
- [R_BC, D]
- [R_BD, C]
- [R_CD, B]
R_BC: [B, C]
R_BD: [B, D]
R_CD: [C, D]

particle:
$top:
A: { J: 1, P: -1, spins: [-1, 1], mass: 4.6 }
$finals:
B: { J: 1, P: -1, mass: 2.00698 }
C: { J: 1, P: -1, mass: 2.01028 }
D: { J: 0, P: -1, mass: 0.13957 }
R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, params: { mass_range: [4.0, 4.2] } }
R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3 }
R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 }

constrains:
particle: null
decay: null

plot:
config:
legend_outside: True
mass:
R_BC: { display: "$M_{BC}$" }
R_BD: { display: "$M_{BD}$" }
R_CD: { display: "$M_{CD}$" }
19 changes: 19 additions & 0 deletions tf_pwa/tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,27 @@ def gen_toy():
np.savetxt("toy_data/phsp_eff_value.dat", np.ones((10000,)))


@pytest.fixture
def toy_npy(gen_toy):
for i in ["data", "bg", "PHSP"]:
data = np.loadtxt(f"toy_data/{i}.dat")
np.save(f"toy_data/{i}_npy.npy", data)


@pytest.fixture
def toy_config(gen_toy):
config = ConfigLoader(f"{this_dir}/config_toy.yml")
config.set_params(f"{this_dir}/exp_params.json")
return config


@pytest.fixture
def toy_config_npy(toy_npy):
config = ConfigLoader(f"{this_dir}/config_toy_npy.yml")
config.set_params(f"{this_dir}/exp_params.json")
return config


@pytest.fixture
def toy_config_lazy(gen_toy):
config = ConfigLoader(f"{this_dir}/config_lazycall.yml")
Expand Down Expand Up @@ -368,3 +382,8 @@ def test_plot_2dpull(toy_config):
a, b = toy_config.get_dalitz_boundary("R_BC", "R_CD")
plt.plot(a, b, color="red")
plt.savefig("adptive_2d.png")


def test_lazy_file(toy_config_npy):
fcn = toy_config_npy.get_fcn()
fcn.nll_grad()
Loading
Loading