Source code for docarray.array.storage.milvus.backend

import copy
import uuid
from typing import Optional, TYPE_CHECKING, Union, Dict, Iterable, List, Tuple
from dataclasses import dataclass, field
import re

import numpy as np
from pymilvus import (
    connections,
    Collection,
    FieldSchema,
    DataType,
    CollectionSchema,
    has_collection,
    loading_progress,
)

from docarray import Document, DocumentArray
from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap
from docarray.helper import dataclass_from_dict, _safe_cast_int
from docarray.score import NamedScore

if TYPE_CHECKING:
    from docarray.typing import (
        DocumentArraySourceType,
    )


ID_VARCHAR_LEN = 1024
SERIALIZED_VARCHAR_LEN = (
    65_535  # 65_535 is the maximum that Milvus allows for a VARCHAR field
)
COLUMN_VARCHAR_LEN = 1024
OFFSET_VARCHAR_LEN = 1024


def _always_true_expr(primary_key: str) -> str:
    """
    Returns a Milvus expression that is always true, thus allowing for the retrieval of all entries in a Collection
    Assumes that the primary key is of type DataType.VARCHAR

    :param primary_key: the name of the primary key
    :return: a Milvus expression that is always true for that primary key
    """
    return f'({primary_key} in ["1"]) or ({primary_key} not in ["1"])'


def _ids_to_milvus_expr(ids):
    ids = ['"' + _id + '"' for _id in ids]
    return '[' + ','.join(ids) + ']'


def _batch_list(l: List, batch_size: int):
    """Iterates over a list in batches of size batch_size"""
    if batch_size < 1:
        yield l
        return
    l_len = len(l)
    for ndx in range(0, l_len, batch_size):
        yield l[ndx : min(ndx + batch_size, l_len)]


def _sanitize_collection_name(name):
    """Removes all chars that are not allowed in a Milvus collection name.
    Thus, it removes all chars that are not alphanumeric or an underscore.

    :param name: the collection name to sanitize
    :return: the sanitized collection name.
    """
    return ''.join(
        re.findall('[a-zA-Z0-9_]', name)
    )  # remove everything that is not a letter, number or underscore


