Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Siddharth Manoj committed Nov 30, 2023
1 parent 046df38 commit 658c1fb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
4 changes: 2 additions & 2 deletions baseplate/sidecars/live_data_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def _load_from_s3(data: bytes) -> bytes:
# We can't assume that every caller of this method will be using prefix sharding on
# their S3 objects.
num_file_shards = loader_config.get("num_file_shards", 1)
logger.error(num_file_shards)

# If the num_file_shards key is present, we may have multiple copies of the same manifest
# uploaded so fetch one randomly using a randomly generated prefix.
# Generate a random number from 0 to num_file_shards exclusive to use as prefix.
file_key_prefix = random.randrange(num_file_shards)
logger.error(file_key_prefix)

# If 0 is generated, don’t append a prefix, fetch the file with no prefix
# since we always upload one file without a prefix.
if file_key_prefix == 0:
Expand Down
44 changes: 32 additions & 12 deletions tests/unit/sidecars/live_data_watcher_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,27 @@

class NodeWatcherTests(unittest.TestCase):
mock_s3 = mock_s3()
bucket_name = "test_bucket"
s3_client = boto3.client(
"s3",
region_name="us-east-1",
)
default_file_key = "test_file_key"

def setUp(self):
self.mock_s3.start()
bucket_name = "test_bucket"
s3_data = {"foo_encrypted": "bar_encrypted"}
s3_client = boto3.client(
"s3",
region_name="us-east-1",
)
s3_client.create_bucket(Bucket=bucket_name)
default_file_key = "test_file_key"

self.s3_client.create_bucket(Bucket=self.bucket_name)
for file_shard_num in range(NUM_FILE_SHARDS):
if file_shard_num == 0:
# The first copy should just be the original file.
sharded_file_key = default_file_key
sharded_file_key = self.default_file_key
else:
# All other copies should include the sharded prefix.
sharded_file_key = str(file_shard_num) + "/" + default_file_key
s3_client.put_object(
Bucket=bucket_name,
sharded_file_key = str(file_shard_num) + "/" + self.default_file_key
self.s3_client.put_object(
Bucket=self.bucket_name,
Key=sharded_file_key,
Body=json.dumps(s3_data).encode(),
SSECustomerKey="test_decryption_key",
Expand All @@ -63,12 +64,18 @@ def test_s3_load_type_on_change(self):
self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name)
self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name)

# For additional good measure, let's also fetch the file from S3 and validate the contents.
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=self.default_file_key)
actual_data = obj["Body"].read().decode("utf-8")
assert actual_data == json.loads(expected_content)

def test_s3_load_type_sharded_on_change(self):
dest = self.output_dir.joinpath("data.txt")
inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)

# Include num_file_shards as a key.
new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1", "num_file_shards": 5}'
expected_content = b'{"foo_encrypted": "bar_encrypteds"}'
expected_content = b'{"foo_encrypted": "bar_encrypted"}'

# For safe measure, run this 20 times. It should succeed every time.
# We've uploaded 5 files to S3 in setUp() and num_file_shards=5 in the ZK node so we should be fetching one of these 5 files randomly (and successfully) - and all should have the same content.
Expand All @@ -78,6 +85,19 @@ def test_s3_load_type_sharded_on_change(self):
self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name)
self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name)

# For additional good measure, let's also fetch the files from S3 and validate the contents.
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=self.default_file_key)
actual_data = obj["Body"].read().decode("utf-8")
assert actual_data == json.loads(expected_content)

# Assert that all copies of the file are fetchable and contain the same
# data as the original.
for i in range(1, NUM_FILE_SHARDS):
file_key = str(i) + "/" + self.default_file_key
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=file_key)
actual_data = obj["Body"].read().decode("utf-8")
assert actual_data == json.loads(expected_content)

def test_on_change(self):
dest = self.output_dir.joinpath("data.txt")
inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)
Expand Down

0 comments on commit 658c1fb

Please sign in to comment.