import itertools
from typing import (
    TYPE_CHECKING,
    Union,
    Sequence,
    overload,
    Any,
    List,
)
import numpy as np
from docarray import Document
from docarray.helper import typename
if TYPE_CHECKING:  # pragma: no cover
    from docarray.typing import (
        DocumentArrayIndexType,
        DocumentArraySingletonIndexType,
        DocumentArrayMultipleIndexType,
        DocumentArrayMultipleAttributeType,
        DocumentArraySingleAttributeType,
    )
[docs]class SetItemMixin:
    """Provides helper function to allow advanced indexing for `__setitem__`"""
    @overload
    def __setitem__(
        self,
        index: 'DocumentArrayMultipleAttributeType',
        value: List[List['Any']],
    ):
        ...
    @overload
    def __setitem__(
        self,
        index: 'DocumentArraySingleAttributeType',
        value: List['Any'],
    ):
        ...
    @overload
    def __setitem__(
        self,
        index: 'DocumentArraySingletonIndexType',
        value: 'Document',
    ):
        ...
    @overload
    def __setitem__(
        self,
        index: 'DocumentArrayMultipleIndexType',
        value: Sequence['Document'],
    ):
        ...
    def __setitem__(
        self,
        index: 'DocumentArrayIndexType',
        value: Union['Document', Sequence['Document']],
    ):
        from docarray.helper import check_root_id
        if getattr(self, '_is_subindex', None):
            check_root_id(self, value)
        self._update_subindices_set(index, value)
        # set by offset
        # allows da[1] = Document()
        if isinstance(index, (int, np.generic)) and not isinstance(index, bool):
            self._set_doc_by_offset(int(index), value)
        elif isinstance(index, str):
            # set by traversal paths
            # allows da['@m,c] = [m1, m2, ..., mn, c1, c2, ..., cp]
            if index.startswith('@'):
                self._set_doc_value_pairs_nested(self.traverse_flat(index[1:]), value)
            # set by ID
            # allows da['id_123'] = Document()
            else:
                self._set_doc(index, value)
        # set by slice
        # allows da[1:3] = [d1, d2]
        elif isinstance(index, slice):
            self._set_docs_by_slice(index, value)
        # flatten and set
        # allows da[...] = [d1, d2,..., dn]
        elif index is Ellipsis:
            self._set_doc_value_pairs(self.flatten(), value)
        # index is sequence
        elif isinstance(index, Sequence):
            # allows da[idx1, idx2] = value
            if isinstance(index, tuple) and len(index) == 2:
                self._set_by_pair(index[0], index[1], value)
            # allows da[True, False, True, True]
            elif isinstance(index[0], bool):
                self._set_by_mask(index, value)
            # allows da[id1, id2, id3] = [d1, d2, d3]
            elif isinstance(index[0], (int, str)):
                for si, _val in zip(index, value):
                    self[si] = _val  # leverage existing setter
            else:
                raise IndexError(
                    f'{index} should be either a sequence of bool, int or str'
                )
        # set by ndarray
        elif isinstance(index, np.ndarray):
            index = index.squeeze()
            if index.ndim == 1:
                self[index.tolist()] = value  # leverage existing setter
            else:
                raise IndexError(
                    f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}'
                )
        else:
            raise IndexError(f'Unsupported index type {typename(index)}: {index}')
    def _set_by_pair(self, idx1, idx2, value):
        if isinstance(idx1, str) and not idx1.startswith('@'):
            # second is an ID
            # allows da[id1, id2] = [d1, d2]
            if isinstance(idx2, str) and idx2 in self:
                self._set_doc_value_pairs((self[idx1], self[idx2]), value)
            # second is an attribute
            # allows da[id, attr] = attr_value
            elif isinstance(idx2, str) and hasattr(self[idx1], idx2):
                self._set_doc_attr_by_id(idx1, idx2, value)
            # second is a list of attributes:
            # allows da[id, [attr1, attr2, attr3]] = [v1, v2, v3]
            elif (
                isinstance(idx2, Sequence)
                and all(isinstance(attr, str) for attr in idx2)
                and all(hasattr(self[idx1], attr) for attr in idx2)
            ):
                for attr, _v in zip(idx2, value):
                    self._set_doc_attr_by_id(idx1, attr, _v)
            else:
                raise IndexError(f'`{idx2}` is neither a valid id nor attribute name')
        elif isinstance(idx1, int):
            # second is an offset
            # allows da[offset1, offset2] = [d1, d2]
            if isinstance(idx2, int):
                self._set_doc_value_pairs((self[idx1], self[idx2]), value)
            # second is an attribute
            # allows da[offset, attr] = value
            elif isinstance(idx2, str) and hasattr(self[idx1], idx2):
                self._set_doc_attr_by_offset(idx1, idx2, value)
            # second is a list of attributes
            # allows da[offset, [attr1, attr2, attr3]] = [v1, v2, v3]
            elif (
                isinstance(idx2, Sequence)
                and all(isinstance(attr, str) for attr in idx2)
                and all(hasattr(self[idx1], attr) for attr in idx2)
            ):
                for attr, _v in zip(idx2, value):
                    self._set_doc_attr_by_offset(idx1, attr, _v)
            else:
                raise IndexError(f'`{idx2}` must be an attribute or list of attributes')
        # allows da[sequence/slice/ellipsis/traversal_path, attributes] = [v1, v2, ...]
        elif (
            isinstance(idx1, (slice, Sequence))
            or idx1 is Ellipsis
            or (isinstance(idx1, str) and idx1.startswith('@'))
        ):
            self._set_docs_attributes(idx1, idx2, value)
        else:
            raise IndexError(f'Unsupported first index type {typename(idx1)}: {idx1}')
    def _set_by_mask(self, mask: List[bool], value):
        _selected = itertools.compress(self, mask)
        self._set_doc_value_pairs(_selected, value)
    def _set_docs_attributes(self, index, attributes, value):
        if isinstance(attributes, str):
            # a -> [a]
            # [a, a] -> [a, a]
            attributes = (attributes,)
            value = (value,)
        if isinstance(index, str) and index.startswith('@'):
            self._set_docs_attributes_traversal_paths(index, attributes, value)
        elif index is Ellipsis:
            _docs = self[index]
            for _a, _v in zip(attributes, value):
                if _a == 'tensor':
                    _docs.tensors = _v
                elif _a == 'embedding':
                    _docs.embeddings = _v
                else:
                    if not isinstance(_v, (list, tuple)):
                        for _d in _docs:
                            setattr(_d, _a, _v)
                    else:
                        for _d, _vv in zip(_docs, _v):
                            setattr(_d, _a, _vv)
            self._set_doc_value_pairs_nested(_docs, _docs)
        else:
            _docs = self[index]
            if not _docs:
                return
            for _a, _v in zip(attributes, value):
                if _a in ('tensor', 'embedding'):
                    if _a == 'tensor':
                        _docs.tensors = _v
                    elif _a == 'embedding':
                        _docs.embeddings = _v
                    for _d in _docs:
                        self._set_doc(_d.id, _d)
                else:
                    if not isinstance(_v, (list, tuple)):
                        for _d in _docs:
                            self._set_doc_attr_by_id(_d.id, _a, _v)
                    else:
                        for _d, _vv in zip(_docs, _v):
                            self._set_doc_attr_by_id(_d.id, _a, _vv)
    def _set_docs_attributes_traversal_paths(
        self, traversal_paths: str, attributes, value
    ):
        _docs = self[traversal_paths]
        if not _docs:
            return
        for _a, _v in zip(attributes, value):
            if _a == 'tensor':
                _docs.tensors = _v
            elif _a == 'embedding':
                _docs.embeddings = _v
            else:
                if not isinstance(_v, (list, tuple)):
                    for _d in _docs:
                        setattr(_d, _a, _v)
                else:
                    for _d, _vv in zip(_docs, _v):
                        setattr(_d, _a, _vv)
        self._set_doc_value_pairs_nested(_docs, _docs)