FNet: ෆූරියර් පරිණාමනය සමඟ ටෝකන මිශ්ර කිරීම

මෙයකඩදාසි FNet හි PyTorch ක්රියාත්මක කිරීම: ෆූරියර් ට්රාන්ස්ෆෝම් සමඟ ටෝකන මිශ්ර කිරීම.

මෙමකඩදාසි ස්වයං අවධානය ස්තරය ප්රතිස්ථාපනය කරයි ෆූරියර් දෙකක් ටෝකන මිශ්ර කිරීමට පරිවර්තනය කරයි. මෙය ස්වයං අවධානයට වඩා කාර්යක්ෂම වේ. ස්වයං අවධානයෙන් මෙය භාවිතා කිරීමේ නිරවද්යතාව නැතිවීම GLUE මිණුම් දණ්ඩේ BERT සඳහා 92% ක් පමණ වේ.

ෆූරියර්පරිණාමන දෙකක් සමඟ ටෝකන මිශ්ර කිරීම

අපිFourier සැඟවුණු මානයක් (කාවැද්දීම මානයක්) හා පසුව අනුක්රමය මානයක් ඔස්සේ පරිණාමනය අදාළ වේ.

කාවැද්දීමේ ආදානය කොතැනද, ෆූරියර් පරිණාමනය සඳහා පෙනී සිටින අතර සංකීර්ණ සංඛ්යා වල සැබෑ සංරචකය නියෝජනය කරයි.

PyTorchමත ක්රියාත්මක කිරීම සඳහා මෙය ඉතා සරලයි - කේත 1 පේළියක් පමණි. කඩදාසි යෝජනා කරන්නේ පෙර සැකසූ ඩීඑෆ්ටී න්යාසයක් භාවිතා කිරීම සහ ෆූරියර් පරිවර්තනය ලබා ගැනීම සඳහා අනුකෘති ගුණ කිරීම ය.

AG News වර්ගීකරණය කිරීම සඳහා FNet පදනම් කරගත් ආකෘතියක් භාවිතා කිරීම සඳහා පුහුණු කේතය මෙන්න.

41from typing import Optional
42
43import torch
44from torch import nn

FNet- ටෝකන මිශ්ර කරන්න

මෙමමොඩියුලය සරලව ක්රියාත්මක කරයි

මෙමමොඩියුලයේ ව්යුහය සම්මත අවධානය මොඩියුලයකට සමාන වන අතර එමඟින් අපට එය ප්රතිස්ථාපනය කළ හැකිය.

47class FNetMix(nn.Module):

සාමාන්ය අවධානය යොමු කිරීමේ මොඩියුලය සඳහා විවිධ ටෝකන කාවැද්දීම් සහ වෙස් මුහුණක් සමඟ පෝෂණය කළ හැකිය.

අපිඑකම ක්රියාකාරී අත්සන අනුගමනය කරන අතර එමඟින් අපට එය කෙලින්ම ප්රතිස්ථාපනය කළ හැකිය.

FNetමිශ්ර කිරීම සඳහා, සහ ආවරණ කිරීම කළ නොහැක. query ( සහ key සහ value ) හැඩය වේ [seq_len, batch_size, d_model] .

60    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):

,, සහ සියල්ලම ටෝකන් මිශ්ර කිරීම සඳහා සමාන විය යුතුය

72        assert query is key and key is value

ටෝකන්මිශ්ර කිරීම ආවරණ සඳහා සහය නොදක්වයි. එනම් සියලුම ටෝකන මගින් අනෙකුත් සියලුම ටෝකන කාවැද්දීම් දකිනු ඇත.

74        assert mask is None

පැහැදිලිකම x සඳහා පවරන්න

77        x = query

සැඟවුණු(කාවැද්දීම) මානය ඔස්සේ ෆූරියර් පරිණාමනය යොදන්න

ෆූරියර්පරිණාමයේ ප්රතිදානය සංකීර්ණ සංඛ්යාවල ආතතියකි.

84        fft_hidden = torch.fft.fft(x, dim=2)

අනුක්රමිකමානය ඔස්සේ ෆූරියර් පරිණාමනය යොදන්න

87        fft_seq = torch.fft.fft(fft_hidden, dim=0)

සැබෑසංරචකය ලබා ගන්න

91        return torch.real(fft_seq)