Skip to content

Commit

Permalink
improvement: Add generic enumerable conversion with adequate warnings.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimsynz committed Jul 19, 2024
1 parent 5373e19 commit 99f862a
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 49 deletions.
116 changes: 67 additions & 49 deletions lib/iter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ defmodule Iter do
Is the passed value an iterator?
"""
@spec is_iter(any) :: Macro.output()
defguard is_iter(value) when is_struct(value, __MODULE__)
defguard is_iter(iter) when is_struct(iter, __MODULE__)

@doc """
Returns `true` if all elements in the iterator are truthy.
Expand Down Expand Up @@ -114,7 +114,7 @@ defmodule Iter do
false
"""
@spec any?(t) :: boolean
def any?(iterable), do: any?(iterable, &(&1 not in [nil, false]))
def any?(iter), do: any?(iter, &(&1 not in [nil, false]))

@doc """
Returns `true` if `fun.(element)` is truthy for at least one element in the
Expand Down Expand Up @@ -142,8 +142,8 @@ defmodule Iter do
false
"""
@spec any?(t, predicate) :: boolean
def any?(iterable, predicate) when is_iter(iterable) and is_function(predicate, 1),
do: Iterable.any?(iterable.iterable, predicate)
def any?(iter, predicate) when is_iter(iter) and is_function(predicate, 1),
do: Iterable.any?(iter.iterable, predicate)

@doc """
Append a new element to the end of the iterable.
Expand Down Expand Up @@ -362,12 +362,12 @@ defmodule Iter do
[1, 2, 3, 4]
"""
@spec concat(t) :: t
def concat(iter) when is_iter(iter),
do:
iter.iterable
|> Iterable.map(&IntoIterable.into_iterable/1)
|> Iterable.concat()
|> new()
def concat(iter) when is_iter(iter) do
iter.iterable
|> Iterable.map(&IntoIterable.into_iterable/1)
|> Iterable.concat()
|> new()
end

@doc """
Creates an iterator that iterates the first argument, followed by the second argument.
Expand All @@ -380,11 +380,11 @@ defmodule Iter do
[1, 2, 3, 4, 5, 6]
"""
@spec concat(t, t) :: t
def concat(lhs, rhs) when is_iter(lhs) and is_iter(rhs),
do:
[lhs.iterable, rhs.iterable]
|> Iterable.concat()
|> new()
def concat(lhs, rhs) when is_iter(lhs) and is_iter(rhs) do
[lhs.iterable, rhs.iterable]
|> Iterable.concat()
|> new()
end

@doc """
Counts the elements in iterator stopping at `limit`.
Expand Down Expand Up @@ -737,10 +737,7 @@ defmodule Iter do
"""
@spec flat_map(t, mapper) :: t
def flat_map(iter, mapper) when is_iter(iter) and is_function(mapper, 1),
do:
iter.iterable
|> Iterable.flat_map(mapper)
|> new()
do: iter.iterable |> Iterable.flat_map(mapper) |> new()

@doc """
Flattens nested iterators.
Expand All @@ -767,6 +764,24 @@ defmodule Iter do
def from(maybe_iterable),
do: maybe_iterable |> IntoIterable.into_iterable() |> new()

@doc """
Convert an `Enumerable` into an `Iter`.
Provides an `Enumerable` compatible source for `Iter` using a `GenServer` to
orchestrate reduction and block as required.
> #### Warning {: .warning}
> You should almost always implement `IntoIterable` for your enumerable and
> use `from/1` rather than resort to calling this function. Unfortunately it
> cannot always be avoided.
"""
@spec from_enum(Enumerable.t()) :: t
def from_enum(enumerable) do
enumerable
|> Iterable.Enumerable.new()
|> new()
end

@doc """
Intersperses `separator` between each element of the iterator.
Expand Down Expand Up @@ -804,16 +819,16 @@ defmodule Iter do
[0, 1, 2, 3, 4]
"""
@spec iterate(element, (element -> element)) :: t
def iterate(start_value, next_fun) when is_function(next_fun, 1),
do:
Iterable.Resource.new(
fn -> start_value end,
fn acc ->
{[acc], next_fun.(acc)}
end,
fn _ -> :ok end
)
|> new()
def iterate(start_value, next_fun) when is_function(next_fun, 1) do
Iterable.Resource.new(
fn -> start_value end,
fn acc ->
{[acc], next_fun.(acc)}
end,
fn _ -> :ok end
)
|> new()
end

@doc """
Creates a new iterator which applies `mapper` on every `nth` element of the
Expand Down Expand Up @@ -865,10 +880,7 @@ defmodule Iter do
"""
@spec map(t, mapper) :: t
def map(iter, mapper) when is_iter(iter) and is_function(mapper, 1),
do:
iter.iterable
|> Iterable.map(mapper)
|> new()
do: iter.iterable |> Iterable.map(mapper) |> new()

@doc """
Returns the maximal element in the iterator according to Erlang's term sorting.
Expand Down Expand Up @@ -1344,12 +1356,12 @@ defmodule Iter do
[12, 15, 18]
"""
@spec zip_with(t, ([element] -> any)) :: t
def zip_with(iter, zipper) when is_iter(iter) and is_function(zipper, 1),
do:
iter.iterable
|> Iterable.map(&IntoIterable.into_iterable/1)
|> Iterable.zip(zipper)
|> new()
def zip_with(iter, zipper) when is_iter(iter) and is_function(zipper, 1) do
iter.iterable
|> Iterable.map(&IntoIterable.into_iterable/1)
|> Iterable.zip(zipper)
|> new()
end

@doc """
Zips corresponding elements from two iterators into a new one, transforming
Expand All @@ -1367,8 +1379,11 @@ defmodule Iter do
[5, 7, 9]
"""
@spec zip_with(t, t, (element, element -> any)) :: t
def zip_with(lhs, rhs, zipper) when is_iter(lhs) and is_iter(rhs) and is_function(zipper, 2),
do: [lhs.iterable, rhs.iterable] |> Iterable.zip(fn [a, b] -> zipper.(a, b) end) |> new()
def zip_with(lhs, rhs, zipper) when is_iter(lhs) and is_iter(rhs) and is_function(zipper, 2) do
[lhs.iterable, rhs.iterable]
|> Iterable.zip(fn [a, b] -> zipper.(a, b) end)
|> new()
end

@doc """
Zips corresponding elements from a finite collection of iterators into one iterator of tuples.
Expand All @@ -1387,12 +1402,12 @@ defmodule Iter do
[{1, :a, "a"}, {2, :b, "b"}, {3, :c, "c"}]
"""
@spec zip(t) :: t
def zip(iter) when is_iter(iter),
do:
iter.iterable
|> Iterable.map(&IntoIterable.into_iterable/1)
|> Iterable.zip(&List.to_tuple/1)
|> new()
def zip(iter) when is_iter(iter) do
iter.iterable
|> Iterable.map(&IntoIterable.into_iterable/1)
|> Iterable.zip(&List.to_tuple/1)
|> new()
end

@doc """
Zips to iterators together.
Expand All @@ -1408,8 +1423,11 @@ defmodule Iter do
[{1, :a}, {2, :b}, {3, :c}]
"""
@spec zip(t, t) :: t
def zip(lhs, rhs) when is_iter(lhs) and is_iter(rhs),
do: [lhs.iterable, rhs.iterable] |> Iterable.zip(&List.to_tuple/1) |> new()
def zip(lhs, rhs) when is_iter(lhs) and is_iter(rhs) do
[lhs.iterable, rhs.iterable]
|> Iterable.zip(&List.to_tuple/1)
|> new()
end

defp new(iterable), do: %__MODULE__{iterable: iterable}
end
107 changes: 107 additions & 0 deletions lib/iter/iterable/enumerable.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
defmodule Iter.Iterable.Enumerable do
defstruct pid: nil

@moduledoc """
Can we convert a enum into an iterable? Let's find out.
"""

alias Iter.{Impl, IntoIterable, Iterable}

@type t :: %__MODULE__{pid: pid}

use GenServer

@doc "Wrap an enumerable in a genserver"
@spec new(Enumerable.t()) :: t
def new(enum) do
case GenServer.start_link(__MODULE__, enum) do
{:ok, pid} -> %__MODULE__{pid: pid}
{:error, reason} -> raise reason
end
end

defimpl Iterable do
use Impl

@doc false
@impl true
def next(%{pid: pid} = enum) do
case GenServer.call(pid, :next, :infinity) do
{:ok, element} -> {:ok, element, enum}
:done -> :done
end
catch
:exit, _ -> :done
end
end

defimpl IntoIterable do
@doc false
@impl true
def into_iterable(self), do: self
end

@doc false
@impl GenServer
def init(enum) do
{:ok, %{enum: enum}}
end

@doc false
@impl GenServer
def handle_call(:next, from, %{enum: enum}) do
receiver = self()

{:ok, pid} =
Task.start_link(fn ->
Enum.reduce_while(enum, :ok, fn element, :ok ->
case GenServer.call(receiver, {:element, element}, :infinity) do

Check warning on line 58 in lib/iter/iterable/enumerable.ex

View workflow job for this annotation

GitHub Actions / Continuous Integration

Function body is nested too deep (max depth is 2, was 3).
:ok -> {:cont, :ok}
:halt -> {:halt, :ok}
end
end)
end)

Process.monitor(pid)

{:noreply, %{next_reply_to: from, source: pid}}
end

def handle_call(:next, _from, %{element: element, element_reply_to: from} = state) do
GenServer.reply(from, :ok)
{:reply, {:ok, element}, Map.drop(state, [:element, :element_reply_to])}
end

def handle_call(:next, _from, %{source: :done} = state) do
{:stop, :normal, :done, state}
end

def handle_call(:next, from, state) do
{:noreply, Map.put(state, :next_reply_to, from)}
end

def handle_call({:element, element}, from, state) do
case Map.pop(state, :next_reply_to) do
{nil, state} ->
state = Map.merge(state, %{element: element, element_reply_to: from})
{:noreply, state}

{from, state} ->
GenServer.reply(from, {:ok, element})
{:reply, :ok, state}
end
end

@doc false
@impl true
def handle_info({:DOWN, _, :process, pid, _}, %{source: pid} = state)
when is_map_key(state, :next_reply_to) do
GenServer.reply(state.next_reply_to, :done)

{:stop, :normal}
end

def handle_info({:DOWN, _, :process, pid, _}, %{source: pid} = state) do
{:noreply, %{state | source: :done}}
end
end
20 changes: 20 additions & 0 deletions test/iter/iterable/enumerable_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
defmodule Iter.Iterable.EnumerableTest do
@moduledoc false
use ExUnit.Case, async: true
alias Iter.Iterable

test "it can iterate a normal enum" do
iterable = Iterable.Enumerable.new([1, 2, 3])
assert [1, 2, 3] = Iterable.to_list(iterable)
end

test "it can iterate an infinite stream" do
stream = Stream.cycle([1, 2, 3])

assert [1, 2, 3, 1, 2] =
stream
|> Iterable.Enumerable.new()
|> Iterable.take_head(5)
|> Iterable.to_list()
end
end

0 comments on commit 99f862a

Please sign in to comment.