k-ළඟමඅසල්වැසියාගේ භාෂා ආකෘතිය තක්සේරු කරන්න

11from typing import Optional, List
12
13import faiss
14import numpy as np
15import torch
16
17from labml import monit, lab
18from labml.logger import inspect
19from labml_nn.transformers.knn.train_model import Configs

-NN ලබා ගැනීමට

මෙන්නඅපි විමසුම් ලෙස, යතුරු ලෙස සහ අගයන් ලෙස සඳහන් කරමු.

22def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray, vals_store: np.ndarray, n_tokens: int):

ප්රති. ල නැවත හැඩගැස්වීම සඳහා විමසුම් වල හැඩය සුරකින්න

31    queries_shape = queries.shape

විමසුම්වල batch සහ sequence මානයන් සමතලා කරන්න

34    queries = queries.view(-1, queries_shape[-1])

ඒ අතර ළඟම අසල්වැසියන් 10 දෙනෙකු සොයා ගන්න . distance FAISS විසින් ලබා දී ඇති දුර ප්රමාණය වන අතර idx , එහි දර්ශකය keys_store වේ.

38    distance, idx = index.search(queries.numpy(), 10)

ලබාගන්න

41    keys_found = queries.new_tensor(keys_store[idx])

ලබාගන්න

43    vals_found = torch.tensor(vals_store[idx]).squeeze(-1)

සාමාන්යකරණයකරන ලද දෛශික අතර කොසයින් සමානතාවය ගණනය කිරීමට අපි යන්නෙමු

සාමාන්‍යකරන්න

48    keys_found_n = keys_found / torch.sqrt((keys_found ** 2).sum(-1, keepdims=True) + 1e-10)

සාමාන්‍යකරන්න

50    queries_n = queries / torch.sqrt((queries ** 2).sum(-1, keepdims=True) + 1e-10)

තිත්නිෂ්පාදන, හෝ කොසයින් සමානකම ලබා ගන්න

53    dot_prod = (keys_found_n * queries_n.unsqueeze(1)).sum(-1)

ටෝකන්-නැණවත්පිවිසුම්

56    logits_token = dot_prod.new_zeros(queries.shape[0], n_tokens)

ආසන්නතමඅසල්වැසියන් මත පදනම්ව ටෝකන් ලොග්ස් විසිරීම සහ රැස් කිරීම

58    _ = logits_token.scatter_(dim=1, index=vals_found, src=dot_prod, reduce='add')

පිවිසුම්නැවත සකස් කරන්න

61    logits_token = logits_token.reshape(queries_shape[0], queries_shape[1], -1)
62
63    return logits_token

වලංගුකිරීමේ අලාභය ගණනය කරන්න

අපි -NN අනාවැකිය හා ට්රාන්ස්ෆෝමර් අනාවැකිය මත ඒකාබද්ධ වලංගු අහිමි ගණනය. -NN ආකෘතියට ලබා දී ඇති බර ලබා දෙනු ලැබේ knn_weight . එය පඩි ලැයිස්තුවක් වන අතර අපි එක් එක් සඳහා වලංගු කිරීමේ අලාභය ගණනය කරමු.

66def validation_loss(knn_weights: List[float], last_n: Optional[int], conf: Configs, index: faiss.IndexFlatL2,
67                    keys_store: np.ndarray, vals_store: np.ndarray):

එක්එක් සඳහා පාඩු ලැයිස්තුව knn_weights

77    losses = [[] for _ in knn_weights]

එක්එක් කාණ්ඩයේ සාම්පල ගණන

79    n_samples = []
80    with torch.no_grad():

වලංගුදත්ත හරහා නැවත

82        for i, batch in monit.enum("Validation", conf.validator.data_loader, is_children_silent=True):

දත්තසහ ඉලක්ක ලේබල ලබා ගන්න

84            data, target = batch[0].to(conf.device), batch[1].to(conf.device)

ආකෘතියධාවනය කර අනාවැකි ලබා ගන්න

86            res = conf.model(data)

-NN අනාවැකි ලබා ගන්න

88            res_knn = knn(conf.model.ff_input.cpu(), index, keys_store, vals_store, conf.n_tokens)
89            res_knn = res_knn.to(conf.device)

මෙය last_n ටෝකන සඳහා වන අලාභය පමණක් ගණනය කිරීමයි. මෙය වැදගත් වන්නේ ට්රාන්ස්ෆෝමර් ආකෘතියේ පළමු අනාවැකි (අනුක්රමය ඔස්සේ) දෙස බැලීමට අතීත ටෝකන ඉතා අල්පය.

94            if last_n:
95                res = res[-last_n:]
96                res_knn = res_knn[-last_n:]
97                target = target[-last_n:]

සාම්පලගණන

100            n_s = res.shape[0] * data.shape[1]
101            n_samples.append(n_s)

එක්එක් සඳහා ලකුණු ගණනය කරන්න knn_weights .

104            for i, c in enumerate(knn_weights):

අලාභයගණනය කරන්න

106                loss = conf.loss_func(res_knn * c + (1 - c) * res, target)
107                losses[i].append(loss * n_s)
108
109    return losses, n_samples

දර්ශකයපූරණය කරන්න

112def load_index(conf: Configs, n_probe: int = 8):

මානයන්

117    d_model = conf.transformer.d_model

පුහුණුදත්ත පැටවුම

119    data_loader = conf.trainer.data_loader

සන්දර්භයන්ගණන; එනම් පුහුණු දත්තවල ටෝකන ගණන us ණ එකක්.

සඳහා
122    n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1

FAISSදර්ශකය පැටවීම

125    with monit.section('Load index'):
126        index = faiss.read_index(str(lab.get_data_path() / 'faiss.index'))

පරීක්ෂණසඳහා සෛල ගණන සකසන්න

128    index.nprobe = n_probe

මතකයසිතියම් ගත කළ අංකුර අරා පටවන්න

131    keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))
132    vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'), dtype=np.int, mode='r', shape=(n_keys, 1))
133
134    return index, keys_store, vals_store
137def main():
138    from labml_nn.transformers.knn.build_index import load_experiment

අත්හදාබැලීම පූරණය කරන්න. ආකෘතිය පුහුණුකිරීමෙන් ඔබ uuid ධාවනය සමඟ uuid ධාවනය කරන්න.

141    conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')

ඇගයීම්ප්රකාරයට ආකෘතිය සකසන්න

143    conf.model.eval()

පැටවුම්දර්ශකය

146    index, keys_store, vals_store = load_index(conf)

-NN අනාවැකිය ලබා දී ඇති බර ලැයිස්තුව. එක් එක් බර සඳහා වලංගු කිරීමේ අලාභය අපි ඇගයීමට ලක් කරන්නෙමු

149    knn_weights = [i / 20 for i in range(10)]

වලංගුකිරීමේ අලාභය තක්සේරු කරන්න

151    losses, n_samples = validation_loss(knn_weights, None, conf, index, keys_store, vals_store)

එක්එක් සඳහා පාඩු ප්රතිදානය knn_weights කරන්න.

153    inspect({c: np.sum(losses[i]) / np.sum(n_samples) for i, c in enumerate(knn_weights)})
154
155
156if __name__ == '__main__':
157    main()