මෙය PyTorch ක්රියාත්මක කිරීම/නිබන්ධනයකි ඩයිනමික් රවුටින් කැප්සියුල අතර .
කැප්සියුලජාලය යනු ස්නායුක ජාල ගෘහ නිර්මාණ ශිල්පයක් වන අතර එය කැප්සියුල ලෙස ලක්ෂණ කාවැද්දීම සහ ඡන්දය ප්රකාශ කිරීමේ යාන්ත්රණයක් සමඟ ඊළඟ ස්ථරයට ගමන් කරයි.
ආකෘතිවලවෙනත් ක්රියාත්මක කිරීම් වලදී මෙන් නොව, අපි නියැදියක් ඇතුළත් කර ඇත, මන්ද මොඩියුල සමඟ සමහර සංකල්ප තේරුම් ගැනීමට අපහසු බැවිනි. MNIST දත්ත සමුදාය වර්ගීකරණය කිරීම සඳහා කැප්සියුල භාවිතා කරන ආකෘතියක් සඳහා වියුක්ත කේතය මෙයයි
කැප්සියුලනෙට්වර්ක් හි මූලික මොඩියුලවල ක්රියාත්මක කිරීම් මෙම ගොනුව සතුව ඇත.
කඩදාසිසමඟ මා සතුව තිබූ ව්යාකූලතා පැහැදිලි කිරීම සඳහා මම ජින්ඩොන්ග්වාන්ග්/පයිටෝච්-කැප්සුලෙනෙට් භාවිතා කළෙමි.
MNISTදත්ත කට්ටලය පිළිබඳ කැප්සියුල ජාලයක් පුහුණු කිරීම සඳහා සටහන් පොතක් මෙන්න.
33import torch.nn as nn
34import torch.nn.functional as F
35import torch.utils.data
36
37from labml_helpers.module import Moduleමෙයසමීකරණයෙන් ලබා දී ඇති කඩදාසි වලින් ස්කොෂ් ශ්රිතයකි.
එක් එක් ප්රමාණයට වඩා කුඩා දිගක් ඇති කැප්සියුල හැකිලෙන අතර, සියලු කැප්සියුල වල දිග සාමාන්යකරණය කරයි.
40class Squash(Module):55 def __init__(self, epsilon=1e-8):
56 super().__init__()
57 self.epsilon = epsilon හි s
හැඩය [batch_size, n_capsules, n_features]
59 def forward(self, s: torch.Tensor):65 s2 = (s ** 2).sum(dim=-1, keepdims=True)අපිඑය ශුන්ය බවට පත් නොවන බවට වග බලා ගන්න කිරීමට ගණනය විට epsilon එකතු. මෙය ශුන්ය බවට පත් වුවහොත් එය nan
සාරධර්ම ලබා දීම ආරම්භ කරන අතර පුහුණුව අසමත් වේ.
71 return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))කඩදාසිවල විස්තර කර ඇති මාර්ගගත කිරීමේ යාන්ත්රණය මෙයයි. ඔබේ මාදිලිවල බහු රවුටින් ස්ථර භාවිතා කළ හැකිය.
මෙයමෙම ස්තරය සඳහා ගණනය කිරීම සහ ක්රියා පටිපාටිය 1හි විස්තර කර ඇති රවුටින් ඇල්ගොරිතම ඒකාබද්ධ කරයි.
74class Router(Module): in_caps
යනු කැප්සියුල ගණන in_d
වන අතර පහත ස්ථරයෙන් කැප්සියුලයකට ඇති ලක්ෂණ ගණන වේ. out_caps
මෙම ස්තරය සඳහා සමාන වේ. out_d
iterations
යනු කඩදාසි වලින් සංකේතවත් කරන මාර්ගගත කිරීමේ පුනරාවර්තන සංඛ්යාවකි.
85 def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):92 super().__init__()
93 self.in_caps = in_caps
94 self.out_caps = out_caps
95 self.iterations = iterations
96 self.softmax = nn.Softmax(dim=1)
97 self.squash = Squash()මෙයබර අනුකෘතියකි . මෙම ස්ථරයේ එක් එක් කැප්සියුලයට පහළ ස්ථරයේ එක් එක් කැප්සියුලය සිතියම්
ගත කරයි101 self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True) හැඩය u
වේ [batch_size, n_capsules, n_features]
. මේවා පහළ ස්ථරයේ සිට කැප්සියුල වේ.
103 def forward(self, u: torch.Tensor):මෙන්න මෙම ස්ථරයේ කැප්සියුල දර්ශකය සඳහා භාවිතා වන අතර පහත ස්ථරයේ (පෙර) කැප්සියුල දර්ශකය කිරීමට භාවිතා කරයි.
112 u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)ආරම්භකපිවිසුම් යනු කැප්සියුලය සමඟ සම්බන්ධ විය යුතු ලොග් පූර්ව සම්භාවිතාවයි. අපි මේවා ශුන්යයෙන් ආරම්භ කරමු
117 b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
118
119 v = Noneපුනරාවර්තනයකරන්න
122 for i in range(self.iterations):සොෆ්ට්මැක්ස්රවුටින්
124 c = self.softmax(b)126 s = torch.einsum('bij,bijm->bjm', c, u_hat)128 v = self.squash(s)130 a = torch.einsum('bjm,bijm->bij', v, u_hat)132 b = b + a
133
134 return vඑක්එක් නිමැවුම් කැප්සියුලය සඳහා වෙනම ආන්තික අලාභයක් භාවිතා වන අතර සම්පූර්ණ අලාභය ඒවායේ එකතුවයි. එක් එක් නිමැවුම් කැප්සියුලයේ දිග යනු ආදානයේ පන්තිය පවතින සම්භාවිතාවයි.
එක්එක් නිමැවුම් කැප්සියුලය හෝ පන්තිය සඳහා නැතිවීම,
යනු පන්තිය පවතින අතර වෙනත් ආකාරයකින් නම්. අලාභයේ පළමු සංරචකය වන්නේ පන්තිය නොමැති විට වන අතර දෙවන සංරචකය වන්නේ පන්තිය පවතී නම් වේ. අනාවැකි අන්තයට යාම වළක්වා ගැනීමට මෙම භාවිතා කරයි. කඩදාසි තුළ සිටීමට හා වීමට සකසා ඇත.
පහතට-බර පුහුණු ආරම්භක අදියර තුළ පහත වැටීම සිට සියලු කරල් දිග නතර කිරීම සඳහා භාවිතා වේ.
137class MarginLoss(Module):157 def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
158 super().__init__()
159
160 self.m_negative = m_negative
161 self.m_positive = m_positive
162 self.lambda_ = lambda_
163 self.n_labels = n_labels v
, මෙම squashed ප්රතිදානය කැප්සියුල වේ. මෙය හැඩය ඇත [batch_size, n_labels, n_features]
; එනම්, එක් එක් ලේබලය සඳහා කැප්සියුලයක් ඇත.
labels
ලේබල, සහ හැඩය ඇත [batch_size]
.
165 def forward(self, v: torch.Tensor, labels: torch.Tensor):173 v_norm = torch.sqrt((v ** 2).sum(dim=-1)) labels
හැඩයේ එක් උණුසුම් කේතනය කරන ලද ලේබල [batch_size, n_labels]
177 labels = torch.eye(self.n_labels, device=labels.device)[labels] loss
හැඩය ඇත [batch_size, n_labels]
. අපි සියලු සඳහා ගණනය කිරීම සමාන්තරකරණය කර ඇත.
183 loss = labels * F.relu(self.m_positive - v_norm) + \
184 self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)187 return loss.sum(dim=-1).mean()