CIFAR10 මත දර්ශන ට්රාන්ස්ෆෝමර් (VIT) පුහුණු කරන්න

View Run

13from labml import experiment
14from labml.configs import option
15from labml_nn.experiments.cifar10 import CIFAR10Configs
16from labml_nn.transformers import TransformerConfigs

වින්යාසකිරීම්

සියලුමදත්ත කට්ටල ආශ්රිත වින්යාසයන්, ප්රශස්තකරණය සහ පුහුණු ලූපයක් නිර්වචනය කරන අපි භාවිතා CIFAR10Configs කරමු.

19class Configs(CIFAR10Configs):
29    transformer: TransformerConfigs

පැච්එකක ප්රමාණය

32    patch_size: int = 4

වර්ගීකරණහිසෙහි සැඟවුණු ස්ථරයේ ප්රමාණය

34    n_hidden_classification: int = 2048

කර්තව්යයේපන්ති ගණන

36    n_classes: int = 10

ට්රාන්ස්ෆෝමර්වින්යාස සාදන්න

39@option(Configs.transformer)
40def _transformer():
44    return TransformerConfigs()

ආකෘතියසාදන්න

47@option(Configs.model)
48def _vit(c: Configs):
52    from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
53        PatchEmbeddings

ට්රාන්ස්ෆෝමර් මානකරණ සිට ට්රාන්ස්ෆෝමර් ප්රමාණය

56    d_model = c.transformer.d_model

දර්ශනට්රාන්ස්ෆෝමරයක් සාදන්න

58    return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
59                             PatchEmbeddings(d_model, c.patch_size, 3),
60                             LearnedPositionalEmbeddings(d_model),
61                             ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
64def main():

අත්හදාබැලීම සාදන්න

66    experiment.create(name='ViT', comment='cifar10')

වින්යාසයන්සාදන්න

68    conf = Configs()

වින්යාසයන්පූරණය කරන්න

70    experiment.configs(conf, {

ප්රශස්තකරණය

72        'optimizer.optimizer': 'Adam',
73        'optimizer.learning_rate': 2.5e-4,

ට්රාන්ස්ෆෝමර්කාවැද්දීමේ ප්රමාණය

76        'transformer.d_model': 512,

ඊපොච්සහ කණ්ඩායම් ප්රමාණය පුහුණු කිරීම

79        'epochs': 32,
80        'train_batch_size': 64,

පුහුණුකිරීම සඳහා CIFAR 10 රූප

83        'train_dataset': 'cifar10_train_augmented',

CIFARවර්ධනය කරන්න එපා 10 වලංගු කිරීම සඳහා රූප

85        'valid_dataset': 'cifar10_valid_no_augment',
86    })

ඉතිරිකිරීම/පැටවීම සඳහා ආකෘතිය සකසන්න

88    experiment.add_pytorch_models({'model': conf.model})

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

90    with experiment.start():
91        conf.run()

95if __name__ == '__main__':
96    main()