සරල DDPM ක්රියාත්මක කිරීම සඳහා අපගේ DDPM ක්රියාත්මක කිරීම වෙත යොමු වන්න. කාලසටහන් ආදිය සඳහා අපි එකම අංකන භාවිතා කරමු.
16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSamplerමෙය DiffusionSampler
මූලික පන්තිය පුළුල් කරයි.
පියවරෙන් පියවර නියැදීමෙන් ශබ්දය නැවත නැවතත් ඉවත් කිරීමෙන් ඩීඩීපීඑම් සාම්පල රූප,
26class DDPMSampler(DiffusionSampler):49 model: LatentDiffusionmodel
ශබ්දය පුරෝකථනය කිරීමේ ආකෘතියයි51 def __init__(self, model: LatentDiffusion):55 super().__init__(model)නියැදි පියවර
58 self.time_steps = np.asarray(list(range(self.n_steps)))
59
60 with torch.no_grad():62 alpha_bar = self.model.alpha_barකාලසටහන
64 beta = self.model.beta66 alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])69 self.sqrt_alpha_bar = alpha_bar ** .571 self.sqrt_1m_alpha_bar = (1. - alpha_bar) ** .573 self.sqrt_recip_alpha_bar = alpha_bar ** -.575 self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .578 variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)කලම්ප ලඝු-සටහන
80 self.log_var = torch.log(torch.clamp(variance, min=1e-20))82 self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)84 self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)shape
ස්වරූපයෙන් ජනනය කරන ලද රූපවල හැඩය[batch_size, channels, height, width]
cond
කොන්දේසි සහිත කාවැද්දීම් වේtemperature
යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)x_last
වේ. සපයා නොමැති නම් අහඹු ශබ්දය භාවිතා කරනු ඇත.uncond_scale
යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේuncond_cond
හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේskip_steps
මඟ හැරීමට කාල පියවර ගණන වේ. අපි නියැදීම ආරම්භ කරමු. එවිටx_last
ය.86 @torch.no_grad()
87 def sample(self,
88 shape: List[int],
89 cond: torch.Tensor,
90 repeat_noise: bool = False,
91 temperature: float = 1.,
92 x_last: Optional[torch.Tensor] = None,
93 uncond_scale: float = 1.,
94 uncond_cond: Optional[torch.Tensor] = None,
95 skip_steps: int = 0,
96 ):උපාංගය සහ කණ්ඩායම් ප්රමාණය ලබා ගන්න
113 device = self.model.device
114 bs = shape[0]ලබා ගන්න
117 x = x_last if x_last is not None else torch.randn(shape, device=device)නියැදි කිරීමට කාල පියවර
120 time_steps = np.flip(self.time_steps)[skip_steps:]නියැදි ලූපය
123 for step in monit.iterate('Sample', time_steps):පියවර වේලාව
125 ts = x.new_full((bs,), step, dtype=torch.long)නියැදිය
128 x, pred_x0, e_t = self.p_sample(x, cond, ts, step,
129 repeat_noise=repeat_noise,
130 temperature=temperature,
131 uncond_scale=uncond_scale,
132 uncond_cond=uncond_cond)ආපසු
135 return xx
හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
c
හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, emb_size]
t
හැඩයෙන් යුක්ත වේ[batch_size]
step
යනු සංඛ්යාංකයක් ලෙස පියවරයි: පුනරාවර්ත_ශබ්දය: කණ්ඩායමේ සියලුම සාම්පල සඳහා ශබ්දය සමාන විය යුතුද යන්න නිශ්චිතව දක්වා ඇතtemperature
යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)uncond_scale
යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේuncond_cond
හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේ137 @torch.no_grad()
138 def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,
139 repeat_noise: bool = False,
140 temperature: float = 1.,
141 uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):ලබා ගන්න
157 e_t = self.get_eps(x, t, c,
158 uncond_scale=uncond_scale,
159 uncond_cond=uncond_cond)කණ්ඩායම් ප්රමාණය ලබා ගන්න
162 bs = x.shape[0]165 sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])167 sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])172 x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t175 mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])177 mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])183 mean = mean_x0_coef * x0 + mean_xt_coef * x185 log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])(අවසාන පියවර නියැදි ක්රියාවලිය) විට ශබ්දය එකතු නොකරන්න. step
එය0
කවදාද යන්න සලකන්න)
189 if step == 0:
190 noise = 0කණ්ඩායමේ සියලුම සාම්පල සඳහා එකම ශබ්දය භාවිතා කරන්නේ නම්
192 elif repeat_noise:
193 noise = torch.randn((1, *x.shape[1:]))එක් එක් නියැදිය සඳහා විවිධ ශබ්ද
195 else:
196 noise = torch.randn(x.shape)උෂ්ණත්වය අනුව ශබ්දය ගුණ කරන්න
199 noise = noise * temperature204 x_prev = mean + (0.5 * log_var).exp() * noise207 return x_prev, x0, e_t
x0
හැඩයෙන් යුක්ත වේ[batch_size, channels, height, width]
index
යනු කාල පියවර දර්ශකයයිnoise
ශබ්දය,209 @torch.no_grad()
210 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):අහඹු ශබ්දය, ශබ්දය නිශ්චිතව දක්වා නොමැති නම්
222 if noise is None:
223 noise = torch.randn_like(x0)වෙතින් නියැදිය
226 return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise