Denoising විසරණය සම්භාවිතාව ආකෘති (DDPM) නියැදීම

සරල DDPM ක්රියාත්මක කිරීම සඳහා අපගේ DDPM ක්රියාත්මක කිරීම වෙත යොමු වන්න. කාලසටහන් ආදිය සඳහා අපි එකම අංකන භාවිතා කරමු.

16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler

ඩීඩීපීඑම් නියැදි

මෙය DiffusionSampler මූලික පන්තිය පුළුල් කරයි.

පියවරෙන් පියවර නියැදීමෙන් ශබ්දය නැවත නැවතත් ඉවත් කිරීමෙන් ඩීඩීපීඑම් සාම්පල රූප,

26class DDPMSampler(DiffusionSampler):
49    model: LatentDiffusion
  • model ශබ්දය පුරෝකථනය කිරීමේ ආකෘතියයි
51    def __init__(self, model: LatentDiffusion):
55        super().__init__(model)

නියැදි පියවර

58        self.time_steps = np.asarray(list(range(self.n_steps)))
59
60        with torch.no_grad():

62            alpha_bar = self.model.alpha_bar

කාලසටහන

64            beta = self.model.beta

66            alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])

69            self.sqrt_alpha_bar = alpha_bar ** .5

71            self.sqrt_1m_alpha_bar = (1. - alpha_bar) ** .5

73            self.sqrt_recip_alpha_bar = alpha_bar ** -.5

75            self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .5

78            variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)

කලම්ප ලඝු-සටහන

80            self.log_var = torch.log(torch.clamp(variance, min=1e-20))

82            self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)

84            self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)

නියැදි ලූප

  • shape ස්වරූපයෙන් ජනනය කරන ලද රූපවල හැඩය[batch_size, channels, height, width]
  • cond කොන්දේසි සහිත කාවැද්දීම් වේ
  • temperature යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)
  • x_last වේ. සපයා නොමැති නම් අහඹු ශබ්දය භාවිතා කරනු ඇත.
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ
  • skip_steps මඟ හැරීමට කාල පියවර ගණන වේ. අපි නියැදීම ආරම්භ කරමු. එවිටx_last.
86    @torch.no_grad()
87    def sample(self,
88               shape: List[int],
89               cond: torch.Tensor,
90               repeat_noise: bool = False,
91               temperature: float = 1.,
92               x_last: Optional[torch.Tensor] = None,
93               uncond_scale: float = 1.,
94               uncond_cond: Optional[torch.Tensor] = None,
95               skip_steps: int = 0,
96               ):

උපාංගය සහ කණ්ඩායම් ප්රමාණය ලබා ගන්න

113        device = self.model.device
114        bs = shape[0]

ලබා ගන්න

117        x = x_last if x_last is not None else torch.randn(shape, device=device)

නියැදි කිරීමට කාල පියවර

120        time_steps = np.flip(self.time_steps)[skip_steps:]

නියැදි ලූපය

123        for step in monit.iterate('Sample', time_steps):

පියවර වේලාව

125            ts = x.new_full((bs,), step, dtype=torch.long)

නියැදිය

128            x, pred_x0, e_t = self.p_sample(x, cond, ts, step,
129                                            repeat_noise=repeat_noise,
130                                            temperature=temperature,
131                                            uncond_scale=uncond_scale,
132                                            uncond_cond=uncond_cond)

ආපසු

135        return x

වෙතින් නියැදිය

  • x හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
  • c හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, emb_size]
  • t හැඩයෙන් යුක්ත වේ[batch_size]
  • step යනු සංඛ්යාංකයක් ලෙස පියවරයි: පුනරාවර්ත_ශබ්දය: කණ්ඩායමේ සියලුම සාම්පල සඳහා ශබ්දය සමාන විය යුතුද යන්න නිශ්චිතව දක්වා ඇත
  • temperature යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ
137    @torch.no_grad()
138    def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,
139                 repeat_noise: bool = False,
140                 temperature: float = 1.,
141                 uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):

ලබා ගන්න

157        e_t = self.get_eps(x, t, c,
158                           uncond_scale=uncond_scale,
159                           uncond_cond=uncond_cond)

කණ්ඩායම් ප්රමාණය ලබා ගන්න

162        bs = x.shape[0]

165        sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])

167        sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])

ධාරාව සමඟ ගණනය කරන්න

172        x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t

175        mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])

177        mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])

ගණනය කරන්න

183        mean = mean_x0_coef * x0 + mean_xt_coef * x

185        log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])

(අවසාන පියවර නියැදි ක්රියාවලිය) විට ශබ්දය එකතු නොකරන්න. step එය0 කවදාද යන්න සලකන්න)

189        if step == 0:
190            noise = 0

කණ්ඩායමේ සියලුම සාම්පල සඳහා එකම ශබ්දය භාවිතා කරන්නේ නම්

192        elif repeat_noise:
193            noise = torch.randn((1, *x.shape[1:]))

එක් එක් නියැදිය සඳහා විවිධ ශබ්ද

195        else:
196            noise = torch.randn(x.shape)

උෂ්ණත්වය අනුව ශබ්දය ගුණ කරන්න

199        noise = noise * temperature

වෙතින් නියැදිය,

204        x_prev = mean + (0.5 * log_var).exp() * noise

207        return x_prev, x0, e_t

වෙතින් නියැදිය

  • x0 හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
  • index යනු කාල පියවර දර්ශකයයි
  • noise ශබ්දය,
209    @torch.no_grad()
210    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):

අහඹු ශබ්දය, ශබ්දය නිශ්චිතව දක්වා නොමැති නම්

222        if noise is None:
223            noise = torch.randn_like(x0)

වෙතින් නියැදිය

226        return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise