Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add end-to-end HuggingFace Example #758

Draft
wants to merge 11 commits into
base: language
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions configs/language/pubmedqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 5}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, meta-llama/Llama-3.2-1B}/online/pubmedqa}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500}
checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.language.models.TextModule
init_args:
prompt: "Respond to the question with a single digit only: 0 for no, 1 for yes, or 2 for maybe. Do not include any words, explanations, or additional characters—only the digit."
model:
class_path: eva.language.models.HuggingFaceTextModel
init_args:
model_name_or_path: "meta-llama/Llama-3.2-1B"
model_kwargs:
max_new_tokens: 1
# TODO: implement metrics for language models
# metrics:
# common:
# - class_path: eva.metrics.MulticlassClassificationMetrics
# init_args:
# num_classes: 3
postprocess: null
data:
class_path: eva.DataModule
init_args:
datasets:
val:
class_path: eva.language.datasets.PubMedQA
init_args: &DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data/pubmedqa}
split: null
download: true
dataloaders:
val:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 1}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 1}
shuffle: false
1,258 changes: 1,249 additions & 9 deletions pdm.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ all = [
"scipy>=1.14.0",
"monai>=1.3.2",
"datasets>=3.2.0",
"litellm>=1.61.8",
"vllm>=0.5.1",
]

[project.scripts]
Expand Down
23 changes: 23 additions & 0 deletions src/eva/core/interface/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,26 @@ def predict_fit(
"""
self.predict(trainer=trainer, model=model, data=data)
self.fit(trainer=trainer, model=model, data=data)

def validate(
self,
trainer: eva_trainer.Trainer,
model: modules.ModelModule,
data: datamodules.DataModule,
) -> None:
"""Perform model validation out-of-place without running fit.

This method is useful when the model is already trained or does not
require further training (e.g., large language models) and you only
want to measure performance.

Args:
trainer: The base trainer to use but not modify.
model: The model module to use but not modify.
data: The data module containing validation data.
"""
trainer.validate_only(
model=model,
datamodule=data,
verbose=True,
)
29 changes: 29 additions & 0 deletions src/eva/core/trainers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,32 @@ def infer_model(
datamodule=datamodule,
return_predictions=return_predictions,
)


def run_validation(
base_trainer: eva_trainer.Trainer,
base_model: modules.ModelModule,
datamodule: datamodules.DataModule,
verbose: bool = True,
) -> _EVALUATE_OUTPUT:
"""Validates a model out-of-place without fitting first.

Note:
This function clones the base model and trainer, so that the inputs
are not modified.

Args:
base_trainer: The base trainer to clone.
base_model: The model module to clone.
datamodule: The data module to validate on.
verbose: Whether to print the validation metrics after validation.

Returns:
The validation metrics produced by `trainer.validate`.
"""
trainer, model = _utils.clone(base_trainer, base_model)
return trainer.validate(
model=model,
datamodule=datamodule,
verbose=verbose,
)
23 changes: 23 additions & 0 deletions src/eva/core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,26 @@ def run_evaluation_session(
n_runs=self.n_runs,
verbose=self.n_runs > 1,
)

def validate_only(
self,
model: modules.ModelModule,
datamodule: datamodules.DataModule,
verbose: bool = True,
) -> None:
"""Runs validation on the model out-of-place without fitting or test.

Args:
model: The model to validate (cloned, not modified in-place).
datamodule: The datamodule for validation.
verbose: Whether to print validation metrics to stdout.

Returns:
Validation metrics as returned by trainer.validate.
"""
functional.run_validation(
base_trainer=self,
base_model=model,
datamodule=datamodule,
verbose=verbose,
)
16 changes: 8 additions & 8 deletions src/eva/language/data/datasets/classification/pubmedqa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""PubMedQA dataset class."""

import os
from typing import Any, Dict, List, Literal
from typing import Dict, List, Literal

import torch
from datasets import Dataset, load_dataset
Expand Down Expand Up @@ -118,15 +118,15 @@ def load_target(self, index: int) -> torch.Tensor:
)

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
def load_metadata(self, index: int) -> Dict[str, str]:
sample = self.dataset[index]
return {
"year": sample["YEAR"],
"labels": sample["LABELS"],
"meshes": sample["MESHES"],
"long_answer": sample["LONG_ANSWER"],
"reasoning_required": sample["reasoning_required_pred"],
"reasoning_free": sample["reasoning_free_pred"],
"year": sample.get("YEAR") or "",
"labels": sample.get("LABELS") or "",
"meshes": sample.get("MESHES") or "",
"long_answer": sample.get("LONG_ANSWER") or "",
"reasoning_required": sample.get("reasoning_required_pred") or "",
"reasoning_free": sample.get("reasoning_free_pred") or "",
}

