Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to dropout. #29

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

aliciafmachado
Copy link
Collaborator

@aliciafmachado aliciafmachado commented Sep 8, 2024

  • Add dropout.
  • Add basic tests for dropout and trainer with dropout enabled.
  • Add example in app.
  • Add evalMode flag so that dropout can be disabled during eval.

Intended to resolve Issue: #1

@aliciafmachado aliciafmachado changed the title Add more tests to dropout and pass flag to computeTransformer to disable dropout during evaluation. Add support to dropout. Sep 8, 2024
Copy link
Collaborator Author

@aliciafmachado aliciafmachado left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some commits going back and forth on some things that I did not fully understood at first, so feel free to squash the commits before merging to main to avoid confusion. Otherwise I can recreate the pull request and fix the commit history.

I also have a few questions / discussion topics:

  1. I added support to dropout but we need something to manage random seeds so that we can seed properly. Should we create an issue for that?
  2. I tried to add dropout based on T5 architecture, but I decided to not add it after the FF layer and in the output. For the FF, I don't think it makes sense since we have a single layer and we apply dropout before the residual connection after the FF network. For the output, I don't see any additional computations after getting out of the stack, so I think it would only increase noise if we were to add another dropout there (I also did not see an additional dropout on the output for the haiku implementation linked in the issue to add dropout).

@aliciafmachado aliciafmachado marked this pull request as ready for review September 8, 2024 16:06
Copy link
Collaborator

@iislucas iislucas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, a few small things.

@@ -225,13 +229,20 @@ function gelu(x: tf.Tensor) {
export function computeAttnHead(
spec: AttnHeadComputeSpec,
params: AttnHeadParams<TensorKind>,
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>
seqInput: GTensor<'batch' | 'pos' | 'inputRep'>,
evalMode: boolean = false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets drop evalMode flag, and just depend on the spec having dropoutRate set different at eval vs inference time.

export function dropout<G extends string, D extends G>(
dropoutRate: number,
g: GTensor<G>,
deterministic: boolean,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets remove deterministic, and just check if rate is 0.

let unNormedSeqOuput = inputToFF
.contract(ff.w, ['inputRepToFF'])
.pointwiseAdd(ff.bIn)
.applyPointWiseTfFn(gelu)
.pointwiseAdd(ff.bOut);

// Dropout before layer norm and residual connection.
let unNormedSeqOuputAfterDropout = unNormedSeqOuput;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets use this: https://github.com/Shivanandroy/simpleT5 as the reference for where to put it for T5. And maybe name this function computeT5AttnHead, and then later we can make a gpt2 one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

const layerSpec: transformer.TransformerParamLayerSpec = {
nHeads: 1,
hasPosEncoding: true,
computeSpec: { residuals: true, dropoutRate: 0.1 },
Copy link
Collaborator

@iislucas iislucas Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add one test also for dropout rate of 1, and then test that loss doesn't decrease.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@aliciafmachado
Copy link
Collaborator Author

Will rebase once #36 is submitted and then pass a generator so that the dropout is reproducible, and then you can take a second look @iislucas.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants