Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions lib/data_structures/seg_tree.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
defmodule Algorithms.DataStructures.SegTree do
@moduledoc """
Generic Segment Tree implementation in Elixir.

Usage example:

op = fn a, b -> min(a, b) end
e = fn -> :infinity end

seg = Algorithms.DataStructures.SegTree.new([5, 3, 7, 9, 6], op, e)
Algorithms.DataStructures.SegTree.prod(seg, 1, 4) # => 3
seg = Algorithms.DataStructures.SegTree.set(seg, 2, 1)
Algorithms.DataStructures.SegTree.prod(seg, 1, 4) # => 1
"""

defstruct size: 0, n: 0, data: nil, op: nil, e: nil

@type t :: %__MODULE__{
size: non_neg_integer(),
n: non_neg_integer(),
data: :array.array(),
op: (any(), any() -> any()),
e: (() -> any())
}

@spec new(list(any()), (any(), any() -> any()), (() -> any())) :: t()
def new(list, op, e) do
n = length(list)
size = bit_ceil(n)
data = :array.new(size * 2, default: e.())

data =
Enum.with_index(list)
|> Enum.reduce(data, fn {v, i}, acc ->
:array.set(size + i, v, acc)
end)

data = build_tree(data, size, op)

%__MODULE__{size: size, n: n, data: data, op: op, e: e}
end

defp build_tree(data, size, op) do
Enum.reduce(Enum.reverse(1..(size - 1)), data, fn i, acc ->
left = :array.get(2 * i, acc)
right = :array.get(2 * i + 1, acc)
:array.set(i, op.(left, right), acc)
end)
end

@spec set(t(), non_neg_integer(), any()) :: t()
def set(%__MODULE__{size: size, data: data, op: op} = st, p, x) do
data = :array.set(size + p, x, data)
data = update_up(size + p, data, op)
%__MODULE__{st | data: data}
end

defp update_up(1, data, _op), do: data

defp update_up(k, data, op) do
parent = div(k, 2)
left = :array.get(2 * parent, data)
right = :array.get(2 * parent + 1, data)
data = :array.set(parent, op.(left, right), data)
update_up(parent, data, op)
end

@spec get(t(), non_neg_integer()) :: any()
def get(%__MODULE__{size: size, data: data}, p) do
:array.get(size + p, data)
end

@spec prod(t(), non_neg_integer(), non_neg_integer()) :: any()
def prod(%__MODULE__{size: size, data: data, op: op, e: e}, l, r) do
do_prod(l + size, r + size, e.(), e.(), data, op)
end

defp do_prod(l, r, sml, smr, data, op) do
do_prod_loop(l, r, sml, smr, data, op)
end

defp do_prod_loop(l, r, sml, smr, data, op) when l < r do
sml =
if rem(l, 2) == 1 do
op.(sml, :array.get(l, data))
else
sml
end

l = if rem(l, 2) == 1, do: l + 1, else: l

smr =
if rem(r, 2) == 1 do
op.(:array.get(r - 1, data), smr)
else
smr
end

r = if rem(r, 2) == 1, do: r - 1, else: r

do_prod_loop(div(l, 2), div(r, 2), sml, smr, data, op)
end

defp do_prod_loop(_, _, sml, smr, _, op), do: op.(sml, smr)

@spec all_prod(t()) :: any()
def all_prod(%__MODULE__{data: data}), do: :array.get(1, data)

@spec bit_ceil(non_neg_integer()) :: non_neg_integer()
defp bit_ceil(0), do: 1

defp bit_ceil(n) when n > 0 do
pow = :math.ceil(:math.log2(n))
trunc(:math.pow(2, pow))
end
end
56 changes: 56 additions & 0 deletions test/data_structures/seg_tree_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defmodule Algorithms.DataStructures.SegTreeTest do
use ExUnit.Case

alias Algorithms.DataStructures.SegTree

describe "Segment Tree basic functionality" do
setup do
op = fn a, b -> min(a, b) end
e = fn -> :infinity end

seg = SegTree.new([5, 3, 7, 9, 6], op, e)
%{seg: seg, op: op, e: e}
end

test "prod returns correct range minimum", %{seg: seg} do
assert SegTree.prod(seg, 0, 5) == 3
assert SegTree.prod(seg, 1, 4) == 3
assert SegTree.prod(seg, 2, 4) == 7
end

test "get returns correct value", %{seg: seg} do
assert SegTree.get(seg, 0) == 5
assert SegTree.get(seg, 1) == 3
assert SegTree.get(seg, 4) == 6
end

test "set updates value and affects prod", %{seg: seg} do
seg = SegTree.set(seg, 1, 10)
assert SegTree.get(seg, 1) == 10
assert SegTree.prod(seg, 0, 5) == 5

seg = SegTree.set(seg, 0, 1)
assert SegTree.prod(seg, 0, 5) == 1
end

test "all_prod returns correct result", %{seg: seg} do
assert SegTree.all_prod(seg) == 3
end
end

describe "Segment Tree with sum operation" do
setup do
op = fn a, b -> a + b end
e = fn -> 0 end

seg = SegTree.new([1, 2, 3, 4, 5], op, e)
%{seg: seg}
end

test "prod returns correct range sum", %{seg: seg} do
assert SegTree.prod(seg, 0, 5) == 15
assert SegTree.prod(seg, 1, 3) == 5
assert SegTree.prod(seg, 2, 4) == 7
end
end
end