Conversation
Move batching and shuffling logic from SFTConfig into iterator functions. train_sft now accepts Iterable[List[Trajectory]] instead of individual trajectories, simplifying the API and making batch management more explicit. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
src/art/types.py
Outdated
|
|
||
|
|
||
| class SFTConfig(pydantic.BaseModel): | ||
| learning_rate: float = 1e-4 |
There was a problem hiding this comment.
Remove custom_lr_schedule
Make learning_rate: float | list[float]
src/art/dev/train.py
Outdated
| Used to identify where assistant turns begin (train on responses only). | ||
| """ | ||
|
|
||
| instruction_part: str |
There was a problem hiding this comment.
We probably can keep this class as empty?
Unsure if instruction_part and response_part is a good fit for experimental feature
src/art/local/backend.py
Outdated
| batch_size = 2 # Default to 2 for SFT | ||
|
|
||
| # Determine learning rates | ||
| if config.custom_lr_schedule and len(config.custom_lr_schedule) > 0: |
There was a problem hiding this comment.
- Refactor/Remove
custom_lr_schedule.learning_rateisfloat | list[float] - Add validation for
num_learning_rate==num_batches
src/art/unsloth/service.py
Outdated
|
|
||
| # Save checkpoint after training | ||
| # Name checkpoint by final training step: starting_step + num_batches | ||
| final_step = get_step_from_dir(self.output_dir) + len(sft_batches) |
There was a problem hiding this comment.
Checkpoint step should be still incremented by 1.
Checkpoint step != Gradient step
src/art/utils/model_config.py
Outdated
| response_part="<|im_start|>assistant\n", | ||
| ), | ||
| # Qwen 3 models (with thinking tokens) | ||
| "Qwen/Qwen3-8B": ModelConfig( |
There was a problem hiding this comment.
- How we decide to support all of this model?
- Prefer to keep it simple and start with model that's widely use in OpenPipe Platform and ART?
- Research Qwen chat template, iirc
<think></think>only show up at the last turn. We may need to remove<think></think>inresponse_partin Qwen.
There was a problem hiding this comment.
I kept only OpenPipe/Qwen3-14B-Instruct for now because it’s the only model with a custom chat template. All other mainstream models should be recognized by the detect_chat_template_parts function.
Also, I don’t feel strongly about this, but I did some research and didn’t find good arguments for using different default learning rates for different models. The general consensus online seems to be to start with 2e-4 with a linear/cosine scheduler.
| progress_bar.close() | ||
|
|
||
|
|
||
| def iterate_file( |
There was a problem hiding this comment.
Have iterate_file take in epoch
See the following PR for reference
| yield _parse_jsonl_line(line) | ||
|
|
||
|
|
||
| async def train_sft_from_file( |
There was a problem hiding this comment.
Modify this so user can have the training continue running after closing their laptop.
- Iterate_file(file, epoch)
- Calculate lr
- Call train_sft
3.1 Write to local disk
3.2 Upload to wandb artifact
3.3 Call train_sft API
3.4 Monitor training status
Resolved conflicts: - pyproject.toml: kept tinker deps and newer weave version from sft branch - src/art/backend.py: kept Protocol signatures from main, added _train_sft method - src/art/serverless/backend.py: kept SFT imports (Trajectory, SFTConfig) Co-Authored-By: Claude Opus 4.5 <[email protected]>
src/art/local/backend.py
Outdated
| from ..utils.model_config import get_instruction_response_parts | ||
|
|
||
| # Get instruction/response parts (from config or auto-detect) | ||
| instruction_part = dev_config.get("instruction_part", None) |
There was a problem hiding this comment.
Remove this? dev_config no longer have instruction_part and response_part
| """Model-specific configuration for chat templates and training defaults.""" | ||
|
|
||
|
|
||
| def detect_chat_template_parts( |
There was a problem hiding this comment.
We need to add support to handle OpenPipe/Qwen3-14B-Instruct?
src/art/utils/sft.py
Outdated
| return | ||
|
|
||
| # Prepare dataset: shuffle per epoch, concatenate, and calculate learning rates | ||
| all_trajectories, learning_rates = prepare_sft_dataset( |
There was a problem hiding this comment.
Won’t this blow up memory for large files?
E.g 1GB file train for 10 epoch will take at least 10GB itself outside of the system overhead itself.
We need to maintain the trajectory as iterator (streaming) from reading from file to calling train_sft to ensure we don't materialize the entire 10GB of memory. Using iterate_file with epochs should address this.
There was a problem hiding this comment.
Yeah, it was not yet optimized for large file streaming, just added it in the latest PR.
A summary of how it works:
-
Count rows without loading data
get_file_row_count(file_path)→ Scans file, counts non-empty lines
This gives usrow_countto calculate the LR schedule without loading data into memory. -
Calculate learning rate schedule upfront
total_trajectories = row_count × epochs
total_batches = ceil(total_trajectories / batch_size)
warmup_steps = total_batches × warmup_ratio
full_schedule = create_lr_schedule(total_batches, peak_lr, ...)
learning_rates = full_schedule[initial_step:] # Slice for resuming
The full LR schedule (one value per batch) is pre-computed as a list.
- Create a streaming trajectory generator
trajectories = iterate_file(file_path, epochs, shuffle_buffer_size, initial_skip)
This returns a generator (not a list). It:
- Reads the file line by line
- Uses buffer-based shuffling (fills buffer of 10k items, randomly pops one)
- Repeats for each epoch with a different random seed (seed + epoch)
-
Pass generator + LR list to backend
await model.train_sft(trajectories, config) # trajectories is a generator -
Backend batches and tokenizes on-the-fly
create_sft_batches()generator:- Collects trajectories into batches of size batch_size
- Tokenizes each batch immediately
- Yields SFTBatch objects
-
Producer thread feeds queue
It may look ugly, but in the local backend we run training in a subprocess, and we cannot send a generator to this subprocess. We can't just callunsloth/service.train_sft()and pass a generator of batches directly.
There are two options: -
Break down unsloth
service.train_sftinto smaller functions (setup, training, cleanup), iterate over batches on the client side and send each batch object to a training function individually. -
Use a queue — create a queue, put batches into it from a producer thread, and pass the queue to train_sft.
I went with the second approach. -
Service trains batch by batch
while batch := queue.get():
forward pass → loss → backward pass → optimizer step
yield metrics
No description provided.