diff --git a/src/fibad/download.py b/src/fibad/download.py index faeaf7c..ac38b0c 100644 --- a/src/fibad/download.py +++ b/src/fibad/download.py @@ -1,182 +1,19 @@ -import contextlib import datetime +import itertools import logging -import os +import time import urllib.request from pathlib import Path -from typing import Union +from threading import Lock, Thread +from typing import Optional, Union -import numpy as np from astropy.table import Table, hstack import fibad.downloadCutout.downloadCutout as dC -# These are the fields that are allowed to vary across the locations -# input from the catalog fits file. Other values for HSC cutout server -# must be provided by config. -variable_fields = ["tract", "ra", "dec"] - logger = logging.getLogger(__name__) -@contextlib.contextmanager -def working_directory(path: Path): - """ - Context Manager to change our working directory. - Supports downloadCutouts which always writes to cwd. - - Parameters - ---------- - path : Path - Path that we change `Path.cwd()` while we are active. - """ - old_cwd = Path.cwd() - os.chdir(path) - try: - yield - finally: - os.chdir(old_cwd) - - -def run(config): - """ - Main entrypoint for downloading cutouts from HSC for use with fibad - - Parameters - ---------- - config : dict - Runtime configuration as a nested dictionary - """ - - config = config.get("download", {}) - - logger.info("Download command Start") - - fits_file = config.get("fits_file", "") - logger.info(f"Reading in fits catalog: {fits_file}") - # Filter the fits file for the fields we want - column_names = ["object_id"] + variable_fields - locations = filterfits(fits_file, column_names) - - # TODO slice up the locations to multiplex across connections if necessary, but for now - # we simply mask off a few - offset = config.get("offset", 0) - end = offset + config.get("num_sources", 10) - locations = locations[offset:end] - - # Make a list of rects to pass to downloadCutout - rects = create_rects(locations, offset=0, default=rect_from_config(config)) - - # Configure global parameters for the downloader - dC.set_max_connections(num=config.get("max_connections", 2)) - - logger.info("Requesting cutouts") - # pass the rects to the cutout downloader - download_cutout_group( - rects=rects, - cutout_dir=config.get("cutout_dir"), - user=config["username"], - password=config["password"], - retrywait=config.get("retry_wait", 30), - retries=config.get("retries", 3), - timeout=config.get("timeout", 3600), - chunksize=config.get("chunk_size", 990), - ) - - logger.info("Done") - - -# TODO add error checking -def filterfits(filename: str, column_names: list[str]) -> Table: - """Read a fits file with the required column names for making cutouts - - The easiest way to make such a fits file is to select from the main HSC catalog - - Parameters - ---------- - filename : str - The fits file to read in - column_names : list[str] - The columns that are filtered out - - Returns - ------- - Table - Returns an astropy table containing only the fields specified in column_names - """ - t = Table.read(filename) - columns = [t[column] for column in column_names] - return hstack(columns, uniq_col_name="{table_name}", table_names=column_names) - - -def rect_from_config(config: dict) -> dC.Rect: - """Takes our runtime config and loads cutout config - common to all cutouts into a prototypical Rect for downloading - - Parameters - ---------- - config : dict - Runtime config, only the download section - - Returns - ------- - dC.Rect - A single rectangle with fields `sw`, `sh`, `filter`, `rerun`, and `type` populated from the config - """ - return dC.Rect.create( - sw=config["sw"], - sh=config["sh"], - filter=config["filter"], - rerun=config["rerun"], - type=config["type"], - ) - - -def create_rects(locations: Table, offset: int = 0, default: dC.Rect = None) -> list[dC.Rect]: - """Create the rects we will need to pass to the downloader. - One Rect per location in our list of sky locations. - - Rects are created with all fields in the default rect pre-filled - - Offset here is to allow multiple downloads on different sections of the source list - without file name clobbering during the download phase. The offset is intended to be - the index of the start of the locations table within some larger fits file. - - Parameters - ---------- - locations : Table - Table containing ra, dec locations in the sky - offset : int, optional - Index to start the `lineno` field in the rects at, by default 0. The purpose of this is to allow - multiple downloads on different sections of a larger source list without file name clobbering during - the download phase. This is important because `lineno` in a rect can becomes a file name parameter - The offset is intended to be the index of the start of the locations table within some larger fits - file. - default : dC.Rect, optional - The default Rect that contains properties common to all sky locations, by default None - - Returns - ------- - list[dC.Rect] - Rects populated with sky locations from the table - """ - rects = [] - for index, location in enumerate(locations): - args = {field: location[field] for field in variable_fields} - args["lineno"] = index + offset - args["tract"] = str(args["tract"]) - # Sets the file name on the rect to be the object_id, also includes other rect fields - # which are interpolated at save time, and are native fields of dc.Rect. - # - # This name is also parsed by FailedChunkCollector.hook to identify the object_id, so don't - # change it without updating code there too. - args["name"] = f"{location['object_id']}_{{type}}_{{ra:.5f}}_{{dec:+.5f}}_{{tract}}_{{filter}}" - rect = dC.Rect.create(default=default, **args) - rects.append(rect) - - return rects - - class DownloadStats: """Subsytem for keeping statistics on downloads: @@ -185,7 +22,8 @@ class DownloadStats: Can be used as a context manager for pretty printing. """ - def __init__(self): + def __init__(self, print_interval_s=30): + self.lock = Lock() self.stats = { "request_duration": datetime.timedelta(), # Time from request sent to first byte from the server "response_duration": datetime.timedelta(), # Total time spent recieving and processing a response @@ -194,11 +32,32 @@ def __init__(self): "snapshots": 0, # Number of fits snapshots downloaded } + # Reference count active threads and whether we've started + self.active_threads = 0 + self.num_threads = 0 + self.data_start = None + + # How often the watcher thread prints (seconds) + self.print_interval_s = print_interval_s + + # Start our watcher thread to print stats to the log + self.watcher_thread = Thread( + target=self._watcher_thread, name="stats watcher thread", args=(logging.INFO,), daemon=True + ) + self.watcher_thread.start() + def __enter__(self): + # Count how many threads are using stats + with self.lock: + self.active_threads += 1 + self.num_threads += 1 + return self.hook def __exit__(self, exc_type, exc_value, traceback): - self._print_stats(logging.INFO) + # Count how many threads are using stats + with self.lock: + self.active_threads -= 1 def _stat_accumulate(self, name: str, value: Union[int, datetime.timedelta]): """Accumulate a sum into the global stats dict @@ -212,6 +71,13 @@ def _stat_accumulate(self, name: str, value: Union[int, datetime.timedelta]): """ self.stats[name] += value + def _watcher_thread(self, log_level): + # Simple polling loop to print + while self.active_threads != 0 or not self.data_start: + if self.data_start: + self._print_stats(log_level) + time.sleep(self.print_interval_s) + def _print_stats(self, log_level): """Print the accumulated stats including bandwidth calculated from duration and sizes @@ -221,22 +87,45 @@ def _print_stats(self, log_level): If you use this class as a context manager, the end of context will output a newline, perserving the last line of stats in your terminal """ - total_dur_s = (self.stats["request_duration"] + self.stats["response_duration"]).total_seconds() - resp_s = self.stats["response_duration"].total_seconds() - down_rate_mb_s = (self.stats["response_size_bytes"] / (1024**2)) / resp_s if resp_s != 0 else 0 + def _div(num, denom, default=0.0): + return num / denom if denom != 0 else default - req_s = self.stats["request_duration"].total_seconds() - up_rate_mb_s = (self.stats["request_size_bytes"] / (1024**2)) / req_s if req_s != 0 else 0 + with self.lock: + now = datetime.datetime.now() - snapshot_rate = self.stats["snapshots"] / total_dur_s if total_dur_s != 0 else 0 + wall_clock_dur_s = (now - self.data_start).total_seconds() if self.data_start else 0 - stats_message = f"Stats: Duration: {total_dur_s:.2f} s, " + # This is the duration across all threads added up + total_dur_s = (self.stats["request_duration"] + self.stats["response_duration"]).total_seconds() + + resp_s = self.stats["response_duration"].total_seconds() + down_rate_mb_s = _div(self.stats["response_size_bytes"] / (1024**2), resp_s) + down_rate_mb_s_overall = _div(self.stats["response_size_bytes"] / (1024**2), wall_clock_dur_s) + + req_s = self.stats["request_duration"].total_seconds() + up_rate_mb_s = _div(self.stats["request_size_bytes"] / (1024**2), req_s) + + snapshot_rate = _div(self.stats["snapshots"], wall_clock_dur_s) + snapshot_rate_thread = _div(self.stats["snapshots"], total_dur_s) + + connnection_efficiency = _div(total_dur_s, wall_clock_dur_s * self.num_threads) + + thread_avg_dur = _div(total_dur_s, self.num_threads) + + stats_message = "Overall stats: " + stats_message += f"Wall-clock Duration: {wall_clock_dur_s:.2f} s, " stats_message += f"Files: {self.stats['snapshots']}, " + stats_message += f"Download rate: {down_rate_mb_s_overall:.2f} MB/s, " + stats_message += f"File rate: {snapshot_rate:.2f} files/s, " + stats_message += f"Conn eff: {connnection_efficiency:.2f}" + logger.log(log_level, stats_message) + + stats_message = f"Per Thread Averages ({self.num_threads} threads): " + stats_message += f"Duration: {thread_avg_dur:.2f} s, " stats_message += f"Upload: {up_rate_mb_s:.2f} MB/s, " stats_message += f"Download: {down_rate_mb_s:.2f} MB/s, " - stats_message += f"File rate: {snapshot_rate:.2f} files/s" - + stats_message += f"File rate: {snapshot_rate_thread:.2f} files/s, " logger.log(log_level, stats_message) def hook( @@ -264,148 +153,478 @@ def hook( chunk_size : int The number of cutout files recieved in this request """ - now = datetime.datetime.now() - self._stat_accumulate("request_duration", response_start - request_start) - self._stat_accumulate("response_duration", now - response_start) - self._stat_accumulate("request_size_bytes", len(request.data)) - self._stat_accumulate("response_size_bytes", response_size) - self._stat_accumulate("snapshots", chunk_size) + with self.lock: + if not self.data_start: + self.data_start = request_start - self._print_stats(logging.INFO) + self._stat_accumulate("request_duration", response_start - request_start) + self._stat_accumulate("response_duration", now - response_start) + self._stat_accumulate("request_size_bytes", len(request.data)) + self._stat_accumulate("response_size_bytes", response_size) + self._stat_accumulate("snapshots", chunk_size) -class FailedChunkCollector: - """Collection system for chunks of sky locations where the request for a chunk of cutouts failed. +class Downloader: + """Class with primarily static methods to namespace downloader related constants and functions.""" - Keeps track of all variable_fields plus object_id for failed chunks + # These are the fields that are allowed to vary across the locations + # input from the catalog fits file. Other values for HSC cutout server + # must be provided by config. + VARIABLE_FIELDS = ["tract", "ra", "dec"] - save() dumps these chunks using astropy.table.Table.write() + # These are the column names we retain when writing a rect out to the manifest.fits file + RECT_COLUMN_NAMES = VARIABLE_FIELDS + ["filter", "sw", "sh", "rerun", "type"] - """ + MANIFEST_FILE_NAME = "manifest.fits" - def __init__(self, filepath: Path, **kwargs): - """_summary_ + @staticmethod + def run(config): + """ + Main entrypoint for downloading cutouts from HSC for use with fibad Parameters ---------- - filepath : Path - File to read in if we are resuming a download, and where to save the failed chunks after. - If the file does not exist yet an empty state is initialized. - - **kwargs : dict - Keyword args passed to astropy.table.Table.read() and write() in the case that a file is used. - Should only be used to control file format, not read/write semantics + config : dict + Runtime configuration as a nested dictionary """ - self.__dict__.update({key: [] for key in variable_fields + ["object_id"]}) - self.seen_object_ids = set() - self.filepath = filepath.resolve() - self.format_kwargs = kwargs - # If there is a failed chunk file from a previous run, - # Read it in to initialize us - if filepath.exists(): - prev_failed_chunks = Table.read(filepath) - for key in variable_fields + ["object_id"]: - column_as_list = prev_failed_chunks[key].data.tolist() - self.__dict__[key] += column_as_list - logger.debug(f"Adding object ID :{self.object_id} to failed list") + config = config.get("download", {}) - self.seen_object_ids = {id for id in self.object_id} + logger.info("Download command Start") - self.count = len(self.seen_object_ids) - logger.debug(f"Failed chunk handler initialized with {self.count} objects") + fits_file = Path(config.get("fits_file", "")).resolve() + logger.info(f"Reading in fits catalog: {fits_file}") + # Filter the fits file for the fields we want + column_names = ["object_id"] + Downloader.VARIABLE_FIELDS + locations = Downloader.filterfits(fits_file, column_names) - def __enter__(self): - return self.hook + # If offet/length specified, filter to that length + offset = config.get("offset", 0) + end = offset + config.get("num_sources", None) + if end is not None: + locations = locations[offset:end] - def __exit__(self, exc_type, exc_value, traceback): - self.save() + cutout_path = Path(config.get("cutout_dir")).resolve() + logger.info(f"Downloading cutouts to {cutout_path}") + + # Make a list of rects to pass to downloadCutout + rects = Downloader.create_rects( + locations, offset=0, default=Downloader.rect_from_config(config), path=cutout_path + ) + + # Prune any previously downloaded rects from our list using the manifest from the previous download + rects = Downloader._prune_downloaded_rects(cutout_path, rects) + + # Early return if there is nothing to download. + if len(rects) == 0: + logger.info("Download already complete according to manifest.") + return + + # Create thread objects for each of our worker threads + num_threads = config.get("concurrent_connections", 2) + if num_threads > 5: + raise RuntimeError("This client only opens 5 connections or fewer.") + + # If we are using more than one connection, cut the list of rectangles into + # batches, one batch for each thread. + # TODO: Remove this in favor of itertools.batched() when we no longer support python < 3.12. + def _batched(iterable, n): + """Brazenly copied and pasted from the python 3.12 documentation. + This is a dodgy version of a new itertools function in Python 3.12 called itertools.batched() + """ + if n < 1: + raise ValueError("n must be at least one") + iterator = iter(iterable) + while batch := tuple(itertools.islice(iterator, n)): + yield batch + + thread_rects = list(_batched(rects, int(len(rects) / num_threads))) if num_threads != 1 else [rects] + + # Empty dictionaries for the threads to create download manifests in + thread_manifests = [dict() for _ in range(num_threads)] + + shared_thread_args = ( + config["username"], + config["password"], + DownloadStats(print_interval_s=config.get("stats_print_interval", 30)), + ) + + shared_thread_kwargs = { + "retrywait": config.get("retry_wait", 30), + "retries": config.get("retries", 3), + "timeout": config.get("timeout", 3600), + "chunksize": config.get("chunk_size", 990), + } + + download_threads = [ + Thread( + target=Downloader.download_thread, + name=f"thread_{i}", + daemon=True, # daemon so these threads will die when the main thread is interrupted + args=(thread_rects[i],) # rects + + shared_thread_args # username, password, download stats + + (i, thread_manifests[i]), # thread_num, manifest + kwargs=shared_thread_kwargs, + ) + for i in range(num_threads) + ] - def hook(self, rects: list[dC.Rect], exception: Exception, attempts: int): - """Called when dc.Download fails to download a chunk of rects + try: + logger.info(f"Started {len(download_threads)} request threads") + [thread.start() for thread in download_threads] + [thread.join() for thread in download_threads] + finally: # Ensure manifest is written even when we get a KeyboardInterrupt during download + Downloader.write_manifest(thread_manifests, cutout_path) + + logger.info("Done") + + @staticmethod + def _prune_downloaded_rects(cutout_path: Path, rects: list[dC.Rect]) -> list[dC.Rect]: + """Prunes already downloaded rects using the manifest in `cutout_path`. `rects` passed in is + mutated by this operation Parameters ---------- + cutout_path : Path + Where on the filesystem to find the manifest rects : list[dC.Rect] - The list of rect objects that were requested from the server - exception : Exception - The exception that was thrown on the final attempt to request this chunk - attempts : int - The number of attempts that were made to request this chunk + List of rects from which we want to prune previously downloaded rects + + Returns + ------- + list[dC.Rect] + Returns `rects` that was passed in. This is only to enable explicit style at the call site. + ` rects` is mutated by this function. + + Raises + ------ + RuntimeError + When there is an issue reading the manifest file, or the manifest file corresponds to a different + set of cutouts than the current download being attempted + """ + # print(rects) + # Read in any prior manifest. + prior_manifest = Downloader.read_manifest(cutout_path) + + # If we found a manifest, we are resuming a download + if len(prior_manifest) != 0: + # Filter rects to figure out which ones are completely downloaded. + # This operation consumes prior_manifest in the process + rects[:] = [rect for rect in rects if Downloader._keep_rect(rect, prior_manifest)] + + # if prior_manifest was not completely consumed, than the earlier download attempted + # some sky locations which would not be included in the current download, and we have + # a problem. + if len(prior_manifest) != 0: + # print(len(prior_manifest)) + # print (prior_manifest) + raise RuntimeError( + f"""{cutout_path/Downloader.MANIFEST_FILE_NAME} describes a download with +sky locations that would not be downloaded in the download currently being attempted. Are you sure you are +resuming the correct download? Deleting the manifest and cutout files will start the download from scratch""" + ) + + return rects + + @staticmethod + def _keep_rect(location_rect: dC.Rect, prior_manifest: dict[dC.Rect, str]) -> bool: + """Private helper function to prune_downloaded_rects which operates the inner loop + of the prune function, and allows it to be written as a list comprehension. + + This function decides element-by-element for our rects that we want to download whether + or not these rects have already been downloaded in a prior download, given the manifest + from that prior download. + + Parameters + ---------- + location_rect : dC.Rect + A rectangle on the sky that we are considering downloading. + + prior_manifest : dict[dC.Rect,str] + The manifest of the prior download. This object is slowly consumed by repeated calls + to this function. When the return value is False, all manifest entries corresponding to the + passed in location_rect have been removed. + + Returns + ------- + bool + Whether this sky location `location_rect` should be included in the download + """ + # Keep any location rect if the manifest passed has nothing in it. + if len(prior_manifest) == 0: + return True + + keep_rect = False + for filter_rect in location_rect.explode(): + # Consume any matching manifest entry, keep the rect if + # 1) The manifest entry doesn't exist -> pop returns None + # 2) The manifest entry contains "Attempted" for the filename -> The corresponding file wasn't + # successfully downloaded + matching_manifest_entry = prior_manifest.pop(filter_rect, None) + if matching_manifest_entry is None or matching_manifest_entry == "Attempted": + keep_rect = True + + return keep_rect + + @staticmethod + def write_manifest(thread_manifests: list[dict[dC.Rect, str]], file_path: Path): + """Write out manifest fits file that is an inventory of the download. + The manifest fits file should have columns object_id, ra, dec, tract, filter, filename + + If filename is empty string ("") that means a download attempt was made, but did not succeed + If the object is not present in the manifest, no download was attempted. + If the object is present in the manifest and the filename is not empty string that file exists + and downloaded successfully. + + This file respects the existence of other manifest files in the directory and operates additively. + If a manifest file is present from an earlier download, this function will read that manifest in, + and include the entire content of that manifest in addition to the manifests passed in. + + The format of the manifest file has the following columns + + object_id: The object ID from the original catalog + filename: The file name where the file can be found OR the string "Attempted" indicating the download + did not complete successfully. + tract: The HSC tract ID number this either comes from the catalog or is the tract ID returned by the + cutout server for downloaded files. + + ra: Right ascension in degrees of the center of the cutout box + dec: Declination in degrees of the center of the cutout box + filter: The name of the filter requested + sw: Semi-width of the cutout box in degrees + sh: Semi-height of the cutout box in degrees + rerun: The data release in use e.g. pdr3_wide + type: coadd, warp, or other values allowed by the HSC docs + + Parameters + ---------- + thread_manifests : list[dict[dC.Rect,str]] + Manifests mapping rects -> Filename or status message. Each manifest came from a separate thread. + file_path : Path + Full path to the location where the manifest file ought be written. The manifest file will be + named manifest.fits """ + logger.info("Assembling download manifest") + # Start building a combined manifest from all threads from the ground truth of the prior manifest + # in this directory, which we will be overwriting. + combined_manifest = Downloader.read_manifest(file_path) - for rect in rects: - # Relies on the name format set up in create_rects to work properly - object_id = int(rect.name.split("_")[0]) + # Combine all thread manifests with the prior manifest, so that the current status of a downloaded + # rect overwrites any status from the prior run (which is no longer relevant.) + for manifest in thread_manifests: + combined_manifest.update(manifest) - if object_id not in self.seen_object_ids: - self.seen_object_ids.add(object_id) + logger.info(f"Writing out download manifest with {len(combined_manifest)} entries.") - self.object_id.append(object_id) + # Convert the combined manifest into an astropy table by building a dict of {column_name: column_data} + # for all the fields in a rect, plus our object_id and filename. + column_names = Downloader.RECT_COLUMN_NAMES + ["filename", "object_id"] + columns = {column_name: [] for column_name in column_names} - for key in variable_fields: - self.__dict__[key].append(rect.__dict__[key]) + for rect, msg in combined_manifest.items(): + # This parsing relies on the name format set up in create_rects to work properly + # We parse the object_id from rect.name in case the filename is "Attempted" because the + # download did not finish. + rect_filename = Path(rect.name).name + object_id = int(rect_filename.split("_")[0]) + columns["object_id"].append(object_id) - self.count += 1 - logger.debug(f"Failed chunk handler processed {len(rects)} rects and is now of size {self.count}") + # Remove the leading path from the filename if any. + filename = Path(msg).name + columns["filename"].append(filename) - def save(self): - """ - Saves the current set of failed locations to the path specified. - If no failed locations were saved by the hook, this function does nothing. + for key in Downloader.RECT_COLUMN_NAMES: + columns[key].append(rect.__dict__[key]) + + # print(columns) + # for key, val in columns.items(): + # print (key, len(val), val) + + manifest_table = Table(columns) + manifest_table.write(file_path / Downloader.MANIFEST_FILE_NAME, overwrite=True, format="fits") + + logger.info("Finished writing download manifest") + + @staticmethod + def read_manifest(file_path: Path) -> dict[dC.Rect, str]: + """Read the manifest.fits file from the given directory and return its contents as a dictionary with + downloadCutout.Rectangles as keys and filenames as values. + + If now manifest file is found, an empty dict is returned. + + Parameters + ---------- + file_path : Path + Where to find the manifest file + + Returns + ------- + dict[dC.Rect, str] + A dictionary containing all the rects in the manifest and all the filenames, or empty dict if no + manifest is found. """ - if self.count == 0: - return + filename = file_path / Downloader.MANIFEST_FILE_NAME + if filename.exists(): + manifest_table = Table.read(filename, format="fits") + rects = Downloader.create_rects( + locations=manifest_table, fields=Downloader.RECT_COLUMN_NAMES, path=file_path + ) + return {rect: filename for rect, filename in zip(rects, manifest_table["filename"])} else: - # convert our class-member-based representation to an astropy table. - for key in variable_fields + ["object_id"]: - self.__dict__[key] = np.array(self.__dict__[key]) - - missed = Table({key: self.__dict__[key] for key in variable_fields + ["object_id"]}) - - # note that the choice to do overwrite=True here and to read in the entire fits file in - # ___init__() is necessary because snapshots corresponding to the same object may cross - # chunk boundaries decided by dC.download. - # - # Since we are de-duplicating rects by object_id, we need to read in all rects from a prior - # run, and we therefore replace the file we were passed. - missed.write(self.filepath, overwrite=True, **self.format_kwargs) - - -def download_cutout_group(rects: list[dC.Rect], cutout_dir: Union[str, Path], user, password, **kwargs): - """Download cutouts to the given directory - - Calls downloadCutout.download, so supports long lists of rects beyond the limits of the HSC web API - - Parameters - ---------- - rects : list[dC.Rect] - The rects we would like to download - cutout_dir : Union[str, Path] - The directory to put the files - user : string - Username for HSC's download service to use - password : string - Password for HSC's download service to use - **kwargs: dict - Additonal arguments for downloadCutout.download. See downloadCutout.download for details - """ + return {} + + @staticmethod + def download_thread( + rects: list[dC.Rect], + user: str, + password: str, + stats: DownloadStats, + thread_num: int, + manifest: dict[dC.Rect, str], + **kwargs, + ): + """Download cutouts to the given directory. Called in its own thread with an id number. + + Calls downloadCutout.download, so supports long lists of rects beyond the limits of the HSC web API - with working_directory(Path(cutout_dir)): - with ( - DownloadStats() as stats_hook, - FailedChunkCollector(Path("failed_locations.fits"), format="fits") as failed_chunk_hook, - ): + Parameters + ---------- + rects : list[dC.Rect] + The rects we would like to download + user : string + Username for HSC's download service to use + password : string + Password for HSC's download service to use + stats : DownloadStats + Instance of DownloadStats to use for stats tracking. + thread_num : int, + The ID number of thread we are, sequential from zero to num_threads-1 + manifest: + A dictionary from dC.Rect to filename which we will fill in in as we download rects. This is the + chief returned piece of data from each thread. + **kwargs: dict + Additonal arguments for downloadCutout.download. See downloadCutout.download for details + """ + logger.info(f"Thread {thread_num} starting download of {len(rects)} rects") + with stats as stats_hook: dC.download( rects, user=user, password=password, onmemory=False, request_hook=stats_hook, - failed_chunk_hook=failed_chunk_hook, - resume=True, + manifest=manifest, **kwargs, ) + + # TODO add error checking + @staticmethod + def filterfits(filename: Path, column_names: list[str]) -> Table: + """Read a fits file with the required column names for making cutouts + + The easiest way to make such a fits file is to select from the main HSC catalog + + Parameters + ---------- + filename : str + The fits file to read in + column_names : list[str] + The columns that are selected from the file and returned in the astropy Table. + + Returns + ------- + Table + Returns an astropy table containing only the fields specified in column_names + """ + t = Table.read(filename) + columns = [t[column] for column in column_names] + return hstack(columns, uniq_col_name="{table_name}", table_names=column_names) + + @staticmethod + def rect_from_config(config: dict) -> dC.Rect: + """Takes our runtime config and loads cutout config + common to all cutouts into a prototypical Rect for downloading + + Parameters + ---------- + config : dict + Runtime config, only the download section + + Returns + ------- + dC.Rect + A single rectangle with fields `sw`, `sh`, `filter`, `rerun`, and `type` populated from the config + """ + return dC.Rect.create( + sw=config["sw"], + sh=config["sh"], + filter=config["filter"], + rerun=config["rerun"], + type=config["type"], + ) + + @staticmethod + def create_rects( + locations: Table, + path: Path, + offset: int = 0, + default: dC.Rect = None, + fields: Optional[list[str]] = None, + ) -> list[dC.Rect]: + """Create the rects we will need to pass to the downloader. + One Rect per location in our list of sky locations. + + Rects are created with all fields in the default rect pre-filled + + Offset here is to allow multiple downloads on different sections of the source list + without file name clobbering during the download phase. The offset is intended to be + the index of the start of the locations table within some larger fits file. + + Parameters + ---------- + locations : Table + Table containing ra, dec locations in the sky + path : Path + Directory where the cutuout files ought live. Used to generate file names on the rect object. + offset : int, optional + Index to start the `lineno` field in the rects at, by default 0. The purpose of this is to allow + multiple downloads on different sections of a larger source list without file name clobbering + during the download phase. This is important because `lineno` in a rect can becomes a file name + parameter The offset is intended to be the index of the start of the locations table within some + larger fits file. + default : dC.Rect, optional + The default Rect that contains properties common to all sky locations, by default None + + fields : list[str], optional + Default fields to pull from the locations table. If not provided, defaults to + ["tract", "ra", "dec"] + + Returns + ------- + list[dC.Rect] + Rects populated with sky locations from the table + """ + rects = [] + fields = fields if fields else Downloader.VARIABLE_FIELDS + for index, location in enumerate(locations): + args = {field: location[field] for field in fields} + args["lineno"] = index + offset + args["tract"] = str(args["tract"]) + # Sets the file name on the rect to be the object_id, also includes other rect fields + # which are interpolated at save time, and are native fields of dc.Rect. + args["name"] = str( + path / f"{location['object_id']}_{{type}}_{{ra:.5f}}_{{dec:+.5f}}_{{tract}}_{{filter}}" + ) + rect = dC.Rect.create(default=default, **args) + rects.append(rect) + + # We sort rects here so they end up tract,ra,dec ordered across all requests made in all threads + # Threads do their own sorting prior to each chunked request in downloadCutout.py; however + # sorting at this stage will allow a greater number of rects that are co-located in the sky + # to end up in the same thread and same chunk. + rects.sort() + + return rects diff --git a/src/fibad/downloadCutout/downloadCutout.py b/src/fibad/downloadCutout/downloadCutout.py index d20a46d..a0c1706 100644 --- a/src/fibad/downloadCutout/downloadCutout.py +++ b/src/fibad/downloadCutout/downloadCutout.py @@ -7,9 +7,7 @@ import datetime import errno import getpass -import hashlib import io -import json import logging import math import os @@ -21,11 +19,8 @@ import urllib.request import urllib.response from collections.abc import Generator -from pathlib import Path from typing import IO, Any, Callable, Optional, Union, cast -import toml - __all__ = [] @@ -477,14 +472,40 @@ def explode(self) -> list["Rect"]: else: return [Rect.create(default=self)] + # Static field list used by __eq__ and __hash__ + immutable_fields = ["ra", "dec", "sw", "sh", "filter", "type", "rerun", "image", "variance", "mask"] + + def __eq__(self, obj) -> bool: + """Define equality on Rects by sky location, size, filter, type, rerun, and image/mask/variance state. + This allows rects can be used as keys in dictionaries while ignoring transient fields such as lineno, + or fields that may be incorrect/changed during download process like tract or name. + + This is a compromise between + 1) Dataclass's unsafe_hash=True which would hash all fields + and + 2) Making the dataclass frozen which would affect some of the mutability used to alter + lineno, tract, and name + + Note that this makes equality on Rects means "the cutout API should return the same data", + rather than "Literally all data members the same" + + Parameters + ---------- + obj : Rect + The rect to compare to self + + Returns + ------- + bool + True if the Rect's are equal + """ + return all([self.__dict__[field] == obj.__dict__[field] for field in Rect.immutable_fields]) -class RectEncoder(json.JSONEncoder): - # TODO this needs to be implemented on a subclass of JSONEncoder - # And it needs to do something very particular in order to work. - def default(self, obj): - if isinstance(obj, Rect): - return obj.__dict__ - return json.JSONEncoder.default(self, obj) + def __hash__(self): + """Define a hash function on Rects. Outside of hash collisions, this function attempts to have the + same semantics as Rect.__eq__(). Look at Rect.__eq__() for further details. + """ + return hash(tuple([self.__dict__[field] for field in Rect.immutable_fields])) @export @@ -1008,9 +1029,6 @@ def download( Some important (but entirely optional!) keyword args processed later in the download callstack are listed below. Anything urllib.request.urlopen will accept is fair game too! - resume : bool - Whether to attempt to resume an ongoing download from filesystem data in onmemory=False mode. - Default: False. See _download() for greater detail. chunksize : int The number of rects to include in a single http request. Default 990 rects. See _download() for greater detail. @@ -1061,10 +1079,8 @@ def _download( *, onmemory: bool, chunksize: int = 990, - resume: bool = False, retries: int = 3, retrywait: int = 30, - failed_chunk_hook: Optional[Callable[[list[Rect], Exception, int], Any]] = None, **kwargs__download_chunk, ) -> Optional[list[list]]: """ @@ -1089,26 +1105,12 @@ def _download( If `onmemory` is False, downloaded cut-outs are written to files in the current working directory. chunksize: int, optional Number of cutout lines to pack into a single request. Defaults to 990 if unspecified. - resume: bool, optional - When `onmemory == True`, uses resume data in the current working directory continue a failed download. - Noop when onmemory=False. Defaults to False if unspecified. - - Passing resume=True is safe when no resume data exists. - _download() will simply start downloading from the beginning of rects. retries: int, optional Number of attempts to make to fetch each chunk. Defaults to 3 if unspecified. retrywait: int, optional Base number of seconds to wait between retries. Retry waits are computed using an exponential backoff where the retry time for attempts is calculated as retrywait * (2 ** attempt) seconds , with attempt=0 for the first wait. - - failed_chunk_hook: Callable[[list[Rect], Exception, int], Any] - Hook which is called every time a chunk fails `retries` time. The arguments to the hook are - the rects in the failed chunk, the exception encountered while making the last request, and - the number of attempts. - - If this function raises, the entire download stops, but otherwise the download will ocntinue - kwargs__download_chunk: dict, optional Additional keyword args are passed through to _download_chunk() @@ -1148,174 +1150,56 @@ def _download( datalist: list[tuple[int, dict, bytes]] = [] - failed_rect_index = 0 - - start_rect_index = 0 - if not onmemory and resume: - start_rect_index = _read_resume_data(exploded_rects) + # Chunk loop + for i in range(0, len(exploded_rects), chunksize): + # Retry loop + for attempt in range(0, retries): + try: + ret = _download_chunk( + exploded_rects[i : i + chunksize], + user, + password, + onmemory=onmemory, + **kwargs__download_chunk, + ) + break + except KeyboardInterrupt: + logger.critical("Keyboard Interrupt recieved.") + raise + except Exception as exception: + # Humans count attempts from 1, this loop counts from zero. + logger.warning( + f"Attempt {attempt + 1} of {retries} to request rects [{i}:{i+chunksize}] has error:" + ) + logger.warning(exception) - try: - # Chunk loop - for i in range(start_rect_index, len(exploded_rects), chunksize): - # Retry loop - for attempt in range(0, retries): - try: - ret = _download_chunk( - exploded_rects[i : i + chunksize], - user, - password, - onmemory=onmemory, - **kwargs__download_chunk, - ) + # If the final attempt on this chunk fails, we move on. + if attempt + 1 == retries: break - except KeyboardInterrupt: - logger.critical("Keyboard Interrupt recieved.") - failed_rect_index = i - raise - except Exception as exception: - # Humans count attempts from 1, this loop counts from zero. - logger.warning( - f"Attempt {attempt + 1} of {retries} to request rects [{i}:{i+chunksize}] has error:" - ) - logger.warning(exception) - - # If the final attempt on this chunk fails, we try to call the failed_chunk_hook - if attempt + 1 == retries: - if failed_chunk_hook is not None: - rect_chunk = [rect for rect, idx in exploded_rects[i : i + chunksize]] - failed_chunk_hook(rects=rect_chunk, exception=exception, attempts=retries) - # If no hook provided, or if the provided hook doesn't raise, we continue the download - break - # Otherwise do exponential backoff and try again - else: - backoff = retrywait * (2**attempt) - if backoff != 0: - logger.info(f"Retrying in {backoff} seconds... ") - time.sleep(backoff) - logger.info("Retrying now") - continue - if onmemory: - datalist += cast(list, ret) - - # Retries have failed or we are being killed - except (Exception, KeyboardInterrupt): - # Write out resume data if we're saving to filesystem and there's been any progress - if (not onmemory) and failed_rect_index != 0: - _write_resume_data(exploded_rects, failed_rect_index) - - # Reraise so exception can reach top level, very important for KeyboardInterrupt - raise + # Otherwise wait for exponential backoff and try again + else: + backoff = retrywait * (2**attempt) + if backoff != 0: + logger.info(f"Retrying in {backoff} seconds... ") + time.sleep(backoff) + logger.info("Retrying now") + continue + if onmemory: + datalist += cast(list, ret) if onmemory: returnedlist: list[list[tuple[dict, bytes]]] = [[] for i in range(len(rects))] for index, metadata, data in datalist: returnedlist[index].append((metadata, data)) - # On success we remove resume data - if not onmemory and resume and os.path.exists(resume_data_filename): - os.remove(resume_data_filename) - return returnedlist if onmemory else None -# TODO multiple connections resume data will need to be instanced by connection -# That will require some interface so the connection number can make it here -resume_data_filename = "resume_download.toml" - - -def _read_resume_data(rects: list[Rect]) -> int: - """Read the resume data from the current working directory - - Parameters - ---------- - rects : list[Rect] - List of rects we intend to process, needed for checksum to ensure the download we are resuming - is the same one that output resume data. - - Returns - ------- - Returns an integer specifying what index in the rect list the resumeing download should start. - If no resume data is found, 0 is returned. - - Raises - ------ - RuntimeError - "No resume data found in " when the resume file could not be found in cwd. - RuntimeError - "Resume data in corrupt" when the file is not a toml file containing keys - 'checksum' and 'start_rect_index' - RuntimeError - "Resume data failed checksum ..." when the rect list has changed from when the resume data file was - written - """ - # Load resume data so we start at the appropriate chunk. - if not os.path.exists(resume_data_filename): - return 0 - - logger.info(f"Resuming failed download from {Path.cwd() / resume_data_filename}") - with open(resume_data_filename, "r") as f: - resumedata = toml.load(f) - if "start_rect_index" not in resumedata or "checksum" not in resumedata: - raise RuntimeError(f"Resume data in {Path.cwd() / resume_data_filename} corrupt.") - - start_rect_index = resumedata["start_rect_index"] - - checksum = _calc_rect_list_checksum(rects[0:start_rect_index]) - if resumedata["checksum"] != checksum: - message = f"""Resume data failed checksum. - Has the list of sky locations changed? If so, remove {Path.cwd() / resume_data_filename}""" - raise RuntimeError(message) - - return start_rect_index - - -def _write_resume_data(rects: list[Rect], failed_rect_index: int) -> None: - """Write resume data - - Parameters - ---------- - rects : list[Rect] - List of Rects we were intending to download, needed to write the checksum into the resume data - failed_rect_index : int - The index of the beginning of the first chunk of rects to fail. - """ - logger.info("Writing resume data") - # Output enough information that we can retry/resume assuming same dir but, - # whatever was DL'ed in current chunk is corrupt - resumedata = { - "start_rect_index": failed_rect_index, - "checksum": _calc_rect_list_checksum(rects[0:failed_rect_index]), - } - with open(resume_data_filename, mode="w") as f: - toml.dump(resumedata, f) - logger.info("Done writing resume data") - - -def _calc_rect_list_checksum(rects: list[Rect]) -> str: - """ - Calculate a sha256 checksum of a list of Rects for the purpose of identifying tha list in the context of - a resumed download - - The method is to dump the list of Rects to JSON and sha256 the JSON. - - Parameters - ---------- - rects : list[Rect] - List of rects that we will checksum - - Returns - ------- - str - Sha256 hex digest of the list of rects. - """ - byte_string = json.dumps(rects, sort_keys=True, cls=RectEncoder).encode("utf-8") - return hashlib.sha256(byte_string).hexdigest() - - def _download_chunk( rects: list[tuple[Rect, Any]], user: str, password: str, + manifest: Optional[dict[Rect, str]], *, onmemory: bool, request_hook: Optional[ @@ -1338,6 +1222,9 @@ def _download_chunk( Username. password Password. + manifest + Dictionary from Rect to filename. If Provided, this function will fill in as it downloads. + If download of a file fails, the file's entry will read "Attempted" for the filename. onmemory Return `datalist` on memory. If `onmemory` is False, downloaded cut-outs are written to files. @@ -1388,6 +1275,11 @@ def _download_chunk( # Set timeout to 1 hour if no timout was set higher up kwargs_urlopen.setdefault("timeout", 3600) + # Set all manifest entries to indicate an attempt was made. + if manifest is not None: + for rect, _ in rects: + manifest[rect] = "Attempted" + with get_connection_semaphore(): request_started = datetime.datetime.now() with urllib.request.urlopen(req, **kwargs_urlopen) as fin: @@ -1420,6 +1312,8 @@ def _download_chunk( os.makedirs(dirname, exist_ok=True) with open(filename, "wb") as fout: _splice(fitem, fout) + if manifest is not None: + manifest[rect] = filename if request_hook: request_hook(req, request_started, response_started, response_size, len(rects)) diff --git a/src/fibad/fibad.py b/src/fibad/fibad.py index 4635634..07389fd 100644 --- a/src/fibad/fibad.py +++ b/src/fibad/fibad.py @@ -147,9 +147,9 @@ def download(self, **kwargs): """ See Fibad.download.run() """ - from .download import run + from .download import Downloader - return run(config=self.config, **kwargs) + return Downloader.run(config=self.config, **kwargs) def predict(self, **kwargs): """ diff --git a/src/fibad/fibad_default_config.toml b/src/fibad/fibad_default_config.toml index d1da052..ec68156 100644 --- a/src/fibad/fibad_default_config.toml +++ b/src/fibad/fibad_default_config.toml @@ -22,11 +22,10 @@ sh = "22asec" filter = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"] type = "coadd" rerun = "pdr3_wide" -max_connections = 2 +concurrent_connections = 1 +stats_print_interval = 30 fits_file = "./catalog.fits" cutout_dir = "./data" -offset = 0 -num_sources = 500 # These control the downloader's HTTP requests and retries # `retry_wait` How long to wait before retrying a failed HTTP request in seconds. Default 30s