Source code for docarray.array.mixins.evaluation

import warnings
from typing import Optional, Union, TYPE_CHECKING, Callable, List, Dict, Tuple, Any

from functools import wraps

import numpy as np
from collections import defaultdict, Counter

from docarray.score import NamedScore

if TYPE_CHECKING:  # pragma: no cover
    from docarray import Document, DocumentArray
    from docarray.array.mixins.embed import CollateFnType
    from docarray.typing import ArrayType, AnyDNN


def _evaluate_deprecation(f):
    """Raises a deprecation warning if the user executes the evaluate function with
    the old interface and adjust the input to fit the new interface."""

    @wraps(f)
    def func(*args, **kwargs):
        if len(args) > 1:
            if not (
                isinstance(args[1], Callable)
                or isinstance(args[1], str)
                or isinstance(args[1], list)
            ):
                kwargs['ground_truth'] = args[1]
                args = [args[0]] + list(args[2:])
                warnings.warn(
                    'The `other` attribute in `evaluate()` is transfered from a '
                    'positional attribute into the keyword attribute `ground_truth`.'
                    'Using it as a positional attribute is deprecated and will be removed '
                    'soon.',
                    DeprecationWarning,
                )
        for old_key, new_key in zip(
            ['other', 'metric', 'metric_name'],
            ['ground_truth', 'metrics', 'metric_names'],
        ):
            if old_key in kwargs:
                kwargs[new_key] = kwargs[old_key]
                warnings.warn(
                    f'`{old_key}` is renamed to `{new_key}` in `evaluate()`, the '
                    f'usage of `{old_key}` is deprecated and will be removed soon.',
                    DeprecationWarning,
                )

        # transfer metrics and metric_names into lists
        list_warning_msg = (
            'The attribute `%s` now accepts a list instead of a '
            'single element. Passing a single element is deprecated and will soon not '
            'be supported anymore.'
        )
        if len(args) > 1:
            if type(args[1]) is str:
                args = list(args)
                args[1] = [args[1]]
                warnings.warn(list_warning_msg % 'metrics', DeprecationWarning)
        for key in ['metrics', 'metric_names']:
            if key in kwargs and type(kwargs[key]) is str:
                kwargs[key] = [kwargs[key]]
                warnings.warn(list_warning_msg % key, DeprecationWarning)
        return f(*args, **kwargs)

    return func


