-
Notifications
You must be signed in to change notification settings - Fork 3
/
KnowLog_finetune_pair.py
105 lines (83 loc) · 3.87 KB
/
KnowLog_finetune_pair.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
from sentence_transformers import models, losses
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator,LabelAccuracyEvaluator,BinaryClassificationEvaluator
from torch.utils.data import DataLoader
import logging
import json
import random
import os
import sys
import math
import numpy as np
random.seed(1)
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
def parse_args():
args = argparse.ArgumentParser()
# network arguments
args.add_argument("-train_data", "--train_data", type=str,
default="./datasets/tasks/LDSM/hw_switch_train.json", help="train dataset")
args.add_argument("-dev_data", "--dev_data", type=str,
default="./datasets/tasks/LDSM/hw_switch_dev.json", help="dev dataset")
args.add_argument("-test_data", "--test_data", type=str,
default="./datasets/tasks/LDSM/hw_switch_test.json", help="test dataset")
args.add_argument("-pretrain_model", "--pretrain_model", type=str,
default="bert-base-uncased", help="the path of the pretrained model to finetune")
args.add_argument("-epoch", "--epoch", type=int,
default=10, help="Number of epochs")
args.add_argument("-batch_size", "--batch_size", type=int,
default=8, help="Batch Size")
args.add_argument("-outfolder", "--outfolder", type=str,
default="./output/knowlog_finetune", help="Folder name to save the models.")
args = args.parse_args()
return args
def read_json(file):
with open(file, 'r+') as file:
content = file.read()
content = json.loads(content)
return content
def evaluate(args):
model_save_path = args.outfolder
train_batch_size = args.batch_size
num_epochs = args.epoch
train_data = read_json(args.train_data)
dev_data = read_json(args.dev_data)
test_data = read_json(args.test_data)
# load model
model = SentenceTransformer(args.pretrain_model)
# load dataset
train_samples = []
dev_samples = []
test_samples = []
for item in train_data:
train_samples.append(InputExample(texts=[item[0][0], item[0][1]], label=item[1]))
for item in test_data:
test_samples.append(InputExample(texts=[item[0][0], item[0][1]], label=item[1]))
for item in dev_data:
dev_samples.append(InputExample(texts=[item[0][0], item[0][1]], label=item[1]))
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
dev_dataloader = DataLoader(dev_samples, shuffle=True, batch_size=train_batch_size)
test_dataloader = DataLoader(test_samples, shuffle=True, batch_size=train_batch_size)
# loss
train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
num_labels=2)
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))
dev_evaluator = LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss, name='test2')
test_evaluator = LabelAccuracyEvaluator(test_dataloader, softmax_model=train_loss, name='test2')
# finetune and evaluate
model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=dev_evaluator,
evaluator2=test_evaluator,
epochs=num_epochs,
evaluation_steps=10000,
warmup_steps=warmup_steps,
output_path=model_save_path,
)
if __name__ == '__main__':
args = parse_args()
evaluate(args)