Skip to content

Commit

Permalink
Cleaned up the files
Browse files Browse the repository at this point in the history
  • Loading branch information
larsgeb committed Jan 31, 2022
1 parent 2aaafa2 commit a87dafc
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 228 deletions.
23 changes: 10 additions & 13 deletions notebooks/Tutorial on using simpleSVGD.ipynb

Large diffs are not rendered by default.

47 changes: 7 additions & 40 deletions src/simpleSVGD/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
from enum import auto
import numpy as _numpy
import tqdm.auto as _tqdm_auto
from scipy.spatial.distance import pdist as _pdist, squareform as _squareform
import matplotlib.pyplot as _plt
import matplotlib.figure as _figure
from typing import Callable as _Callable, List as _List, Tuple as _Tuple

from .kernels import rbf_kernel as _rbf_kernel
from .helpers import TorchWrapper as _TorchWrapper

def _rbf_kernel(theta, h=-1):
"""Radial basis function kernel."""
sq_dist = _pdist(theta)
pairwise_dists = _squareform(sq_dist) ** 2
if h < 0: # if h < 0, using median trick
h = _numpy.median(pairwise_dists)
h = _numpy.sqrt(0.5 * h / _numpy.log(theta.shape[0] + 1))

# compute the rbf kernel
Kxy = _numpy.exp(-pairwise_dists / h ** 2 / 2)
from . import _version

dxkxy = -_numpy.matmul(Kxy, theta)
sumkxy = _numpy.sum(Kxy, axis=1)
for i in range(theta.shape[1]):
dxkxy[:, i] = dxkxy[:, i] + _numpy.multiply(theta[:, i], sumkxy)
dxkxy = dxkxy / (h ** 2)
return (Kxy, dxkxy)
__version__ = _version.get_versions()["version"]


def update(
Expand All @@ -33,9 +20,6 @@ def update(
n_iter: int = 1000,
stepsize: float = 1e-3,
bandwidth: float = -1,
alpha: float = 0.9,
fudge_factor=1e-3,
historical_grad=1,
# All following parameter only concern animation
animate: bool = False,
figure: _figure.Figure = None,
Expand Down Expand Up @@ -68,10 +52,6 @@ def update(
more likely to produce good results, but will slow the algorithm down.
Default is -1.
alpha
Parameter with which to dampen gradient changes in the target function
during SVGD updating. Default is 0.9.
animate
A boolean to animate the algorithm. Only works for functions of at
least two dimensions. Default is False.
Expand Down Expand Up @@ -127,7 +107,7 @@ def update(
axis.set_aspect(1)

figure.canvas.draw()
_plt.pause(0.00001)
_plt.pause(1e-5)

# The Try/Except allows on to interrupt the algorithm using CTRL+C while
# still getting x0_updated at the point of interruption.
Expand All @@ -143,17 +123,7 @@ def update(
0
]

# adagrad
if iter == 0:
historical_grad = historical_grad + grad_theta ** 2
else:
historical_grad = alpha * historical_grad + (1 - alpha) * (
grad_theta ** 2
)
adj_grad = _numpy.divide(
grad_theta, fudge_factor + _numpy.sqrt(historical_grad)
)
x0_updated = x0_updated + stepsize * adj_grad
x0_updated = x0_updated - stepsize * grad_theta

if animate:
scatter.set_offsets(
Expand All @@ -165,7 +135,7 @@ def update(
)
)
figure.canvas.draw()
_plt.pause(0.00001)
_plt.pause(1e-5)

except KeyboardInterrupt:
pass
Expand All @@ -183,6 +153,3 @@ def grd(m: _numpy.array) -> _numpy.array:
).T

return grd

from . import _version
__version__ = _version.get_versions()['version']
Loading

0 comments on commit a87dafc

Please sign in to comment.