Skip to content

Commit

Permalink
Fix fm layer
Browse files Browse the repository at this point in the history
  • Loading branch information
GrogusBall committed Mar 11, 2021
1 parent 6a125a5 commit 3da1912
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 74 deletions.
62 changes: 31 additions & 31 deletions deep_recommenders/layers/fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@Author: Wang Yao
@Date: 2020-12-03 18:01:05
@LastEditors: Wang Yao
@LastEditTime: 2021-02-26 15:17:34
@LastEditTime: 2021-03-10 19:14:47
"""
from typing import Optional, Union, Text

Expand All @@ -25,8 +25,8 @@ def __init__(self,
super(FM, self).__init__(**kwargs)

self._factors = factors
self._kernel_init = kernel_init
self._kernel_regu = kernel_regu
self._kernel_init = tf.keras.initializers.get(kernel_init)
self._kernel_regu = tf.keras.regularizers.get(kernel_regu)

if (self._factors is not None) and (self._factors <= 0):
raise ValueError("`factors` should be bigger than 0. "
Expand All @@ -38,49 +38,49 @@ def build(self, input_shape):
if self._factors is not None:
self._kernel = self.add_weight(
shape=(last_dim, self._factors),
initializer=tf.keras.initializers.get(self._kernel_init),
regularizer=tf.keras.regularizers.get(self._kernel_regu),
initializer=self._kernel_init,
regularizer=self._kernel_regu,
trainable=True,
name="kernel"
)
self.built = True

def call(self, x: tf.Tensor):

if tf.keras.backend.ndim(x) != 3:
raise ValueError("`x` dim should be 3. Got `x` dim = {}".format(
tf.keras.backend.ndim(x)))

if self._factors is None:
embed_x = x
square_embed_x = tf.pow(x, 2)

if tf.keras.backend.ndim(x) != 3:
raise ValueError("When `factors` is None, `x` dim should be 3. "
"Got `x` dim = {}".format(tf.keras.backend.ndim(x)))

x_sum = tf.reduce_sum(x, axis=1)
x_square_sum = tf.reduce_sum(tf.pow(x, 2), axis=1)
else:
embed_x = tf.matmul(x, self._kernel)
square_embed_x = tf.matmul(tf.pow(x, 2), tf.pow(self._kernel, 2))
if tf.keras.backend.ndim(x) != 2:
raise ValueError("When `factors` is not None, `x` dim should be 2. "
"Got `x` dim = {}".format(tf.keras.backend.ndim(x)))

x_sum = tf.linalg.matmul(x, self._kernel, a_is_sparse=True)
x_square_sum = tf.linalg.matmul(
tf.pow(x, 2), tf.pow(self._kernel, 2), a_is_sparse=True)

outputs = 0.5 * tf.reduce_sum(
tf.subtract(
tf.pow(tf.reduce_sum(embed_x, axis=1, keepdims=True), 2),
tf.reduce_sum(square_embed_x, axis=1, keepdims=True)
), axis=-1, keepdims=False
)
tf.pow(x_sum, 2),
x_square_sum
), axis=1, keepdims=True)

return outputs

def get_config(self):
if self._factors is None:
config = {
"factors":
self._factors,
}
else:
config = {
"factors":
self._factors,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_init),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regu),
}
config = {
"factors":
self._factors,
"kernel_init":
tf.keras.initializers.serialize(self._kernel_init),
"kernel_regu":
tf.keras.regularizers.serialize(self._kernel_regu),
}
base_config = super(FM, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

103 changes: 60 additions & 43 deletions deep_recommenders/tests/test_layer_fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@Author: Wang Yao
@Date: 2021-02-26 15:15:08
@LastEditors: Wang Yao
@LastEditTime: 2021-02-26 15:16:51
@LastEditTime: 2021-03-11 15:19:01
"""
import sys
sys.dont_write_bytecode = True
Expand All @@ -24,86 +24,103 @@ def test_invalid_factors(self):
""" 测试 factors <= 0 """
with self.assertRaisesRegexp(ValueError,
r"should be bigger than 0"):
x = np.random.random((10, 12, 5)).astype(np.float32) # pylint: disable=no-member
x = np.random.random((10, 5)).astype(np.float32) # pylint: disable=no-member
layer = FM(factors=-1)
layer(x)

def test_int_factors(self):
""" 测试 factors = int """
x = np.asarray([[
[1.0, 0.0, 0.0],
def test_x_invalid_dim(self):
""" 测试 x dim invalid """
with self.assertRaisesRegexp(ValueError,
r"`x` dim should be 2."):
x = np.random.normal(size=(3, 10, 5))
fm = FM(factors=2)
fm(x)
with self.assertRaisesRegexp(ValueError,
r"`x` dim should be 3."):
x = np.random.normal(size=(10, 5))
fm = FM(factors=None)
fm(x)

def test_factors(self):
""" 测试 factors """
x = np.asarray([
[1.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
]]).astype(np.float32)
layer = FM(factors=2, kernel_init="ones")
output = layer(x)
]).astype(np.float32)

factors = np.asarray([
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0]
]).astype(np.float32)

x_sum = x @ factors
x_square_sum = np.power(x, 2) @ np.power(factors, 2)
expected_outputs = 0.5 * np.sum(np.power(x_sum, 2) - x_square_sum, axis=1, keepdims=True)

fm = FM(factors=2, kernel_init="ones")
outputs = fm(x)

self.evaluate(tf.compat.v1.global_variables_initializer())
self.assertAllClose(np.asarray([[2.]]).astype(np.float32), output)
self.assertAllClose(outputs, expected_outputs)

def test_none_factors(self):
""" 测试 factors = None """
x = np.random.random((10, 12, 5)).astype(np.float32) # pylint: disable=no-member
layer = FM(factors=None)
layer(x)

def test_x_invalid_dim(self):
""" 测试 x dim invalid """
with self.assertRaisesRegexp(ValueError,
r"`x` dim should be 3."):
x = np.random.random((10, 60)).astype(np.float32) # pylint: disable=no-member
layer = FM()
layer(x)
x = np.asarray([
[1.0, 1.0, 0.0],
[0.0, 1.0, 1.0]])

fm_factors = FM(factors=2)
fm_factors_outputs = fm_factors(x)

embeddings_martix = fm_factors.get_weights()[0]

def test_outputs_with_diff_factors(self):
""" 测试 factors = None 和 factors = 10 输出是否相等 """
x = np.random.random((10, 12, 5)).astype(np.float32) # pylint: disable=no-member
factors = 5
fm_nofactors_x = tf.gather(embeddings_martix, [[0, 1], [1, 2]])

def identity(shape, dtype=None):
return np.eye(shape[-1])

layer_factors_none = FM(factors=None)
layer_factors_5 = FM(factors=factors, kernel_init=identity)
layer_factors_none_output = layer_factors_none(x)
layer_factors_5_output = layer_factors_5(x)
fm_nofactors = FM(factors=None)
fm_nofactors_outputs = fm_nofactors(fm_nofactors_x)

self.evaluate(tf.compat.v1.global_variables_initializer())
self.assertAllClose(layer_factors_none_output, layer_factors_5_output)
self.assertAllClose(fm_nofactors_outputs, fm_factors_outputs)


def test_train_model(self):
""" 测试训练模型 """

def get_model():
inputs = tf.keras.layers.Input(shape=(12, 5,))
inputs = tf.keras.layers.Input(shape=(10, 5,))
x = FM(factors=None)(inputs)
logits = tf.keras.layers.Dense(units=1)(x)
print(x)
logits = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, logits)
return model

model = get_model()
random_input = np.random.uniform(size=(10, 12, 5))
random_output = np.random.uniform(size=(10,))
random_inputs = np.random.uniform(size=(32, 10, 5))
random_outputs = np.random.uniform(size=(32,))
model.compile(loss="mse")
model.fit(random_input, random_output, verbose=0)
model.fit(random_inputs, random_outputs, verbose=0)

def test_save_model(self):
""" 测试保存模型 """

def get_model():
inputs = tf.keras.layers.Input(shape=(12, 5,))
inputs = tf.keras.layers.Input(shape=(10, 5,))
x = FM(factors=None)(inputs)
logits = tf.keras.layers.Dense(units=1)(x)
logits = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, logits)
return model

model = get_model()
random_input = np.random.uniform(size=(10, 12, 5))
random_input = np.random.uniform(size=(32, 10, 5))
model_pred = model.predict(random_input)

with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "fm_model")
path = os.path.join(tmp, "fm")
model.save(
path,
options=tf.saved_model.SaveOptions(namespace_whitelist=["Addons"]))
options=tf.saved_model.SaveOptions(namespace_whitelist=["FM"]))
loaded_model = tf.keras.models.load_model(path)
loaded_pred = loaded_model.predict(random_input)
for model_layer, loaded_layer in zip(model.layers, loaded_model.layers):
Expand All @@ -112,4 +129,4 @@ def get_model():


if __name__ == "__main__":
tf.test.main()
tf.test.main()

0 comments on commit 3da1912

Please sign in to comment.