Source code for docarray.document.mixins.multimodal

import base64
import typing

from docarray.dataclasses.types import (
    is_multimodal,
    _is_field,
    AttributeTypeError,
)
from docarray.dataclasses.enums import DocumentMetadata, AttributeType

if typing.TYPE_CHECKING:
    from docarray import Document, DocumentArray


[docs]class MultiModalMixin: @property def is_multimodal(self) -> bool: """ Return true if this Document can be represented by a class wrapped by :meth:`docarray.dataclasses.types.dataclass`. """ return DocumentMetadata.MULTI_MODAL_SCHEMA in self._metadata @classmethod def _from_dataclass(cls, obj) -> 'Document': if not is_multimodal(obj): raise TypeError( f'Object {type(obj).__name__} is not a `docarray.dataclass` instance' ) from docarray import Document root = Document() tags = {} multi_modal_schema = {} for key, field in obj.__dataclass_fields__.items(): attribute = getattr(obj, key) if attribute is None: continue if field.type in [str, int, float, bool] and not _is_field(field): tags[key] = attribute multi_modal_schema[key] = { 'attribute_type': AttributeType.PRIMITIVE, 'type': field.type.__name__, } elif field.type == bytes and not _is_field(field): tags[key] = base64.b64encode(attribute).decode() multi_modal_schema[key] = { 'attribute_type': AttributeType.PRIMITIVE, 'type': field.type.__name__, } elif isinstance(field.type, typing._GenericAlias): if field.type._name in ['List', 'Iterable']: sub_type = field.type.__args__[0] if sub_type in [str, int, float, bool]: tags[key] = attribute multi_modal_schema[key] = { 'attribute_type': AttributeType.ITERABLE_PRIMITIVE, 'type': f'{field.type._name}[{sub_type.__name__}]', } else: try: attribute_type = cls._get_attribute_type_from_obj_type( sub_type, field ) except AttributeTypeError: raise TypeError( f'Unsupported type annotation inside Iterable: {sub_type}' ) if attribute_type == AttributeType.DOCUMENT: attribute_type = AttributeType.ITERABLE_DOCUMENT elif attribute_type == AttributeType.NESTED: attribute_type = AttributeType.ITERABLE_NESTED chunk = Document() for element in attribute: doc, _ = cls._from_obj(element, sub_type, field) chunk.chunks.append(doc) multi_modal_schema[key] = { 'attribute_type': attribute_type, 'type': f'{field.type._name}[{sub_type.__name__}]', 'position': len(root.chunks), } root.chunks.append(chunk) else: raise TypeError( f'Unsupported type annotation on field `{field.type._name}`' ) else: doc, attribute_type = cls._from_obj(attribute, field.type, field) multi_modal_schema[key] = { 'attribute_type': attribute_type, 'type': field.type.__name__, 'position': len(root.chunks), } root.chunks.append(doc) # TODO: may have to modify this? root.tags = tags root._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA] = multi_modal_schema return root def _get_mm_attr_postion(self, attr): if not self.is_multimodal: raise ValueError( 'the Document does not correspond to a Multi Modal Document' ) if attr not in self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA]: raise ValueError( f'the Document schema does not contain attribute `{attr}`, typo?' ) pos = self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA][attr].get('position') if pos is None: raise ValueError( f'attribute {attr} is not a valid multi modal attribute.' f' One possible cause is the usage of a non-supported type in the dataclass definition.' ) return int(pos)
[docs] def get_multi_modal_attribute(self, attribute: str) -> 'DocumentArray': from docarray import DocumentArray position = self._get_mm_attr_postion(attribute) attribute_type = self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA][attribute][ 'attribute_type' ] if attribute_type in [AttributeType.DOCUMENT, AttributeType.NESTED]: return DocumentArray([self.chunks[position]]) elif attribute_type in [ AttributeType.ITERABLE_DOCUMENT, AttributeType.ITERABLE_NESTED, ]: return self.chunks[position].chunks else: raise ValueError( f'Invalid attribute {attribute}: must be a Document attribute or nested dataclass' )
[docs] def set_multi_modal_attribute( self, attribute: str, value: typing.Union['Document', 'DocumentArray'] ): position = self._get_mm_attr_postion(attribute) attribute_type = self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA][attribute][ 'attribute_type' ] if attribute_type in [AttributeType.DOCUMENT, AttributeType.NESTED]: self.chunks[position] = value elif attribute_type in [ AttributeType.ITERABLE_DOCUMENT, AttributeType.ITERABLE_NESTED, ]: self.chunks[position].chunks = value else: raise ValueError( f'Invalid attribute {attribute}: must be a Document attribute or nested dataclass' )
@classmethod def _from_obj(cls, obj, obj_type, field) -> typing.Tuple['Document', AttributeType]: attribute_type = AttributeType.DOCUMENT if is_multimodal(obj_type): doc = cls(obj) attribute_type = AttributeType.NESTED elif _is_field(field): doc = field.setter(obj) else: raise AttributeTypeError(f'Unsupported type annotation {obj_type}') return doc, attribute_type @staticmethod def _get_attribute_type_from_obj_type(obj_type, field) -> AttributeType: if is_multimodal(obj_type): attribute_type = AttributeType.NESTED elif _is_field(field): attribute_type = AttributeType.DOCUMENT else: raise AttributeTypeError(f'Unsupported type annotation {obj_type}') return attribute_type def _has_multimodal_attr(self, attr): try: data = super().__getattribute__('_data') has_data = bool(data) except AttributeError: return False has_metadata = has_data and getattr(self._data, '_metadata') is not None return ( has_metadata and self.is_multimodal and attr in self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA] ) def __getattr__(self, attr): if self._has_multimodal_attr(attr): mm_attr_da = self.get_multi_modal_attribute(attr) attr_type = self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA][attr][ 'attribute_type' ] if attr_type in [ AttributeType.ITERABLE_DOCUMENT, AttributeType.ITERABLE_NESTED, AttributeType.ITERABLE_PRIMITIVE, ]: return mm_attr_da else: return mm_attr_da[0] else: raise AttributeError(f'{self.__class__.__name__} has no attribute `{attr}`') def __setattr__(self, attr, value): if self._has_multimodal_attr(attr): self.set_multi_modal_attribute(attr, value) else: object.__setattr__(self, attr, value)