Skip to content

Commit

Permalink
Merge pull request #88 from jiangyi15/dataloader
Browse files Browse the repository at this point in the history
Better support for large data
  • Loading branch information
jiangyi15 authored Aug 3, 2023
2 parents f3e7662 + cc8ba51 commit 377a000
Show file tree
Hide file tree
Showing 19 changed files with 428 additions and 123 deletions.
20 changes: 16 additions & 4 deletions tf_pwa/amp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def _get_cg_matrix(self, ls): # CG factor inside H
lambda_b - lambda_c,
)
)
return tf.convert_to_tensor(ret)
return ret

def get_helicity_amp(self, data, data_p, **kwargs):
m_dep = self.get_ls_amp(data, data_p, **kwargs)
Expand Down Expand Up @@ -1622,6 +1622,8 @@ def add_used_chains(self, used_chains):
self.chains_idx.append(i)

def set_used_chains(self, used_chains):
if isinstance(used_chains, str):
used_chains = [used_chains]
self.chains_idx = list(used_chains)
if len(self.chains_idx) != len(self.chains):
self.not_full = True
Expand Down Expand Up @@ -1704,10 +1706,18 @@ def value_and_grad(f, var):

class AmplitudeModel(object):
def __init__(
self, decay_group, name="", polar=None, vm=None, use_tf_function=False
self,
decay_group,
name="",
polar=None,
vm=None,
use_tf_function=False,
no_id_cached=False,
jit_compile=False,
):
self.decay_group = decay_group
self._name = name
self.no_id_cached = no_id_cached
with variable_scope(vm) as vm:
if polar is not None:
vm.polar = polar
Expand All @@ -1720,7 +1730,9 @@ def __init__(
if use_tf_function:
from tf_pwa.experimental.wrap_function import WrapFun

self.cached_fun = WrapFun(self.decay_group.sum_amp)
self.cached_fun = WrapFun(
self.decay_group.sum_amp, jit_compile=jit_compile
)
else:
self.cached_fun = self.decay_group.sum_amp

Expand Down Expand Up @@ -1783,7 +1795,7 @@ def trainable_variables(self):
def __call__(self, data, cached=False):
if isinstance(data, LazyCall):
data = data.eval()
if id(data) in self.f_data:
if id(data) in self.f_data or self.no_id_cached:
if not self.decay_group.not_full:
return self.cached_fun(data)
else:
Expand Down
37 changes: 31 additions & 6 deletions tf_pwa/cal_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .angle import SU2M, EulerAngle, LorentzVector, Vector3, _epsilon
from .config import get_config
from .data import (
HeavyCall,
LazyCall,
data_index,
data_merge,
Expand Down Expand Up @@ -261,8 +262,8 @@ def cal_single_boost(data, decay_chain: DecayChain) -> dict:
def cal_helicity_angle(
data: dict,
decay_chain: DecayChain,
base_z=np.array([[0.0, 0.0, 1.0]]),
base_x=np.array([[1.0, 0.0, 0.0]]),
base_z=np.array([0.0, 0.0, 1.0]),
base_x=np.array([1.0, 0.0, 0.0]),
) -> dict:
"""
Calculate helicity angle for A -> B + C: :math:`\\theta_{B}^{A}, \\phi_{B}^{A}` from momentum.
Expand All @@ -276,7 +277,6 @@ def cal_helicity_angle(

# print(decay_chain, part_data)
part_data = cal_chain_boost(data, decay_chain)
# print(decay_chain , part_data)
# calculate angle and base x,z axis from mother particle rest frame momentum and base axis
set_x = {decay_chain.top: base_x}
set_z = {decay_chain.top: base_z}
Expand Down Expand Up @@ -405,6 +405,7 @@ def cal_angle_from_particle(
r_boost=True,
final_rest=True,
align_ref=None, # "center_mass",
only_left_angle=False,
):
"""
Calculate helicity angle for particle momentum, add aligned angle.
Expand All @@ -422,7 +423,7 @@ def cal_angle_from_particle(
# get base z axis
p4 = data[decay_group.top]["p"]
p3 = LorentzVector.vect(p4)
base_z = np.array([[0.0, 0.0, 1.0]]) + tf.zeros_like(p3)
base_z = np.array([0.0, 0.0, 1.0]) + tf.zeros_like(p3)
if random_z:
p3_norm = Vector3.norm(p3)
mask = tf.expand_dims(p3_norm < 1e-5, -1)
Expand Down Expand Up @@ -474,6 +475,10 @@ def cal_angle_from_particle(
# ang = AlignmentAngle.angle_px_px(z1, x1, z2, x2)
part_data[i]["aligned_angle"] = ang
ret = data_strip(decay_data, ["r_matrix", "b_matrix", "x", "z"])
if only_left_angle:
for i in ret:
for j in ret[i]:
del ret[i][j][j.outs[1]]["ang"]
return ret


Expand Down Expand Up @@ -629,6 +634,7 @@ def cal_angle_from_momentum_base(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand All @@ -646,6 +652,7 @@ def cal_angle_from_momentum_base(
r_boost,
random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
ret = []
for i in split_generator(p, batch):
Expand All @@ -658,6 +665,7 @@ def cal_angle_from_momentum_base(
r_boost,
random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
)
return data_merge(*ret)
Expand Down Expand Up @@ -707,11 +715,20 @@ def cal_angle_from_momentum_id_swap(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
ret = []
id_particles = decs.identical_particles
data = cal_angle_from_momentum_base(
p, decs, using_topology, center_mass, r_boost, random_z, batch
p,
decs,
using_topology,
center_mass,
r_boost,
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if id_particles is None or len(id_particles) == 0:
return data
Expand All @@ -727,6 +744,7 @@ def cal_angle_from_momentum_id_swap(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
return data

Expand All @@ -740,6 +758,7 @@ def cal_angle_from_momentum(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand All @@ -750,13 +769,15 @@ def cal_angle_from_momentum(
"""
if isinstance(p, LazyCall):
return LazyCall(
cal_angle_from_momentum,
HeavyCall(cal_angle_from_momentum),
p,
decs=decs,
using_topology=using_topology,
center_mass=center_mass,
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
batch=batch,
)
ret = []
Expand All @@ -771,6 +792,7 @@ def cal_angle_from_momentum(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if cp_particles is None or len(cp_particles) == 0:
return data
Expand All @@ -785,6 +807,7 @@ def cal_angle_from_momentum(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
return data

Expand All @@ -797,6 +820,7 @@ def cal_angle_from_momentum_single(
r_boost=True,
random_z=True,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand Down Expand Up @@ -824,6 +848,7 @@ def cal_angle_from_momentum_single(
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
data = {"particle": data_p, "decay": data_d}
add_relative_momentum(data)
Expand Down
21 changes: 16 additions & 5 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,23 @@ def get_decay(self, full=True):

@functools.lru_cache()
def get_amplitude(self, vm=None, name=""):
use_tf_function = self.config.get("data", {}).get(
"use_tf_function", False
)
amp_config = self.config.get("data", {})
use_tf_function = amp_config.get("use_tf_function", False)
no_id_cached = amp_config.get("no_id_cached", False)
jit_compile = amp_config.get("jit_compile", False)
decay_group = self.full_decay
self.check_valid_jp(decay_group)
if vm is None:
vm = self.vm
if vm in self.amps:
return self.amps[vm]
amp = AmplitudeModel(
decay_group, vm=vm, name=name, use_tf_function=use_tf_function
decay_group,
vm=vm,
name=name,
use_tf_function=use_tf_function,
no_id_cached=no_id_cached,
jit_compile=jit_compile,
)
self.add_constraints(amp)
self.amps[vm] = amp
Expand Down Expand Up @@ -561,6 +567,7 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""):
bg = [None] * self._Ngroup
model = self._get_model(vm=vm, name=name)
fcns = []

# print(self.config["data"].get("using_mix_likelihood", False))
if self.config["data"].get("using_mix_likelihood", False):
print(" Using Mix Likelihood")
Expand All @@ -575,7 +582,9 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""):
if all_data is None:
self.cached_fcn[vm] = fcn
return fcn
for md, dt, mc, sb, ij in zip(model, data, phsp, bg, inmc):
for idx, (md, dt, mc, sb, ij) in enumerate(
zip(model, data, phsp, bg, inmc)
):
if self.config["data"].get("model", "auto") == "cfit":
fcns.append(
FCN(
Expand Down Expand Up @@ -644,6 +653,7 @@ def fit(
maxiter=None,
jac=True,
print_init_nll=True,
callback=None,
):
if data is None and phsp is None:
data, phsp, bg, inmc = self.get_all_data()
Expand Down Expand Up @@ -677,6 +687,7 @@ def fit(
improve=False,
maxiter=maxiter,
jac=jac,
callback=callback,
)
if self.fit_params.hess_inv is not None:
self.inv_he = self.fit_params.hess_inv
Expand Down
46 changes: 31 additions & 15 deletions tf_pwa/config_loader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def get_data(self, idx) -> dict:
weight_sign = self.get_weight_sign(idx)
charge = self.dic.get(idx + "_charge", None)
ret = self.load_data(files, weights, weight_sign, charge)
return self.process_scale(idx, ret)
ret = self.process_scale(idx, ret)
return ret

def process_scale(self, idx, data):
if idx in self.scale_list and self.dic.get("weight_scale", False):
Expand All @@ -136,6 +137,12 @@ def process_scale(self, idx, data):
)
return data

def set_lazy_call(self, data, idx):
if isinstance(data, LazyCall):
name = idx
cached_file = self.dic.get("cached_lazy_call", None)
data.set_cached_file(cached_file, name)

def get_n_data(self):
data = self.get_data("data")
weight = data.get("weight", np.ones((data_shape(data),)))
Expand All @@ -156,13 +163,15 @@ def cal_angle(self, p4, charge=None):
r_boost = self.dic.get("r_boost", True)
random_z = self.dic.get("random_z", True)
align_ref = self.dic.get("align_ref", None)
only_left_angle = self.dic.get("only_left_angle", False)
data = cal_angle_from_momentum(
p4,
self.decay_struct,
center_mass=center_mass,
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if charge is not None:
data["charge_conjugation"] = charge
Expand All @@ -185,18 +194,17 @@ def load_data(
p4 = self.load_p4(files)
charges = None if charges is None else charges[: data_shape(p4)]
data = self.cal_angle(p4, charges)
if weights is not None:
if isinstance(weights, float):
data["weight"] = np.array(
[weights * weights_sign] * data_shape(data)
)
elif isinstance(weights, str): # weight files
weight = self.load_weight_file(weights)
data["weight"] = weight[: data_shape(data)] * weights_sign
else:
raise TypeError(
"weight format error: {}".format(type(weights))
)
if weights is None:
data["weight"] = np.array([1.0 * weights_sign] * data_shape(data))
elif isinstance(weights, float):
data["weight"] = np.array(
[weights * weights_sign] * data_shape(data)
)
elif isinstance(weights, str): # weight files
weight = self.load_weight_file(weights)
data["weight"] = weight[: data_shape(data)] * weights_sign
else:
raise TypeError("weight format error: {}".format(type(weights)))

if charge is None:
data["charge_conjugation"] = tf.ones((data_shape(data),))
Expand Down Expand Up @@ -322,8 +330,11 @@ def savetxt(self, file_name, data):
else:
raise ValueError("not support data")
p4 = data_to_numpy(p4)
p4 = np.stack(p4).transpose((1, 0, 2)).reshape((-1, 4))
np.savetxt(file_name, p4)
p4 = np.stack(p4).transpose((1, 0, 2))
if file_name.endswith("npy"):
np.save(file_name, p4)
else:
np.savetxt(file_name, p4.reshape((-1, 4)))


@register_data_mode("multi")
Expand All @@ -342,6 +353,10 @@ def process_scale(self, idx, data):
)
return data

def set_lazy_call(self, data, idx):
for i, data_i in enumerate(data):
super().set_lazy_call(data_i, "s{}{}".format(i, idx))

def get_n_data(self):
data = self.get_data("data")
weight = [
Expand Down Expand Up @@ -405,6 +420,7 @@ def get_data(self, idx) -> list:
data_shape(k)
)
ret = self.process_scale(idx, ret)
self.set_lazy_call(ret, idx)
return ret

def get_phsp_noeff(self):
Expand Down
Loading

0 comments on commit 377a000

Please sign in to comment.