Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

"Low-Level" QR Implementation #1366

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
947 changes: 839 additions & 108 deletions src/ops/linalg_ops.ts
Original file line number Diff line number Diff line change
@@ -20,15 +20,21 @@
*/

import {ENV} from '../environment';
import {dispose} from '../globals';
import {range, scalar} from './tensor_ops';
import {Tensor, Tensor1D, Tensor2D} from '../tensor';
import {TensorLike, TypedArray} from '../types';
import {add, mul, sub} from './binary_ops';
import {logicalAnd} from './logical_ops';
import {complex, real, imag} from './complex_ops';
import {assert} from '../util';
import {eye, squeeze, stack, unstack} from './array_ops';
import {convertToTensor} from '../tensor_util_env';
import {squeeze, stack} from './array_ops';
import {split} from './concat_split';
import {matMul} from './matmul';
import {norm} from './norm';
import {op} from './operation';
import {sum} from './reduction_ops';
import {tensor2d} from './tensor_ops';
import {upcastType} from '../types';

/**
* Gram-Schmidt orthogonalization.
@@ -106,12 +112,770 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D {
}
}

/**
* Conjugates a tensor of matrices and then transposes the last two dimensions.
* The adjoint is also commonly known as the Hermitian Transpose.
*
* ```js
* const a = tf.tensor3d([[[1, 2],
* [3, 4]],
* [[5, 6],
* [7, 8]]]);
* const aT = tf.linalg.adjoint(a);
* aT.print();
* // Output:
* // [[[1, 3],
* // [2, 4]],
* // [[5, 7],
* // [6, 8]]]
* ```
*
* @param a Tensor of shape [...,M,N]. The tensor of matrices that is to be
* tranposed.
*
* @returns Tensor of shape [...,N,M]. The transpose of `a`.
*/
/**
* @doc {heading:'Operations',
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function adjoint_<T extends Tensor>( a: T|TensorLike ): T
{
let $a = convertToTensor(a,'a','bandPart');

if( $a.rank < 2 ) {
throw new Error(`adjoint(): a.rank = ${$a.rank} < 2.`);
}

const axes = Array.from( $a.shape, (_,i) => i );
axes[axes.length-2] = axes.length-1;
axes[axes.length-1] = axes.length-2;

if( $a.dtype.startsWith('complex') ) {
$a = complex( real($a), imag($a).neg() ); // <- TODO: implement tf.conj
}

return $a.transpose(axes);
}

/**
* Copies a tensor of matrices, setting everything outside a central band
* in each matrix to zero. Does not yet support Infinity or NaN entries.
*
* ```js
* const a = tf.tensor2d([[11, 12, 13, 14],
* [21, 22, 23, 24],
* [31, 32, 33, 34],
* [41, 42, 43, 44]]);
* tf.linalg.bandPart(a,0,2);
* // Output:
* // [[11, 12, 13, 0],
* // [ 0, 22, 23, 24],
* // [ 0, 0, 33, 34],
* // [ 0, 0, 0, 44]]
*
* tf.linalg.bandPart(a,1,-1);
* // Output:
* // [[11, 12, 13, 14],
* // [21, 22, 23, 24],
* // [ 0, 32, 33, 34],
* // [ 0, 0, 43, 44]]
* ```
*
* @param a Tensor of matrices from which the band part is extracted.
* @param numLower The number of subdiagonal lines to be copied.
* If set to `-1`, all entries below the diagonal are
* copied.
* @param numUpper The number of superdiagonal lines to be copied.
* If set to `-1`, all entries above the diagonal are
* copied.
*/
/**
* @doc {heading:'Operations',
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function bandPart_<T extends Tensor>(
a: T|TensorLike, numLower: number, numUpper: number
): T
{
if( numLower%1 !== 0 ){
throw new Error(`bandPart(): numLower=${numLower} not an integer.`);
}
if( numUpper%1 !== 0 ){
throw new Error(`bandPart(): numUpper=${numUpper} not an integer.`);
}

return ENV.engine.tidy( () => {
const $a = convertToTensor(a,'a','bandPart');

if( $a.rank < 2 ) {
throw new Error(`bandPart(): a.rank = ${$a.rank} < 2.`);
}

if( ! isFinite($a.abs().max().dataSync()[0]) ) {
throw new Error(`bandPart(): NaN and Infinity not yet supported.`);
}

const [M,N] = $a.shape.slice(-2);

if( !(numLower <= M) ) {
throw new Error(`bandPart() check failed: numLower <= #rows.` );
}
if( !(numUpper <= N) ) {
throw new Error(`bandPart() check failed: numUpper <= #columns.`);
}

if( numLower < 0 ) { numLower = M; }
if( numUpper < 0 ) { numUpper = N; }

const i = range(0,M, 1, 'int32').reshape([-1,1]),
j = range(0,N, 1, 'int32');

const inBand = logicalAnd(
sub(i,j).lessEqual( scalar(numLower,'int32') ),
sub(j,i).lessEqual( scalar(numUpper,'int32') )
).cast($a.dtype);

return mul($a,inBand);
});
}

function triangularSolveKernel(
l: Tensor, y: Tensor, lower: boolean, adjoint: boolean
): Tensor
{
if( ! l.dtype.startsWith('float') ) {
throw new Error(`triangularSolve(): l.dtype=${l.dtype} not supported.`);
}
if( ! y.dtype.startsWith('float') ) {
throw new Error(`triangularSolve(): y.dtype=${y.dtype} not supported.`);
}
if( l.rank < 2 ) {
throw new Error('triangularSolve(): l must be at least 2D.');
}
if( y.rank < 2 ) {
throw new Error('triangularSolve(): y must be at least 2D.');
}
if( l.rank !== y.rank ) {
throw new Error('triangularSolve(): l and y must have same rank.');
}
for( let i=l.rank-2; i-- > 0; ) {
if( l.shape[i] !== y.shape[i] ) {
throw new Error('triangularSolve(): leading dimensions do not match.');
}
}

const [N,M] = l.shape.slice(-2),
[I,J] = y.shape.slice(-2);
if( N !== M ) {
throw new Error('triangularSolve(): Last two axes of L not square.');
}
if( I !== M ) {
throw new Error('triangularSolve(): L and y do not match.');
}

const
rank = l.rank,
xShape = Array.from(l.shape);
xShape[rank-2] = I;
xShape[rank-1] = J;

// GENERATE RESULT DATA
const
dtype = 'float32',
// dtype = ( l.dtype === 'float64' ||
// y.dtype === 'float64' ) ? 'float64' : 'float32',
// tslint:disable
DTypeArray = Float32Array,
// tslint:enable
// DTypeArray = dtype === 'float32' ? Float32Array
// : Float64Array,
L = l.dataSync(),
X = DTypeArray.from( y.dataSync() ) as TypedArray;
l = undefined;
y = undefined;

for( let lOff = 0,
xOff = 0; xOff < X.length; xOff += N*J,
lOff += N*N )
{
if( ! adjoint )
{
if(lower)
{ // FORWARD SUBSTITUTION
for( let i=0; i < I; i++ ) {
for( let k=0; k < i; k++ ) {
for( let j=0; j < J; j++ ) {
X[xOff + J*i+j] -= L[lOff + N*i+k] * X[xOff + J*k+j];
}}

for( let j=0; j < J; j++ ) {
X[xOff + J*i+j] /= L[lOff + N*i+i];
}
}
}
else
{ // BACKWARD SUBSTITUTION
for( let i=I; i-- > 0; ) {
for( let j=J; j-- > 0; ) {
X[xOff + J*i+j] /= L[lOff + N*i+i];
}

for( let k=i; k-- > 0; ) {
for( let j=J; j-- > 0; ) {
X[xOff + J*k+j] -= L[lOff + N*k+i] * X[xOff + J*i+j];
}}
}
}
}
else
{
if(lower)
{ // BACKWARD SUBSTITUTION (TRANSPOSED)
for( let i=I; i-- > 0; ) {
for( let j=J; j-- > 0; ) {
X[xOff + J*i+j] /= L[lOff + N*i+i];
}

for( let k=i; k-- > 0; ) {
for( let j=J; j-- > 0; ) {
X[xOff + J*k+j] -= L[lOff + N*i+k] * X[xOff + J*i+j];
}}
}
}
else
{ // FORWARD SUBSTITUTION (TRANSPOSED)
for( let i=0; i < I; i++ ) {
for( let k=0; k < i; k++ ) {
for( let j=0; j < J; j++ ) {
X[xOff + J*i+j] -= L[lOff + N*k+i] * X[xOff + J*k+j];
}}

for( let j=0; j < J; j++ ) {
X[xOff + J*i+j] /= L[lOff + N*i+i];
}
}
}
}
}

return Tensor.make(xShape,{values: X},dtype);
}

/**
* Solves a triangular linear equation system (LES).
*
* @param l The triangular matrix of the LES.
* @param y The right-hand-side of the LES.
* @param lower If set to `true`, `l` is interpreted as lower triangular
* matrix. The strict upper triangular entries are ignore.
* If set to `false`, `l` is interpreted as upper triangular
* matrix and the strict lower triangular entries are ignored.
* @param adjoint If set to `true`, the hermitian transpose of `l` is used in
* the LES.
*
* @returns The solution of one of the following LES:
* <dl>
* <dt>lower=false, adjoint=false <dd>tril(l) ∙x == y
* <dt>lower=true, adjoint=false <dd>triu(l) ∙x == y
* <dt>lower=false, adjoint=true <dd>tril(l)ᴴ∙x == y
* <dt>lower=true, adjoint=true <dd>triu(l)ᴴ∙x == y
* </dl>
*/
/**
* @doc {heading:'Operations',
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function triangularSolve_(
l: Tensor|TensorLike, y: Tensor|TensorLike, lower=true, adjoint=false
): Tensor
{
// FIXME: if `l` is singular the right hand side could be
// checked for 0 and then some/any solution could be used

// let [$l,$y] = broadcastMatrices(
// convertToTensor(l,'l','triangularSolve'),
// convertToTensor(y,'y','triangularSolve')
// );
let $l = convertToTensor(l,'l','triangularSolve'),
$y = convertToTensor(y,'y','triangularSolve');
l=undefined;
y=undefined;
if( $l.rank < 2 ){
throw new Error(`triangularSolve(): l.rank must be at least 2.`);
}
if( $y.rank < 2 ){
throw new Error(`triangularSolve(): y.rank must be at least 2.`);
}

const dtype = upcastType($l.dtype, $y.dtype);
if( $l.dtype !== dtype ) { $l = $l.cast(dtype); }
if( $y.dtype !== dtype ) { $y = $y.cast(dtype); }

// WHERE THE BACKPROP COMES FROM:
// x = L⁻¹∙y
// => dx = d(L⁻¹)∙y + L⁻¹∙dy = L⁻¹∙dy - L⁻¹∙dL∙L⁻¹∙y = L⁻¹∙dy - L⁻¹∙dL∙x
// => df = tr( (∂f/∂x)∙dxᵀ )
// = tr( (∂f/∂x)∙dyᵀ∙L⁻ᵀ ) - tr( (∂f/∂x)∙yᵀ∙L⁻ᵀ∙dLᵀ∙L⁻ᵀ )
// = tr( (∂f/∂x)ᵀ∙L⁻¹∙dy ) - tr( (∂f/∂x)∙yᵀ∙L⁻ᵀ∙(L⁻¹∙dL)ᵀ )
// = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( L⁻¹∙y∙(∂f/∂x)ᵀ∙ L⁻¹∙dL )
// = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( x∙(∂f/∂x)ᵀ∙ L⁻¹∙dL )
// = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( L⁻ᵀ ∙(∂f/∂x) ∙ xᵀ ∙dLᵀ )
// => ∂f/∂y = L⁻ᵀ∙(∂f/∂x)
// ∂f/∂L = -L⁻ᵀ∙(∂f/∂x)∙xᵀ = ∂f/∂L = -(∂f/∂y)∙xᵀ

// tslint:disable
// SEE: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L218
// tslint:enable
return ENV.engine.runKernel(
(backend,saveFn) => {
const x = triangularSolveKernel($l,$y,lower,adjoint);
saveFn(x);
return x;
},
{$l,$y},
(dx,[x]) => {
const dy = triangularSolve($l, dx, lower, !adjoint);
return {
$l: () => {
let dl = adjoint ? matMul( x, dy, false, true)
: matMul(dy, x, false, true);
dl = dl.neg();
dl = lower ? bandPart(dl,-1, 0)
: bandPart(dl, 0,-1);
return dl;
},
$y: () => dy
};
}
);
}

/** Computes the economic QR Decomposition.
*/
function qrEcoDecompKernel( a: Tensor ): [Tensor,Tensor]
{
if( a.rank < 2 ) {
throw new Error(`qrEco(): input must have rank >= 2, got rank ${a.rank}.`);
}
if( a.dtype !== 'float32' ) {
throw new Error(`qrEco(): only float32 currently supported as dtype.`);
}
if( a.shape[a.rank-2] < a.shape[a.rank-1] ) {
throw new Error(`qrEco(): a must have at least as many rows as columns`);
}

const dtype = 'float32',
// tslint:disable
DTypeArray = Float32Array,
// tslint:enable
qShape = Array.from( a.shape ),
rShape = Array.from( qShape ),
[M,N] = qShape.slice(-2);
rShape[rShape.length-2] = N;
Object.freeze(qShape);
Object.freeze(rShape);

const Q = DTypeArray.from( a.dataSync() ); a = undefined;
const R = new DTypeArray( Q.length/M*N ),
cs = new DTypeArray( 2*M*N - N*(N+1) );// <- MEMOIZE ROTATIONS

for(
let rOff=0,
qOff=0; qOff < Q.length; qOff += M*N,
rOff += N*N
)
{
let csi = 0;

for( let i=1; i < M; i++ ) { const J = Math.min(i,N);
for( let j=0; j < J; j++ )
{ // DETERMINE GIVENS ROTATION cos AND sin
const rIJ = Q[qOff + N*i+j]; if( 0.0 === rIJ ) {cs[csi++]=1.0;
cs[csi++]=0.0; continue;}
const rJJ = Q[qOff + N*j+j],
norm = Math.hypot(rJJ,rIJ),
c = rJJ / norm,
s = rIJ / norm;
cs[csi++] = c;
cs[csi++] = s;
Q[qOff + N*j+j] = norm;
Q[qOff + N*i+j] = 0;
// ROTATE ROWS IN R (WHICH IS CURRENTLY STORED IN Q)
for( let k=j; ++k < N; )
{ const rJK = Q[qOff + N*j+k],
rIK = Q[qOff + N*i+k];
Q[qOff + N*j+k] = s*rIK + c*rJK;
Q[qOff + N*i+k] = c*rIK - s*rJK;
}
}}

assert( csi === cs.length, `WTF: ${csi} !== ${cs.length}` );

// COPY R FROM Q -> R
for( let i=0; i < N; i++ ) {
for( let j=i; j < N; j++ ) {
R[rOff + N*i+j] = Q[qOff + N*i+j];
Q[qOff + N*i+j] = i !== j ? 0.0 : 1.0;
}}

// COMPUTE Q
for( let i=M; --i > 0; ) { const J = Math.min(i,N);
for( let j=J; j-- > 0; )
{ const s = cs[--csi],
c = cs[--csi];
// ROTATE ROWS IN Q
for( let k=N; k-- > 0; )
{ const qJK = Q[qOff + N*j+k],
qIK = Q[qOff + N*i+k];
Q[qOff + N*j+k] = c*qJK - s*qIK;
Q[qOff + N*i+k] = s*qJK + c*qIK;
}
}}

assert( csi === 0, `WTF: ${csi} !== 0` );
}

const q = Tensor.make(qShape, { values: Q }, dtype);
const r = Tensor.make(rShape, { values: R }, dtype);

return [q,r];
}

