Skip to content

Commit

Permalink
Align with CPU implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mmikolajcz committed Nov 27, 2023
1 parent f72ba94 commit 9c8cbd6
Showing 1 changed file with 60 additions and 12 deletions.
72 changes: 60 additions & 12 deletions tests/layer_tests/tensorflow_tests/test_tf_Multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,56 @@ def _prepare_input(self, inputs_dict, kwargs):
inputs_dict["probs"] = kwargs["input"]
return inputs_dict

def create_tf_multinomial_net(self, global_seed, op_seed, logits_shape, input_type, out_type):
def create_tf_multinomial_net_shape(
self, global_seed, op_seed, logits_shape, input_type, out_type
):
tf.compat.v1.reset_default_graph()
# Configuration required to make multinomial randomness predictable across devices, results depends on TF parallel execution.
session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
session_conf = tf.compat.v1.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1
)
# Create the graph and model
with tf.compat.v1.Session(config=session_conf) as sess:
probs = tf.compat.v1.placeholder(input_type, logits_shape, "probs")
num_samples = tf.compat.v1.placeholder(tf.int32, [], "num_samples")
if global_seed is not None:
tf.random.set_seed(global_seed)
tf.raw_ops.Multinomial(logits=tf.math.log(probs), num_samples=num_samples, seed=global_seed, seed2=op_seed, output_dtype=out_type)
tf.raw_ops.ZerosLike(
x=tf.raw_ops.Multinomial(
logits=tf.math.log(probs),
num_samples=num_samples,
seed=global_seed,
seed2=op_seed,
output_dtype=out_type,
)
)

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

def create_tf_multinomial_net_exact(
self, global_seed, op_seed, logits_shape, input_type, out_type
):
tf.compat.v1.reset_default_graph()
# Configuration required to make multinomial randomness predictable across devices, results depends on TF parallel execution.
session_conf = tf.compat.v1.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1
)
# Create the graph and model
with tf.compat.v1.Session(config=session_conf) as sess:
probs = tf.compat.v1.placeholder(input_type, logits_shape, "probs")
num_samples = tf.compat.v1.placeholder(tf.int32, [], "num_samples")
if global_seed is not None:
tf.random.set_seed(global_seed)
tf.raw_ops.Multinomial(
logits=tf.math.log(probs),
num_samples=num_samples,
seed=global_seed,
seed2=op_seed,
output_dtype=out_type,
)

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
Expand All @@ -32,32 +71,39 @@ def create_tf_multinomial_net(self, global_seed, op_seed, logits_shape, input_ty

@pytest.mark.parametrize("out_type", [tf.int32, tf.int64])
@pytest.mark.parametrize(
("input", "num_samples", "seed"),
("input", "num_samples", "seed", "test_type"),
[
(
np.array([[0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0]], dtype=np.float32),
1024,
[32465, 48971],
"exact",
),
(
np.array([[0.001, 0.001, 0.1, 0.9], [5, 10, 1e-5, 256], [1, 1e-5, 1e-5, 1e-5]], dtype=np.float64),
np.array(
[
[0.001, 0.001, 0.1, 0.9],
[5, 10, 1e-5, 256],
[1, 1e-5, 1e-5, 1e-5],
],
dtype=np.float64,
),
256,
[32465, 48971],
"shape",
),
(
np.array([[1, 1, 1, 1]], dtype=np.float16),
1024,
[1, 1],
),
(np.array([[1, 1, 1, 1]], dtype=np.float16), 1024, [1, 1], "shape"),
(
np.array([[1, 2, 3, 4], [4, 3, 2, 1], [1, 0, 0, 0]], dtype=np.float32),
1,
[78132, None],
"shape",
),
(
np.array([[7, 7, 7, 7], [7, 7, 7, 7], [7, 7, 7, 7]], dtype=np.float32),
1024,
[32465, None],
"shape",
),
],
)
Expand All @@ -69,6 +115,7 @@ def test_multinomial_basic(
num_samples,
seed,
out_type,
test_type,
ie_device,
precision,
ir_version,
Expand All @@ -78,8 +125,9 @@ def test_multinomial_basic(
):
if ie_device == "GPU":
pytest.skip("Multinomial is not supported on GPU")
net = getattr(self, f"create_tf_multinomial_net_{test_type}")
self._test(
*self.create_tf_multinomial_net(
*net(
global_seed=seed[0],
op_seed=seed[1],
logits_shape=input.shape,
Expand All @@ -92,5 +140,5 @@ def test_multinomial_basic(
ir_version=ir_version,
use_new_frontend=use_new_frontend,
use_old_api=use_old_api,
kwargs_to_prepare_input={"input": input, "num_samples": num_samples}
kwargs_to_prepare_input={"input": input, "num_samples": num_samples},
)

0 comments on commit 9c8cbd6

Please sign in to comment.