බරප්රමිතිකරණය සහ කණ්ඩායම්-නාලිකා සාමාන්යකරණය උත්සාහ කිරීම සඳහා CIFAR10 අත්හදා බැලීම

12import torch.nn as nn
13
14from labml import experiment
15from labml.configs import option
16from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
17from labml_nn.normalization.batch_channel_norm import BatchChannelNorm
18from labml_nn.normalization.weight_standardization.conv2d import Conv2d

CIFA-10වර්ගීකරණය සඳහා VGG ආකෘතිය

මෙය සාමාන්ය VGG විලාසිතාවේ ගෘහ නිර්මාණ ශිල්පයෙන්ලබා ගනී.

21class Model(CIFAR10VGGModel):
28    def conv_block(self, in_channels, out_channels) -> nn.Module:
29        return nn.Sequential(
30            Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
31            BatchChannelNorm(out_channels, 32),
32            nn.ReLU(inplace=True),
33        )
35    def __init__(self):
36        super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])

ආකෘතියසාදන්න

39@option(CIFAR10Configs.model)
40def _model(c: CIFAR10Configs):
44    return Model().to(c.device)
47def main():

අත්හදාබැලීම සාදන්න

49    experiment.create(name='cifar10', comment='weight standardization')

වින්යාසයන්සාදන්න

51    conf = CIFAR10Configs()

වින්යාසයන්පූරණය කරන්න

53    experiment.configs(conf, {
54        'optimizer.optimizer': 'Adam',
55        'optimizer.learning_rate': 2.5e-4,
56        'train_batch_size': 64,
57    })

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

59    with experiment.start():
60        conf.run()

64if __name__ == '__main__':
65    main()