Skip to content

Commit

Permalink
Added classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
penenadpi committed Jun 25, 2023
1 parent 2ba7184 commit 75c863a
Showing 1 changed file with 97 additions and 0 deletions.
97 changes: 97 additions & 0 deletions bots/nenad/joke_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,103 @@ def train(self, jokes_dataset_path='./shortjokes.csv', BATCH_SIZE = 16, EPOCHS =
def load_state(self, models_folder="trained_models", weights_file="gpt2_joker_0.pt"):
model_path = os.path.join(models_folder, weights_file)
self.model.load_state_dict(torch.load(model_path))



class SimpleGPT2SequenceClassifier(nn.Module):
def __init__(self, hidden_size=768, num_classes=10 ,max_seq_len=128, gpt_model_name='gpt2'):
super(SimpleGPT2SequenceClassifier,self).__init__()
self.gpt2model = GPT2Model.from_pretrained(gpt_model_name)
self.fc1 = nn.Linear(hidden_size*max_seq_len, num_classes)
self.model = SimpleGPT2SequenceClassifier(hidden_size, num_classes, max_seq_len, gpt_model_name)
self.labels_map = {
1: "1",
2: "2",
3: "3",
4: "4",
5: "5",
6: "6",
7: "7",
8: "8",
9: "9",
10: "10"
}

def forward(self, input_id, mask):

gpt_out, _ = self.gpt2model(input_ids=input_id, attention_mask=mask, return_dict=False)
batch_size = gpt_out.shape[0]
linear_output = self.fc1(gpt_out.view(batch_size,-1))
return linear_output


def train(model, train_data, learning_rate=1e-5, epochs=1):
train = RatingDataset(train_data)

train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

criterion = nn.CrossEntropyLoss()
optimizer = Adam(self.model.parameters(), lr=learning_rate)

for epoch_num in range(epochs):
total_acc_train = 0
total_loss_train = 0

for train_input, train_label in tqdm(train_dataloader):
train_label = train_label.to(device)
mask = train_input['attention_mask'].to(device)
input_id = train_input["input_ids"].squeeze(1).to(device)

self.model.zero_grad()

output = self.model(input_id, mask)

batch_loss = criterion(output, train_label)
total_loss_train += batch_loss.item()

acc = (output.argmax(dim=1)==train_label).sum().item()
total_acc_train += acc

batch_loss.backward()
optimizer.step()

total_acc_val = 0
total_loss_val = 0


torch.save(self.model.state_dict(), "./trained_models/gpt2-text-classifier-model.pt")



def load_state(self, models_folder="trained_models", weights_file="gpt2-text-classifier-model.pt"):
model_path = os.path.join(models_folder, weights_file)
self.model.load_state_dict(torch.load(model_path))
self.model.eval()


def rate_joke(joke):
fixed_text = " ".join(joke.lower().split())
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

model_input = tokenizer(fixed_text, padding='max_length', max_length=128, truncation=True, return_tensors="pt")


mask = model_input['attention_mask'].cpu()
input_id = model_input["input_ids"].squeeze(1).cpu()

output = self.model(input_id, mask)

prob = torch.nn.functional.softmax(output, dim=1)[0]

pred_label = self.labels_map[output.argmax(dim=1).item()]
return pred_label



if __name__ == "__main__":
bot_nenad = Bot()
Expand Down

0 comments on commit 75c863a

Please sign in to comment.