Denoising විසරණ ව්යංග ආකෘති (DDIM) නියැදීම

මෙය කඩදාසි වලින් DDIM නියැදීම ක්රියාත්මක කරයි Denoising Diffusion Implicit ආකෘති Denoising

16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler

ඩීඩීඅයිඑම් නියැදිකරු

මෙය DiffusionSampler මූලික පන්තිය පුළුල් කරයි.

ඩීඩීපීඑම් සාම්පල රූප භාවිතා කරමින් පියවරෙන් පියවර නියැදීමෙන් ශබ්දය නැවත නැවතත් ඉවත් කිරීමෙන්,

අහඹු ශබ්දය යනු කොතැනද, දිග අනුක්රමයකි, සහ

ඩීඩීඅයිඑම් කඩදාසි වල ඩීඩීපීඑම් වෙතින් සඳහන් වන බව සලකන්න.

26class DDIMSampler(DiffusionSampler):
52    model: LatentDiffusion
  • model ශබ්දය පුරෝකථනය කිරීමේ ආකෘතියයි
  • n_steps DDIM නියැදි පියවර ගණන,
  • ddim_discretize උපුටා ගන්නේ කෙසේද යන්න නියම කරයි. එය එක්කෝuniform හෝ විය හැකියquad .
  • ddim_eta ගණනය කිරීමට භාවිතා වේ. නියැදි ක්රියාවලිය තීරණය කරයි.
54    def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
63        super().__init__(model)

පියවර ගණන,

65        self.n_steps = model.n_steps

ඒකාකාරව බෙදා හැරීමට ගණනය කරන්න

68        if ddim_discretize == 'uniform':
69            c = self.n_steps // n_steps
70            self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1

චතුරස්රාකාර ලෙස බෙදා හැරීමට ගණනය කරන්න

72        elif ddim_discretize == 'quad':
73            self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
74        else:
75            raise NotImplementedError(ddim_discretize)
76
77        with torch.no_grad():

ලබා ගන්න

79            alpha_bar = self.model.alpha_bar

82            self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)

84            self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)

86            self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])

91            self.ddim_sigma = (ddim_eta *
92                               ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
93                                (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)

96            self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5

නියැදි ලූප

  • shape ස්වරූපයෙන් ජනනය කරන ලද රූපවල හැඩය[batch_size, channels, height, width]
  • cond කොන්දේසි සහිත කාවැද්දීම් වේ
  • temperature යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)
  • x_last වේ. සපයා නොමැති නම් අහඹු ශබ්දය භාවිතා කරනු ඇත.
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ
  • skip_steps මඟ හැරීමට කාල පියවර ගණන වේ. අපි නියැදීම ආරම්භ කරමු. එවිටx_last.
98    @torch.no_grad()
99    def sample(self,
100               shape: List[int],
101               cond: torch.Tensor,
102               repeat_noise: bool = False,
103               temperature: float = 1.,
104               x_last: Optional[torch.Tensor] = None,
105               uncond_scale: float = 1.,
106               uncond_cond: Optional[torch.Tensor] = None,
107               skip_steps: int = 0,
108               ):

උපාංගය සහ කණ්ඩායම් ප්රමාණය ලබා ගන්න

125        device = self.model.device
126        bs = shape[0]

ලබා ගන්න

129        x = x_last if x_last is not None else torch.randn(shape, device=device)

නියැදි කිරීමට කාල පියවර

132        time_steps = np.flip(self.time_steps)[skip_steps:]
133
134        for i, step in monit.enum('Sample', time_steps):

ලැයිස්තුවේ දර්ශකය

136            index = len(time_steps) - i - 1

පියවර වේලාව

138            ts = x.new_full((bs,), step, dtype=torch.long)

නියැදිය

141            x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
142                                            repeat_noise=repeat_noise,
143                                            temperature=temperature,
144                                            uncond_scale=uncond_scale,
145                                            uncond_cond=uncond_cond)

ආපසු

148        return x

නියැදිය

  • x හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
  • c හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, emb_size]
  • t හැඩයෙන් යුක්ත වේ[batch_size]
  • step යනු පූර්ණ සංඛ්යාවක් ලෙස පියවරයි
  • index ලැයිස්තුවේ දර්ශකය වේ
  • repeat_noise කණ්ඩායමේ සියලුම සාම්පල සඳහා ශබ්දය සමාන විය යුතුද යන්න නිශ්චිතව දක්වා ඇත
  • temperature යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ
150    @torch.no_grad()
151    def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
152                 repeat_noise: bool = False,
153                 temperature: float = 1.,
154                 uncond_scale: float = 1.,
155                 uncond_cond: Optional[torch.Tensor] = None):

ලබා ගන්න

172        e_t = self.get_eps(x, t, c,
173                           uncond_scale=uncond_scale,
174                           uncond_cond=uncond_cond)

ගණනය කර පුරෝකථනය කර ඇත

177        x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
178                                                      temperature=temperature,
179                                                      repeat_noise=repeat_noise)

182        return x_prev, pred_x0, e_t

ලබා දී ඇති නියැදිය

184    def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
185                               temperature: float,
186                               repeat_noise: bool):

192        alpha = self.ddim_alpha[index]

194        alpha_prev = self.ddim_alpha_prev[index]

196        sigma = self.ddim_sigma[index]

198        sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]

සඳහා වත්මන් අනාවැකිය,

202        pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)

දිශාව යොමු කරයි

205        dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t

ශබ්දයක් එකතු නොවේ, විට

208        if sigma == 0.:
209            noise = 0.

කණ්ඩායමේ සියලුම සාම්පල සඳහා එකම ශබ්දය භාවිතා කරන්නේ නම්

211        elif repeat_noise:
212            noise = torch.randn((1, *x.shape[1:]), device=x.device)

එක් එක් නියැදිය සඳහා විවිධ ශබ්ද

214        else:
215            noise = torch.randn(x.shape, device=x.device)

උෂ්ණත්වය අනුව ශබ්දය ගුණ කරන්න

218        noise = noise * temperature

227        x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise

230        return x_prev, pred_x0

වෙතින් නියැදිය

  • x0 හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
  • index යනු කාල පියවර දර්ශකයයි
  • noise ශබ්දය,
232    @torch.no_grad()
233    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):

අහඹු ශබ්දය, ශබ්දය නිශ්චිතව දක්වා නොමැති නම්

246        if noise is None:
247            noise = torch.randn_like(x0)

වෙතින් නියැදිය

252        return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise

පින්තාරු ලූප

  • x හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
  • cond කොන්දේසි සහිත කාවැද්දීම් වේ
  • t_start සිට ආරම්භ කිරීමට නියැදි පියවර වේ,
  • orig යනු මුල් රූපයයි ගුප්ත පිටුව අපි පැල්ලම් කරන. මෙය සපයා නොමැති නම්, එය රූප පරිවර්තනයට රූපයක් වනු ඇත.
  • mask මුල් රූපය තබා ගැනීම සඳහා වෙස්මුහුණ වේ.
  • orig_noise මුල් රූපයට එකතු කළ යුතු ස්ථාවර ශබ්දය.
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ
254    @torch.no_grad()
255    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
256              orig: Optional[torch.Tensor] = None,
257              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
258              uncond_scale: float = 1.,
259              uncond_cond: Optional[torch.Tensor] = None,
260              ):

කණ්ඩායම් ප්රමාණය ලබා ගන්න

276        bs = x.shape[0]

නියැදි කිරීමට කාල පියවර

279        time_steps = np.flip(self.time_steps[:t_start])
280
281        for i, step in monit.enum('Paint', time_steps):

ලැයිස්තුවේ දර්ශකය

283            index = len(time_steps) - i - 1

පියවර වේලාව

285            ts = x.new_full((bs,), step, dtype=torch.long)

නියැදිය

288            x, _, _ = self.p_sample(x, cond, ts, step, index=index,
289                                    uncond_scale=uncond_scale,
290                                    uncond_cond=uncond_cond)

වෙස් ගත් ප්රදේශය මුල් රූපය සමඟ ප්රතිස්ථාපනය කරන්න

293            if orig is not None:

ගුප්ත අවකාශයේ මුල් රූපය සඳහා ලබා ගන්න

295                orig_t = self.q_sample(orig, index, noise=orig_noise)

වෙස්ගත් ප්රදේශය ප්රතිස්ථාපනය කරන්න

297                x = orig_t * mask + x * (1 - mask)

300        return x