Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpointed_scan #60

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 71 additions & 14 deletions src/haliax/hof.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import dataclasses
import functools
import inspect
from functools import wraps
from typing import Any, Callable, ParamSpec, Protocol, Tuple, TypeVar, Union, overload
from typing import Any, Callable, Optional, ParamSpec, Protocol, Sequence, Tuple, TypeVar, Union, overload

import equinox as eqx
import jax
import jax.lax as lax
import numpy as np
from jaxtyping import PyTree

import haliax
Expand All @@ -14,7 +16,7 @@
from ._src.util import index_where
from .axis import Axis, AxisSelector, selects_axis
from .core import NamedArray
from .jax_utils import Static, broadcast_prefix, is_jax_array_like
from .jax_utils import Static, broadcast_prefix, checkpointed_scan, is_jax_array_like
from .partitioning import physical_axis_name
from .util import is_jax_or_hax_array_like, is_named_array

Expand Down Expand Up @@ -45,6 +47,8 @@ def scan(
reverse: bool = False,
unroll: int = 1,
is_scanned: BoolAxisSpec = is_named_or_shaped_array_like,
grad_checkpointing: bool = False,
checkpoint_blocks: Optional[Sequence[int]] = None,
) -> Callable[[Carry, PyTree[X]], Tuple[Carry, PyTree[Y]]]:
...

Expand All @@ -57,6 +61,8 @@ def scan(
reverse: bool = False,
unroll: int = 1,
is_scanned: BoolAxisSpec = is_named_or_shaped_array_like,
grad_checkpointing: bool = False,
checkpoint_blocks: Optional[Sequence[int]] = None,
) -> Callable:
...

Expand All @@ -68,6 +74,8 @@ def scan(
reverse=False,
unroll=1,
is_scanned: BoolAxisSpec = is_named_or_shaped_array_like,
grad_checkpointing: bool = False,
checkpoint_blocks: Optional[Sequence[int]] = None,
):
"""
Scan over a named axis. Non-scalar unnamed arrays will have their first axis scanned over.
Expand Down Expand Up @@ -112,6 +120,16 @@ def scanned_f(init, *args, **kwargs):
# invariants until we're ready to create the result.
axis_first_xs = htu.tree_map(_ensure_first(axis), scanned_xs)

# if we were passed in a string arg, we need to get its axis size out from some arg
if isinstance(axis, str):
true_axis = _infer_axis_size_from_tree(axis_first_xs, axis)
if true_axis is not None:
true_axis
else:
raise ValueError("scan requires either an actual Axis or at least one NamedArray or array arg")
else:
true_axis = axis

# now get a template of an element of "X"
x_elem = htu.tree_map(_select_0th(axis), axis_first_xs)
# NB: we don't want to use htu.tree_structure here because we want to eliminate the leading axis
Expand All @@ -130,16 +148,39 @@ def wrapped_fn(carry, scanned_x_leaves):

# as above, we don't want to use htu.tree_leaves here because we want to eliminate the leading axis
leaves = jax.tree_util.tree_leaves(axis_first_xs)
with jax.named_scope(f"scan({haliax.axis_name(axis)})"):
carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
true_axis = _infer_axis_size_from_result(ys, axis)
if grad_checkpointing:
if unroll != 1:
# TODO: support for case when it's a suffix of block size?
raise ValueError("Can't use grad_checkpointing with unroll != 1")
with jax.named_scope(f"ckpt_scan({haliax.axis_name(axis)})"):
blocks = _rectify_scan_lengths(true_axis, checkpoint_blocks)

scan_fn = functools.partial(checkpointed_scan, lengths=blocks, prevent_cse=False, reverse=reverse)
carry, ys = scan_fn(wrapped_fn, init, leaves)
else:
with jax.named_scope(f"scan({haliax.axis_name(axis)})"):
carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)

true_axis = _infer_axis_size_from_tree(ys, axis)
ys = jax.tree_util.tree_map(_prepend_named_batch_axis(true_axis), ys, is_leaf=_is_passive_array)

return carry, ys

return scanned_f


def _rectify_scan_lengths(axis: Axis, checkpoint_blocks: Optional[Sequence[int]]) -> list[int]:
blocks = checkpoint_blocks or [axis.size]
cur_size = np.prod(blocks)
if cur_size != axis.size:
left = axis.size // cur_size
if left * cur_size != axis.size:
raise ValueError(f"Can't partition {axis.size} into blocks of size {blocks}")
return list(blocks) + [left]
else:
return list(blocks)


@overload
def fold(
fn: Callable[[Carry, X], Carry],
Expand All @@ -148,6 +189,8 @@ def fold(
reverse: bool = False,
unroll: int = 1,
is_scanned: BoolAxisSpec = is_jax_or_hax_array_like,
grad_checkpointing: bool = False,
checkpoint_blocks: Optional[Sequence[int]] = None,
) -> Callable[[Carry, PyTree[X]], Carry]:
...

Expand All @@ -160,6 +203,8 @@ def fold(
reverse: bool = False,
unroll: int = 1,
is_scanned: BoolAxisSpec = is_jax_or_hax_array_like,
grad_checkpointing: bool = False,
checkpoint_blocks: Optional[Sequence[int]] = None,
) -> Callable:
...

Expand All @@ -171,6 +216,8 @@ def fold(
reverse: bool = False,
unroll: int = 1,
is_scanned: BoolAxisSpec = is_named_or_shaped_array_like,
grad_checkpointing: bool = False,
checkpoint_blocks: Optional[Sequence[int]] = None,
) -> Callable:
"""
Slightly simpler implementation of scan that folds over the named axis of the array, not returning intermediates.
Expand All @@ -196,7 +243,15 @@ def fold(
def scan_compatible_fn(carry, *args, **kwargs):
return fn(carry, *args, **kwargs), None

scan_preconfig = scan(scan_compatible_fn, axis, reverse=reverse, unroll=unroll, is_scanned=is_scanned)
scan_preconfig = scan(
scan_compatible_fn,
axis,
reverse=reverse,
unroll=unroll,
is_scanned=is_scanned,
grad_checkpointing=grad_checkpointing,
checkpoint_blocks=checkpoint_blocks,
)

def scanned_f(init, *args, **kwargs):
return scan_preconfig(init, *args, **kwargs)[0]
Expand Down Expand Up @@ -359,7 +414,7 @@ def wrapped_fn(args, kwargs):
result = eqx.combine(result_dynamic, result_static.value)

# if we were passed in a string arg, we need to get its axis size out from some result
true_axis = _infer_axis_size_from_result(result, axis)
true_axis = _infer_axis_size_from_tree(result, axis)
if true_axis is None:
raise ValueError("vmap failed to infer axis size from result")

Expand All @@ -369,17 +424,19 @@ def wrapped_fn(args, kwargs):
return wrapped_vmap_fn


def _infer_axis_size_from_result(result, axis):
def _infer_axis_size_from_tree(result, axis):
if isinstance(axis, str):
result_leaves = jax.tree_util.tree_leaves(result, is_leaf=_is_passive_array)
if len(result_leaves) == 0:
# this really shouldn't happen
return None
if isinstance(result_leaves[0], _PassiveNamedArray):
true_axis_size = result_leaves[0].array.shape[0] # batch axis is defined to be 0 above
leaf = result_leaves[0]
if isinstance(leaf, _PassiveNamedArray):
true_axis_size = leaf.array.shape[0] # batch axis is defined to be 0 above
true_axis = Axis(axis, true_axis_size)
else:
true_axis_size = result_leaves[0].shape[0] # batch axis is defined to be 0 above
elif isinstance(leaf, NamedArray):
true_axis = leaf.resolve_axis(axis)
elif isinstance(leaf, jax.numpy.ndarray) and leaf.ndim > 0:
true_axis_size = leaf.shape[0] # batch axis is defined to be 0 above
true_axis = Axis(axis, true_axis_size)
else:
true_axis = axis
Expand Down Expand Up @@ -424,7 +481,7 @@ def tree_unflatten(cls, aux, tree: Any) -> Any:


def _is_passive_array(arr):
return isinstance(arr, _PassiveNamedArray)
return isinstance(arr, _PassiveNamedArray) or isinstance(arr, NamedArray)


def _prepend_named_batch_axis(leading_axis: Axis):
Expand Down
80 changes: 80 additions & 0 deletions src/haliax/jax_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import functools as ft
import typing
from typing import Any, Callable, List, Optional, Sequence, Union
Expand All @@ -9,6 +10,8 @@
from jax import random as jrandom
from jaxtyping import PRNGKeyArray

import haliax


F = typing.TypeVar("F", bound=Callable[..., Any])

Expand Down Expand Up @@ -140,3 +143,80 @@ def is_pallas_dslice(x: object) -> bool:

_PALLAS_DSLICE_TYPE = type(pdslice(0, 1))
return isinstance(x, _PALLAS_DSLICE_TYPE)


def is_scalarish(x):
if isinstance(x, haliax.NamedArray):
return x.ndim == 0
else:
return jnp.isscalar(x) or x.shape == ()


def checkpointed_scan(
body_fn,
init,
xs,
lengths: Sequence[int],
*,
reverse: bool = False,
policy: Optional[Callable[..., bool]] = None,
prevent_cse: bool = False,
):
"""
Runs a recursive checkpointed scan over xs, where the scan is split into multiple scans, each of which has length
lengths[i] for some i.

This uses less memory than not checkpointing a scan, but more than

Note this uses "vanilla" JAX arrays, not NamedArrays

"""
if len(lengths) == 1:
return jax.lax.scan(jax.checkpoint(body_fn, prevent_cse=prevent_cse, policy=policy), init, xs, lengths[0])
else:
# we want to split the scan up into multiple recursive scans, doing a total of `prod(lengths)` steps
# this makes a tree of scans with depth len(lengths)
# check total length against any xs
total_length = np.prod(lengths)

def check_leaf(x):
assert x.shape[0] == total_length

jax.tree_util.tree_map(lambda x: check_leaf(x), xs)

ckpt = functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)

@ckpt
def _body_fn(carry, i, start):
my_xs = jax.tree_util.tree_map(lambda x: x[start + i], xs)
return body_fn(carry, my_xs)

def rec_scan_fn(lengths):
# returns a fn that, when called, scans over prod(lengths) steps recursively
if len(lengths) == 1:
range = jnp.arange(lengths[0])
return ckpt(
lambda carry, start: jax.lax.scan(
functools.partial(_body_fn, start=start), carry, range, lengths[0], reverse=reverse
)
)
else:
my_len = lengths[0]
rest_len = lengths[1:]
range_to_scan = jnp.arange(my_len) * np.prod(rest_len)
return ckpt(
lambda carry, start: jax.lax.scan(
rec_scan_fn(rest_len),
carry,
range_to_scan + start,
reverse=reverse,
)
)

res, unflattened = rec_scan_fn(lengths)(init, 0)

# need to flatten the output
# we need to flatten the leading len(lengths) dimensions of the output
flattened = jax.tree_util.tree_map(lambda y: jnp.reshape(y, (-1,) + y.shape[len(lengths) :]), unflattened)

return res, flattened
31 changes: 25 additions & 6 deletions src/haliax/nn/scan.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import functools
import math
from typing import Dict, Generic, Optional, Protocol, Sequence, Type, TypeVar

import equinox as eqx
import jax

import haliax
import haliax.util
from haliax.jax_utils import filter_checkpoint
from haliax.jax_utils import filter_checkpoint, named_call

from ..axis import Axis

Expand Down Expand Up @@ -70,7 +71,7 @@ class Stacked(eqx.Module, Generic[M]):

@staticmethod
def init(
Block: Axis, module: Type[M], *, gradient_checkpointing: bool = False, prevent_cse: bool = True
Block: Axis, module: Type[M], *, gradient_checkpointing: bool = False, prevent_cse: bool = False
) -> ModuleInit["Stacked[M]"]:
"""
Initialize a Stacked module. This method is curried: you can pass in the Block and module, and it will return
Expand All @@ -89,16 +90,34 @@ def fn(*args, **kwargs):

return fn

def scan(self, init, *extra_args, **extra_kwargs):
@named_call(name="Stacked.scan")
def scan(self, init, *args, **kwargs):
if self.gradient_checkpointing:
do_block = filter_checkpoint(self._do_block, prevent_cse=self.prevent_cse)
# determine a checkpoint block size, should be roughly sqrt(self.Block.size)
size = int(math.sqrt(self.Block.size))
num_blocks = int(math.ceil(self.Block.size / size))
rest = self.Block.size // size
block_spec = [num_blocks, rest]

return haliax.scan(
do_block, self.Block, grad_checkpointing=self.gradient_checkpointing, checkpoint_blocks=block_spec
)(init, self.stacked, *args, **kwargs)
else:
do_block = self._do_block
return haliax.scan(do_block, self.Block)(init, self.stacked, *extra_args, **extra_kwargs)
return haliax.scan(self._do_block, self.Block)(init, self.stacked, *args, **kwargs)

@named_call(name="Stacked.fold")
def fold(self, init, *args, **kwargs):
print(f"FOLD! {self.gradient_checkpointing} {self.prevent_cse}", flush=True)
if self.gradient_checkpointing:
do_block = filter_checkpoint(self._do_block)
do_block = filter_checkpoint(self._do_block, prevent_cse=self.prevent_cse)
# determine a checkpoint block size, should be roughly sqrt(self.Block.size)
size = int(math.sqrt(self.Block.size))
num_blocks = int(math.ceil(self.Block.size / size))

return haliax.fold(
do_block, self.Block, grad_checkpointing=self.gradient_checkpointing, checkpoint_blocks=[num_blocks]
)(init, self.stacked, *args, **kwargs)
else:
do_block = self._do_block

Expand Down
38 changes: 38 additions & 0 deletions tests/test_hof.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,41 @@ def __call__(self, x):
Width = Axis("Width", 3)

hax.vmap(lambda a: Module(a), Batch)(Width)


def test_scan_raises_with_string_arg_and_no_args():
def scan_fun(acc):
return acc, acc

try:
hax.scan(scan_fun, "Height")(0.0)
except ValueError as e:
assert "scan requires either an actual Axis or at least one NamedArray or array" in str(e)
else:
assert False, "should have raised"


def test_scan_works_with_string_arg_and_one_arg():
Height = Axis("Height", 10)
named1 = hax.random.uniform(PRNGKey(0), (Height,))

def scan_fun(acc, x):
return acc + x.scalar(), x

total, named2 = hax.scan(scan_fun, "Height")(0.0, named1)

assert jnp.all(jnp.isclose(total, jnp.sum(named1.array)))
assert jnp.all(jnp.equal(named1.array, named2.array))


def test_scan_works_with_string_and_unnamed_args():
Height = Axis("Height", 10)
named1 = hax.random.uniform(PRNGKey(0), (Height,))

def scan_fun(acc, x):
return acc + x, x

total, named2 = hax.scan(scan_fun, "Height")(0.0, named1.array)

assert jnp.all(jnp.isclose(total, jnp.sum(named1.array)))
assert jnp.all(jnp.equal(named1.array, named2))
Loading