Skip to content

Commit

Permalink
Let models/multi_head_attention use policy-controlled dtypes,
Browse files Browse the repository at this point in the history
avoid hard-coding float32. For the default policy, there is no change.

PiperOrigin-RevId: 574825445
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Oct 19, 2023
1 parent b07efe7 commit a20623a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 21 deletions.
3 changes: 1 addition & 2 deletions tensorflow_gnn/models/multi_head_attention/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,12 @@ def convolve(self,
pass
elif self._score_scaling == "rsqrt_dim":
attention_coefficients *= tf.math.rsqrt(
tf.cast(tf.shape(keys)[-1], tf.float32))
tf.cast(tf.shape(keys)[-1], self.compute_dtype))
elif self._score_scaling == "trainable_elup1":
if self._score_scaling_weight is None:
self._score_scaling_weight = self.add_weight(
name="score_scaling",
shape=[self._num_heads, 1],
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(0.0),
trainable=True,
)
Expand Down
89 changes: 70 additions & 19 deletions tensorflow_gnn/models/multi_head_attention/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models import multi_head_attention
Expand All @@ -34,12 +35,29 @@ class ReloadModel(int, enum.Enum):

class MultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(("", False), ("TransformAfter", True))
def testBasic(self, transform_values_after_pooling):
def setUp(self):
super().setUp()
# tf.test.TestCase neglects to reset this between tests.
tf.keras.mixed_precision.set_global_policy("float32")

@parameterized.named_parameters(
("", False),
("F64", False, "float64"),
("MF16", False, "mixed_float16"),
("MBF16", False, "mixed_bfloat16"),
("TransformAfter", True),
("TransformAfterMBF16", True, "mixed_bfloat16"))
def testBasic(self,
transform_values_after_pooling,
mixed_precision_policy_name=None):
"""Tests that a single-headed MHA is correct given predefined weights."""
# NOTE: Many following tests use minor variations of the explicit
# construction of weights and results introduced here.

if mixed_precision_policy_name:
tf.keras.mixed_precision.set_global_policy(mixed_precision_policy_name)
mixed_precision_policy = tf.keras.mixed_precision.global_policy()

# Construct a graph with three nodes 0, 1, 2, and six edges:
# a cycle 0->1->2->0 (let's call it clockwise)
# and the reverse cycle 0->2->1->0 (counterclockwise).
Expand All @@ -65,6 +83,9 @@ def testBasic(self, transform_values_after_pooling):
_ = conv(gt_input, edge_set_name="edges") # Build weights.
weights = {v.name: v for v in conv.trainable_weights}
self.assertLen(weights, 6)
for k, v in weights.items():
with self.subTest(f"dtype check for weight '{k}'"):
self.assertDTypeEqual(v, mixed_precision_policy.variable_dtype)

weights["multi_head_attention_conv/query/kernel:0"].assign(
# The space of attention computation of the single head has dimension 3.
Expand All @@ -86,19 +107,19 @@ def testBasic(self, transform_values_after_pooling):
log20 = tf.math.log(20.).numpy()
log2 = tf.math.log(2.).numpy()
# Using an inverse scaling factor to cancel out the score scaling.
inverse_scaling_factor = tf.math.sqrt(3.)
inverse_scaling_factor = tf.math.sqrt(3.).numpy()
weights["multi_head_attention_conv/key_node/kernel:0"].assign(
# ... the key vectors of node 1 and 2, resp., are \sqrt(3) *
# [log(2), log(20), 0.] and [0., log(2), log(20)]. Therefore, the node 0
# query vector [0., 1., 0.] dot-product on the key vectors of node 1
# and 2 will give \sqrt(3) * log(20) (favored) and \sqrt(3) * log(2)
# (not favored), and the \sqrt(3) is canceled out after scaling.
inverse_scaling_factor * [
inverse_scaling_factor * np.array([
[log20, 0., log2],
[log2, log20, 0.],
[0., log2, log20],
[0., 0., 0.],
])
], dtype=mixed_precision_policy.variable_dtype))
weights["multi_head_attention_conv/key_node/bias:0"].assign([0., 0., 0.])

# The attention coefficients are computed by the dot-product of transformed
Expand Down Expand Up @@ -141,9 +162,15 @@ def testBasic(self, transform_values_after_pooling):
[0., 0., 2.3], # Node 0.
[0., 0., 3.1], # Node 1.
[0., 0., 1.2], # Node 2.
])
], dtype=mixed_precision_policy.compute_dtype)
self.assertAllEqual(got.shape, (3, 3))
self.assertAllClose(got, want, atol=.0001)
self.assertEqual(got.dtype, want.dtype)
self.assertAllCloseAccordingToType(
got, want,
rtol=0., atol=1e-8,
float_rtol=0., float_atol=.0001,
half_rtol=0., half_atol=.01,
bfloat16_rtol=0., bfloat16_atol=.04)

# For node states with more than one feature dimension, MultiHeadAttention
# works in parallel on the vectors from the innermost dimension, so we can
Expand All @@ -160,9 +187,15 @@ def testBasic(self, transform_values_after_pooling):
[[0., 0., 2.3], [0., 0., 9.6]],
[[0., 0., 3.1], [0., 0., 3.9]],
[[0., 0., 1.2], [0., 0., 6.3]],
])
], dtype=mixed_precision_policy.compute_dtype)
self.assertAllEqual(got_2.shape, (3, 2, 3))
self.assertAllClose(got_2, want_2, atol=.0001)
self.assertEqual(got_2.dtype, want_2.dtype)
self.assertAllCloseAccordingToType(
got_2, want_2,
rtol=0., atol=1e-8,
float_rtol=0., float_atol=.0001,
half_rtol=0., half_atol=.01,
bfloat16_rtol=0.015, bfloat16_atol=0.) # NOTE: Accepts 9.5? for 9.6.

def testAttentionActivation(self):
"""Tests that a single-headed MHA correctly applies attention activations."""
Expand Down Expand Up @@ -288,9 +321,20 @@ def get_conv(attention_activation=None):
self.assertAllEqual(got.shape, (3, 3))
self.assertAllClose(got, want, atol=.0001)

def testScoreScalingTypes(self):
# The trainable_* type involves a manually created variable, so we make sure
# to test mixed precision policies for it.
@parameterized.named_parameters(
("",),
("F64", "float64"),
("MF16", "mixed_float16"),
("MBF16", "mixed_bfloat16"))
def testScoreScalingTypes(self, mixed_precision_policy_name=None):
"""Tests that the different types of score scaling are applied correctly."""

if mixed_precision_policy_name:
tf.keras.mixed_precision.set_global_policy(mixed_precision_policy_name)
mixed_precision_policy = tf.keras.mixed_precision.global_policy()

# The same test graph as in the testBasic above.
gt_input = _get_test_bidi_cycle_graph(
tf.constant([
Expand All @@ -317,6 +361,9 @@ def get_conv(score_scaling=None):
self.assertLen(weights, 7)
else:
self.assertLen(weights, 6)
for k, v in weights.items():
with self.subTest(f"dtype check for weight '{k}'"):
self.assertDTypeEqual(v, mixed_precision_policy.variable_dtype)

# If we're using "trainable_elup1", set the initial value to a known
# constant.
Expand Down Expand Up @@ -386,10 +433,12 @@ def get_conv(score_scaling=None):

return conv

# Define these with full Python precision; donwcast later.
named_scalings = {
"none": 1.0,
"rsqrt_dim": 1.0 / math.sqrt(3.0),
"trainable_elup1": 1.0 + tf.keras.activations.elu(-1.234),
"trainable_elup1": (1.0 + tf.keras.activations.elu(
tf.constant(-1.234, tf.float64))).numpy(),
}

for scaling_name, scaling_factor in named_scalings.items():
Expand All @@ -399,15 +448,17 @@ def get_conv(score_scaling=None):

# Since the transformed values are just the identity matrix, we recover
# the attention weights for each query.
w = tf.math.exp(scaling_factor).numpy()
want = tf.constant([
[0.0, w, 1.0],
[w, 0.0, 1.0],
[1.0, w, 0.0],
]) / tf.constant(
w + 1.0, dtype=tf.float32)
w = tf.math.exp(tf.constant(scaling_factor, tf.float64)).numpy()
want = tf.cast(
tf.constant([
[0.0, w, 1.0],
[w, 0.0, 1.0],
[1.0, w, 0.0],
], tf.float64) / (w + 1.0),
dtype=mixed_precision_policy.compute_dtype)
self.assertAllEqual(got.shape, (3, 3))
self.assertAllClose(got, want, atol=0.0001)
self.assertEqual(got.dtype, want.dtype)
self.assertAllCloseAccordingToType(got, want)

def testNoTransformKeys(self):
"""Tests that the no key transformation variant of MHA is correct."""
Expand Down

0 comments on commit a20623a

Please sign in to comment.