Source code for docarray.array.storage.weaviate.find

from typing import (
    TYPE_CHECKING,
    TypeVar,
    Sequence,
    List,
    Dict,
    Optional,
    Union,
)

import numpy as np

from docarray import Document, DocumentArray
from docarray.math import ndarray
from docarray.math.helper import EPSILON
from docarray.math.ndarray import to_numpy_array
from docarray.score import NamedScore

if TYPE_CHECKING:  # pragma: no cover
    import tensorflow
    import torch

    WeaviateArrayType = TypeVar(
        'WeaviateArrayType',
        np.ndarray,
        tensorflow.Tensor,
        torch.Tensor,
        Sequence[float],
    )


[docs]class FindMixin: def _find_similar_vectors( self, query: 'WeaviateArrayType', limit=10, filter: Optional[Dict] = None, additional: Optional[List] = None, sort: Optional[Union[Dict, List]] = None, query_params: Optional[Dict] = None, ): """Returns a subset of documents by the given vector. :param query: input supported to be stored in Weaviate. This includes any from the list '[np.ndarray, tensorflow.Tensor, torch.Tensor, Sequence[float]]' :param limit: number of retrieved items :param filter: the input filter to apply in each stored document :param additional: Optional Weaviate flags for meta data :param sort: sort parameters performed on matches performed on results :param query_params: additional parameters applied to the query outside of the where clause :return: a `DocumentArray` containing the `Document` objects that verify the filter. """ query = to_numpy_array(query) is_all_zero = np.all(query == 0) if is_all_zero: query = query + EPSILON query_dict = {'vector': query} if query_params: query_dict.update(query_params) _additional = ['id', 'distance'] if additional: _additional = _additional + additional query_builder = ( self._client.query.get(self._class_name, '_serialized') .with_additional(_additional) .with_limit(limit) .with_near_vector(query_dict) ) if filter is not None: query_builder = query_builder.with_where(filter) if sort is not None: query_builder = query_builder.with_sort(sort) results = query_builder.do() if 'errors' in results: errors = '\n'.join(map(lambda error: error['message'], results['errors'])) raise ValueError( f'find failed, please check your filter query. Errors: \n{errors}' ) found_results = results.get('data', {}).get('Get', {}).get(self._class_name, []) # The serialized document is stored in results['data']['Get'][self._class_name] docs = [] for result in found_results: doc = Document.from_base64(result['_serialized'], **self._serialize_config) distance = result['_additional']['distance'] doc.scores['distance'] = NamedScore(value=distance) certainty = result['_additional'].get('certainty', None) if certainty is not None: doc.scores['weaviate_certainty'] = NamedScore(value=certainty) doc.tags['wid'] = result['_additional']['id'] if additional: for add in additional: doc.tags[f'{add}'] = result['_additional'][add] docs.append(doc) return DocumentArray(docs) def _filter( self, filter: Dict, limit: Optional[Union[int, float]] = 20, additional: Optional[List] = None, sort: Optional[Union[Dict, List]] = None, ) -> 'DocumentArray': """Returns a subset of documents by filtering by the given filter (Weaviate `where` filter). :param filter: the input filter to apply in each stored document :param limit: number of retrieved items :param additional: Optional Weaviate flags for meta data :param sort: sort parameters performed on matches performed on results :return: a `DocumentArray` containing the `Document` objects that verify the filter. """ if not filter: return self _additional = ['id'] if additional: _additional = _additional + additional query_builder = ( self._client.query.get(self._class_name, '_serialized') .with_additional(_additional) .with_where(filter) .with_limit(limit) ) if sort: query_builder = query_builder.with_sort(sort) results = query_builder.do() docs = [] if 'errors' in results: errors = '\n'.join(map(lambda error: error['message'], results['errors'])) raise ValueError( f'filter failed, please check your filter query. Errors: \n{errors}' ) found_results = results.get('data', {}).get('Get', {}).get(self._class_name, []) # The serialized document is stored in results['data']['Get'][self._class_name] for result in found_results: doc = Document.from_base64(result['_serialized'], **self._serialize_config) doc.tags['wid'] = result['_additional']['id'] if additional: for add in additional: doc.tags[f'{add}'] = result['_additional'][add] docs.append(doc) return DocumentArray(docs) def _find( self, query: 'WeaviateArrayType', limit: int = 10, filter: Optional[Dict] = None, additional: Optional[List] = None, sort: Optional[Union[Dict, List]] = None, query_params: Optional[Dict] = None, **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be stored in Weaviate. This includes any from the list '[np.ndarray, tensorflow.Tensor, torch.Tensor, Sequence[float]]' :param limit: number of retrieved items :param filter: filter query used for pre-filtering :param additional: Optional Weaviate flags for meta data :param sort: sort parameters performed on matches performed on results :param query_params: additional parameters applied to the query outside of the where clause :return: DocumentArray containing the closest documents to the query if it is a single query, otherwise a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. Note: Weaviate returns `certainty` values. To get cosine similarities one needs to use `cosine_sim = 2*certainty - 1` as explained here: https://weaviate.io/developers/weaviate/current/more-resources/faq.html#q-how-do-i-get-the-cosine-similarity-from-weaviates-certainty """ num_rows, _ = ndarray.get_array_rows(query) if num_rows == 1: return [ self._find_similar_vectors( query, limit=limit, additional=additional, filter=filter, sort=sort, query_params=query_params, ) ] else: closest_docs = [] for q in query: da = self._find_similar_vectors( q, limit=limit, additional=additional, filter=filter, sort=sort, query_params=query_params, ) closest_docs.append(da) return closest_docs