Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

143 changes: 129 additions & 14 deletions train_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def sparse_comms_start(idxes_np, N, rank, world, send_idxes_buffer):
)
send_counts = torch.from_numpy(insertion_points[1:] - insertion_points[:-1])
# zero-out own send-count - we won't send our own gradient rows to ourselves as it's a waste:
# in sparse_comms_merge_gradients, we'll use the slice of the gradient that already includes them as the base tensor
# in sparse_comms_merge_gradients, we'll use the slice of the gradient that already includes them as the base tensorwhy
send_counts[rank] = 0

# remove indexes owned by our rank from the send list
Expand Down Expand Up @@ -1456,7 +1456,7 @@ def get_bigram_hash(x):
out[1:] = torch.bitwise_xor(rand_int_1 * out[1:], rand_int_2 * out[:-1]) % mod
return out

def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True):
def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True, yield_cpu: bool = False):
# align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
Expand Down Expand Up @@ -1522,21 +1522,29 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l
_cum_lengths = _cum_lengths.to(dtype=torch.int32)
_bigram_inputs = get_bigram_hash(_inputs)

new_params = yield (
_inputs.to(device="cuda", non_blocking=True),
_targets.to(device="cuda", non_blocking=True),
_cum_lengths.to(device="cuda", non_blocking=True),
_bigram_inputs.to(device="cuda", non_blocking=True),
_bigram_inputs.numpy(),
)
if yield_cpu:
new_params = yield (
_inputs.pin_memory(),
_targets.pin_memory(),
_cum_lengths.pin_memory(),
_bigram_inputs.pin_memory(),
_bigram_inputs.numpy(),
)
else:
new_params = yield (
_inputs.to(device="cuda", non_blocking=True),
_targets.to(device="cuda", non_blocking=True),
_cum_lengths.to(device="cuda", non_blocking=True),
_bigram_inputs.to(device="cuda", non_blocking=True),
_bigram_inputs.numpy(),
)

if new_params is not None:
# makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send()
new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params
assert new_num_tokens % (world_size * new_grad_accum_steps) == 0, "Num tokens must be divisible by world size"
num_tokens = new_num_tokens // new_grad_accum_steps
max_seq_len = new_max_seq_len

# -----------------------------------------------------------------------------
# Training Management

Expand Down Expand Up @@ -1927,7 +1935,112 @@ def nvidia_smi():
########################################
# Training and validation #
########################################
train_loader = distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, TRAINING_STAGES[0].train_max_seq_len, grad_accum_steps=grad_accum_steps)
class PrefetchLoader:
"""Overlaps CPU data prep AND H2D transfers with GPU compute via double-buffering.
- Background thread: CPU-only data loading (no CUDA ops, avoids NCCL conflicts)
- Main thread: H2D on a dedicated copy stream with GPU-side event sync
- Pipeline: H2D for batch N+1 runs on copy stream while compute for batch N
runs on default stream, using separate DMA and compute engines."""
def __init__(self, gen):
self.gen = gen
self._cpu_result = None
self._error = None
self._ready = threading.Event()
self._go = threading.Event()
self._device = torch.cuda.current_device()
self._copy_stream = torch.cuda.Stream(device=self._device)
self._copy_event = torch.cuda.Event()
self._pending_gpu = None # GPU tensors with H2D in flight on copy stream
self._worker_running = False
self._thread = threading.Thread(target=self._worker, daemon=True)
self._thread.start()

def _worker(self):
"""Background thread: CPU data prep only, no CUDA ops."""
while True:
self._go.wait()
self._go.clear()
try:
self._cpu_result = self.gen.send(None)
except Exception as e:
self._error = e
self._ready.set()

def _h2d_async(self, cpu_tensors):
"""Start H2D on copy stream from main thread. Returns GPU tensor refs immediately."""
with torch.cuda.stream(self._copy_stream):
gpu = tuple(
t if isinstance(t, np.ndarray)
else t.to(device="cuda", non_blocking=True)
for t in cpu_tensors
)
self._copy_event.record(self._copy_stream)
return gpu

@staticmethod
def _mark_stream(gpu_tensors, stream):
"""Tell caching allocator these tensors are used on stream (prevents premature reuse)."""
for t in gpu_tensors:
if isinstance(t, torch.Tensor):
t.record_stream(stream)

def _drain_worker(self):
"""Wait for and consume worker result if running."""
if self._worker_running:
with torch.profiler.record_function("PrefetchLoader::wait_worker"):
self._ready.wait()
self._ready.clear()
self._worker_running = False
if self._error is not None:
raise RuntimeError("PrefetchLoader worker died") from self._error

def _kick_worker(self):
"""Start background CPU prep for next batch."""
self._go.set()
self._worker_running = True

def _enter_pipeline(self):
"""Wait for next CPU batch and start its H2D on copy stream (pipeline fill)."""
self._drain_worker()
self._pending_gpu = self._h2d_async(self._cpu_result)
self._kick_worker()

def send(self, value):
with torch.profiler.record_function("PrefetchLoader::send"):
default_stream = torch.cuda.current_stream()

if value is not None:
# Stage change: discard pipeline, synchronous path
self._pending_gpu = None
self._drain_worker()
result = self._h2d_async(self.gen.send(value))
default_stream.wait_event(self._copy_event)
self._mark_stream(result, default_stream)
self._kick_worker()
self._enter_pipeline()
return result

if self._pending_gpu is not None:
# Steady state: H2D for this batch already in flight on copy stream
with torch.profiler.record_function("PrefetchLoader::wait_h2d"):
default_stream.wait_event(self._copy_event)
result = self._pending_gpu
self._pending_gpu = None
# Tell allocator default stream uses these (allocated on copy stream)
self._mark_stream(result, default_stream)
# Start H2D for next batch (overlaps with caller's GPU compute)
self._enter_pipeline()
return result

# Bootstrap (very first call)
result = self._h2d_async(self.gen.send(None))
default_stream.wait_event(self._copy_event)
self._mark_stream(result, default_stream)
self._kick_worker()
self._enter_pipeline()
return result

train_loader = PrefetchLoader(distributed_data_generator(args.train_files, TRAINING_STAGES[0].batch_size, TRAINING_STAGES[0].train_max_seq_len, grad_accum_steps=grad_accum_steps, yield_cpu=True))

gc.collect()

Expand All @@ -1937,6 +2050,7 @@ def nvidia_smi():
t0 = time.perf_counter()
# begin training
train_steps = training_schedule.total_steps

for step in range(train_steps + 1):
last_step = (step == train_steps)
training_manager.advance_schedule(step)
Expand Down Expand Up @@ -1973,9 +2087,10 @@ def nvidia_smi():
# the last step only has the validation loop, so break to avoid training
break

# --------------- TRAINING SECTION -----------------
# --------------- TRAINING SECTION ----------------
for idx in range(grad_accum_steps):
inputs, targets, cum_seqlens, bigram_inputs, bigram_cpu = train_loader.send(training_manager.train_loader_send_args)
send_args = training_manager.train_loader_send_args if idx == 0 else None
inputs, targets, cum_seqlens, bigram_inputs, bigram_cpu = train_loader.send(send_args)
training_manager.sparse_index_update(step, bigram_cpu)
loss = model(inputs, targets, cum_seqlens, bigram_inputs, training_manager.get_forward_args()) * grad_scale
training_manager.sparse_index_share(step)
Expand All @@ -1989,4 +2104,4 @@ def nvidia_smi():

print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)
dist.destroy_process_group()
dist.destroy_process_group()