Skip to content

Commit

Permalink
Fix errors when the concept is empty (#1158)
Browse files Browse the repository at this point in the history
When the concept has no positive or negative examples, predict 0.5
always.

Fixes #1153
  • Loading branch information
dsmilkov authored Feb 2, 2024
1 parent 9185105 commit 2df9fd5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
33 changes: 27 additions & 6 deletions lilac/concepts/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from joblib import Parallel, delayed
from pydantic import BaseModel, field_validator
from sklearn.base import clone
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve, roc_auc_score
from sklearn.model_selection import KFold
from sklearn.utils.validation import check_is_fitted

from ..embeddings.embedding import get_embed_fn
from ..signal import TextEmbeddingSignal, get_signal_cls
Expand Down Expand Up @@ -140,6 +142,15 @@ class ConceptMetrics(BaseModel):
overall: OverallScore


def _is_fitted(model: LogisticRegression) -> bool:
"""Check if the model is fitted."""
try:
check_is_fitted(model)
return True
except NotFittedError:
return False


@dataclasses.dataclass
class LogisticEmbeddingModel:
"""A model that uses logistic regression with embeddings."""
Expand All @@ -155,7 +166,10 @@ def __post_init__(self) -> None:

def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
"""Get the scores for the provided embeddings."""
y_probs = self._model.predict_proba(embeddings)[:, 1]
if _is_fitted(self._model):
y_probs = self._model.predict_proba(embeddings)[:, 1]
else:
y_probs = np.ones(len(embeddings)) * 0.5
# Map [0, threshold, 1] to [0, 0.5, 1].
power = np.log(self._threshold) / np.log(0.5)
return y_probs**power
Expand All @@ -173,7 +187,9 @@ def _setup_training(
def fit(self, embeddings: np.ndarray, labels: list[bool]) -> None:
"""Fit the model to the provided embeddings and labels."""
label_set = set(labels)
if len(label_set) < 2:
if len(label_set) == 0:
return
elif len(label_set) < 2:
dim = embeddings.shape[1]
random_vector = np.random.randn(dim).astype(np.float32)
random_vector /= np.linalg.norm(random_vector)
Expand Down Expand Up @@ -206,7 +222,10 @@ def _fit_and_score(
if len(set(y_train)) < 2:
return np.array([]), np.array([])
model.fit(X_train, y_train)
y_pred = model.predict_proba(X_test)[:, 1]
if _is_fitted(model):
y_pred = model.predict_proba(X_test)[:, 1]
else:
y_pred = np.ones_like(y_test) * 0.5
return y_test, y_pred

# Compute the metrics for each validation fold in parallel.
Expand Down Expand Up @@ -298,7 +317,11 @@ def score_embeddings(self, draft: DraftId, embeddings: np.ndarray) -> np.ndarray

def coef(self, draft: DraftId = DRAFT_MAIN) -> np.ndarray:
"""Get the coefficients of the underlying ML model."""
return self._get_logistic_model(draft)._model.coef_.reshape(-1)
model = self._get_logistic_model(draft)
if _is_fitted(model._model):
return model._model.coef_.reshape(-1)
else:
return np.zeros(0)

def _get_logistic_model(self, draft: DraftId = DRAFT_MAIN) -> LogisticEmbeddingModel:
"""Get the logistic model for the provided draft."""
Expand Down Expand Up @@ -345,8 +368,6 @@ def _compute_embeddings(self, concept: Concept) -> None:
concept_embeddings: dict[str, np.ndarray] = {}

examples = concept.data.items()
if not examples:
raise ValueError(f'Cannot sync concept "{concept.concept_name}". It has no examples.')

# Compute the embeddings for the examples with cache miss.
texts_of_missing_embeddings: dict[str, str] = {}
Expand Down
14 changes: 14 additions & 0 deletions lilac/concepts/db_concept_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,17 @@ def test_embedding_not_found_in_map(

with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
model_db.sync(model.namespace, model.concept_name, model.embedding_name)

def test_empty_concept(
self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB]
) -> None:
concept_db = concept_db_cls()
model_db = model_db_cls(concept_db)

namespace = 'test'
concept_name = 'test_concept'
concept_db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT)
model = model_db.create(namespace, concept_name, embedding_name='test_embedding')
model = model_db.sync(model.namespace, model.concept_name, model.embedding_name)
# Make sure the model is in sync.
assert model_db.in_sync(model) is True

0 comments on commit 2df9fd5

Please sign in to comment.