Module llmflex.Embeddings.hf_embeddings_server
Expand source code
import os
class HuggingFaceEmbeddingsServer:
def __init__(self, model_id: str = 'thenlper/gte-small', default_batch_size: int = 128, **kwargs) -> None:
"""Initialising the model server.
Args:
model_id (str, optional): Huggingface repo id. Defaults to 'thenlper/gte-small'.
default_batch_size (int, optional): Default batch size for encoding if not specified on the client side. Defaults to 128.
"""
from flask import Flask
from ..utils import get_config
from sentence_transformers import SentenceTransformer
os.environ['SENTENCE_TRANSFORMERS_HOME'] = get_config()['st_home']
os.environ['HF_HOME'] = get_config()['hf_home']
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
self.model = SentenceTransformer(model_id, **kwargs)
self.default_batch_size = default_batch_size
self.info = dict(
model_id=model_id,
embedding_dimension=self.model.get_sentence_embedding_dimension(),
max_seq_length=self.model.max_seq_length,
device=str(self.model.device),
default_batch_size=default_batch_size
)
self.app = Flask(__name__)
def run(self, **kwargs) -> None:
"""Start the server.
"""
from flask import request, jsonify
import torch
@self.app.route('/embeddings', methods=['GET'])
def get_embeddings():
args_dict = request.json
input_texts = args_dict.get('input_texts')
batch_size = args_dict.get('batch_size', self.default_batch_size)
normalize_embddings = args_dict.get('normalize_embeddings', True)
embeddings = self.model.encode(input_texts, batch_size=batch_size, normalize_embeddings=normalize_embddings).tolist()
if self.model.device.type == 'mps':
torch.mps.empty_cache()
elif self.model.device.type == 'cuda':
torch.cuda.empty_cache()
else:
import gc
gc.collect()
return jsonify(embeddings)
@self.app.route('/info', methods=['GET'])
def get_info():
return jsonify(self.info)
self.app.run(**kwargs)
Classes
class HuggingFaceEmbeddingsServer (model_id: str = 'thenlper/gte-small', default_batch_size: int = 128, **kwargs)-
Initialising the model server.
Args
model_id:str, optional- Huggingface repo id. Defaults to 'thenlper/gte-small'.
default_batch_size:int, optional- Default batch size for encoding if not specified on the client side. Defaults to 128.
Expand source code
class HuggingFaceEmbeddingsServer: def __init__(self, model_id: str = 'thenlper/gte-small', default_batch_size: int = 128, **kwargs) -> None: """Initialising the model server. Args: model_id (str, optional): Huggingface repo id. Defaults to 'thenlper/gte-small'. default_batch_size (int, optional): Default batch size for encoding if not specified on the client side. Defaults to 128. """ from flask import Flask from ..utils import get_config from sentence_transformers import SentenceTransformer os.environ['SENTENCE_TRANSFORMERS_HOME'] = get_config()['st_home'] os.environ['HF_HOME'] = get_config()['hf_home'] os.environ['TOKENIZERS_PARALLELISM'] = 'true' self.model = SentenceTransformer(model_id, **kwargs) self.default_batch_size = default_batch_size self.info = dict( model_id=model_id, embedding_dimension=self.model.get_sentence_embedding_dimension(), max_seq_length=self.model.max_seq_length, device=str(self.model.device), default_batch_size=default_batch_size ) self.app = Flask(__name__) def run(self, **kwargs) -> None: """Start the server. """ from flask import request, jsonify import torch @self.app.route('/embeddings', methods=['GET']) def get_embeddings(): args_dict = request.json input_texts = args_dict.get('input_texts') batch_size = args_dict.get('batch_size', self.default_batch_size) normalize_embddings = args_dict.get('normalize_embeddings', True) embeddings = self.model.encode(input_texts, batch_size=batch_size, normalize_embeddings=normalize_embddings).tolist() if self.model.device.type == 'mps': torch.mps.empty_cache() elif self.model.device.type == 'cuda': torch.cuda.empty_cache() else: import gc gc.collect() return jsonify(embeddings) @self.app.route('/info', methods=['GET']) def get_info(): return jsonify(self.info) self.app.run(**kwargs)Methods
def run(self, **kwargs) ‑> None-
Start the server.
Expand source code
def run(self, **kwargs) -> None: """Start the server. """ from flask import request, jsonify import torch @self.app.route('/embeddings', methods=['GET']) def get_embeddings(): args_dict = request.json input_texts = args_dict.get('input_texts') batch_size = args_dict.get('batch_size', self.default_batch_size) normalize_embddings = args_dict.get('normalize_embeddings', True) embeddings = self.model.encode(input_texts, batch_size=batch_size, normalize_embeddings=normalize_embddings).tolist() if self.model.device.type == 'mps': torch.mps.empty_cache() elif self.model.device.type == 'cuda': torch.cuda.empty_cache() else: import gc gc.collect() return jsonify(embeddings) @self.app.route('/info', methods=['GET']) def get_info(): return jsonify(self.info) self.app.run(**kwargs)