Skip to content

Commit

Permalink
Merge pull request #146 from jiangyi15/multigpu1
Browse files Browse the repository at this point in the history
Refactor: distribute data to multi gpu  in FCN
  • Loading branch information
jiangyi15 authored Jun 9, 2024
2 parents f3793eb + 3c42d56 commit 8d419f5
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 41 deletions.
56 changes: 39 additions & 17 deletions config.sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ data:
## the `finals` in `particle` will be used by default.
## it is necessary when dat files has momentum out of final particles
dat_order: [B, C, D]
## The file path is grouped by `name[_tail]`
## Each tail correspanding to the valriable in a dataset
## basic data files, every line is `E px py pz` as 4-momentum of a particle,
## and every m lines group as m final particls
## support input multi data as [["data.dat"], ["data2.dat"]] to do simultaneous fit using the same model
data: ["data/data4600_new.dat"]
## The additional file is grouped by `name[_tail]`
## Each tail is corresponding to the valriables in a dataset `name`
## data weight file, each line for a weight
# data_weight: ["data/data_weight.dat"]
## phase space data for normalize amplitude
Expand All @@ -28,19 +33,34 @@ data:
bg: ["data/bg4600_new.dat"]
## background weight
bg_weight: 0.731
## inject MC in data
# inmc: ["data/inMC.dat"]
# inject_ratio: 0.01
# float_inmc_ratio_in_pdf: False
## using total monumtum direction as initial axis (or labrary axis)
## set to False if the initial particle is e+e- direcly
# random_z: True
## whether boost data to center mass frame first
# center_mass: True
# cached data file
## cached data file
# cached_data: "data/all_data.npy"
# data_charge: ["data/data4600_cc.dat"] # charge conjugation condition same as weight
## charge conjugation condition same as weight
# data_charge: ["data/data4600_cc.dat"]
# cp_trans: True # when used charge conjugation as above, this do p -> -p for charge conjugation process.

## Currently, addtion configuration for fit can also be put in here, some option might be changed frequrently.
## likehood formula, cfit, some model require other options (bg_frac)
# model: default
# bg_frac: 0.3
## use tf function to complite the amplitude model
# use_tf_function: Ture
## Pre-Proceesor and amplutude model for different way of amplitude calculation. ["default", "cached_amp","cached_shape", "p4_directly"]
# preprocessor: cached_shape
# amp_model: cached_shape
## use TensorFlow Dataset instead of loading all data to GPU directly.
# lazy_call: True
## caching preprocessor results in file_name
# cached_lazy_call: file_name
## use memmap instead of loading data to memory
# lazy_file: True
## using multi gpu (experimental), only support with `model: simple`
# multi_gpu: True


## `decay` describe the decay structure, each node can be a list of particle name
decay:
## each entry is a list of direct decay particle, or a list or the list.
Expand Down Expand Up @@ -112,25 +132,27 @@ particle:

## The following config is used for DecayChain
# decay_chain:
# $all:
# is_cp: True
# $all:
# is_cp: True

## constrains for params [WIP]
constrains:
# particle:
# equal:
# mass: [[D1_2430p, D1_2430]]
# equal:
# mass: [[D1_2430p, D1_2430]]
## fix the first decay chain total to 1.0
decay:
fix_chain_idx: 0
fix_chain_val: 1
# decay_d: 3.0
# decay_d: 3.0 # support number or list of number
# fix_var:
# "A->Zc_4025.DZc_4025->B.C_total_0r": 1.0
# "A->Zc_4025.DZc_4025->B.C_total_0r": 1.0
# free_var:
# - "A->Zc_4025.DZc_4025->B.C_total_0r"
# var_range:
# "A->Zc_4025.DZc_4025->B.C_total_0r": [1.0, null]
# "A->Zc_4025.DZc_4025->B.C_total_0r": [1.0, null]
# var_equal:
# - ["A->D2_2460p.CD2_2460p->B.D_total_0r", "A->D2_2460.BD2_2460->C.D_total_0r"]
# - ["A->D2_2460p.CD2_2460p->B.D_total_0r", "A->D2_2460.BD2_2460->C.D_total_0r"]

## plot describe the configuration of plotting 1-d data distribution
plot:
Expand Down Expand Up @@ -199,4 +221,4 @@ plot:
x: m_R_CD**2
y: m_R_BC**2
display: "$M_{CD}^2$ vs $M_{BC}^2$"
plot_figs: ["data", "sideband", "fitted"]
plot_figs: ["data", "sideband", "fitted", "pull"]
8 changes: 7 additions & 1 deletion fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,13 @@ def main():
parser.add_argument(
"--CPU-plot", action="store_true", default=False, dest="cpu_plot"
)
parser.add_argument("-c", "--config", default="config.yml", dest="config")
parser.add_argument(
"-c",
"--config",
default="config.yml",
dest="config",
help="config.yml files, support input multiply file as `config1.yml,config2.yml` to do simultaneous fit using different model",
)
parser.add_argument(
"-i", "--init_params", default="init_params.json", dest="init"
)
Expand Down
25 changes: 4 additions & 21 deletions tf_pwa/model/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,7 @@ def _fast_nll_part_grad(self, data, int_mc=None, idx=0):
@tf.function
def _fast_int_mc_grad_multi(self, ia):
strategy = self.Amp.vm.strategy
n_p = strategy.num_replicas_in_sync
ia = list(
split_generator(ia, batch_size=(data_shape(ia) + n_p - 1) // n_p)
)

def _tmp_fun(ctx):
return ia[ctx.replica_id_in_sync_group]

i = strategy.experimental_distribute_values_from_function(_tmp_fun)
a, b = i
a, b = ia
vm = self.Amp.vm
per_replica_losses = vm.strategy.run(
self.value_and_grad(self.eval_normal_factors), args=(a, b)
Expand All @@ -96,22 +87,14 @@ def _tmp_fun(ctx):
@tf.function
def _fast_nll_part_grad_multi(self, ia, int_mc_x, int_mc_g, idx):
strategy = self.Amp.vm.strategy
n_p = strategy.num_replicas_in_sync
ia = list(
split_generator(ia, batch_size=(data_shape(ia) + n_p - 1) // n_p)
)

def _tmp_fun(ctx):
return ia[ctx.replica_id_in_sync_group]

ab = strategy.experimental_distribute_values_from_function(_tmp_fun)
a, b = ia
int_mc = SumVar(int_mc_x, int_mc_g, self.Amp.trainable_variables)
vm = self.Amp.vm
per_replica_losses = vm.strategy.run(
self.value_and_grad(
lambda i: self.eval_nll_part(i[0], i[1], int_mc(), idx)
lambda i0, i1: self.eval_nll_part(i0, i1, int_mc(), idx)
),
args=(ab,),
args=(a, b),
)
tmp = vm.strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None
Expand Down
33 changes: 31 additions & 2 deletions tf_pwa/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,12 +1159,41 @@ def __init__(
self.mc_weight = tf.convert_to_tensor(
[1 / n_mcdata] * n_mcdata, dtype="float64"
)

self.batch_weight = self._convert_batch(self.weight, self.batch)
self.batch_mc_weight = self._convert_batch(self.mc_weight, self.batch)
self.gauss_constr = GaussianConstr(self.vm, gauss_constr)
self.cached_mc = {}

def _convert_batch(self, data, batch):
return _convert_batch(data, batch)
ret = _convert_batch(data, batch)
if self.vm.strategy is not None:
ret = self._distribute_multi_gpu(ret)
return ret

def _distribute_multi_gpu(self, data):
strategy = self.vm.strategy
if isinstance(data, tf.data.Dataset):
data = strategy.experimental_distribute_dataset(data)
elif isinstance(data, list):
ret = []
n_p = strategy.num_replicas_in_sync
for ia in data:
ia = list(
split_generator(
ia, batch_size=(data_shape(ia) + n_p - 1) // n_p
)
)

def _tmp_fun(ctx):
return ia[ctx.replica_id_in_sync_group]

tmp = strategy.experimental_distribute_values_from_function(
_tmp_fun
)
ret.append(tmp)
data = ret
return data

def get_params(self, trainable_only=False):
return self.vm.get_all_dic(trainable_only)
Expand Down Expand Up @@ -1214,7 +1243,7 @@ def get_nll_grad(self, x={}):
nll, g = self.model.nll_grad_batch(
self.batch_data,
self.batch_mcdata,
weight=list(data_split(self.weight, self.batch)),
weight=self.batch_weight,
mc_weight=self.batch_mc_weight,
)
self.n_call += 1
Expand Down

0 comments on commit 8d419f5

Please sign in to comment.