පෙළකැබලි වල බර්ට් කාවැද්දීම්

RETRO ආකෘතියසඳහා කුට්ටි වල BERT කාවැද්දීම් ලබා ගැනීමේ කේතය මෙයයි.

13from typing import List
14
15import torch
16from transformers import BertTokenizer, BertModel
17
18from labml import lab, monit

බර්ට්කාවැද්දීම්

දීඇති පෙළ කුට්ටියක් සඳහා මෙම පන්තිය BERT කාවැද්දීම් ජනනය කරයි . සියලුම ටෝකන වල BERT කාවැද්දීම් වල සාමාන්යය වේ.

21class BERTChunkEmbeddings:
29    def __init__(self, device: torch.device):
30        self.device = device
33        with monit.section('Load BERT tokenizer'):
34            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35                                                           cache_dir=str(
36                                                               lab.get_data_path() / 'cache' / 'bert-tokenizer'))

HuggingFace වෙතින් බර්ට් ආකෘතිය පටවන්න

39        with monit.section('Load BERT model'):
40            self.model = BertModel.from_pretrained("bert-base-uncased",
41                                                   cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))

ආකෘතියවෙත ගෙන යන්න device

44            self.model.to(device)

මෙමක්රියාත්මක කිරීමේදී, අපි ස්ථාවර ටෝකන සංඛ්යාවක් සමඟ කුට්ටි සාදන්නේ නැත. එක් හේතුවක් නම්, මෙම ක්රියාත්මක කිරීම චරිත මට්ටමේ ටෝකන භාවිතා කිරීම සහ BERT එහි උප වචන ටෝකනයිසර් භාවිතා කිරීමයි.

එබැවින්මෙම ක්රමය මඟින් අර්ධ ටෝකන නොමැති බවට වග බලා ගැනීම සඳහා පා truncate වේ.

නිදසුනක්වශයෙන්, කෙළවරේ අර්ධ වචන (අර්ධ උප වචන ටෝකන) සහිත කුට්ටියක් සමාන s a popular programming la විය හැකිය. වඩා හොඳ බර්ට් කාවැද්දීම් ලබා ගැනීම සඳහා අපි ඒවා ඉවත් කරමු. කලින් සඳහන් කළ පරිදි, ටෝකනීකරණයෙන් පසු අපි කුට්ටි කඩා දැමුවහොත් මෙය අවශ්ය නොවේ.

46    @staticmethod
47    def _trim_chunk(chunk: str):

තීරුවයිට්ස්පේස්

61        stripped = chunk.strip()

වචනකඩන්න

63        parts = stripped.split()

පළමුහා අවසාන කෑලි ඉවත් කරන්න

65        stripped = stripped[len(parts[0]):-len(parts[-1])]

වයිට්ස්පේස්ඉවත් කරන්න

68        stripped = stripped.strip()

හිස්ආපසු මුල් string නම්

71        if not stripped:
72            return chunk

එසේනොමැතිනම්, ඉවත් කරන ලද නූල් ආපසු ලබා දෙන්න

74        else:
75            return stripped

කුට්ටිලැයිස්තුවක් සඳහා ලබා ගන්න.

77    def __call__(self, chunks: List[str]):

අපටඅනුක්රමික ගණනය කිරීමට අවශ්ය නැත

83        with torch.no_grad():

කුට්ටිකපන්න

85            trimmed_chunks = [self._trim_chunk(c) for c in chunks]

බර්ට්ටෝකනයිසර් සමඟ කුට්ටි ටෝකන්ට් කරන්න

88            tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)

ටෝකන්හැඳුනුම්පත්, අවධානය ආවරණ සහ ටෝකන් වර්ග උපාංගයට ගෙන යන්න

91            input_ids = tokens['input_ids'].to(self.device)
92            attention_mask = tokens['attention_mask'].to(self.device)
93            token_type_ids = tokens['token_type_ids'].to(self.device)

ආකෘතියතක්සේරු කරන්න

95            output = self.model(input_ids=input_ids,
96                                attention_mask=attention_mask,
97                                token_type_ids=token_type_ids)

ටෝකන්කාවැද්දීම් ලබා ගන්න

100            state = output['last_hidden_state']

සාමාන්යටෝකන කාවැද්දීම් ගණනය කරන්න. ටෝකනය හිස් පෑඩ් 0 නම් අවධානය යොමු කිරීමේ ආවරණ බව සලකන්න. කුට්ටි විවිධ දිග බැවින් අපට හිස් ටෝකන ලැබේ.

104            emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)

107            return emb

BERTකාවැද්දීම් පරීක්ෂා කිරීමට කේතය

110def _test():
114    from labml.logger import inspect

ආරම්භකරන්න

117    device = torch.device('cuda:0')
118    bert = BERTChunkEmbeddings(device)

නියැදිය

121    text = ["Replace me by any text you'd like.",
122            "Second sentence"]

බර්ට්ටෝකනයිසර් පරීක්ෂා කරන්න

125    encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
126
127    inspect(encoded_input, _expand=True)

බර්ට්ආකෘති ප්රතිදානයන් පරීක්ෂා කරන්න

130    output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131                        attention_mask=encoded_input['attention_mask'].to(device),
132                        token_type_ids=encoded_input['token_type_ids'].to(device))
133
134    inspect({'last_hidden_state': output['last_hidden_state'],
135             'pooler_output': output['pooler_output']},
136            _expand=True)

ටෝකන්id වලින් පෙළ ප්රතිනිර්මාණය කිරීම පරීක්ෂා කරන්න

139    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)

කුට්ටිකාවැද්දීම් ලබා ගන්න

143    inspect(bert(text))

147if __name__ == '__main__':
148    _test()