diff --git a/notebooks/Tutorial on using simpleSVGD.ipynb b/notebooks/Tutorial on using simpleSVGD.ipynb
index 24ee3a5..b3c90a7 100644
--- a/notebooks/Tutorial on using simpleSVGD.ipynb
+++ b/notebooks/Tutorial on using simpleSVGD.ipynb
@@ -10,7 +10,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"id": "bdf3a8dd",
"metadata": {},
"outputs": [],
@@ -30,7 +30,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 2,
"id": "e9a71a8f",
"metadata": {},
"outputs": [],
@@ -89,7 +89,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 3,
"id": "86d380f7",
"metadata": {},
"outputs": [],
@@ -120,7 +120,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 4,
"id": "bf6126cd",
"metadata": {},
"outputs": [
@@ -1092,7 +1092,7 @@
{
"data": {
"text/html": [
- ""
+ ""
],
"text/plain": [
""
@@ -1104,7 +1104,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "f715ee898ebe4b16b2f31c34fbb5ba41",
+ "model_id": "ac6bd9d8d69440beb4efca369a548f70",
"version_major": 2,
"version_minor": 0
},
@@ -1131,10 +1131,7 @@
" Himmelblau_grad, \n",
" n_iter=130,\n",
" # AdaGrad parameters\n",
- " stepsize=1e-1,\n",
- " alpha=0.9,\n",
- " fudge_factor=1e-3,\n",
- " historical_grad=1,\n",
+ " stepsize=1e0,\n",
" animate=True,\n",
" background=background,\n",
" figure=figure,\n",
@@ -1151,7 +1148,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 5,
"id": "5ac5e7d1",
"metadata": {
"scrolled": false
@@ -2125,7 +2122,7 @@
{
"data": {
"text/html": [
- ""
+ ""
],
"text/plain": [
""
@@ -2195,7 +2192,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.11"
+ "version": "3.8.12"
}
},
"nbformat": 4,
diff --git a/src/simpleSVGD/__init__.py b/src/simpleSVGD/__init__.py
index 2edeec1..d96197a 100644
--- a/src/simpleSVGD/__init__.py
+++ b/src/simpleSVGD/__init__.py
@@ -1,29 +1,16 @@
-from enum import auto
import numpy as _numpy
import tqdm.auto as _tqdm_auto
-from scipy.spatial.distance import pdist as _pdist, squareform as _squareform
import matplotlib.pyplot as _plt
import matplotlib.figure as _figure
from typing import Callable as _Callable, List as _List, Tuple as _Tuple
+from .kernels import rbf_kernel as _rbf_kernel
+from .helpers import TorchWrapper as _TorchWrapper
-def _rbf_kernel(theta, h=-1):
- """Radial basis function kernel."""
- sq_dist = _pdist(theta)
- pairwise_dists = _squareform(sq_dist) ** 2
- if h < 0: # if h < 0, using median trick
- h = _numpy.median(pairwise_dists)
- h = _numpy.sqrt(0.5 * h / _numpy.log(theta.shape[0] + 1))
- # compute the rbf kernel
- Kxy = _numpy.exp(-pairwise_dists / h ** 2 / 2)
+from . import _version
- dxkxy = -_numpy.matmul(Kxy, theta)
- sumkxy = _numpy.sum(Kxy, axis=1)
- for i in range(theta.shape[1]):
- dxkxy[:, i] = dxkxy[:, i] + _numpy.multiply(theta[:, i], sumkxy)
- dxkxy = dxkxy / (h ** 2)
- return (Kxy, dxkxy)
+__version__ = _version.get_versions()["version"]
def update(
@@ -33,9 +20,6 @@ def update(
n_iter: int = 1000,
stepsize: float = 1e-3,
bandwidth: float = -1,
- alpha: float = 0.9,
- fudge_factor=1e-3,
- historical_grad=1,
# All following parameter only concern animation
animate: bool = False,
figure: _figure.Figure = None,
@@ -68,10 +52,6 @@ def update(
more likely to produce good results, but will slow the algorithm down.
Default is -1.
- alpha
- Parameter with which to dampen gradient changes in the target function
- during SVGD updating. Default is 0.9.
-
animate
A boolean to animate the algorithm. Only works for functions of at
least two dimensions. Default is False.
@@ -127,7 +107,7 @@ def update(
axis.set_aspect(1)
figure.canvas.draw()
- _plt.pause(0.00001)
+ _plt.pause(1e-5)
# The Try/Except allows on to interrupt the algorithm using CTRL+C while
# still getting x0_updated at the point of interruption.
@@ -143,17 +123,7 @@ def update(
0
]
- # adagrad
- if iter == 0:
- historical_grad = historical_grad + grad_theta ** 2
- else:
- historical_grad = alpha * historical_grad + (1 - alpha) * (
- grad_theta ** 2
- )
- adj_grad = _numpy.divide(
- grad_theta, fudge_factor + _numpy.sqrt(historical_grad)
- )
- x0_updated = x0_updated + stepsize * adj_grad
+ x0_updated = x0_updated - stepsize * grad_theta
if animate:
scatter.set_offsets(
@@ -165,7 +135,7 @@ def update(
)
)
figure.canvas.draw()
- _plt.pause(0.00001)
+ _plt.pause(1e-5)
except KeyboardInterrupt:
pass
@@ -183,6 +153,3 @@ def grd(m: _numpy.array) -> _numpy.array:
).T
return grd
-
-from . import _version
-__version__ = _version.get_versions()['version']
diff --git a/src/simpleSVGD/_version.py b/src/simpleSVGD/_version.py
index b51f1ad..576e0e9 100644
--- a/src/simpleSVGD/_version.py
+++ b/src/simpleSVGD/_version.py
@@ -1,4 +1,3 @@
-
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
@@ -59,17 +58,18 @@ class NotThisMethod(Exception):
def register_vcs_handler(vcs, method): # decorator
"""Create decorator to mark a method as the handler of a VCS."""
+
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
+
return decorate
-def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
- env=None):
+def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
process = None
@@ -77,10 +77,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try:
dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
- process = subprocess.Popen([command] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None))
+ process = subprocess.Popen(
+ [command] + args,
+ cwd=cwd,
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr else None),
+ )
break
except OSError:
e = sys.exc_info()[1]
@@ -115,15 +118,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
- return {"version": dirname[len(parentdir_prefix):],
- "full-revisionid": None,
- "dirty": False, "error": None, "date": None}
+ return {
+ "version": dirname[len(parentdir_prefix) :],
+ "full-revisionid": None,
+ "dirty": False,
+ "error": None,
+ "date": None,
+ }
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
- print("Tried directories %s but none started with prefix %s" %
- (str(rootdirs), parentdir_prefix))
+ print(
+ "Tried directories %s but none started with prefix %s"
+ % (str(rootdirs), parentdir_prefix)
+ )
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@@ -182,7 +191,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
- tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
+ tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
@@ -191,7 +200,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
- tags = {r for r in refs if re.search(r'\d', r)}
+ tags = {r for r in refs if re.search(r"\d", r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
@@ -199,24 +208,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
- r = ref[len(tag_prefix):]
+ r = ref[len(tag_prefix) :]
# Filter out refs that exactly match prefix or that don't start
# with a number once the prefix is stripped (mostly a concern
# when prefix is '')
- if not re.match(r'\d', r):
+ if not re.match(r"\d", r):
continue
if verbose:
print("picking %s" % r)
- return {"version": r,
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": None,
- "date": date}
+ return {
+ "version": r,
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False,
+ "error": None,
+ "date": date,
+ }
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
- return {"version": "0+unknown",
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": "no suitable tags", "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False,
+ "error": "no suitable tags",
+ "date": None,
+ }
@register_vcs_handler("git", "pieces_from_vcs")
@@ -233,8 +249,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
GITS = ["git.cmd", "git.exe"]
TAG_PREFIX_REGEX = r"\*"
- _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
@@ -242,11 +257,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty",
- "--always", "--long",
- "--match",
- "%s%s" % (tag_prefix, TAG_PREFIX_REGEX)],
- cwd=root)
+ describe_out, rc = runner(
+ GITS,
+ [
+ "describe",
+ "--tags",
+ "--dirty",
+ "--always",
+ "--long",
+ "--match",
+ "%s%s" % (tag_prefix, TAG_PREFIX_REGEX),
+ ],
+ cwd=root,
+ )
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
@@ -261,8 +284,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
- branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
- cwd=root)
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
# --abbrev-ref was added in git-1.6.3
if rc != 0 or branch_name is None:
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
@@ -302,17 +324,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
- git_describe = git_describe[:git_describe.rindex("-dirty")]
+ git_describe = git_describe[: git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
- mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
+ mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparsable. Maybe git-describe is misbehaving?
- pieces["error"] = ("unable to parse git-describe output: '%s'"
- % describe_out)
+ pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
# tag
@@ -321,10 +342,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
- pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
- % (full_tag, tag_prefix))
+ pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
+ full_tag,
+ tag_prefix,
+ )
return pieces
- pieces["closest-tag"] = full_tag[len(tag_prefix):]
+ pieces["closest-tag"] = full_tag[len(tag_prefix) :]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
@@ -373,8 +396,7 @@ def render_pep440(pieces):
rendered += ".dirty"
else:
# exception #1
- rendered = "0+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
@@ -403,8 +425,7 @@ def render_pep440_branch(pieces):
rendered = "0"
if pieces["branch"] != "master":
rendered += ".dev0"
- rendered += "+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
+ rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
@@ -432,7 +453,7 @@ def render_pep440_pre(pieces):
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
rendered = tag_version
if post_version is not None:
- rendered += ".post%d.dev%d" % (post_version+1, pieces["distance"])
+ rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
else:
rendered += ".post0.dev%d" % (pieces["distance"])
else:
@@ -565,11 +586,13 @@ def render_git_describe_long(pieces):
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
- return {"version": "unknown",
- "full-revisionid": pieces.get("long"),
- "dirty": None,
- "error": pieces["error"],
- "date": None}
+ return {
+ "version": "unknown",
+ "full-revisionid": pieces.get("long"),
+ "dirty": None,
+ "error": pieces["error"],
+ "date": None,
+ }
if not style or style == "default":
style = "pep440" # the default
@@ -593,9 +616,13 @@ def render(pieces, style):
else:
raise ValueError("unknown style '%s'" % style)
- return {"version": rendered, "full-revisionid": pieces["long"],
- "dirty": pieces["dirty"], "error": None,
- "date": pieces.get("date")}
+ return {
+ "version": rendered,
+ "full-revisionid": pieces["long"],
+ "dirty": pieces["dirty"],
+ "error": None,
+ "date": pieces.get("date"),
+ }
def get_versions():
@@ -609,8 +636,7 @@ def get_versions():
verbose = cfg.verbose
try:
- return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
- verbose)
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
except NotThisMethod:
pass
@@ -619,13 +645,16 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
- for _ in cfg.versionfile_source.split('/'):
+ for _ in cfg.versionfile_source.split("/"):
root = os.path.dirname(root)
except NameError:
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to find root of source tree",
- "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to find root of source tree",
+ "date": None,
+ }
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
@@ -639,6 +668,10 @@ def get_versions():
except NotThisMethod:
pass
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to compute version", "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to compute version",
+ "date": None,
+ }
diff --git a/src/simpleSVGD/kernels.py b/src/simpleSVGD/kernels.py
new file mode 100644
index 0000000..6a182e7
--- /dev/null
+++ b/src/simpleSVGD/kernels.py
@@ -0,0 +1,21 @@
+import numpy as _numpy
+from scipy.spatial.distance import pdist as _pdist, squareform as _squareform
+
+
+def rbf_kernel(theta, h=-1):
+ """Radial basis function kernel."""
+ sq_dist = _pdist(theta)
+ pairwise_dists = _squareform(sq_dist) ** 2
+ if h < 0: # if h < 0, using median trick
+ h = _numpy.median(pairwise_dists)
+ h = _numpy.sqrt(0.5 * h / _numpy.log(theta.shape[0] + 1))
+
+ # compute the rbf kernel
+ Kxy = _numpy.exp(-pairwise_dists / h ** 2 / 2)
+
+ dxkxy = -_numpy.matmul(Kxy, theta)
+ sumkxy = _numpy.sum(Kxy, axis=1)
+ for i in range(theta.shape[1]):
+ dxkxy[:, i] = dxkxy[:, i] + _numpy.multiply(theta[:, i], sumkxy)
+ dxkxy = dxkxy / (h ** 2)
+ return (Kxy, dxkxy)
diff --git a/versioneer.py b/versioneer.py
index b4cd1d6..d70f31b 100644
--- a/versioneer.py
+++ b/versioneer.py
@@ -1,4 +1,3 @@
-
# Version: 0.21
"""The Versioneer - like a rocketeer, but for versions.
@@ -309,11 +308,13 @@ def get_root():
setup_py = os.path.join(root, "setup.py")
versioneer_py = os.path.join(root, "versioneer.py")
if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):
- err = ("Versioneer was unable to run the project root directory. "
- "Versioneer requires setup.py to be executed from "
- "its immediate directory (like 'python setup.py COMMAND'), "
- "or in a way that lets it use sys.argv[0] to find the root "
- "(like 'python path/to/setup.py COMMAND').")
+ err = (
+ "Versioneer was unable to run the project root directory. "
+ "Versioneer requires setup.py to be executed from "
+ "its immediate directory (like 'python setup.py COMMAND'), "
+ "or in a way that lets it use sys.argv[0] to find the root "
+ "(like 'python path/to/setup.py COMMAND')."
+ )
raise VersioneerBadRootError(err)
try:
# Certain runtime workflows (setup.py install/develop in a setuptools
@@ -326,8 +327,10 @@ def get_root():
me_dir = os.path.normcase(os.path.splitext(my_path)[0])
vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0])
if me_dir != vsr_dir:
- print("Warning: build in %s is using versioneer.py from %s"
- % (os.path.dirname(my_path), versioneer_py))
+ print(
+ "Warning: build in %s is using versioneer.py from %s"
+ % (os.path.dirname(my_path), versioneer_py)
+ )
except NameError:
pass
return root
@@ -372,15 +375,16 @@ class NotThisMethod(Exception):
def register_vcs_handler(vcs, method): # decorator
"""Create decorator to mark a method as the handler of a VCS."""
+
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
HANDLERS.setdefault(vcs, {})[method] = f
return f
+
return decorate
-def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
- env=None):
+def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
process = None
@@ -388,10 +392,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try:
dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
- process = subprocess.Popen([command] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None))
+ process = subprocess.Popen(
+ [command] + args,
+ cwd=cwd,
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr else None),
+ )
break
except OSError:
e = sys.exc_info()[1]
@@ -414,7 +421,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
return stdout, process.returncode
-LONG_VERSION_PY['git'] = r'''
+LONG_VERSION_PY[
+ "git"
+] = r'''
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
@@ -1116,7 +1125,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
- tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
+ tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
@@ -1125,7 +1134,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
- tags = {r for r in refs if re.search(r'\d', r)}
+ tags = {r for r in refs if re.search(r"\d", r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
@@ -1133,24 +1142,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
- r = ref[len(tag_prefix):]
+ r = ref[len(tag_prefix) :]
# Filter out refs that exactly match prefix or that don't start
# with a number once the prefix is stripped (mostly a concern
# when prefix is '')
- if not re.match(r'\d', r):
+ if not re.match(r"\d", r):
continue
if verbose:
print("picking %s" % r)
- return {"version": r,
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": None,
- "date": date}
+ return {
+ "version": r,
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False,
+ "error": None,
+ "date": date,
+ }
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
- return {"version": "0+unknown",
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": "no suitable tags", "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False,
+ "error": "no suitable tags",
+ "date": None,
+ }
@register_vcs_handler("git", "pieces_from_vcs")
@@ -1167,8 +1183,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
GITS = ["git.cmd", "git.exe"]
TAG_PREFIX_REGEX = r"\*"
- _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
@@ -1176,11 +1191,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty",
- "--always", "--long",
- "--match",
- "%s%s" % (tag_prefix, TAG_PREFIX_REGEX)],
- cwd=root)
+ describe_out, rc = runner(
+ GITS,
+ [
+ "describe",
+ "--tags",
+ "--dirty",
+ "--always",
+ "--long",
+ "--match",
+ "%s%s" % (tag_prefix, TAG_PREFIX_REGEX),
+ ],
+ cwd=root,
+ )
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
@@ -1195,8 +1218,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
- branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
- cwd=root)
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
# --abbrev-ref was added in git-1.6.3
if rc != 0 or branch_name is None:
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
@@ -1236,17 +1258,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
- git_describe = git_describe[:git_describe.rindex("-dirty")]
+ git_describe = git_describe[: git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
- mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
+ mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparsable. Maybe git-describe is misbehaving?
- pieces["error"] = ("unable to parse git-describe output: '%s'"
- % describe_out)
+ pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
# tag
@@ -1255,10 +1276,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
- pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
- % (full_tag, tag_prefix))
+ pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
+ full_tag,
+ tag_prefix,
+ )
return pieces
- pieces["closest-tag"] = full_tag[len(tag_prefix):]
+ pieces["closest-tag"] = full_tag[len(tag_prefix) :]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
@@ -1331,15 +1354,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
- return {"version": dirname[len(parentdir_prefix):],
- "full-revisionid": None,
- "dirty": False, "error": None, "date": None}
+ return {
+ "version": dirname[len(parentdir_prefix) :],
+ "full-revisionid": None,
+ "dirty": False,
+ "error": None,
+ "date": None,
+ }
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
- print("Tried directories %s but none started with prefix %s" %
- (str(rootdirs), parentdir_prefix))
+ print(
+ "Tried directories %s but none started with prefix %s"
+ % (str(rootdirs), parentdir_prefix)
+ )
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@@ -1368,11 +1397,13 @@ def versions_from_file(filename):
contents = f.read()
except OSError:
raise NotThisMethod("unable to read _version.py")
- mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON",
- contents, re.M | re.S)
+ mo = re.search(
+ r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S
+ )
if not mo:
- mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON",
- contents, re.M | re.S)
+ mo = re.search(
+ r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S
+ )
if not mo:
raise NotThisMethod("no version_json in _version.py")
return json.loads(mo.group(1))
@@ -1381,8 +1412,7 @@ def versions_from_file(filename):
def write_to_version_file(filename, versions):
"""Write the given version number to the given _version.py file."""
os.unlink(filename)
- contents = json.dumps(versions, sort_keys=True,
- indent=1, separators=(",", ": "))
+ contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": "))
with open(filename, "w") as f:
f.write(SHORT_VERSION_PY % contents)
@@ -1414,8 +1444,7 @@ def render_pep440(pieces):
rendered += ".dirty"
else:
# exception #1
- rendered = "0+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
@@ -1444,8 +1473,7 @@ def render_pep440_branch(pieces):
rendered = "0"
if pieces["branch"] != "master":
rendered += ".dev0"
- rendered += "+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
+ rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
@@ -1473,7 +1501,7 @@ def render_pep440_pre(pieces):
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
rendered = tag_version
if post_version is not None:
- rendered += ".post%d.dev%d" % (post_version+1, pieces["distance"])
+ rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
else:
rendered += ".post0.dev%d" % (pieces["distance"])
else:
@@ -1606,11 +1634,13 @@ def render_git_describe_long(pieces):
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
- return {"version": "unknown",
- "full-revisionid": pieces.get("long"),
- "dirty": None,
- "error": pieces["error"],
- "date": None}
+ return {
+ "version": "unknown",
+ "full-revisionid": pieces.get("long"),
+ "dirty": None,
+ "error": pieces["error"],
+ "date": None,
+ }
if not style or style == "default":
style = "pep440" # the default
@@ -1634,9 +1664,13 @@ def render(pieces, style):
else:
raise ValueError("unknown style '%s'" % style)
- return {"version": rendered, "full-revisionid": pieces["long"],
- "dirty": pieces["dirty"], "error": None,
- "date": pieces.get("date")}
+ return {
+ "version": rendered,
+ "full-revisionid": pieces["long"],
+ "dirty": pieces["dirty"],
+ "error": None,
+ "date": pieces.get("date"),
+ }
class VersioneerBadRootError(Exception):
@@ -1659,8 +1693,9 @@ def get_versions(verbose=False):
handlers = HANDLERS.get(cfg.VCS)
assert handlers, "unrecognized VCS '%s'" % cfg.VCS
verbose = verbose or cfg.verbose
- assert cfg.versionfile_source is not None, \
- "please set versioneer.versionfile_source"
+ assert (
+ cfg.versionfile_source is not None
+ ), "please set versioneer.versionfile_source"
assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix"
versionfile_abs = os.path.join(root, cfg.versionfile_source)
@@ -1714,9 +1749,13 @@ def get_versions(verbose=False):
if verbose:
print("unable to compute version")
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None, "error": "unable to compute version",
- "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to compute version",
+ "date": None,
+ }
def get_version():
@@ -1769,6 +1808,7 @@ def run(self):
print(" date: %s" % vers.get("date"))
if vers["error"]:
print(" error: %s" % vers["error"])
+
cmds["version"] = cmd_version
# we override "build_py" in both distutils and setuptools
@@ -1787,8 +1827,8 @@ def run(self):
# setup.py egg_info -> ?
# we override different "build_py" commands for both environments
- if 'build_py' in cmds:
- _build_py = cmds['build_py']
+ if "build_py" in cmds:
+ _build_py = cmds["build_py"]
elif "setuptools" in sys.modules:
from setuptools.command.build_py import build_py as _build_py
else:
@@ -1803,14 +1843,14 @@ def run(self):
# now locate _version.py in the new build/ directory and replace
# it with an updated value
if cfg.versionfile_build:
- target_versionfile = os.path.join(self.build_lib,
- cfg.versionfile_build)
+ target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build)
print("UPDATING %s" % target_versionfile)
write_to_version_file(target_versionfile, versions)
+
cmds["build_py"] = cmd_build_py
- if 'build_ext' in cmds:
- _build_ext = cmds['build_ext']
+ if "build_ext" in cmds:
+ _build_ext = cmds["build_ext"]
elif "setuptools" in sys.modules:
from setuptools.command.build_ext import build_ext as _build_ext
else:
@@ -1830,14 +1870,15 @@ def run(self):
return
# now locate _version.py in the new build/ directory and replace
# it with an updated value
- target_versionfile = os.path.join(self.build_lib,
- cfg.versionfile_build)
+ target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build)
print("UPDATING %s" % target_versionfile)
write_to_version_file(target_versionfile, versions)
+
cmds["build_ext"] = cmd_build_ext
if "cx_Freeze" in sys.modules: # cx_freeze enabled?
from cx_Freeze.dist import build_exe as _build_exe
+
# nczeczulin reports that py2exe won't like the pep440-style string
# as FILEVERSION, but it can be used for PRODUCTVERSION, e.g.
# setup(console=[{
@@ -1858,17 +1899,21 @@ def run(self):
os.unlink(target_versionfile)
with open(cfg.versionfile_source, "w") as f:
LONG = LONG_VERSION_PY[cfg.VCS]
- f.write(LONG %
- {"DOLLAR": "$",
- "STYLE": cfg.style,
- "TAG_PREFIX": cfg.tag_prefix,
- "PARENTDIR_PREFIX": cfg.parentdir_prefix,
- "VERSIONFILE_SOURCE": cfg.versionfile_source,
- })
+ f.write(
+ LONG
+ % {
+ "DOLLAR": "$",
+ "STYLE": cfg.style,
+ "TAG_PREFIX": cfg.tag_prefix,
+ "PARENTDIR_PREFIX": cfg.parentdir_prefix,
+ "VERSIONFILE_SOURCE": cfg.versionfile_source,
+ }
+ )
+
cmds["build_exe"] = cmd_build_exe
del cmds["build_py"]
- if 'py2exe' in sys.modules: # py2exe enabled?
+ if "py2exe" in sys.modules: # py2exe enabled?
from py2exe.distutils_buildexe import py2exe as _py2exe
class cmd_py2exe(_py2exe):
@@ -1884,18 +1929,22 @@ def run(self):
os.unlink(target_versionfile)
with open(cfg.versionfile_source, "w") as f:
LONG = LONG_VERSION_PY[cfg.VCS]
- f.write(LONG %
- {"DOLLAR": "$",
- "STYLE": cfg.style,
- "TAG_PREFIX": cfg.tag_prefix,
- "PARENTDIR_PREFIX": cfg.parentdir_prefix,
- "VERSIONFILE_SOURCE": cfg.versionfile_source,
- })
+ f.write(
+ LONG
+ % {
+ "DOLLAR": "$",
+ "STYLE": cfg.style,
+ "TAG_PREFIX": cfg.tag_prefix,
+ "PARENTDIR_PREFIX": cfg.parentdir_prefix,
+ "VERSIONFILE_SOURCE": cfg.versionfile_source,
+ }
+ )
+
cmds["py2exe"] = cmd_py2exe
# we override different "sdist" commands for both environments
- if 'sdist' in cmds:
- _sdist = cmds['sdist']
+ if "sdist" in cmds:
+ _sdist = cmds["sdist"]
elif "setuptools" in sys.modules:
from setuptools.command.sdist import sdist as _sdist
else:
@@ -1919,8 +1968,10 @@ def make_release_tree(self, base_dir, files):
# updated value
target_versionfile = os.path.join(base_dir, cfg.versionfile_source)
print("UPDATING %s" % target_versionfile)
- write_to_version_file(target_versionfile,
- self._versioneer_generated_versions)
+ write_to_version_file(
+ target_versionfile, self._versioneer_generated_versions
+ )
+
cmds["sdist"] = cmd_sdist
return cmds
@@ -1980,11 +2031,9 @@ def do_setup():
root = get_root()
try:
cfg = get_config_from_root(root)
- except (OSError, configparser.NoSectionError,
- configparser.NoOptionError) as e:
+ except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e:
if isinstance(e, (OSError, configparser.NoSectionError)):
- print("Adding sample versioneer config to setup.cfg",
- file=sys.stderr)
+ print("Adding sample versioneer config to setup.cfg", file=sys.stderr)
with open(os.path.join(root, "setup.cfg"), "a") as f:
f.write(SAMPLE_CONFIG)
print(CONFIG_ERROR, file=sys.stderr)
@@ -1993,15 +2042,18 @@ def do_setup():
print(" creating %s" % cfg.versionfile_source)
with open(cfg.versionfile_source, "w") as f:
LONG = LONG_VERSION_PY[cfg.VCS]
- f.write(LONG % {"DOLLAR": "$",
- "STYLE": cfg.style,
- "TAG_PREFIX": cfg.tag_prefix,
- "PARENTDIR_PREFIX": cfg.parentdir_prefix,
- "VERSIONFILE_SOURCE": cfg.versionfile_source,
- })
-
- ipy = os.path.join(os.path.dirname(cfg.versionfile_source),
- "__init__.py")
+ f.write(
+ LONG
+ % {
+ "DOLLAR": "$",
+ "STYLE": cfg.style,
+ "TAG_PREFIX": cfg.tag_prefix,
+ "PARENTDIR_PREFIX": cfg.parentdir_prefix,
+ "VERSIONFILE_SOURCE": cfg.versionfile_source,
+ }
+ )
+
+ ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py")
if os.path.exists(ipy):
try:
with open(ipy, "r") as f:
@@ -2049,8 +2101,10 @@ def do_setup():
else:
print(" 'versioneer.py' already in MANIFEST.in")
if cfg.versionfile_source not in simple_includes:
- print(" appending versionfile_source ('%s') to MANIFEST.in" %
- cfg.versionfile_source)
+ print(
+ " appending versionfile_source ('%s') to MANIFEST.in"
+ % cfg.versionfile_source
+ )
with open(manifest_in, "a") as f:
f.write("include %s\n" % cfg.versionfile_source)
else: