diff --git a/code4me-server/src/query_filter.py b/code4me-server/src/query_filter.py index 9dc582a..b26f002 100644 --- a/code4me-server/src/query_filter.py +++ b/code4me-server/src/query_filter.py @@ -180,11 +180,13 @@ def get_model(model_name): return model class Filter(enum.Enum): + NO_FILTER = 'no_filter' FEATURE = 'feature' CONTEXT = 'context' JOINT_H = 'joint_h' JOINT_A = 'joint_a' +no_filter = lambda request_json: True logres = Logres(coef, intercept) set_all_seeds() # just in case context_filter = MyPipeline( device=DEVICE, task='text-classification', @@ -200,6 +202,7 @@ class Filter(enum.Enum): model_name='13_jonberta-biased-12_codeberta-biased-2e-05lr--0-(ATTN-208C_f-[0]L)-2e-05lr--4' ) filters = { + Filter.NO_FILTER: no_filter, Filter.FEATURE: logres.predict, Filter.CONTEXT: context_filter, Filter.JOINT_H: joint_h_filter,