Skip to content

Cooley-Tukey in Scala #746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions contents/cooley_tukey/code/scala/fft.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@

case class Complex(val real: Double, val imag: Double) {
def +(other: Complex): Complex = new Complex(this.real + other.real, this.imag + other.imag)
def -(other: Complex): Complex = new Complex(this.real - other.real, this.imag - other.imag)

def *(other: Complex): Complex = new Complex(
this.real * other.real - this.imag * other.imag,
this.real * other.imag + this.imag * other.real)

def *(other: Double): Complex = new Complex(this.real * other, this.imag * other)
}

object Complex {
def fromPolar(mag: Double, phase: Double): Complex =
new Complex(mag * Math.cos(phase), mag * Math.sin(phase))
}

object FT {

/** Calculates a single DFT coefficient. */
private def coefficient(n: Int, k: Int, ftLength: Int): Complex =
Complex.fromPolar(1.0, -2.0 * Math.PI * k * n / ftLength)

/** Calculates one value of the DFT of signal */
private def dftValue(signal: IndexedSeq[Double], k: Int): Complex = {
// Multiply the signal with the coefficients vector
val terms = for (i <- signal.indices) yield coefficient(i, k, signal.length) * signal(i)

terms reduce { _ + _ }
}

def dft(signal: IndexedSeq[Double]): IndexedSeq[Complex] =
signal.indices map { dftValue(signal, _) }

/** Combines the transforms of the even and odd indices */
private def mergeTransforms(evens: IndexedSeq[Complex], odds: IndexedSeq[Complex]): IndexedSeq[Complex] = {
val oddTerms = for (i <- odds.indices)
yield coefficient(1, i, 2 * odds.length) * odds(i)

val pairs = evens.zip(oddTerms)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't you need to calculate even terms too?

Copy link
Contributor Author

@0xJonas 0xJonas Oct 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because only the odd values are multiplied with the coefficients. The 'evenTerms' would just be the 'evens' parameter.


(pairs map { case (e, o) => e + o }) ++ (pairs map { case (e, o) => e - o })
}

def cooleyTukey(signal: IndexedSeq[Double]): IndexedSeq[Complex] = signal.length match {
case 2 => mergeTransforms(
Vector(new Complex(signal(0), 0)),
Vector(new Complex(signal(1), 0)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the last parenthesis be on a separate line for this code style?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find anything in the style guide on that topic, but all the examples put the closing paren on the same line, so I did it as well.

case _ => {
// Split signal into even and odd indices and call cooleyTukey recursively on each.
val evens = cooleyTukey(for (i <- 0 until signal.length by 2) yield signal(i))
val odds = cooleyTukey(for (i <- 1 until signal.length by 2) yield signal(i))

mergeTransforms(evens, odds)
}
}

/** Reverses the bits in value. */
private def reverseBits(value: Int, length: Int): Int = length match {
case 1 => value
case _ => {
// Split bits in the middle.
val lowerHalf = length / 2

// The upper half will be longer if the number of bits is odd.
val upperHalf = length - lowerHalf
val mask = (1 << lowerHalf) - 1

// Reverse each half recursively and then swap them.
(reverseBits(value & mask, lowerHalf) << upperHalf) +
reverseBits(value >> lowerHalf, upperHalf)
}
}


private def log2(x: Double): Double = Math.log(x) / Math.log(2.0)

def bitReverseIndices(signal: IndexedSeq[Double]): IndexedSeq[Double] = {
// Find the maximum number of bits needed.
val bitLength = log2(signal.length).ceil.toInt

for (i <- signal.indices)
yield signal(reverseBits(i, bitLength))
}

private def butterfly(x1: Complex, x2: Complex, coeff: Complex): (Complex, Complex) =
(x1 + coeff * x2, x1 - coeff * x2)

@scala.annotation.tailrec
private def iterativeCooleyTukeyLoop(signal: IndexedSeq[Complex], dist: Int): IndexedSeq[Complex] = {
// Distance between subsequent groups of butterflies
val stride = 2 * dist
val result = new Array[Complex](signal.length)

for (groupStart <- 0 until signal.length by stride; i <- 0 until dist) {
val index = groupStart + i
val (r1, r2) = butterfly(
signal(index),
signal(index + dist),
coefficient(1, i, stride))

result(index) = r1
result(index + dist) = r2
}

if (stride >= signal.length)
result.toVector
else
iterativeCooleyTukeyLoop(result.toVector, dist * 2)
}

def iterativeCooleyTukey(signal: IndexedSeq[Double]): IndexedSeq[Complex] =
iterativeCooleyTukeyLoop(bitReverseIndices(signal) map { new Complex(_, 0.0) }, 1)
}

object Main {

private def approxEqual(a: IndexedSeq[Complex], b: IndexedSeq[Complex]): Boolean = {
val diffs = a.zip(b) map { case (x, y) => Math.abs(x.real - y.real + x.imag - y.imag) }
diffs map { _ < 1e-12 } reduce { _ && _ }
}

def main(args: Array[String]): Unit = {
val signal = for (i <- 0 until 16) yield Math.random() * 2 - 1
val x = FT.dft(signal)
val y = FT.cooleyTukey(signal)
val z = FT.iterativeCooleyTukey(signal)

println("DFT and Cooley-Tukey approx. equal: " + approxEqual(x, y))
println("DFT and iterative Cooley-Tukey approx. equal: " + approxEqual(x, z))
println("Cooley-Tukey and iterative Cooley-Tukey approx. equal: " + approxEqual(y, z))
}
}
6 changes: 6 additions & 0 deletions contents/cooley_tukey/cooley_tukey.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ For some reason, though, putting code to this transformation really helped me fi
[import:3-15, lang:"javascript"](code/javascript/fft.js)
{% sample lang="rs" %}
[import:24-37, lang:"rust"](code/rust/fft.rs)
{% sample lang="scala" %}
[import:20-33, lang:"scala"](code/scala/fft.scala)
{% endmethod %}

In this function, we define `n` to be a set of integers from $$0 \rightarrow N-1$$ and arrange them to be a column.
Expand Down Expand Up @@ -142,6 +144,8 @@ In the end, the code looks like:
[import:17-39, lang="javascript"](code/javascript/fft.js)
{% sample lang="rs" %}
[import:39-55, lang:"rust"](code/rust/fft.rs)
{% sample lang="scala" %}
[import:35-56, lang:"scala"](code/scala/fft.scala)
{% endmethod %}

As a side note, we are enforcing that the array must be a power of 2 for the operation to work.
Expand Down Expand Up @@ -255,6 +259,8 @@ Some rather impressive scratch code was submitted by Jie and can be found here:
[import, lang:"javascript"](code/javascript/fft.js)
{% sample lang="rs" %}
[import, lang:"rust"](code/rust/fft.rs)
{% sample lang="scala" %}
[import, lang:"scala"](code/scala/fft.scala)
{% endmethod %}

<script>
Expand Down