Skip to content

WIP: SFT (local backend)#530

Open
Kovbo wants to merge 59 commits intomainfrom
sft-local-backend
Open

WIP: SFT (local backend)#530
Kovbo wants to merge 59 commits intomainfrom
sft-local-backend

Conversation

@Kovbo
Copy link
Collaborator

@Kovbo Kovbo commented Jan 22, 2026

No description provided.

@Kovbo Kovbo requested a review from angkywilliam January 22, 2026 02:48
@Kovbo Kovbo marked this pull request as ready for review January 22, 2026 21:42
src/art/types.py Outdated


class SFTConfig(pydantic.BaseModel):
learning_rate: float = 1e-4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove custom_lr_schedule
Make learning_rate: float | list[float]

Used to identify where assistant turns begin (train on responses only).
"""

instruction_part: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably can keep this class as empty?
Unsure if instruction_part and response_part is a good fit for experimental feature

batch_size = 2 # Default to 2 for SFT

# Determine learning rates
if config.custom_lr_schedule and len(config.custom_lr_schedule) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Refactor/Remove custom_lr_schedule.learning_rate is float | list[float]
  2. Add validation for num_learning_rate == num_batches


# Save checkpoint after training
# Name checkpoint by final training step: starting_step + num_batches
final_step = get_step_from_dir(self.output_dir) + len(sft_batches)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint step should be still incremented by 1.
Checkpoint step != Gradient step

response_part="<|im_start|>assistant\n",
),
# Qwen 3 models (with thinking tokens)
"Qwen/Qwen3-8B": ModelConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. How we decide to support all of this model?
  2. Prefer to keep it simple and start with model that's widely use in OpenPipe Platform and ART?
  3. Research Qwen chat template, iirc <think></think> only show up at the last turn. We may need to remove <think></think> in response_part in Qwen.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept only OpenPipe/Qwen3-14B-Instruct for now because it’s the only model with a custom chat template. All other mainstream models should be recognized by the detect_chat_template_parts function.

Also, I don’t feel strongly about this, but I did some research and didn’t find good arguments for using different default learning rates for different models. The general consensus online seems to be to start with 2e-4 with a linear/cosine scheduler.

progress_bar.close()


def iterate_file(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have iterate_file take in epoch
See the following PR for reference

yield _parse_jsonl_line(line)


async def train_sft_from_file(
Copy link
Collaborator

@angkywilliam angkywilliam Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify this so user can have the training continue running after closing their laptop.

  1. Iterate_file(file, epoch)
  2. Calculate lr
  3. Call train_sft
    3.1 Write to local disk
    3.2 Upload to wandb artifact
    3.3 Call train_sft API
    3.4 Monitor training status

Kovbo and others added 4 commits January 26, 2026 19:45
Resolved conflicts:
- pyproject.toml: kept tinker deps and newer weave version from sft branch
- src/art/backend.py: kept Protocol signatures from main, added _train_sft method
- src/art/serverless/backend.py: kept SFT imports (Trajectory, SFTConfig)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@Kovbo Kovbo changed the title SFT (local backend) WIP: SFT (local backend) Feb 2, 2026
from ..utils.model_config import get_instruction_response_parts

# Get instruction/response parts (from config or auto-detect)
instruction_part = dev_config.get("instruction_part", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this? dev_config no longer have instruction_part and response_part

"""Model-specific configuration for chat templates and training defaults."""


def detect_chat_template_parts(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add support to handle OpenPipe/Qwen3-14B-Instruct?

return

# Prepare dataset: shuffle per epoch, concatenate, and calculate learning rates
all_trajectories, learning_rates = prepare_sft_dataset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won’t this blow up memory for large files?
E.g 1GB file train for 10 epoch will take at least 10GB itself outside of the system overhead itself.

We need to maintain the trajectory as iterator (streaming) from reading from file to calling train_sft to ensure we don't materialize the entire 10GB of memory. Using iterate_file with epochs should address this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it was not yet optimized for large file streaming, just added it in the latest PR.

A summary of how it works:

  1. Count rows without loading data
    get_file_row_count(file_path) → Scans file, counts non-empty lines
    This gives us row_count to calculate the LR schedule without loading data into memory.

  2. Calculate learning rate schedule upfront

total_trajectories = row_count × epochs
total_batches = ceil(total_trajectories / batch_size)
warmup_steps = total_batches × warmup_ratio

full_schedule = create_lr_schedule(total_batches, peak_lr, ...)
learning_rates = full_schedule[initial_step:]  # Slice for resuming

The full LR schedule (one value per batch) is pre-computed as a list.

  1. Create a streaming trajectory generator
    trajectories = iterate_file(file_path, epochs, shuffle_buffer_size, initial_skip)
    This returns a generator (not a list). It:
  • Reads the file line by line
  • Uses buffer-based shuffling (fills buffer of 10k items, randomly pops one)
  • Repeats for each epoch with a different random seed (seed + epoch)
  1. Pass generator + LR list to backend
    await model.train_sft(trajectories, config) # trajectories is a generator

  2. Backend batches and tokenizes on-the-fly
    create_sft_batches() generator:

    • Collects trajectories into batches of size batch_size
    • Tokenizes each batch immediately
    • Yields SFTBatch objects
  3. Producer thread feeds queue
    It may look ugly, but in the local backend we run training in a subprocess, and we cannot send a generator to this subprocess. We can't just call unsloth/service.train_sft() and pass a generator of batches directly.
    There are two options:

  4. Break down unsloth service.train_sft into smaller functions (setup, training, cleanup), iterate over batches on the client side and send each batch object to a training function individually.

  5. Use a queue — create a queue, put batches into it from a producer thread, and pass the queue to train_sft.
    I went with the second approach.

  6. Service trains batch by batch

  while batch := queue.get():
      forward pass → loss → backward pass → optimizer step
      yield metrics

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants