Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 3, 2023
1 parent b49721e commit 2f72194
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/gmrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
#
# The following training loop requires a GPU with at least 11 GB of memory.


# %%
@jax.jit
def loss(noisy_image, target_image, log_potentials):
Expand Down
1 change: 1 addition & 0 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
# %% [markdown]
# ### Gradients and batching


# %%
def loss(log_potentials_updates, evidence_updates):
bp_arrays = bp.init(
Expand Down
3 changes: 3 additions & 0 deletions examples/rcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5) -> tuple[np.ndarray, np.n
# %% [markdown]
# ### 4.2.1 Pre-compute the valid configs for different perturb radius.


# %%
def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
"""Returns the valid configurations for a factor given the perturb radius.
Expand Down Expand Up @@ -294,6 +295,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
# %% [markdown]
# ## 5.1 Helper functions to initialize the evidence for a given image


# %%
def get_bu_msg(img: np.ndarray) -> np.ndarray:
"""Computes the bottom-up messages given a test image.
Expand Down Expand Up @@ -365,6 +367,7 @@ def get_bu_msg(img: np.ndarray) -> np.ndarray:
# %% [markdown]
# ## 5.2 Run MAP inference on all test images


# %%
def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
"""Returns the evidence (shape (n_frcs, M)).
Expand Down
1 change: 0 additions & 1 deletion pgmax/factor/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def pass_enum_fac_to_var_messages(
num_val_configs: int,
temperature: float,
) -> jnp.ndarray:

"""Passes messages from EnumFactors to Variables.
The update is performed in two steps. First, a "summary" array is generated that has an entry for every valid
Expand Down
1 change: 0 additions & 1 deletion pgmax/factor/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def pass_logical_fac_to_var_messages(
temperature: float,
log_potentials: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:

"""Passes messages from LogicalFactors to Variables.
Args:
Expand Down

0 comments on commit 2f72194

Please sign in to comment.