මෙයවේ දත්ත සමුදාය සාදන්න සහ RETRO ආකෘතියසඳහා ළඟම අසල්වැසියන් ලබා ගනී.
කඩදාසිScanN පුස්තකාලය භාවිතා කර ඇති අතර දත්ත සමුදාය සඳහා අපි FaISS පුස්තකාලය භාවිතා කරමු.
16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddingschunk_len
යනු කුට්ටියක දිග (අක්ෂර ගණන) batch_size
ගණනය කිරීමේදී භාවිතා කළ යුතු කණ්ඩායම් ප්රමාණයයි d_emb
FAISS දර්ශකයේ තෝරා ගැනීම සඳහා කාවැද්දීම් ලැයිස්තු වල විශේෂාංග ගණන වේ n_centeroids
දර්ශකයේ ලැයිස්තු ගණන code_size
දර්ශකයේ කේතනය කරන ලද දෛශික ප්රමාණය n_probe
විමර්ශනය කළ යුතු ලැයිස්තු ගණන 27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28 code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):දත්තසමුදාය පෙළ ගොනුව පූරණය කරන්න
43 dataset = TextFileDataset(
44 lab.get_data_path() / 'tiny_shakespeare.txt',
45 list,
46 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')පුහුණුදත්ත ලබා ගන්න (නූලක්)
49 text = dataset.trainපෙළකැබලි වලට බෙදන්න chunk_length
52 chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]එක්එක් කුට්ටි වල ඕෆ්සෙට් ලබා ගන්න
54 chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])කුට්ටිගණන
56 n_chunks = len(chunks)ලබාගැනීම සඳහා බර්ට් ආරම්භ කරන්න
59 bert = BERTChunkEmbeddings(torch.device('cuda:0'))එක්එක් පුනරාවර්තනයේ කුට්ටි batch_size
ගණන සැකසීම මගින් කුට්ටි කාවැද්දීම් ලබා ගන්න
62 chunk_emb = []
63 for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64 chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())ඒවාතනි ආතනයකට ඒකාබද්ධ කරන්න
66 chunk_emb = torch.cat(chunk_emb, dim=0).numpy()FAISS දර්ශකය සාදන්න
69 quantizer = faiss.IndexFlatL2(d_emb)
70 index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71 index.nprobe = n_probeකුට්ටම්දර්ශකවල අහඹු නියැදියක් ලබා ගන්න
74 random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)යතුරුගබඩා කිරීම සඳහා දර්ශකය පුහුණු කරන්න
77 with monit.section('Train index'):
78 index.train(chunk_emb[random_sample])ප්රමාණයෙන්කාණ්ඩවල දර්ශකයට කුට්ටි එකතු කරන්න 1024
81 for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82 e = min(s + 1024, n_chunks)දර්ශකයටඑකතු කරන්න
84 index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])දර්ශකයසුරකින්න
87 with monit.section('Save'):
88 faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))91class RetroIndex:chunk_len
කුටියේ දිග වේ n_probe
විමර්ශනය කළ යුතු ලැයිස්තු ගණන n_neighbors
ලබා ගැනීමට අසල්වැසියන් සංඛ්යාව n_extra
අපි විමසුම් කුට්ටියේ සමග අතිච්ඡාදනය අසල්වැසියන් ඉවත් කරනු ඇත සිට ලබා ගැනීමට අමතර අසල්වැසියන් සංඛ්යාව වේ exclude_neighbor_span
අතිච්ඡාදනය සඳහා පරීක්ෂා කිරීමේදී වළක්වා ගත යුතු අමතර පෙළ දිග වේ96 def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97 n_neighbors: int = 2, n_extra: int = 2,
98 exclude_neighbor_span: int = 8):108 self.n_neighbors = n_neighbors
109 self.chunk_len = chunk_len
110 self.exclude_neighbor_span = exclude_neighbor_span
111 self.n_extra = n_extraලබාගැනීම සඳහා බර්ට් ආරම්භ කරන්න
114 self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))දත්තසමුදාය පටවන්න
116 with monit.section('Load index'):
117 self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118 self.index.nprobe = n_probeඅසල්වාසීන්ගේතනතුරු ලබා දී ඇති neighbor_offsets
අතර විමසුම් කුට්ටියේ පිහිටීම වේ offset
.
120 def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):127 return [n for n in neighbor_offsets
128 if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129 or n > offset + (self.chunk_len + self.exclude_neighbor_span)]131 def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):විමසුම්කුට්ටි ලබා ගන්න
137 emb = self.bert(query_chunks).cpu()දත්තගබඩාවෙන් n_neighbors + n_extra
ළඟම අසල්වැසියන් ලබා ගන්න
140 distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)විමසුමකුට්ටියේ ඕෆ්සෙට් අතිච්ඡාදනය කුට්ටි සිදු පෙරහන ලබා දී තිබේ නම්
143 if offsets is not None:
144 neighbor_offsets = [self.filter_neighbors(off, n_off)
145 for off, n_off in zip(offsets, neighbor_offsets)]n_neighbors
පෙරීමෙන් පසු ආසන්නතම දේ ලබා ගන්න
148 neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]151 return neighbor_offsets155if __name__ == '__main__':
156 build_database()