විමසුමක් සමඟ ස්ථාවර විසරණය භාවිතා කරමින් තීන්ත ආලේපන රූප

11import argparse
12from pathlib import Path
13from typing import Optional
14
15import torch
16
17from labml import lab, monit
18from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
19from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
20from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
21from labml_nn.diffusion.stable_diffusion.util import load_model, save_images, load_img, set_seed

රූපයේ පින්තාරු කිරීමේ පන්තිය

24class InPaint:
28    model: LatentDiffusion
29    sampler: DiffusionSampler
  • checkpoint_path යනු මුරපොලේ මාර්ගයයි
  • ddim_steps නියැදි පියවර ගණන වේ
  • ddim_eta DDIM නියැදි නියතය
31    def __init__(self, *, checkpoint_path: Path,
32                 ddim_steps: int = 50,
33                 ddim_eta: float = 0.0):
39        self.ddim_steps = ddim_steps
42        self.model = load_model(checkpoint_path)

උපාංගය ලබා ගන්න

44        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

ආකෘතිය උපාංගයට ගෙන යන්න

46        self.model.to(self.device)
49        self.sampler = DDIMSampler(self.model,
50                                   n_steps=ddim_steps,
51                                   ddim_eta=ddim_eta)
  • dest_path ජනනය කරන ලද රූප ගබඩා කිරීමේ මාර්ගයයි
  • orig_img යනු පරිවර්තනය කිරීමට රූපයයි
  • strength මුල් රූපයේ කොපමණ ප්රමාණයක් සංරක්ෂණය නොකළ යුතුද යන්න නියම කරයි
  • batch_size යනු කණ්ඩායමක් තුළ ජනනය කළ යුතු රූප ගණන
  • prompt සමඟ රූප ජනනය කිරීමේ විමසුම
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශ පරිමාණයයි. මෙය භාවිතා වේ
53    @torch.no_grad()
54    def __call__(self, *,
55                 dest_path: str,
56                 orig_img: str,
57                 strength: float,
58                 batch_size: int = 3,
59                 prompt: str,
60                 uncond_scale: float = 5.0,
61                 mask: Optional[torch.Tensor] = None,
62                 ):

විමසුම් කණ්ඩායමක් සාදන්න

73        prompts = batch_size * [prompt]

රූපය පටවන්න

75        orig_image = load_img(orig_img).to(self.device)

ගුප්ත අවකාශයේ රූපය කේතනය කර එහිbatch_size පිටපත් සාදන්න

77        orig = self.model.autoencoder_encode(orig_image).repeat(batch_size, 1, 1, 1)

සපයා නොමැතිmask නම්, රූපයේ පහළ භාගය ආරක්ෂා කර ගැනීම සඳහා අපි නියැදි ආවරණයක් සකස් කරමු

80        if mask is None:
81            mask = torch.zeros_like(orig, device=self.device)
82            mask[:, :, mask.shape[2] // 2:, :] = 1.
83        else:
84            mask = mask.to(self.device)

ශබ්දය මුල් රූපය විසුරුවා හරින්න

86        orig_noise = torch.randn(orig.shape, device=self.device)

මුල් පිටපත විසුරුවා හැරීමට පියවර ගණන ලබා ගන්න

89        assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
90        t_index = int(strength * self.ddim_steps)

AMP වාහන වාත්තු

93        with torch.cuda.amp.autocast():

කොන්දේසි විරහිත පරිමාණය තුළ හිස් විමසීම් සඳහා කාවැද්දීම් ලබා නොගනී (කන්ඩිෂනේෂන් නැත).

95            if uncond_scale != 1.0:
96                un_cond = self.model.get_text_conditioning(batch_size * [""])
97            else:
98                un_cond = None

කඩිනම් කාවැද්දීම් ලබා ගන්න

100            cond = self.model.get_text_conditioning(prompts)

මුල් රූපයට ශබ්දය එක් කරන්න

102            x = self.sampler.q_sample(orig, t_index, noise=orig_noise)

වෙස්ගත් ප්රදේශය ආරක්ෂා කර ගනිමින් is ෝෂාකාරී රූපයෙන් ප්රතිනිර්මාණය කරන්න

104            x = self.sampler.paint(x, cond, t_index,
105                                   orig=orig,
106                                   mask=mask,
107                                   orig_noise=orig_noise,
108                                   uncond_scale=uncond_scale,
109                                   uncond_cond=un_cond)
111            images = self.model.autoencoder_decode(x)

පින්තූර සුරකින්න

114        save_images(images, dest_path, 'paint_')

CLI

117def main():
121    parser = argparse.ArgumentParser()
122
123    parser.add_argument(
124        "--prompt",
125        type=str,
126        nargs="?",
127        default="a painting of a cute monkey playing guitar",
128        help="the prompt to render"
129    )
130
131    parser.add_argument(
132        "--orig-img",
133        type=str,
134        nargs="?",
135        help="path to the input image"
136    )
137
138    parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
139    parser.add_argument("--steps", type=int, default=50, help="number of sampling steps")
140
141    parser.add_argument("--scale", type=float, default=5.0,
142                        help="unconditional guidance scale: "
143                             "eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
144
145    parser.add_argument("--strength", type=float, default=0.75,
146                        help="strength for noise: "
147                             " 1.0 corresponds to full destruction of information in init image")
148
149    opt = parser.parse_args()
150    set_seed(42)
151
152    in_paint = InPaint(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
153                       ddim_steps=opt.steps)
154
155    with monit.section('Generate'):
156        in_paint(dest_path='outputs',
157                 orig_img=opt.orig_img,
158                 strength=opt.strength,
159                 batch_size=opt.batch_size,
160                 prompt=opt.prompt,
161                 uncond_scale=opt.scale)

165if __name__ == "__main__":
166    main()