මෙය PonderNet කඩදාසි PyTorch ක්රියාත්මක කිරීම: පොන්ඩර් වෙත ඉගෙනීම .
PonderNetආදානය මත පදනම්ව ගණනය කිරීම අනුවර්තනය කරයි. ආදානය මත පදනම්ව පුනරාවර්තන ජාලයක් සඳහා ගත යුතු පියවර ගණන එය වෙනස් කරයි. පොන්ඩර්නෙට් මෙය ඉගෙන ගන්නේ අවසානය සිට අවසානය දක්වා වූ ශ්රේණියේ සම්භවයක් සමඟය.
පොන්ඩර්නෙට්හි ආකෘතියේ පියවර ශ්රිතයක් ඇත
ආදානය කොතැනද, ප්රාන්තය ද , පියවරෙන් පියවර පුරෝකථනය ද, වර්තමාන පියවරේදී නතර කිරීමේ සම්භාවිතාව (නතර කිරීම).
ඕනෑම ස්නායුක ජාලයක් විය හැකිය (උදා: LSTM, MLP, GRU, අවධානය යොමු කිරීමේ ස්ථරය).
පියවරෙන්පියවර නතර කිරීමේ කොන්දේසි විරහිත සම්භාවිතාව නම්,
පෙරකිසිදු පියවරකින් නතර නොවී පියවරෙන් පියවර නතර කිරීමේ සම්භාවිතාව මෙයයි .
අනුමානයඅතරතුර, අපි නතර කිරීමේ සම්භාවිතාව මත පදනම්ව නියැදීමෙන් නතර කර අවසාන නිමැවුම ලෙස අඩක් ස්ථරයේ පුරෝකථනය ලබා ගනිමු.
පුහුණුවඅතරතුර, අපි සියලු ස්ථරවලින් අනාවැකි ලබා ගන්නා අතර ඒවා එක් එක් සඳහා පාඩු ගණනය කරමු. ඉන්පසු එක් එක් ස්ථරයේ නතර වීමේ සම්භාවිතාව මත පදනම්ව අලාභයේ බර තැබූ සාමාන්යය ගන්න .
පියවරශ්රිතය පරිත්යාග කරන ලද උපරිම පියවර ගණනකට අදාළ වේ.
පොන්ඩර්නෙට්හි සමස්ත අලාභය වන්නේ
ඉලක්කය හා අනාවැකිය අතර සාමාන්ය පාඩු කාර්යය වේ.
යනු කුල්බැක් - ලයිබ්ලර් අපසරනයයි.
යනු පරාමිතිකරණය කරන ලද ජ්යාමිතික බෙදා හැරීමයි . සමග කිසිදු සම්බන්ධයක් නැත ; අපි හුදෙක් කඩදාසි ලෙස එම අංකනය ඇලී සිටිති. .
නියාමනයකිරීමේ අලාභය ජාලය පියවර ගැනීම සඳහා නැඹුරුව ඇති අතර සියලු පියවර සඳහා ශුන්ය නොවන සම්භාවිතාවන් දිරිගන්වයි; එනම් ගවේෂණය ප්රවර්ධනය කරයි.
Pority Task පිළිබඳ පොන්ඩර්නෙට් පුහුණු experiment.py
කිරීම සඳහා පුහුණු කේතය මෙන්න.
65from typing import Tuple
66
67import torch
68from torch import nn
69
70from labml_helpers.module import Moduleමෙයපියවර ශ්රිතය ලෙස GRU සෛලයක් භාවිතා කරන සරල ආකෘතියකි.
මෙමආකෘතිය ආදාන දෛශිකයක් වන Parity Task සඳහා වේ n_elems
. දෛශිකයේ සෑම මූලද්රව්යයක්ම එක්කෝ 0
, 1
නැතහොත් -1
ප්රතිදානය යනු සමානාත්මතාවයයි - 1
s ගණන ඔත්තේ නම් සත්ය වන ද්විමය අගයකි සහ වෙනත් ආකාරයකින් අසත්යය.
ආකෘතියේපුරෝකථනය යනු සමානාත්මතාවයේ ලොග් සම්භාවිතාවයි.
73class ParityPonderGRU(Module):n_elems
ආදාන දෛශිකයේ මූලද්රව්ය ගණන n_hidden
GRU හි රාජ්ය දෛශික ප්රමාණය වේ max_steps
පියවර උපරිම සංඛ්යාව 87 def __init__(self, n_elems: int, n_hidden: int, max_steps: int):93 super().__init__()
94
95 self.max_steps = max_steps
96 self.n_hidden = n_hiddenGRU
100 self.gru = nn.GRUCell(n_elems, n_hidden)අපට ආදානය ලෙස සහ ආදානය ලෙස ස්තරයක් භාවිතා කළ හැකි නමුත් සරල බව සඳහා අපි මෙය සමඟ ගියෙමු.
104 self.output_layer = nn.Linear(n_hidden, 1)106 self.lambda_layer = nn.Linear(n_hidden, 1)
107 self.lambda_prob = nn.Sigmoid()ගණනයඇත්තටම අනුමාන කාලය නතර කරන බව එසේ අනුමානය තුළ සකස් කිරීමට විකල්පයක්
109 self.is_halt = Falsex
හැඩයේ ආදානය වේ [batch_size, n_elems]
මෙමtensors හතරක් tuple ප්රතිදානය:
1. හැඩයේ ආතතියෙන් [N, batch_size]
2. හැඩයේ ආතතියක් තුළ [N, batch_size]
- සමානාත්මතාවයේ ලොග් සම්භාවිතාව 3. හැඩය [batch_size]
4. මෙම ගණනය පියවර නතර කරන ලදී [batch_size]
එහිදී හැඩය
111 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:124 batch_size = x.shape[0]අපටආරම්භක තත්වය ලැබේ
127 h = x.new_zeros((x.shape[0], self.n_hidden))
128 h = self.gru(x, h)ගබඩාකිරීමට ලැයිස්තු සහ
131 p = []
132 y = []134 un_halted_prob = h.new_ones((batch_size,))සාම්පලගණනය කිරීම නවතා දමා ඇති නඩත්තු කිරීම සඳහා දෛශිකයක්
137 halted = h.new_zeros((batch_size,))හා එහිදී ගණනය පියවර නතර කරන ලදී
139 p_m = h.new_zeros((batch_size,))
140 y_m = h.new_zeros((batch_size,))පියවර සඳහා අවධාරණය කරන්න
143 for n in range(1, self.max_steps + 1):අවසානපියවර සඳහා නතර කිරීමේ සම්භාවිතාව
145 if n == self.max_steps:
146 lambda_n = h.new_ones(h.shape[0])148 else:
149 lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]151 y_n = self.output_layer(h)[:, 0]154 p_n = un_halted_prob * lambda_nයාවත්කාලීනකරන්න
156 un_halted_prob = un_halted_prob * (1 - lambda_n)නතරකිරීමේ සම්භාවිතාව මත පදනම්ව නතර කරන්න
159 halt = torch.bernoulli(lambda_n) * (1 - halted)එකතුකරන්න සහ
162 p.append(p_n)
163 y.append(y_n)යාවත්කාලීනකිරීම සහ වර්තමාන පියවරේදී නවතා දැමූ දේ මත පදනම්ව
166 p_m = p_m * (1 - halt) + p_n * halt
167 y_m = y_m * (1 - halt) + y_n * haltනතරකරන ලද සාම්පල යාවත්කාලීන කරන්න
170 halted = halted + haltඊළඟතත්වය ලබා ගන්න
172 h = self.gru(x, h)සියලුමසාම්පල නතර කර ඇත්නම් ගණනය කිරීම නවත්වන්න
175 if self.is_halt and halted.sum() == batch_size:
176 break179 return torch.stack(p), torch.stack(y), p_m, y_m182class ReconstructionLoss(Module):loss_func
පාඩු ශ්රිතය වේ 191 def __init__(self, loss_func: nn.Module):195 super().__init__()
196 self.loss_func = loss_funcp
හැඩයේ ආතතියෙන් යුක්ත වේ [N, batch_size]
y_hat
හැඩයේ ආතතියෙන් යුක්ත වේ [N, batch_size, ...]
y
හැඩයේ ඉලක්කයයි [batch_size, ...]
198 def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):එකතුව
206 total_loss = p.new_tensor(0.)නැවතනැවත කරන්න
208 for n in range(p.shape[0]):එක් එක් නියැදිය සහ ඒවායේ මධ්යන්යය සඳහා
210 loss = (p[n] * self.loss_func(y_hat[n], y)).mean()සම්පූර්ණඅලාභයට එකතු කරන්න
212 total_loss = total_loss + loss215 return total_loss
යනු කුල්බැක් - ලයිබ්ලර් අපසරනයයි.
යනු පරාමිතිකරණය කරන ලද ජ්යාමිතික බෙදා හැරීමයි . සමග කිසිදු සම්බන්ධයක් නැත ; අපි හුදෙක් කඩදාසි ලෙස එම අංකනය ඇලී සිටිති. .
විධිමත්කිරීමේ අලාභය ජාලය පියවර ගැනීම කෙරෙහි නැඹුරුවීම සහ සියලු පියවර සඳහා ශුන්ය නොවන සම්භාවිතාවන් දිරිගැන්වීම; එනම් ගවේෂණය ප්රවර්ධනය කරයි.
218class RegularizationLoss(Module):lambda_p
යනු - ජ්යාමිතික ව්යාප්තියේ සාර්ථක සම්භාවිතාව max_steps
ඉහළම ; අපි මෙය පූර්ව ගණනය කිරීමට භාවිතා කරමු 234 def __init__(self, lambda_p: float, max_steps: int = 1_000):239 super().__init__()ගණනයකිරීම සඳහා හිස් දෛශිකය
242 p_g = torch.zeros((max_steps,))244 not_halted = 1.නැවතනැවත කරන්න max_steps
246 for k in range(max_steps):248 p_g[k] = not_halted * lambda_pයාවත්කාලීනකරන්න
250 not_halted = not_halted * (1 - lambda_p)සුරකින්න
253 self.p_g = nn.Parameter(p_g, requires_grad=False)එල්. එල්-අපසරනය අහිමි
256 self.kl_div = nn.KLDivLoss(reduction='batchmean')p
හැඩයේ ආතතියෙන් යුක්ත වේ [N, batch_size]
258 def forward(self, p: torch.Tensor):p
වෙත සම්ප්රේෂණය කරන්න [batch_size, N]
263 p = p.transpose(0, 1)දක්වා ලබා ගැනීමට හා කණ්ඩායම මානයක් හරහා එය පුළුල්
265 p_g = self.p_g[None, :p.shape[1]].expand_as(p)කේඑල්-අපසරනයගණනය කරන්න. Pytorch KL-අපසරනය ක්රියාත්මක කිරීම ලොග් සම්භාවිතාව පිළිගනී.
270 return self.kl_div(p.log(), p_g)