Source code for docarray.array.storage.qdrant.find

from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, Union

from docarray import Document, DocumentArray
from docarray.math import ndarray
from docarray.score import NamedScore
from qdrant_client.http import models
from qdrant_client.http.models.models import Distance

if TYPE_CHECKING:  # pragma: no cover
    import numpy as np
    import tensorflow
    import torch
    from qdrant_client import QdrantClient

    QdrantArrayType = TypeVar(
        'QdrantArrayType',
        np.ndarray,
        tensorflow.Tensor,
        torch.Tensor,
        Sequence[float],
    )


[docs]class FindMixin: @property @abstractmethod def client(self) -> 'QdrantClient': raise NotImplementedError() @property @abstractmethod def collection_name(self) -> str: raise NotImplementedError() @property @abstractmethod def serialize_config(self) -> dict: raise NotImplementedError() @property @abstractmethod def distance(self) -> 'Distance': raise NotImplementedError() def _find_similar_vectors( self, q: 'QdrantArrayType', limit: int = 10, filter: Optional[Dict] = None, search_params: Optional[Dict] = None, **kwargs, ): query_vector = self._map_embedding(q) search_result = self.client.search( self.collection_name, query_vector=query_vector, query_filter=filter, search_params=None if not search_params else models.SearchParams(**search_params), limit=limit, append_payload=['_serialized'], ) docs = [] for hit in search_result: doc = Document.from_base64( hit.payload['_serialized'], **self.serialize_config ) doc.scores[f'{self.distance.lower()}_similarity'] = NamedScore( value=hit.score ) docs.append(doc) return DocumentArray(docs) def _find( self, query: 'QdrantArrayType', limit: int = 10, filter: Optional[Dict] = None, search_params: Optional[Dict] = None, **kwargs, ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be used in Qdrant. :param limit: number of retrieved items :param filter: filter query used for pre-filtering :param search_params: additional parameters of the search :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ num_rows, _ = ndarray.get_array_rows(query) if num_rows == 1: return [ self._find_similar_vectors( query, limit=limit, filter=filter, search_params=search_params ) ] else: closest_docs = [] for q in query: da = self._find_similar_vectors( q, limit=limit, filter=filter, search_params=search_params ) closest_docs.append(da) return closest_docs def _find_with_filter( self, filter: Optional[Dict], limit: Optional[Union[int, float]] = 10 ): list_of_points, _offset = self.client.scroll( collection_name=self.collection_name, scroll_filter=models.Filter(**filter), with_payload=True, limit=limit, ) da = DocumentArray() for result in list_of_points[:limit]: doc = Document.from_base64( result.payload['_serialized'], **self.serialize_config ) da.append(doc) return da def _filter( self, filter: Optional[Dict], limit: Optional[Union[int, float]] = 10 ) -> 'DocumentArray': """Returns a subset of documents by filtering by the given filter (`Qdrant` filter).. :param limit: number of retrieved items :param filter: filter query used for filtering. For more information: https://docs.docarray.org/advanced/document-store/qdrant/#qdrant :return: a `DocumentArray` containing the `Document` objects that verify the filter. """ return self._find_with_filter(filter, limit=limit)