ළඟමඅසල්වැසියා ලබා ගැනීම සඳහා දත්ත සමුදාය

මෙයවේ දත්ත සමුදාය සාදන්න සහ RETRO ආකෘතියසඳහා ළඟම අසල්වැසියන් ලබා ගනී.

කඩදාසිScanN පුස්තකාලය භාවිතා කර ඇති අතර දත්ත සමුදාය සඳහා අපි FaISS පුස්තකාලය භාවිතා කරමු.

16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings

දත්තසමුදාය ගොඩනැගීම

  • chunk_len යනු කුට්ටියක දිග (අක්ෂර ගණන)
  • batch_size ගණනය කිරීමේදී භාවිතා කළ යුතු කණ්ඩායම් ප්රමාණයයි
  • d_emb FAISS දර්ශකයේ තෝරා ගැනීම සඳහා කාවැද්දීම් ලැයිස්තු වල විශේෂාංග ගණන වේ
  • n_centeroids දර්ශකයේ ලැයිස්තු ගණන
  • code_size දර්ශකයේ කේතනය කරන ලද දෛශික ප්රමාණය
  • n_probe විමර්ශනය කළ යුතු ලැයිස්තු ගණන
  • `n_train'යනු දර්ශකය පුහුණු කිරීම සඳහා යතුරු ගණන වේ
27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28                   code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):

දත්තසමුදාය පෙළ ගොනුව පූරණය කරන්න

43    dataset = TextFileDataset(
44        lab.get_data_path() / 'tiny_shakespeare.txt',
45        list,
46        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

පුහුණුදත්ත ලබා ගන්න (නූලක්)

49    text = dataset.train

පෙළකැබලි වලට බෙදන්න chunk_length

52    chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]

එක්එක් කුට්ටි වල ඕෆ්සෙට් ලබා ගන්න

54    chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])

කුට්ටිගණන

56    n_chunks = len(chunks)

ලබාගැනීම සඳහා බර්ට් ආරම්භ කරන්න

59    bert = BERTChunkEmbeddings(torch.device('cuda:0'))

එක්එක් පුනරාවර්තනයේ කුට්ටි batch_size ගණන සැකසීම මගින් කුට්ටි කාවැද්දීම් ලබා ගන්න

62    chunk_emb = []
63    for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64        chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())

ඒවාතනි ආතනයකට ඒකාබද්ධ කරන්න

66    chunk_emb = torch.cat(chunk_emb, dim=0).numpy()

FAISS දර්ශකය සාදන්න

69    quantizer = faiss.IndexFlatL2(d_emb)
70    index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71    index.nprobe = n_probe

කුට්ටම්දර්ශකවල අහඹු නියැදියක් ලබා ගන්න

74    random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)

යතුරුගබඩා කිරීම සඳහා දර්ශකය පුහුණු කරන්න

77    with monit.section('Train index'):
78        index.train(chunk_emb[random_sample])

ප්රමාණයෙන්කාණ්ඩවල දර්ශකයට කුට්ටි එකතු කරන්න 1024

81    for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82        e = min(s + 1024, n_chunks)

දර්ශකයටඑකතු කරන්න

84        index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])

දර්ශකයසුරකින්න

87    with monit.section('Save'):
88        faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))

ළඟමඅසල්වැසියන් ලබා ගැනීම සඳහා දර්ශකය

91class RetroIndex:
  • chunk_len කුටියේ දිග වේ
  • n_probe විමර්ශනය කළ යුතු ලැයිස්තු ගණන
  • n_neighbors ලබා ගැනීමට අසල්වැසියන් සංඛ්යාව
  • n_extra අපි විමසුම් කුට්ටියේ සමග අතිච්ඡාදනය අසල්වැසියන් ඉවත් කරනු ඇත සිට ලබා ගැනීමට අමතර අසල්වැසියන් සංඛ්යාව වේ
  • exclude_neighbor_span අතිච්ඡාදනය සඳහා පරීක්ෂා කිරීමේදී වළක්වා ගත යුතු අමතර පෙළ දිග වේ
96    def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97                 n_neighbors: int = 2, n_extra: int = 2,
98                 exclude_neighbor_span: int = 8):
108        self.n_neighbors = n_neighbors
109        self.chunk_len = chunk_len
110        self.exclude_neighbor_span = exclude_neighbor_span
111        self.n_extra = n_extra

ලබාගැනීම සඳහා බර්ට් ආරම්භ කරන්න

114        self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))

දත්තසමුදාය පටවන්න

116        with monit.section('Load index'):
117            self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118            self.index.nprobe = n_probe

විමසුමසමඟ අතිච්ඡාදනය වන අසල්වැසියන් පෙරහන් කරන්න

අසල්වාසීන්ගේතනතුරු ලබා දී ඇති neighbor_offsets අතර විමසුම් කුට්ටියේ පිහිටීම වේ offset .

120    def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):
127        return [n for n in neighbor_offsets
128                if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129                or n > offset + (self.chunk_len + self.exclude_neighbor_span)]

ළඟමඅසල්වැසියන් ලබා ගන්න

131    def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):

විමසුම්කුට්ටි ලබා ගන්න

137        emb = self.bert(query_chunks).cpu()

දත්තගබඩාවෙන් n_neighbors + n_extra ළඟම අසල්වැසියන් ලබා ගන්න

140        distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)

විමසුමකුට්ටියේ ඕෆ්සෙට් අතිච්ඡාදනය කුට්ටි සිදු පෙරහන ලබා දී තිබේ නම්

143        if offsets is not None:
144            neighbor_offsets = [self.filter_neighbors(off, n_off)
145                                for off, n_off in zip(offsets, neighbor_offsets)]

n_neighbors පෙරීමෙන් පසු ආසන්නතම දේ ලබා ගන්න

148        neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]

151        return neighbor_offsets

155if __name__ == '__main__':
156    build_database()