Skip to content

Module vexpresso.retrievers#

View Source
from vexpresso.retrievers.base import BaseRetriever, RetrievalOutput

from vexpresso.retrievers.faiss import FaissRetriever

from vexpresso.retrievers.np import Retriever

__all__ = [

    "BaseRetriever",

    "Retriever",

    "RetrievalOutput",

    "FaissRetriever",

]

Sub-modules#

Classes#

BaseRetriever#

class BaseRetriever(
    /,
    *args,
    **kwargs
)
View Source
class BaseRetriever(metaclass=abc.ABCMeta):

    SUPPORTED_TYPES = [np.dtype]

    @abc.abstractmethod

    def retrieve(

        self, query_embeddings: np.ndarray, embeddings: List[Any], *args, **kwargs

    ) -> Union[List[RetrievalOutput], RetrievalOutput]:

        """

        Queries embeddings with query embedding vector and returns nearest embeddings and their corresponding ids

        Args:

            query_embeddings (np.ndarray): query, used to find nearest embeddings in set.

            embeddings (List[Any]): embeddings set, query is compared to this.

        Returns:

            Union[List[QueryOutput], QueryOutput]: dataclasses containing returned embeddings and corresponding indices.

            When this has more than one entry, that means that the call was batched

        """

Descendants#

  • vexpresso.retrievers.FaissRetriever
  • vexpresso.retrievers.Retriever

Class variables#

SUPPORTED_TYPES

Methods#

retrieve#

def retrieve(
    self,
    query_embeddings: numpy.ndarray,
    embeddings: List[Any],
    *args,
    **kwargs
) -> Union[List[vexpresso.retrievers.base.RetrievalOutput], vexpresso.retrievers.base.RetrievalOutput]

Queries embeddings with query embedding vector and returns nearest embeddings and their corresponding ids

Parameters:

Name Type Description Default
query_embeddings np.ndarray query, used to find nearest embeddings in set. None
embeddings List[Any] embeddings set, query is compared to this. None

Returns:

Type Description
Union[List[QueryOutput], QueryOutput] dataclasses containing returned embeddings and corresponding indices.
When this has more than one entry, that means that the call was batched
View Source
    @abc.abstractmethod

    def retrieve(

        self, query_embeddings: np.ndarray, embeddings: List[Any], *args, **kwargs

    ) -> Union[List[RetrievalOutput], RetrievalOutput]:

        """

        Queries embeddings with query embedding vector and returns nearest embeddings and their corresponding ids

        Args:

            query_embeddings (np.ndarray): query, used to find nearest embeddings in set.

            embeddings (List[Any]): embeddings set, query is compared to this.

        Returns:

            Union[List[QueryOutput], QueryOutput]: dataclasses containing returned embeddings and corresponding indices.

            When this has more than one entry, that means that the call was batched

        """

FaissRetriever#

class FaissRetriever(

)
View Source
class FaissRetriever(BaseRetriever):

    def __init__(self):

        self._faiss = None

        try:

            import faiss  # noqa

            self._faiss = faiss

        except ImportError:

            raise ImportError(

                "Could not import faiss python package."

                "Please install it with `pip install faiss-cpu` or `faiss-gpu`."

            )

        self.index = None

    def _setup_index(self, embeddings: np.ndarray):

        self.index = self._faiss.IndexFlatL2(embeddings.shape[1])  # noqa

        self.index.add(embeddings.astype(np.float32))

    def retrieve(

        self,

        query_embeddings: np.ndarray,

        embeddings: np.ndarray,

        k: int = 4,

    ) -> List[RetrievalOutput]:

        if not isinstance(embeddings, np.ndarray):

            embeddings = np.array(embeddings)

        query_embeddings = np.array(query_embeddings)

        self._setup_index(embeddings)

        distances, indices = self.index.search(query_embeddings.astype(np.float32), k=k)

        out = []

        for idx in range(indices.shape[0]):

            query_output = RetrievalOutput(

                embeddings[indices[idx]],

                indices[idx],

                scores=distances[idx],

                query_embeddings=query_embeddings,

            )

            out.append(query_output)

        return out

Ancestors (in MRO)#

  • vexpresso.retrievers.BaseRetriever

Class variables#

SUPPORTED_TYPES

Methods#

retrieve#

def retrieve(
    self,
    query_embeddings: numpy.ndarray,
    embeddings: numpy.ndarray,
    k: int = 4
) -> List[vexpresso.retrievers.base.RetrievalOutput]

Queries embeddings with query embedding vector and returns nearest embeddings and their corresponding ids

Parameters:

Name Type Description Default
query_embeddings np.ndarray query, used to find nearest embeddings in set. None
embeddings List[Any] embeddings set, query is compared to this. None

Returns:

