11from typing import Optional, List
12
13import faiss
14import numpy as np
15import torch
16
17from labml import monit, lab
18from labml.logger import inspect
19from labml_nn.transformers.knn.train_model import Configs22def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray, vals_store: np.ndarray, n_tokens: int):ප්රති. ල නැවත හැඩගැස්වීම සඳහා විමසුම් වල හැඩය සුරකින්න
31 queries_shape = queries.shapeවිමසුම්වල batch
සහ sequence
මානයන් සමතලා කරන්න
34 queries = queries.view(-1, queries_shape[-1]) ඒ අතර ළඟම අසල්වැසියන් 10 දෙනෙකු සොයා ගන්න . distance
FAISS විසින් ලබා දී ඇති දුර ප්රමාණය වන අතර idx
, එහි දර්ශකය keys_store
වේ.
38 distance, idx = index.search(queries.numpy(), 10)ලබාගන්න
41 keys_found = queries.new_tensor(keys_store[idx])ලබාගන්න
43 vals_found = torch.tensor(vals_store[idx]).squeeze(-1)සාමාන්යකරණයකරන ලද දෛශික අතර කොසයින් සමානතාවය ගණනය කිරීමට අපි යන්නෙමු
සාමාන්යකරන්න
48 keys_found_n = keys_found / torch.sqrt((keys_found ** 2).sum(-1, keepdims=True) + 1e-10)සාමාන්යකරන්න
50 queries_n = queries / torch.sqrt((queries ** 2).sum(-1, keepdims=True) + 1e-10)තිත්නිෂ්පාදන, හෝ කොසයින් සමානකම ලබා ගන්න
53 dot_prod = (keys_found_n * queries_n.unsqueeze(1)).sum(-1)ටෝකන්-නැණවත්පිවිසුම්
56 logits_token = dot_prod.new_zeros(queries.shape[0], n_tokens)ආසන්නතමඅසල්වැසියන් මත පදනම්ව ටෝකන් ලොග්ස් විසිරීම සහ රැස් කිරීම
58 _ = logits_token.scatter_(dim=1, index=vals_found, src=dot_prod, reduce='add')පිවිසුම්නැවත සකස් කරන්න
61 logits_token = logits_token.reshape(queries_shape[0], queries_shape[1], -1)
62
63 return logits_tokenඅපි -NN අනාවැකිය හා ට්රාන්ස්ෆෝමර් අනාවැකිය මත ඒකාබද්ධ වලංගු අහිමි ගණනය. -NN ආකෘතියට ලබා දී ඇති බර ලබා දෙනු ලැබේ knn_weight
. එය පඩි ලැයිස්තුවක් වන අතර අපි එක් එක් සඳහා වලංගු කිරීමේ අලාභය ගණනය කරමු.
66def validation_loss(knn_weights: List[float], last_n: Optional[int], conf: Configs, index: faiss.IndexFlatL2,
67 keys_store: np.ndarray, vals_store: np.ndarray):එක්එක් සඳහා පාඩු ලැයිස්තුව knn_weights
77 losses = [[] for _ in knn_weights]එක්එක් කාණ්ඩයේ සාම්පල ගණන
79 n_samples = []
80 with torch.no_grad():වලංගුදත්ත හරහා නැවත
82 for i, batch in monit.enum("Validation", conf.validator.data_loader, is_children_silent=True):දත්තසහ ඉලක්ක ලේබල ලබා ගන්න
84 data, target = batch[0].to(conf.device), batch[1].to(conf.device)ආකෘතියධාවනය කර අනාවැකි ලබා ගන්න
86 res = conf.model(data)-NN අනාවැකි ලබා ගන්න
88 res_knn = knn(conf.model.ff_input.cpu(), index, keys_store, vals_store, conf.n_tokens)
89 res_knn = res_knn.to(conf.device)මෙය last_n
ටෝකන සඳහා වන අලාභය පමණක් ගණනය කිරීමයි. මෙය වැදගත් වන්නේ ට්රාන්ස්ෆෝමර් ආකෘතියේ පළමු අනාවැකි (අනුක්රමය ඔස්සේ) දෙස බැලීමට අතීත ටෝකන ඉතා අල්පය.
94 if last_n:
95 res = res[-last_n:]
96 res_knn = res_knn[-last_n:]
97 target = target[-last_n:]සාම්පලගණන
100 n_s = res.shape[0] * data.shape[1]
101 n_samples.append(n_s)එක්එක් සඳහා ලකුණු ගණනය කරන්න knn_weights
.
104 for i, c in enumerate(knn_weights):අලාභයගණනය කරන්න
106 loss = conf.loss_func(res_knn * c + (1 - c) * res, target)
107 losses[i].append(loss * n_s)
108
109 return losses, n_samples112def load_index(conf: Configs, n_probe: int = 8):මානයන්
117 d_model = conf.transformer.d_modelපුහුණුදත්ත පැටවුම
119 data_loader = conf.trainer.data_loader122 n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1FAISSදර්ශකය පැටවීම
125 with monit.section('Load index'):
126 index = faiss.read_index(str(lab.get_data_path() / 'faiss.index'))පරීක්ෂණසඳහා සෛල ගණන සකසන්න
128 index.nprobe = n_probeමතකයසිතියම් ගත කළ අංකුර අරා පටවන්න
131 keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))
132 vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'), dtype=np.int, mode='r', shape=(n_keys, 1))
133
134 return index, keys_store, vals_store137def main():
138 from labml_nn.transformers.knn.build_index import load_experimentඅත්හදාබැලීම පූරණය කරන්න. ආකෘතිය පුහුණුකිරීමෙන් ඔබ uuid ධාවනය සමඟ uuid ධාවනය කරන්න.
141 conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')ඇගයීම්ප්රකාරයට ආකෘතිය සකසන්න
143 conf.model.eval()පැටවුම්දර්ශකය
146 index, keys_store, vals_store = load_index(conf)-NN අනාවැකිය ලබා දී ඇති බර ලැයිස්තුව. එක් එක් බර සඳහා වලංගු කිරීමේ අලාභය අපි ඇගයීමට ලක් කරන්නෙමු
149 knn_weights = [i / 20 for i in range(10)]වලංගුකිරීමේ අලාභය තක්සේරු කරන්න
151 losses, n_samples = validation_loss(knn_weights, None, conf, index, keys_store, vals_store)එක්එක් සඳහා පාඩු ප්රතිදානය knn_weights
කරන්න.
153 inspect({c: np.sum(losses[i]) / np.sum(n_samples) for i, c in enumerate(knn_weights)})
154
155
156if __name__ == '__main__':
157 main()