පොන්ඩර්නෙට්: මෙනෙහි කිරීමට ඉගෙනීම

මෙය PonderNet කඩදාසි PyTorch ක්රියාත්මක කිරීම: පොන්ඩර් වෙත ඉගෙනීම .

PonderNetආදානය මත පදනම්ව ගණනය කිරීම අනුවර්තනය කරයි. ආදානය මත පදනම්ව පුනරාවර්තන ජාලයක් සඳහා ගත යුතු පියවර ගණන එය වෙනස් කරයි. පොන්ඩර්නෙට් මෙය ඉගෙන ගන්නේ අවසානය සිට අවසානය දක්වා වූ ශ්රේණියේ සම්භවයක් සමඟය.

පොන්ඩර්නෙට්හි ආකෘතියේ පියවර ශ්රිතයක් ඇත

ආදානය කොතැනද, ප්රාන්තය ද , පියවරෙන් පියවර පුරෝකථනය ද, වර්තමාන පියවරේදී නතර කිරීමේ සම්භාවිතාව (නතර කිරීම).

ඕනෑම ස්නායුක ජාලයක් විය හැකිය (උදා: LSTM, MLP, GRU, අවධානය යොමු කිරීමේ ස්ථරය).

පියවරෙන්පියවර නතර කිරීමේ කොන්දේසි විරහිත සම්භාවිතාව නම්,

පෙරකිසිදු පියවරකින් නතර නොවී පියවරෙන් පියවර නතර කිරීමේ සම්භාවිතාව මෙයයි .

අනුමානයඅතරතුර, අපි නතර කිරීමේ සම්භාවිතාව මත පදනම්ව නියැදීමෙන් නතර කර අවසාන නිමැවුම ලෙස අඩක් ස්ථරයේ පුරෝකථනය ලබා ගනිමු.

පුහුණුවඅතරතුර, අපි සියලු ස්ථරවලින් අනාවැකි ලබා ගන්නා අතර ඒවා එක් එක් සඳහා පාඩු ගණනය කරමු. ඉන්පසු එක් එක් ස්ථරයේ නතර වීමේ සම්භාවිතාව මත පදනම්ව අලාභයේ බර තැබූ සාමාන්යය ගන්න .

පියවරශ්රිතය පරිත්යාග කරන ලද උපරිම පියවර ගණනකට අදාළ වේ.

පොන්ඩර්නෙට්හි සමස්ත අලාභය වන්නේ

ඉලක්කය හා අනාවැකිය අතර සාමාන්ය පාඩු කාර්යය වේ.

යනු කුල්බැක් - ලයිබ්ලර් අපසරනයයි.

යනු පරාමිතිකරණය කරන ලද ජ්යාමිතික බෙදා හැරීමයි . සමග කිසිදු සම්බන්ධයක් නැත ; අපි හුදෙක් කඩදාසි ලෙස එම අංකනය ඇලී සිටිති. .

නියාමනයකිරීමේ අලාභය ජාලය පියවර ගැනීම සඳහා නැඹුරුව ඇති අතර සියලු පියවර සඳහා ශුන්ය නොවන සම්භාවිතාවන් දිරිගන්වයි; එනම් ගවේෂණය ප්රවර්ධනය කරයි.

Pority Task පිළිබඳ පොන්ඩර්නෙට් පුහුණු experiment.py කිරීම සඳහා පුහුණු කේතය මෙන්න.

View Run

65from typing import Tuple
66
67import torch
68from torch import nn
69
70from labml_helpers.module import Module

සමානාත්මතාකාර්ය සඳහා GRU සමඟ පොන්ඩර්නෙට්

මෙයපියවර ශ්රිතය ලෙස GRU සෛලයක් භාවිතා කරන සරල ආකෘතියකි.