Type Description
Union[List[QueryOutput], QueryOutput] dataclasses containing returned embeddings and corresponding indices.
When this has more than one entry, that means that the call was batched
View Source
    def retrieve(

        self,

        query_embeddings: np.ndarray,

        embeddings: np.ndarray,

        k: int = 4,

    ) -> List[RetrievalOutput]:

        if not isinstance(embeddings, np.ndarray):

            embeddings = np.array(embeddings)

        query_embeddings = np.array(query_embeddings)

        self._setup_index(embeddings)

        distances, indices = self.index.search(query_embeddings.astype(np.float32), k=k)

        out = []

        for idx in range(indices.shape[0]):

            query_output = RetrievalOutput(

                embeddings[indices[idx]],

                indices[idx],

                scores=distances[idx],

                query_embeddings=query_embeddings,

            )

            out.append(query_output)

        return out

RetrievalOutput#

class RetrievalOutput(
    embeddings: Any,
    indices: Union[numpy.ndarray, Iterable[int]],
    scores: Union[numpy.ndarray, Iterable[float]],
    query_embeddings: Optional[Any] = None
)

RetrievalOutput(embeddings: Any, indices: Union[numpy.ndarray, Iterable[int]], scores: Union[numpy.ndarray, Iterable[float]], query_embeddings: Optional[Any] = None)

View Source
@dataclass

class RetrievalOutput:

    embeddings: Any

    indices: Union[np.ndarray, Iterable[int]]

    scores: Union[np.ndarray, Iterable[float]]

    query_embeddings: Optional[Any] = None

Class variables#

query_embeddings

Retriever#

class Retriever(
    similarity_fn: str = 'cosine'
)
View Source
class Retriever(BaseRetriever):

    def __init__(self, similarity_fn: str = "cosine"):

        self.similarity_fn = get_similarity_fn(similarity_fn)

    def _get_similarities(

        self,

        query_embeddings: np.ndarray,

        embeddings: np.ndarray,

    ):

        if not is_batched(query_embeddings):

            query_embeddings = np.expand_dims(query_embeddings, axis=0)

        similarities = self.similarity_fn(query_embeddings, embeddings)

        if not is_batched(similarities):

            similarities = np.expand_dims(similarities, 0)

        return similarities

    def _get_top_k(

        self,

        query_embeddings: np.ndarray,

        embeddings: np.ndarray,

        k: int = 1,

    ):

        similarities = self._get_similarities(query_embeddings, embeddings)

        top_indices = np.flip(

            np.argsort(similarities, axis=-1)[:, -k:], axis=-1

        )  # B X k

        return top_indices, similarities

    def retrieve(

        self,

        query_embeddings: np.ndarray,

        embeddings: List[Any],

        k: int = 4,

    ) -> List[RetrievalOutput]:

        embeddings = np.array(embeddings)

        query_embeddings = np.array(query_embeddings)

        top_indices, similarities = self._get_top_k(query_embeddings, embeddings, k)

        # move to list for consistency w/ single and batch calls

        out = []

        for idx in range(top_indices.shape[0]):

            query_output = RetrievalOutput(

                embeddings[top_indices[idx]],

                top_indices[idx],

                scores=similarities[idx],

                query_embeddings=query_embeddings[idx],

            )

            out.append(query_output)

        return out

Ancestors (in MRO)#

  • vexpresso.retrievers.BaseRetriever

Class variables#

SUPPORTED_TYPES

Methods#

retrieve#

def retrieve(
    self,
    query_embeddings: numpy.ndarray,
    embeddings: List[Any],
    k: int = 4
) -> List[vexpresso.retrievers.base.RetrievalOutput]

Queries embeddings with query embedding vector and returns nearest embeddings and their corresponding ids

Parameters:

Name Type Description Default
query_embeddings np.ndarray query, used to find nearest embeddings in set. None
embeddings List[Any] embeddings set, query is compared to this. None

Returns:

Type Description
Union[List[QueryOutput], QueryOutput] dataclasses containing returned embeddings and corresponding indices.
When this has more than one entry, that means that the call was batched
View Source
    def retrieve(

        self,

        query_embeddings: np.ndarray,

        embeddings: List[Any],

        k: int = 4,

    ) -> List[RetrievalOutput]:

        embeddings = np.array(embeddings)

        query_embeddings = np.array(query_embeddings)

        top_indices, similarities = self._get_top_k(query_embeddings, embeddings, k)

        # move to list for consistency w/ single and batch calls

        out = []

        for idx in range(top_indices.shape[0]):

            query_output = RetrievalOutput(

                embeddings[top_indices[idx]],

                top_indices[idx],

                scores=similarities[idx],

                query_embeddings=query_embeddings[idx],

            )

            out.append(query_output)

        return out