රෙට්රෝපුහුණු දත්ත කට්ටලය

අපිළඟම අසල්වැසියන් ප්රධාන වටිනාකම් දත්ත ගබඩාවෙන් පූර්ව ලබා ගෙන RETRO ආකෘතිය පුහුණු කිරීම සඳහා දත්ත කට්ටලය නිර්මාණය කරමු.

15import json
16from pathlib import Path
17
18import numpy as np
19import torch
20from torch.utils.data import Dataset as PyTorchDataset
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndex

දත්තසමුදාය සාදන්න

  • chunk_len කුටියේ දිග වේ
  • chunks_per_sample පුහුණු නියැදියකට කුට්ටි ගණන
  • skip_range සාම්පල දෙකක් අතර මඟ හැරීමට උපරිම අක්ෂර ගණන වේ. අපි මඟ කිහිපයක් චරිත අතර සාම්පල බවට වග බලා ගන්න සාම්පල පෙලගැසී නොමැත සමග හොඳින් කැබලි දී දත්ත සමුදාය
27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):

පෙළගොනුව පූරණය කරන්න

39    dataset = TextFileDataset(
40        lab.get_data_path() / 'tiny_shakespeare.txt',
41        list,
42        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

එහිපුහුණු කොටස

45    text = dataset.train

අසල්වැසියන්ලබා ගැනීම සඳහා දර්ශකය පටවන්න

48    index = RetroIndex()

ආදානනියැදි ඕෆ්සෙට්

51    sample_offsets = []

පෙළසඳහා කර්සරය

53    i = 0
54    while i < len(text):

අසල්වැසියන්සමඟ පෙලගැසී නොමැති බවට වග බලා ගැනීම සඳහා චරිත කිහිපයක් මඟ හරින්න

56        skip = np.random.randint(skip_range)
57        i += skip

අපිපෙළෙහි අවසානයට පැමිණ ඇත්නම් නවත්වන්න

60        if i + chunks_per_sample * chunk_len > len(text):
61            break

ඕෆ්සෙට්එකතු කරන්න

64        sample_offsets.append(i)

කර්සරයවැඩි කරන්න

67        i += chunks_per_sample * chunk_len

සාම්පලසඳහා

70    samples = []

නියැදිඕෆ්සෙට් හරහා නැවත කරන්න

72    for i in monit.iterate('Gather Neighbors', sample_offsets):

අමතරචරිතයක් ඇතුළුව නියැදිය ලබා ගන්න (අනාවැකිය සඳහා)

74        sample = text[i: i + chunks_per_sample * chunk_len + 1]

ආදානය

76        src = sample[:-1]

එයකුට්ටි වලට කඩන්න

78        chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]

කුට්ටියඕෆ්සෙට්

80        chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]

ළඟමඅසල්වැසියන් ලබා ගන්න

83        neighbor_offsets = index(chunks, chunk_offsets)

අසල්වැසියාගේපෙළ ලබා ගන්න. අසල්වැසියාගේ දිග මෙන් දෙගුණයක් වේ chunk_len

86        neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]

සාම්පලලැයිස්තුවට එකතු කරන්න

89        samples.append((sample[:-1], sample[1:], neighbors))

JSONහි සාම්පල සුරකින්න. අපගේ දත්ත කට්ටලය කුඩා බැවින් අපට සංකීර්ණ දත්ත කට්ටල ගබඩා කිරීමේ යාන්ත්රණ හෝ පූර්ව ටෝකීකරණය භාවිතා කිරීමට අවශ්ය නොවේ.

94    with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95        f.write(json.dumps(samples))

දත්තකට්ටලය

විසින්නිර්මාණය කරන ලද දත්ත කට්ටලය පටවන PyTorch දත්ත කට්ටලය මෙයයි build_dataset .

98class Dataset(PyTorchDataset):
  • file_path සුරකින ලද JSON ගොනුවේ මාර්ගයයි
  • tds වේ TextDataset
  • 105    def __init__(self, file_path: Path, tds: TextDataset):
    111        self.tds = tds

    සාම්පලපූරණය කරන්න

    113        with open(str(file_path), 'r') as f:
    114            self.samples = json.loads(f.read())

    සාම්පලගණන

    116    def __len__(self):
    120        return len(self.samples)

    නියැදියක්ලබා ගන්න

    122    def __getitem__(self, idx: int):

    නියැදියලබා ගන්න

    127        s = self.samples[idx]

    ටෝකනයිස්කරන්න

    129        src = self.tds.text_to_i(s[0])
    130        tgt = self.tds.text_to_i(s[1])
    131        neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])

    133        return src, tgt, neighbors

    136if __name__ == '__main__':
    137    build_dataset()