Skip to content

Commit

Permalink
cleanup of deprecated test methods
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716245406
  • Loading branch information
tf-transform-team authored and tfx-copybara committed Jan 16, 2025
1 parent a14023c commit e982a7b
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 31 deletions.
5 changes: 3 additions & 2 deletions tensorflow_transform/beam/analyzer_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def test_validate_dataset_keys(self):
})

for key in {analyzer_cache.DatasetKey(k) for k in ('^foo^', 'foo 1')}:
with self.assertRaisesRegexp(
ValueError, 'Dataset key .* does not match allowed pattern:'):
with self.assertRaisesRegex(
ValueError, 'Dataset key .* does not match allowed pattern:'
):
analyzer_cache.validate_dataset_keys({key})

@test_case.named_parameters(
Expand Down
10 changes: 6 additions & 4 deletions tensorflow_transform/beam/bucketize_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,9 @@ def no_assert():

assertion = no_assert()
if input_dtype == tf.float16:
assertion = self.assertRaisesRegexp(
TypeError, '.*DataType float16 not in list of allowed values.*')
assertion = self.assertRaisesRegex(
TypeError, '.*DataType float16 not in list of allowed values.*'
)

with assertion:
self.assertAnalyzeAndTransformResults(
Expand Down Expand Up @@ -504,8 +505,9 @@ def no_assert():

assertion = no_assert()
if input_dtype == tf.float16:
assertion = self.assertRaisesRegexp(
TypeError, '.*DataType float16 not in list of allowed values.*')
assertion = self.assertRaisesRegex(
TypeError, '.*DataType float16 not in list of allowed values.*'
)

with assertion:
self.assertAnalyzeAndTransformResults(
Expand Down
18 changes: 11 additions & 7 deletions tensorflow_transform/beam/impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,8 +1463,9 @@ def preprocessing_fn(inputs):
expected_data = [{'x_scaled': float('nan')}]
expected_metadata = tft.DatasetMetadata.from_feature_spec(
{'x_scaled': tf.io.FixedLenFeature([], tf.float32)})
with self.assertRaisesRegexp( # pylint: disable=g-error-prone-assert-raises
ValueError, 'output_min must be less than output_max'):
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
ValueError, 'output_min must be less than output_max'
):
self.assertAnalyzeAndTransformResults(input_data, input_metadata,
preprocessing_fn, expected_data,
expected_metadata)
Expand Down Expand Up @@ -4656,8 +4657,9 @@ def preprocessing_fn(inputs):
preprocessing_fn, expected_outputs)

def testEmptySchema(self):
with self.assertRaisesRegexp( # pylint: disable=g-error-prone-assert-raises
ValueError, 'The input metadata is empty.'):
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
ValueError, 'The input metadata is empty.'
):
self.assertAnalyzeAndTransformResults(
input_data=[{'x': x} for x in range(5)],
input_metadata=tft.DatasetMetadata.from_feature_spec({}),
Expand Down Expand Up @@ -4785,10 +4787,12 @@ def preprocessing_fn(inputs):
preprocessing_fn, expected_outputs)

def test_preprocessing_fn_returns_wrong_type(self):
with self.assertRaisesRegexp( # pylint: disable=g-error-prone-assert-raises
ValueError, r'A `preprocessing_fn` must return a '
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
ValueError,
r'A `preprocessing_fn` must return a '
r'Dict\[str, Union\[tf.Tensor, tf.SparseTensor, tf.RaggedTensor\]\]. '
'Got: Tensor.*'):
'Got: Tensor.*',
):
self.assertAnalyzeAndTransformResults(
input_data=[{'f1': 0}],
input_metadata=tft.DatasetMetadata.from_feature_spec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def mock_write_metadata_expand(unused_self, unused_metadata):

with mock.patch.object(transform_fn_io.beam_metadata_io.WriteMetadata,
'expand', mock_write_metadata_expand):
with self.assertRaisesRegexp(ArithmeticError, 'Some error'):
with self.assertRaisesRegex(ArithmeticError, 'Some error'):
_ = ((saved_model_dir_pcoll, object())
| transform_fn_io.WriteTransformFn(transform_output_dir))

Expand Down
4 changes: 2 additions & 2 deletions tensorflow_transform/coders/csv_coder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_constructor_error(self,
error_type=ValueError,
**kwargs):
schema = schema_utils.schema_from_feature_spec(feature_spec)
with self.assertRaisesRegexp(error_type, error_msg):
with self.assertRaisesRegex(error_type, error_msg):
csv_coder.CsvCoder(columns, schema, **kwargs)

@test_case.named_parameters(*_ENCODE_ERROR_CASES)
Expand All @@ -266,7 +266,7 @@ def test_encode_error(self,
**kwargs):
schema = schema_utils.schema_from_feature_spec(feature_spec)
coder = csv_coder.CsvCoder(columns, schema, **kwargs)
with self.assertRaisesRegexp(error_type, error_msg):
with self.assertRaisesRegex(error_type, error_msg):
coder.encode(instance)

def test_picklable(self):
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_transform/coders/example_proto_coder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def test_encode_error(self,
error_type=ValueError,
**kwargs):
schema = schema_utils.schema_from_feature_spec(feature_spec)
with self.assertRaisesRegexp(error_type, error_msg):
with self.assertRaisesRegex(error_type, error_msg):
coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs)
coder.encode(instance)

