11import os
12import random
13from pathlib import Path
14
15import PIL
16import numpy as np
17import torch
18from PIL import Image
19
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel28def set_seed(seed: int):32 random.seed(seed)
33 np.random.seed(seed)
34 torch.manual_seed(seed)
35 torch.cuda.manual_seed_all(seed)LatentDiffusion
ආකෘතිය38def load_model(path: Path = None) -> LatentDiffusion:ස්වයංක්රීය එන්කෝඩරය ආරම්භ කරන්න
44 with monit.section('Initialize autoencoder'):
45 encoder = Encoder(z_channels=4,
46 in_channels=3,
47 channels=128,
48 channel_multipliers=[1, 2, 4, 4],
49 n_resnet_blocks=2)
50
51 decoder = Decoder(out_channels=3,
52 z_channels=4,
53 channels=128,
54 channel_multipliers=[1, 2, 4, 4],
55 n_resnet_blocks=2)
56
57 autoencoder = Autoencoder(emb_channels=4,
58 encoder=encoder,
59 decoder=decoder,
60 z_channels=4)CLIP පෙළ කාවැද්දීම ආරම්භ කරන්න
63 with monit.section('Initialize CLIP Embedder'):
64 clip_text_embedder = CLIPTextEmbedder()යූ-නෙට් ආරම්භ කරන්න
67 with monit.section('Initialize U-Net'):
68 unet_model = UNetModel(in_channels=4,
69 out_channels=4,
70 channels=320,
71 attention_levels=[0, 1, 2],
72 n_res_blocks=2,
73 channel_multipliers=[1, 2, 4, 4],
74 n_heads=8,
75 tf_layers=1,
76 d_cond=768)ගුප්ත විසරණය ආකෘතිය ආරම්භ කරන්න
79 with monit.section('Initialize Latent Diffusion model'):
80 model = LatentDiffusion(linear_start=0.00085,
81 linear_end=0.0120,
82 n_steps=1000,
83 latent_scaling_factor=0.18215,
84
85 autoencoder=autoencoder,
86 clip_embedder=clip_text_embedder,
87 unet_model=unet_model)මුරපොල පූරණය කරන්න
90 with monit.section(f"Loading model from {path}"):
91 checkpoint = torch.load(path, map_location="cpu")ආදර්ශ තත්වය සකසන්න
94 with monit.section('Load state'):
95 missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)ප්රතිදානය නිදොස්කරණය
98 inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99 _expand=True)102 model.eval()
103 return modelමෙය ගොනුවකින් රූපයක් පටවන අතර පයිටෝච් ටෙන්සරයක් නැවත ලබා දෙයි.
path
යනු රූපයේ මාර්ගයයි106def load_img(path: str):විවෘත රූපය
115 image = Image.open(path).convert("RGB")රූපයේ ප්රමාණය ලබා ගන්න
117 w, h = image.size32 ක බහුයකට වෙනස් කරන්න
119 w = w - w % 32
120 h = h - h % 32
121 image = image.resize((w, h), resample=PIL.Image.LANCZOS)[-1, 1]
සඳහා නොම්මර එකේ සහ සිතියම බවට පරිවර්තනය කරන්න[0, 255]
123 image = np.array(image).astype(np.float32) * (2. / 255.0) - 1හැඩයට සම්ප්රේෂණය කරන්න[batch_size, channels, height, width]
125 image = image[None].transpose(0, 3, 1, 2)පන්දම බවට පරිවර්තනය කරන්න
127 return torch.from_numpy(image)images
හැඩයේ රූප සහිත ටෙන්සරයයි[batch_size, channels, height, width]
dest_path
පින්තූර සුරැකීමට ෆෝල්ඩරයයිprefix
ගොනු නාම වලට එකතු කිරීමට උපසර්ගය වේimg_format
රූප ආකෘතිය වේ130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):ගමනාන්ත ෆෝල්ඩරය සාදන්න
141 os.makedirs(dest_path, exist_ok=True)සිතියම්ගත රූප[0, 1]
අවකාශයට සහ ක්ලිප්
144 images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)වෙත සම්ප්රේෂණය[batch_size, height, width, channels]
කර අංකයට පරිවර්තනය කරන්න
146 images = images.cpu().permute(0, 2, 3, 1).numpy()පින්තූර සුරකින්න
149 for i, img in enumerate(images):
150 img = Image.fromarray((255. * img).astype(np.uint8))
151 img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)