මෙමආකෘතිය ආදාන දෛශිකයක් වන Parity Task සඳහා වේ n_elems . දෛශිකයේ සෑම මූලද්රව්යයක්ම එක්කෝ 0 , 1 නැතහොත් -1 ප්රතිදානය යනු සමානාත්මතාවයයි - 1 s ගණන ඔත්තේ නම් සත්ය වන ද්විමය අගයකි සහ වෙනත් ආකාරයකින් අසත්යය.

ආකෘතියේපුරෝකථනය යනු සමානාත්මතාවයේ ලොග් සම්භාවිතාවයි.

73class ParityPonderGRU(Module):
  • n_elems ආදාන දෛශිකයේ මූලද්රව්ය ගණන
  • n_hidden GRU හි රාජ්ය දෛශික ප්රමාණය වේ
  • max_steps පියවර උපරිම සංඛ්යාව
87    def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
93        super().__init__()
94
95        self.max_steps = max_steps
96        self.n_hidden = n_hidden

GRU

100        self.gru = nn.GRUCell(n_elems, n_hidden)

අපට ආදානය ලෙස සහ ආදානය ලෙස ස්තරයක් භාවිතා කළ හැකි නමුත් සරල බව සඳහා අපි මෙය සමඟ ගියෙමු.

104        self.output_layer = nn.Linear(n_hidden, 1)

106        self.lambda_layer = nn.Linear(n_hidden, 1)
107        self.lambda_prob = nn.Sigmoid()

ගණනයඇත්තටම අනුමාන කාලය නතර කරන බව එසේ අනුමානය තුළ සකස් කිරීමට විකල්පයක්

109        self.is_halt = False
  • x හැඩයේ ආදානය වේ [batch_size, n_elems]

මෙමtensors හතරක් tuple ප්රතිදානය:

1. හැඩයේ ආතතියෙන් [N, batch_size] 2. හැඩයේ ආතතියක් තුළ [N, batch_size] - සමානාත්මතාවයේ ලොග් සම්භාවිතාව 3. හැඩය [batch_size] 4. මෙම ගණනය පියවර නතර කරන ලදී [batch_size] එහිදී හැඩය

111    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

124        batch_size = x.shape[0]

අපටආරම්භක තත්වය ලැබේ

127        h = x.new_zeros((x.shape[0], self.n_hidden))
128        h = self.gru(x, h)

ගබඩාකිරීමට ලැයිස්තු සහ

131        p = []
132        y = []

134        un_halted_prob = h.new_ones((batch_size,))

සාම්පලගණනය කිරීම නවතා දමා ඇති නඩත්තු කිරීම සඳහා දෛශිකයක්

137        halted = h.new_zeros((batch_size,))

හා එහිදී ගණනය පියවර නතර කරන ලදී

139        p_m = h.new_zeros((batch_size,))
140        y_m = h.new_zeros((batch_size,))

පියවර සඳහා අවධාරණය කරන්න

143        for n in range(1, self.max_steps + 1):

අවසානපියවර සඳහා නතර කිරීමේ සම්භාවිතාව

145            if n == self.max_steps:
146                lambda_n = h.new_ones(h.shape[0])

148            else:
149                lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]

151            y_n = self.output_layer(h)[:, 0]

154            p_n = un_halted_prob * lambda_n

යාවත්කාලීනකරන්න

156            un_halted_prob = un_halted_prob * (1 - lambda_n)

නතරකිරීමේ සම්භාවිතාව මත පදනම්ව නතර කරන්න

159            halt = torch.bernoulli(lambda_n) * (1 - halted)

එකතුකරන්න සහ

162            p.append(p_n)
163            y.append(y_n)

යාවත්කාලීනකිරීම සහ වර්තමාන පියවරේදී නවතා දැමූ දේ මත පදනම්ව

166            p_m = p_m * (1 - halt) + p_n * halt
167            y_m = y_m * (1 - halt) + y_n * halt

නතරකරන ලද සාම්පල යාවත්කාලීන කරන්න