[docs]@dataclass class MilvusConfig: n_dim: int collection_name: str = None host: str = 'localhost' port: Optional[Union[str, int]] = 19530 # 19530 for gRPC, 9091 for HTTP distance: str = 'IP' # metric_type in milvus index_type: str = 'HNSW' index_params: Dict = field( default_factory=lambda: { 'M': 4, 'efConstruction': 200, } ) # passed to milvus at index creation time. The default assumes 'HNSW' index type collection_config: Dict = field( default_factory=dict ) # passed to milvus at collection creation time serialize_config: Dict = field(default_factory=dict) consistency_level: str = 'Session' batch_size: int = -1 columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None list_like: bool = True root_id: bool = True
[docs]class BackendMixin(BaseBackendMixin): TYPE_MAP = { 'str': TypeMap(type=DataType.VARCHAR, converter=str), 'float': TypeMap( type=DataType.DOUBLE, converter=float ), # it doesn't like DataType.FLOAT type, perhaps because python floats are double precision? 'double': TypeMap(type=DataType.DOUBLE, converter=float), 'int': TypeMap(type=DataType.INT64, converter=_safe_cast_int), 'bool': TypeMap(type=DataType.BOOL, converter=bool), } def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, config: Optional[Union[MilvusConfig, Dict]] = None, **kwargs, ): config = copy.deepcopy(config) if not config: raise ValueError('Empty config is not allowed for Milvus storage') elif isinstance(config, dict): config = dataclass_from_dict(MilvusConfig, config) if config.collection_name is None: id = uuid.uuid4().hex config.collection_name = 'docarray__' + id self._list_like = config.list_like self._config = config self._config.columns = self._normalize_columns(self._config.columns) self._connection_alias = ( f'docarray_{config.host}_{config.port}_{uuid.uuid4().hex}' ) connections.connect( alias=self._connection_alias, host=config.host, port=config.port ) self._collection = self._create_or_reuse_collection() self._offset2id_collection = self._create_or_reuse_offset2id_collection() self._build_index() super()._init_storage(**kwargs) # To align with Sqlite behavior; if `docs` is not `None` and table name # is provided, :class:`DocumentArraySqlite` will clear the existing # table and load the given `docs` if _docs is None: return self.clear() if isinstance(_docs, Iterable): self.extend(_docs) else: if isinstance(_docs, Document): self.append(_docs) def _create_or_reuse_collection(self): if has_collection(self._config.collection_name, using=self._connection_alias): return Collection( self._config.collection_name, using=self._connection_alias ) document_id = FieldSchema( name='document_id', dtype=DataType.VARCHAR, max_length=ID_VARCHAR_LEN, is_primary=True, ) embedding = FieldSchema( name='embedding', dtype=DataType.FLOAT_VECTOR, dim=self._config.n_dim ) serialized = FieldSchema( name='serialized', dtype=DataType.VARCHAR, max_length=SERIALIZED_VARCHAR_LEN ) additional_columns = [] for col, coltype in self._config.columns.items(): mapped_type = self._map_type(coltype) if mapped_type == DataType.VARCHAR: field_ = FieldSchema( name=col, dtype=mapped_type, max_length=COLUMN_VARCHAR_LEN ) else: field_ = FieldSchema(name=col, dtype=mapped_type) additional_columns.append(field_) schema = CollectionSchema( fields=[document_id, embedding, serialized, *additional_columns], description='DocumentArray collection schema', ) return Collection( name=self._config.collection_name, schema=schema, using=self._connection_alias, **self._config.collection_config, ) def _build_index(self): index_params = { 'metric_type': self._config.distance, 'index_type': self._config.index_type, 'params': self._config.index_params, } self._collection.create_index(field_name='embedding', index_params=index_params) def _create_or_reuse_offset2id_collection(self): if has_collection( self._config.collection_name + '_offset2id', using=self._connection_alias ): return Collection( self._config.collection_name + '_offset2id', using=self._connection_alias, ) document_id = FieldSchema( name='document_id', dtype=DataType.VARCHAR, max_length=ID_VARCHAR_LEN ) offset = FieldSchema( name='offset', dtype=DataType.VARCHAR, max_length=OFFSET_VARCHAR_LEN, is_primary=True, ) dummy_vector = FieldSchema( name='dummy_vector', dtype=DataType.FLOAT_VECTOR, dim=1 ) schema = CollectionSchema( fields=[offset, document_id, dummy_vector], description='offset2id for DocumentArray', ) return Collection( name=self._config.collection_name + '_offset2id', schema=schema, using=self._connection_alias, # **self._config.collection_config, # we probably don't want to apply the same config here ) def _ensure_unique_config( self, config_root: dict, config_subindex: dict, config_joined: dict, subindex_name: str, ) -> dict: if 'collection_name' not in config_subindex: config_joined['collection_name'] = _sanitize_collection_name( config_joined['collection_name'] + '_subindex_' + subindex_name ) return config_joined def _doc_to_milvus_payload(self, doc): return self._docs_to_milvus_payload([doc]) def _docs_to_milvus_payload(self, docs: 'Iterable[Document]'): extra_columns = [ [self._map_column(doc.tags.get(col), col_type) for doc in docs] for col, col_type in self._config.columns.items() ] return [ [doc.id for doc in docs], [self._map_embedding(doc.embedding) for doc in docs], [doc.to_base64(**self._config.serialize_config) for doc in docs], *extra_columns, ] @staticmethod def _docs_from_query_response(response): return DocumentArray([Document.from_base64(d['serialized']) for d in response]) @staticmethod def _docs_from_search_response(responses, distance: str) -> 'List[DocumentArray]': das = [] for r in responses: da = [] for hit in r: doc = Document.from_base64(hit.entity.get('serialized')) doc.scores[distance] = NamedScore(value=hit.score) da.append(doc) das.append(DocumentArray(da)) return das def _update_kwargs_from_config(self, field_to_update, **kwargs): kwargs_field_value = kwargs.get(field_to_update, None) config_field_value = getattr(self._config, field_to_update, None) if ( kwargs_field_value is not None or config_field_value is None ): # no need to update return kwargs kwargs[field_to_update] = config_field_value return kwargs def _map_embedding(self, embedding): if embedding is not None: from docarray.math.ndarray import to_numpy_array embedding = to_numpy_array(embedding) if embedding.ndim > 1: embedding = np.asarray(embedding).squeeze() else: embedding = np.zeros(self._config.n_dim) return embedding def __getstate__(self): d = dict(self.__dict__) del d['_collection'] del d['_offset2id_collection'] return d def __setstate__(self, state): self.__dict__ = state connections.connect( alias=self._connection_alias, host=self._config.host, port=self._config.port ) self._collection = self._create_or_reuse_collection() self._offset2id_collection = self._create_or_reuse_offset2id_collection() def __enter__(self): _ = super().__enter__() self._collection.load() self._offset2id_collection.load() return self def __exit__(self, exc_type, exc_val, exc_tb): self._collection.release() self._offset2id_collection.release() super().__exit__(exc_type, exc_val, exc_tb)
[docs] def loaded_collection(self, collection=None): """ Context manager to load a collection and release it after the context is exited. If the collection is already loaded when entering, it will not be released while exiting. :param collection: the collection to load. If None, the main collection of this indexer is used. :return: Context manager for the provided collection. """ class LoadedCollectionManager: def __init__(self, coll, connection_alias): self._collection = coll self._loaded_when_enter = False self._connection_alias = connection_alias def __enter__(self): self._loaded_when_enter = ( loading_progress( self._collection.name, using=self._connection_alias )['loading_progress'] != '0%' ) self._collection.load() return self def __exit__(self, exc_type, exc_val, exc_tb): if not self._loaded_when_enter: self._collection.release() return LoadedCollectionManager( collection if collection else self._collection, self._connection_alias )