K-nnසෙවීම සඳහා FAISS දර්ශකය සාදන්න

අපටදර්ශකය තැනීමට අවශ්යයි . අපි ගබඩා කර මතකයේ සිතියම් ගත කළ අරා. අපි FAISS භාවිතා කිරීමට ආසන්නතම සොයා. FAISS දර්ශක සහ අපි එය විමසන්න .

15from typing import Optional
16
17import faiss
18import numpy as np
19import torch
20
21from labml import experiment, monit, lab
22from labml.utils.pytorch import get_modules
23from labml_nn.transformers.knn.train_model import Configs

දුම්රිය ආකෘතියෙන්සුරකින ලද අත්හදා බැලීමක් පටවන්න.

26def load_experiment(run_uuid: str, checkpoint: Optional[int] = None):

වින්යාසවස්තුව සාදන්න

32    conf = Configs()

අත්හදාබැලීමේදී භාවිතා කරන අභිරුචි වින්යාසයන් පූරණය කරන්න

34    conf_dict = experiment.load_configs(run_uuid)

අපටආහාර ඉදිරි ස්ථරයට යෙදවුම් ලබා ගත යුතුය,

36    conf_dict['is_save_ff_input'] = True

මෙමඅත්හදා බැලීම ඇගයීමක් පමණි; එනම් කිසිවක් ලුහුබැඳ හෝ ගැලවීම නැත

39    experiment.evaluate()

වින්යාසයන්ආරම්භ කරන්න

41    experiment.configs(conf, conf_dict)

ඉතිරිකිරීම/පැටවීම සඳහා ආකෘති සකසන්න

43    experiment.add_pytorch_models(get_modules(conf))

සිටපැටවීම සඳහා අත්හදා බැලීම නියම කරන්න

45    experiment.load(run_uuid, checkpoint)

අත්හදාබැලීම ආරම්භ කරන්න; මෙය ඇත්ත වශයෙන්ම ආකෘති පටවන විට

48    experiment.start()
49
50    return conf

අරාතුළ ඒවා එකතු කර සුරකින්න

ඔබගේදත්ත කට්ටලයේ ප්රමාණය අනුව මෙම අරා විශාල ඉඩක් (ගිගාබයිට් සිය ගණනක් වුවද) ගන්නා බව සලකන්න.

53def gather_keys(conf: Configs):

මානයන්

62    d_model = conf.transformer.d_model

පුහුණුදත්ත පැටවුම

64    data_loader = conf.trainer.data_loader

සන්දර්භයන්ගණන; එනම් පුහුණු දත්තවල ටෝකන ගණන us ණ එකක්.

සඳහා
67    n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1

සඳහාඅංකුර අරාව

69    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='w+', shape=(n_keys, d_model))

සඳහාඅංකුර අරාව

71    vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'), dtype=np.int, mode='w+', shape=(n_keys, 1))

එකතු කරන ලද යතුරු ගණන

74    added = 0
75    with torch.no_grad():

දත්තහරහා ලූප

77        for i, batch in monit.enum("Collect data", data_loader, is_children_silent=True):

ඉලක්කගත ලේබල

79            vals = batch[1].view(-1, 1)

ආදානදත්ත ආකෘතියේ උපාංගය වෙත ගෙන ගියේය

81            data = batch[0].to(conf.device)

ආකෘතියධාවනය කරන්න

83            _ = conf.model(data)

ලබාගන්න

85            keys = conf.model.ff_input.view(-1, d_model)

යතුරුසුරකින්න, මතකයේ සිතියම්ගත කළ අංකන අරාව තුළ

87            keys_store[added: added + keys.shape[0]] = keys.cpu()

මතකසිතියම්ගත කළ අංකන අරාව තුළ අගයන් සුරකින්න

89            vals_store[added: added + keys.shape[0]] = vals

එකතුකරන ලද යතුරු සංඛ්යාව වැඩි කිරීම

91            added += keys.shape[0]

FAISSදර්ශකය ගොඩනඟ

ආරම්භකිරීම, වේගවත් සෙවීමසහ මතක අඩිපාර නිබන්ධන අඩුකිරීම FAISS භාවිතය පිළිබඳ වැඩිදුර දැන ගැනීමට FAISS ඔබට උපකාර කරනු ඇත.

94def build_index(conf: Configs, n_centeroids: int = 2048, code_size: int = 64, n_probe: int = 8, n_train: int = 200_000):

මානයන්

104    d_model = conf.transformer.d_model

පුහුණුදත්ත පැටවුම

106    data_loader = conf.trainer.data_loader

සන්දර්භයන්ගණන; එනම් පුහුණු දත්තවල ටෝකන ගණන us ණ එකක්.

සඳහා
109    n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1

පූර්ණදෛශික ගබඩා නොකරන සම්පීඩනය සමඟ වෙරෙනෝයි සෛල මත පදනම් වූ වේගවත් සෙවුමක් සහිත සූචකයක් සාදන්න.

113    quantizer = faiss.IndexFlatL2(d_model)
114    index = faiss.IndexIVFPQ(quantizer, d_model, n_centeroids, code_size, 8)
115    index.nprobe = n_probe

මතකයසිතියම් ගත කළ අංකිත යතුරු අරාව පූරණය කරන්න

118    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))

දර්ශකයපුහුණු කිරීම සඳහා යතුරු අහඹු නියැදියක් තෝරන්න

121    random_sample = np.random.choice(np.arange(n_keys), size=[min(n_train, n_keys)], replace=False)
122
123    with monit.section('Train index'):

යතුරුගබඩා කිරීම සඳහා දර්ශකය පුහුණු කරන්න

125        index.train(keys_store[random_sample])

දර්ශකයටයතුරු එකතු කරන්න;

128    for s in monit.iterate('Index', range(0, n_keys, 1024)):
129        e = min(s + 1024, n_keys)

131        keys = keys_store[s:e]

133        idx = np.arange(s, e)

දර්ශකයටඑකතු කරන්න

135        index.add_with_ids(keys, idx)
136
137    with monit.section('Save'):

දර්ශකයසුරකින්න

139        faiss.write_index(index, str(lab.get_data_path() / 'faiss.index'))
142def main():

අත්හදාබැලීම පූරණය කරන්න. ආකෘතිය පුහුණුකිරීමෙන් ඔබ uuid ධාවනය සමඟ uuid ධාවනය කරන්න.

145    conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')

ඇගයීම්ප්රකාරයට ආකෘතිය සකසන්න

147    conf.model.eval()

එකතුකරන්න

150    gather_keys(conf)

වේගවත්සෙවීම සඳහා ඒවා දර්ශකයට එක් කරන්න

152    build_index(conf)
153
154
155if __name__ == '__main__':
156    main()