Skip to content

Commit

Permalink
Use bf16 only if supported
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed May 17, 2023
1 parent 7fe4a09 commit eb33b8b
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
real_data = True
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = True # use PyTorch 2.0 to compile the model to be faster
profile = False # use pytorch profiler, or just simple benchmarking?
exec(open('configurator.py').read()) # overrides from command line or config file
Expand Down
2 changes: 1 addition & 1 deletion sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
backend = 'nccl' # 'nccl', 'gloo', etc.
# system
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = True # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
Expand Down

0 comments on commit eb33b8b

Please sign in to comment.