Skip to content

Commit

Permalink
Update Zero-shot Classification Task (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi authored Oct 29, 2024
1 parent a247b4e commit 737ec9f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ Evaluates the quality of the learned representations in retrieving the <i>k</i>
using recall@k metric. This is applicable to any number of pairs of modalities at once, depending on memory constraints.
</td>
</tr>
<tr>
<td>

Zero-shot Classification
</td>
<td>
Evaluates the ability of a pre-trained encoder-based multimodal model to predict classes that were not explicitly seen
during training. The new classes are given as text prompts, and the query modality can be any of the supported modalities.
Binary and multi-class classification tasks are supported.
</td>
</tr>
</table>

## Components
Expand Down
43 changes: 29 additions & 14 deletions mmlearn/tasks/zero_shot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ def evaluation_step(
query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
query_embeddings = query_embeddings[matching_indices]

logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
if self.all_dataset_info[dataset_index]["num_classes"] == 2:
softmax_output = _safe_matmul(
query_embeddings, class_embeddings
).softmax(dim=-1)
logits = softmax_output[:, 1] - softmax_output[:, 0]
else:
logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
targets = batch[Modalities.get_modality(query_modality).target][
matching_indices
]
Expand Down Expand Up @@ -233,27 +239,36 @@ def _create_metrics(
num_classes: int, top_k: List[int], prefix: str, postfix: str
) -> MetricCollection:
"""Create a collection of classification metrics."""
task_type = "binary" if num_classes == 2 else "multiclass"
acc_metrics = (
{
f"top{k}_accuracy": Accuracy(
task=task_type, num_classes=num_classes, top_k=k, average="micro"
)
for k in top_k
}
if num_classes > 2
else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
)
return MetricCollection(
{
"precision": Precision(
task="multiclass", num_classes=num_classes, average="macro"
task=task_type,
num_classes=num_classes,
average="macro" if num_classes > 2 else "micro",
),
"recall": Recall(
task="multiclass", num_classes=num_classes, average="macro"
task=task_type,
num_classes=num_classes,
average="macro" if num_classes > 2 else "micro",
),
"f1_score_macro": F1Score(
task="multiclass", num_classes=num_classes, average="macro"
task=task_type,
num_classes=num_classes,
average="macro" if num_classes > 2 else "micro",
),
"aucroc": AUROC(task="multiclass", num_classes=num_classes),
**{
f"top{k}_accuracy": Accuracy(
task="multiclass",
num_classes=num_classes,
top_k=k,
average="micro",
)
for k in top_k
},
"aucroc": AUROC(task=task_type, num_classes=num_classes),
**acc_metrics,
},
prefix=prefix,
postfix=postfix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ datasets:

dataloader:
test:
batch_size: 64
batch_size: 128
num_workers: 4

task:
Expand All @@ -153,15 +153,14 @@ task:
task_specs:
- top_k: [1]
query_modality: rgb
run_on_validation: false
run_on_test: true
run_on_validation: False
run_on_test: True
compute_validation_loss: False
compute_test_loss: False

trainer:
precision: 16-mixed
deterministic: False
benchmark: True
deterministic: True
sync_batchnorm: False # set to True if using DDP with batchnorm
log_every_n_steps: 100

Expand Down

0 comments on commit 737ec9f

Please sign in to comment.