Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCAV with a DeBERTa model #1465

Open
elenagaz opened this issue May 13, 2024 · 2 comments
Open

TCAV with a DeBERTa model #1465

elenagaz opened this issue May 13, 2024 · 2 comments

Comments

@elenagaz
Copy link

Hi, I want to test a classification model of type DeBERTa and definitely want to use TCAV for my evaluation.
I have gone through the glue_model.py files and the glue_demo.py, but I cannot seem to find where exactly the TCAV implementation is added to a new demo.

Is there any documentation on where and how the TCAV can be added?
And are there any prerequisites for it to work?

Thanks in advance.

@RyanMullins
Copy link
Member

RyanMullins commented May 13, 2024

Hi @elenagaz!

Thanks for your interest in LIT. TCAV is included as one of LIT's default interpreters, but it is only compatible with certain model APIs, and specifically with classification models. Compatibility checks happen at initialization time for any LitApp (either a server or in a notebook context), but incompatibilities are silent so it can be hard to diagnose mis-configurations.

To summarize the linked function, at least one model needs to 1) predict a MulticlassPreds results, 2) correlate that prediction to a Gradients and an Embeddings, and 3) provide a CategoryLabel field with the ground truth to compare the prediction against. All four of these values need to be in the model's output spec, which also means they are returned in the JSON Objects returned in calls to Model.predict().

I know the above can be a lot to process, so if there's any way you can share your model classes and server script (e.g., a Gist or PR), I'm happy to take a look and provide feedback about why you might not be seeing the TCAV option and how to make it work.

@elenagaz
Copy link
Author

elenagaz commented May 21, 2024

Hi @RyanMullins

Yes, it does incorporate a few things.

As I wrote before, I have used the glue_models.py example included in the code. I did make some changes, to be able to use a DeBERTa model. But as you pointed out, the ouput_spec has proven to be a bit difficult.

The model I am using is a classification model that has an input such as ‘I would like to test this model.’ but returns 9 probabilities, in comparison to the glue_models I have seen.

This is the way I input the data when using my model with the glue_model.py structure.

class ModerationModelSpec(ModerationModel):
    def __init__(self, *args, **kw):
        super().__init__(
            *args,
            prompt="prompt",
            label_names=['M', 'M2', 'B', 'OK', 'L', 'X', 'P', 'V', 'V2'],
            labels=["1", "0"],
            **kw)
    def output_spec(self) -> Spec:
        ret = super().output_spec()
        ret["probas"] = lit_types.MulticlassPreds(
            parent=self.config.label_names,
            vocab=self.config.labels)
        return ret

This would be the output - with the input: ‘I would like to test this model.’
Label: OK - Probability: 0.9922
Label: P - Probability: 0.0020
Label: M - Probability: 0.0019
Label: V - Probability: 0.0010
Label: V2 - Probability: 0.0007
Label: B - Probability: 0.0007
Label: L - Probability: 0.0006
Label: X - Probability: 0.0004
Label: M2 - Probability: 0.0003

With the unedited probabilities like this: tensor([1.9488e-03, 3.2813e-04, 6.1552e-04, 9.9222e-01, 7.1795e-04, 4.4982e-04, 1.9912e-03, 9.9053e-04, 7.3782e-04], grad_fn=<SqueezeBackward0>)

I particularly have trouble computing the grads
And think it is because of the get_target_scores() method

If you have any input on how I could edit the input or output specification, I would appreciate it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants