diff --git a/lib/data_structures/seg_tree.ex b/lib/data_structures/seg_tree.ex new file mode 100644 index 0000000..33d5346 --- /dev/null +++ b/lib/data_structures/seg_tree.ex @@ -0,0 +1,116 @@ +defmodule Algorithms.DataStructures.SegTree do + @moduledoc """ + Generic Segment Tree implementation in Elixir. + + Usage example: + + op = fn a, b -> min(a, b) end + e = fn -> :infinity end + + seg = Algorithms.DataStructures.SegTree.new([5, 3, 7, 9, 6], op, e) + Algorithms.DataStructures.SegTree.prod(seg, 1, 4) # => 3 + seg = Algorithms.DataStructures.SegTree.set(seg, 2, 1) + Algorithms.DataStructures.SegTree.prod(seg, 1, 4) # => 1 + """ + + defstruct size: 0, n: 0, data: nil, op: nil, e: nil + + @type t :: %__MODULE__{ + size: non_neg_integer(), + n: non_neg_integer(), + data: :array.array(), + op: (any(), any() -> any()), + e: (() -> any()) + } + + @spec new(list(any()), (any(), any() -> any()), (() -> any())) :: t() + def new(list, op, e) do + n = length(list) + size = bit_ceil(n) + data = :array.new(size * 2, default: e.()) + + data = + Enum.with_index(list) + |> Enum.reduce(data, fn {v, i}, acc -> + :array.set(size + i, v, acc) + end) + + data = build_tree(data, size, op) + + %__MODULE__{size: size, n: n, data: data, op: op, e: e} + end + + defp build_tree(data, size, op) do + Enum.reduce(Enum.reverse(1..(size - 1)), data, fn i, acc -> + left = :array.get(2 * i, acc) + right = :array.get(2 * i + 1, acc) + :array.set(i, op.(left, right), acc) + end) + end + + @spec set(t(), non_neg_integer(), any()) :: t() + def set(%__MODULE__{size: size, data: data, op: op} = st, p, x) do + data = :array.set(size + p, x, data) + data = update_up(size + p, data, op) + %__MODULE__{st | data: data} + end + + defp update_up(1, data, _op), do: data + + defp update_up(k, data, op) do + parent = div(k, 2) + left = :array.get(2 * parent, data) + right = :array.get(2 * parent + 1, data) + data = :array.set(parent, op.(left, right), data) + update_up(parent, data, op) + end + + @spec get(t(), non_neg_integer()) :: any() + def get(%__MODULE__{size: size, data: data}, p) do + :array.get(size + p, data) + end + + @spec prod(t(), non_neg_integer(), non_neg_integer()) :: any() + def prod(%__MODULE__{size: size, data: data, op: op, e: e}, l, r) do + do_prod(l + size, r + size, e.(), e.(), data, op) + end + + defp do_prod(l, r, sml, smr, data, op) do + do_prod_loop(l, r, sml, smr, data, op) + end + + defp do_prod_loop(l, r, sml, smr, data, op) when l < r do + sml = + if rem(l, 2) == 1 do + op.(sml, :array.get(l, data)) + else + sml + end + + l = if rem(l, 2) == 1, do: l + 1, else: l + + smr = + if rem(r, 2) == 1 do + op.(:array.get(r - 1, data), smr) + else + smr + end + + r = if rem(r, 2) == 1, do: r - 1, else: r + + do_prod_loop(div(l, 2), div(r, 2), sml, smr, data, op) + end + + defp do_prod_loop(_, _, sml, smr, _, op), do: op.(sml, smr) + + @spec all_prod(t()) :: any() + def all_prod(%__MODULE__{data: data}), do: :array.get(1, data) + + @spec bit_ceil(non_neg_integer()) :: non_neg_integer() + defp bit_ceil(0), do: 1 + + defp bit_ceil(n) when n > 0 do + pow = :math.ceil(:math.log2(n)) + trunc(:math.pow(2, pow)) + end +end diff --git a/test/data_structures/seg_tree_test.exs b/test/data_structures/seg_tree_test.exs new file mode 100644 index 0000000..116fae3 --- /dev/null +++ b/test/data_structures/seg_tree_test.exs @@ -0,0 +1,56 @@ +defmodule Algorithms.DataStructures.SegTreeTest do + use ExUnit.Case + + alias Algorithms.DataStructures.SegTree + + describe "Segment Tree basic functionality" do + setup do + op = fn a, b -> min(a, b) end + e = fn -> :infinity end + + seg = SegTree.new([5, 3, 7, 9, 6], op, e) + %{seg: seg, op: op, e: e} + end + + test "prod returns correct range minimum", %{seg: seg} do + assert SegTree.prod(seg, 0, 5) == 3 + assert SegTree.prod(seg, 1, 4) == 3 + assert SegTree.prod(seg, 2, 4) == 7 + end + + test "get returns correct value", %{seg: seg} do + assert SegTree.get(seg, 0) == 5 + assert SegTree.get(seg, 1) == 3 + assert SegTree.get(seg, 4) == 6 + end + + test "set updates value and affects prod", %{seg: seg} do + seg = SegTree.set(seg, 1, 10) + assert SegTree.get(seg, 1) == 10 + assert SegTree.prod(seg, 0, 5) == 5 + + seg = SegTree.set(seg, 0, 1) + assert SegTree.prod(seg, 0, 5) == 1 + end + + test "all_prod returns correct result", %{seg: seg} do + assert SegTree.all_prod(seg) == 3 + end + end + + describe "Segment Tree with sum operation" do + setup do + op = fn a, b -> a + b end + e = fn -> 0 end + + seg = SegTree.new([1, 2, 3, 4, 5], op, e) + %{seg: seg} + end + + test "prod returns correct range sum", %{seg: seg} do + assert SegTree.prod(seg, 0, 5) == 15 + assert SegTree.prod(seg, 1, 3) == 5 + assert SegTree.prod(seg, 2, 4) == 7 + end + end +end