CARVIEW |
Select Language
HTTP/2 200
date: Tue, 29 Jul 2025 11:53:05 GMT
content-type: text/html; charset=utf-8
vary: X-PJAX, X-PJAX-Container, Turbo-Visit, Turbo-Frame, X-Requested-With,Accept-Encoding, Accept, X-Requested-With
etag: W/"e4489d7b44526e15269d3a9b099ad9e6"
cache-control: max-age=0, private, must-revalidate
strict-transport-security: max-age=31536000; includeSubdomains; preload
x-frame-options: deny
x-content-type-options: nosniff
x-xss-protection: 0
referrer-policy: origin-when-cross-origin, strict-origin-when-cross-origin
content-security-policy: default-src 'none'; base-uri 'self'; child-src github.githubassets.com github.com/assets-cdn/worker/ github.com/assets/ gist.github.com/assets-cdn/worker/; connect-src 'self' uploads.github.com www.githubstatus.com collector.github.com raw.githubusercontent.com api.github.com github-cloud.s3.amazonaws.com github-production-repository-file-5c1aeb.s3.amazonaws.com github-production-upload-manifest-file-7fdce7.s3.amazonaws.com github-production-user-asset-6210df.s3.amazonaws.com *.rel.tunnels.api.visualstudio.com wss://*.rel.tunnels.api.visualstudio.com objects-origin.githubusercontent.com copilot-proxy.githubusercontent.com proxy.individual.githubcopilot.com proxy.business.githubcopilot.com proxy.enterprise.githubcopilot.com *.actions.githubusercontent.com wss://*.actions.githubusercontent.com productionresultssa0.blob.core.windows.net/ productionresultssa1.blob.core.windows.net/ productionresultssa2.blob.core.windows.net/ productionresultssa3.blob.core.windows.net/ productionresultssa4.blob.core.windows.net/ productionresultssa5.blob.core.windows.net/ productionresultssa6.blob.core.windows.net/ productionresultssa7.blob.core.windows.net/ productionresultssa8.blob.core.windows.net/ productionresultssa9.blob.core.windows.net/ productionresultssa10.blob.core.windows.net/ productionresultssa11.blob.core.windows.net/ productionresultssa12.blob.core.windows.net/ productionresultssa13.blob.core.windows.net/ productionresultssa14.blob.core.windows.net/ productionresultssa15.blob.core.windows.net/ productionresultssa16.blob.core.windows.net/ productionresultssa17.blob.core.windows.net/ productionresultssa18.blob.core.windows.net/ productionresultssa19.blob.core.windows.net/ github-production-repository-image-32fea6.s3.amazonaws.com github-production-release-asset-2e65be.s3.amazonaws.com insights.github.com wss://alive.github.com api.githubcopilot.com api.individual.githubcopilot.com api.business.githubcopilot.com api.enterprise.githubcopilot.com; font-src github.githubassets.com; form-action 'self' github.com gist.github.com copilot-workspace.githubnext.com objects-origin.githubusercontent.com; frame-ancestors 'none'; frame-src viewscreen.githubusercontent.com notebooks.githubusercontent.com; img-src 'self' data: blob: github.githubassets.com media.githubusercontent.com camo.githubusercontent.com identicons.github.com avatars.githubusercontent.com private-avatars.githubusercontent.com github-cloud.s3.amazonaws.com objects.githubusercontent.com release-assets.githubusercontent.com secured-user-images.githubusercontent.com/ user-images.githubusercontent.com/ private-user-images.githubusercontent.com opengraph.githubassets.com copilotprodattachments.blob.core.windows.net/github-production-copilot-attachments/ github-production-user-asset-6210df.s3.amazonaws.com customer-stories-feed.github.com spotlights-feed.github.com objects-origin.githubusercontent.com *.githubusercontent.com; manifest-src 'self'; media-src github.com user-images.githubusercontent.com/ secured-user-images.githubusercontent.com/ private-user-images.githubusercontent.com github-production-user-asset-6210df.s3.amazonaws.com gist.github.com; script-src github.githubassets.com; style-src 'unsafe-inline' github.githubassets.com; upgrade-insecure-requests; worker-src github.githubassets.com github.com/assets-cdn/worker/ github.com/assets/ gist.github.com/assets-cdn/worker/
server: github.com
content-encoding: gzip
accept-ranges: bytes
set-cookie: _gh_sess=59RSRHk0V7QSSN7%2Bn1FnDSA4oI%2F%2BJTjIWsE08ql%2BDB5N8PBg8d2Xqxe83N4Q%2FjJYbv7P1BFDXMasB%2BIgjjsHkzS6hvPWx%2BVHugj1MozLOKaZ4Lgdo3Dky3Z%2BwiSfW8CEp8CvWsBHnxscuMyTgsuxk9QJx1vdVl%2FAEaBFuRrRtfGCUsrAJZQm9RmRsd%2FsSgGR8CFBvh6Ma7bZi1R3rV%2BtToZy73wN1hYkpicsr3J6wBd9SnRWzIiI1VhUoQL26qgELXN5NQBXZMRRFu6ivHe1bQ%3D%3D--Y3i0Uxssqo4tOIbM--0gUR1QQ3cqhYotcHCGA3%2Bg%3D%3D; Path=/; HttpOnly; Secure; SameSite=Lax
set-cookie: _octo=GH1.1.1383725167.1753789985; Path=/; Domain=github.com; Expires=Wed, 29 Jul 2026 11:53:05 GMT; Secure; SameSite=Lax
set-cookie: logged_in=no; Path=/; Domain=github.com; Expires=Wed, 29 Jul 2026 11:53:05 GMT; HttpOnly; Secure; SameSite=Lax
x-github-request-id: D19C:319329:A23BF3:C1F394:6888B621
Self contained example of how pipeline parallel works (AFAB and 1F1B) in 200 LOC · GitHub
Show Gist options
Save sugatoray/e76ead6d3d0d96e2a1a8461e9c2ad78d to your computer and use it in GitHub Desktop.
{{ message }}
Instantly share code, notes, and snippets.
Forked from 3outeille/pipeline_parallel.py
Created
October 2, 2024 17:20
-
Star
1
(1)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save sugatoray/e76ead6d3d0d96e2a1a8461e9c2ad78d to your computer and use it in GitHub Desktop.
Self contained example of how pipeline parallel works (AFAB and 1F1B) in 200 LOC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#VERBOSE=0 torchrun --nproc_per_node 3 self_contained_pp_LOC.py | |
import os, random, numpy as np, torch, torch.nn as nn, torch.distributed as dist, torch.nn.functional as F | |
from torch.optim import AdamW | |
from torch.utils.data import DataLoader, DistributedSampler | |
from datasets import load_dataset | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
STEP, local_rank, world_size, verbose = 0, int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]), os.environ.get("VERBOSE", "0") == "1" | |
def set_all_seed(seed): | |
for module in [random, np.random]: module.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) | |
class MicroBatchDataLoader(DataLoader): | |
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): | |
self.global_batch_size, self.micro_batch_size, self.data_parallel_size, self.seq_length = global_batch_size, micro_batch_size, data_parallel_size, seq_length | |
self.local_batch_size = self.global_batch_size // self.data_parallel_size | |
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size | |
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
self.dataset = load_dataset(dataset_name, split=split) | |
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) | |
dist.barrier() | |
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) | |
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=DistributedSampler(self.dataset, num_replicas=data_parallel_size, rank=0, shuffle=False), shuffle=False) | |
def collate_batch(self, batch_data): | |
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) | |
batch_size, seq_len = batch_input_ids.shape | |
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} | |
class ProcessGroupManager: | |
def __init__(self, pp_rank, pp_world_size): | |
self.pp_rank, self.pp_world_size = pp_rank, pp_world_size | |
self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else (self.pp_rank + 1) % self.pp_world_size | |
self.pp_prev_rank = None if self.pp_rank == 0 else (self.pp_rank - 1) % self.pp_world_size | |
self.is_pipeline_last_stage = self.pp_rank == self.pp_world_size - 1 | |
self.is_pipeline_first_stage = self.pp_rank == 0 | |
self.pp_group = dist.new_group(list(range(self.pp_world_size))) | |
class PipelineParallel(nn.Module): | |
def __init__(self, model_name): | |
super().__init__() | |
self.config = AutoConfig.from_pretrained(model_name) | |
base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config) | |
layer_distribution = self.distribute_layers(self.config.num_hidden_layers) | |
self.embed_tokens = base_model.model.embed_tokens if process_group_manager.is_pipeline_first_stage else nn.Identity() | |
self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution}) | |
self.norm = base_model.model.norm if process_group_manager.is_pipeline_last_stage else nn.Identity() | |
self.lm_head = base_model.lm_head if process_group_manager.is_pipeline_last_stage else nn.Identity() | |
del base_model | |
def distribute_layers(self, num_layers): | |
layers_per_gpu = [num_layers // process_group_manager.pp_world_size + (1 if i < num_layers % process_group_manager.pp_world_size else 0) for i in range(process_group_manager.pp_world_size)] | |
start_layer = sum(layers_per_gpu[:process_group_manager.pp_rank]) | |
return list(range(start_layer, start_layer + layers_per_gpu[process_group_manager.pp_rank])) | |
def forward(self, batch, device): | |
x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device) | |
x = self.embed_tokens(x) | |
for layer in self.decoder_layers.values(): | |
x = layer(x, position_ids=batch["position_index"].to(device))[0] | |
x = self.norm(x) | |
return self.lm_head(x) | |
def backward(self, input_tensor, output_tensor, output_tensor_grad): | |
if input_tensor is not None: input_tensor.retain_grad() | |
if output_tensor_grad is None: | |
output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format) | |
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False) | |
return input_tensor.grad if input_tensor is not None else None | |
def communicate(tensor=None, shapes=None, dtype=None, operation='send_forward'): | |
global STEP | |
if operation == 'recv_forward': | |
if process_group_manager.is_pipeline_first_stage: return None | |
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) | |
src = process_group_manager.pp_prev_rank | |
elif operation == 'send_forward': | |
if process_group_manager.is_pipeline_last_stage: return | |
dest = process_group_manager.pp_next_rank | |
elif operation == 'recv_backward': | |
if process_group_manager.is_pipeline_last_stage: return None | |
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) | |
src = process_group_manager.pp_next_rank | |
elif operation == 'send_backward': | |
if process_group_manager.is_pipeline_first_stage: return | |
dest = process_group_manager.pp_prev_rank | |
is_send = operation.startswith('send') | |
peer_rank = dest if is_send else src | |
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank) | |
if verbose: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{process_group_manager.pp_rank}", flush=True) | |
[req.wait() for req in dist.batch_isend_irecv([op])] | |
torch.cuda.synchronize() | |
if verbose: STEP += 1 | |
return tensor if not is_send else None | |
def bidirectional_communicate(op, send_tensor, recv_shapes, dtype, device): | |
global STEP | |
is_fwd = (op == 'send_fwd_recv_bwd') | |
if (is_fwd and process_group_manager.is_pipeline_last_stage) or (not is_fwd and process_group_manager.is_pipeline_first_stage): return None | |
peer_rank = process_group_manager.pp_next_rank if is_fwd else process_group_manager.pp_prev_rank | |
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype) | |
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)]) | |
if verbose: print(f"{op} | sending {'next' if is_fwd else 'prev'} {process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{process_group_manager.pp_rank}", flush=True) | |
[req.wait() for req in reqs] | |
torch.cuda.synchronize() | |
if verbose: STEP += 1 | |
return recv_tensor | |
def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device): | |
num_warmup_microbatches = min(process_group_manager.pp_world_size - process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches) | |
num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches | |
logging_loss, input_tensors, output_tensors = 0.0, [], [] | |
def _forward_step(input_tensor): | |
batch = next(iter(data_loader)) | |
batch["hidden_states"] = input_tensor | |
output_tensor = model.forward(batch, device) | |
if process_group_manager.is_pipeline_last_stage: | |
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') | |
nonlocal logging_loss | |
logging_loss += output_tensor.item() | |
return output_tensor | |
for _ in range(num_warmup_microbatches): # Warmup forward passes | |
input_tensor = communicate(shapes=tensor_shapes, dtype=torch.float32, operation='recv_forward') | |
output_tensor = _forward_step(input_tensor) | |
communicate(tensor=output_tensor, operation='send_forward') | |
input_tensors.append(input_tensor) | |
output_tensors.append(output_tensor) | |
if num_microbatches_remaining > 0: | |
input_tensor = communicate(shapes=tensor_shapes, dtype=torch.float32, operation='recv_forward') | |
for i in range(num_microbatches_remaining): # 1F1B steady state | |
output_tensor = _forward_step(input_tensor) | |
output_tensor_grad = bidirectional_communicate('send_fwd_recv_bwd', output_tensor, tensor_shapes, torch.float32, device) | |
input_tensors.append(input_tensor) | |
output_tensors.append(output_tensor) | |
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) | |
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) | |
if i == num_microbatches_remaining - 1: # last iteration | |
input_tensor = None | |
communicate(tensor=input_tensor_grad, operation='send_backward') | |
else: | |
input_tensor = bidirectional_communicate('send_bwd_recv_fwd', input_tensor_grad, tensor_shapes, torch.float32, device) | |
for _ in range(num_warmup_microbatches): # Cooldown backward passes | |
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) | |
output_tensor_grad = communicate(shapes=tensor_shapes, dtype=torch.float32, operation='recv_backward') | |
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) | |
communicate(tensor=input_tensor_grad, operation='send_backward') | |
return logging_loss | |
def pipeline_parallel_afab(model, data_loader, tensor_shapes, device): | |
logging_loss, input_tensors, output_tensors = 0.0, [], [] | |
for _ in range(data_loader.num_local_micro_batches): # All forward passes | |
input_tensor = communicate(shapes=tensor_shapes, dtype=torch.float32, operation='recv_forward') | |
batch = next(iter(data_loader)) | |
batch["hidden_states"] = input_tensor | |
output_tensor = model.forward(batch, device) | |
communicate(tensor=output_tensor, operation='send_forward') | |
if process_group_manager.is_pipeline_last_stage: | |
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') | |
logging_loss += output_tensor.item() | |
input_tensors.append(input_tensor) | |
output_tensors.append(output_tensor) | |
for _ in range(data_loader.num_local_micro_batches): # All backward passes | |
output_tensor_grad = communicate(shapes=tensor_shapes, dtype=torch.float32, operation='recv_backward') | |
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) | |
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) | |
communicate(tensor=input_tensor_grad, operation='send_backward') | |
return logging_loss | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800 | |
dist.init_process_group(backend="nccl") | |
torch.cuda.set_device(local_rank) | |
device = torch.device("cuda", local_rank) | |
process_group_manager = ProcessGroupManager(pp_rank=local_rank, pp_world_size=world_size) | |
set_all_seed(seed=42) | |
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) | |
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, 1, SEQ_LEN, "roneneldan/TinyStories", "HuggingFaceTB/SmolLM-360M-Instruct", num_samples=NUM_SAMPLES) | |
tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, model.config.hidden_size) | |
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) | |
trained_tokens, step = 0, 0 | |
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN | |
while trained_tokens < MAX_TOKENS: | |
optimizer.zero_grad() | |
loss = pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device) #loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device) | |
optimizer.step() | |
trained_tokens += tokens_per_step | |
step += 1 | |
if process_group_manager.is_pipeline_last_stage: | |
print(f"Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can’t perform that action at this time.
⚡️ Link to the LinkedIn post.