Skip to content

Commit

Permalink
Restructured installation
Browse files Browse the repository at this point in the history
  • Loading branch information
larsgeb committed Jan 28, 2022
1 parent 4536efa commit b8d3dd3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"
13 changes: 10 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
long_description_content_type="text/markdown",
url="https://github.com/larsgeb/simpleSVGD",
project_urls={
"Bug Tracker": "https://github.com/larsgeb/simpleSVGD/issues", },
packages=setuptools.find_packages(),
"Bug Tracker": "https://github.com/larsgeb/simpleSVGD/issues",
},
package_dir={"": "src"},
packages=setuptools.find_packages(where="src", exclude=["test"]),
classifiers=[
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3.7",
Expand All @@ -23,5 +25,10 @@
],
python_requires=">=3.7",
install_requires=["numpy", "tqdm", "scipy", "matplotlib"],
extras_require={"dev": ["black", "pytest", ]},
extras_require={
"dev": [
"black",
"pytest",
]
},
)
57 changes: 42 additions & 15 deletions simpleSVGD/__init__.py → src/simpleSVGD/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def _rbf_kernel(theta, h=-1):
return (Kxy, dxkxy)


def update(
x0, gradient_fn, n_iter=1000, stepsize=1e-3, bandwidth=-1, alpha=0.9, h=-1
):
def update(x0, gradient_fn, n_iter=1000, stepsize=1e-3, bandwidth=-1, alpha=0.9, h=-1):
# Check input
if x0 is None or gradient_fn is None:
raise ValueError("x0 or gradient_fn cannot be None!")
Expand All @@ -47,10 +45,10 @@ def update(
if iter == 0:
historical_grad = historical_grad + grad_theta ** 2
else:
historical_grad = alpha * historical_grad + \
(1 - alpha) * (grad_theta ** 2)
adj_grad = _np.divide(grad_theta, fudge_factor +
_np.sqrt(historical_grad))
historical_grad = alpha * historical_grad + (1 - alpha) * (
grad_theta ** 2
)
adj_grad = _np.divide(grad_theta, fudge_factor + _np.sqrt(historical_grad))
theta = theta + stepsize * adj_grad
except KeyboardInterrupt:
pass
Expand All @@ -59,11 +57,19 @@ def update(


def update_visual(
x0, lnprob, n_iter=1000, stepsize=1e-3, bandwidth=-1, alpha=0.9, figure=None, dimensions_to_plot=[0, 1]
x0,
lnprob,
n_iter=1000,
stepsize=1e-3,
bandwidth=-1,
alpha=0.9,
figure=None,
dimensions_to_plot=[0, 1],
background=None,
):

if figure is None:
figure = _plt.figure(figsize=(4, 4))
figure = _plt.figure(figsize=(8, 8))

axis = _plt.gca()

Expand All @@ -77,9 +83,26 @@ def update_visual(
fudge_factor = 1e-6
historical_grad = 0

if background is not None:
x1s, x2s, background_image = background

axis.contour(
x1s,
x2s,
_np.exp(-background_image),
levels=20,
alpha=0.5,
zorder=0,
)

scatter = axis.scatter(
theta[:, dimensions_to_plot[0]], theta[:, dimensions_to_plot[1]]
)

if background is not None:
_plt.xlim([x1s.min(), x1s.max()])
_plt.ylim([x2s.min(), x2s.max()])

axis.set_aspect(1)

figure.canvas.draw()
Expand All @@ -97,15 +120,19 @@ def update_visual(
if iter == 0:
historical_grad = historical_grad + grad_theta ** 2
else:
historical_grad = alpha * historical_grad + \
(1 - alpha) * (grad_theta ** 2)
adj_grad = _np.divide(grad_theta, fudge_factor +
_np.sqrt(historical_grad))
historical_grad = alpha * historical_grad + (1 - alpha) * (
grad_theta ** 2
)
adj_grad = _np.divide(grad_theta, fudge_factor + _np.sqrt(historical_grad))
theta = theta + stepsize * adj_grad

scatter.set_offsets(
_np.hstack((theta[:, dimensions_to_plot[0], None],
theta[:, dimensions_to_plot[1], None]))
_np.hstack(
(
theta[:, dimensions_to_plot[0], None],
theta[:, dimensions_to_plot[1], None],
)
)
)
figure.canvas.draw()
_plt.pause(0.00001)
Expand Down

0 comments on commit b8d3dd3

Please sign in to comment.