From a87dafc4f4a27e658f08e3ef1abe974a4d8e2a6d Mon Sep 17 00:00:00 2001 From: Lars Gebraad Date: Mon, 31 Jan 2022 21:42:42 +0100 Subject: [PATCH] Cleaned up the files --- notebooks/Tutorial on using simpleSVGD.ipynb | 23 +- src/simpleSVGD/__init__.py | 47 +--- src/simpleSVGD/_version.py | 157 +++++++---- src/simpleSVGD/kernels.py | 21 ++ versioneer.py | 280 +++++++++++-------- 5 files changed, 300 insertions(+), 228 deletions(-) create mode 100644 src/simpleSVGD/kernels.py diff --git a/notebooks/Tutorial on using simpleSVGD.ipynb b/notebooks/Tutorial on using simpleSVGD.ipynb index 24ee3a5..b3c90a7 100644 --- a/notebooks/Tutorial on using simpleSVGD.ipynb +++ b/notebooks/Tutorial on using simpleSVGD.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "bdf3a8dd", "metadata": {}, "outputs": [], @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 2, "id": "e9a71a8f", "metadata": {}, "outputs": [], @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 3, "id": "86d380f7", "metadata": {}, "outputs": [], @@ -120,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 4, "id": "bf6126cd", "metadata": {}, "outputs": [ @@ -1092,7 +1092,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -1104,7 +1104,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f715ee898ebe4b16b2f31c34fbb5ba41", + "model_id": "ac6bd9d8d69440beb4efca369a548f70", "version_major": 2, "version_minor": 0 }, @@ -1131,10 +1131,7 @@ " Himmelblau_grad, \n", " n_iter=130,\n", " # AdaGrad parameters\n", - " stepsize=1e-1,\n", - " alpha=0.9,\n", - " fudge_factor=1e-3,\n", - " historical_grad=1,\n", + " stepsize=1e0,\n", " animate=True,\n", " background=background,\n", " figure=figure,\n", @@ -1151,7 +1148,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 5, "id": "5ac5e7d1", "metadata": { "scrolled": false @@ -2125,7 +2122,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -2195,7 +2192,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.11" + "version": "3.8.12" } }, "nbformat": 4, diff --git a/src/simpleSVGD/__init__.py b/src/simpleSVGD/__init__.py index 2edeec1..d96197a 100644 --- a/src/simpleSVGD/__init__.py +++ b/src/simpleSVGD/__init__.py @@ -1,29 +1,16 @@ -from enum import auto import numpy as _numpy import tqdm.auto as _tqdm_auto -from scipy.spatial.distance import pdist as _pdist, squareform as _squareform import matplotlib.pyplot as _plt import matplotlib.figure as _figure from typing import Callable as _Callable, List as _List, Tuple as _Tuple +from .kernels import rbf_kernel as _rbf_kernel +from .helpers import TorchWrapper as _TorchWrapper -def _rbf_kernel(theta, h=-1): - """Radial basis function kernel.""" - sq_dist = _pdist(theta) - pairwise_dists = _squareform(sq_dist) ** 2 - if h < 0: # if h < 0, using median trick - h = _numpy.median(pairwise_dists) - h = _numpy.sqrt(0.5 * h / _numpy.log(theta.shape[0] + 1)) - # compute the rbf kernel - Kxy = _numpy.exp(-pairwise_dists / h ** 2 / 2) +from . import _version - dxkxy = -_numpy.matmul(Kxy, theta) - sumkxy = _numpy.sum(Kxy, axis=1) - for i in range(theta.shape[1]): - dxkxy[:, i] = dxkxy[:, i] + _numpy.multiply(theta[:, i], sumkxy) - dxkxy = dxkxy / (h ** 2) - return (Kxy, dxkxy) +__version__ = _version.get_versions()["version"] def update( @@ -33,9 +20,6 @@ def update( n_iter: int = 1000, stepsize: float = 1e-3, bandwidth: float = -1, - alpha: float = 0.9, - fudge_factor=1e-3, - historical_grad=1, # All following parameter only concern animation animate: bool = False, figure: _figure.Figure = None, @@ -68,10 +52,6 @@ def update( more likely to produce good results, but will slow the algorithm down. Default is -1. - alpha - Parameter with which to dampen gradient changes in the target function - during SVGD updating. Default is 0.9. - animate A boolean to animate the algorithm. Only works for functions of at least two dimensions. Default is False. @@ -127,7 +107,7 @@ def update( axis.set_aspect(1) figure.canvas.draw() - _plt.pause(0.00001) + _plt.pause(1e-5) # The Try/Except allows on to interrupt the algorithm using CTRL+C while # still getting x0_updated at the point of interruption. @@ -143,17 +123,7 @@ def update( 0 ] - # adagrad - if iter == 0: - historical_grad = historical_grad + grad_theta ** 2 - else: - historical_grad = alpha * historical_grad + (1 - alpha) * ( - grad_theta ** 2 - ) - adj_grad = _numpy.divide( - grad_theta, fudge_factor + _numpy.sqrt(historical_grad) - ) - x0_updated = x0_updated + stepsize * adj_grad + x0_updated = x0_updated - stepsize * grad_theta if animate: scatter.set_offsets( @@ -165,7 +135,7 @@ def update( ) ) figure.canvas.draw() - _plt.pause(0.00001) + _plt.pause(1e-5) except KeyboardInterrupt: pass @@ -183,6 +153,3 @@ def grd(m: _numpy.array) -> _numpy.array: ).T return grd - -from . import _version -__version__ = _version.get_versions()['version'] diff --git a/src/simpleSVGD/_version.py b/src/simpleSVGD/_version.py index b51f1ad..576e0e9 100644 --- a/src/simpleSVGD/_version.py +++ b/src/simpleSVGD/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -59,17 +58,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) process = None @@ -77,10 +77,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except OSError: e = sys.exc_info()[1] @@ -115,15 +118,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -182,7 +191,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -191,7 +200,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -199,24 +208,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -233,8 +249,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): GITS = ["git.cmd", "git.exe"] TAG_PREFIX_REGEX = r"\*" - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -242,11 +257,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", - "%s%s" % (tag_prefix, TAG_PREFIX_REGEX)], - cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s%s" % (tag_prefix, TAG_PREFIX_REGEX), + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -261,8 +284,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -302,17 +324,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -321,10 +342,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -373,8 +396,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -403,8 +425,7 @@ def render_pep440_branch(pieces): rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -432,7 +453,7 @@ def render_pep440_pre(pieces): tag_version, post_version = pep440_split_post(pieces["closest-tag"]) rendered = tag_version if post_version is not None: - rendered += ".post%d.dev%d" % (post_version+1, pieces["distance"]) + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) else: rendered += ".post0.dev%d" % (pieces["distance"]) else: @@ -565,11 +586,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -593,9 +616,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -609,8 +636,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -619,13 +645,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -639,6 +668,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/src/simpleSVGD/kernels.py b/src/simpleSVGD/kernels.py new file mode 100644 index 0000000..6a182e7 --- /dev/null +++ b/src/simpleSVGD/kernels.py @@ -0,0 +1,21 @@ +import numpy as _numpy +from scipy.spatial.distance import pdist as _pdist, squareform as _squareform + + +def rbf_kernel(theta, h=-1): + """Radial basis function kernel.""" + sq_dist = _pdist(theta) + pairwise_dists = _squareform(sq_dist) ** 2 + if h < 0: # if h < 0, using median trick + h = _numpy.median(pairwise_dists) + h = _numpy.sqrt(0.5 * h / _numpy.log(theta.shape[0] + 1)) + + # compute the rbf kernel + Kxy = _numpy.exp(-pairwise_dists / h ** 2 / 2) + + dxkxy = -_numpy.matmul(Kxy, theta) + sumkxy = _numpy.sum(Kxy, axis=1) + for i in range(theta.shape[1]): + dxkxy[:, i] = dxkxy[:, i] + _numpy.multiply(theta[:, i], sumkxy) + dxkxy = dxkxy / (h ** 2) + return (Kxy, dxkxy) diff --git a/versioneer.py b/versioneer.py index b4cd1d6..d70f31b 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,3 @@ - # Version: 0.21 """The Versioneer - like a rocketeer, but for versions. @@ -309,11 +308,13 @@ def get_root(): setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -326,8 +327,10 @@ def get_root(): me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(my_path), versioneer_py)) + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py) + ) except NameError: pass return root @@ -372,15 +375,16 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" HANDLERS.setdefault(vcs, {})[method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) process = None @@ -388,10 +392,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except OSError: e = sys.exc_info()[1] @@ -414,7 +421,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, return stdout, process.returncode -LONG_VERSION_PY['git'] = r''' +LONG_VERSION_PY[ + "git" +] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -1116,7 +1125,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1125,7 +1134,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1133,24 +1142,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -1167,8 +1183,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): GITS = ["git.cmd", "git.exe"] TAG_PREFIX_REGEX = r"\*" - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1176,11 +1191,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", - "%s%s" % (tag_prefix, TAG_PREFIX_REGEX)], - cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s%s" % (tag_prefix, TAG_PREFIX_REGEX), + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -1195,8 +1218,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -1236,17 +1258,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1255,10 +1276,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1331,15 +1354,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1368,11 +1397,13 @@ def versions_from_file(filename): contents = f.read() except OSError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1381,8 +1412,7 @@ def versions_from_file(filename): def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1414,8 +1444,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1444,8 +1473,7 @@ def render_pep440_branch(pieces): rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1473,7 +1501,7 @@ def render_pep440_pre(pieces): tag_version, post_version = pep440_split_post(pieces["closest-tag"]) rendered = tag_version if post_version is not None: - rendered += ".post%d.dev%d" % (post_version+1, pieces["distance"]) + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) else: rendered += ".post0.dev%d" % (pieces["distance"]) else: @@ -1606,11 +1634,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -1634,9 +1664,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } class VersioneerBadRootError(Exception): @@ -1659,8 +1693,9 @@ def get_versions(verbose=False): handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1714,9 +1749,13 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } def get_version(): @@ -1769,6 +1808,7 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version # we override "build_py" in both distutils and setuptools @@ -1787,8 +1827,8 @@ def run(self): # setup.py egg_info -> ? # we override different "build_py" commands for both environments - if 'build_py' in cmds: - _build_py = cmds['build_py'] + if "build_py" in cmds: + _build_py = cmds["build_py"] elif "setuptools" in sys.modules: from setuptools.command.build_py import build_py as _build_py else: @@ -1803,14 +1843,14 @@ def run(self): # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py - if 'build_ext' in cmds: - _build_ext = cmds['build_ext'] + if "build_ext" in cmds: + _build_ext = cmds["build_ext"] elif "setuptools" in sys.modules: from setuptools.command.build_ext import build_ext as _build_ext else: @@ -1830,14 +1870,15 @@ def run(self): return # now locate _version.py in the new build/ directory and replace # it with an updated value - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1858,17 +1899,21 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if 'py2exe' in sys.modules: # py2exe enabled? + if "py2exe" in sys.modules: # py2exe enabled? from py2exe.distutils_buildexe import py2exe as _py2exe class cmd_py2exe(_py2exe): @@ -1884,18 +1929,22 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["py2exe"] = cmd_py2exe # we override different "sdist" commands for both environments - if 'sdist' in cmds: - _sdist = cmds['sdist'] + if "sdist" in cmds: + _sdist = cmds["sdist"] elif "setuptools" in sys.modules: from setuptools.command.sdist import sdist as _sdist else: @@ -1919,8 +1968,10 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + cmds["sdist"] = cmd_sdist return cmds @@ -1980,11 +2031,9 @@ def do_setup(): root = get_root() try: cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, - configparser.NoOptionError) as e: + except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: if isinstance(e, (OSError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1993,15 +2042,18 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): try: with open(ipy, "r") as f: @@ -2049,8 +2101,10 @@ def do_setup(): else: print(" 'versioneer.py' already in MANIFEST.in") if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) + print( + " appending versionfile_source ('%s') to MANIFEST.in" + % cfg.versionfile_source + ) with open(manifest_in, "a") as f: f.write("include %s\n" % cfg.versionfile_source) else: