Source code for docarray.base

import copy as cp
import dataclasses
from dataclasses import fields
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, Tuple, Dict
from docarray.dataclasses import is_multimodal

from docarray.helper import typename

if TYPE_CHECKING:  # pragma: no cover
    from docarray.typing import T


@lru_cache()
def _get_fields(dc):
    return [f.name for f in fields(dc)]


[docs]class BaseDCType: _data_class = None def __init__( self: 'T', _obj: Optional['T'] = None, copy: bool = False, field_resolver: Optional[Dict[str, str]] = None, unknown_fields_handler: str = 'catch', **kwargs, ): self._data = None if isinstance(_obj, type(self)): if copy: self.copy_from(_obj) else: self._data = _obj._data elif isinstance(_obj, dict): kwargs.update(_obj) elif is_multimodal(_obj): self._data = type(self)._from_dataclass(_obj)._data if kwargs: try: if self._data is not None: if self._unresolved_fields_dest in kwargs.keys(): getattr(self, self._unresolved_fields_dest).update( kwargs[self._unresolved_fields_dest] ) kwargs.pop(self._unresolved_fields_dest) self._data = dataclasses.replace(self._data, **kwargs) else: self._data = self._data_class(self, **kwargs) except TypeError as ex: if unknown_fields_handler == 'raise': raise AttributeError(f'unknown attributes') from ex else: if field_resolver: kwargs = { field_resolver.get(k, k): v for k, v in kwargs.items() } _fields = _get_fields(self._data_class) _unknown_kwargs = None _unresolved = set(kwargs.keys()).difference(_fields) if _unresolved: _unknown_kwargs = {k: kwargs[k] for k in _unresolved} for k in _unresolved: kwargs.pop(k) if self._data is not None: self._data = dataclasses.replace(self._data, **kwargs) else: self._data = self._data_class(self, **kwargs) if _unknown_kwargs and unknown_fields_handler == 'catch': getattr(self, self._unresolved_fields_dest).update( _unknown_kwargs ) for k in self._post_init_fields: if k in kwargs: setattr(self, k, kwargs[k]) if not _obj and not kwargs and self._data is None: self._data = self._data_class(self) if self._data is None: raise ValueError( f'Failed to initialize {typename(self)} from obj={_obj}, kwargs={kwargs}' )
[docs] def copy_from(self: 'T', other: 'T') -> None: """Overwrite self by copying from another :class:`Document`. :param other: the other Document to copy from """ self._data = cp.deepcopy(other._data)
[docs] def clear(self) -> None: """Clear all fields from this :class:`Document` to their default values.""" for f in self.non_empty_fields: setattr(self._data, f, None)
[docs] def pop(self, *fields) -> None: """Clear some fields from this :class:`Document` to their default values. :param fields: field names to clear. """ for f in fields: if hasattr(self, f): setattr(self._data, f, None)
@property def non_empty_fields(self) -> Tuple[str]: """Get all non-emtpy fields of this :class:`Document`. Non-empty fields are the fields with not-`None` and not-default values. :return: field names in a tuple. """ return self._data._non_empty_fields @property def nbytes(self) -> int: """Return total bytes consumed by protobuf. :return: number of bytes """ return len(bytes(self)) def __hash__(self): return hash(self._data) def __repr__(self): content = str(self.non_empty_fields) content += f' at {getattr(self, "id", id(self))}' return f'<{self.__class__.__name__} {content.strip()}>' def __bytes__(self): return self.to_bytes() def __eq__(self, other): if type(self) is type(other): return self._data == other._data return False