Skip to content

Commit

Permalink
Merge pull request #206 from 4dn-dcic/0.8.4
Browse files Browse the repository at this point in the history
0.8.4
  • Loading branch information
carlvitzthum authored Jun 12, 2019
2 parents 83acc35 + 4d87165 commit f0b4853
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 3 deletions.
111 changes: 109 additions & 2 deletions tests/tibanna/unicorn/test_ec2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tibanna.utils import create_jobid
from tibanna.exceptions import (
MissingFieldInInputJsonException,
MalFormattedInputJsonException,
EC2InstanceLimitException,
EC2InstanceLimitWaitException
)
Expand Down Expand Up @@ -37,6 +38,108 @@ def test_args_missing_field():
assert 'output_S3_bucket' in str(ex.value)


def test_args_parse_input_files():
input_dict = {'args': {'input_files': {"file1": "s3://somebucket/somekey"},
'output_S3_bucket': 'somebucket',
'cwl_main_filename': 'main.cwl',
'cwl_directory_url': 'someurl',
'app_name': 'someapp'}}
args = Args(**input_dict['args'])
args.fill_default()
assert hasattr(args, 'input_files')
assert 'file1' in args.input_files
assert 'bucket_name' in args.input_files['file1']
assert 'object_key' in args.input_files['file1']
assert args.input_files['file1']['bucket_name'] == 'somebucket'
assert args.input_files['file1']['object_key'] == 'somekey'


def test_args_parse_input_files2():
input_dict = {'args': {'input_files': {"file1": [["s3://somebucket/somekey1",
"s3://somebucket/somekey2"],
["s3://somebucket/somekey3",
"s3://somebucket/somekey4"]]},
'output_S3_bucket': 'somebucket',
'cwl_main_filename': 'main.cwl',
'cwl_directory_url': 'someurl',
'app_name': 'someapp'}}
args = Args(**input_dict['args'])
args.fill_default()
assert hasattr(args, 'input_files')
assert 'file1' in args.input_files
assert 'bucket_name' in args.input_files['file1']
assert 'object_key' in args.input_files['file1']
assert args.input_files['file1']['bucket_name'] == 'somebucket'
assert isinstance(args.input_files['file1']['object_key'], list)
assert len(args.input_files['file1']['object_key']) == 2
assert isinstance(args.input_files['file1']['object_key'][0], list)
assert len(args.input_files['file1']['object_key'][0]) == 2
assert isinstance(args.input_files['file1']['object_key'][1], list)
assert len(args.input_files['file1']['object_key'][1]) == 2
assert args.input_files['file1']['object_key'][0][0] == 'somekey1'
assert args.input_files['file1']['object_key'][0][1] == 'somekey2'
assert args.input_files['file1']['object_key'][1][0] == 'somekey3'
assert args.input_files['file1']['object_key'][1][1] == 'somekey4'


def test_args_parse_input_files3():
input_dict = {'args': {'input_files': {"file1": ["s3://somebucket/somekey1",
"s3://somebucket/somekey2"]},
'output_S3_bucket': 'somebucket',
'cwl_main_filename': 'main.cwl',
'cwl_directory_url': 'someurl',
'app_name': 'someapp'}}
args = Args(**input_dict['args'])
args.fill_default()
assert hasattr(args, 'input_files')
assert 'file1' in args.input_files
assert 'bucket_name' in args.input_files['file1']
assert 'object_key' in args.input_files['file1']
assert args.input_files['file1']['bucket_name'] == 'somebucket'
assert isinstance(args.input_files['file1']['object_key'], list)
assert len(args.input_files['file1']['object_key']) == 2
assert args.input_files['file1']['object_key'][0] == 'somekey1'
assert args.input_files['file1']['object_key'][1] == 'somekey2'


def test_args_parse_input_files_format_error():
input_dict = {'args': {'input_files': {"file1": "somerandomstr"},
'output_S3_bucket': 'somebucket',
'cwl_main_filename': 'main.cwl',
'cwl_directory_url': 'someurl',
'app_name': 'someapp'}}
args = Args(**input_dict['args'])
with pytest.raises(MalFormattedInputJsonException) as ex:
args.fill_default()
assert ex
assert 'S3 url must begin with' in str(ex.value)


def test_args_parse_input_files_format_error2():
input_dict = {'args': {'input_files': {"file1": ["s3://somebucket/somekey1",
"s3://otherbucket/somekey2"]},
'output_S3_bucket': 'somebucket',
'cwl_main_filename': 'main.cwl',
'cwl_directory_url': 'someurl',
'app_name': 'someapp'}}
args = Args(**input_dict['args'])
with pytest.raises(MalFormattedInputJsonException) as ex:
args.fill_default()
assert ex
assert 'bucket' in str(ex.value)


def test_parse_command():
input_dict = {'args': {'command': ['command1', 'command2', 'command3'],
'output_S3_bucket': 'somebucket',
'language': 'shell',
'container_image': 'someimage',
'app_name': 'someapp'}}
args = Args(**input_dict['args'])
args.fill_default()
assert args.command == 'command1; command2; command3'


def test_config():
input_dict = {'config': {'log_bucket': 'tibanna-output', 'shutdown_min': 30}}
cfg = Config(**input_dict['config'])
Expand Down Expand Up @@ -124,6 +227,7 @@ def test_execution_benchmark():
s3.delete_objects(Bucket='tibanna-output',
Delete={'Objects': [{'Key': randomstr}]})


def test_get_file_size():
randomstr = 'test-' + create_jobid()
s3 = boto3.client('s3')
Expand All @@ -135,6 +239,7 @@ def test_get_file_size():
s3.delete_objects(Bucket='tibanna-output',
Delete={'Objects': [{'Key': randomstr}]})


def test_get_input_size_in_bytes():
randomstr = 'test-' + create_jobid()
s3 = boto3.client('s3')
Expand All @@ -155,6 +260,7 @@ def test_get_input_size_in_bytes():
s3.delete_objects(Bucket='tibanna-output',
Delete={'Objects': [{'Key': randomstr}]})


def test_update_config_ebs_size():
"""ebs_size is given as the 'x' format. The total estimated ebs_size is smaller than 10"""
randomstr = 'test-' + create_jobid()
Expand All @@ -176,6 +282,7 @@ def test_update_config_ebs_size():
s3.delete_objects(Bucket='tibanna-output',
Delete={'Objects': [{'Key': randomstr}]})


def test_update_config_ebs_size2():
"""ebs_size is given as the 'x' format. The total estimated ebs_size is larger than 10"""
randomstr = 'test-' + create_jobid()
Expand All @@ -192,12 +299,12 @@ def test_update_config_ebs_size2():
execution = Execution(input_dict)
execution.input_size_in_bytes = execution.get_input_size_in_bytes()
execution.update_config_ebs_size()
assert execution.cfg.ebs_size > 18
assert execution.cfg.ebs_size < 19
assert execution.cfg.ebs_size == 19
# cleanup afterwards
s3.delete_objects(Bucket='tibanna-output',
Delete={'Objects': [{'Key': randomstr}]})


