Skip to content

Latest commit

 

History

History
62 lines (45 loc) · 3 KB

3.1 数据集定义.md

File metadata and controls

62 lines (45 loc) · 3 KB

1 数据集定义

BDataset 类是是为BERT模型预训练任务设计的PyTorch数据集类,它负责将原始文本数据转换成模型训练所需的格式。这个类在数据处理流程中起到了桥梁的作用,确保数据以正确的方式输入到模型中。

1.1 功能

  • 数据封装:BDataset 类封装了两个文本序列、它们的对应标签以及词汇索引映射。
  • 数据预处理:实现数据预处理的逻辑,包括文本分词、未知词替换、序列截断和填充,以确保输入数据符合模型的输入要求。
  • 掩码操作:为了模拟BERT的预训练过程,BDataset 类在数据中随机应用掩码操作。这是通过随机选择一定比例的词,并用特殊的[MASK]标记替换它们来实现的。这一步骤对于训练BERT模型理解上下文至关重要。
  • 段索引和序列构建: 构建段索引,用于区分句子对中的句子A和句子B。同时,将特殊标记(如[CLS]和[SEP])纳入序列,以符合BERT模型的输入格式。

注:这些特殊标记符的含义可以在数据预处理中找到对应含义。

class BDataset(Dataset):
    def __init__(self, all_text1, all_text2, all_label, max_len, word_2_index):
        self.all_text1 = all_text1
        self.all_text2 = all_text2
        self.all_label = all_label
        self.max_len = max_len
        self.word_2_index = word_2_index

    def __getitem__(self, index):
        text1 = self.all_text1[index]
        text2 = self.all_text2[index]

        lable = self.all_label[index]
        unk_idx = self.word_2_index["[UNK]"]
        text1_idx = [self.word_2_index.get(i, unk_idx) for i in text1][:62]
        text2_idx = [self.word_2_index.get(i, unk_idx) for i in text2][:63]

        mask_val = [0] * self.max_len

        text_idx = [self.word_2_index["[CLS]"]] + text1_idx + [self.word_2_index["[SEP]"]] + text2_idx + [
            self.word_2_index["[SEP]"]]
        seg_idx = [0] + [0] * len(text1_idx) + [0] + [1] * len(text2_idx) + [1] + [2] * (self.max_len - len(text_idx))

        for i, v in enumerate(text_idx):
            if v in [self.word_2_index["[CLS]"], self.word_2_index["[SEP]"], self.word_2_index["[UNK]"]]:
                continue

            if random.random() < 0.15:
                r = random.random()
                if r < 0.8:
                    text_idx[i] = self.word_2_index["[MASK]"]

                    mask_val[i] = v

                elif r > 0.9:
                    other_idx = random.randint(6, len(self.word_2_index) - 1)
                    text_idx[i] = other_idx
                    mask_val[i] = v

        text_idx = text_idx + [self.word_2_index["[PAD]"]] * (self.max_len - len(text_idx))

        return torch.tensor(text_idx), torch.tensor(lable), torch.tensor(mask_val), torch.tensor(seg_idx)

    def __len__(self):
        return len(self.all_label)

在这个过程中,我遇到了越界错误,由于目的是为了实现一个BERT的训练流程,所以我仅仅是添加了一个越界检测纠正代码。