Source code for docarray.array.mixins.group

import random
from collections import defaultdict
from typing import Dict, Any, TYPE_CHECKING, Generator, List

import numpy as np

from docarray.helper import dunder_get

if TYPE_CHECKING:  # pragma: no cover
    from docarray import DocumentArray


[docs]class GroupMixin: """These helpers yield groups of :class:`DocumentArray` from a source :class:`DocumentArray`."""
[docs] def split_by_tag(self, tag: str) -> Dict[Any, 'DocumentArray']: """Split the `DocumentArray` into multiple DocumentArray according to the tag value of each `Document`. :param tag: the tag name to split stored in tags. :return: a dict where Documents with the same value on `tag` are grouped together, their orders are preserved from the original :class:`DocumentArray`. .. note:: If the :attr:`tags` of :class:`Document` do not contains the specified :attr:`tag`, return an empty dict. """ from docarray import DocumentArray rv = defaultdict(DocumentArray) for doc in self: if '__' in tag: value = dunder_get(doc.tags, tag) elif tag in doc.tags: value = doc.tags[tag] else: continue rv[value].append(doc) return dict(rv)
[docs] def batch( self, batch_size: int, shuffle: bool = False, show_progress: bool = False, ) -> Generator['DocumentArray', None, None]: """ Creates a `Generator` that yields `DocumentArray` of size `batch_size` until `docs` is fully traversed along the `traversal_path`. The None `docs` are filtered out and optionally the `docs` can be filtered by checking for the existence of a `Document` attribute. Note, that the last batch might be smaller than `batch_size`. :param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32) :param shuffle: If set, shuffle the Documents before dividing into minibatches. :param show_progress: if set, show a progress bar when batching documents. :yield: a Generator of `DocumentArray`, each in the length of `batch_size` """ from rich.progress import track if not (isinstance(batch_size, int) and batch_size > 0): raise ValueError('`batch_size` should be a positive integer') N = len(self) ix = list(range(N)) n_batches = int(np.ceil(N / batch_size)) if shuffle: random.shuffle(ix) for i in track( range(n_batches), description='Batching documents', disable=not show_progress, ): yield self[ix[i * batch_size : (i + 1) * batch_size]]
[docs] def batch_ids( self, batch_size: int, shuffle: bool = False, ) -> Generator[List[str], None, None]: """ Creates a `Generator` that yields `lists of ids` of size `batch_size` until `self` is fully traversed. Note, that the last batch might be smaller than `batch_size`. :param batch_size: Size of each generated batch (except the last one, which might be smaller) :param shuffle: If set, shuffle the Documents before dividing into minibatches. :yield: a Generator of `list` of IDs, each in the length of `batch_size` """ if not (isinstance(batch_size, int) and batch_size > 0): raise ValueError('`batch_size` should be a positive integer') N = len(self) ix = self[:, 'id'] n_batches = int(np.ceil(N / batch_size)) if shuffle: random.shuffle(ix) for i in range(n_batches): yield ix[i * batch_size : (i + 1) * batch_size]