BDataset
类是是为BERT模型预训练任务设计的PyTorch数据集类,它负责将原始文本数据转换成模型训练所需的格式。这个类在数据处理流程中起到了桥梁的作用,确保数据以正确的方式输入到模型中。
- 数据封装:
BDataset
类封装了两个文本序列、它们的对应标签以及词汇索引映射。 - 数据预处理:实现数据预处理的逻辑,包括文本分词、未知词替换、序列截断和填充,以确保输入数据符合模型的输入要求。
- 掩码操作:为了模拟BERT的预训练过程,
BDataset
类在数据中随机应用掩码操作。这是通过随机选择一定比例的词,并用特殊的[MASK]
标记替换它们来实现的。这一步骤对于训练BERT模型理解上下文至关重要。 - 段索引和序列构建: 构建段索引,用于区分句子对中的句子A和句子B。同时,将特殊标记(如[CLS]和[SEP])纳入序列,以符合BERT模型的输入格式。
注:这些特殊标记符的含义可以在数据预处理中找到对应含义。
class BDataset(Dataset):
def __init__(self, all_text1, all_text2, all_label, max_len, word_2_index):
self.all_text1 = all_text1
self.all_text2 = all_text2
self.all_label = all_label
self.max_len = max_len
self.word_2_index = word_2_index
def __getitem__(self, index):
text1 = self.all_text1[index]
text2 = self.all_text2[index]
lable = self.all_label[index]
unk_idx = self.word_2_index["[UNK]"]
text1_idx = [self.word_2_index.get(i, unk_idx) for i in text1][:62]
text2_idx = [self.word_2_index.get(i, unk_idx) for i in text2][:63]
mask_val = [0] * self.max_len
text_idx = [self.word_2_index["[CLS]"]] + text1_idx + [self.word_2_index["[SEP]"]] + text2_idx + [
self.word_2_index["[SEP]"]]
seg_idx = [0] + [0] * len(text1_idx) + [0] + [1] * len(text2_idx) + [1] + [2] * (self.max_len - len(text_idx))
for i, v in enumerate(text_idx):
if v in [self.word_2_index["[CLS]"], self.word_2_index["[SEP]"], self.word_2_index["[UNK]"]]:
continue
if random.random() < 0.15:
r = random.random()
if r < 0.8:
text_idx[i] = self.word_2_index["[MASK]"]
mask_val[i] = v
elif r > 0.9:
other_idx = random.randint(6, len(self.word_2_index) - 1)
text_idx[i] = other_idx
mask_val[i] = v
text_idx = text_idx + [self.word_2_index["[PAD]"]] * (self.max_len - len(text_idx))
return torch.tensor(text_idx), torch.tensor(lable), torch.tensor(mask_val), torch.tensor(seg_idx)
def __len__(self):
return len(self.all_label)
在这个过程中,我遇到了越界
错误,由于目的是为了实现一个BERT
的训练流程,所以我仅仅是添加了一个越界
检测纠正代码。