-
Notifications
You must be signed in to change notification settings - Fork 204
Expand file tree
/
Copy pathfile_dataset.py
More file actions
77 lines (59 loc) · 2.25 KB
/
file_dataset.py
File metadata and controls
77 lines (59 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import threading
from llm_trainer import FileDataset, TrainerTools
from constant import data_root_dir
from modelscope import dataset_snapshot_download
class FileDatasetBase(FileDataset):
def __init__(self, file_names: list):
self.file_names = file_names
def __len__(self) -> int:
return len(self.file_names)
def __getitem__(self, idx) -> str:
file_path = f"{data_root_dir()}{self.file_names[idx]}"
# 下载当前文件
if not os.path.exists(file_path):
if TrainerTools().parallel.is_main_process:
dataset_snapshot_download(
'qibin0506/Cortex-3.0-data',
allow_file_pattern=[self.file_names[idx]],
local_dir=data_root_dir()
)
TrainerTools().parallel.wait()
# 下载后一个文件
if idx < len(self.file_names) - 1 and TrainerTools().parallel.is_main_process:
next_file = self.file_names[idx + 1]
dst_file = f'{data_root_dir()}{next_file}'
if not os.path.exists(dst_file):
threading.Thread(
target=dataset_snapshot_download,
kwargs={
'dataset_id': 'qibin0506/Cortex-3.0-data',
'allow_file_pattern': [next_file],
'local_dir': data_root_dir()
}
).start()
# 删除前一个文件
if idx > 0 and TrainerTools().parallel.is_main_process:
prev_file = self.file_names[idx - 1]
if os.path.exists(f'{data_root_dir()}{prev_file}'):
os.remove(f'{data_root_dir()}{prev_file}')
return file_path
class PretrainFileDataset(FileDatasetBase):
def __init__(self):
super().__init__([
'pretrain_data_0.npy',
'pretrain_data_1.npy',
])
class MidtrainFileDataset(FileDatasetBase):
def __init__(self):
super().__init__([
'midtrain_data_0.npy',
])
class SFTFileDataset(FileDatasetBase):
def __init__(self):
super().__init__([
'sft_data.npy',
])
class PPOFileDataset(FileDatasetBase):
def __init__(self):
super().__init__(['ppo_data.npy'])