Source code for docarray.array.storage.qdrant.backend

import copy
import uuid
from abc import abstractmethod
from dataclasses import dataclass, field, asdict
from typing import (
    Optional,
    TYPE_CHECKING,
    Union,
    Dict,
    Iterable,
    List,
    Tuple,
)

import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models.models import (
    Distance,
    CreateCollection,
    PointsList,
    PointStruct,
    HnswConfigDiff,
    VectorParams,
)

from docarray import Document
from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap
from docarray.array.storage.qdrant.helper import DISTANCES
from docarray.helper import dataclass_from_dict, random_identity
from docarray.math.helper import EPSILON

if TYPE_CHECKING:  # pragma: no cover
    from docarray.typing import DocumentArraySourceType, ArrayType


[docs]@dataclass class QdrantConfig: n_dim: int distance: str = 'cosine' collection_name: Optional[str] = None list_like: bool = True host: Optional[str] = field(default="localhost") port: Optional[int] = field(default=6333) grpc_port: Optional[int] = field(default=6334) prefer_grpc: Optional[bool] = field(default=False) api_key: Optional[str] = field(default=None) https: Optional[bool] = field(default=None) serialize_config: Dict = field(default_factory=dict) scroll_batch_size: int = 64 ef_construct: Optional[int] = None full_scan_threshold: Optional[int] = None m: Optional[int] = None columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None root_id: bool = True
[docs]class BackendMixin(BaseBackendMixin): @property @abstractmethod def client(self) -> 'QdrantClient': raise NotImplementedError() @property @abstractmethod def collection_name(self) -> str: raise NotImplementedError() @property @abstractmethod def distance(self) -> 'Distance': raise NotImplementedError() @classmethod def _tmp_collection_name(cls) -> str: return uuid.uuid4().hex TYPE_MAP = { 'int': TypeMap(type='integer', converter=int), 'float': TypeMap(type='float', converter=float), 'bool': TypeMap(type='int', converter=bool), 'str': TypeMap(type='keyword', converter=str), 'text': TypeMap(type='text', converter=str), 'geo': TypeMap(type='geo', converter=dict), } def _init_storage( self, docs: Optional['DocumentArraySourceType'] = None, config: Optional[Union[QdrantConfig, Dict]] = None, **kwargs, ): """Initialize qdrant storage. :param docs: the list of documents to initialize to :param config: the config object used to initialize connection to qdrant server :param kwargs: extra keyword arguments :raises ValueError: only one of name or docs can be used for initialization, raise an error if both are provided """ config = copy.deepcopy(config) self._schemas = None if not config: raise ValueError('Empty config is not allowed for Qdrant storage') elif isinstance(config, dict): config = dataclass_from_dict(QdrantConfig, config) if config.distance not in DISTANCES.keys(): raise ValueError( f'Invalid distance parameter, must be one of: {", ".join(DISTANCES.keys())}' ) if not config.collection_name: config.collection_name = self._tmp_collection_name() self._n_dim = config.n_dim self._distance = config.distance self._serialize_config = config.serialize_config self._client = QdrantClient( host=config.host, port=config.port, prefer_grpc=config.prefer_grpc, grpc_port=config.grpc_port, api_key=config.api_key, https=config.https, ) self._config = config self._list_like = config.list_like self._config.columns = self._normalize_columns(self._config.columns) self._config.collection_name = ( self.__class__.__name__ + random_identity() if self._config.collection_name is None else self._config.collection_name ) self._initialize_qdrant_schema() super()._init_storage(**kwargs) if docs is None and config.collection_name: return # To align with Sqlite behavior; if `docs` is not `None` and table name # is provided, :class:`DocumentArraySqlite` will clear the existing # table and load the given `docs` self.clear() if isinstance(docs, Iterable): self.extend(docs) elif isinstance(docs, Document): self.append(docs) def _ensure_unique_config( self, config_root: dict, config_subindex: dict, config_joined: dict, subindex_name: str, ) -> dict: if 'collection_name' not in config_subindex: config_joined['collection_name'] = ( config_joined['collection_name'] + '_subindex_' + subindex_name ) return config_joined def _initialize_qdrant_schema(self): if not self._collection_exists(self.collection_name): hnsw_config = HnswConfigDiff( ef_construct=self._config.ef_construct, full_scan_threshold=self._config.full_scan_threshold, m=self._config.m, ) self.client.recreate_collection( collection_name=self.collection_name, vectors_config=VectorParams( size=self.n_dim, distance=self.distance, ), hnsw_config=hnsw_config, ) for col, coltype in self._config.columns.items(): if coltype == 'text': self.client.create_payload_index( collection_name=self.collection_name, field_name=col, field_schema=models.TextIndexParams( type="text", tokenizer=models.TokenizerType.WORD, ), ) else: self.client.create_payload_index( collection_name=self.collection_name, field_name=col, field_schema=self._map_type(coltype), ) def _collection_exists(self, collection_name): resp = self.client.get_collections() collections = [collection.name for collection in resp.collections] return collection_name in collections @staticmethod def _map_id(doc_id: str): # if doc_id is a random ID in hex format, just translate back to UUID str # otherwise, create UUID5 from doc_id try: return str(uuid.UUID(hex=doc_id)) except ValueError: return str(uuid.uuid5(uuid.NAMESPACE_URL, doc_id)) def __getstate__(self): d = dict(self.__dict__) del d['_client'] return d def __setstate__(self, state): self.__dict__ = state self._client = QdrantClient( host=state['_config'].host, port=state['_config'].port, prefer_grpc=state['_config'].prefer_grpc, grpc_port=state['_config'].grpc_port, api_key=state['_config'].api_key, https=state['_config'].https, ) def _get_offset2ids_meta(self) -> List[str]: if not self._collection_exists(self.collection_name_meta): return [] return self.client.retrieve(self.collection_name_meta, ids=[1])[0].payload.get( 'offset2id', [] ) def _update_offset2ids_meta(self): if not self._collection_exists(self.collection_name_meta): self.client.recreate_collection( collection_name=self.collection_name_meta, vectors_config={}, # no vectors ) self.client.upsert( collection_name=self.collection_name_meta, points=[ PointStruct( id=1, payload={"offset2id": self._offset2ids.ids}, vector={} ) ], wait=True, ) def _map_embedding(self, embedding: 'ArrayType') -> List[float]: if embedding is None: embedding = np.random.rand(self.n_dim) else: from docarray.math.ndarray import to_numpy_array embedding = to_numpy_array(embedding) if embedding.ndim > 1: embedding = np.asarray(embedding).squeeze() if embedding.ndim == 0: # scalar embedding = np.array([embedding]) if np.all(embedding == 0): embedding = embedding + EPSILON return embedding.astype(float).tolist()