ස්ථාවර විසරණය යූ-නෙට් සඳහා ට්රාන්ස්ෆෝමර්

මෙය ලබා දෙන යූ-නෙට් හි භාවිතා කරන ට්රාන්ස්ෆෝමර් මොඩියුලය ක්රියාත්මක කරයි

අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.

19from typing import Optional
20
21import torch
22import torch.nn.functional as F
23from torch import nn

අවකාශීය ට්රාන්ස්ෆෝමර්

26class SpatialTransformer(nn.Module):
  • channels විශේෂාංග සිතියමේ නාලිකා ගණන
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • n_layers ට්රාන්ස්ෆෝමර් ස්ථර ගණන
  • d_cond යනු කොන්දේසි සහිත කාවැද්දිවල ප්රමාණයයි
31    def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):
38        super().__init__()

ආරම්භක කණ්ඩායම් සාමාන්යකරණය

40        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)

මූලික කැටි ගැසිම

42        self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)

ට්රාන්ස්ෆෝමර් ස්ථර

45        self.transformer_blocks = nn.ModuleList(
46            [BasicTransformerBlock(channels, n_heads, channels // n_heads, d_cond=d_cond) for _ in range(n_layers)]
47        )

අවසාන කැටි ගැසිම

50        self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
  • x හැඩයේ විශේෂාංග සිතියමයි[batch_size, channels, height, width]
  • cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
  • 52    def forward(self, x: torch.Tensor, cond: torch.Tensor):

    හැඩය ලබා ගන්න[batch_size, channels, height, width]

    58        b, c, h, w = x.shape

    අවශේෂ සම්බන්ධතාවය සඳහා

    60        x_in = x

    සාමාන්‍ය කරන්න

    62        x = self.norm(x)

    මූලික කැටි ගැසිම

    64        x = self.proj_in(x)

    සිට සම්ප්රේෂණය කර නැවත හැඩගස්වා[batch_size, channels, height, width] ගන්න[batch_size, height * width, channels]

    67        x = x.permute(0, 2, 3, 1).view(b, h * w, c)

    ට්රාන්ස්ෆෝමර් ස්ථර යොදන්න

    69        for block in self.transformer_blocks:
    70            x = block(x, cond)

    නැවත හැඩගස්වා සිට සම්ප්රේෂණය[batch_size, height * width, channels] කරන්න[batch_size, channels, height, width]

    73        x = x.view(b, h, w, c).permute(0, 3, 1, 2)

    අවසාන කැටි ගැසිම

    75        x = self.proj_out(x)

    අවශේෂ එකතු කරන්න

    77        return x + x_in

    ට්රාන්ස්ෆෝමර් ස්ථරය

    80class BasicTransformerBlock(nn.Module):
    • d_model ආදාන කාවැද්දීමේ ප්රමාණයයි
    • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
    • d_head අවධානය යොමු හිසෙහි ප්රමාණයයි
    • d_cond යනු කොන්දේසි සහිත කාවැද්දීම් වල ප්රමාණයයි
    85    def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
    92        super().__init__()

    ස්වයං අවධානය ස්ථරය හා පෙර-සම්මත ස්ථරය

    94        self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
    95        self.norm1 = nn.LayerNorm(d_model)

    හරස් අවධානය ස්ථරය සහ පෙර-සම්මත ස්ථරය

    97        self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
    98        self.norm2 = nn.LayerNorm(d_model)

    Feed-ඉදිරි ජාලය සහ පෙර-සම්මත ස්ථරය

    100        self.ff = FeedForward(d_model)
    101        self.norm3 = nn.LayerNorm(d_model)
    • x හැඩයේ ආදාන කාවැද්දීම් වේ[batch_size, height * width, d_model]
  • cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
  • 103    def forward(self, x: torch.Tensor, cond: torch.Tensor):

    ස්වයං අවධානය

    109        x = self.attn1(self.norm1(x)) + x

    කන්ඩිෂනේෂන් සමඟ හරස් අවධානය

    111        x = self.attn2(self.norm2(x), cond=cond) + x

    Feed-ඉදිරි ජාලය

    113        x = self.ff(self.norm3(x)) + x

    115        return x

    හරස් අවධානය ස්ථරය

    කොන්දේසි සහිත කාවැද්දීම් නිශ්චිතව දක්වා නොමැති විට මෙය ස්වයං අවධානයට යොමු වේ.

    118class CrossAttention(nn.Module):
    125    use_flash_attention: bool = False
    • d_model ආදාන කාවැද්දීමේ ප්රමාණයයි
    • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
    • d_head අවධානය යොමු හිසෙහි ප්රමාණයයි
    • d_cond යනු කොන්දේසි සහිත කාවැද්දීම් වල ප්රමාණයයි
    • is_inplace මතකය ඉතිරි කර ගැනීම සඳහා අවධානය softmax ගණනය inplace ඉටු කිරීමට යන්න නියම
    127    def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
    136        super().__init__()
    137
    138        self.is_inplace = is_inplace
    139        self.n_heads = n_heads
    140        self.d_head = d_head

    අවධානය පරිමාණ සාධකය

    143        self.scale = d_head ** -0.5

    විමසුම්, යතුර සහ අගය සිතියම්

    146        d_attn = d_head * n_heads
    147        self.to_q = nn.Linear(d_model, d_attn, bias=False)
    148        self.to_k = nn.Linear(d_cond, d_attn, bias=False)
    149        self.to_v = nn.Linear(d_cond, d_attn, bias=False)

    අවසාන රේඛීය ස්ථරය

    152        self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))

    සැකසුම ෆ්ලෑෂ් අවධානය. ෆ්ලෑෂ් අවධානය භාවිතා කරනු ලබන්නේ එය ස්ථාපනය කර ඇත්නම් සහCrossAttention.use_flash_attention එය සකසා ඇත්නම් පමණිTrue .

    157        try:

    ක්ලෝනකරණය කිරීමෙන් ඔබට ෆ්ලෑෂ් අවධානය ස්ථාපනය කළ හැකිය Github repo, https://github.com/HazyResearch/flash-attention ඉන්පසු ධාවනයpython setup.py install

    161            from flash_attn.flash_attention import FlashAttention
    162            self.flash = FlashAttention()

    පරිමාණ තිත් නිෂ්පාදන අවධානය සඳහා පරිමාණය සකසන්න.

    164            self.flash.softmax_scale = self.scale

    එය ස්ථාපනය කර නොමැතිNone නම් සකසන්න

    166        except ImportError:
    167            self.flash = None
    • x හැඩයේ ආදාන කාවැද්දීම් වේ[batch_size, height * width, d_model]
  • cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
  • 169    def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):

    None අපි ස්වයං අවධානය යොමුcond කරන්නේ නම්

    176        has_cond = cond is not None
    177        if not has_cond:
    178            cond = x

    විමසුම, යතුර සහ අගය දෛශික ලබා ගන්න

    181        q = self.to_q(x)
    182        k = self.to_k(cond)
    183        v = self.to_v(cond)

    ෆ්ලෑෂ් අවධානය ලබා ගත හැකි නම් සහ හිස ප්රමාණය අඩු හෝ සමාන නම් භාවිතා කරන්න128

    186        if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
    187            return self.flash_attention(q, k, v)

    එසේ නොමැති නම්, සාමාන්ය අවධානයට වැටීම

    189        else:
    190            return self.normal_attention(q, k, v)

    ෆ්ලෑෂ් අවධානය

    • q හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
    • k හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
    • v හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
    192    def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

    අනුක්රමික අක්ෂය ඔස්සේ කණ්ඩායම් ප්රමාණය සහ මූලද්රව්ය ගණන ලබා ගන්න (width * height )

    202        batch_size, seq_len, _ = q.shape

    ෆ්ලෑෂ් අවධානය සඳහාv දෛශිකq k , හැඩයේ තනි ටෙන්සරයක් ලබා ගැනීමට[batch_size, seq_len, 3, n_heads * d_head]

    206        qkv = torch.stack((q, k, v), dim=2)

    හිස් බෙදන්න

    208        qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)

    ෆ්ලෑෂ් අවධානය හිස් ප්රමාණ සඳහා ක්රියා කරයි32 128 ,64 සහ, එබැවින් මෙම ප්රමාණයට සරිලන පරිදි හිස් පෑඩ් කළ යුතුය.

    212        if self.d_head <= 32:
    213            pad = 32 - self.d_head
    214        elif self.d_head <= 64:
    215            pad = 64 - self.d_head
    216        elif self.d_head <= 128:
    217            pad = 128 - self.d_head
    218        else:
    219            raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')

    හිස් පෑඩ් කරන්න

    222        if pad:
    223            qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)

    අවධානය ගණනය කරන්න මෙය හැඩයේ ආතතිකයක් ලබා දෙයි[batch_size, seq_len, n_heads, d_padded]

    228        out, _ = self.flash(qkv)

    අමතර හිස ප්රමාණය කපා

    230        out = out[:, :, :, :self.d_head]

    නැවත හැඩගස්වන්න[batch_size, seq_len, n_heads * d_head]

    232        out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)

    රේඛීය ස්ථරයක්[batch_size, height * width, d_model] සමඟ සිතියම

    235        return self.to_out(out)

    සාමාන්ය අවධානය

    • q හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
    • k හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
    • v හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
    237    def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

    හැඩයේ හිස් වලට බෙදන්න[batch_size, seq_len, n_heads, d_head]

    247        q = q.view(*q.shape[:2], self.n_heads, -1)
    248        k = k.view(*k.shape[:2], self.n_heads, -1)
    249        v = v.view(*v.shape[:2], self.n_heads, -1)

    අවධානය ගණනය කරන්න

    252        attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale

    සොෆ්ට්මැක්ස් ගණනය කරන්න

    256        if self.is_inplace:
    257            half = attn.shape[0] // 2
    258            attn[half:] = attn[half:].softmax(dim=-1)
    259            attn[:half] = attn[:half].softmax(dim=-1)
    260        else:
    261            attn = attn.softmax(dim=-1)

    අවධානය ප්රතිදානය ගණනය

    265        out = torch.einsum('bhij,bjhd->bihd', attn, v)

    නැවත හැඩගස්වන්න[batch_size, height * width, n_heads * d_head]

    267        out = out.reshape(*out.shape[:2], -1)

    රේඛීය ස්ථරයක්[batch_size, height * width, d_model] සමඟ සිතියම

    269        return self.to_out(out)

    Feed-ඉදිරි ජාලය

    272class FeedForward(nn.Module):
    • d_model ආදාන කාවැද්දීමේ ප්රමාණයයි
    • d_mult සැඟවුණු ස්ථර ප්රමාණය සඳහා බහුකාර්ය සාධකයකි
    277    def __init__(self, d_model: int, d_mult: int = 4):
    282        super().__init__()
    283        self.net = nn.Sequential(
    284            GeGLU(d_model, d_model * d_mult),
    285            nn.Dropout(0.),
    286            nn.Linear(d_model * d_mult, d_model)
    287        )
    289    def forward(self, x: torch.Tensor):
    290        return self.net(x)

    Glu සක්රිය කිරීම

    293class GeGLU(nn.Module):
    300    def __init__(self, d_in: int, d_out: int):
    301        super().__init__()

    ඒකාබද්ධ රේඛීය ප්රක්ෂේපණ සහ

    303        self.proj = nn.Linear(d_in, d_out * 2)
    305    def forward(self, x: torch.Tensor):

    ලබා ගන්න

    307        x, gate = self.proj(x).chunk(2, dim=-1)

    309        return x * F.gelu(gate)