diff --git a/luigi/contrib/lsf_runner.py b/luigi/contrib/lsf_runner.py old mode 100755 new mode 100644 index 5a6c8b5699..f483e7bf45 --- a/luigi/contrib/lsf_runner.py +++ b/luigi/contrib/lsf_runner.py @@ -28,7 +28,7 @@ except ImportError: import pickle import logging -import tarfile +from luigi.safe_extractor import SafeExtractor def do_work_on_compute_node(work_dir): @@ -52,10 +52,8 @@ def extract_packages_archive(work_dir): curdir = os.path.abspath(os.curdir) os.chdir(work_dir) - tar = tarfile.open(package_file) - for tarinfo in tar: - tar.extract(tarinfo) - tar.close() + extractor = SafeExtractor(work_dir) + extractor.safe_extract(package_file) if '' not in sys.path: sys.path.insert(0, '') diff --git a/luigi/contrib/sge_runner.py b/luigi/contrib/sge_runner.py index f0621fb475..2600f2d6dc 100755 --- a/luigi/contrib/sge_runner.py +++ b/luigi/contrib/sge_runner.py @@ -36,7 +36,7 @@ import sys import pickle import logging -import tarfile +from luigi.safe_extractor import SafeExtractor def _do_work_on_compute_node(work_dir, tarball=True): @@ -64,10 +64,8 @@ def _extract_packages_archive(work_dir): curdir = os.path.abspath(os.curdir) os.chdir(work_dir) - tar = tarfile.open(package_file) - for tarinfo in tar: - tar.extract(tarinfo) - tar.close() + extractor = SafeExtractor(work_dir) + extractor.safe_extract(package_file) if '' not in sys.path: sys.path.insert(0, '') diff --git a/luigi/safe_extractor.py b/luigi/safe_extractor.py new file mode 100644 index 0000000000..f106a14c37 --- /dev/null +++ b/luigi/safe_extractor.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module provides a class `SafeExtractor` that offers a secure way to extract tar files while +mitigating path traversal vulnerabilities, which can occur when files inside the archive are +crafted to escape the intended extraction directory. + +The `SafeExtractor` ensures that the extracted file paths are validated before extraction to +prevent malicious archives from extracting files outside the intended directory. + +Classes: + SafeExtractor: A class to securely extract tar files with protection against path traversal attacks. + +Usage Example: + extractor = SafeExtractor("/desired/directory") + extractor.safe_extract("archive.tar") +""" + +import os +import tarfile + + +class SafeExtractor: + """ + A class to safely extract tar files, ensuring that no path traversal + vulnerabilities are exploited. + + Attributes: + path (str): The directory to extract files into. + + Methods: + _is_within_directory(directory, target): + Checks if a target path is within a given directory. + + safe_extract(tar_path, members=None, \\*, numeric_owner=False): + Safely extracts the contents of a tar file to the specified directory. + """ + + def __init__(self, path="."): + """ + Initializes the SafeExtractor with the specified directory path. + + Args: + path (str): The directory to extract files into. Defaults to the current directory. + """ + self.path = path + + @staticmethod + def _is_within_directory(directory, target): + """ + Checks if a target path is within a given directory. + + Args: + directory (str): The directory to check against. + target (str): The target path to check. + + Returns: + bool: True if the target path is within the directory, False otherwise. + """ + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + prefix = os.path.commonprefix([abs_directory, abs_target]) + return prefix == abs_directory + + def safe_extract(self, tar_path, members=None, *, numeric_owner=False): + """ + Safely extracts the contents of a tar file to the specified directory. + + Args: + tar_path (str): The path to the tar file to extract. + members (list, optional): A list of members to extract. Defaults to None. + numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False. + + Raises: + RuntimeError: If a path traversal attempt is detected. + """ + with tarfile.open(tar_path, 'r') as tar: + for member in tar.getmembers(): + member_path = os.path.join(self.path, member.name) + if not self._is_within_directory(self.path, member_path): + raise RuntimeError("Attempted Path Traversal in Tar File") + tar.extractall(self.path, members, numeric_owner=numeric_owner) diff --git a/test/safe_extractor_test.py b/test/safe_extractor_test.py new file mode 100644 index 0000000000..e14367438e --- /dev/null +++ b/test/safe_extractor_test.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Safe Extractor Test +============= + +Tests for the Safe Extractor class in luigi.safe_extractor module. +""" + +import os +import shutil +import tarfile +import tempfile +import unittest + +from luigi.safe_extractor import SafeExtractor + + +class TestSafeExtract(unittest.TestCase): + """ + Unit test class for testing the SafeExtractor module. + """ + + def setUp(self): + """Set up a temporary directory for test files.""" + self.temp_dir = tempfile.mkdtemp() + self.test_file_template = 'test_file_{}.txt' + self.tar_file_name = 'test.tar' + self.tar_file_name_with_traversal = f'traversal_{self.tar_file_name}' + + def tearDown(self): + """Clean up the temporary directory after each test.""" + shutil.rmtree(self.temp_dir) + + def create_test_tar(self, tar_path, file_count=1, with_traversal=False): + """ + Create a tar file containing test files. + + Args: + tar_path (str): Path where the tar file will be created. + file_count (int): Number of test files to include. + with_traversal (bool): If True, creates a tar file with path traversal vulnerability. + """ + # Default content for the test files + file_contents = [f'This is {self.test_file_template.format(i)}' for i in range(file_count)] + + with tarfile.open(tar_path, 'w') as tar: + for i in range(file_count): + file_name = self.test_file_template.format(i) + file_path = os.path.join(self.temp_dir, file_name) + + # Write content to each test file + with open(file_path, 'w') as f: + f.write(file_contents[i]) + + # If path traversal is enabled, create malicious paths + archive_name = f'../../{file_name}' if with_traversal else file_name + + # Add the file to the tar archive + tar.add(file_path, arcname=archive_name) + + def verify_extracted_files(self, file_count): + """ + Verify that the correct files were extracted and their contents match expectations. + + Args: + file_count (int): Number of files to verify. + """ + for i in range(file_count): + file_name = self.test_file_template.format(i) + file_path = os.path.join(self.temp_dir, file_name) + + # Check if the file exists + self.assertTrue(os.path.exists(file_path), f"File {file_name} does not exist.") + + # Check if the file content is correct + with open(file_path, 'r') as f: + content = f.read() + expected_content = f'This is {file_name}' + self.assertEqual(content, expected_content, f"Content mismatch in {file_name}.") + + def test_safe_extract(self): + """Test normal safe extraction of tar files.""" + tar_path = os.path.join(self.temp_dir, self.tar_file_name) + + # Create a tar file with 3 files + self.create_test_tar(tar_path, file_count=3) + + # Initialize SafeExtractor and perform extraction + extractor = SafeExtractor(self.temp_dir) + extractor.safe_extract(tar_path) + + # Verify that all 3 files were extracted correctly + self.verify_extracted_files(3) + + def test_safe_extract_with_traversal(self): + """Test safe extraction for tar files with path traversal (should raise an error).""" + tar_path = os.path.join(self.temp_dir, self.tar_file_name_with_traversal) + + # Create a tar file with a path traversal file + self.create_test_tar(tar_path, file_count=1, with_traversal=True) + + # Initialize SafeExtractor and expect RuntimeError due to path traversal + extractor = SafeExtractor(self.temp_dir) + with self.assertRaises(RuntimeError): + extractor.safe_extract(tar_path) + + +if __name__ == '__main__': + unittest.main()