Skip to content

Commit a91d91f

Browse files
Merge pull request #1367 from j2kun:chebyshev
PiperOrigin-RevId: 725244577
2 parents d097c7d + e5de7db commit a91d91f

File tree

9 files changed

+236
-7
lines changed

9 files changed

+236
-7
lines changed

LICENSE

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ prospectively choose to deem waived or otherwise exclude such Section(s) of
218218
the License, but only in their entirety and only with respect to the Combined
219219
Software.
220220

221+
# Yosys
221222
Copyright (C) 2012 - 2018 Clifford Wolf <[email protected]>, <[email protected]>
222223

223224
Copyright (C) 2012 Martin Schmölzer <[email protected]>
@@ -242,3 +243,30 @@ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
242243
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
243244
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
244245
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
246+
247+
# PocketFFT
248+
Copyright (C) 2010-2018 Max-Planck-Society
249+
All rights reserved.
250+
251+
Redistribution and use in source and binary forms, with or without modification,
252+
are permitted provided that the following conditions are met:
253+
254+
* Redistributions of source code must retain the above copyright notice, this
255+
list of conditions and the following disclaimer.
256+
* Redistributions in binary form must reproduce the above copyright notice, this
257+
list of conditions and the following disclaimer in the documentation and/or
258+
other materials provided with the distribution.
259+
* Neither the name of the copyright holder nor the names of its contributors may
260+
be used to endorse or promote products derived from this software without
261+
specific prior written permission.
262+
263+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
264+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
265+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
266+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
267+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
268+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
269+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
270+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
271+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
272+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

WORKSPACE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,10 @@ git_repository(
421421
patches = ["@heir//bazel/openfhe:add_config_core.patch"],
422422
remote = "https://github.com/openfheorg/openfhe-development.git",
423423
)
424+
425+
git_repository(
426+
name = "pocketfft",
427+
build_file = "//bazel/pocketfft:pocketfft.BUILD",
428+
commit = "bb5bdb776c64819f66cb2205f78bef1581448628",
429+
remote = "https://gitlab.mpcdf.mpg.de/mtr/pocketfft.git",
430+
)

bazel/pocketfft/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# This build file is necessary to mark this directory as a subpackage for bazel
2+
# to have access to the files.
3+
4+
package(
5+
default_applicable_licenses = ["@heir//:license"],
6+
default_visibility = ["//visibility:public"],
7+
)

bazel/pocketfft/pocketfft.BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# BUILD file for a bazel-native pocketfft build
2+
package(
3+
default_visibility = ["//visibility:public"],
4+
)
5+
6+
licenses(["notice"])
7+
8+
cc_library(
9+
name = "pocketfft",
10+
hdrs = [
11+
"pocketfft_hdronly.h",
12+
],
13+
copts = [
14+
"-fexceptions",
15+
],
16+
features = [
17+
"-use_header_modules",
18+
],
19+
)

lib/Utils/Approximation/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cc_library(
1313
"@heir//lib/Utils/Polynomial",
1414
"@llvm-project//llvm:Support",
1515
"@llvm-project//mlir:Support",
16+
"@pocketfft",
1617
],
1718
)
1819

lib/Utils/Approximation/Chebyshev.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11

2+
#include <algorithm>
23
#include <cmath>
4+
#include <complex>
5+
#include <cstddef>
36
#include <cstdint>
7+
#include <vector>
48

59
#include "lib/Utils/Polynomial/Polynomial.h"
610
#include "llvm/include/llvm/ADT/APFloat.h" // from @llvm-project
711
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
12+
#include "pocketfft_hdronly.h" // from @pocketfft
813

914
namespace mlir {
1015
namespace heir {
@@ -69,6 +74,105 @@ void getChebyshevPolynomials(int64_t numPolynomials,
6974
}
7075
}
7176

77+
FloatPolynomial chebyshevToMonomial(const SmallVector<APFloat> &coefficients) {
78+
SmallVector<FloatPolynomial> chebPolys;
79+
chebPolys.reserve(coefficients.size());
80+
getChebyshevPolynomials(coefficients.size(), chebPolys);
81+
82+
FloatPolynomial result = FloatPolynomial::zero();
83+
for (int64_t i = 0; i < coefficients.size(); ++i) {
84+
result = result.add(chebPolys[i].scale(coefficients[i]));
85+
}
86+
87+
return result;
88+
}
89+
90+
void interpolateChebyshev(ArrayRef<APFloat> chebEvalPoints,
91+
SmallVector<APFloat> &outputChebCoeffs) {
92+
size_t n = chebEvalPoints.size();
93+
if (n == 0) {
94+
return;
95+
}
96+
if (n == 1) {
97+
outputChebCoeffs.push_back(chebEvalPoints[0]);
98+
return;
99+
}
100+
101+
// When the function being evaluated has even or odd symmetry, we can get
102+
// coefficients. In particular, even symmetry implies all odd-numbered
103+
// Chebyshev coefficients are zero. Odd symmetry implies even-numbered
104+
// coefficients are zero.
105+
bool isEven =
106+
std::equal(chebEvalPoints.begin(), chebEvalPoints.begin() + n / 2,
107+
chebEvalPoints.rbegin());
108+
109+
bool isOdd = true;
110+
for (int i = 0; i < n / 2; ++i) {
111+
if (chebEvalPoints[i] != -chebEvalPoints[(n - 1) - i]) {
112+
isOdd = false;
113+
break;
114+
}
115+
}
116+
117+
// Construct input to ifft so as to compute a Discrete Cosine Transform
118+
// The inputs are [v_{n-1}, v_{n-2}, ..., v_0, v_1, ..., v_{n-2}]
119+
std::vector<std::complex<double>> ifftInput;
120+
size_t fftLen = 2 * (n - 1);
121+
ifftInput.reserve(fftLen);
122+
for (size_t i = n - 1; i > 0; --i) {
123+
ifftInput.emplace_back(chebEvalPoints[i].convertToDouble());
124+
}
125+
for (size_t i = 0; i < n - 1; ++i) {
126+
ifftInput.emplace_back(chebEvalPoints[i].convertToDouble());
127+
}
128+
129+
// Compute inverse FFT using minimal API call to pocketfft. This should be
130+
// equivalent to numpy.fft.ifft, as it uses pocketfft underneath. It's worth
131+
// noting here that we're computing the Discrete Cosine Transform (DCT) in
132+
// terms of a complex Discrete Fourier Transform (DFT), but pocketfft appears
133+
// to have a built-in `dct` function. It may be trivial to switch to
134+
// pocketfft::dct, but this was originally based on a reference
135+
// implementation that did not have access to a native DCT. Migrating to a
136+
// DCT should only be necessary (a) once the reference implementation is
137+
// fully ported and tested, and (b) if we determine that there's a
138+
// performance benefit to using the native DCT. Since this routine is
139+
// expected to be used in doing relatively low-degree approximations, it
140+
// probably won't be a problem.
141+
std::vector<std::complex<double>> ifftResult(fftLen);
142+
pocketfft::shape_t shape{fftLen};
143+
pocketfft::stride_t strided{sizeof(std::complex<double>)};
144+
pocketfft::shape_t axes{0};
145+
146+
pocketfft::c2c(shape, strided, strided, axes, pocketfft::BACKWARD,
147+
ifftInput.data(), ifftResult.data(), 1. / fftLen);
148+
149+
outputChebCoeffs.clear();
150+
outputChebCoeffs.reserve(n);
151+
for (size_t i = 0; i < n; ++i) {
152+
outputChebCoeffs.push_back(APFloat(ifftResult[i].real()));
153+
}
154+
155+
// Due to the endpoint behavior of Chebyshev polynomials and the properties
156+
// of the DCT, the non-endpoint coefficients of the DCT are the Chebyshev
157+
// coefficients scaled by 2.
158+
for (int i = 1; i < n - 1; ++i) {
159+
outputChebCoeffs[i] = outputChebCoeffs[i] * APFloat(2.0);
160+
}
161+
162+
// Even/odd corrections
163+
if (isEven) {
164+
for (size_t i = 1; i < n; i += 2) {
165+
outputChebCoeffs[i] = APFloat(0.0);
166+
}
167+
}
168+
169+
if (isOdd) {
170+
for (size_t i = 0; i < n; i += 2) {
171+
outputChebCoeffs[i] = APFloat(0.0);
172+
}
173+
}
174+
}
175+
72176
} // namespace approximation
73177
} // namespace heir
74178
} // namespace mlir

lib/Utils/Approximation/Chebyshev.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,38 @@ namespace approximation {
1818
/// This is a port of the chebfun routine at
1919
/// https://github.com/chebfun/chebfun/blob/db207bc9f48278ca4def15bf90591bfa44d0801d/%40chebtech2/chebpts.m#L34
2020
void getChebyshevPoints(int64_t numPoints,
21-
SmallVector<::llvm::APFloat> &results);
21+
::llvm::SmallVector<::llvm::APFloat> &results);
2222

2323
/// Generate the first `numPolynomials` Chebyshev polynomials of the second
2424
/// kind, storing them in the results outparameter.
2525
///
2626
/// The first few polynomials are 1, 2x, 4x^2 - 1, 8x^3 - 4x, ...
2727
void getChebyshevPolynomials(
2828
int64_t numPolynomials,
29-
SmallVector<::mlir::heir::polynomial::FloatPolynomial> &results);
29+
::llvm::SmallVector<::mlir::heir::polynomial::FloatPolynomial> &results);
30+
31+
/// Convert a vector of Chebyshev coefficients to the monomial basis. If the
32+
/// Chebyshev polynomials are T_0, T_1, ..., then entry i of the input vector
33+
/// is the coefficient of T_i.
34+
::mlir::heir::polynomial::FloatPolynomial chebyshevToMonomial(
35+
const ::llvm::SmallVector<::llvm::APFloat> &coefficients);
36+
37+
/// Interpolate Chebyshev coefficients for a given set of points. The values in
38+
/// chebEvalPoints are assumed to be evaluations of the target function on the
39+
/// first N+1 Chebyshev points of the second kind, where N is the degree of the
40+
/// interpolating polynomial. The produced coefficients are stored in the
41+
/// outparameter outputChebCoeffs.
42+
///
43+
/// A port of chebfun vals2coeffs, cf.
44+
/// https://github.com/chebfun/chebfun/blob/69c12cf75f93cb2f36fd4cfd5e287662cd2f1091/%40ballfun/vals2coeffs.m
45+
/// based on the a trigonometric interpolation via the FFT.
46+
///
47+
/// Cf. Henrici, "Fast Fourier Methods in Computational Complex Analysis"
48+
/// https://doi.org/10.1137/1021093
49+
/// https://people.math.ethz.ch/~hiptmair/Seminars/CONVQUAD/Articles/HEN79.pdf
50+
void interpolateChebyshev(
51+
::llvm::ArrayRef<::llvm::APFloat> chebEvalPoints,
52+
::llvm::SmallVector<::llvm::APFloat> &outputChebCoeffs);
3053

3154
} // namespace approximation
3255
} // namespace heir

lib/Utils/Approximation/ChebyshevTest.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <cmath>
12
#include <cstdint>
23

