Skip to content

Commit

Permalink
add the bias option to config, default it to True for now
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jan 27, 2023
1 parent 2bf07a3 commit cc5444e
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def new_gelu(x):
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

def __init__(self, ndim, bias=True):
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
Expand All @@ -39,9 +39,9 @@ def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
Expand Down Expand Up @@ -76,8 +76,8 @@ class MLP(nn.Module):

def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)

def forward(self, x):
Expand All @@ -91,9 +91,9 @@ class Block(nn.Module):

def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=False)
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=False)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)

def forward(self, x):
Expand All @@ -109,6 +109,7 @@ class GPTConfig:
n_head: int = 12
n_embd: int = 768
dropout: float = 0.1
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

class GPT(nn.Module):

Expand All @@ -123,7 +124,7 @@ def __init__(self, config):
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=False),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# with weight tying when using torch.compile() some warnings get generated:
Expand Down

0 comments on commit cc5444e

Please sign in to comment.