diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b5a3c46 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel" +] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.py b/setup.py index ba428b3..2876de2 100755 --- a/setup.py +++ b/setup.py @@ -13,8 +13,10 @@ long_description_content_type="text/markdown", url="https://github.com/larsgeb/simpleSVGD", project_urls={ - "Bug Tracker": "https://github.com/larsgeb/simpleSVGD/issues", }, - packages=setuptools.find_packages(), + "Bug Tracker": "https://github.com/larsgeb/simpleSVGD/issues", + }, + package_dir={"": "src"}, + packages=setuptools.find_packages(where="src", exclude=["test"]), classifiers=[ "Development Status :: 4 - Beta", "Programming Language :: Python :: 3.7", @@ -23,5 +25,10 @@ ], python_requires=">=3.7", install_requires=["numpy", "tqdm", "scipy", "matplotlib"], - extras_require={"dev": ["black", "pytest", ]}, + extras_require={ + "dev": [ + "black", + "pytest", + ] + }, ) diff --git a/simpleSVGD/__init__.py b/src/simpleSVGD/__init__.py similarity index 66% rename from simpleSVGD/__init__.py rename to src/simpleSVGD/__init__.py index cabb504..e6f1b19 100644 --- a/simpleSVGD/__init__.py +++ b/src/simpleSVGD/__init__.py @@ -23,9 +23,7 @@ def _rbf_kernel(theta, h=-1): return (Kxy, dxkxy) -def update( - x0, gradient_fn, n_iter=1000, stepsize=1e-3, bandwidth=-1, alpha=0.9, h=-1 -): +def update(x0, gradient_fn, n_iter=1000, stepsize=1e-3, bandwidth=-1, alpha=0.9, h=-1): # Check input if x0 is None or gradient_fn is None: raise ValueError("x0 or gradient_fn cannot be None!") @@ -47,10 +45,10 @@ def update( if iter == 0: historical_grad = historical_grad + grad_theta ** 2 else: - historical_grad = alpha * historical_grad + \ - (1 - alpha) * (grad_theta ** 2) - adj_grad = _np.divide(grad_theta, fudge_factor + - _np.sqrt(historical_grad)) + historical_grad = alpha * historical_grad + (1 - alpha) * ( + grad_theta ** 2 + ) + adj_grad = _np.divide(grad_theta, fudge_factor + _np.sqrt(historical_grad)) theta = theta + stepsize * adj_grad except KeyboardInterrupt: pass @@ -59,11 +57,19 @@ def update( def update_visual( - x0, lnprob, n_iter=1000, stepsize=1e-3, bandwidth=-1, alpha=0.9, figure=None, dimensions_to_plot=[0, 1] + x0, + lnprob, + n_iter=1000, + stepsize=1e-3, + bandwidth=-1, + alpha=0.9, + figure=None, + dimensions_to_plot=[0, 1], + background=None, ): if figure is None: - figure = _plt.figure(figsize=(4, 4)) + figure = _plt.figure(figsize=(8, 8)) axis = _plt.gca() @@ -77,9 +83,26 @@ def update_visual( fudge_factor = 1e-6 historical_grad = 0 + if background is not None: + x1s, x2s, background_image = background + + axis.contour( + x1s, + x2s, + _np.exp(-background_image), + levels=20, + alpha=0.5, + zorder=0, + ) + scatter = axis.scatter( theta[:, dimensions_to_plot[0]], theta[:, dimensions_to_plot[1]] ) + + if background is not None: + _plt.xlim([x1s.min(), x1s.max()]) + _plt.ylim([x2s.min(), x2s.max()]) + axis.set_aspect(1) figure.canvas.draw() @@ -97,15 +120,19 @@ def update_visual( if iter == 0: historical_grad = historical_grad + grad_theta ** 2 else: - historical_grad = alpha * historical_grad + \ - (1 - alpha) * (grad_theta ** 2) - adj_grad = _np.divide(grad_theta, fudge_factor + - _np.sqrt(historical_grad)) + historical_grad = alpha * historical_grad + (1 - alpha) * ( + grad_theta ** 2 + ) + adj_grad = _np.divide(grad_theta, fudge_factor + _np.sqrt(historical_grad)) theta = theta + stepsize * adj_grad scatter.set_offsets( - _np.hstack((theta[:, dimensions_to_plot[0], None], - theta[:, dimensions_to_plot[1], None])) + _np.hstack( + ( + theta[:, dimensions_to_plot[0], None], + theta[:, dimensions_to_plot[1], None], + ) + ) ) figure.canvas.draw() _plt.pause(0.00001)