මෙය වෙස් භාෂා ආකෘතියක්පුහුණු කිරීම සඳහා කරන ලද පයිටෝර්ච් අත්හදා බැලීමකි.
11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import Encoder, Generator
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.mlm import MLM28class TransformerMLM(nn.Module):encoder
ට්රාන්ස්ෆෝමර් එන්කෝඩරය src_embed
යනු ටෝකන් කාවැද්දීමේ මොඩියුලය (ස්ථානීය කේතීකරණ සමඟ) generator
යනු පිවිසුම් ලබා දෙන අවසාන පූර්ණ සම්බන්ධිත ස්ථරයයි . 33 def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):40 super().__init__()
41 self.generator = generator
42 self.src_embed = src_embed
43 self.encoder = encoder45 def forward(self, x: torch.Tensor):ස්ථානීයකේතන ක්රම සමඟ ටෝකන් කාවැද්දීම් ලබා ගන්න
47 x = self.src_embed(x)ට්රාන්ස්ෆෝමර්එන්කෝඩරය
49 x = self.encoder(x, None)ප්රතිදානයසඳහා ලොගින් වන්න
51 y = self.generator(x)ප්රතිලාභප්රති results ල (දෙවන අගය රාජ්ය සඳහා වේ, මන්ද අපගේ පුහුණුකරු RNs සමඟ ද භාවිතා කරයි)
55 return y, Noneමෙයඋරුම වන්නේ අප මෙහි නැවත භාවිතා කරන දත්ත නල මාර්ග ක්රියාත්මක කිරීම් ඇති NLPAutoRegressionConfigs
බැවිනි. අපි අභිරුචි පුහුණු පියවරක් MLM පෝරමයක් ක්රියාත්මක කර ඇත්තෙමු.
58class Configs(NLPAutoRegressionConfigs):එම්එල්එම්ආකෘතිය
69 model: TransformerMLMට්රාන්ස්ෆෝමර්
71 transformer: TransformerConfigsටෝකනගණන
74 n_tokens: int = 'n_tokens_mlm'වෙස්මූස්නොකළ යුතු ටෝකන
76 no_mask_tokens: List[int] = []ටෝකනයක්ආවරණ කිරීමේ සම්භාවිතාව
78 masking_prob: float = 0.15අහඹුටෝකනයකින් වෙස්මුහුණ ප්රතිස්ථාපනය කිරීමේ සම්භාවිතාව
80 randomize_prob: float = 0.1වෙස්මුහුණමුල් ටෝකනය සමඟ ප්රතිස්ථාපනය කිරීමේ සම්භාවිතාව
82 no_change_prob: float = 0.1[MASK]
ටෝකනය
87 mask_token: int[PADDING]
ටෝකනය
89 padding_token: intනියැදියවෙත විමසන්න
92 prompt: str = [
93 "We are accounted poor citizens, the patricians good.",
94 "What authority surfeits on would relieve us: if they",
95 "would yield us but the superfluity, while it were",
96 "wholesome, we might guess they relieved us humanely;",
97 "but they think we are too dear: the leanness that",
98 "afflicts us, the object of our misery, is as an",
99 "inventory to particularise their abundance; our",
100 "sufferance is a gain to them Let us revenge this with",
101 "our pikes, ere we become rakes: for the gods know I",
102 "speak this in hunger for bread, not in thirst for revenge.",
103 ]105 def init(self):[MASK]
ටෝකනය
111 self.mask_token = self.n_tokens - 1[PAD]
ටෝකනය
113 self.padding_token = self.n_tokens - 2116 self.mlm = MLM(padding_token=self.padding_token,
117 mask_token=self.mask_token,
118 no_mask_tokens=self.no_mask_tokens,
119 n_tokens=self.n_tokens,
120 masking_prob=self.masking_prob,
121 randomize_prob=self.randomize_prob,
122 no_change_prob=self.no_change_prob)නිරවද්යතාවමෙට්රික් (සමාන ලේබල් නොසලකා හරින්න [PAD]
)
125 self.accuracy = Accuracy(ignore_index=self.padding_token)හරස්එන්ට්රොපිය නැතිවීම (සමාන ලේබල් නොසලකා හරින්න [PAD]
)
127 self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)129 super().init()131 def step(self, batch: any, batch_idx: BatchIndex):ආදානයඋපාංගයට ගෙනයන්න
137 data = batch[0].to(self.device)පුහුණුප්රකාරයේදී ගෝලීය පියවර යාවත්කාලීන කරන්න (සැකසූ ටෝකන ගණන)
140 if self.mode.is_train:
141 tracker.add_global_step(data.shape[0] * data.shape[1])වෙස්මූඩ්ආදානය සහ ලේබල ලබා ගන්න
144 with torch.no_grad():
145 data, labels = self.mlm(data)ආකෘතිප්රතිදානයන් ග්රහණය කර ගත යුතුද යන්න
148 with self.mode.update(is_log_activations=batch_idx.is_last):ආදර්ශප්රතිදානයන් ලබා ගන්න. ආර්එන්එස් භාවිතා කරන විට එය ප්රාන්ත සඳහා ටූල් එකක් නැවත ලබා දෙයි. මෙය තවමත් ක්රියාත්මක නොවේ.
152 output, *_ = self.model(data)අලාභයගණනය කර ලොග් කරන්න
155 loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
156 tracker.add("loss.", loss)ගණනයකිරීම සහ ලොග් කිරීමේ නිරවද්යතාවය
159 self.accuracy(output, labels)
160 self.accuracy.track()ආකෘතියපුහුණු කරන්න
163 if self.mode.is_train:අනුක්රමිකගණනය කරන්න
165 loss.backward()ක්ලිප්අනුක්රමික
167 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)ප්රශස්තිකරණපියවර ගන්න
169 self.optimizer.step()සෑමයුගලයකම අවසාන කණ්ඩායමේ ආදර්ශ පරාමිතීන් සහ අනුක්රමික ලොග් කරන්න
171 if batch_idx.is_last:
172 tracker.add('model', self.model)අනුක්රමිකඉවත්
174 self.optimizer.zero_grad()ලුහුබැඳඇති ප්රමිතික සුරකින්න
177 tracker.save()179 @torch.no_grad()
180 def sample(self):පුරවාඇති දත්ත සඳහා හිස් ටෙන්සර් [PAD]
.
186 data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)විමසීම්එකින් එක එකතු කරන්න
188 for i, p in enumerate(self.prompt):ටෝකන්දර්ශක ලබා ගන්න
190 d = self.text.text_to_i(p)ටෙන්සරයටඑකතු කරන්න
192 s = min(self.seq_len, len(d))
193 data[:s, i] = d[:s]ටෙන්සරයවත්මන් උපාංගයට ගෙන යන්න
195 data = data.to(self.device)වෙස්ආදාන සහ ලේබල ලබා ගන්න
198 data, labels = self.mlm(data)ආදර්ශප්රතිදානයන් ලබා ගන්න
200 output, *_ = self.model(data)ජනනයකරන ලද සාම්පල මුද්රණය කරන්න
203 for j in range(data.shape[1]):මුද්රණයෙන්ප්රතිදානය එකතු කරන්න
205 log = []එක්එක් ටෝකනය සඳහා
207 for i in range(len(data)):ලේබලයනොමැති නම් [PAD]
209 if labels[i, j] != self.padding_token:අනාවැකියලබා ගන්න
211 t = output[i, j].argmax().item()එයමුද්රණය කළ හැකි චරිතයක් නම්
213 if t < len(self.text.itos):නිවැරදිඅනාවැකිය
215 if t == labels[i, j]:
216 log.append((self.text.itos[t], Text.value))වැරදිඅනාවැකිය
218 else:
219 log.append((self.text.itos[t], Text.danger))එයමුද්රණය කළ හැකි චරිතයක් නොවේ නම්
221 else:
222 log.append(('*', Text.danger))ලේබලය [PAD]
(නොකැඩූ) නම් මුල් පිටපත මුද්රණය කරන්න.
224 elif data[i, j] < len(self.text.itos):
225 log.append((self.text.itos[data[i, j]], Text.subtle))මුද්රණය
228 logger.log(log) ඇතුළුව [PAD]
සහ ටෝකන ගණන [MASK]
231@option(Configs.n_tokens)
232def n_tokens_mlm(c: Configs):236 return c.text.n_tokens + 2239@option(Configs.transformer)
240def _transformer_configs(c: Configs):අපගේ වින්යාසගත කළ හැකි ට්රාන්ස්ෆෝමර් ක්රියාත්මක කිරීම භාවිතා කරමු
247 conf = TransformerConfigs()කාවැද්දීම්සහ පිවිසුම් උත්පාදනය සඳහා වචන මාලාව ප්රමාණ සකසන්න
249 conf.n_src_vocab = c.n_tokens
250 conf.n_tgt_vocab = c.n_tokensකාවැද්දීමප්රමාණය
252 conf.d_model = c.d_model255 return confවර්ගීකරණආකෘතිය සාදන්න
258@option(Configs.model)
259def _model(c: Configs):263 m = TransformerMLM(encoder=c.transformer.encoder,
264 src_embed=c.transformer.src_embed,
265 generator=c.transformer.generator).to(c.device)
266
267 return m270def main():අත්හදාබැලීම සාදන්න
272 experiment.create(name="mlm")වින්යාසසාදන්න
274 conf = Configs()වින්යාසයන්අභිබවා යන්න
276 experiment.configs(conf, {කණ්ඩායම්ප්රමාණය
278 'batch_size': 64,අනුපිළිවෙලදිග . වේගයෙන් පුහුණු කිරීම සඳහා අපි කෙටි අනුක්රමික දිගක් භාවිතා කරමු. එසේ නොමැතිනම් එය සදහටම පුහුණු කිරීමට ගත වේ.
281 'seq_len': 32,1024එපොච් සඳහා දුම්රිය.
284 'epochs': 1024,එක් යුගයකට වරක් පුහුණුව සහ වලංගු කිරීම අතර මාරු වන්න
287 'inner_iterations': 1,ට්රාන්ස්ෆෝමර්වින්යාසයන් (පෙරනිමි ලෙස)
290 'd_model': 128,
291 'transformer.ffn.d_ff': 256,
292 'transformer.n_heads': 8,
293 'transformer.n_layers': 6,නෝම් ප්රශස්තකරණය භාවිතා කරන්න
296 'optimizer.optimizer': 'Noam',
297 'optimizer.learning_rate': 1.,
298 })ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න
301 experiment.add_pytorch_models({'model': conf.model})අත්හදාබැලීම ආරම්භ කරන්න
304 with experiment.start():පුහුණුධාවනය
306 conf.run()310if __name__ == '__main__':
311 main()