Module vexpresso.retrievers.faiss#
View Source
from typing import List
import numpy as np
from vexpresso.retrievers.base import BaseRetriever, RetrievalOutput
class FaissRetriever(BaseRetriever):
def __init__(self):
self._faiss = None
try:
import faiss # noqa
self._faiss = faiss
except ImportError:
raise ImportError(
"Could not import faiss python package."
"Please install it with `pip install faiss-cpu` or `faiss-gpu`."
)
self.index = None
def _setup_index(self, embeddings: np.ndarray):
self.index = self._faiss.IndexFlatL2(embeddings.shape[1]) # noqa
self.index.add(embeddings.astype(np.float32))
def retrieve(
self,
query_embeddings: np.ndarray,
embeddings: np.ndarray,
k: int = 4,
) -> List[RetrievalOutput]:
if not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
query_embeddings = np.array(query_embeddings)
self._setup_index(embeddings)
distances, indices = self.index.search(query_embeddings.astype(np.float32), k=k)
out = []
for idx in range(indices.shape[0]):
query_output = RetrievalOutput(
embeddings[indices[idx]],
indices[idx],
scores=distances[idx],
query_embeddings=query_embeddings,
)
out.append(query_output)
return out
Classes#
FaissRetriever#
class FaissRetriever(
)
View Source
class FaissRetriever(BaseRetriever):
def __init__(self):
self._faiss = None
try:
import faiss # noqa
self._faiss = faiss
except ImportError:
raise ImportError(
"Could not import faiss python package."
"Please install it with `pip install faiss-cpu` or `faiss-gpu`."
)
self.index = None
def _setup_index(self, embeddings: np.ndarray):
self.index = self._faiss.IndexFlatL2(embeddings.shape[1]) # noqa
self.index.add(embeddings.astype(np.float32))
def retrieve(
self,
query_embeddings: np.ndarray,
embeddings: np.ndarray,
k: int = 4,
) -> List[RetrievalOutput]:
if not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
query_embeddings = np.array(query_embeddings)
self._setup_index(embeddings)
distances, indices = self.index.search(query_embeddings.astype(np.float32), k=k)
out = []
for idx in range(indices.shape[0]):
query_output = RetrievalOutput(
embeddings[indices[idx]],
indices[idx],
scores=distances[idx],
query_embeddings=query_embeddings,
)
out.append(query_output)
return out
Ancestors (in MRO)#
- vexpresso.retrievers.base.BaseRetriever
Class variables#
SUPPORTED_TYPES
Methods#
retrieve#
def retrieve(
self,
query_embeddings: numpy.ndarray,
embeddings: numpy.ndarray,
k: int = 4
) -> List[vexpresso.retrievers.base.RetrievalOutput]
Queries embeddings with query embedding vector and returns nearest embeddings and their corresponding ids
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
| query_embeddings | np.ndarray | query, used to find nearest embeddings in set. | None |
| embeddings | List[Any] | embeddings set, query is compared to this. | None |
Returns:
| Type | Description |
|---|---|
| Union[List[QueryOutput], QueryOutput] | dataclasses containing returned embeddings and corresponding indices. When this has more than one entry, that means that the call was batched |
View Source
def retrieve(
self,
query_embeddings: np.ndarray,
embeddings: np.ndarray,
k: int = 4,
) -> List[RetrievalOutput]:
if not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
query_embeddings = np.array(query_embeddings)
self._setup_index(embeddings)
distances, indices = self.index.search(query_embeddings.astype(np.float32), k=k)
out = []
for idx in range(indices.shape[0]):
query_output = RetrievalOutput(
embeddings[indices[idx]],
indices[idx],
scores=distances[idx],
query_embeddings=query_embeddings,
)
out.append(query_output)
return out