ට්රාන්ස්ෆෝමර්එක්ස්එල් අත්හදා බැලීම

මෙයට්රාන්ස්ෆෝමර් xl ආකෘතියක් පුහුණු කිරීම සඳහා කරන ලද පයිටෝර්ච් අත්හදා බැලීමකි.

11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer

ස්වයංක්රීයප්රතිගාමී ආකෘතිය

26class AutoregressiveModel(Module):
31    def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32        super().__init__()

ටෝකන්කාවැද්දීම මොඩියුලය

34        self.src_embed = nn.Embedding(n_vocab, d_model)

ට්රාන්ස්ෆෝමර්

36        self.transformer = transformer

අවසන්ස්ථරය

38        self.generator = nn.Linear(d_model, n_vocab)

වෙස්මුහුණු

40        self.mask_x = None
41        self.mask_mem = None
43    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):

මතකයේදිග

45        m_len = len(mem[0]) if mem else 0

ටෝකනසඳහා පසුකාලීන වෙස් මුහුණක් සාදන්න

47        if self.mask_x is None or self.mask_x.shape[0] < len(x):
48            from labml_nn.transformers.utils import subsequent_mask
49            self.mask_x = subsequent_mask(len(x)).to(x.device)

මතකයසඳහා සියලු (සම්පූර්ණ දෘශ්යතාව) වෙස්මුහුණක් සාදන්න

51        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

මතකයක්තිබේ නම් වෙස් මුහුණු සංයුක්ත කරන්න

55        if m_len:
56            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

පසුකාලීනආවරණ වෙනත් ආකාරයකින් භාවිතා කරන්න

58        else:
59            mask = self.mask_x[:len(x), :len(x)]

ටෝකන්කාවැද්දීම්

62        x = self.src_embed(x)

ට්රාන්ස්ෆෝමරයහරහා එය ධාවනය කරන්න

64        res, mem = self.transformer(x, mem, mask)

ඊළඟටෝකනයේ පිවිසුම් ජනනය කරන්න

66        res = self.generator(res)

68        return res, mem

වින්යාසකිරීම්

අපිඅත්හදා බැලීම ආරම්භ කරන විට පෙරනිමි වින්යාස කළ හැකි අතර එය අධික ලෙස ධාවනය වනු ඇත

71class Configs(NLPAutoRegressionConfigs):
78    model: AutoregressiveModel

ටෝකන්කාවැද්දීමේ ප්රමාණය

81    d_model: int = 128

අවධානයයොමු ප්රධානීන් ගණන

83    heads: int = 4

අතහැරදැමීමේ සම්භාවිතාව

85    dropout: float = 0.0

FFNසැඟවුණු ස්ථරයේ විශේෂාංග ගණන

87    d_ff: int = 256

ට්රාන්ස්ෆෝමර්ස්ථර ගණන

89    n_layers: int = 6

තබාගත යුතු මතකයන් ගණන

91    mem_len: int = 128

පුහුණුවසහ වලංගු කිරීම අතර මාරුවීමේදී මතකයන් පවත්වා ගැනීම සඳහා රාජ්ය මොඩියුලය

93    memory = SimpleStateModule()
95    def init(self):

ට්රැකර්වින්යාසයන් සකසන්න

97        tracker.set_scalar("accuracy.*", True)
98        tracker.set_scalar("loss.*", True)

මොඩියුලප්රතිදානයන් ලොග් කිරීමට කොක්කක් එක් කරන්න

100        hook_model_outputs(self.mode, self.model, 'model')

මෙයපුහුණුව සහ වලංගු කිරීම සඳහා නිරවද්යතා මෙට්රික් සංඛ්යාන සහ මතකයන් වෙනම තබා ගනී.

102        self.state_modules = [self.accuracy, self.memory]

මතකයන්සංකෝචනය කර උපරිම මතකයන් තබා ගැනීම සඳහා පැරණි mem_len මතකයන් ඉවත් කරන්න.

104    def merge_memory(self, old_mem, new_mem):

එයවින්යාස කර ඇත්නම් මතකය භාවිතා නොකිරීමට

111        if self.mem_len == 0:
112            return []

පැරණිමතකය සමඟ සංයුක්ත වන්න

115        if old_mem:
116            mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117        else:
118            mem = new_mem

පැරණිමතකයන් ඉවත් කරන්න

121        if len(mem[0]) > self.mem_len:
122            mem = [m[-self.mem_len:] for m in mem]

125        return mem

පුහුණුව/වලංගුකිරීමේ පියවර

127    def step(self, batch: any, batch_idx: BatchIndex):

උපාංගයවෙත දත්ත ගෙනයන්න

133        data, target = batch[0].to(self.device), batch[1].to(self.device)

පුහුණුප්රකාරයේදී ගෝලීය පියවර යාවත්කාලීන කරන්න (සැකසූ ටෝකන ගණන)

136        if self.mode.is_train:
137            tracker.add_global_step(data.shape[0] * data.shape[1])

