Skip to content

Commit

Permalink
Use where instead of changing matrices (for jax later)
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianPfaff committed Nov 29, 2023
1 parent d3806e3 commit e79a755
Show file tree
Hide file tree
Showing 50 changed files with 3,327 additions and 17 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/merge_linter_changes.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Auto Merge PRs

on:
pull_request:
types:
- opened
- synchronize
- labeled
- unlabeled
- edited

jobs:
auto-merge:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Auto-merge PR if criteria are met
run: |
if [[ "${{ github.event.pull_request.base.ref }}" != "main" && "${{ github.event.pull_request.title }}" == "[MegaLinter]"* ]]; then
echo "Criteria met. Auto-merging the PR."
curl -X PUT "https://api.github.com/repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/merge" \
-H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
-H "Accept: application/vnd.github.v3+json"
else
echo "Criteria not met. Not auto-merging."
fi
Empty file added ModelNet10-SO3.zip
Empty file.
Empty file added ModelNet10.zip
Empty file.
40 changes: 40 additions & 0 deletions add_pylint_disables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import re

def add_pylint_disable(filename):
with open(filename, 'r') as file:
content = file.read()

def replacer(match):
import_statement = match.group(1)
if '# pylint: disable=redefined-builtin' not in import_statement:
return f'# pylint: disable=redefined-builtin\n{import_statement}'
return import_statement

pattern_single = re.compile(r'(from pyrecest\.backend import.*\b(?:all|any|sum)\b.*)')
pattern_multi = re.compile(r'(from pyrecest\.backend import \([\s\S]*?\b(?:all|any|sum)\b[\s\S]*?\))', re.MULTILINE)

content = re.sub(pattern_single, replacer, content)
content = re.sub(pattern_multi, replacer, content)

# Ensure there's a newline at the end of the file
content = content.rstrip('\n') + '\n'

with open(filename, 'w') as file:
file.write(content)


def add_import_statements_and_replace(root_dir):
script_path = os.path.abspath(__file__) # Get the path of this script
for subdir, dirs, files in os.walk(root_dir):
# Skip hidden directories and directories starting with _
dirs[:] = [d for d in dirs if not d.startswith(('.', '_'))]
for file in files:
if file.endswith('.py') and not file == os.path.basename(script_path):
file_path = os.path.join(subdir, file)
add_pylint_disable(file_path)


# Specify the directory to start the search from
root_directory = './pyrecest'
add_import_statements_and_replace(root_directory)
39 changes: 39 additions & 0 deletions change_files_for_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os

def add_import_statements_and_replace(root_dir):
items = [
"meshgrid",
]
script_path = os.path.abspath(__file__) # Get the path of this script
for subdir, dirs, files in os.walk(root_dir):
# Skip hidden directories and directories starting with _
dirs[:] = [d for d in dirs if not d.startswith(('.', '_'))]
for file in files:
# Only process Python files
if file.endswith('.py') and not file.startswith('.'):
file_path = os.path.join(subdir, file)
if file_path == script_path:
# Skip processing this script file
continue
with open(file_path, 'r') as f:
file_content = f.read()
updated_content = file_content
# Check if any item is in the file content
import_statements = []
for item in items:
if f'np.{item}' in file_content:
import_statements.append(f'from pyrecest.backend import {item}')
updated_content = updated_content.replace(f'np.{item}', item)
# Prepend import statements if not already present
for import_statement in import_statements:
if import_statement not in updated_content:
updated_content = f'{import_statement}\n{updated_content}'
# Write updated content back to file if any changes were made
if updated_content != file_content:
with open(file_path, 'w') as f:
f.write(updated_content)
print(f'Updated {file_path}')

# Specify the directory to start the search from
root_directory = './pyrecest'
add_import_statements_and_replace(root_directory)
55 changes: 55 additions & 0 deletions grid_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np

def generate_gaussian_like_grid_s3_v1(n_lat, n_lon):
# Generate Gaussian-like latitudes
beta = np.linspace(-np.pi / 2, np.pi / 2, n_lat + 1)
latitudes = np.arcsin(0.5 * (np.sin(beta[:-1]) + np.sin(beta[1:])))

# Generate longitudes for each latitude
longitudes = []
for i, lat in enumerate(latitudes):
n_lon_i = int(n_lon * np.cos(lat))
longitudes.append(np.linspace(0, 2 * np.pi, n_lon_i, endpoint=False))

