Source code for ragoon.datasets

# -*- coding: utf-8 -*-
# Copyright (c) Louis Brulé Naudet. All Rights Reserved.
# This software may be used and distributed according to the terms of License Agreement.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import concurrent.futures

import datasets

from typing import (
    IO,
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Type,
    Tuple,
    Union,
    Mapping,
    TypeVar,
    Callable,
    Optional,
    Sequence,
)

from datasets import load_dataset
from tqdm import tqdm

from ragoon._logger import Logger

logger = Logger()


[docs] def dataset_loader( name: str, streaming: Optional[bool] = True, split: Optional[Union[str, List[str]]] = None ) -> datasets.Dataset: """ Helper function to load a single dataset in parallel. Parameters ---------- name : str Name of the dataset to be loaded. streaming : bool, optional Determines if datasets are streamed. Default is True. split : Optional[Union[str, List[str]]], optional Which split of the data to load. If None, will return a dict with all splits (typically datasets.Split.TRAIN and datasets.Split.TEST). If given, will return a single Dataset. Splits can be combined and specified like in tensorflow-datasets. Returns ------- dataset : datasets.Dataset Loaded dataset object. Raises ------ Exception If an error occurs during dataset loading. """ try: return load_dataset( name, streaming=streaming, split=split ) except Exception as exc: logger.error(f"Error loading dataset {name}: {exc}") return None
[docs] def load_datasets( req: list, streaming: Optional[bool] = False, ) -> list: """ Downloads datasets specified in a list and creates a list of loaded datasets. Parameters ---------- req : list A list containing the names of datasets to be downloaded. streaming : bool, optional Determines if datasets are streamed. Default is False. Returns ------- datasets_list : list A list containing loaded datasets as per the requested names provided in 'req'. Raises ------ Exception If an error occurs during dataset loading or processing. Examples -------- >>> req = [ ... "louisbrulenaudet/code-artisanat", ... "louisbrulenaudet/code-action-sociale-familles", ... # ... ] >>> datasets_list = load_datasets( ... req=req, ... streaming=True ) >>> dataset = datasets.concatenate_datasets( ... datasets_list ) """ datasets_list = [] with concurrent.futures.ThreadPoolExecutor() as executor: future_to_dataset = { executor.submit(dataset_loader, name, streaming): name for name in req } for future in tqdm( concurrent.futures.as_completed(future_to_dataset), total=len(req) ): name = future_to_dataset[future] try: dataset = future.result() if dataset: datasets_list.append(dataset) except Exception as exc: logger.error(f"Error processing dataset {name}: {exc}") return datasets_list