Skip to content

Commit

Permalink
Update use_with_jax.mdx
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq authored Jun 4, 2024
1 parent 8e37eaf commit fd3566c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions docs/source/use_with_jax.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ If your dataset consists of N-dimensional arrays, you will see that by default t

```py
>>> from datasets import Dataset
>>> data = [[[1, 2],[3, 4]], [[5, 6],[7, 8]]]
>>> data = [[[1, 2],[3, 4]], [[5, 6],[7, 8]]] # fixed shape
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("jax")
>>> ds[0]
Expand All @@ -93,7 +93,7 @@ If your dataset consists of N-dimensional arrays, you will see that by default t

```py
>>> from datasets import Dataset
>>> data = [[[1, 2],[3]], [[4, 5, 6],[7, 8]]]
>>> data = [[[1, 2],[3]], [[4, 5, 6],[7, 8]]] # varying shape
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("jax")
>>> ds[0]
Expand Down

0 comments on commit fd3566c

Please sign in to comment.