/** Computes the full QR Decomposition an memoizes the
* Givens rotation angles in the process.
*/
function qrFullDecompKernel( a: Tensor ): [Tensor,Tensor,Tensor]
{
if( a.rank < 2 ) {
throw new Error(`qrEco(): input must have rank >= 2, got rank ${a.rank}.`);
}
if( a.dtype !== 'float32' ) {
throw new Error(`qrEco(): only float32 currently supported as dtype.`);
}

const dtype = 'float32',
// tslint:disable
DTypeArray = Float32Array,
// tslint:enable
rShape = Array.from( a.shape ),
qShape = Array.from( a.shape ),
[M,N] = a.shape.slice(-2),
R = DTypeArray.from( a.dataSync() );
a = undefined;
const L = Math.min(M,N),
Q = new DTypeArray( R.length/N*M ),
CS = new DTypeArray( R.length/N/M * 2 * (
(L*(L-1) >>> 1) + Math.max(0,M-N)*N
));
qShape[qShape.length-1] = M;
Object.freeze(qShape);
Object.freeze(rShape);

let l = 0;
for( let qOff=0,
rOff=0; qOff < Q.length; qOff += M*M,
rOff += M*N )
{
// INIT Q TO IDENTITY
for( let i=0; i < M; i++ ) { Q[qOff + M*i+i] = 1; }

// BEGIN QR DECOMPOSITION
for( let i=1; i < M; i++ ) { const J = Math.min(i,N);
for( let j=0; j < J; j++ )
{
// DETERMINE GIVENS ROTATION cos AND sin
const rIJ = R[rOff + N*i+j]; if( 0.0 === rIJ ) { CS[l++]=1.0;
CS[l++]=0.0; continue; }
const rJJ = R[rOff + N*j+j],
norm = Math.hypot(rJJ,rIJ),
c = rJJ / norm,
s = rIJ / norm;
CS[l++] = c;
CS[l++] = s;
R[rOff + N*j+j] = norm;
R[rOff + N*i+j] = 0;
// ROTATE ROWS IN R
for( let k=j; ++k < N; )
{ const rJK = R[rOff + N*j+k],
rIK = R[rOff + N*i+k];
R[rOff + N*j+k] = s*rIK + c*rJK;
R[rOff + N*i+k] = c*rIK - s*rJK;
}
// ROTATE ROWS IN Qᵀ
for( let k=0; k <= i; k++ )
{ const qJK = Q[qOff + M*j+k],
qIK = Q[qOff + M*i+k];
Q[qOff + M*j+k] = s*qIK + c*qJK;
Q[qOff + M*i+k] = c*qIK - s*qJK;
}
}} // END QR DECOMPOSITION

// TRANSPOSE Q (was transposed for cache locality)
for( let i=0; i < M; i++ ) {
for( let j=0; j < i; j++ ) {
const qIJ = Q[qOff + M*i+j];
Q[qOff + M*i+j] = Q[qOff + M*j+i];
Q[qOff + M*j+i] = qIJ;
}}
}
assert( l === CS.length, `WTF: ${l} != ${CS.length}` );

const q = Tensor.make(qShape, {values: Q}, dtype);
const r = Tensor.make(rShape, {values: R}, dtype);
const cs = Tensor.make([CS.length], {values: CS}, dtype);

return [q,r,cs];
}