# Generate Cartesian coordinates for the Gaussian-like grid on S^3
grid_points = []
for lat, lon_set in zip(latitudes, longitudes):
for lon in lon_set:
x = np.cos(lat) * np.cos(lon)
y = np.cos(lat) * np.sin(lon)
z = np.sin(lat)
w = np.sqrt(1 - x**2 - y**2 - z**2)
grid_points.append([w, x, y, z])

return np.array(grid_points)

n_lat = 10
n_lon = 20
grid_points_s3 = generate_gaussian_like_grid_s3_v1(n_lat, n_lon)

def generate_gaussian_grid_on_S3_v2(n_theta, n_phi, sigma_theta, sigma_phi):
theta = np.linspace(0, np.pi, n_theta)
phi = np.linspace(0, 2 * np.pi, n_phi)

theta_weights = np.exp(-0.5 * (theta - np.pi/2)**2 / sigma_theta**2)
theta_weights /= np.sum(theta_weights)
phi_weights = np.exp(-0.5 * phi**2 / sigma_phi**2)
phi_weights /= np.sum(phi_weights)

theta_grid, phi_grid = np.meshgrid(theta, phi, indexing='ij')
x = np.sin(theta_grid) * np.cos(phi_grid)
y = np.sin(theta_grid) * np.sin(phi_grid)
z = np.cos(theta_grid)

weights = np.outer(theta_weights, phi_weights)
grid_points = np.stack([x, y, z], axis=-1)

return grid_points, weights

n_theta = 10
n_phi = 10
sigma_theta = np.pi / 8
sigma_phi = np.pi / 8

grid_points, weights = generate_gaussian_grid_on_S3_v2(n_theta, n_phi, sigma_theta, sigma_phi)
print(grid_points)
146 changes: 146 additions & 0 deletions lcd_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Based on Matlab code by UDH

import numpy as np
from scipy.optimize import minimize
from scipy.interpolate import UnivariateSpline
from scipy.integrate import quad

import matplotlib.pyplot as plt

def glcd(x, y, wxy, SX, SY):
assert x.ndim == 1, 'x must be a vector'
assert y.ndim == 1, 'y must be a vector'
assert wxy.ndim == 1, 'y must be a vector'

xb = np.concatenate((x, y))
res = minimize(objective_function, xb, args=(wxy, SX, SY), method='L-BFGS-B', jac=gradient_function, options={'disp': False, 'ftol': 1e-22, 'maxiter': 5000})
x, y = res.x[:len(x)], res.x[len(x):]
return x, y, res.fun

def objective_function(xb, wxy, SX, SY):
x, y = xb[:len(wxy)], xb[len(wxy):]
lambda_ = np.array([1000, 1000, 10, 10, 10])
f = distance_measure_gaussian_numeric(wxy, x, y, SX, SY, lambda_)
return f

def gradient_function(xb, wxy, SX, SY):
x, y = xb[:len(wxy)], xb[len(wxy):]
lambda_ = np.array([1000, 1000, 10, 10, 10])
r = gradient_gaussian_numeric(wxy, x, y, SX, SY, lambda_)
return r

def distance_measure_gaussian_numeric(wxy, x, y, SX, SY, lambda_):
flag = 0
sx = SX
sy = SY
bmax = 20
Cb = np.log(4 * bmax ** 2) - 0.577216
b = np.linspace(0.0001, np.sqrt(bmax), 100) ** 2
xx = np.column_stack((x, y))
N = 2
L = wxy.shape[0]

if flag == 0:
flag = 1
G1 = np.pi ** (N / 2) * b ** (N + 1) / (np.sqrt(sx ** 2 + b ** 2) * np.sqrt(sy ** 2 + b ** 2))
pp = UnivariateSpline(b, G1, s=0)
G1, _ = quad(pp, 0, bmax)

G2t = 0

for i in range(L):
G2t += wxy[i] * np.exp(-0.5 * (xx[i, 0] ** 2 / (sx ** 2 + 2 * b ** 2) + xx[i, 1] ** 2 / (sy ** 2 + 2 * b ** 2)))

