ස්ථාවර විසරණය සඳහා උපයෝගිතා කාර්යයන්

11import os
12import random
13from pathlib import Path
14
15import PIL
16import numpy as np
17import torch
18from PIL import Image
19
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel

අහඹු බීජ සකසන්න

28def set_seed(seed: int):
32    random.seed(seed)
33    np.random.seed(seed)
34    torch.manual_seed(seed)
35    torch.cuda.manual_seed_all(seed)

පැටවීමේ LatentDiffusion ආකෘතිය

38def load_model(path: Path = None) -> LatentDiffusion:

ස්වයංක්රීය එන්කෝඩරය ආරම්භ කරන්න

44    with monit.section('Initialize autoencoder'):
45        encoder = Encoder(z_channels=4,
46                          in_channels=3,
47                          channels=128,
48                          channel_multipliers=[1, 2, 4, 4],
49                          n_resnet_blocks=2)
50
51        decoder = Decoder(out_channels=3,
52                          z_channels=4,
53                          channels=128,
54                          channel_multipliers=[1, 2, 4, 4],
55                          n_resnet_blocks=2)
56
57        autoencoder = Autoencoder(emb_channels=4,
58                                  encoder=encoder,
59                                  decoder=decoder,
60                                  z_channels=4)

CLIP පෙළ කාවැද්දීම ආරම්භ කරන්න

63    with monit.section('Initialize CLIP Embedder'):
64        clip_text_embedder = CLIPTextEmbedder()

යූ-නෙට් ආරම්භ කරන්න

67    with monit.section('Initialize U-Net'):
68        unet_model = UNetModel(in_channels=4,
69                               out_channels=4,
70                               channels=320,
71                               attention_levels=[0, 1, 2],
72                               n_res_blocks=2,
73                               channel_multipliers=[1, 2, 4, 4],
74                               n_heads=8,
75                               tf_layers=1,
76                               d_cond=768)

ගුප්ත විසරණය ආකෘතිය ආරම්භ කරන්න

79    with monit.section('Initialize Latent Diffusion model'):
80        model = LatentDiffusion(linear_start=0.00085,
81                                linear_end=0.0120,
82                                n_steps=1000,
83                                latent_scaling_factor=0.18215,
84
85                                autoencoder=autoencoder,
86                                clip_embedder=clip_text_embedder,
87                                unet_model=unet_model)

මුරපොල පූරණය කරන්න

90    with monit.section(f"Loading model from {path}"):
91        checkpoint = torch.load(path, map_location="cpu")

ආදර්ශ තත්වය සකසන්න

94    with monit.section('Load state'):
95        missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)

ප්රතිදානය නිදොස්කරණය

98    inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99            _expand=True)

102    model.eval()
103    return model

රූපයක් පූරණය කරන්න

මෙය ගොනුවකින් රූපයක් පටවන අතර පයිටෝච් ටෙන්සරයක් නැවත ලබා දෙයි.

  • path යනු රූපයේ මාර්ගයයි
106def load_img(path: str):

විවෘත රූපය

115    image = Image.open(path).convert("RGB")

රූපයේ ප්රමාණය ලබා ගන්න

117    w, h = image.size

32 ක බහුයකට වෙනස් කරන්න

119    w = w - w % 32
120    h = h - h % 32
121    image = image.resize((w, h), resample=PIL.Image.LANCZOS)

[-1, 1] සඳහා නොම්මර එකේ සහ සිතියම බවට පරිවර්තනය කරන්න[0, 255]

123    image = np.array(image).astype(np.float32) * (2. / 255.0) - 1

හැඩයට සම්ප්රේෂණය කරන්න[batch_size, channels, height, width]

125    image = image[None].transpose(0, 3, 1, 2)

පන්දම බවට පරිවර්තනය කරන්න

127    return torch.from_numpy(image)

පින්තූර සුරකින්න

  • images හැඩයේ රූප සහිත ටෙන්සරයයි[batch_size, channels, height, width]
  • dest_path පින්තූර සුරැකීමට ෆෝල්ඩරයයි
  • prefix ගොනු නාම වලට එකතු කිරීමට උපසර්ගය වේ
  • img_format රූප ආකෘතිය වේ
130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):

ගමනාන්ත ෆෝල්ඩරය සාදන්න

141    os.makedirs(dest_path, exist_ok=True)

සිතියම්ගත රූප[0, 1] අවකාශයට සහ ක්ලිප්

144    images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)

වෙත සම්ප්රේෂණය[batch_size, height, width, channels] කර අංකයට පරිවර්තනය කරන්න

146    images = images.cpu().permute(0, 2, 3, 1).numpy()

පින්තූර සුරකින්න

149    for i, img in enumerate(images):
150        img = Image.fromarray((255. * img).astype(np.uint8))
151        img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)