def test_unicorn_input_missing_field():
"""app_name that doesn't exist in benchmark, without instance type, mem, cpu info"""
input_dict = {'args': {'input_files': {}, 'app_name': 'app_name_not_in_benchmark',
Expand Down
2 changes: 1 addition & 1 deletion tibanna/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.8.3"
__version__ = "0.8.4"
55 changes: 55 additions & 0 deletions tibanna/ec2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import boto3
import copy
import re
from .utils import (
printlog,
does_key_exist,
Expand All @@ -23,6 +24,7 @@
)
from .exceptions import (
MissingFieldInInputJsonException,
MalFormattedInputJsonException,
EC2LaunchException,
EC2InstanceLimitException,
EC2InstanceLimitWaitException,
Expand Down Expand Up @@ -137,6 +139,8 @@ def fill_default(self):
self.singularity = False
if not hasattr(self, 'app_name'):
self.app_name = ''
# input file format check and parsing
self.parse_input_files()
# check workflow info is there and fill in default
errmsg_template = "field %s is required in args for language %s"
if self.language == 'wdl':
Expand Down Expand Up @@ -183,6 +187,53 @@ def fill_default(self):
if not self.cwl_directory_local and not self.cwl_directory_url:
errmsg = "either %s or %s must be provided in args" % ('cwl_directory_url', 'cwl_directory_local')
raise MissingFieldInInputJsonException(errmsg)
# reformat command
self.parse_command()

def parse_command(self):
"""if command is a list, conert it to a string"""
if hasattr(self, 'command'):
if isinstance(self.command, list):
self.command = '; '.join(self.command)
elif not isinstance(self.command, str):
raise MalFormattedInputJsonException("command must be either a string or a list")

def parse_input_files(self):
"""checking format for input files and converting s3:// style string into
bucket_name and object_key"""
if hasattr(self, 'input_files'):
if not isinstance(self.input_files, dict):
errmsg = "'input_files' must be provided as a dictionary (key-value pairs)"
raise MalFormattedInputJsonException(errmsg)
for ip, v in self.input_files.items():
if isinstance(v, str):
bucket_name, object_key = self.parse_s3_url(v)
self.input_files[ip] = {'bucket_name': bucket_name, 'object_key': object_key}
elif isinstance(v, list):
buckets = flatten(run_on_nested_arrays1(v, self.parse_s3_url, **{'bucket_only': True}))
if len(set(buckets)) != 1:
errmsg = "All the input files corresponding to a single input file argument " + \
"must be from the same bucket."
raise MalFormattedInputJsonException(errmsg)
object_keys = run_on_nested_arrays1(v, self.parse_s3_url, **{'key_only': True})
self.input_files[ip] = {'bucket_name': buckets[0], 'object_key': object_keys}
elif isinstance(v, dict) and 'bucket_name' in v and 'object_key' in v:
pass
else:
errmsg = "Each input_file value must be either a string starting with 's3://'" + \
" or a dictionary with 'bucket_name' and 'object_key' as keys"
raise MalFormattedInputJsonException(errmsg)

def parse_s3_url(self, url, bucket_only=False, key_only=False):
if not url.startswith('s3://'):
raise MalFormattedInputJsonException("S3 url must begin with 's3://'")
bucket_name = re.sub('^s3://', '', url).split('/')[0]
object_key = re.sub('^s3://' + bucket_name + '/', '', url)
if bucket_only:
return bucket_name
if key_only:
return object_key
return bucket_name, object_key

def as_dict(self):
return copy.deepcopy(self.__dict__)
Expand Down Expand Up @@ -359,6 +410,10 @@ def auto_calculate_ebs_size(self):
if isinstance(self.cfg.ebs_size, str) and self.cfg.ebs_size.endswith('x'):
multiplier = float(self.cfg.ebs_size.rstrip('x'))
self.cfg.ebs_size = multiplier * self.total_input_size_in_gb
if round(self.cfg.ebs_size) < self.cfg.ebs_size:
self.cfg.ebs_size = round(self.cfg.ebs_size) + 1
else:
self.cfg.ebs_size = round(self.cfg.ebs_size)
if self.cfg.ebs_size < 10:
self.cfg.ebs_size = 10

Expand Down
4 changes: 4 additions & 0 deletions tibanna/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ class EC2InstanceLimitWaitException(Exception):

class MissingFieldInInputJsonException(Exception):
pass


class MalFormattedInputJsonException(Exception):
pass

0 comments on commit f0b4853

Please sign in to comment.