/** Computes the backpropagation full QR Decomposition using
* memoized Givens rotation angles in the process.
*/
function qrFullBackpropKernel(
q: Tensor, dq: Tensor, r: Tensor, dr: Tensor, cs: Tensor
): Tensor
{
if( q.rank !== dq.rank ) {
throw new Error(
`qrFullBackprop(): q.rank == ${q.rank} != ${dq.rank} == dq.rank`
);
}
if( q.rank !== dr.rank ) {
throw new Error(
`qrFullBackprop(): q.rank == ${q.rank} != ${dr.rank} == dr.rank`
);
}
if( q.rank !== r.rank ) {
throw new Error(
`qrFullBackprop(): q.rank == ${q.rank} != ${ r.rank} == r.rank`
);
}

if( cs.rank !== 1 ) {
throw new Error(`qrFullBackprop(): cs.rank == ${cs.rank} != 1`);
}

const rank = q.rank;

if( rank < 2 ) {
throw new Error(
`qrFullBackprop(): input must have rank >= 2, got rank ${rank}.`
);
}

for( let i=rank-2; i-- > 0; )
{
if( q.shape[i] !== dq.shape[i] ) {
throw new Error(
'qrFullBackprop(): '
+ `q.shape[${i}] == ${q.shape[i]} != ${dq.shape[i]} == dq.shape[${i}]`
);
}
if( q.shape[i] !== dr.shape[i] ) {
throw new Error(
'qrFullBackprop(): '
+ `q.shape[${i}] == ${q.shape[i]} != ${dr.shape[i]} == dr.shape[${i}]`
);
}
if( q.shape[i] !== r.shape[i] ) {
throw new Error(
'qrFullBackprop(): '
+ `q.shape[${i}] == ${q.shape[i]} != ${ r.shape[i]} == r.shape[${i}]`
);
}
}

if( q.shape[rank-2] !== q.shape[rank-1] ) {
throw new Error(
'qrFullBackprop(): ' +
`q.shape[-2] == ${q.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]`
);
}
if( q.shape[rank-2] !== dq.shape[rank-1] ) {
throw new Error(
'qrFullBackprop(): ' +
`q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-1]} == dq.shape[-1]`
);
}
if( q.shape[rank-2] !== dq.shape[rank-2] ) {
throw new Error(
'qrFullBackprop(): ' +
`q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-2]} == dq.shape[-2]`
);
}
if( r.shape[rank-2] !== q.shape[rank-1] ) {
throw new Error(
'qrFullBackprop(): ' +
`r.shape[-2] == ${r.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]`
);
}
if( r.shape[rank-1] !== dr.shape[rank-1] ) {
throw new Error(
'qrFullBackprop(): ' +
`r.shape[-1] == ${r.shape[rank-1]} != ${dr.shape[rank-1]} == dr.shape[-1]`
);
}
if( r.shape[rank-2] !== dr.shape[rank-2] ) {
throw new Error(
'qrFullBackprop(): ' +
`r.shape[-2] == ${r.shape[rank-2]} != ${dr.shape[rank-2]} == dr.shape[-2]`
);
}

if( q.dtype !== dq.dtype ) {
throw new Error(
`qrFullBackprop(): q.dtype == ${q.dtype} != ${dq.dtype} == dq.dtype`
);
}
if( q.dtype !== dr.dtype ) {
throw new Error(
`qrFullBackprop(): q.dtype == ${q.dtype} != ${dr.dtype} == dr.dtype`
);
}
if( q.dtype !== r.dtype ) {
throw new Error(
`qrFullBackprop(): q.dtype == ${q.dtype} != ${r.dtype} == r.dtype`
);
}
if( q.dtype !== cs.dtype ) {
throw new Error(
`qrFullBackprop(): q.dtype == ${q.dtype} != ${cs.dtype} == cs.dtype`
);
}

if( q.dtype !== 'float32' ) {
throw new Error(
`qrFullBackprop(): only float32 currently supported as dtype.`
);
}

const dtype ='float32',
// tslint:disable
DTypeArray = Float32Array,
// tslint:enable
dAShape = Array.from( r.shape ),
[M,N] = dAShape.slice(-2);
const Q = DTypeArray.from( q.dataSync() ); q = undefined;
const dQ = DTypeArray.from( dq.dataSync() ); dq = undefined;
const R = DTypeArray.from( r.dataSync() ); r = undefined;
const dR = DTypeArray.from( dr.dataSync() ); dr = undefined;
const CS = cs.dataSync();
Object.freeze(dAShape);

let l = CS.length;
for( let rOff=R.length,
qOff=Q.length; qOff > 0; )
{
qOff -= M*M;
rOff -= M*N;

// TRANSPOSE Q (for cache locality)
for( let i=0; i < M; i++ ) {
for( let j=0; j < i; j++ ) {
const qIJ = Q[qOff + M*i+j];
Q[qOff + M*i+j] = Q[qOff + M*j+i];
Q[qOff + M*j+i] = qIJ;
}}

// TRANSPOSE dQ (for cache locality)
for( let i=0; i < M; i++ ) {
for( let j=0; j < i; j++ ) {
const dQij = dQ[qOff + M*i+j];
dQ[qOff + M*i+j] = dQ[qOff + M*j+i];
dQ[qOff + M*j+i] = dQij;
}}

// BEGIN QR DECOMPOSITION
for( let i=M; --i > 0; ) { const J = Math.min(i,N);
for( let j=J; j-- > 0; )
{
// DETERMINE GIVENS ROTATION cos AND sin
const s = CS[--l]; if( 0 === s ) { continue; }
const c = CS[--l],
norm = R[rOff + N*j+j];

// ROTATE ROWS IN R
for( let k=j; k < N; k++ )
{ const rJK = R[rOff + N*j+k],
rIK = R[rOff + N*i+k];
R[rOff + N*j+k] = c*rJK - s*rIK;
R[rOff + N*i+k] = s*rJK + c*rIK;
}

// ROTATE ROWS IN Qᵀ
for( let k=0; k <= i; k++ )
{ const qJK = Q[qOff + M*j+k],
qIK = Q[qOff + M*i+k];
Q[qOff + M*j+k] = c*qJK - s*qIK;
Q[qOff + M*i+k] = s*qJK + c*qIK;
}

const rIJ = R[rOff + N*i+j] / norm,
rJJ = R[rOff + N*j+j] / norm,
dCdJ = +rIJ*rIJ / norm,
dCdI = -rIJ*rJJ / norm,
dSdJ = -rJJ*rIJ / norm,
dSdI = +rJJ*rJJ / norm;
let dj = 0.0,
di = 0.0;

// ROTATE ROWS IN dR
for( let k=j; k < N; k++ )
{ const dRjk = dR[rOff + N*j+k],
dRik = dR[rOff + N*i+k];
dR[rOff + N*j+k] = c*dRjk - s*dRik;
dR[rOff + N*i+k] = s*dRjk + c*dRik;

const rJK = R[rOff + N*j+k],
rIK = R[rOff + N*i+k];

dj += dRjk*(rIK*dSdJ + rJK*dCdJ) + dRik*(rIK*dCdJ - rJK*dSdJ);
di += dRjk*(rIK*dSdI + rJK*dCdI) + dRik*(rIK*dCdI - rJK*dSdI);
}

// ROTATE ROWS IN dQᵀ
for( let k=0; k <= i; k++ )
{ const dQjk = dQ[qOff + M*j+k],
dQik = dQ[qOff + M*i+k];
dQ[qOff + M*j+k] = c*dQjk - s*dQik;
dQ[qOff + M*i+k] = s*dQjk + c*dQik;

const qJK = Q[qOff + M*j+k],
qIK = Q[qOff + M*i+k];

dj += dQjk*(qIK*dSdJ + qJK*dCdJ) + dQik*(qIK*dCdJ - qJK*dSdJ);
di += dQjk*(qIK*dSdI + qJK*dCdI) + dQik*(qIK*dCdI - qJK*dSdI);
}

dR[rOff + N*j+j] += dj;
dR[rOff + N*i+j] += di;
}} // END QR DECOMPOSITION
}
assert( 0 === l, `WTF: ${l} != 0` );

return Tensor.make(dAShape,{values: dR},dtype);
}

