Skip to content

Commit

Permalink
- Updated gitignore, added basic unit tests, and streamlined code some.
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaellh0079 committed Aug 15, 2023
1 parent f3a9f69 commit 6f72758
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 14 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
.idea/
dist/
node_modules/
package.zip
__pycache__/

26 changes: 12 additions & 14 deletions src_py/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import boto3
import concurrent.futures

s3_client = boto3.client('s3')

ed_pub_account_id = os.getenv('EDPUB_ACCOUNT_ID')
source_bucket = os.getenv('EDPUB_BUCKET')
destination_bucket = os.getenv('DAAC_BUCKET')
Expand All @@ -24,7 +24,7 @@ def get_keys(paginator, bucket):
return src_keys


def scan_ed_pub():
def scan_ed_pub(s3_client):
paginator = s3_client.get_paginator('list_objects_v2')
src_keys = get_keys(paginator, source_bucket)
missing_keys = src_keys.difference(get_keys(paginator, destination_bucket))
Expand All @@ -34,7 +34,12 @@ def scan_ed_pub():
futures = []
for key in missing_keys:
futures.append(
executor.submit(transfer_wrapper, key)
executor.submit(
s3_client.put_object,
Bucket=destination_bucket,
Body=s3_client.get_object(Bucket=source_bucket, Key=key).get('Body').read(),
Key=key
)
)

for future in concurrent.futures.as_completed(futures):
Expand All @@ -43,15 +48,7 @@ def scan_ed_pub():
return responses


def transfer_wrapper(key):
return s3_client.put_object(
Bucket=destination_bucket,
Body=s3_client.get_object(Bucket=source_bucket, Key=key).get('Body').read(),
Key=key
)


def handle_s3_event_message(event):
def handle_s3_event_message(event, s3_client):
object_key = event.get('Records')[0].get('s3').get('object').get('key')
return s3_client.copy_object(
Bucket=destination_bucket,
Expand All @@ -64,10 +61,11 @@ def handle_s3_event_message(event):


def handler(event, context):
s3_client = boto3.client('s3')
if event.get('Records', None):
ret = handle_s3_event_message(event)
ret = handle_s3_event_message(event, s3_client)
else:
ret = scan_ed_pub()
ret = scan_ed_pub(s3_client)

return ret

Expand Down
Empty file added tests/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import math
import unittest
from unittest.mock import MagicMock

from src_py.main import get_keys, scan_ed_pub


class FakePaginator:
def __init__(self, key_count, page_size, start_at=0):
self.key_count = key_count
self.page_size = page_size
self.page_count = int(math.ceil(self.key_count / self.page_size))
self.start_at = start_at
self.keys = [{'Key': f'key_{x + self.start_at}'} for x in range(self.start_at, self.key_count + self.start_at)]

def paginate(self, **kwargs):
index = 0
for x in range(self.page_count):
yield {'Contents': self.keys[index:index + self.page_size]}
index += self.page_size


class TestMain(unittest.TestCase):
def test_fake_paginator_1(self):
key_count = 5
page_size = 2
fp = FakePaginator(key_count, page_size)
self.assertEqual(fp.key_count, key_count)
self.assertEqual(fp.page_size, page_size)
self.assertEqual(fp.start_at, 0)
self.assertEqual(len(fp.keys), key_count)

results = []
for x in fp.paginate():
results.append(x)

self.assertEqual(len(results), 3)

def test_fake_paginator_2(self):
key_count = 5
page_size = 2
start_at = 5
fp = FakePaginator(key_count, page_size, start_at)
self.assertEqual(fp.key_count, key_count)
self.assertEqual(fp.page_size, page_size)
self.assertEqual(fp.start_at, start_at)
self.assertEqual(len(fp.keys), key_count)

results = []
for x in fp.paginate():
results.append(x)

self.assertEqual(len(results), 3)

def test_get_keys(self):
fake_bucket = 'fake_bucket_1'
fake_paginator_1 = FakePaginator(5, 2)
keys = get_keys(fake_paginator_1, fake_bucket)

self.assertSetEqual(keys, set(x.get('Key') for x in fake_paginator_1.keys))

def test_scan_ed_pub(self):
scan_ed_pub(MagicMock())

0 comments on commit 6f72758

Please sign in to comment.