-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_config.yaml
49 lines (44 loc) · 1.23 KB
/
train_config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
dataset: PUBMED
feature_type: TAPE
cache_dir: .cache
seed: 42
lm_encoder:
dataset_name: ${dataset}
feature_type: ${feature_type}
model_name_or_path: avsolatorio/GIST-Embedding-v0
model_library: sentence_transformer
sentence_transformer_encoder_args:
batch_size: 100
show_progress_bar: True
precision: float32
cache_dir: ${cache_dir}
llm_online_engine:
cache_dir: ${cache_dir}
sampling_kwargs:
max_tokens: 500 # LLM completion tokens
# Pick any provider model from https://litellm.vercel.app/docs/providers
#model: anthropic/claude-3-haiku-20240307
#rate_limit_per_minute: 5 # https://docs.anthropic.com/en/api/rate-limits#rate-limits
model: huggingface/meta-llama/Meta-Llama-3-8B-Instruct
# llm_offline_engine:
# cache_dir: ${cache_dir}
# sampling_kwargs:
# max_tokens: 500
# n: 1
# temperature: 0.6
# top_p: 0.9
# model: meta-llama/Meta-Llama-3-8B-Instruct
# batch_size: 100
# engine_kwargs:
# seed: ${seed}
# max_model_len: 8192
gnn_model:
conv_layer: SAGEConv # `torch_geometric.nn.conv` layer
hidden_channels: 64
num_layers: 4
dropout: 0.1
gnn_trainer:
epochs: 500
early_stopping_patience: 50
lr: 0.0031622776601683794 # 10**-2.5
weight_decay: 0.00001 # 10**-5