18from typing import Optional, List
19
20import torch
21
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion

නියැදි ඇල්ගොරිතම සඳහා මූලික පන්තිය

25class DiffusionSampler:
29    model: LatentDiffusion
  • model ශබ්දය පුරෝකථනය කිරීමේ ආකෘතියයි
31    def __init__(self, model: LatentDiffusion):
35        super().__init__()

ආකෘතිය සකසන්න

37        self.model = model

ආකෘතිය පුහුණු කරන ලද පියවර ගණන ලබා ගන්න

39        self.n_steps = model.n_steps

ලබා ගන්න

  • x හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
  • t හැඩයෙන් යුක්ත වේ[batch_size]
  • c හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, emb_size]
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ
41    def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
42                uncond_scale: float, uncond_cond: Optional[torch.Tensor]):

පරිමාණය විට

55        if uncond_cond is None or uncond_scale == 1.:
56            return self.model(x, t, c)

අනුපිටපත් සහ

59        x_in = torch.cat([x] * 2)
60        t_in = torch.cat([t] * 2)

සංයුක්ත සහ

62        c_in = torch.cat([uncond_cond, c])

ලබා ගන්න

64        e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)

ගණනය කරන්න

67        e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)

70        return e_t

නියැදි ලූප

  • shape ස්වරූපයෙන් ජනනය කරන ලද රූපවල හැඩය[batch_size, channels, height, width]
  • cond කොන්දේසි සහිත කාවැද්දීම් වේ
  • temperature යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)
  • x_last වේ. සපයා නොමැති නම් අහඹු ශබ්දය භාවිතා කරනු ඇත.
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ
  • skip_steps මඟ හැරීමට කාල පියවර ගණන වේ.
72    def sample(self,
73               shape: List[int],
74               cond: torch.Tensor,
75               repeat_noise: bool = False,
76               temperature: float = 1.,
77               x_last: Optional[torch.Tensor] = None,
78               uncond_scale: float = 1.,
79               uncond_cond: Optional[torch.Tensor] = None,
80               skip_steps: int = 0,
81               ):
95        raise NotImplementedError()

පින්තාරු ලූප

  • x හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
  • cond කොන්දේසි සහිත කාවැද්දීම් වේ
  • t_start සිට ආරම්භ කිරීමට නියැදි පියවර වේ,
  • orig යනු මුල් රූපයයි ගුප්ත පිටුව අපි පැල්ලම් කරන.
  • mask මුල් රූපය තබා ගැනීම සඳහා වෙස්මුහුණ වේ.
  • orig_noise මුල් රූපයට එකතු කළ යුතු ස්ථාවර ශබ්දය.
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ
97    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
98              orig: Optional[torch.Tensor] = None,
99              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
100              uncond_scale: float = 1.,
101              uncond_cond: Optional[torch.Tensor] = None,
102              ):
116        raise NotImplementedError()

වෙතින් නියැදිය

  • x0 හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
  • index යනු කාල පියවර දර්ශකයයි
  • noise ශබ්දය,
118    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
126        raise NotImplementedError()