diff --git a/animated-transformer/src/lib/seqtasks/tiny_worlds_train.script.ts b/animated-transformer/src/lib/seqtasks/tiny_worlds_train.script.ts index 097c09a..8481e11 100644 --- a/animated-transformer/src/lib/seqtasks/tiny_worlds_train.script.ts +++ b/animated-transformer/src/lib/seqtasks/tiny_worlds_train.script.ts @@ -38,6 +38,7 @@ import { transformerLastTokenLogits, transformerLastTokenCrossEntropyLoss, transformerAccuracy, + transformerAllTokensCrossEntropyLoss } from '../transformer/transformer_gtensor'; import { TinyWorldTask, @@ -52,6 +53,7 @@ import { singleNextTokenIdxOutputPrepFn, prepareBasicTaskTokenRep, BasicTaskTokenRep, + prepareTargetsTensor } from '../tokens/token_gemb'; import { layer } from '@tensorflow/tfjs-vis/dist/show/model'; import { example } from 'yargs'; @@ -114,7 +116,7 @@ function* dataGenerator(task: TinyWorldTask, batchNum: number, batchSize: number function unbindedLossFn( batchId: number, batchInput: string[][], - batchOutput: string[][], + batchOutput: string[][], // Targets tokenRep: BasicTaskTokenRep, transformerConfig: TransformerConfig, decoderParamsTree: TransformerParams @@ -128,7 +130,7 @@ function unbindedLossFn( batchInput ); let singleNextTokenIdx = singleNextTokenIdxOutputPrepFn(tokenRep, batchOutput); - let entropyLoss: tf.Scalar = transformerLastTokenCrossEntropyLoss( + let lastTokenEntropyLoss: tf.Scalar = transformerLastTokenCrossEntropyLoss( computation, decoderParamsTree.tokenEmbedding, singleNextTokenIdx @@ -139,14 +141,19 @@ function unbindedLossFn( singleNextTokenIdx ); + let targetIdxs = prepareTargetsTensor(tokenRep, batchInput, batchOutput); + let fullEntropyLoss = transformerAllTokensCrossEntropyLoss(computation, decoderParamsTree.obj.tokenEmbedding, targetIdxs); + if (batchId % printEveryNBatches === 0) { console.log( `batch: ${batchId} `.padEnd(15) + - ('entropyLoss: ' + entropyLoss.arraySync().toFixed(8)).padEnd(25) + - ('accuracy: ' + accuracy.arraySync().toFixed(8)).padEnd(25) + ('lastTokenEntropyLoss: ' + lastTokenEntropyLoss.arraySync().toFixed(8)).padEnd(25) + + ('fullEntropyLoss: ' + fullEntropyLoss.arraySync().toFixed(8)).padEnd(25) + + ('accuracy: ' + accuracy.arraySync().toFixed(8)).padEnd(25) ); } - return entropyLoss; + // return lastTokenEntropyLoss; + return fullEntropyLoss; } function run() { diff --git a/animated-transformer/src/lib/tokens/token_gemb.ts b/animated-transformer/src/lib/tokens/token_gemb.ts index a9f6729..9eb535d 100644 --- a/animated-transformer/src/lib/tokens/token_gemb.ts +++ b/animated-transformer/src/lib/tokens/token_gemb.ts @@ -258,6 +258,19 @@ export function singleNextTokenIdxOutputPrepFn( ); } +export function prepareTargetsTensor( + tokenRep: BasicTaskTokenRep, + inputSeqs: string[][], + outputSeqs: string[][] +): GTensor<'batch' | 'pos'> { + const firstColumnOfOutputSeq = tf.tensor2d(outputSeqs.map((outputSeq) => tokenRep.tokenToIdx[outputSeq[0]])).slice([0, 0], [-1, 1]); + const resultTensor = tf.tensor2d(inputSeqs).concat(firstColumnOfOutputSeq, 1); + return new GTensor( + resultTensor, + ['batch', 'pos'] + ); +} + export function padInputSeqStart( paddingToken: string, maxInputLength: number, diff --git a/animated-transformer/src/lib/transformer/transformer_gtensor.ts b/animated-transformer/src/lib/transformer/transformer_gtensor.ts index 74ce68a..deac1b9 100644 --- a/animated-transformer/src/lib/transformer/transformer_gtensor.ts +++ b/animated-transformer/src/lib/transformer/transformer_gtensor.ts @@ -416,6 +416,51 @@ export function transformerLastTokenCrossEntropyLoss( // return loss.tensor; } +/** Return logits for all tokens of the transformer. + * + * params: transformer parameters. + * tokenEmb: embeddings for all tokens. + */ +export function transformerLogits( + params: TransformerComputation, + tokenEmb: GTensor<'tokenId' | 'inputRep'> +): GTensor<'batch' | 'pos' | 'tokenId'> { + const lastLayer = params.layers[params.layers.length - 1]; + const seqOutput = lastLayer.seqOuput; + const logits = seqOutput.contract(tokenEmb, ['inputRep']); + return logits; +} + +/** + * Returns the average per example loss for all tokens predicated. + * losses are summed over all positions. + */ +export function transformerAllTokensCrossEntropyLoss( + params: TransformerComputation, + tokenEmb: GTensor<'tokenId' | 'inputRep'>, + targetTokenIdxs: GTensor<'batch' | 'pos'> +): tf.Scalar { + const logits = transformerLogits(params, tokenEmb); + + const logProbs = logits.softmax('tokenId').log(); + const oneHotToken = new GTensor(oneHot(targetTokenIdxs.tensor, tokenEmb.dim.tokenId.size), [ + 'batch', + 'pos', + 'tokenId', + ]); + + const crossEntopy = logProbs.pointwiseMul(oneHotToken); + + const batchSizeScalar = tf.scalar(targetTokenIdxs.dim.batch.size * -1); + const posSizeScalar = tf.scalar(targetTokenIdxs.dim.pos.size * -1); + + return ( + crossEntopy + .sumOverDims(['batch', 'pos', 'tokenId']) + ._tfScalarDiv(tf.mul(batchSizeScalar, posSizeScalar)).tensor as tf.Scalar + ); +} + /** Batch compute the top prediction from the last token of a transformer. * * params: transformer parameters.