diff --git a/examples/gmrf.py b/examples/gmrf.py index 5228d0d..8b330c4 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -155,6 +155,7 @@ # # The following training loop requires a GPU with at least 11 GB of memory. + # %% @jax.jit def loss(noisy_image, target_image, log_potentials): diff --git a/examples/ising_model.py b/examples/ising_model.py index 1befd27..5a2826b 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -65,6 +65,7 @@ # %% [markdown] # ### Gradients and batching + # %% def loss(log_potentials_updates, evidence_updates): bp_arrays = bp.init( diff --git a/examples/rcn.py b/examples/rcn.py index df5a30d..24a7eaf 100644 --- a/examples/rcn.py +++ b/examples/rcn.py @@ -225,6 +225,7 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5) -> tuple[np.ndarray, np.n # %% [markdown] # ### 4.2.1 Pre-compute the valid configs for different perturb radius. + # %% def valid_configs(r: int, hps: int, vps: int) -> np.ndarray: """Returns the valid configurations for a factor given the perturb radius. @@ -294,6 +295,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray: # %% [markdown] # ## 5.1 Helper functions to initialize the evidence for a given image + # %% def get_bu_msg(img: np.ndarray) -> np.ndarray: """Computes the bottom-up messages given a test image. @@ -365,6 +367,7 @@ def get_bu_msg(img: np.ndarray) -> np.ndarray: # %% [markdown] # ## 5.2 Run MAP inference on all test images + # %% def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray: """Returns the evidence (shape (n_frcs, M)). diff --git a/pgmax/factor/enum.py b/pgmax/factor/enum.py index 856965e..bd27075 100644 --- a/pgmax/factor/enum.py +++ b/pgmax/factor/enum.py @@ -256,7 +256,6 @@ def pass_enum_fac_to_var_messages( num_val_configs: int, temperature: float, ) -> jnp.ndarray: - """Passes messages from EnumFactors to Variables. The update is performed in two steps. First, a "summary" array is generated that has an entry for every valid diff --git a/pgmax/factor/logical.py b/pgmax/factor/logical.py index b9d9fcd..97e19a9 100644 --- a/pgmax/factor/logical.py +++ b/pgmax/factor/logical.py @@ -312,7 +312,6 @@ def pass_logical_fac_to_var_messages( temperature: float, log_potentials: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: - """Passes messages from LogicalFactors to Variables. Args: