-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor RankingModel class for Text+Numr use case
- Loading branch information
Wei-Cheng Chang
committed
Sep 4, 2024
1 parent
37028ca
commit f07c4f8
Showing
10 changed files
with
935 additions
and
386 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# PECOS XMR Reranker on MS-Marco Dataset | ||
|
||
This is an example of PECOS-based RankingModel that reproduced the [RankLlaMA paper](https://arxiv.org/abs/2310.08319). | ||
|
||
## How to run | ||
|
||
### Training | ||
```bash | ||
torchrun --nnodes 1 --nproc-per-node 8 \ | ||
-m pecos.xmr.reranker.train \ | ||
--config_json_path ./msmarco_qwen2-7B.train.json | ||
``` | ||
|
||
### Predictions | ||
```bash | ||
python -m pecos.xmr.reranker.predict \ | ||
--config_json_path ./msmarco_qwen2-7B.pred.json | ||
``` | ||
|
||
## Evaluation | ||
We first convert the predictions from parquet to TREC format: | ||
```python | ||
python -u parquet_to_trec_eval.py -i inference_outputs/ms_marco/qwen2-7B -o inference_outputs/ms_marco/qwen2-7B.pred.trec | ||
``` | ||
|
||
We then follow [Pyserini]() evaluation protocol to eval the NDCG@10, | ||
and you should see the results like: | ||
```python | ||
python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 dl19-passage inference_outputs/ms_marco/qwen2-7B.pred.trec | ||
|
||
Results: | ||
ndcg_cut_10 all 0.7619 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
{ | ||
"target_data_folder": "./datasets/ms_marco/eval_aux/target", | ||
"input_data_folder": "./datasets/ms_marco/eval_aux/input", | ||
"label_data_folder": "./datasets/ms_marco/eval_aux/label", | ||
"model_path": "./models/ms_marco/qwen2-7B/", | ||
"output_dir": "./inference_outputs/ms_marco/qwen2-7B/", | ||
"per_device_eval_batch_size": 1024, | ||
"dataloader_num_workers": 1, | ||
"dataloader_prefetch_factor": 10, | ||
"rerank_max_len": 196, | ||
"query_prefix": "query: ", | ||
"passage_prefix": "document: ", | ||
"inp_id_col": "inp_id", | ||
"lbl_id_col": "lbl_id", | ||
"inp_id_orig_col": "inp_id_orig", | ||
"lbl_id_orig_col": "lbl_id_orig", | ||
"keyword_col_name": "keywords", | ||
"content_col_names": ["title", "contents"], | ||
"append_eos_token": false, | ||
"pad_to_multiple_of": 8 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
{ | ||
"train_params": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams" | ||
}, | ||
"target_data_folder": "./datasets/ms_marco/train/target", | ||
"input_data_folder": "./datasets/ms_marco/train/input", | ||
"label_data_folder": "./datasets/ms_marco/train/label", | ||
"hf_trainer_args": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.trainer###RankingTrainer.TrainingArgs" | ||
}, | ||
"output_dir": "./models/ms_marco/qwen2-7B", | ||
"ddp_find_unused_parameters": false, | ||
"loss_fn": "listwise", | ||
"loss_alpha": 1.0, | ||
"group_size": 16, | ||
"per_device_train_batch_size": 6, | ||
"gradient_accumulation_steps": 8, | ||
"disable_tqdm": false, | ||
"logging_strategy": "steps", | ||
"logging_first_step": false, | ||
"learning_rate": 1e-4, | ||
"max_steps": 1500, | ||
"save_steps": 50, | ||
"logging_steps": 10, | ||
"save_strategy": "steps", | ||
"save_total_limit": 5, | ||
"seed": 42, | ||
"data_seed": 42, | ||
"bf16": true, | ||
"dataloader_num_workers": 2, | ||
"dataloader_prefetch_factor": 10, | ||
"gradient_checkpointing": true, | ||
"deepseed": { | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "none", | ||
"pin_memory": true | ||
}, | ||
"offload_param": { | ||
"device": "none", | ||
"pin_memory": true | ||
}, | ||
"overlap_comm": true, | ||
"contiguous_gradients": true, | ||
"sub_group_size": 1e9, | ||
"reduce_bucket_size": 1e6, | ||
"stage3_prefetch_bucket_size": "auto", | ||
"stage3_param_persistence_threshold": "auto", | ||
"stage3_max_live_parameters": 1e9, | ||
"stage3_max_reuse_distance": 1e9, | ||
"stage3_gather_16bit_weights_on_model_save": true | ||
}, | ||
"fp16": { | ||
"enabled": "auto", | ||
"loss_scale": 0, | ||
"initial_scale_power": 10, | ||
"loss_scale_window": 1000, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
"bf16": { | ||
"enabled": "auto", | ||
"loss_scale": 0, | ||
"initial_scale_power": 10, | ||
"loss_scale_window": 1000, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
"optimizer": { | ||
"type": "AdamW", | ||
"params": { | ||
"lr": "auto", | ||
"betas": "auto", | ||
"eps": "auto", | ||
"weight_decay": "auto", | ||
"torch_adam": true | ||
} | ||
}, | ||
"scheduler": { | ||
"type": "WarmupDecayLR", | ||
"params": { | ||
"warmup_min_lr": "auto", | ||
"warmup_max_lr": "auto", | ||
"warmup_num_steps": "auto", | ||
"total_num_steps": "auto" | ||
} | ||
}, | ||
"gradient_accumulation_steps": "auto", | ||
"gradient_clipping": "auto", | ||
"steps_per_print": 1000, | ||
"train_batch_size": "auto", | ||
"train_micro_batch_size_per_gpu": "auto", | ||
"wall_clock_breakdown": false | ||
} | ||
} | ||
}, | ||
"model_params": { | ||
"__meta__": { | ||
"class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams" | ||
}, | ||
"encoder_config": { | ||
"text_config": { | ||
"model_type": "qwen2", | ||
"name_or_path": "Qwen/Qwen2-7B", | ||
"attn_implementation": "sdpa", | ||
"trust_remote_code": true, | ||
"token": null | ||
}, | ||
"numr_config": null, | ||
"text_pooling_type": "last", | ||
"head_size_list": [128] | ||
}, | ||
"model_modifier": { | ||
"modifier_type": "peft", | ||
"config_type": "LoraConfig" , | ||
"config": { | ||
"r": 16, | ||
"lora_alpha": 32, | ||
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | ||
"modules_to_save": ["head_layers", "scorer"], | ||
"lora_dropout": 0.1 | ||
} | ||
}, | ||
"positive_passage_no_shuffle": false, | ||
"negative_passage_no_shuffle": false, | ||
"rerank_max_len": 196, | ||
"query_prefix": "query: ", | ||
"passage_prefix": "document: ", | ||
"inp_id_col": "inp_id", | ||
"lbl_idxs_col": "ret_idxs", | ||
"score_col": "rel", | ||
"keyword_col_name": "keywords", | ||
"content_col_names": ["title", "contents"], | ||
"append_eos_token": false, | ||
"pad_to_multiple_of": 16 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
|
||
import argparse | ||
import os | ||
import pandas as pd | ||
|
||
|
||
def main(args): | ||
""" | ||
Combine all results from the results folder and write them to the output file. | ||
""" | ||
result_files = [ | ||
os.path.join(args.input_parquet_path, x) | ||
for x in os.listdir(args.input_parquet_path) | ||
] | ||
all_results = pd.read_parquet(result_files[0]) | ||
for f in result_files[1:]: | ||
all_results = pd.concat([all_results, pd.read_parquet(f)]) | ||
# sort all results by 'inp_id' and then 'score' in descending order | ||
all_results = all_results.sort_values(by=['inp_id', 'score'], ascending=[True, False]) | ||
|
||
cur_inp_id = None | ||
with open(args.output_trec_path, "w") as fout: | ||
for row in all_results.itertuples(): | ||
if cur_inp_id != row.inp_id: | ||
cur_inp_id = row.inp_id | ||
rank = 0 | ||
rank += 1 | ||
fout.write(f"{row.inp_id} Q0 {row.lbl_id} {rank} {row.score} dense\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-i", "--input-parquet-path", type=str, required=True) | ||
parser.add_argument("-o", "--output-trec-path", type=str, required=True) | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.