Skip to content

Commit

Permalink
refactor: cached_lazy_call
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Jul 31, 2023
1 parent b6cf9d4 commit d88a8f5
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 29 deletions.
7 changes: 0 additions & 7 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,6 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""):
model = self._get_model(vm=vm, name=name)
fcns = []

cached_file = self.config["data"].get("cached_lazy_call", None)

# print(self.config["data"].get("using_mix_likelihood", False))
if self.config["data"].get("using_mix_likelihood", False):
print(" Using Mix Likelihood")
Expand All @@ -587,9 +585,6 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""):
for idx, (md, dt, mc, sb, ij) in enumerate(
zip(model, data, phsp, bg, inmc)
):
cached_file2 = (
None if cached_file is None else cached_file + "s" + str(idx)
)
if self.config["data"].get("model", "auto") == "cfit":
fcns.append(
FCN(
Expand All @@ -599,7 +594,6 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""):
batch=batch,
inmc=ij,
gauss_constr=self.gauss_constr_dic,
cached_file=cached_file2,
)
)
else:
Expand All @@ -612,7 +606,6 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""):
batch=batch,
inmc=ij,
gauss_constr=self.gauss_constr_dic,
cached_file=cached_file2,
)
)
if len(fcns) == 1:
Expand Down
13 changes: 12 additions & 1 deletion tf_pwa/config_loader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def get_data(self, idx) -> dict:
weight_sign = self.get_weight_sign(idx)
charge = self.dic.get(idx + "_charge", None)
ret = self.load_data(files, weights, weight_sign, charge)
return self.process_scale(idx, ret)
ret = self.process_scale(idx, ret)
return ret

def process_scale(self, idx, data):
if idx in self.scale_list and self.dic.get("weight_scale", False):
Expand All @@ -136,6 +137,11 @@ def process_scale(self, idx, data):
)
return data

def set_lazy_call(self, data, idx):
if isinstance(data, LazyCall):
data.name = idx
data.cached_file = self.dic.get("cached_lazy_call", None)

def get_n_data(self):
data = self.get_data("data")
weight = data.get("weight", np.ones((data_shape(data),)))
Expand Down Expand Up @@ -344,6 +350,10 @@ def process_scale(self, idx, data):
)
return data

def set_lazy_call(self, data, idx):
for i, data_i in enumerate(data):
super().set_lazy_call(data_i, "s{}{}".format(i, idx))

def get_n_data(self):
data = self.get_data("data")
weight = [
Expand Down Expand Up @@ -407,6 +417,7 @@ def get_data(self, idx) -> list:
data_shape(k)
)
ret = self.process_scale(idx, ret)
self.set_lazy_call(ret, idx)
return ret

