-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpt2_rl_hyperparameter_search.yaml
More file actions
52 lines (45 loc) · 1.53 KB
/
gpt2_rl_hyperparameter_search.yaml
File metadata and controls
52 lines (45 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
description: Hyperparameter Search for RL on Graph (H2)
target:
service: amlk8s
name: itpeastusv100cl2
vc: resrchvc
environment:
image: pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel
setup:
- pip install --user numpy
- pip install --user networkx
- pip install --user numba
- pip install --user stable-baselines3==1.0
- pip install --user transformers
- pip install --user gym
- pip install --user sklearn
- pip install --user pandas
- pip install --user sentencepiece
- pip install --user spacy==3.0.3
- pip install --user https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.0.0/en_core_web_lg-3.0.0-py3-none-any.whl
code:
local_dir: $CONFIG_DIR/h2_code
data:
local_dir: $CONFIG_DIR/h2_data_models/
remote_dir: h2_data_models
search:
job_template:
name: search_gpt2_turing_{auto:s}
sku: G1
command:
- python graph_rl.py --input_file $$AMLT_DATA_DIR/experiments_data/turing_leakages
--output_file $$AMLT_OUTPUT_DIR/h2_turing_gpt2 --cls_model $$AMLT_DATA_DIR/roberta
--gen_model $$AMLT_DATA_DIR/gpt2_medium --batch_size 64 --top_k {top_k} --steps 5000
--horizon {horizon} --sim_threshold {sim_threshold}
type: grid
max_trials: 16
params:
- name: sim_threshold
spec: discrete
values: [0.5, 0.6, 0.7, 0.8]
- name: top_k
spec: discrete
values: [200, 300]
- name: horizon
spec: discrete
values: [200, 300]