G2 = -2 * (2 * np.pi) ** (N / 2) * b ** (N + 1) / (np.sqrt(sx ** 2 + 2 * b ** 2) * np.sqrt(sy ** 2 + 2 * b ** 2)) * G2t
pp = UnivariateSpline(b, G2, s=0)
G2, _ = quad(pp, 0, bmax)

Mxx = np.subtract.outer(x, x).T
Myy = np.subtract.outer(y, y).T
T = Mxx ** 2 + Myy ** 2
G3 = np.squeeze(np.pi * wxy @ (4 * bmax ** 2 * np.exp(-0.5 * T / (2 * bmax ** 2)) - Cb * T + xplog(T) - T ** 2 / (4 * bmax ** 2)) @ wxy / 8)

G = G1 + G2 + G3 + lambda_[0] * (wxy @ x) ** 2 + lambda_[1] * (wxy @ y) ** 2 + \
lambda_[2] * (wxy @ (x ** 2) - sx ** 2) ** 2 + lambda_[3] * (wxy @ (x * y)) ** 2 + \
lambda_[4] * (wxy @ (y ** 2) - sy ** 2) ** 2

return G

def gradient_gaussian_numeric(wxy, x, y, SX, SY, lambda_):
s = np.array([SX, SY])
bmax = 20
Cb = np.log(4 * bmax ** 2) - 0.577216
xx = np.column_stack((x, y))
N = 2
L = len(wxy)
db = 0.005
b = np.arange(db, bmax + db, db)

H = 2 * (2 * np.pi) ** (N / 2) * b ** (N + 1) / (np.sqrt(s[0] ** 2 + 2 * b ** 2) * np.sqrt(s[1] ** 2 + 2 * b ** 2))

G1 = zeros((2 * L, len(b)))

for eta in range(2):
k = H / (s[eta] ** 2 + 2 * b ** 2)
for i in range(L):
G1[eta * L + i, :] = wxy[i] * xx[i, eta] * k * np.exp(-0.5 * (xx[i, 0] ** 2 / (s[0] ** 2 + 2 * b ** 2) + xx[i, 1] ** 2 / (s[1] ** 2 + 2 * b ** 2)))

G1 = db * np.sum(G1, axis=1)

Mxx = np.subtract.outer(x, x).T
Myy = np.subtract.outer(y, y).T
M = Mxx ** 2 + Myy ** 2
T = plog(M) - M / (4 * bmax ** 2)

rx = (wxy @ (Mxx * T))
ry = (wxy @ (Myy * T))

G2 = np.pi * np.hstack((np.squeeze(rx) + Cb * (wxy @ x - x), np.squeeze(ry) + Cb * (wxy @ y - y))) / (2 * L)

G3 = np.hstack((
np.squeeze(2 * lambda_[0] * (wxy @ x) * wxy) + 4 * (wxy @ (x ** 2) - s[0] ** 2) * lambda_[2] * np.squeeze(wxy) * x + 2 * (wxy @ (x * y)) * lambda_[3] * np.squeeze(wxy) * y,
np.squeeze(2 * lambda_[1] * (wxy @ y) * wxy) + 4 * (wxy @ (y ** 2) - s[1] ** 2) * lambda_[4] * np.squeeze(wxy) * y + 2 * (wxy @ (x * y)) * lambda_[3] * np.squeeze(wxy) * x))

G = G1 + G2 + G3

return G

def plog(x):
x = np.array(x)
indx = (x == 0)
x[indx] = 1
y = np.log(x)
return y

def xplog(x):
return x * plog(x)

def randn_box_muller(n_samples):
numpy_random_numbers = np.random.randn(n_samples)
return numpy_random_numbers


if __name__ == "__main__":
SX = 1
SY = 0.7
L = 10

for SY in np.arange(0.9, 0.001, -0.02):
wxy = np.ones(L) / L
seed_value = 42
np.random.seed(seed_value)


x = SX * np.array(randn_box_muller(int(L)))
y = SY * np.array(randn_box_muller(int(L)))

x, y, G = glcd(x, y, wxy, SX, SY)
plt.cla()
plt.plot(x, y, '.', markeredgecolor=[1, 1, 1], markersize=10)
plt.plot(x, y, 'r.', markersize=7)
plt.axis('equal')
plt.gca().set_xlim([-4, 4])
plt.gca().set_ylim([-4, 4])
plt.draw()
plt.show(block=False)
plt.pause(.001)
Loading

0 comments on commit e79a755

Please sign in to comment.