Skip to content

Commit

Permalink
Code cleanup (#5234)
Browse files Browse the repository at this point in the history
Co-authored-by: keewis <keewis@users.noreply.github.com>
Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
Co-authored-by: Stephan Hoyer <shoyer@google.com>
  • Loading branch information
4 people authored May 13, 2021
1 parent 4067c01 commit 1a7b285
Show file tree
Hide file tree
Showing 48 changed files with 377 additions and 514 deletions.
45 changes: 19 additions & 26 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,11 @@ def _get_default_engine_netcdf():

def _get_default_engine(path: str, allow_remote: bool = False):
if allow_remote and is_remote_uri(path):
engine = _get_default_engine_remote_uri()
return _get_default_engine_remote_uri()
elif path.endswith(".gz"):
engine = _get_default_engine_gz()
return _get_default_engine_gz()
else:
engine = _get_default_engine_netcdf()
return engine
return _get_default_engine_netcdf()


def _validate_dataset_names(dataset):
Expand Down Expand Up @@ -282,7 +281,7 @@ def _chunk_ds(

mtime = _get_mtime(filename_or_obj)
token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
name_prefix = "open_dataset-%s" % token
name_prefix = f"open_dataset-{token}"

variables = {}
for name, var in backend_ds.variables.items():
Expand All @@ -295,8 +294,7 @@ def _chunk_ds(
name_prefix=name_prefix,
token=token,
)
ds = backend_ds._replace(variables)
return ds
return backend_ds._replace(variables)


def _dataset_from_backend_dataset(
Expand All @@ -308,12 +306,10 @@ def _dataset_from_backend_dataset(
overwrite_encoded_chunks,
**extra_tokens,
):
if not (isinstance(chunks, (int, dict)) or chunks is None):
if chunks != "auto":
raise ValueError(
"chunks must be an int, dict, 'auto', or None. "
"Instead found %s. " % chunks
)
if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
)

_protect_dataset_variables_inplace(backend_ds, cache)
if chunks is None:
Expand All @@ -331,9 +327,8 @@ def _dataset_from_backend_dataset(
ds.set_close(backend_ds._close)

# Ensure source filename always stored in dataset object (GH issue #2550)
if "source" not in ds.encoding:
if isinstance(filename_or_obj, str):
ds.encoding["source"] = filename_or_obj
if "source" not in ds.encoding and isinstance(filename_or_obj, str):
ds.encoding["source"] = filename_or_obj

return ds

Expand Down Expand Up @@ -515,7 +510,6 @@ def open_dataset(
**decoders,
**kwargs,
)

return ds


Expand Down Expand Up @@ -1015,8 +1009,8 @@ def to_netcdf(
elif engine != "scipy":
raise ValueError(
"invalid engine for creating bytes with "
"to_netcdf: %r. Only the default engine "
"or engine='scipy' is supported" % engine
f"to_netcdf: {engine!r}. Only the default engine "
"or engine='scipy' is supported"
)
if not compute:
raise NotImplementedError(
Expand All @@ -1037,7 +1031,7 @@ def to_netcdf(
try:
store_open = WRITEABLE_STORES[engine]
except KeyError:
raise ValueError("unrecognized engine for to_netcdf: %r" % engine)
raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}")

if format is not None:
format = format.upper()
Expand All @@ -1049,9 +1043,8 @@ def to_netcdf(
autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"]
if autoclose and engine == "scipy":
raise NotImplementedError(
"Writing netCDF files with the %s backend "
"is not currently supported with dask's %s "
"scheduler" % (engine, scheduler)
f"Writing netCDF files with the {engine} backend "
f"is not currently supported with dask's {scheduler} scheduler"
)

target = path_or_file if path_or_file is not None else BytesIO()
Expand All @@ -1061,7 +1054,7 @@ def to_netcdf(
kwargs["invalid_netcdf"] = invalid_netcdf
else:
raise ValueError(
"unrecognized option 'invalid_netcdf' for engine %s" % engine
f"unrecognized option 'invalid_netcdf' for engine {engine}"
)
store = store_open(target, mode, format, group, **kwargs)

Expand Down Expand Up @@ -1203,7 +1196,7 @@ def save_mfdataset(
Data variables:
a (time) float64 0.0 0.02128 0.04255 0.06383 ... 0.9574 0.9787 1.0
>>> years, datasets = zip(*ds.groupby("time.year"))
>>> paths = ["%s.nc" % y for y in years]
>>> paths = [f"{y}.nc" for y in years]
>>> xr.save_mfdataset(datasets, paths)
"""
if mode == "w" and len(set(paths)) < len(paths):
Expand All @@ -1215,7 +1208,7 @@ def save_mfdataset(
if not isinstance(obj, Dataset):
raise TypeError(
"save_mfdataset only supports writing Dataset "
"objects, received type %s" % type(obj)
f"objects, received type {type(obj)}"
)

if groups is None:
Expand Down
3 changes: 1 addition & 2 deletions xarray/backends/cfgrib_.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def get_dimensions(self):

def get_encoding(self):
dims = self.get_dimensions()
encoding = {"unlimited_dims": {k for k, v in dims.items() if v is None}}
return encoding
return {"unlimited_dims": {k for k, v in dims.items() if v is None}}


class CfgribfBackendEntrypoint(BackendEntrypoint):
Expand Down
7 changes: 3 additions & 4 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
base_delay = initial_delay * 2 ** n
next_delay = base_delay + np.random.randint(base_delay)
msg = (
"getitem failed, waiting %s ms before trying again "
"(%s tries remaining). Full traceback: %s"
% (next_delay, max_retries - n, traceback.format_exc())
f"getitem failed, waiting {next_delay} ms before trying again "
f"({max_retries - n} tries remaining). Full traceback: {traceback.format_exc()}"
)
logger.debug(msg)
time.sleep(1e-3 * next_delay)
Expand Down Expand Up @@ -336,7 +335,7 @@ def set_dimensions(self, variables, unlimited_dims=None):
if dim in existing_dims and length != existing_dims[dim]:
raise ValueError(
"Unable to update size for existing dimension"
"%r (%d != %d)" % (dim, length, existing_dims[dim])
f"{dim!r} ({length} != {existing_dims[dim]})"
)
elif dim not in existing_dims:
is_unlimited = dim in unlimited_dims
Expand Down
15 changes: 6 additions & 9 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
class H5NetCDFArrayWrapper(BaseNetCDF4Array):
def get_array(self, needs_lock=True):
ds = self.datastore._acquire(needs_lock)
variable = ds.variables[self.variable_name]
return variable
return ds.variables[self.variable_name]

def __getitem__(self, key):
return indexing.explicit_indexing_adapter(
Expand Down Expand Up @@ -102,7 +101,7 @@ def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=Fal
if group is None:
root, group = find_root_and_group(manager)
else:
if not type(manager) is h5netcdf.File:
if type(manager) is not h5netcdf.File:
raise ValueError(
"must supply a h5netcdf.File if the group "
"argument is provided"
Expand Down Expand Up @@ -233,11 +232,9 @@ def get_dimensions(self):
return self.ds.dimensions

def get_encoding(self):
encoding = {}
encoding["unlimited_dims"] = {
k for k, v in self.ds.dimensions.items() if v is None
return {
"unlimited_dims": {k for k, v in self.ds.dimensions.items() if v is None}
}
return encoding

def set_dimension(self, name, length, is_unlimited=False):
if is_unlimited:
Expand Down Expand Up @@ -266,9 +263,9 @@ def prepare_variable(
"h5netcdf does not yet support setting a fill value for "
"variable-length strings "
"(https://github.com/shoyer/h5netcdf/issues/37). "
"Either remove '_FillValue' from encoding on variable %r "
f"Either remove '_FillValue' from encoding on variable {name!r} "
"or set {'dtype': 'S1'} in encoding to use the fixed width "
"NC_CHAR type." % name
"NC_CHAR type."
)

if dtype is str:
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def locked(self):
return any(lock.locked for lock in self.locks)

def __repr__(self):
return "CombinedLock(%r)" % list(self.locks)
return f"CombinedLock({list(self.locks)!r})"


class DummyLock:
Expand Down
63 changes: 28 additions & 35 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,23 @@ def _encode_nc4_variable(var):
def _check_encoding_dtype_is_vlen_string(dtype):
if dtype is not str:
raise AssertionError( # pragma: no cover
"unexpected dtype encoding %r. This shouldn't happen: please "
"file a bug report at github.com/pydata/xarray" % dtype
f"unexpected dtype encoding {dtype!r}. This shouldn't happen: please "
"file a bug report at github.com/pydata/xarray"
)


def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False):
if nc_format == "NETCDF4":
datatype = _nc4_dtype(var)
else:
if "dtype" in var.encoding:
encoded_dtype = var.encoding["dtype"]
_check_encoding_dtype_is_vlen_string(encoded_dtype)
if raise_on_invalid_encoding:
raise ValueError(
"encoding dtype=str for vlen strings is only supported "
"with format='NETCDF4'."
)
datatype = var.dtype
return datatype
return _nc4_dtype(var)
if "dtype" in var.encoding:
encoded_dtype = var.encoding["dtype"]
_check_encoding_dtype_is_vlen_string(encoded_dtype)
if raise_on_invalid_encoding:
raise ValueError(
"encoding dtype=str for vlen strings is only supported "
"with format='NETCDF4'."
)
return var.dtype


def _nc4_dtype(var):
Expand Down Expand Up @@ -178,7 +176,7 @@ def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group):
ds = create_group(ds, key)
else:
# wrap error to provide slightly more helpful message
raise OSError("group not found: %s" % key, e)
raise OSError(f"group not found: {key}", e)
return ds


Expand All @@ -203,7 +201,7 @@ def _force_native_endianness(var):
# if endian exists, remove it from the encoding.
var.encoding.pop("endian", None)
# check to see if encoding has a value for endian its 'native'
if not var.encoding.get("endian", "native") == "native":
if var.encoding.get("endian", "native") != "native":
raise NotImplementedError(
"Attempt to write non-native endian type, "
"this is not supported by the netCDF4 "
Expand Down Expand Up @@ -270,8 +268,8 @@ def _extract_nc4_variable_encoding(
invalid = [k for k in encoding if k not in valid_encodings]
if invalid:
raise ValueError(
"unexpected encoding parameters for %r backend: %r. Valid "
"encodings are: %r" % (backend, invalid, valid_encodings)
f"unexpected encoding parameters for {backend!r} backend: {invalid!r}. Valid "
f"encodings are: {valid_encodings!r}"
)
else:
for k in list(encoding):
Expand All @@ -282,10 +280,8 @@ def _extract_nc4_variable_encoding(


def _is_list_of_strings(value):
if np.asarray(value).dtype.kind in ["U", "S"] and np.asarray(value).size > 1:
return True
else:
return False
arr = np.asarray(value)
return arr.dtype.kind in ["U", "S"] and arr.size > 1


class NetCDF4DataStore(WritableCFDataStore):
Expand Down Expand Up @@ -313,7 +309,7 @@ def __init__(
if group is None:
root, group = find_root_and_group(manager)
else:
if not type(manager) is netCDF4.Dataset:
if type(manager) is not netCDF4.Dataset:
raise ValueError(
"must supply a root netCDF4.Dataset if the group "
"argument is provided"
Expand Down Expand Up @@ -417,25 +413,22 @@ def open_store_variable(self, name, var):
return Variable(dimensions, data, attributes, encoding)

def get_variables(self):
dsvars = FrozenDict(
return FrozenDict(
(k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items()
)
return dsvars

def get_attrs(self):
attrs = FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs())
return attrs
return FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs())

def get_dimensions(self):
dims = FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items())
return dims
return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items())

def get_encoding(self):
encoding = {}
encoding["unlimited_dims"] = {
k for k, v in self.ds.dimensions.items() if v.isunlimited()
return {
"unlimited_dims": {
k for k, v in self.ds.dimensions.items() if v.isunlimited()
}
}
return encoding

def set_dimension(self, name, length, is_unlimited=False):
dim_length = length if not is_unlimited else None
Expand Down Expand Up @@ -473,9 +466,9 @@ def prepare_variable(
"netCDF4 does not yet support setting a fill value for "
"variable-length strings "
"(https://github.com/Unidata/netcdf4-python/issues/730). "
"Either remove '_FillValue' from encoding on variable %r "
f"Either remove '_FillValue' from encoding on variable {name!r} "
"or set {'dtype': 'S1'} in encoding to use the fixed width "
"NC_CHAR type." % name
"NC_CHAR type."
)

encoding = _extract_nc4_variable_encoding(
Expand Down
2 changes: 0 additions & 2 deletions xarray/backends/netcdf3.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ def is_valid_nc3_name(s):
"""
if not isinstance(s, str):
return False
if not isinstance(s, str):
s = s.decode("utf-8")
num_bytes = len(s.encode("utf-8"))
return (
(unicodedata.normalize("NFC", s) == s)
Expand Down
5 changes: 1 addition & 4 deletions xarray/backends/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ def build_engines(pkg_entrypoints):
backend_entrypoints.update(external_backend_entrypoints)
backend_entrypoints = sort_backends(backend_entrypoints)
set_missing_parameters(backend_entrypoints)
engines = {}
for name, backend in backend_entrypoints.items():
engines[name] = backend()
return engines
return {name: backend() for name, backend in backend_entrypoints.items()}


@functools.lru_cache(maxsize=1)
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _getitem(self, key):
result = robust_getitem(array, key, catch=ValueError)
# in some cases, pydap doesn't squeeze axes automatically like numpy
axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types))
if result.ndim + len(axis) != array.ndim and len(axis) > 0:
if result.ndim + len(axis) != array.ndim and axis:
result = np.squeeze(result, axis)

return result
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc
# the filename is probably an s3 bucket rather than a regular file
mtime = None
token = tokenize(filename, mtime, chunks)
name_prefix = "open_rasterio-%s" % token
name_prefix = f"open_rasterio-{token}"
result = result.chunk(chunks, name_prefix=name_prefix, token=token)

# Make the file closeable
Expand Down
Loading

0 comments on commit 1a7b285

Please sign in to comment.