/**
* Compute QR decomposition of m-by-n matrix using Householder transformation.
* Compute QR decomposition of m-by-n matrix using Givens rotations.
*
* Implementation based on
* [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]
* (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)
* See: http://www.math.usm.edu/lambers/mat610/sum10/lecture9.pdf
*
* ```js
* const a = tf.tensor2d([[1, 2], [3, 4]]);
* let [q, r] = tf.linalg.qr(a);
* console.log('Q');
* q.print();
* console.log('R');
* r.print();
* console.log('Orthogonalized');
* q.dot(q.transpose()).print() // should be nearly the identity matrix.
* console.log('Reconstructed');
* q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];
* ```
*
* ```js
* const a = tf.tensor2d([[1, 2], [3, 4]]);
@@ -150,115 +914,82 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D {
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] {
if (x.rank < 2) {
function qr_( a: Tensor, fullMatrices = false ): [Tensor, Tensor] {
if( a.rank < 2 ) {
throw new Error(
`qr() requires input tensor to have a rank >= 2, but got rank ${
x.rank}`);
} else if (x.rank === 2) {
return qr2d(x as Tensor2D, fullMatrices);
} else {
// Rank > 2.
// TODO(cais): Below we split the input into individual 2D tensors,
// perform QR decomposition on them and then stack the results back
// together. We should explore whether this can be parallelized.
const outerDimsProd = x.shape.slice(0, x.shape.length - 2)
.reduce((value, prev) => value * prev);
const x2ds = unstack(
x.reshape([
outerDimsProd, x.shape[x.shape.length - 2],
x.shape[x.shape.length - 1]
]),
0);
const q2ds: Tensor2D[] = [];
const r2ds: Tensor2D[] = [];
x2ds.forEach(x2d => {
const [q2d, r2d] = qr2d(x2d as Tensor2D, fullMatrices);
q2ds.push(q2d);
r2ds.push(r2d);
});
const q = stack(q2ds, 0).reshape(x.shape);
const r = stack(r2ds, 0).reshape(x.shape);
return [q, r];
`qr() requires input tensor to have a rank >= 2, but got rank ${a.rank}`
);
}
if( a.dtype.startsWith('complex') ) {
throw new Error(`qr() not yet supported for complex tensors.`);
}
}

function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {
return ENV.engine.tidy(() => {
if (x.shape.length !== 2) {
throw new Error(
`qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`);
}
const [m,n] = a.shape.slice(-2);

const m = x.shape[0];
const n = x.shape[1];

let q = eye(m) as Tensor2D; // Orthogonal transform so far.
let r = x.clone(); // Transformed matrix so far.

const one2D = tensor2d([[1]], [1, 1]);
let w: Tensor2D = one2D.clone();

const iters = m >= n ? n : m;
for (let j = 0; j < iters; ++j) {
// This tidy within the for-loop ensures we clean up temporary
// tensors as soon as they are no longer needed.
const rTemp = r;
const wTemp = w;
const qTemp = q;
[w, r, q] = ENV.engine.tidy((): [Tensor2D, Tensor2D, Tensor2D] => {
// Find H = I - tau * w * w', to put zeros below R(j, j).
const rjEnd1 = r.slice([j, j], [m - j, 1]);
const normX = rjEnd1.norm();
const rjj = r.slice([j, j], [1, 1]);
const s = rjj.sign().neg() as Tensor2D;
const u1 = rjj.sub(s.mul(normX)) as Tensor2D;
const wPre = rjEnd1.div(u1);
if (wPre.shape[0] === 1) {
w = one2D.clone();
} else {
w = one2D.concat(
wPre.slice([1, 0], [wPre.shape[0] - 1, wPre.shape[1]]) as
Tensor2D,
0);
}
const tau = s.matMul(u1).div(normX).neg() as Tensor2D;

// -- R := HR, Q := QH.
const rjEndAll = r.slice([j, 0], [m - j, n]);
const tauTimesW = tau.mul(w) as Tensor2D;
if (j === 0) {
r = rjEndAll.sub(tauTimesW.matMul(w.transpose().matMul(rjEndAll)));
} else {
r = r.slice([0, 0], [j, n])
.concat(
rjEndAll.sub(tauTimesW.matMul(
w.transpose().matMul(rjEndAll))) as Tensor2D,
0) as Tensor2D;
}
const qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]);
if (j === 0) {
q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tauTimesW.transpose()));
} else {
q = q.slice([0, 0], [m, j])
.concat(
qAllJEnd.sub(qAllJEnd.matMul(w).matMul(
tauTimesW.transpose())) as Tensor2D,
1) as Tensor2D;
if( m === n || m > n && !fullMatrices )
{
// FIXME: What if R is (nearly) singular?
return ENV.engine.runKernel(
(backend,saveFunc) => {
const [q,r] = qrEcoDecompKernel(a);
saveFunc(q);
saveFunc(r);
return [q,r];
},
{a},
([dq,dr], [q,r]) => ({
a: () => {
// TODO: is tidy required here?
// tslint:disable
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L160
// tslint:enable
const qdq = matMul(q,dq, true, false),
rdr = matMul(r,dr, false, true),
qdq_ = qdq.sub( adjoint(qdq) ),
rdr_ = rdr.sub( adjoint(rdr) ),
tril = bandPart( add(qdq_,rdr_), -1, 0 );

const triSolv = (x: Tensor,r: Tensor) => adjoint(
triangularSolve(r, adjoint(x), /*lower=*/false, /*adjoint_r*/false)
);

const gradA = matMul( q, dr.add( triSolv(tril,r) ) ),
gradB = triSolv( dq.sub( matMul(q,qdq) ), r );

return add(gradA,gradB);
}
return [w, r, q];
});
dispose([rTemp, wTemp, qTemp]);
}
})
) as [Tensor, Tensor];
}

if (!fullMatrices && m > n) {
q = q.slice([0, 0], [m, n]);
r = r.slice([0, 0], [n, n]);
}
let [q,r] = ENV.engine.runKernel(
(backend,saveFunc) => {
const [q,r,cs] = qrFullDecompKernel(a);
saveFunc(q);
saveFunc(r);
saveFunc(cs);
return [q,r];
},
{a},
([dq,dr], [q,r,cs]) => ({
a: () => ENV.engine.runKernel(
(backend,saveFunc) => qrFullBackpropKernel(q,dq, r,dr, cs),
{ $dq: dq, $dr: dr }
)
})
);

if( ! fullMatrices && m > n ) {
const end = a.shape.slice();
q = q.slice([0, 0], end); end[end.length-2] = n;
r = r.slice([0, 0], end);
}

return [q, r];
}) as [Tensor2D, Tensor2D];
return [q,r];
}

export const adjoint = op({adjoint_});
export const bandPart = op({bandPart_});
export const gramSchmidt = op({gramSchmidt_});
export const qr = op({qr_});
export const triangularSolve = op({triangularSolve_});
578 changes: 459 additions & 119 deletions src/ops/linalg_ops_test.ts
Original file line number Diff line number Diff line change
@@ -16,12 +16,114 @@
*/

