-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathargs.py
More file actions
31 lines (26 loc) · 796 Bytes
/
args.py
File metadata and controls
31 lines (26 loc) · 796 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import json
from types import SimpleNamespace
def parse_json_args(filename):
args = {
"model_id": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
"dataset_id": "alpaca",
"seed": 123456,
"precision": "32",
"strategy": "axonn",
"tp_dimensions": [],
"global_batch_size": 4,
"gradient_acc_steps": 1,
"log_interval": 1,
"num_epochs": 1,
"stop_iteration": -1,
"random_init": False,
"compile": False,
"tokens_to_generate": 512,
}
user_args = {}
with open(filename) as f:
user_args = json.load(f)
args.update((k, user_args[k]) for k in args.keys() & user_args.keys())
# convert dict to object
args = SimpleNamespace(**args)
return args