Skip to content

Commit

Permalink
added support for bert pruning (#1)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <ubuntu@ip-172-31-89-56.ec2.internal>
  • Loading branch information
anandhu-eng and Ubuntu authored Aug 2, 2023
1 parent 8fab885 commit c99186c
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cm-mlops/script/get-ml-model-huggingface-zoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,4 @@ ___
___
### Maintainers

* [Open MLCommons taskforce on automation and reproducibility](https://github.com/mlcommons/ck/blob/master/docs/taskforce.md)
* [Open MLCommons taskforce on automation and reproducibility](https://github.com/mlcommons/ck/blob/master/docs/taskforce.md)
5 changes: 5 additions & 0 deletions cm-mlops/script/get-ml-model-huggingface-zoo/_cm.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
"env": {
"CM_MODEL_ZOO_STUB": "pierreguillou/bert-base-cased-squad-v1.1-portuguese"
}
},
"prune":{
"env":{
"CM_MODEL_TASK": "prune"
}
}
}
}
Expand Down
26 changes: 19 additions & 7 deletions cm-mlops/script/get-ml-model-huggingface-zoo/download_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
from huggingface_hub import hf_hub_download
import os

model_stub= os.environ.get('CM_MODEL_ZOO_STUB', '')
model_filename= os.environ.get('CM_MODEL_ZOO_FILENAME', 'model.onnx')
model_stub = os.environ.get('CM_MODEL_ZOO_STUB', '')
model_task = os.environ.get('CM_MODEL_TASK', '')

print("Downloading model: "+model_stub)
if model_task == "prune":
print("Downloading model: "+model_stub)
downloaded_model_path = hf_hub_download(repo_id=model_stub,
filename="pytorch_model.bin",
cache_dir=os.getcwd())
downloaded_model_path = hf_hub_download(repo_id=model_stub,
filename="config.json",
cache_dir=os.getcwd())
with open('tmp-run-env.out', 'w') as f:
f.write(f"CM_ML_MODEL_FILE_WITH_PATH={os.path.join(os.getcwd(),'')}")
else:
model_filename= os.environ.get('CM_MODEL_ZOO_FILENAME', 'model.onnx')

downloaded_model_path = hf_hub_download(repo_id=model_stub,
print("Downloading model: "+model_stub)

downloaded_model_path = hf_hub_download(repo_id=model_stub,
filename=model_filename,
cache_dir=os.getcwd(),
force_filename=model_filename)

with open('tmp-run-env.out', 'w') as f:
f.write(f"CM_ML_MODEL_FILE_WITH_PATH={os.path.join(os.getcwd(),model_filename)}")

with open('tmp-run-env.out', 'w') as f:
f.write(f"CM_ML_MODEL_FILE_WITH_PATH={os.path.join(os.getcwd(),model_filename)}")
38 changes: 36 additions & 2 deletions cm-mlops/script/prune-bert-models/_cm.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,53 @@
"tags": "get,generic-python-lib,_tqdm"
},
{
"tags": "get,generic-python-lib,_torch"
"tags": "get,generic-python-lib,_torch_cuda"
},
{
"tags": "get,generic-python-lib,_datasets"
},
{
"tags": "get,generic-python-lib,_transformers"
},
{
"tags": "get,generic-python-lib,_scikit-learn"
},
{
"tags": "get,git,repo,_repo.https://github.com/anandhu-eng/retraining-free-pruning"
},
{
"names": [
"get-model"
],
"tags": "get, ml-model, model, zoo, model-zoo, huggingface, _prune"
}
],
"tags": [
"prune",
"bert-prune",
"prune-bert-models"
],
"uid": "76182d4896414216"
"uid": "76182d4896414216",
"variations":{
"path.#":{
"env":{
"CM_UNPRUNED_MODEL_PATH":"#"
}
},
"task.#":{
"env":{
"CM_PRUNE_TASK":"#"
}
},
"model-name.#":{
"adr":{
"get-model":{
"tags":"_model-stub.#"
}
},
"env":{
"CM_PRUNE_MODEL_NAME":"#"
}
}
}
}
11 changes: 7 additions & 4 deletions cm-mlops/script/prune-bert-models/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ def preprocess(i):
os_info = i['os_info']

env = i['env']

print("Entered preprocess")

env['BERT_PRUNE_REPO_PATH'] = env['CM_GIT_CHECKOUT_PATH']
print("Pruning repo path:"+env['BERT_PRUNE_REPO_PATH'])
env['CM_UNPRUNED_MODEL_PATH']=env['CM_ML_MODEL_FILE_WITH_PATH']+"models--bert-large-uncased/snapshots/80792f8e8216b29f3c846b653a0ff0a37c210431"
out_dir="/home/ubuntu/prune_model/out"
cmd = "python3 "+env['BERT_PRUNE_REPO_PATH']+"/main.py --model_name " + env['CM_PRUNE_MODEL_NAME'] + " --task_name " + env['CM_PRUNE_TASK'] + " --ckpt_dir "+env['CM_UNPRUNED_MODEL_PATH']+" --constraint 0.5 --output_dir "+out_dir
os.system(cmd)
return {'return': 0}

def postprocess(i):
Expand All @@ -17,4 +20,4 @@ def postprocess(i):

print("Entered postprocess")

return {'return': 0}
return {'return': 0}

0 comments on commit c99186c

Please sign in to comment.