import * as tf from '../index';
import {ENV} from '../environment';
import {describeWithFlags} from '../jasmine_util';
import {Tensor1D, Tensor2D} from '../tensor';
import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util';
import {Scalar, Tensor, Tensor1D, Tensor2D} from '../tensor';
import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util';

import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops';

/** Returns a random integer in the range of [from,until).
*/
const randInt = (from: number, until: number) => {
return Math.floor(Math.random()*(until-from)) + from;
};

/**
* Computes the gradients using finite differences. Current
* implmentation uses an O(h⁴) central difference.
*
* SEE: https://en.wikipedia.org/wiki/Finite_difference
*
* FIXME this is terribly imprecise... wish there was
* double precision support *hint hint*.
*/
const numDiff = (f: (x: Tensor) => Scalar) => (a: Tensor) => {
if( a.dtype !== 'float32' ) {
throw new Error(`numDiff(): dtype=${a.dtype} not supported.`);
}

const aData = Float32Array.from( a.dataSync() );

const eps = Math.sqrt( ENV.get('EPSILON') );

return ENV.engine.tidy(() => {

const dA = new Float32Array( aData.length );

for( let i=0; i < aData.length; i++ )
{ // use central difference
const x = aData[i],
h = Math.max( Math.abs(x)*eps, eps );

const g = ( x: number ) => ENV.engine.tidy( () => {
aData[i] = x;

const b = Tensor.make(a.shape, {values: aData});
const scalar = f(b);

if( scalar.rank !== 0 ) {
throw new Error('f() returned a non-scalar value.');
}

return scalar.dataSync()[0];
});

// https://www.geometrictools.com/Documentation/FiniteDifferences.pdf
dA[i] = (-g(x+2*h) + 8*g(x+h) - 8*g(x-h) + g(x-2*h) ) / (12*h);
aData[i] = x; // <- undo modifications
}

return Tensor.make(a.shape,{values: dA});
});
};

/**
* An tensor equivalency assertion that uses a comparison operator
* that is very similar to NumPy's `is_close()` function.
*/
function expectTensorsRelativelyClose(
actual: Tensor, expected: Tensor, rtol?: number, atol?: number
): void
{
if( expected.shape.some( (s,i) => s !== actual.shape[i] ) ) {
throw new Error(
`Shapes [${actual.shape}] and [${expected.shape}] do not match.`
);
}

if( null == atol ) { atol = ENV.get('TEST_EPSILON'); }
if( null == rtol ) { rtol = ENV.get('TEST_EPSILON'); }

const act = actual.dataSync(),
exp = expected.dataSync();

const isClose = (x: number, y: number) => {
x = Math.abs(x);
y = Math.abs(y);
return Math.abs(x-y) <= atol + rtol/2*(x+y);
};

for( let i=act.length; i-- > 0; ) {
if( ! isClose(act[i],exp[i]) )
{
console.log( 'actual:'); actual.print();
console.log('expected:'); expected.print();
const idx = [],
shape = actual.shape;
for( let j=i, d=shape.length; d-- > 0; )
{
const size = shape[d];
idx.unshift(j % size);
j = Math.trunc(j / size);
}
throw new Error(
`actual[${idx}] = ${act[i]} != ${exp[i]} = expected[${idx}]`
);
}
}
}

describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => {
it('2x2, Array of Tensor1D', () => {
const xs: Tensor1D[] = [
@@ -94,137 +196,375 @@ describeWithFlags('gramSchmidt-non-tiny', WEBGL_ENVS, () => {
});
});

describeWithFlags('qr', ALL_ENVS, () => {
it('1x1', () => {
const x = tensor2d([[10]], [1, 1]);
const [q, r] = tf.linalg.qr(x);
expectArraysClose(q, tensor2d([[-1]], [1, 1]));
expectArraysClose(r, tensor2d([[-10]], [1, 1]));
describeWithFlags('adjoint', ALL_ENVS, () => {
it('2x3', () => {
const a = tf.tensor2d([[1,2,3],
[4,5,6]], [2,3]),
aT = tf.tensor2d([[1,4],
[2,5],
[3,6]],[3,2]);
// FIXME: shouldn't tf.transpose be lossless?
// Yet this fails on Travis with `expectArraysEqual`...
expectArraysClose( tf.linalg.adjoint(a), aT );
});

it('2x2', () => {
const x = tensor2d([[1, 3], [-2, -4]], [2, 2]);
const [q, r] = tf.linalg.qr(x);
expectArraysClose(
q, tensor2d([[-0.4472, -0.8944], [0.8944, -0.4472]], [2, 2]));
expectArraysClose(r, tensor2d([[-2.2361, -4.9193], [0, -0.8944]], [2, 2]));
it('3x2x1', () => {
const a = tf.tensor3d([[[1],[2]],
[[3],[4]],
[[5],[6]]], [3,2,1]),
aT = tf.tensor3d([[[1,2]],
[[3,4]],
[[5,6]]], [3,1,2]);
expectArraysClose( tf.linalg.adjoint(a), aT );
});
});

it('2x2x2', () => {
const x = tensor3d([[[-1, -3], [2, 4]], [[1, 3], [-2, -4]]], [2, 2, 2]);
const [q, r] = tf.linalg.qr(x);
expectArraysClose(
q,
tensor3d(
[
[[-0.4472, -0.8944], [0.8944, -0.4472]],
[[-0.4472, -0.8944], [0.8944, -0.4472]]
],
[2, 2, 2]));
expectArraysClose(
r,
tensor3d(
[
[[2.2361, 4.9193], [0, 0.8944]],
[[-2.2361, -4.9193], [0, -0.8944]]
],
[2, 2, 2]));
});
describeWithFlags('bandPart', ALL_ENVS, () => {
const la = tf.linalg;

it('2x1x2x2', () => {
const x =
tensor4d([[[[-1, -3], [2, 4]]], [[[1, 3], [-2, -4]]]], [2, 1, 2, 2]);
const [q, r] = tf.linalg.qr(x);
expectArraysClose(
q,
tensor4d(
[
[[[-0.4472, -0.8944], [0.8944, -0.4472]]],
[[[-0.4472, -0.8944], [0.8944, -0.4472]]],
],
[2, 1, 2, 2]));
expectArraysClose(
r,
tensor4d(
[
[[[2.2361, 4.9193], [0, 0.8944]]],
[[[-2.2361, -4.9193], [0, -0.8944]]]
],
[2, 1, 2, 2]));
});
// FIXME: shouldn't 1*x be lossless?
// It's even in the IEEE spec somewhere...
// Yet this fails on Travis with `expectArraysEqual`...
const expectArraysEqual = expectArraysClose;

it('3x3', () => {
const x = tensor2d([[1, 3, 2], [-2, 0, 7], [8, -9, 4]], [3, 3]);
const [q, r] = tf.linalg.qr(x);
expectArraysClose(
q,
tensor2d(
[
[-0.1204, 0.8729, 0.4729], [0.2408, -0.4364, 0.8669],
[-0.9631, -0.2182, 0.1576]
],
[3, 3]));
expectArraysClose(
r,
tensor2d(
[[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]],
[3, 3]));
});
it('3x4', () => {
const a = tf.tensor2d([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9,10,11,12]
]);
expectArraysEqual(
la.bandPart(a,0,0),
tf.tensor2d([[1, 0, 0, 0],
[0, 6, 0, 0],
[0, 0,11, 0]])
);
expectArraysEqual(
la.bandPart(a,0,1),
tf.tensor2d([[1, 2, 0, 0],
[0, 6, 7, 0],
[0, 0,11,12]])
);
expectArraysEqual(
la.bandPart(a,0,2),
tf.tensor2d([[1, 2, 3, 0],
[0, 6, 7, 8],
[0, 0,11,12]])
);
expectArraysEqual(
la.bandPart(a,0,2),
tf.tensor2d([[1, 2, 3, 0],
[0, 6, 7, 8],
[0, 0,11,12]])
);
for( const numUpper of [3,4,-1,-2] ) {
expectArraysEqual(
la.bandPart(a,0,numUpper),
tf.tensor2d([[1, 2, 3, 4],
[0, 6, 7, 8],
[0, 0,11,12]])
);
}

it('3x2, fullMatrices = default false', () => {
const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]);
const [q, r] = tf.linalg.qr(x);
expectArraysClose(
q,
tensor2d(
[[-0.2673, 0.9221], [-0.8018, -0.3738], [0.5345, -0.0997]],
[3, 2]));
expectArraysClose(r, tensor2d([[-3.7417, 2.4054], [0, 2.8661]], [2, 2]));
});
expectArraysEqual(
la.bandPart(a,1,0),
tf.tensor2d([[1, 0, 0, 0],
[5, 6, 0, 0],
[0,10,11, 0]])
);
expectArraysEqual(
la.bandPart(a,1,1),
tf.tensor2d([[1, 2, 0, 0],
[5, 6, 7, 0],
[0,10,11,12]])
);
expectArraysEqual(
la.bandPart(a,1,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[0,10,11,12]])
);
expectArraysEqual(
la.bandPart(a,1,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[0,10,11,12]])
);
for( const numUpper of [3,4,-1,-2] ) {
expectArraysEqual(
la.bandPart(a,1,numUpper),
tf.tensor2d([[1, 2, 3, 4],
[5, 6, 7, 8],
[0,10,11,12]])
);
}

it('3x2, fullMatrices = true', () => {
const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]);
const [q, r] = tf.linalg.qr(x, true);
expectArraysClose(
q,
tensor2d(
[
[-0.2673, 0.9221, 0.2798], [-0.8018, -0.3738, 0.4663],
[0.5345, -0.0997, 0.8393]
],
[3, 3]));
expectArraysClose(
r, tensor2d([[-3.7417, 2.4054], [0, 2.8661], [0, 0]], [3, 2]));
for( const numLower of [2,3,-1,-2])
{
expectArraysEqual(
la.bandPart(a,numLower,0),
tf.tensor2d([[1, 0, 0, 0],
[5, 6, 0, 0],
[9,10,11, 0]])
);
expectArraysEqual(
la.bandPart(a,numLower,1),
tf.tensor2d([[1, 2, 0, 0],
[5, 6, 7, 0],
[9,10,11,12]])
);
expectArraysEqual(
la.bandPart(a,numLower,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[9,10,11,12]])
);
expectArraysEqual(
la.bandPart(a,numLower,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[9,10,11,12]])
);
for( const numUpper of [3,4,-1,-2] ) {
expectArraysEqual(
la.bandPart(a,numLower,numUpper),
tf.tensor2d([[1, 2, 3, 4],
[5, 6, 7, 8],
[9,10,11,12]])
);
}
}
// following test is only required for custom backend implementations
//
// for( const numUpper of [0,1,2,3,4,-1,-2] ) {
// for( const numLower of [0,1,2,3, -1,-2] ) {
// const w = tf.randomUniform(a.shape),
// f = (x: Tensor) => {
// return la.bandPart(x,numLower,numUpper).mul(w).mean() as Scalar;
// },
// g = numDiff(f),
// h = tf.grad(f);
// expectArraysClose( g(a), h(a) );
// }}
});
});

it('2x3, fullMatrices = default false', () => {
const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]);
const [q, r] = tf.linalg.qr(x);
expectArraysClose(
q,
tensor2d([[-0.3162278, -0.9486833], [0.9486833, -0.31622773]], [2, 2]));
expectArraysClose(
r,
tensor2d(
[[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]], [2, 3]),
);
});
describeWithFlags('triangularSolve', CPU_ENVS, () => {
const la = tf.linalg;

it('2x3, fullMatrices = true', () => {
const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]);
const [q, r] = tf.linalg.qr(x, true);
expectArraysClose(
q,
tensor2d([[-0.3162278, -0.9486833], [0.9486833, -0.31622773]], [2, 2]));
expectArraysClose(
r,
tensor2d(
[[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]], [2, 3]),
const testWith = (L: Tensor, y: Tensor) => {
const test = (adjoint: boolean) =>
{
let tril = la.bandPart(L,-1, 0),
triu = la.bandPart(L, 0,-1);
if( adjoint ) {
tril = la.adjoint(tril);
triu = la.adjoint(triu);
}
for( const lower of [true,undefined] )
{
const x = la.triangularSolve(L,y, lower, adjoint);
const [a,b] = [y,tril.matMul(x)];
expectArraysClose(a,b);
}
const x = la.triangularSolve(L,y, /*lower=*/false, adjoint);
const [a,b] = [y,triu.matMul(x)];
// const [a,b] = broadcastMatrices( y, triu.matMul(x) );
expectArraysClose(a,b);

for( const lower of [false,true,undefined] )
{
const w = tf.randomUniform(y.shape,-1,+1),
f = (L: Tensor, y: Tensor) => {
return la.triangularSolve(L,y,lower).mul(w).mean() as Scalar;
},
[g1,g2] = tf.grads(f)([L,y]),
h1 = numDiff( (L: Tensor) => f(L,y) )(L),
h2 = numDiff( (y: Tensor) => f(L,y) )(y);
expectArraysClose(g1,h1);
expectArraysClose(g2,h2);
}
};
test(undefined);
test(false);
test(true);
};

it('3x3', () => testWith(
tf.tensor2d([[1,2,3],
[4,5,6],
[7,8,9]]),
tf.tensor2d([[10,11],
[12,13],
[14,15]])
));

for( let run=0; run < 128; run++ )
{
const lShape = Array.from({ length: randInt(2,5) }, () => randInt(1,7) ),
yShape = lShape.slice();
lShape[lShape.length-1] = lShape[lShape.length-2];

// RUN TEST
it(`random#${run}_${lShape.join('x')}_${yShape.join('x')}`, () => {
const ONE = tf.scalar(1),
TWO = tf.scalar(2);
const y = tf.randomUniform(yShape,-1,+1);
let L: Tensor = tf.randomUniform(lShape,-1,+1);
// SET THE DIAGONAL TO BE FAR FROM ZERO
const i = tf.range(0,lShape[lShape.length-2]).reshape([-1,1]),
j = tf.range(0,lShape[lShape.length-1]),
diag = tf.equal(i,j).cast('float32'),
magn = tf.randomNormal (lShape, /*mean=*/1,/*stdDev=*/0.1),
sign = tf.randomUniform(lShape, 0,2, 'int32')
.cast('float32').mul(TWO).sub(ONE);
L = tf.add(
diag.sub(ONE).mul(L), // <- off-diagonal
diag.mul(sign).mul(magn) // <- diagonal
);
L = tf.clone(L);
testWith(L,y);
});
}
});

describeWithFlags('qr', CPU_ENVS, () => {
const testWith = (a: Tensor) => {
const [m,n] = a.shape.slice(-2),
l = Math.min(m,n),
// Indices of matrix transpose.
T = Array.from({ length: a.rank }, (_,i) => i );
T[T.length-2] = T.length-1;
T[T.length-1] = T.length-2;

for( const fullMatrices of [undefined,false,true] )
{
const tril = (() => {
const [p,q] = fullMatrices ? [m,n] : [l,n],
i = tf.range(0,p).reshape([p,1]),
j = tf.range(0,q).reshape([1,q]);
return i.greater(j).cast('float32');
})();
const EYE = (() => {
const d = fullMatrices ? m : l;
return tf.stack(
Array.from(
{ length: a.shape.slice(0,-2).reduce( (x,y) => x*y, 1 ) },
() => tf.eye(d)
)
).reshape([...a.shape.slice(0,-2),d,d]);
})();
const [q,r] = tf.linalg.qr(a,fullMatrices);

// TEST SHAPE OF Q
expectArraysEqual( q.shape.slice(0,-1), a.shape.slice(0,-1) );
expectArraysEqual( q.shape.slice( -1), fullMatrices ? [m ] : [l ] );

// TEST SHAPE OF R
expectArraysEqual( r.shape.slice(0,-2), a.shape.slice(0,-2) );
expectArraysEqual( r.shape.slice( -2), fullMatrices ? [m,n] : [l,n] );

// TEST DECOMPOSITION (Q @ R == A)
try {
expectArraysClose( q.matMul(r), a );
} catch(err) {
console.log('A'); a.print();
console.log('Q'); q.print();
console.log('R'); r.print();
throw err;
}

const qT = q.transpose(T);

// TEST ORTHOGONALITY OF Q
if( fullMatrices || n >= m ) {
expectArraysClose( tf.matMul(q,qT), EYE );
}
expectArraysClose( tf.matMul(qT,q), EYE );

// TEST TRIANGULARITY OF R
expectArraysEqual( tril.mul(r), tf.zeros(r.shape) );

// TEST GRADIENTS
const wQ = tf.randomUniform(q.shape,-1,+1),
wR = tf.randomUniform(r.shape,-1,+1),
f = (a: Tensor) => {
const [q,r] = tf.linalg.qr(a,fullMatrices);
return tf.add(
q.mul(wQ).mean(),
r.mul(wR).mean()
) as Scalar;
};
const g = numDiff(f);
const h = tf.grad(f);
try {
expectTensorsRelativelyClose(g(a), h(a), /*rtol=*/1e-2, /*atol=*/1e-2);
}
catch(err) {
console.log('fullMatrices:', fullMatrices);
console.log('A:'); a .print();
// const [q,r] = tf.linalg.qr(a,fullMatrices);
// console.log('Q:'); q .print();
// console.log('R:'); r .print();
// console.log('G:'); g(a).print();
// console.log('H:'); h(a).print();
throw err;
}
}
};

it('1x1', () => testWith( tensor2d([[10]], [1, 1]) ) );

it('2x2', () => testWith( tensor2d([[ 1, 3],
[-2,-4]], [2, 2]) ) );

it('2x2x2', () => testWith( tensor3d([[[-1,-3],
[ 2, 4]],
[[ 1, 3],
[-2,-4]]], [2, 2, 2]) ) );

it('2x1x2x2', () => testWith( tensor4d([[[[-1,-3],
[ 2, 4]]],
[[[ 1, 3],
[-2,-4]]]], [2, 1, 2, 2]) ) );

it('3x3', () => testWith( tensor2d([[ 1, 3, 2],
[-2, 0, 7],
[ 8,-9, 4]], [3, 3]) ) );

it('3x2', () => testWith( tensor2d([[ 1, 2],
[ 3,-3],
[-2, 1]], [3, 2]) ) );

it('2x3', () => testWith( tensor2d([[ 1, 2, 3],
[-3,-2, 1]], [2, 3]) ) );

for( let run=0; run < 128; run++ )
{
const shape = Array.from({ length: randInt(2,5) }, () => randInt(1,7) );
it(
`random#${run}_${shape.join('x')}`,
() => testWith( tf.randomUniform(shape,-1,+1) )
);
}

it('Is reasonably fast', () => {
// TODO is there a better way to test this with a timeout?
const N = 128,
A = tf.randomUniform([N,N],-1,+1),
wQ = tf.randomUniform([N,N],-1,+1),
wR = tf.randomUniform([N,N],-1,+1),
f = (a: Tensor) => {
const [q,r] = tf.linalg.qr(a);
return q.mul(wQ).mean().add( r.mul(wR).mean() );
};
const g = tf.grad(f);
// following hopefully prevents g(A) from being JITes/Optimized away...
expectArraysClose( g(A), g(A) );
});

it('Does not leak memory', () => {
const x = tensor2d([[1, 3], [-2, -4]], [2, 2]);
const x = tensor2d([[ 1, 3],
[-2,-4]], [2, 2]);
// The first call to qr creates and keeps internal singleton tensors.
// Subsequent calls should always create exactly two tensors.
tf.linalg.qr(x);