Module llmflex.Models.Cores.utils
Expand source code
from ...Prompts.prompt_template import PromptTemplate
from typing import List, Optional, Any, Literal, Tuple, Iterator, Dict
def add_newline_char_to_stopwords(stop: List[str]) -> List[str]:
"""Create a duplicate of the stop words and add a new line character as a prefix to each of them if their prefixes are not new line characters.
Args:
stop (List[str]): List of stop words.
Returns:
List[str]: New version of the list of stop words, with new line characters.
"""
stop = list(filter(lambda x: x != '', stop))
new = stop.copy()
for i in stop:
if not i.startswith('\n'):
new.append('\n' + i)
new = list(set(new))
return new
def get_stop_words(stop: Optional[List[str]], tokenizer: Any,
add_newline_version: bool = True, tokenizer_type: Literal['transformers', 'llamacpp', 'openai'] = 'transformers') -> List[str]:
"""Adding necessary stop words such as EOS token and multiple newline characters.
Args:
stop (Optional[List[str]]): List of stop words, if None is given, an empty list will be assumed.
tokenizer (Any): Tokenizer to get the EOS token.
add_newline_version (bool, optional): Whether to use add_newline_char_to_stopwords function. Defaults to True.
tokenizer_type (Literal['transformers', 'llamacpp', 'openai'], optional): Type of tokenizer. Defaults to 'transformers'.
Returns:
List[str]: Updated list of stop words.
"""
stop = stop if isinstance(stop, list) else []
if tokenizer_type == 'transformers':
eos_token = tokenizer.eos_token
elif tokenizer_type == 'llamacpp':
eos_token = tokenizer.detokenize(tokens=[tokenizer.token_eos()]).decode()
elif tokenizer_type == 'openai':
eos_token = tokenizer.decode(tokens=[tokenizer.eot_token])
if ((eos_token is not None) & (eos_token not in stop)):
stop.append(eos_token)
if '\n\n\n' not in stop:
stop.append('\n\n\n')
if add_newline_version:
return add_newline_char_to_stopwords(stop)
else:
stop = list(filter(lambda x: x != '', stop))
return list(set(stop))
def find_roots(text: str, stop: List[str], stop_len: List[int]) -> Tuple[str, str]:
"""This function is a helper function for stopping stop words from showing up while doing work streaming in some custom llm classes. Not intended to be used alone.
Args:
text (str): Output of the model.
stop (List[str]): List of stop words.
stop_len (List[int]): List of the lengths of the stop words.
Returns:
Tuple[str, str]: Curated output of the model, potential root of stop words.
"""
root = ''
for w in stop:
if w in text:
return text.split(w)[0], w
for i, w in enumerate(stop):
for j in range(stop_len[i]):
if text[-(j + 1):]==w[:j+1]:
root = w[:j+1]
break
if root:
break
text = text[:-len(root)] if root else text
return text, root
def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Strip text with the given stop words.
Args:
text (str): Text to strip.
stop (List[str]): List of stop words.
Returns:
str: Stripped text.
"""
stop_pos = list(map(lambda x: text.find(x), stop))
stop_map = list(zip(stop, stop_pos))
stop_map = list(filter(lambda x: x[1] != -1, stop_map))
if len(stop_map) != 0:
stop_map.sort(key=lambda x: x[1])
stop_word = stop_map[0][0]
return text.split(sep=stop_word)[0]
else:
return text
def textgen_iterator(text_generator: Iterator[str], stop: List[str]) -> Iterator[str]:
"""Make a text generator stop before spitting out the stop words.
Args:
text_generator (Iterator[str]): Text generator to transform.
stop (List[str]): Stop words.
Yields:
Iterator[str]: Text generator with stop words applied.
"""
text, output, root = '', '', ''
cont = True
stop_len = list(map(len, stop))
for i in text_generator:
temp = text + root + i
text, root = find_roots(temp, stop, stop_len)
if root in stop:
cont = False
token = text.removeprefix(output)
output += token
if cont:
yield token
else:
yield ''
if root not in stop:
yield root
else:
yield ''
def detect_prompt_template_by_id(model_id: str) -> str:
"""Guess the prompt format for the model by model ID.
Args:
model_id (str): Huggingface ID of the model.
Returns:
str: Prompt template preset.
"""
finetunes = dict(
hermes = 'ChatML',
nous = 'ChatML',
wizardlm = 'Vicuna',
openchat = 'OpenChat',
zephyr = 'Zephyr',
solar = 'Llama2'
)
base = {
'llama-3': 'Llama3',
'llama-2': 'Llama2',
'mistral': 'Llama2',
'mixtral': 'Llama2'
}
id_lower = model_id.lower()
# Check if it is in the finetune list
keys = list(map(lambda x: (x, id_lower.find(x)), finetunes.keys()))
keys.sort(key=lambda x: x[1])
keys = list(filter(lambda x: x[1]!=-1, keys))
if len(keys) != 0:
return finetunes[keys[0][0]]
# Check if in the base list
keys = list(map(lambda x: (x, id_lower.find(x)), base.keys()))
keys.sort(key=lambda x: x[1])
keys = list(filter(lambda x: x[1]!=-1, keys))
if len(keys) != 0:
return base[keys[0][0]]
return 'Default'
def detect_prompt_template_by_jinja(jinja_template: str) -> str:
"""Detect if the jinja template given is the same as one of the presets.
Args:
jinja_template (str): Jinja template to test.
Returns:
str: Prompt template preset.
"""
from ...Prompts.prompt_template import presets
for k, v in presets.items():
if jinja_template in v['template']:
return k
if 'keywords' in v.keys():
if all(kw in jinja_template for kw in v['keywords']):
return k
return 'Default'
def get_prompt_template_by_jinja(model_id: str, tokenizer: Any) -> PromptTemplate:
"""Getting the appropriate prompt template given the huggingface tokenizer.
Args:
model_id (str): Repo ID of the tokenizer.
tokenizer (Any): Huggingface tokenizer.
Returns:
PromptTemplate: The prompt template object.
"""
if tokenizer.chat_template is not None:
jinja = tokenizer.chat_template
if not isinstance(jinja, str):
if isinstance(jinja, dict):
jinja = jinja.get('default')
if jinja:
priority = True
else:
jinja = tokenizer.default_chat_template
priority = False
else:
jinja = tokenizer.default_chat_template
priority = False
else:
jinja = tokenizer.default_chat_template
priority = False
else:
jinja = tokenizer.default_chat_template
priority = False
prompt_template = detect_prompt_template_by_jinja(jinja)
prompt_template = detect_prompt_template_by_id(model_id) if ((prompt_template == 'Default') and not priority) else prompt_template
if (priority and (prompt_template == 'Default')):
prompt_template = PromptTemplate(template=jinja, eos_token=tokenizer.eos_token, bos_token=tokenizer.bos_token, stop=[] if tokenizer.eos_token is None else [tokenizer.eos_token])
else:
prompt_template = PromptTemplate.from_preset(prompt_template)
return prompt_template
def list_local_models() -> List[Dict[str, str]]:
"""Check what you have in your local model cache directory.
Returns:
List[Dict[str, str]]: List of dictionarys of model details.
"""
import os
from ...utils import get_config
model_dir = os.path.join(get_config()['hf_home'], 'hub')
repos = list(filter(lambda x: x.startswith('models--'), os.listdir(model_dir)))
repo_dirs = list(map(lambda x: os.path.join(model_dir, x, 'snapshots'), repos))
repos = list(map(lambda x: x.removeprefix('models--').replace('--', '/'), repos))
repo_dirs = list(map(lambda x: os.path.join(x, list(filter(lambda y: '.DS_Store' not in y, os.listdir(x)))[0]), repo_dirs))
repos = list(zip(repos, repo_dirs))
repos = list(map(lambda x: dict(repo_id=x[0], files=os.listdir(x[1])), repos))
return repos
Functions
def add_newline_char_to_stopwords(stop: List[str]) ‑> List[str]-
Create a duplicate of the stop words and add a new line character as a prefix to each of them if their prefixes are not new line characters.
Args
stop:List[str]- List of stop words.
Returns
List[str]- New version of the list of stop words, with new line characters.
Expand source code
def add_newline_char_to_stopwords(stop: List[str]) -> List[str]: """Create a duplicate of the stop words and add a new line character as a prefix to each of them if their prefixes are not new line characters. Args: stop (List[str]): List of stop words. Returns: List[str]: New version of the list of stop words, with new line characters. """ stop = list(filter(lambda x: x != '', stop)) new = stop.copy() for i in stop: if not i.startswith('\n'): new.append('\n' + i) new = list(set(new)) return new def detect_prompt_template_by_id(model_id: str) ‑> str-
Guess the prompt format for the model by model ID.
Args
model_id:str- Huggingface ID of the model.
Returns
str- Prompt template preset.
Expand source code
def detect_prompt_template_by_id(model_id: str) -> str: """Guess the prompt format for the model by model ID. Args: model_id (str): Huggingface ID of the model. Returns: str: Prompt template preset. """ finetunes = dict( hermes = 'ChatML', nous = 'ChatML', wizardlm = 'Vicuna', openchat = 'OpenChat', zephyr = 'Zephyr', solar = 'Llama2' ) base = { 'llama-3': 'Llama3', 'llama-2': 'Llama2', 'mistral': 'Llama2', 'mixtral': 'Llama2' } id_lower = model_id.lower() # Check if it is in the finetune list keys = list(map(lambda x: (x, id_lower.find(x)), finetunes.keys())) keys.sort(key=lambda x: x[1]) keys = list(filter(lambda x: x[1]!=-1, keys)) if len(keys) != 0: return finetunes[keys[0][0]] # Check if in the base list keys = list(map(lambda x: (x, id_lower.find(x)), base.keys())) keys.sort(key=lambda x: x[1]) keys = list(filter(lambda x: x[1]!=-1, keys)) if len(keys) != 0: return base[keys[0][0]] return 'Default' def detect_prompt_template_by_jinja(jinja_template: str) ‑> str-
Detect if the jinja template given is the same as one of the presets.
Args
jinja_template:str- Jinja template to test.
Returns
str- Prompt template preset.
Expand source code
def detect_prompt_template_by_jinja(jinja_template: str) -> str: """Detect if the jinja template given is the same as one of the presets. Args: jinja_template (str): Jinja template to test. Returns: str: Prompt template preset. """ from ...Prompts.prompt_template import presets for k, v in presets.items(): if jinja_template in v['template']: return k if 'keywords' in v.keys(): if all(kw in jinja_template for kw in v['keywords']): return k return 'Default' def enforce_stop_tokens(text: str, stop: List[str]) ‑> str-
Strip text with the given stop words.
Args
text:str- Text to strip.
stop:List[str]- List of stop words.
Returns
str- Stripped text.
Expand source code
def enforce_stop_tokens(text: str, stop: List[str]) -> str: """Strip text with the given stop words. Args: text (str): Text to strip. stop (List[str]): List of stop words. Returns: str: Stripped text. """ stop_pos = list(map(lambda x: text.find(x), stop)) stop_map = list(zip(stop, stop_pos)) stop_map = list(filter(lambda x: x[1] != -1, stop_map)) if len(stop_map) != 0: stop_map.sort(key=lambda x: x[1]) stop_word = stop_map[0][0] return text.split(sep=stop_word)[0] else: return text def find_roots(text: str, stop: List[str], stop_len: List[int]) ‑> Tuple[str, str]-
This function is a helper function for stopping stop words from showing up while doing work streaming in some custom llm classes. Not intended to be used alone.
Args
text:str- Output of the model.
stop:List[str]- List of stop words.
stop_len:List[int]- List of the lengths of the stop words.
Returns
Tuple[str, str]- Curated output of the model, potential root of stop words.
Expand source code
def find_roots(text: str, stop: List[str], stop_len: List[int]) -> Tuple[str, str]: """This function is a helper function for stopping stop words from showing up while doing work streaming in some custom llm classes. Not intended to be used alone. Args: text (str): Output of the model. stop (List[str]): List of stop words. stop_len (List[int]): List of the lengths of the stop words. Returns: Tuple[str, str]: Curated output of the model, potential root of stop words. """ root = '' for w in stop: if w in text: return text.split(w)[0], w for i, w in enumerate(stop): for j in range(stop_len[i]): if text[-(j + 1):]==w[:j+1]: root = w[:j+1] break if root: break text = text[:-len(root)] if root else text return text, root def get_prompt_template_by_jinja(model_id: str, tokenizer: Any) ‑> PromptTemplate-
Getting the appropriate prompt template given the huggingface tokenizer.
Args
model_id:str- Repo ID of the tokenizer.
tokenizer:Any- Huggingface tokenizer.
Returns
PromptTemplate- The prompt template object.
Expand source code
def get_prompt_template_by_jinja(model_id: str, tokenizer: Any) -> PromptTemplate: """Getting the appropriate prompt template given the huggingface tokenizer. Args: model_id (str): Repo ID of the tokenizer. tokenizer (Any): Huggingface tokenizer. Returns: PromptTemplate: The prompt template object. """ if tokenizer.chat_template is not None: jinja = tokenizer.chat_template if not isinstance(jinja, str): if isinstance(jinja, dict): jinja = jinja.get('default') if jinja: priority = True else: jinja = tokenizer.default_chat_template priority = False else: jinja = tokenizer.default_chat_template priority = False else: jinja = tokenizer.default_chat_template priority = False else: jinja = tokenizer.default_chat_template priority = False prompt_template = detect_prompt_template_by_jinja(jinja) prompt_template = detect_prompt_template_by_id(model_id) if ((prompt_template == 'Default') and not priority) else prompt_template if (priority and (prompt_template == 'Default')): prompt_template = PromptTemplate(template=jinja, eos_token=tokenizer.eos_token, bos_token=tokenizer.bos_token, stop=[] if tokenizer.eos_token is None else [tokenizer.eos_token]) else: prompt_template = PromptTemplate.from_preset(prompt_template) return prompt_template def get_stop_words(stop: Optional[List[str]], tokenizer: Any, add_newline_version: bool = True, tokenizer_type: Literal['transformers', 'llamacpp', 'openai'] = 'transformers') ‑> List[str]-
Adding necessary stop words such as EOS token and multiple newline characters.
Args
stop:Optional[List[str]]- List of stop words, if None is given, an empty list will be assumed.
tokenizer:Any- Tokenizer to get the EOS token.
add_newline_version:bool, optional- Whether to use add_newline_char_to_stopwords function. Defaults to True.
tokenizer_type (Literal['transformers', 'llamacpp', 'openai'], optional): Type of tokenizer. Defaults to 'transformers'.
Returns
List[str]- Updated list of stop words.
Expand source code
def get_stop_words(stop: Optional[List[str]], tokenizer: Any, add_newline_version: bool = True, tokenizer_type: Literal['transformers', 'llamacpp', 'openai'] = 'transformers') -> List[str]: """Adding necessary stop words such as EOS token and multiple newline characters. Args: stop (Optional[List[str]]): List of stop words, if None is given, an empty list will be assumed. tokenizer (Any): Tokenizer to get the EOS token. add_newline_version (bool, optional): Whether to use add_newline_char_to_stopwords function. Defaults to True. tokenizer_type (Literal['transformers', 'llamacpp', 'openai'], optional): Type of tokenizer. Defaults to 'transformers'. Returns: List[str]: Updated list of stop words. """ stop = stop if isinstance(stop, list) else [] if tokenizer_type == 'transformers': eos_token = tokenizer.eos_token elif tokenizer_type == 'llamacpp': eos_token = tokenizer.detokenize(tokens=[tokenizer.token_eos()]).decode() elif tokenizer_type == 'openai': eos_token = tokenizer.decode(tokens=[tokenizer.eot_token]) if ((eos_token is not None) & (eos_token not in stop)): stop.append(eos_token) if '\n\n\n' not in stop: stop.append('\n\n\n') if add_newline_version: return add_newline_char_to_stopwords(stop) else: stop = list(filter(lambda x: x != '', stop)) return list(set(stop)) def list_local_models() ‑> List[Dict[str, str]]-
Check what you have in your local model cache directory.
Returns
List[Dict[str, str]]- List of dictionarys of model details.
Expand source code
def list_local_models() -> List[Dict[str, str]]: """Check what you have in your local model cache directory. Returns: List[Dict[str, str]]: List of dictionarys of model details. """ import os from ...utils import get_config model_dir = os.path.join(get_config()['hf_home'], 'hub') repos = list(filter(lambda x: x.startswith('models--'), os.listdir(model_dir))) repo_dirs = list(map(lambda x: os.path.join(model_dir, x, 'snapshots'), repos)) repos = list(map(lambda x: x.removeprefix('models--').replace('--', '/'), repos)) repo_dirs = list(map(lambda x: os.path.join(x, list(filter(lambda y: '.DS_Store' not in y, os.listdir(x)))[0]), repo_dirs)) repos = list(zip(repos, repo_dirs)) repos = list(map(lambda x: dict(repo_id=x[0], files=os.listdir(x[1])), repos)) return repos def textgen_iterator(text_generator: Iterator[str], stop: List[str]) ‑> Iterator[str]-
Make a text generator stop before spitting out the stop words.
Args
text_generator:Iterator[str]- Text generator to transform.
stop:List[str]- Stop words.
Yields
Iterator[str]- Text generator with stop words applied.
Expand source code
def textgen_iterator(text_generator: Iterator[str], stop: List[str]) -> Iterator[str]: """Make a text generator stop before spitting out the stop words. Args: text_generator (Iterator[str]): Text generator to transform. stop (List[str]): Stop words. Yields: Iterator[str]: Text generator with stop words applied. """ text, output, root = '', '', '' cont = True stop_len = list(map(len, stop)) for i in text_generator: temp = text + root + i text, root = find_roots(temp, stop, stop_len) if root in stop: cont = False token = text.removeprefix(output) output += token if cont: yield token else: yield '' if root not in stop: yield root else: yield ''