-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
100 lines (81 loc) · 3.33 KB
/
train.py
File metadata and controls
100 lines (81 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# train.py
import os
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
EarlyStoppingCallback
)
from datasets import load_from_disk
from huggingface_hub import HfFolder, login
import argparse
from sconf import Config
import evaluate
import numpy as np
login(token=os.getenv("HF_TOKEN") or HfFolder.get_token())
def main(config_path="config.yaml", processed_data_dir="data/tokenized"):
# Load config using sconf
config = Config(config_path)
# Load model and tokenizer
print("Loading the model...")
model = AutoModelForSeq2SeqLM.from_pretrained(
config.model.name,
device_map=config.model.device_map,
use_cache=config.model.use_cache,
)
tokenizer = AutoTokenizer.from_pretrained(config.model.name)
# Load tokenized datasets
print("Loading preprocessed datasets...")
train_dataset = load_from_disk(os.path.join(processed_data_dir, "train"))
valid_dataset = load_from_disk(os.path.join(processed_data_dir, "validation"))
# Prepare output directory
output_dir = config.output.dir
# Hugging Face token (from environment or config)
hf_token = os.getenv("HF_TOKEN")
# Training arguments
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
num_train_epochs=config.training.num_train_epochs,
per_device_train_batch_size=config.training.per_device_train_batch_size,
per_device_eval_batch_size=config.training.per_device_eval_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
eval_accumulation_steps=config.training.eval_accumulation_steps,
fp16=config.training.fp16,
fp16_full_eval=config.training.fp16_full_eval,
learning_rate=config.training.learning_rate,
lr_scheduler_type=config.training.lr_scheduler_type,
eval_strategy=config.training.eval_strategy,
eval_steps=config.training.eval_steps,
save_strategy=config.training.save_strategy,
save_steps=config.training.save_steps,
save_total_limit=config.training.save_total_limit,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=config.training.logging_steps,
report_to=config.training.report_to,
push_to_hub=config.training.push_to_hub,
hub_private_repo=config.training.private_repo,
hub_strategy=config.training.strategy,
hub_token=hf_token,
)
# Trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
callbacks=[EarlyStoppingCallback(early_stopping_patience=config.training.early_stopping_patience)]
)
# Train
print("Starting training...")
trainer.train()
print("✅ Training complete!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train seq2seq model.")
parser.add_argument("--tokenized_dataset", type=str, default="data/tokenized", help="Path to preprocessed dataset directory")
parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
args = parser.parse_args()
main(config_path=args.config, processed_data_dir=args.tokenized_dataset)