170            halted = halted + halt

ඊළඟතත්වය ලබා ගන්න

172            h = self.gru(x, h)

සියලුමසාම්පල නතර කර ඇත්නම් ගණනය කිරීම නවත්වන්න

175            if self.is_halt and halted.sum() == batch_size:
176                break

179        return torch.stack(p), torch.stack(y), p_m, y_m

ප්රතිසංස්කරණඅලාභය

ඉලක්කය හා අනාවැකිය අතර සාමාන්ය පාඩු කාර්යය වේ.

182class ReconstructionLoss(Module):
  • loss_func පාඩු ශ්රිතය වේ
191    def __init__(self, loss_func: nn.Module):
195        super().__init__()
196        self.loss_func = loss_func
  • p හැඩයේ ආතතියෙන් යුක්ත වේ [N, batch_size]
  • y_hat හැඩයේ ආතතියෙන් යුක්ත වේ [N, batch_size, ...]
  • y හැඩයේ ඉලක්කයයි [batch_size, ...]
198    def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):

එකතුව

206        total_loss = p.new_tensor(0.)

නැවතනැවත කරන්න

208        for n in range(p.shape[0]):

එක් එක් නියැදිය සහ ඒවායේ මධ්යන්යය සඳහා

210            loss = (p[n] * self.loss_func(y_hat[n], y)).mean()

සම්පූර්ණඅලාභයට එකතු කරන්න

212            total_loss = total_loss + loss

215        return total_loss

අලාභයවිධිමත් කිරීම

යනු කුල්බැක් - ලයිබ්ලර් අපසරනයයි.

යනු පරාමිතිකරණය කරන ලද ජ්යාමිතික බෙදා හැරීමයි . සමග කිසිදු සම්බන්ධයක් නැත ; අපි හුදෙක් කඩදාසි ලෙස එම අංකනය ඇලී සිටිති. .

විධිමත්කිරීමේ අලාභය ජාලය පියවර ගැනීම කෙරෙහි නැඹුරුවීම සහ සියලු පියවර සඳහා ශුන්ය නොවන සම්භාවිතාවන් දිරිගැන්වීම; එනම් ගවේෂණය ප්රවර්ධනය කරයි.

218class RegularizationLoss(Module):
  • lambda_p යනු - ජ්යාමිතික ව්යාප්තියේ සාර්ථක සම්භාවිතාව
  • max_steps ඉහළම ; අපි මෙය පූර්ව ගණනය කිරීමට භාවිතා කරමු
234    def __init__(self, lambda_p: float, max_steps: int = 1_000):
239        super().__init__()

ගණනයකිරීම සඳහා හිස් දෛශිකය

242        p_g = torch.zeros((max_steps,))

244        not_halted = 1.

නැවතනැවත කරන්න max_steps

246        for k in range(max_steps):

248            p_g[k] = not_halted * lambda_p

යාවත්කාලීනකරන්න

250            not_halted = not_halted * (1 - lambda_p)

සුරකින්න

253        self.p_g = nn.Parameter(p_g, requires_grad=False)

එල්. එල්-අපසරනය අහිමි

256        self.kl_div = nn.KLDivLoss(reduction='batchmean')
  • p හැඩයේ ආතතියෙන් යුක්ත වේ [N, batch_size]
258    def forward(self, p: torch.Tensor):

p වෙත සම්ප්රේෂණය කරන්න [batch_size, N]

263        p = p.transpose(0, 1)

දක්වා ලබා ගැනීමට හා කණ්ඩායම මානයක් හරහා එය පුළුල්

265        p_g = self.p_g[None, :p.shape[1]].expand_as(p)

කේඑල්-අපසරනයගණනය කරන්න. Pytorch KL-අපසරනය ක්රියාත්මක කිරීම ලොග් සම්භාවිතාව පිළිගනී.

270        return self.kl_div(p.log(), p_g)