Skip to content

Commit

Permalink
Merge pull request karpathy#225 from otaviogood/grad_accum
Browse files Browse the repository at this point in the history
Fix for gradient_accumulation_steps training slow
  • Loading branch information
karpathy authored Apr 18, 2023
2 parents d9f4735 + a6a708c commit 21f9bff
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion config/train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
batch_size = 12
block_size = 1024
gradient_accumulation_steps = 5
gradient_accumulation_steps = 5 * 8

# this makes total number of tokens be 300B
max_iters = 600000
Expand Down
1 change: 1 addition & 0 deletions config/train_shakespeare_char.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
wandb_run_name = 'mini-gpt'

dataset = 'shakespeare_char'
gradient_accumulation_steps = 1
batch_size = 64
block_size = 256 # context of up to 256 previous characters

Expand Down
10 changes: 7 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
wandb_run_name = 'gpt2' # 'run' + str(time.time())
# data
dataset = 'openwebtext'
gradient_accumulation_steps = 5 # used to simulate larger batch sizes
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 1024
# model
Expand Down Expand Up @@ -84,16 +84,20 @@
init_process_group(backend=backend)
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
seed_offset = ddp_rank # each process gets a different seed
assert gradient_accumulation_steps % torch.cuda.device_count() == 0
gradient_accumulation_steps //= torch.cuda.device_count()
else:
# if not ddp, we are running on a single gpu, and one process
master_process = True
seed_offset = 0
gradient_accumulation_steps *= 8 # simulate 8 gpus
print("total number of tokens per iteration:", batch_size * block_size * gradient_accumulation_steps)
ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
os.makedirs(out_dir, exist_ok=True)
Expand Down

0 comments on commit 21f9bff

Please sign in to comment.