Expand Down
4 changes: 2 additions & 2 deletions tensorflow_transform/graph_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def testInitializableGraphAnalyzerConstructorRaises(
tensors = create_graph_fn()
replaced_tensors_ready = [(tensors[name], ready)
for name, ready in replaced_tensors_ready.items()]
with self.assertRaisesRegexp(ValueError, error_msg_regex):
with self.assertRaisesRegex(ValueError, error_msg_regex):
graph_tools.InitializableGraphAnalyzer(graph,
{x: tensors[x] for x in feeds},
replaced_tensors_ready)
Expand Down Expand Up @@ -639,7 +639,7 @@ def testInitializableGraphAnalyzerReadyToRunRaises(
tensors[name], ready) for name, ready in replaced_tensors_ready.items()]
graph_analyzer = graph_tools.InitializableGraphAnalyzer(
graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready)
with self.assertRaisesRegexp(ValueError, error_msg_regex):
with self.assertRaisesRegex(ValueError, error_msg_regex):
tensor = tensors[fetch]
graph_analyzer.ready_to_run(tensor)

Expand Down
9 changes: 5 additions & 4 deletions tensorflow_transform/mappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def testNGramsWithRepeatedTokensPerRow(self):
def testNGramsBadSizes(self):
string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', ''])
tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter='')
with self.assertRaisesRegexp(ValueError, 'Invalid ngram_range'):
with self.assertRaisesRegex(ValueError, 'Invalid ngram_range'):
mappers.ngrams(tokenized_tensor, (0, 5), separator='')
with self.assertRaisesRegexp(ValueError, 'Invalid ngram_range'):
with self.assertRaisesRegex(ValueError, 'Invalid ngram_range'):
mappers.ngrams(tokenized_tensor, (6, 5), separator='')

def testNGramsBagOfWordsEmpty(self):
Expand Down Expand Up @@ -837,8 +837,9 @@ def testApplyBucketsWithInterpolationAllNanBoundariesRaises(self):
with self.test_session() as sess:
x = tf.constant([float('-inf'), float('nan'), 0.0, 1.0])
boundaries = tf.constant([[float('nan'), float('nan'), float('nan')]])
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
'num_boundaries'):
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError, 'num_boundaries'
):
sess.run(mappers.apply_buckets_with_interpolation(x, boundaries))

def testApplyBucketsWithInterpolationRaises(self):
Expand Down
12 changes: 7 additions & 5 deletions tensorflow_transform/test_case_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,17 @@ def testAssertDataCloseOrEqual(self):
'd': ('second', 2.0000001)},
{'e': 2,
'f': 3}])
with self.assertRaisesRegexp(AssertionError, r'len\(.*\) != len\(\[\]\)'):
with self.assertRaisesRegex(AssertionError, r'len\(.*\) != len\(\[\]\)'):
self.assertDataCloseOrEqual([{'a': 1}], [])
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AssertionError,
re.compile('Element counts were not equal.*: Row 0', re.DOTALL)):
re.compile('Element counts were not equal.*: Row 0', re.DOTALL),
):
self.assertDataCloseOrEqual([{'a': 1}], [{'b': 1}])
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AssertionError,
re.compile('Not equal to tolerance.*: Row 0, key a', re.DOTALL)):
re.compile('Not equal to tolerance.*: Row 0, key a', re.DOTALL),
):
self.assertDataCloseOrEqual([{'a': 1}], [{'a': 2}])

@test_case.parameters((1, 'a'), (2, 'b'))
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_transform/tf_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def test_same_shape_exceptions(self, x_input, y_input, x_shape, y_shape,
x = tf.compat.v1.placeholder(tf.int32, x_shape)
y = tf.compat.v1.placeholder(tf.int32, y_shape)
with tf.compat.v1.Session() as sess:
with self.assertRaisesRegexp(exception_cls, error_string):
with self.assertRaisesRegex(exception_cls, error_string):
sess.run(tf_utils.assert_same_shape(x, y), {x: x_input, y: y_input})

@test_case.named_parameters(test_case.FUNCTION_HANDLERS)
Expand Down Expand Up @@ -1965,7 +1965,7 @@ def test_sparse_indices(self):
x = tf.compat.v1.sparse_placeholder(tf.int64, shape=[None, None])
key = tf.compat.v1.sparse_placeholder(tf.string, shape=[None, None])
with tf.compat.v1.Session() as sess:
with self.assertRaisesRegexp(exception_cls, error_string):
with self.assertRaisesRegex(exception_cls, error_string):
sess.run(tf_utils.reduce_batch_minus_min_and_max_per_key(x, key),
feed_dict={x: value, key: key_value})

Expand Down Expand Up @@ -2000,7 +2000,7 @@ def test_convert_sparse_indices(self):
dense_shape=[4, 2, 5])

with tf.compat.v1.Session() as sess:
with self.assertRaisesRegexp(exception_cls, error_string):
with self.assertRaisesRegex(exception_cls, error_string):
sess.run(tf_utils._validate_and_get_dense_value_key_inputs(sparse1,
sparse2),
feed_dict={sparse1: sparse_value1, sparse2: sparse_value2})
Expand Down

0 comments on commit e982a7b

Please sign in to comment.