Skip to content

Commit

Permalink
Add retry logic to each batch method of the GCS IO (apache#33539)
Browse files Browse the repository at this point in the history
* Add retry logic to each batch method of the GCS IO

A transient error might occur when writing a lot of shards to GCS, and right now
the GCS IO does not have any retry logic in place:

https://github.com/apache/beam/blob/a06454a2/sdks/python/apache_beam/io/gcp/gcsio.py#L269

It means that in such cases the entire bundle of elements fails, and then Beam
itself will attempt to retry the entire bundle, and will fail the job if it
exceeds the number of retries.

This change adds new logic to retry only failed requests, and uses the typical
exponential backoff strategy.

Note that this change accesses a private method (`_predicate`) of the retry
object, which we could avoid by basically copying the logic over here. But
existing code already accesses `_responses` property so maybe it's not a big
deal.

https://github.com/apache/beam/blob/b4c3a4ff/sdks/python/apache_beam/io/gcp/gcsio.py#L297

Existing (unresolved) issue in the GCS client library:

googleapis/python-storage#1277

* Catch correct exception type in `_batch_with_retry`

The `RetryError` would be always raised since the retry decorator would catch
all HTTP-related exceptions.

* Update chanelog with GCSIO retry logic fix
  • Loading branch information
sadovnychyi authored Jan 10, 2025
1 parent b649235 commit e5defdd
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Upgraded to protobuf 4 (Java) ([#33192](https://github.com/apache/beam/issues/33192)).
* [GCSIO] Added retry logic to each batch method of the GCS IO (Python) ([#33539](https://github.com/apache/beam/pull/33539))

## Breaking Changes

Expand Down
75 changes: 53 additions & 22 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from typing import Optional
from typing import Union

from google.api_core.exceptions import RetryError
from google.cloud import storage
from google.cloud.exceptions import NotFound
from google.cloud.exceptions import from_http_response
from google.cloud.storage.fileio import BlobReader
from google.cloud.storage.fileio import BlobWriter
from google.cloud.storage.retry import DEFAULT_RETRY
Expand Down Expand Up @@ -264,9 +266,45 @@ def delete(self, path):
except NotFound:
return

def _batch_with_retry(self, requests, fn):
current_requests = [*enumerate(requests)]
responses = [None for _ in current_requests]

@self._storage_client_retry
def run_with_retry():
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for _, request in current_requests:
fn(request)
last_retryable_exception = None
for (i, current_pair), response in zip(
[*current_requests], current_batch._responses
):
responses[i] = response
should_retry = (
response.status_code >= 400 and
self._storage_client_retry._predicate(from_http_response(response)))
if should_retry:
last_retryable_exception = from_http_response(response)
else:
current_requests.remove((i, current_pair))
if last_retryable_exception:
raise last_retryable_exception

try:
run_with_retry()
except RetryError:
pass

return responses

def _delete_batch_request(self, path):
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

def delete_batch(self, paths):
"""Deletes the objects at the given GCS paths.
Warning: any exception during batch delete will NOT be retried.
Args:
paths: List of GCS file path patterns or Dict with GCS file path patterns
Expand All @@ -285,16 +323,11 @@ def delete_batch(self, paths):
current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE]
else:
current_paths = paths[s:]
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for path in current_paths:
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

responses = self._batch_with_retry(
current_paths, self._delete_batch_request)
for i, path in enumerate(current_paths):
error_code = None
resp = current_batch._responses[i]
resp = responses[i]
if resp.status_code >= 400 and resp.status_code != 404:
error_code = resp.status_code
final_results.append((path, error_code))
Expand Down Expand Up @@ -334,9 +367,16 @@ def copy(self, src, dest):
source_generation=src_generation,
retry=self._storage_client_retry)

def _copy_batch_request(self, pair):
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)
src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

def copy_batch(self, src_dest_pairs):
"""Copies the given GCS objects from src to dest.
Warning: any exception during batch copy will NOT be retried.
Args:
src_dest_pairs: list of (src, dest) tuples of gs://<bucket>/<name> files
Expand All @@ -354,20 +394,11 @@ def copy_batch(self, src_dest_pairs):
current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE]
else:
current_pairs = src_dest_pairs[s:]
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for pair in current_pairs:
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)

src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

responses = self._batch_with_retry(
current_pairs, self._copy_batch_request)
for i, pair in enumerate(current_pairs):
error_code = None
resp = current_batch._responses[i]
resp = responses[i]
if resp.status_code >= 400:
error_code = resp.status_code
final_results.append((pair[0], pair[1], error_code))
Expand Down
68 changes: 68 additions & 0 deletions sdks/python/apache_beam/io/gcp/gcsio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,74 @@ def test_copy(self):
'gs://gcsio-test/non-existent',
'gs://gcsio-test/non-existent-destination')

@staticmethod
def _fake_batch_responses(status_codes):
return mock.Mock(
__enter__=mock.Mock(),
__exit__=mock.Mock(),
_responses=[
mock.Mock(
**{
'json.return_value': {
'error': {
'message': 'error'
}
},
'request.method': 'BATCH',
'request.url': 'contentid://None',
},
status_code=code,
) for code in status_codes
],
)

@mock.patch('apache_beam.io.gcp.gcsio.MAX_BATCH_OPERATION_SIZE', 3)
@mock.patch('time.sleep', mock.Mock())
def test_copy_batch(self):
src_dest_pairs = [
(f'gs://source_bucket/file{i}.txt', f'gs://dest_bucket/file{i}.txt')
for i in range(7)
]
gcs_io = gcsio.GcsIO(
storage_client=mock.Mock(
batch=mock.Mock(
side_effect=[
self._fake_batch_responses([200, 404, 429]),
self._fake_batch_responses([429]),
self._fake_batch_responses([429]),
self._fake_batch_responses([200]),
self._fake_batch_responses([200, 429, 200]),
self._fake_batch_responses([200]),
self._fake_batch_responses([200]),
]),
))
results = gcs_io.copy_batch(src_dest_pairs)
expected = [
('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', None),
('gs://source_bucket/file1.txt', 'gs://dest_bucket/file1.txt', 404),
('gs://source_bucket/file2.txt', 'gs://dest_bucket/file2.txt', None),
('gs://source_bucket/file3.txt', 'gs://dest_bucket/file3.txt', None),
('gs://source_bucket/file4.txt', 'gs://dest_bucket/file4.txt', None),
('gs://source_bucket/file5.txt', 'gs://dest_bucket/file5.txt', None),
('gs://source_bucket/file6.txt', 'gs://dest_bucket/file6.txt', None),
]
self.assertEqual(results, expected)

@mock.patch('time.sleep', mock.Mock())
@mock.patch('time.monotonic', mock.Mock(side_effect=[0, 120]))
def test_copy_batch_timeout_exceeded(self):
src_dest_pairs = [
('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt')
]
gcs_io = gcsio.GcsIO(
storage_client=mock.Mock(
batch=mock.Mock(side_effect=[self._fake_batch_responses([429])])))
results = gcs_io.copy_batch(src_dest_pairs)
expected = [
('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', 429),
]
self.assertEqual(results, expected)

def test_copytree(self):
src_dir_name = 'gs://gcsio-test/source/'
dest_dir_name = 'gs://gcsio-test/dest/'
Expand Down

0 comments on commit e5defdd

Please sign in to comment.