[docs]class EvaluationMixin: """A mixin that provides ranking evaluation functionality to DocumentArrayLike objects"""
[docs] @_evaluate_deprecation def evaluate( self, metrics: List[Union[str, Callable[..., float]]], ground_truth: Optional['DocumentArray'] = None, hash_fn: Optional[Callable[['Document'], str]] = None, metric_names: Optional[List[str]] = None, strict: bool = True, label_tag: str = 'label', num_relevant_documents_per_label: Optional[Dict[Any, int]] = None, **kwargs, ) -> Dict[str, float]: """ Compute ranking evaluation metrics for a given `DocumentArray` when compared with a ground truth. If one provides a `ground_truth` DocumentArray that is structurally identical to `self`, this function compares the `matches` of `documents` inside the `DocumentArray` to this `ground_truth`. Alternatively, one can directly annotate the documents by adding labels in the form of tags with the key specified in the `label_tag` attribute. Those tags need to be added to `self` as well as to the documents in the matches properties. This method will fill the `evaluations` field of Documents inside this `DocumentArray` and will return the average of the computations :param metrics: List of metric names or metric functions to be computed :param ground_truth: The ground_truth `DocumentArray` that the `DocumentArray` compares to. :param hash_fn: For the evaluation against a `ground_truth` DocumentArray, this function is used for generating hashes which are used to compare the documents. If not given, ``Document.id`` is used. :param metric_names: If provided, the results of the metrics computation will be stored in the `evaluations` field of each Document with this names. If not provided, the names will be derived from the metric function names. :param strict: If set, then left and right sides are required to be fully aligned: on the length, and on the semantic of length. These are preventing you to evaluate on irrelevant matches accidentally. :param label_tag: Specifies the tag which contains the labels. :param num_relevant_documents_per_label: Some metrics, e.g., recall@k, require the number of relevant documents. To apply those to a labeled dataset, one can provide a dictionary which maps labels to the total number of documents with this label. :param kwargs: Additional keyword arguments to be passed to the metric functions. :return: A dictionary which stores for each metric name the average evaluation score. """ if len(self) == 0: raise ValueError('It is not possible to evaluate an empty DocumentArray') if ground_truth and len(ground_truth) > 0 and ground_truth[0].matches: ground_truth_type = 'matches' elif label_tag in self[0].tags: if ground_truth: warnings.warn( 'A ground_truth attribute is provided but does not ' 'contain matches. The labels are used instead and ' 'ground_truth is ignored.' ) ground_truth = self ground_truth_type = 'labels' else: raise RuntimeError( 'Could not find proper ground truth data. Either labels or the ' 'ground_truth attribute with matches is required' ) if strict: self._check_length(len(ground_truth)) if hash_fn is None: hash_fn = lambda d: d.id metric_fns = [] for metric in metrics: if callable(metric): metric_fns.append(metric) elif isinstance(metric, str): from docarray.math import evaluation metric_fns.append(getattr(evaluation, metric)) if not metric_names: metric_names = [metric_fn.__name__ for metric_fn in metric_fns] if len(metric_names) != len(metrics): raise ValueError( 'Could not match metric names to the metrics since the number of ' 'metric names does not match the number of metrics' ) results = defaultdict(list) caller_max_rel = kwargs.pop('max_rel', None) for d, gd in zip(self, ground_truth): if caller_max_rel: max_rel = caller_max_rel elif ground_truth_type == 'labels': if num_relevant_documents_per_label: max_rel = num_relevant_documents_per_label.get( d.tags[label_tag], None ) if max_rel is None: raise ValueError( '`num_relevant_documents_per_label` misses the label ' + str(d.tags[label_tag]) ) else: max_rel = None else: max_rel = len(gd.matches) if strict and hash_fn(d) != hash_fn(gd): raise ValueError( f'Document {d} from the left-hand side and ' f'{gd} from the right-hand are not hashed to the same value. ' f'This means your left and right DocumentArray may not be aligned; or it means your ' f'`hash_fn` is badly designed.' ) if not d.matches or not gd.matches: raise ValueError( f'Document {d!r} or {gd!r} has no matches, please check your Document' ) targets = gd.matches if ground_truth_type == 'matches': desired = {hash_fn(m) for m in targets} if len(desired) != len(targets): warnings.warn( f'{hash_fn!r} may not be valid, as it maps multiple Documents into the same hash. ' f'Evaluation results may be affected' ) binary_relevance = [ 1 if hash_fn(m) in desired else 0 for m in d.matches ] elif ground_truth_type == 'labels': binary_relevance = [ 1 if m.tags[label_tag] == d.tags[label_tag] else 0 for m in targets ] else: raise RuntimeError( 'Could not identify which kind of ground truth' 'information is provided to evaluate the matches.' ) for metric_name, metric_fn in zip(metric_names, metric_fns): if 'max_rel' in metric_fn.__code__.co_varnames: kwargs['max_rel'] = max_rel r = metric_fn(binary_relevance, **kwargs) d.evaluations[metric_name] = NamedScore( value=r, op_name=str(metric_fn), ref_id=d.id ) results[metric_name].append(r) return { metric_name: float(np.mean(values)) for metric_name, values in results.items() }
[docs] def embed_and_evaluate( self, metrics: List[Union[str, Callable[..., float]]], index_data: Optional['DocumentArray'] = None, ground_truth: Optional['DocumentArray'] = None, metric_names: Optional[str] = None, strict: bool = True, label_tag: str = 'label', embed_models: Optional[Union['AnyDNN', Tuple['AnyDNN', 'AnyDNN']]] = None, embed_funcs: Optional[Union[Callable, Tuple[Callable, Callable]]] = None, device: str = 'cpu', batch_size: Union[int, Tuple[int, int]] = 256, collate_fns: Union[ Optional['CollateFnType'], Tuple[Optional['CollateFnType'], Optional['CollateFnType']], ] = None, distance: Union[ str, Callable[['ArrayType', 'ArrayType'], 'np.ndarray'] ] = 'cosine', limit: Optional[Union[int, float]] = 20, normalization: Optional[Tuple[float, float]] = None, exclude_self: bool = False, use_scipy: bool = False, num_worker: int = 1, match_batch_size: int = 100_000, query_sample_size: int = 1_000, **kwargs, ) -> Optional[Union[float, List[float]]]: # average for each metric """ Computes ranking evaluation metrics for a given `DocumentArray`. This function does embedding and matching in the same turn. Thus, you don't need to call ``embed`` and ``match`` before it. Instead, it embeds the documents in `self` (and `index_data` when provided`) and compute the nearest neighbour itself. This might be done in batches for the `index_data` object to reduce the memory consumption of the evlauation process. The evaluation itself can be done against a `ground_truth` DocumentArray or on the basis of labels like it is possible with the :func:``evaluate`` function. :param metrics: List of metric names or metric functions to be computed :param index_data: The other DocumentArray to match against, if not given, `self` will be matched against itself. This means that every document in will be compared to all other documents in `self` to determine the nearest neighbors. :param ground_truth: The ground_truth `DocumentArray` that the `DocumentArray` compares to. :param metric_names: If provided, the results of the metrics computation will be stored in the `evaluations` field of each Document with this names. If not provided, the names will be derived from the metric function names. :param strict: If set, then left and right sides are required to be fully aligned: on the length, and on the semantic of length. These are preventing you to evaluate on irrelevant matches accidentally. :param label_tag: Specifies the tag which contains the labels. :param embed_models: One or two embedding model written in Keras / Pytorch / Paddle for embedding `self` and `index_data`. :param embed_funcs: As an alternative to embedding models, custom embedding functions can be provided. :param device: the computational device for `embed_models`, and the matching can be either `cpu` or `cuda`. :param batch_size: Number of documents in a batch for embedding. :param collate_fns: For each embedding function the respective collate function creates a mini-batch of input(s) from the given `DocumentArray`. If not provided a default built-in collate_fn uses the `tensors` of the documents to create input batches. :param distance: The distance metric. :param limit: The maximum number of matches, when not given defaults to 20. :param normalization: A tuple [a, b] to be used with min-max normalization, the min distance will be rescaled to `a`, the max distance will be rescaled to `b` all values will be rescaled into range `[a, b]`. :param exclude_self: If set, Documents in ``index_data`` with same ``id`` as the left-hand values will not be considered as matches. :param use_scipy: if set, use ``scipy`` as the computation backend. Note, ``scipy`` does not support distance on sparse matrix. :param num_worker: Specifies the number of workers for the execution of the match function. :parma match_batch_size: The number of documents which are embedded and matched at once. Set this value to a lower value, if you experience high memory consumption. :param kwargs: Additional keyword arguments to be passed to the metric functions. :param query_sample_size: For a large number of documents in `self` the evaluation becomes infeasible, especially, if `index_data` is large. Therefore, queries are sampled if the number of documents in `self` exceeds `query_sample_size`. Usually, this has only small impact on the mean metric values returned by this function. To prevent sampling, you can set `query_sample_size` to None. :return: A dictionary which stores for each metric name the average evaluation score. """ from docarray import Document, DocumentArray if not query_sample_size: query_sample_size = len(self) query_data = self only_one_dataset = not index_data apply_sampling = len(self) > query_sample_size if only_one_dataset: # if the user does not provide a separate set of documents for indexing, # the matching is done on the documents itself copy_flag = ( apply_sampling or (type(embed_funcs) is tuple) or ((embed_funcs is None) and (type(embed_models) is tuple)) ) index_data = DocumentArray(self, copy=True) if copy_flag else self if apply_sampling: rng = np.random.default_rng() query_data = DocumentArray( rng.choice(self, size=query_sample_size, replace=False) ) if ground_truth and apply_sampling: ground_truth = DocumentArray( [ground_truth[d.id] for d in query_data if d.id in ground_truth] ) if len(ground_truth) != len(query_data): raise ValueError( 'The DocumentArray provided in the ground_truth attribute does ' 'not contain all the documents in self.' ) index_data_labels = None if not ground_truth: if not label_tag in query_data[0].tags: raise ValueError( 'Either a ground_truth `DocumentArray` or labels are ' 'required for the evaluation.' ) if not label_tag in index_data[0].tags: raise ValueError( 'The `DocumentArray` provided in `index_data` misses ' 'labels.' ) index_data_labels = dict() for id_value, tags in zip(index_data[:, 'id'], index_data[:, 'tags']): index_data_labels[id_value] = tags[label_tag] if embed_funcs is None: # derive embed function from embed model if embed_models is None: raise RuntimeError( 'For embedding the documents you need to provide either embedding ' 'model(s) or embedding function(s)' ) else: if type(embed_models) is not tuple: embed_models = (embed_models, embed_models) embed_args = [ { 'embed_model': model, 'device': device, 'batch_size': batch_size, 'collate_fn': collate_fns[i] if type(collate_fns) is tuple else collate_fns, } for i, (model, docs) in enumerate( zip(embed_models, (query_data, index_data)) ) ] else: if type(embed_funcs) is not tuple: embed_funcs = ( embed_funcs, embed_funcs, ) # use the same embedding function for queries and index # embed queries: if embed_funcs: embed_funcs[0](query_data) else: query_data.embed(**embed_args[0]) for doc in query_data: doc.matches.clear() local_queries = DocumentArray( [Document(id=doc.id, embedding=doc.embedding) for doc in query_data] ) def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): global_matches.extend(local_matches) global_matches = sorted( global_matches, key=lambda x: x.scores[distance].value, )[:limit] return DocumentArray(global_matches) for batch in index_data.batch(match_batch_size): if ( apply_sampling or (batch.embeddings is None) or (batch[0].embedding[0] == 0) ): if embed_funcs: embed_funcs[1](batch) else: batch.embed(**embed_args[1]) local_queries.match( batch, limit=limit, metric=distance, normalization=normalization, exclude_self=exclude_self, use_scipy=use_scipy, num_worker=num_worker, device=device, batch_size=int(len(batch) / num_worker) if num_worker > 1 else None, only_id=True, ) for doc in local_queries: query_data[doc.id, 'matches'] = fuse_matches( query_data[doc.id].matches, doc.matches, ) batch.embeddings = None # set labels if necessary if not ground_truth: for i, doc in enumerate(query_data): new_matches = DocumentArray() for m in doc.matches: m.tags = {label_tag: index_data_labels[m.id]} new_matches.append(m) query_data[doc.id, 'matches'] = new_matches if ( not (ground_truth and len(ground_truth) > 0 and ground_truth[0].matches) and label_tag in index_data[0].tags ): num_relevant_documents_per_label = dict( Counter([d.tags[label_tag] for d in index_data]) ) if only_one_dataset and exclude_self: for k, v in num_relevant_documents_per_label.items(): num_relevant_documents_per_label[k] -= 1 else: num_relevant_documents_per_label = None metrics_resp = query_data.evaluate( ground_truth=ground_truth, metrics=metrics, metric_names=metric_names, strict=strict, label_tag=label_tag, num_relevant_documents_per_label=num_relevant_documents_per_label, **kwargs, ) return metrics_resp