diff --git a/baseplate/sidecars/live_data_watcher.py b/baseplate/sidecars/live_data_watcher.py index 7d7c2a853..376c78694 100644 --- a/baseplate/sidecars/live_data_watcher.py +++ b/baseplate/sidecars/live_data_watcher.py @@ -127,12 +127,12 @@ def _load_from_s3(data: bytes) -> bytes: # We can't assume that every caller of this method will be using prefix sharding on # their S3 objects. num_file_shards = loader_config.get("num_file_shards", 1) - logger.error(num_file_shards) + # If the num_file_shards key is present, we may have multiple copies of the same manifest # uploaded so fetch one randomly using a randomly generated prefix. # Generate a random number from 0 to num_file_shards exclusive to use as prefix. file_key_prefix = random.randrange(num_file_shards) - logger.error(file_key_prefix) + # If 0 is generated, don’t append a prefix, fetch the file with no prefix # since we always upload one file without a prefix. if file_key_prefix == 0: diff --git a/tests/unit/sidecars/live_data_watcher_tests.py b/tests/unit/sidecars/live_data_watcher_tests.py index d8e8729b5..9e1ef4517 100644 --- a/tests/unit/sidecars/live_data_watcher_tests.py +++ b/tests/unit/sidecars/live_data_watcher_tests.py @@ -18,26 +18,27 @@ class NodeWatcherTests(unittest.TestCase): mock_s3 = mock_s3() + bucket_name = "test_bucket" + s3_client = boto3.client( + "s3", + region_name="us-east-1", + ) + default_file_key = "test_file_key" def setUp(self): self.mock_s3.start() - bucket_name = "test_bucket" s3_data = {"foo_encrypted": "bar_encrypted"} - s3_client = boto3.client( - "s3", - region_name="us-east-1", - ) - s3_client.create_bucket(Bucket=bucket_name) - default_file_key = "test_file_key" + + self.s3_client.create_bucket(Bucket=self.bucket_name) for file_shard_num in range(NUM_FILE_SHARDS): if file_shard_num == 0: # The first copy should just be the original file. - sharded_file_key = default_file_key + sharded_file_key = self.default_file_key else: # All other copies should include the sharded prefix. - sharded_file_key = str(file_shard_num) + "/" + default_file_key - s3_client.put_object( - Bucket=bucket_name, + sharded_file_key = str(file_shard_num) + "/" + self.default_file_key + self.s3_client.put_object( + Bucket=self.bucket_name, Key=sharded_file_key, Body=json.dumps(s3_data).encode(), SSECustomerKey="test_decryption_key", @@ -63,12 +64,18 @@ def test_s3_load_type_on_change(self): self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name) self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name) + # For additional good measure, let's also fetch the file from S3 and validate the contents. + obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=self.default_file_key) + actual_data = obj["Body"].read().decode("utf-8") + assert actual_data == json.loads(expected_content) + def test_s3_load_type_sharded_on_change(self): dest = self.output_dir.joinpath("data.txt") inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777) + # Include num_file_shards as a key. new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1", "num_file_shards": 5}' - expected_content = b'{"foo_encrypted": "bar_encrypteds"}' + expected_content = b'{"foo_encrypted": "bar_encrypted"}' # For safe measure, run this 20 times. It should succeed every time. # We've uploaded 5 files to S3 in setUp() and num_file_shards=5 in the ZK node so we should be fetching one of these 5 files randomly (and successfully) - and all should have the same content. @@ -78,6 +85,19 @@ def test_s3_load_type_sharded_on_change(self): self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name) self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name) + # For additional good measure, let's also fetch the files from S3 and validate the contents. + obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=self.default_file_key) + actual_data = obj["Body"].read().decode("utf-8") + assert actual_data == json.loads(expected_content) + + # Assert that all copies of the file are fetchable and contain the same + # data as the original. + for i in range(1, NUM_FILE_SHARDS): + file_key = str(i) + "/" + self.default_file_key + obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=file_key) + actual_data = obj["Body"].read().decode("utf-8") + assert actual_data == json.loads(expected_content) + def test_on_change(self): dest = self.output_dir.joinpath("data.txt") inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)