මෙයස්වයංක්රීය ප්රතිගාමී සඳහා සරල ට්රාන්ස්ෆෝමර් ආකෘතියක් පුහුණු කරයි. ස්ථාන-නැණවත් පෝෂක ජාලයසඳහා අපි විවිධ ප්රභේද උත්සාහ කරමු.
මෙය labml.configs
මොඩියුලය භාවිතා නොකරන සරල ක්රියාත්මක කිරීමකි. හුරුපුරුදු නොවන පාඨකයන්ට පහසු කිරීම සඳහා සරල ක්රියාත්මක කිරීමක් ලිවීමට අපි තීරණය කළා.
20import dataclasses
21
22import torch
23from labml_helpers.module import Module
24from torch import nn
25from torch.utils.data import Dataset, DataLoader
26
27from labml import experiment, lab, tracker, monit, logger
28from labml.logger import Text
29from labml.utils.download import download_file
30from labml_nn.experiments.nlp_autoregression import transpose_batch
31from labml_nn.optimizers.noam import Noam
32from labml_nn.transformers import Encoder, MultiHeadAttention
33from labml_nn.transformers.feed_forward import FeedForward
34from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
35from labml_nn.transformers.utils import subsequent_mask38class AutoregressiveModel(Module):43 def __init__(self, src_embed: Module, encoder: Encoder, generator: Module):
44 super().__init__()ටෝකන්කාවැද්දීම මොඩියුලය
46 self.src_embed = src_embedට්රාන්ස්ෆෝමර්පදනම් කරගත් එන්කෝඩරය
48 self.encoder = encoderඊළඟටෝකන් උත්පාදන ස්තරය; මෙය ඊළඟ ටෝකනයේ පිවිසුම් ලබා දෙයි
51 self.generator = generatorමෙයපළමු ඇමතුමෙන් ආරම්භ කෙරේ
53 self.src_mask = None55 def forward(self, src: torch.Tensor):ට්රාන්ස්ෆෝමරයඅතීත ටෝකන කෙරෙහි පමණක් අවධානය යොමු කළ හැකි වන පරිදි පසුකාලීන වෙස්මුහුණක් සාදන්න.
57 if self.src_mask is None or self.src_mask.size(0) != len(src):
58 self.src_mask = subsequent_mask(len(src)).to(src.device)ටෝකනකාවැද්දීම (src
) සහ ට්රාන්ස්ෆෝමරය හරහා එය ක්රියාත්මක කරන්න
60 res = self.encoder(self.src_embed(src), self.src_mask)ඊළඟටෝකනයේ පිවිසුම් ජනනය කරන්න
62 return self.generator(res)65@dataclasses.dataclass
66class Configs:70 d_model: int = 512
71 seq_len: int = 128
72 batch_size: int = 32
73 n_layers: int = 6
74 n_heads: int = 8
75 dropout: float = 0.1
76 d_ff: int = 2048
77 glu_variant: str = 'GLU'
78 epochs: int = 5
79 grad_norm_clip: float = 0.582class TinyShakespeareDataset(Dataset):87 def __init__(self, seq_len: int):පෙළගොනුවේ පිහිටීම
89 path = lab.get_data_path() / 'tiny_shakespeare.txt'ගොනුවබාගන්න
91 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)බාගතකළ ගොනුව කියවන්න
93 with open(str(path), 'r') as f:
94 text = f.read()අක්ෂරඋපුටා ගන්න
97 chars = list(set(text))අක්ෂරය(පූර්ණ සංඛ්යා) සිතියමට
99 self.stoi = {c: i for i, c in enumerate(chars)}Idසිට අක්ෂර සිතියම
101 self.itos = {i: c for i, c in enumerate(chars)}පුහුණුනියැදියක දිග
103 self.seq_len = seq_lenහැඳුනුම්පත්ආතතියක ස්වරූපයෙන් දත්ත
105 self.data = self.text_to_i(text)පෙළහැඳුනුම්පත්වල ආතතියකට පරිවර්තනය කරන්න
107 def text_to_i(self, text: str):111 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)113 def __len__(self):119 return len(self.data) - self.seq_len - 1නියැදියක්ආපසු ලබා දෙන්න
121 def __getitem__(self, idx):125 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]128class Trainer:133 def __init__(self, configs: Configs):උපාංගයලබා ගන්න
135 self.device = torch.device('cpu')
136 if torch.cuda.is_available():
137 self.device = torch.device('cuda:0')දත්තසමුදාය ආරම්භ කරන්න
139 self.dataset = TinyShakespeareDataset(configs.seq_len)දත්තසමුදාය ආරම්භ කරන්න
141 self.dataloader = DataLoader(self.dataset,
142 batch_size=configs.batch_size,
143 collate_fn=transpose_batch,
144 shuffle=True)ගේට්ටුරේඛීය ඒකකය සමඟ FFN
148 if configs.glu_variant == 'GLU':
149 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)බිලීනියර්සැඟවුණු තට්ටුවක් සහිත FFN
152 elif configs.glu_variant == 'Bilinear':
153 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)RelUගේට්ටුව සමඟ FFN
156 elif configs.glu_variant == 'ReGLU':
157 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)GELUගේට්ටුව සහිත FFN
160 elif configs.glu_variant == 'GEGLU':
161 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)ස්විෂ් ගේට්ටුව සමඟ එෆ්එෆ්එන්
165 elif configs.glu_variant == 'SwiGLU':
166 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)RelUසක්රිය කිරීම සමඟ FFN
169 elif configs.glu_variant == 'ReLU':
170 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())RelUසක්රිය කිරීම සමඟ FFN
173 elif configs.glu_variant == 'GELU':
174 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
175 else:
176 raise ValueError(f'Unknown variant {configs.glu_variant}')විවිධඅක්ෂර ගණන
179 n_chars = len(self.dataset.stoi)බහු-හිස අවධානය මොඩියුලය ආරම්භ කරන්න
182 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)ට්රාන්ස්ෆෝමර් කොටස ආරම්භ කරන්න
184 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
185 feed_forward=ffn, dropout_prob=configs.dropout)පිවිසුම්උත්පාදනය කිරීම සඳහා කාවැද්දීමේ ස්ථරයක් (ස්ථාවර ස්ථානීය කේතීකරණයක් සහිත) ට්රාන්ස්ෆෝමර් එන්කෝඩරය සහ රේඛීය තට්ටුවක් සහිත ආකෘතිය ආරම්භ කරන්න.
191 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
192 Encoder(transformer_layer, configs.n_layers),
193 nn.Linear(configs.d_model, n_chars))ආකෘතියවත්මන් උපාංගයට ගෙනයන්න
196 self.model.to(self.device)Noam ප්රශස්තකරණය ආරම්භ කරන්න
199 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)හරස්එන්ට්රොපි නැතිවීම
202 self.loss_func = nn.CrossEntropyLoss()පුහුණු එපොච් ගණන; අපගේ දත්ත කට්ටල අර්ථ දැක්වීම තනි යුගයකින් දත්ත seq_len
වේලාවන් පුනරාවර්තනය කරන බව සලකන්න
205 self.epochs = configs.epochsශ්රේණියේක්ලිපින් සම්මතය
207 self.grad_norm_clip = configs.grad_norm_clipට්රැකර්වින්යාසයන් සකසන්න
210 tracker.set_scalar("loss.*", True)212 def sample(self):විමසුමක්ආරම්භ කිරීම
218 prompt = 'It is'මුද්රණයසඳහා ප්රතිදානය එකතු කරන්න
220 log = [(prompt, Text.subtle)]සාම්පල25 ටෝකන
222 for i in monit.iterate('Sample', 25):විමසුමටෝකෙන්කරන්න
224 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
225 data = data.to(self.device)ආදර්ශප්රතිදානය ලබා ගන්න
227 output = self.model(data)ආදර්ශඅනාවැකිය ලබා ගන්න (කෑදර)
229 output = output.argmax(dim=-1).squeeze()විමසුමටඅනාවැකිය එක් කරන්න
231 prompt += self.dataset.itos[output[-1].item()]ලොග්වීම සඳහා අනාවැකිය එක් කරන්න
233 log += [(self.dataset.itos[output[-1].item()], Text.value)]නියැදිප්රතිදානය මුද්රණය කරන්න
236 logger.log(log)238 def train(self):ලබාදී ඇති එපොච් සංඛ්යාව සඳහා ලූප්
244 for _ in monit.loop(self.epochs):මිනිබැච්හරහා නැවත ක්රියාත්මක කරන්න
246 for i, batch in monit.enum('Train', self.dataloader):උපාංගයවෙත දත්ත ගෙනයන්න
248 data, target = batch[0].to(self.device), batch[1].to(self.device)පුහුණුකරන ලද අක්ෂර ගණන ලෙස ට්රැකර් පියවර සකසන්න
251 tracker.add_global_step(data.shape[0] * data.shape[1])පුහුණුවසඳහා ආදර්ශ තත්වය සකසන්න
254 self.model.train()ආකෘතියතක්සේරු කරන්න
256 output = self.model(data)අලාභයගණනය කරන්න
259 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))අලාභයලොග් කරන්න
261 tracker.add("loss.train", loss)අනුක්රමිකගණනය කරන්න
264 loss.backward()ක්ලිප්අනුක්රමික
266 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)ප්රශස්තිකරණපියවර ගන්න
268 self.optimizer.step()ආදර්ශපරාමිතීන් සහ අනුක්රමික ලොග් කරන්න
270 if (i + 1) % 100 == 0:
271 tracker.add('model', self.model)අනුක්රමිකඉවත්
273 self.optimizer.zero_grad()නියැදියක්ජනනය කරන්න
276 if (i + 1) % 100 == 0:
277 self.model.eval()
278 with torch.no_grad():
279 self.sample()ලුහුබැඳඇති ප්රමිතික සුරකින්න
282 if (i + 1) % 10 == 0:
283 tracker.save()ආකෘතියසුරකින්න
286 experiment.save_checkpoint()289def main():අත්හදාබැලීම සාදන්න
291 experiment.create(name="glu_variants")වින්යාසසාදන්න
293 configs = Configs()වින්යාසයන්පූරණය කරන්න
295 experiment.configs(dataclasses.asdict(configs))පුහුණුකරුසාදන්න
298 trainer = Trainer(configs)පුහුණුවසහ පැටවීම සඳහා ආකෘති සකසන්න
300 experiment.add_pytorch_models({'model': trainer.model})අත්හදාබැලීම ආරම්භ කරන්න
303 with experiment.start():ආකෘතියපුහුණු කරන්න
305 trainer.train()
306
307
308if __name__ == '__main__':
309 main()