def get_phsp_noeff(self):
Expand Down
19 changes: 18 additions & 1 deletion tf_pwa/config_loader/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def _plot_partial_wave(
ax.set_title(display, fontsize="xx-large")
else:
ax.set_title(
"{}: -lnL= {:.5}".format(display, nll), fontsize="xx-large"
"{}: -lnL= {:.2f}".format(display, nll), fontsize="xx-large"
)
ax.set_xlabel(display + units)
ywidth = np.mean(
Expand Down Expand Up @@ -887,6 +887,23 @@ def _2d_plot(
plt.ylim(range2)
plt.savefig(prefix + k + "_data")
plt.clf()
print("Finish plotting 2D data " + prefix + k) # data

Check warning on line 890 in tf_pwa/config_loader/plot.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/config_loader/plot.py#L890

Added line #L890 was not covered by tests
if "data_hist" in plot_figs:
plt.hist2d(

Check warning on line 892 in tf_pwa/config_loader/plot.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/config_loader/plot.py#L892

Added line #L892 was not covered by tests
data_1,
data_2,
bins=100,
weights=data_dict["data_weights"],
cmin=1e-12,
)
plt.xlabel(name1)
plt.ylabel(name2)
plt.title(display, fontsize="xx-large")
plt.legend()
plt.xlim(range1)
plt.ylim(range2)
plt.savefig(prefix + k + "_data_hist")
plt.clf()

Check warning on line 906 in tf_pwa/config_loader/plot.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/config_loader/plot.py#L899-L906

Added lines #L899 - L906 were not covered by tests
print("Finish plotting 2D data " + prefix + k)
# sideband
if "sideband" in plot_figs:
Expand Down
28 changes: 19 additions & 9 deletions tf_pwa/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def __init__(self, f, x, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.extra = {}
self.cached_batch = None
self.cached_batch_dataset = None
self.cached_batch = {}
self.cached_file = None
self.name = ""
self.version = 0

def batch(self, batch, axis):
for i, j in zip(
Expand All @@ -89,9 +91,9 @@ def batch(self, batch, axis):
ret[k] = v
yield ret

def as_dataset(self, batch=65000, cached_file=None):
if self.cached_batch == batch:
return self.cached_batch_dataset
def as_dataset(self, batch=65000):
if batch in self.cached_batch:
return self.cached_batch[batch]

def f(x):
x_a = x["x"]
Expand All @@ -108,9 +110,11 @@ def f(x):
{"x": real_x, "extra": self.extra}
)
# data = data.batch(batch).cache().map(f)
if cached_file is not None:
if self.cached_file is not None:
from tf_pwa.utils import create_dir

Check warning on line 114 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L114

Added line #L114 was not covered by tests

cached_file = self.cached_file + self.name + str(self.version)

Check warning on line 116 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L116

Added line #L116 was not covered by tests

cached_file += "_" + str(batch)
create_dir(cached_file)
data = data.batch(batch).map(f)
Expand All @@ -119,8 +123,7 @@ def f(x):
data = data.batch(batch).cache().map(f)
data = data.prefetch(tf.data.AUTOTUNE)

self.cached_batch = batch
self.cached_batch_dataset = data
self.cached_batch[batch] = data
return data

def merge(self, *other, axis=0):
Expand All @@ -134,6 +137,10 @@ def merge(self, *other, axis=0):
self.f, data_merge(*all_x, axis=axis), *self.args, **self.kwargs
)
ret.extra = new_extra
ret.cached_file = self.cached_file
ret.name = self.name
for i in other:
ret.name += "_" + i.name
return ret

def __setitem__(self, index, value):
Expand All @@ -155,8 +162,11 @@ def get_weight(self):
return tf.ones(data_shape(self), dtype=get_config("dtype"))

def copy(self):
ret = LazyCall(lambda x: x, self)
ret = LazyCall(self.f, self.x, *self.args, **self.kwargs)
ret.extra = self.extra.copy()
ret.cached_file = self.cached_file
ret.name = self.name
ret.version += self.version + 1
return ret

def eval(self):
Expand Down
16 changes: 5 additions & 11 deletions tf_pwa/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,9 +1098,7 @@ def __init__(
batch=65000,
inmc=None,
gauss_constr={},
cached_file=None,
):
self.cached_file = cached_file
self.model = model
self.vm = model.vm
self.n_call = 0
Expand All @@ -1116,9 +1114,9 @@ def __init__(
self.alpha = tf.reduce_sum(weight) / tf.reduce_sum(weight * weight)
self.weight = weight
self.data = data
self.batch_data = self._convert_batch(data, batch, "data")
self.batch_data = self._convert_batch(data, batch)
self.mcdata = mcdata
self.batch_mcdata = self._convert_batch(mcdata, batch, "mc")
self.batch_mcdata = self._convert_batch(mcdata, batch)
self.batch = batch
if mcdata.get("weight", None) is not None:
mc_weight = tf.convert_to_tensor(mcdata["weight"], dtype="float64")
Expand All @@ -1127,16 +1125,12 @@ def __init__(
self.mc_weight = tf.convert_to_tensor(
[1 / n_mcdata] * n_mcdata, dtype="float64"
)
self.batch_mc_weight = self._convert_batch(
self.mc_weight, self.batch, "mcweight"
)
self.batch_mc_weight = self._convert_batch(self.mc_weight, self.batch)
self.gauss_constr = GaussianConstr(self.vm, gauss_constr)
self.cached_mc = {}

def _convert_batch(self, data, batch, name):
return _convert_batch(
data, batch, cached_file=self.cached_file, name=name
)
def _convert_batch(self, data, batch):
return _convert_batch(data, batch)

def get_params(self, trainable_only=False):
return self.vm.get_all_dic(trainable_only)
Expand Down

0 comments on commit d88a8f5

Please sign in to comment.