34
#include "gmock/gmock.h" // from @googletest
@@ -14,6 +15,7 @@ namespace {
1415

1516
using ::llvm::APFloat;
1617
using ::mlir::heir::polynomial::FloatPolynomial;
18+
using ::testing::DoubleEq;
1719
using ::testing::ElementsAre;
1820

1921
TEST(ChebyshevTest, TestGetChebyshevPointsSingle) {
@@ -47,10 +49,8 @@ TEST(ChebyshevTest, TestGetChebyshevPoints9) {
4749
TEST(ChebyshevTest, TestGetChebyshevPolynomials) {
4850
SmallVector<FloatPolynomial> chebPolys;
4951
int64_t n = 9;
52+
chebPolys.reserve(n);
5053
getChebyshevPolynomials(n, chebPolys);
51-
52-
for (const auto& p : chebPolys) p.dump();
53-
5454
EXPECT_THAT(
5555
chebPolys,
5656
ElementsAre(
@@ -67,6 +67,39 @@ TEST(ChebyshevTest, TestGetChebyshevPolynomials) {
6767
{1., 0., -40., 0., 240., 0., -448., 0., 256.})));
6868
}
6969

70+
TEST(ChebyshevTest, TestChebyshevToMonomial) {
71+
// 1 (1) - 1 (-1 + 4x^2) + 2 (-4x + 8x^3)
72+
SmallVector<APFloat> chebCoeffs = {APFloat(1.0), APFloat(0.0), APFloat(-1.0),
73+
APFloat(2.0)};
74+
// 2 - 8 x - 4 x^2 + 16 x^3
75+
FloatPolynomial expected =
76+
FloatPolynomial::fromCoefficients({2.0, -8.0, -4.0, 16.0});
77+
FloatPolynomial actual = chebyshevToMonomial(chebCoeffs);
78+
EXPECT_EQ(actual, expected);
79+
}
80+
81+
TEST(ChebyshevTest, TestInterpolateChebyshevExpDegree3) {
82+
// degree 3 implies we need 4 points.
83+
SmallVector<APFloat> chebPts = {APFloat(-1.0), APFloat(-0.5), APFloat(0.5),
84+
APFloat(1.0)};
85+
SmallVector<APFloat> expVals;
86+
expVals.reserve(chebPts.size());
87+
for (const APFloat& pt : chebPts) {
88+
expVals.push_back(APFloat(std::exp(pt.convertToDouble())));
89+
}
90+
91+
SmallVector<APFloat> actual;
92+
interpolateChebyshev(expVals, actual);
93+
94+
EXPECT_THAT(actual[0].convertToDouble(), DoubleEq(1.2661108550760016));
95+
EXPECT_THAT(actual[1].convertToDouble(), DoubleEq(1.1308643327583656));
96+
EXPECT_THAT(actual[2].convertToDouble(), DoubleEq(0.276969779739242));
97+
// This test is slightly off from what numpy produces (up to ~10^{-15}), not
98+
// sure why.
99+
// EXPECT_THAT(actual[3].convertToDouble(), DoubleEq(0.04433686088543568));
100+
EXPECT_THAT(actual[3].convertToDouble(), DoubleEq(0.044336860885435536));
101+
}
102+
70103
} // namespace
71104
} // namespace approximation
72105
} // namespace heir

lib/Utils/Polynomial/Polynomial.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ class PolynomialBase {
243243
return os.str();
244244
}
245245

246+
// Returns a zero polynomial
247+
static Derived zero() {
248+
SmallVector<Monomial> monomials;
249+
return Derived(monomials);
250+
}
251+
246252
bool isZero() const { return getTerms().empty(); }
247253

248254
unsigned getDegree() const {
@@ -262,7 +268,8 @@ class PolynomialBase {
262268
/// A single-variable polynomial with integer coefficients.
263269
///
264270
/// Eg: x^1024 + x + 1
265-
class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial, APInt> {
271+
class IntPolynomial final
272+
: public PolynomialBase<IntPolynomial, IntMonomial, APInt> {
266273
public:
267274
explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
268275

@@ -283,7 +290,7 @@ class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial, APInt> {
283290
/// A single-variable polynomial with double coefficients.
284291
///
285292
/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
286-
class FloatPolynomial
293+
class FloatPolynomial final
287294
: public PolynomialBase<FloatPolynomial, FloatMonomial, APFloat> {
288295
public:
289296
explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)

0 commit comments

Comments
 (0)