ආකෘතිප්රතිදානයන් ග්රහණය කර ගත යුතුද යන්න

140        with self.mode.update(is_log_activations=batch_idx.is_last):

මතකයන්ලබා ගන්න

142            mem = self.memory.get()

ආකෘතියධාවනය කරන්න

144            output, new_mem = self.model(data, mem)

මතකයඒකාබද්ධ කරන්න

146            mem = self.merge_memory(mem, new_mem)

මතකයන්යාවත්කාලීන කරන්න

148            self.memory.set(mem)

හරස්එන්ට්රොපි අලාභය ගණනය කර ලොග් කරන්න

151        loss = self.loss_func(output, target)
152        tracker.add("loss.", loss)

ගණනයකිරීම සහ ලොග් කිරීමේ නිරවද්යතාවය

155        self.accuracy(output, target)
156        self.accuracy.track()

ආකෘතියපුහුණු කරන්න

159        if self.mode.is_train:

අනුක්රමිකගණනය කරන්න

161            loss.backward()

ක්ලිප්අනුක්රමික

163            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

ප්රශස්තිකරණපියවර ගන්න

165            self.optimizer.step()

සෑමයුගලයකම අවසාන කණ්ඩායමේ ආදර්ශ පරාමිතීන් සහ අනුක්රමික ලොග් කරන්න

167            if batch_idx.is_last:
168                tracker.add('model', self.model)

අනුක්රමිකඉවත්

170            self.optimizer.zero_grad()

ලුහුබැඳඇති ප්රමිතික සුරකින්න

173        tracker.save()

පුහුණුවඅතරතුර වරින් වර සාම්පල ජනනය කිරීම සඳහා නියැදි කිරීමේ කාර්යය

175    def sample(self):

විමසුමක්ආරම්භ කිරීම

181        prompt = self.prompt

මුද්රණයසඳහා ප්රතිදානය එකතු කරන්න

183        log = [(prompt, Text.subtle)]

මතකය

185        mem = []

සාම්පල25 ටෝකන

187        for i in monit.iterate('Sample', 25):

විමසුමටෝකෙන්කරන්න

189            data = self.text.text_to_i(prompt).unsqueeze(-1)

උපාංගයවෙත ගෙන යන්න

191            data = data.to(self.device)

ආදර්ශප්රතිදානය ලබා ගන්න

193            output, new_mem = self.model(data, mem)

ආදර්ශඅනාවැකිය ලබා ගන්න (කෑදර)

195            output = output.argmax(dim=-1).squeeze(1)

විමසුමටඅනාවැකිය එක් කරන්න

197            prompt += self.prompt_separator + self.text.itos[output[-1]]

ඊළඟපුනරාවර්තනයේදී අවසාන චරිතය ආකෘතියට පමණක් පෝෂණය කරන්න, විවේකය මතකයන් ලෙස ඉදිරියට යනු ඇත

199            prompt = prompt[-1:]

ලොග්වීම සඳහා අනාවැකිය එක් කරන්න

201            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

මතකයයාවත්කාලීන කරන්න

203            mem = self.merge_memory(mem, new_mem)

නියැදිප්රතිදානය මුද්රණය කරන්න

206        logger.log(log)

ස්වයංක්රීයප්රතිගාමී ආකෘතිය ආරම්භ කරන්න

209@option(Configs.model)
210def autoregressive_model(c: Configs):
214    from labml_nn.transformers.xl import RelativeMultiHeadAttention
215    from labml_nn.transformers.feed_forward import FeedForward
216    m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217        TransformerXLLayer(d_model=c.d_model,
218                           self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219                           feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220                           dropout_prob=c.dropout), c.n_layers))
221    return m.to(c.device)

අත්හදාබැලීම ක්රියාත්මක කරන්න

224def main():

අත්හදාබැලීම සාදන්න

229    experiment.create(name="transformer_xl", comment='')

වින්යාසසාදන්න

231    conf = Configs()

වින්යාසයන්පූරණය කරන්න

233    experiment.configs(conf,

අභිබවායාම සඳහා වින්යාසයන් පිළිබඳ ශබ්දකෝෂයක්

235                       {'tokenizer': 'character',
236                        'text': 'tiny_shakespeare',
237                        'optimizer.learning_rate': 1.,
238                        'optimizer.optimizer': 'Noam',
239                        'prompt': 'It is',
240                        'prompt_separator': '',
241
242                        'train_loader': 'sequential_train_loader',
243                        'valid_loader': 'sequential_valid_loader',
244
245                        'seq_len': 2,
246                        'mem_len': 32,
247                        'epochs': 128,
248                        'batch_size': 32,
249                        'inner_iterations': 25,
250                        })

ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න

253    experiment.add_pytorch_models({'model': conf.model})

අත්හදාබැලීම ආරම්භ කරන්න

256    with experiment.start():

TrainValidConfigs.run

258        conf.run()

262if __name__ == '__main__':
263    main()