රෙට්රෝපුහුණුව

මෙය RETROසඳහා පුහුණු කේතය වේ.

View Run

16import torch
17from torch import nn
18from torch.utils.data import DataLoader, RandomSampler
19
20from labml import monit, lab, tracker, experiment, logger
21from labml.logger import Text
22from labml_helpers.datasets.text import TextFileDataset
23from labml_nn.optimizers.noam import Noam
24from labml_nn.transformers.retro import model as retro
25from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
26from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder

නියැදිකරු

මෙමපන්තිය ආකෘතියකින් කෑදරකමින් සාම්පල.

29class Sampler:
  • device ආකෘතියේ උපාංගය වේ
  • model රෙට්රෝ ප්රකාරය
  • tds යනු පෙළ දත්ත සමුදාය (අසල්වැසියා කුට්ටි ලබා ගැනීමට භාවිතා කරයි)
  • chunk_len යනු කුට්ටියක දිග
36    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
43        self.chunk_len = chunk_len
44        self.tds = tds
45        self.model = model
46        self.device = device
49        self.index = RetroIndex()

ලබාදී ඇති කුට්ටියක ළඟම අසල්වැසියන් ලබා ගන්න

51    def retrieve_nearest_neighbours(self, chunk: str):

ළඟමඅසල්වාසීන්ගේ හිලව් ලබා ගන්න

57        neighbor_offsets = self.index([chunk], None)

අසල්වැසියන්ලබා ගන්න (අසල්වැසියාගේ දිග සමාන chunk_len * 2 )

60        text = self.tds.train
61        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]

64        return neighbors

ලබාදී ඇති විමසුමෙන් නියැදි පෙළ

66    def sample(self, prompt: str, sample_len: int):

ආසන්නතමඅසල්වැසියන් නූල් ලෙස ගබඩා කිරීම

72        neighbors_str = []

නියැදිපෙළ

75        sampled = ''

නියැදි sample_len ටෝකන

78        for i in range(sample_len):

අපදැනටමත් ලබා ගෙන ඇති ප්රමාණයට වඩා වැඩි නියැදි කුට්ටි තිබේ නම්, අසල්වැසියන් නැවත ලබා ගත යුතුය

81            while len(neighbors_str) < len(prompt) // self.chunk_len:

අපඅසල්වැසියන් ලබා නොගත් අවසාන කුට්ටිය ලබා ගන්න

83                off = len(neighbors_str) * self.chunk_len
84                chunk = prompt[off: off + self.chunk_len]

ළඟමඅසල්වැසියන් ලබා ගන්න

86                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))

ආදානයටෝකෙන්කරණය කරන්න

89            src = self.tds.text_to_i(prompt)

ලබාගත් අසල්වැසියන් ටෝකීස් කරන්න

91            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])

ආකෘතියටසමාන උපාංගයකට ඒවා ගෙනයන්න

94            src = src.to(self.device)
95            neighbors = neighbors.to(self.device)

ආදර්ශප්රතිදානය ලබා ගන්න

98            res = self.model(src[None, :], neighbors[None, :, :, :])

කෑදරකමින්අවසාන ටෝකනය සාම්පලය

101            token = res[0, -1, :].argmax(dim=-1)

නියැදිටෝකන් පෙළ විමසුමට සහ නියැදි පෙළට එක් කරන්න

104            prompt += self.tds.itos[token.item()]
105            sampled += self.tds.itos[token.item()]

108        return sampled

රෙට්රොපුහුණුකරු

111class Trainer:
116    def __init__(self, device: torch.device, model: retro.RetroModel,
117                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
124        self.optimizer = optimizer
125        self.device = device
126        self.dataloader = dataloader
127        self.model = model
128        self.loss_func = nn.CrossEntropyLoss()

එපෝච්සඳහා ආකෘතිය පුහුණු කරන්න

130    def __call__(self):

පුහුණුදත්ත හරහා නැවත භාවිතා කරන්න

136        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):

උපාංගයවෙත දත්ත ගෙනයන්න

138            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)

ඉදිරිසාමාර්ථය

141            res = self.model(src, neighbors)

අලාභයගණනය කරන්න

143            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))

අනුක්රමිකඉවත්

146            self.optimizer.zero_grad()

පසුගාමීපාස්

148            loss.backward()

ආකෘතියප්රශස්ත කරන්න

150            self.optimizer.step()

පුහුණුසංඛ්යාලේඛන සුරකින්න සහ ගෝලීය පියවර කවුන්ටරය වැඩි කරන්න

153            tracker.save({'loss.train': loss})
154            tracker.add_global_step(len(src))

කුඩාආකෘතියක් සාදන්න සහ පුහුණු කරන්න

157def train():

අත්හදාබැලීමක් සාදන්න

163    experiment.create(name='retro_small')

GPUඋපාංගය

166    device = torch.device('cuda:0')

කුඩාෂේක්ස්පියර් දත්ත කට්ටලය පටවන්න

169    tds = TextFileDataset(
170        lab.get_data_path() / 'tiny_shakespeare.txt',
171        list,
172        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
175    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)

දත්තකාරකය සාදන්න

178    train_dl = DataLoader(train_dataset,
179                          batch_size=4,
180                          sampler=RandomSampler(train_dataset, replacement=True))

අධිපරාමිතීන්

183    chunk_len = 16
184    d_model = 128
185    d_ff = 512
186    n_heads = 16
187    d_k = 16

ළඟමඅසල්වැසියාගේ එන්කෝඩරය සාදන්න

190    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)

ආකෘතියසාදන්න

192    model = RetroModel(tds.n_tokens, d_model, 6,
193                       {3, 5},
194                       chunk_len, n_heads, d_k, d_ff,
195                       encoder=nearest_neighbor_encoder)

උපාංගයවෙත ආකෘතිය ගෙනයන්න

197    model = model.to(device)

ප්රශස්තකරණයසාදන්න

199    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)

සාදන්න Trainer

201    trainer = Trainer(device, model, train_dl, optimizer)

සාදන්න Sampler

203    sampler = Sampler(device, model, tds, chunk_len)

205    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''

ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න

208    experiment.add_pytorch_models(model=model)

අත්හදාබැලීම ආරම්භ කරන්න

211    with experiment.start():

32 Epochs සඳහා දුම්රිය

213        for epoch in monit.loop(32):

දුම්රිය

215            trainer()

නවරේඛාවක් මුද්රණය කරන්න

217            tracker.new_line()

වෙතින්නියැදිය prompt

219            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
220                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])

ආකෘතිසුරකින්න

222            experiment.save_checkpoint()

226if __name__ == '__main__':
227    train()