@override
Expand Down
12 changes: 12 additions & 0 deletions src/eva/language/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Language Models API."""

from eva.language.models import networks, wrappers
from eva.language.models.networks import TextModule
from eva.language.models.wrappers import HuggingFaceTextModel, LiteLLMTextModel, VLLMTextModel

__all__ = ["networks",
"wrappers",
"TextModule",
"HuggingFaceTextModel",
"LiteLLMTextModel",
'VLLMTextModel']
5 changes: 5 additions & 0 deletions src/eva/language/models/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Language Networks API."""

from eva.language.models.networks.module import TextModule

__all__ = ["TextModule"]
83 changes: 83 additions & 0 deletions src/eva/language/models/networks/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""LLM Text Module for Inference."""

from typing import Any, List

from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from typing_extensions import override

from eva.core.metrics import structs as metrics_lib
from eva.core.models.modules import module
from eva.core.models.modules.typings import INPUT_BATCH
from eva.core.models.modules.utils import batch_postprocess


class TextModule(module.ModelModule):
"""Text-based LLM module for inference.

Uses LLM wrappers for text generation.
Supports evaluation using configurable metrics and post-processing. # TODO: Add support
"""

def __init__(
self,
model: nn.Module,
prompt: str,
metrics: metrics_lib.MetricsSchema | None = None,
postprocess: batch_postprocess.BatchPostProcess | None = None,
) -> None:
"""Initializes the text inference module.

Args:
model: An LLM wrapper (PyTorch-compatible) for text generation.
prompt: The prompt to use for generating text.
metrics: Metrics schema for evaluation.
postprocess: A helper function to post-process model outputs before evaluation.
"""
super().__init__(metrics=metrics, postprocess=postprocess)

self.model = model
self.prompt = prompt

@override
def forward(self, prompts: str, *args: Any, **kwargs: Any) -> List[str]:
"""Generates text responses for a batch of prompts.

Args:
prompts: List of input texts to generate responses.
args: Additional arguments.
kwargs: Additional keyword arguments.

Returns:
List of generated responses.
"""
return self.model.generate(prompts)

@override
def validation_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
"""Validation step that runs batch inference and evaluates metrics.

Args:
batch: An input batch.
args: Additional arguments.
kwargs: Additional keyword arguments.

Returns:
Dictionary with predictions, ground truth, and evaluation metrics.
"""
return self._batch_step(batch)

def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
"""Runs inference on a batch and evaluates model predictions.

Args:
batch: A batch containing 'QUESTION', 'CONTEXTS', 'final_decision', etc.

Returns:
Dictionary with predictions, ground truth, and evaluation metrics.
"""
data, targets, metadata = INPUT_BATCH(*batch)
message = self.prompt + str(data) + "\nAnswer: "
predictions = self(message)
# TODO: Add support for evaluation metrics
return {"predictions": predictions, "targets": targets, "metadata": metadata}
7 changes: 7 additions & 0 deletions src/eva/language/models/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Language Model Wrappers API."""

from eva.language.models.wrappers.huggingface import HuggingFaceTextModel
from eva.language.models.wrappers.litellm import LiteLLMTextModel
from eva.language.models.wrappers.vllm import VLLMTextModel

__all__ = ["HuggingFaceTextModel", "LiteLLMTextModel", "VLLMTextModel"]
55 changes: 55 additions & 0 deletions src/eva/language/models/wrappers/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""LLM wrapper for HuggingFace `transformers` models."""

from typing import Any, Dict, Literal

import transformers
from typing_extensions import override

from eva.core.models.wrappers import base


class HuggingFaceTextModel(base.BaseModel):
"""Wrapper class for loading HuggingFace `transformers` models using pipelines."""

def __init__(
self,
model_name_or_path: str,
task: Literal["text-generation", "text-classification"] = "text-generation",
model_kwargs: Dict[str, Any] | None = None,
) -> None:
"""Initializes the model.

Args:
model_name_or_path: The model name or path to load the model from.
This can be a local path or a model name from the `HuggingFace`
model hub.
task: The pipeline task. Defaults to "text-generation".
model_kwargs: Additional arguments for configuring the pipeline.
"""
super().__init__()

self._model_name_or_path = model_name_or_path
self._task = task
self._model_kwargs = model_kwargs or {}

self.load_model()

@override
def load_model(self) -> None:
"""Loads the model as a Hugging Face pipeline."""
self._pipeline = transformers.pipeline(
task=self._task, model=self._model_name_or_path, **self._model_kwargs
)

def generate(self, prompt: str, **generate_kwargs) -> Any:
"""Generates text using the pipeline.

Args:
prompt: The input prompt for the model.
generate_kwargs: Additional generation parameters (e.g., max_length).

Returns:
The generated text as a string.
"""
output = self._pipeline(prompt, return_full_text=False, **generate_kwargs)
return output[0]["generated_text"] if isinstance(output, list) else output
Loading
Loading