You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Masking appears to do absolutely nothing in keras 3.0.5.
Code to reproduce:
import numpy as np
import tensorflow as tf
import keras
# import tf_keras as keras # uncomment this to restore working masking functionality
# MyLayer accepts 2 inputs, both of which may be masked.
class MyLayer(keras.layers.Layer):
def __init__(self, *args, **kwargs):
self.supports_masking = True
super(MyLayer, self).__init__(*args, **kwargs)
def compute_mask(self, inputs, mask=None):
if isinstance(mask, list):
mask = mask[1]
return [mask, mask]
def call(self, inputs, mask=None, **kwargs):
input_a, input_b = inputs
# My layer would typically do things here - but just pass inputs through to test
output = input_b[..., 0]
# In keras < 3 I was getting a list of masks, in keras 3.0.5 mask is always None...
if mask is not None:
print("APPLYING MASK!")
assert isinstance(mask, list)
mask = mask[1]
if mask is not None:
output = tf.where(mask, output, tf.zeros_like(output))
output = tf.expand_dims(output, axis=-1)
return output
# Dummy model to test masking
input_a = keras.layers.Input(shape=(None, 2))
input_b = keras.layers.Input(shape=(None, 2))
input_a_masked = keras.layers.Masking()(input_a)
input_b_masked = keras.layers.Masking()(input_b)
y = MyLayer()([input_a_masked, input_b_masked])
model = keras.models.Model(inputs=[input_a, input_b], outputs=[y])
model.summary()
# Dummy data
inputs_a = np.random.random([32, 10, 2]).astype(np.float32)
inputs_a[:, 7:, :] = 0.0 # Masking
inputs_b = np.random.random([32, 20, 2]).astype(np.float32)
inputs_b[:, 15:, :] = 0.0 # Masking
output = model([inputs_a, inputs_b])
print("APPLYING MASK!") is never called.
Installing tf-keras==2.16.0 and replacing the keras import with import tf_keras as keras (commented out above) restores masking functionality.
The text was updated successfully, but these errors were encountered:
Thanks for the report. This seems to be a niche bug that happens to be triggered by the presence of **kwargs in the call() signature. If you remove it, it works. I have fixed the bug on our side as well.
Masking appears to do absolutely nothing in keras 3.0.5.
Code to reproduce:
print("APPLYING MASK!")
is never called.Installing tf-keras==2.16.0 and replacing the keras import with
import tf_keras as keras
(commented out above) restores masking functionality.The text was updated successfully, but these errors were encountered: