LLM.INT8 () ප්රමාණකරණයභාවිතා කරමින් GPT-Neox වෙතින් පෙළ ජනනය කරන්නේ කෙසේද යන්න මෙයින් පෙන්වයි.
මේසඳහා 24GB මතකයක් සහිත GPU එකක් අවශ්ය වේ.
15import torch
16from torch import nn
17
18from labml import monit
19from labml_nn.neox.model import LayerGenerator
20from labml_nn.neox.samples.generate import PROMPT, infer
21from labml_nn.neox.utils import get_tokens, print_tokens
22from labml_nn.neox.utils.cache import get_cache25def generate():31 cache = get_cache()
32 cache.set('use_cache', True)උපාංගය
35 device = torch.device('cuda:0')පාවෙන16 හි ස්ථර CPU තුළට පටවන්න. අපි ස්ථර පසුව int8 බවට පරිවර්තනය කරමු, මන්ද ස්ථර GPU වෙත පැටවීමෙන් පසු පියාසර කිරීම CUDA මතක ඛණ්ඩනය වීමට හේතු වේ (3GB පමණ මතකය කැබලි වීම නිසා අහිමි විය හැක).
40 layer_generator = LayerGenerator(is_clone_layers=True,
41 dtype=torch.float16,
42 device=torch.device('cpu'),
43 is_llm_int8=False,
44 )
45 layers = list(layer_generator.load())මෙයCUDA මතක ඛණ්ඩනය අඩු කරයි
48 for layer in monit.iterate('Convert to int8', layers, is_children_silent=True):
49 layer_generator.post_load_prepare(layer,
50 device=device,
51 is_llm_int8=True,
52 llm_int8_threshold=6.0,
53 )
54 layer.to(device)nn.Sequential
ආකෘතිය සාදන්න
57 model = nn.Sequential(*layers)නිදොස්කරණයසඳහා හැඹිලි සහ මුද්රිත මතක සාරාංශය පැහැදිලි කරන්න
60 torch.cuda.empty_cache()
61 print(torch.cuda.memory_summary())ටෝකන්හැඳුනුම්පත් ලබා ගන්න
64 ids = get_tokens(PROMPT)ආකෘතියධාවනය කරන්න. අපි අර්ථ දක්වා ඇති infer
ශ්රිතය භාවිතා කරමු generate.py
68 cache.set('state_ids', (None, 1))
69 with monit.section('Infer'):
70 next_token = infer(model, ids, device)[-1]පුරෝකථනයකළ ටෝකනය එක් කරන්න
73 ids += [next_token]ටෝකන100 ක් පුරෝකථනය කරන්න
76 for i in range(1, 100):හැඹිලිසක්රිය කිරීම් භාවිතා කිරීමට රාජ්යය සකසන්න
78 cache.set('state_ids', (i, i + 1))ඊළඟටෝකනය ලබා ගන්න. පෙර ටෝකන වල යතුර/අගය යුගල හැඹිලි කරන නිසා අපි ආකෘතියට අවසාන ටෝකනය පමණක් පෝෂණය කරන බව සලකන්න.
81 with monit.section('Infer'):
82 next_token = infer(model, [next_token], device)[-1]පුරෝකථනයකළ ටෝකනය එක් කරන්න
84 ids += [next_token]මුද්රණය
86 print_tokens(ids, [ids])90if __name__ == '__main__':
91 generate()