කණ්ඩායම්සාමාන්යකරණය

මෙය සමූහ සාමාන්යකරණය කිරීමේ කඩදාසි PyTorch ක්රියාත්මක කිරීමයි.

කණ්ඩායම් සාමාන්යකරණය ප්රමාණවත් තරම් විශාල කණ්ඩායම් ප්රමාණ සඳහා හොඳින් ක්රියා කරන නමුත් කුඩා කණ්ඩායම් ප්රමාණ සඳහා හොඳින් නොවේ, මන්ද එය කණ්ඩායමට වඩා සාමාන්යකරණය කරයි. උපාංගවල මතක ධාරිතාව නිසා විශාල කණ්ඩායම් ප්රමාණ සහිත විශාල ආකෘති පුහුණු කිරීම කළ නොහැක.

මෙමලිපිය සමූහ සාමාන්යකරණය හඳුන්වා දෙයි, එය කණ්ඩායමක් ලෙස එකට විශේෂාංග සමූහයක් සාමාන්යකරණය කරයි. මෙය පදනම් වී ඇත්තේ SIFT සහ HOG වැනි සම්භාව්ය ලක්ෂණ කණ්ඩායම් අනුව ලක්ෂණ බව නිරීක්ෂණය කිරීම මත ය. විශේෂාංග නාලිකා කණ්ඩායම් වලට බෙදීමට සහ එක් එක් කණ්ඩායම තුළ ඇති සියලුම නාලිකා වෙන වෙනම සාමාන්යකරණය කිරීමට ලිපිය යෝජනා කරයි.

සූත්රගතකිරීම

සියලුමසාමාන්යකරණ ස්ථර පහත දැක්වෙන ගණනය කිරීම් මගින් අර්ථ දැක්විය හැකිය.

කණ්ඩායමනියෝජනය කරන ටෙන්සරය කොහේද , තනි අගයක දර්ශකය වේ. නිදසුනක් ලෙස, එය 2D රූප වන විට කණ්ඩායම, විශේෂාංග නාලිකාව, සිරස් ඛණ්ඩාංක සහ තිරස් ඛණ්ඩාංක තුළ රූපය සුචිගත කිරීම සඳහා 4-d දෛශිකයකි. ඒවා මධ්යන්ය හා සම්මත අපගමනය වේ.

යනු දර්ශකය සඳහා මධ්යන්ය හා සම්මත අපගමනය ගණනය කරනු ලබන දර්ශක සමූහයයි. සියලු දෙනාටම එක හා සමාන වන කට්ටලයේ ප්රමාණය වේ.

කණ්ඩායම් සාමාන්යකරණය, ස්ථර සාමාන්යකරණයසහ අවස්ථාසාමාන්යකරණයසඳහා අර්ථ දැක්වීම වෙනස් වේ.

කණ්ඩායම් සාමාන්යකරණය

එකමවිශේෂාංග නාලිකාව බෙදා ගන්නා අගයන් එකට සාමාන්යකරණය වේ.

ස්ථරය සාමාන්යකරණය

කණ්ඩායමේඑකම නියැදියක අගයන් එකට සාමාන්යකරණය වේ.

උදාහරණයක් සාමාන්යකරණය

එකමනියැදියක සහ එකම විශේෂාංග නාලිකාවේ අගයන් එකට සාමාන්යකරණය වේ.

කණ්ඩායම්සාමාන්යකරණය

කණ්ඩායම් ගණන කොතැනද සහ නාලිකා ගණන වේ.

සමූහසාමාන්යකරණය එකම නියැදියක සහ එකම නාලිකා සමූහයේ අගයන් සාමාන්යකරණය කරයි.

මෙන්න CIFAR තියෙන්නේ 10 උදාහරණයක් සාමාන්යකරණය භාවිතා කරන වර්ගීකරණය ආකෘතිය .

Open In Colab View Run

85import torch
86from torch import nn
87
88from labml_helpers.module import Module

කණ්ඩායම්සාමාන්යකරණය ස්ථරය

91class GroupNorm(Module):
  • groups යනු කණ්ඩායම් ගණන විශේෂාංග වලට බෙදා ඇත
  • channels යනු ආදානයේ ඇති විශේෂාංග ගණන
  • eps සංඛ්යාත්මක ස්ථායිතාව සඳහා භාවිතා වේ
  • affine සාමාන්යකරණය කළ අගය පරිමාණය කර මාරු කළ යුතුද යන්නයි
96    def __init__(self, groups: int, channels: int, *,
97                 eps: float = 1e-5, affine: bool = True):
104        super().__init__()
105
106        assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
107        self.groups = groups
108        self.channels = channels
109
110        self.eps = eps
111        self.affine = affine

පරිමාණයසහ මාරුව සඳහා සහ පරාමිතීන් සාදන්න

113        if self.affine:
114            self.scale = nn.Parameter(torch.ones(channels))
115            self.shift = nn.Parameter(torch.zeros(channels))

x හැඩයේ ආතන්ය [batch_size, channels, *] වේ. * ඕනෑම සංඛ්යාවක් (සමහරවිට 0) මානයන් දක්වයි. උදාහරණයක් ලෙස, රූපයක් (2D) සංකෝචනය තුළ මෙය වනු ඇත [batch_size, channels, height, width]

117    def forward(self, x: torch.Tensor):

මුල්හැඩය තබා ගන්න

125        x_shape = x.shape

කණ්ඩායම්ප්රමාණය ලබා ගන්න

127        batch_size = x_shape[0]

විශේෂාංගගණන සමාන බව තහවුරු කර ගැනීම සඳහා සනීපාරක්ෂාව පරීක්ෂා කරන්න

129        assert self.channels == x.shape[1]

නැවතහැඩගස්වා ගන්න [batch_size, groups, n]

132        x = x.view(batch_size, self.groups, -1)

අවසානමානය හරහා මධ්යන්යය ගණනය කරන්න; එනම් එක් එක් නියැදිය සහ නාලිකා කණ්ඩායම සඳහා මාධ්යයන්

136        mean = x.mean(dim=[-1], keepdim=True)

අවසානමානය හරහා වර්ග කළ මධ්යන්යය ගණනය කරන්න; එනම් එක් එක් නියැදිය සහ නාලිකා කණ්ඩායම සඳහා මාධ්යයන්

139        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)

එක්එක් නියැදිය සහ විශේෂාංග කණ්ඩායම සඳහා විචලතාව

142        var = mean_x2 - mean ** 2

සාමාන්‍යකරන්න

147        x_norm = (x - mean) / torch.sqrt(var + self.eps)

නාලිකාවඅනුව පරිමාණය සහ මාරුව

151        if self.affine:
152            x_norm = x_norm.view(batch_size, self.channels, -1)
153            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)

මුල්පිටපතට නැවත හැඩගස්වා නැවත පැමිණීම

156        return x_norm.view(x_shape)

සරලපරීක්ෂණය

159def _test():
163    from labml.logger import inspect
164
165    x = torch.zeros([2, 6, 2, 4])
166    inspect(x.shape)
167    bn = GroupNorm(2, 6)
168
169    x = bn(x)
170    inspect(x.shape)

174if __name__ == '__main__':
175    _test()