Skip to content

Let the gradient code for convolution be simplified more (relative to the hand-written gradient that exploits multi-linearity) #123

@Mikolaj

Description

@Mikolaj

Let's look at the Haskell code of conv2dSameS (that uses operations from the horde-ad library and was originally contributed by Ben Lippmeier and Tran Ma), one of the three shaped convolution functions in that file written in point-full style using sbuild (with ranked versions of the three functions in another file)

-- | Full convolution, where the output image size is the same
-- as the input size.
conv2dSameS
:: forall nImgs nCinp nCinpA nCout nAh nAw nKh nKw shK shA shB shK1
target r.
( KnownNat nImgs, KnownNat nCinp, KnownNat nCout
, KnownNat nAh, KnownNat nAw, KnownNat nKh, KnownNat nKw
, ADReady target, GoodScalar r
, nCinpA ~ nCinp
, shK ~ '[nCout, nCinp, nKh, nKw]
, shA ~ '[nImgs, nCinp, nAh, nAw]
, shB ~ '[nImgs, nCout, nAh, nAw]
, shK1 ~ '[1, nCinpA, nKh, nKw]
)
=> target (TKS shK r) -> target (TKS shA r) -> target (TKS shB r)
conv2dSameS arrK arrA =
sbuild @(Rank shB) $ \case
[iImg, iCout, iBh, iBw] ->
let arrAt = slicezS @shK1 arrA
[iImg, 0, iBh, iBw]
arrKt = slicezS arrK
[iCout, 0, 0, 0]
in sdot0 arrAt arrKt
_ -> error "conv2dSameS: impossible pattern needlessly required"

The unscientific timings for (a hundred runs of) gradient programs (to be described later) of this function instantiated to small arrays are

    conv2dSameVjp Bench dKrn Handwritten:           OK (2.31s)
    conv2dSameVjp Bench dKrn HandwrittenVectorized: OK (0.07s)
    conv2dSameVjp Bench dKrn Symbolic:              OK (0.23s)
    conv2dSameVjp Bench dKrn Concrete:              OK (37.93s)

The goal of this ticket is to get the third timing (Symbolic) closer to the second (HandwrittenVectorized).

The gradient (reverse derivative) programs are as follows. Handwritten is the simplest handwritten code of the gradient (with respect to the kernels) of the conv2dSameS function:

-- | Hand-written reverse derivative of full convolution with respect
-- to the kernels.
-- This code vectorized is pretty-printed in test testSameCNNOPPKrnHandwritten.
-- Example code that horde-ad generates for the same is in testSameCNNOPP0cW.
conv2dSame_dKrn
:: forall nImgs nCinp nCout nAh nAw nKh nKw shK shA shB shB1
target r.
( KnownNat nImgs, KnownNat nCinp, KnownNat nCout
, KnownNat nAh, KnownNat nAw, KnownNat nKh, KnownNat nKw
, ADReady target, GoodScalar r
, shK ~ '[nCout, nCinp, nKh, nKw]
, shA ~ '[nImgs, nCinp, nAh, nAw]
, shB ~ '[nImgs, nCout, nAh, nAw]
, shB1 ~ '[nImgs, 1, nAh, nAw] )
=> target (TKS shA r)
-> target (TKS shB r)
-> target (TKS shK r)
conv2dSame_dKrn arrA arrB =
sbuild @(Rank shK) $ \case
[iCout, iCinp, iKh, iKw] ->
let arrBt = slicezS @shB1 arrB
[0, iCout, 0, 0]
arrAt = slicezS arrA
[0, iCinp, iKh, iKw]
in sdot0 arrBt arrAt
_ -> error "conv2dSame_dKrn: impossible pattern needlessly required"

HandwrittenVectorized below is the same code but vectorized (passed through the bulk-operation transform (BOT) so that is uses bulk array operation). The displayed code is additionally instantiated to a particular set of types of small arrays. Note that sgather can be trivially expressed as sbuild, if that's what the user prefers, but in horde-ad this probably degrades performance (TODO: benchmark, understand/improve/fuse more).

\u0 ->
  \u99 ->
    ssum @108
      (stranspose @'[ 4, 0, 1, 2, 3]
         (sreshape @'[ 3, 3, 3, 3, 108]
            (str
               (sreplicate @3
                  (str (sreplicate @3 (str (sreplicate @3 (stranspose @'[ 1, 2, 0] (sreplicate @1 (str u99)))))))) *
             sreplicate @3
               (stranspose @'[ 1, 2, 3, 4, 0]
                  (sreplicate @1
                     (stranspose @'[ 2, 3, 0, 4, 5, 1]
                        (sgather
                           (stranspose @'[ 4, 2, 0, 3, 1]
                              (sgather (stranspose @'[ 2, 1, 0] u0) (\[i18, i11] -> [i18 + i11])))
                           (\[i16, i12] -> [i16 + i12]))))))))

Symbolic is the code (vectorized) that the symbolic pipeline of horde-ad produces [edit: for a ranked variant of the convolution, I've just realized, hence the conversions such as rfromS; please ignore these]:

\u0 ->
  \dret u1 ->
    rfromS
      (ssum @6
         (ssum @6
            (sdot1In
               (stranspose @'[ 2, 3, 0, 4, 5, 6, 1]
                  (sreplicate @2
                     (stranspose @'[ 2, 3, 0, 4, 5, 1]
                        (sgather
                           (stranspose @'[ 4, 2, 0, 3, 1]
                              (sgather (stranspose @'[ 2, 0, 1] (sfromR u0)) (\[i81, i83] -> [i81 + i83])))
                           (\[i24, i25] -> [i24 + i25])))))
               (stranspose @'[ 2, 3, 1, 4, 5, 6, 0]
                  (sreshape @'[ 6, 2, 6, 6, 2, 2, 2] (stranspose @'[ 1, 2, 3, 4, 0] (sreplicate @8 (sfromR dret))))))))

Concrete represents the runtime of the primitive horde-ad pipeline (similar to what package ad does), in which no code is produced and the performance is terrible due to this being the worst case for horde-ad non-symbolic AD (indexing in a loop and lots of unrolling (instead of bulk operations); the latter is what makes the performance of Handwritten bad, too, though compiling the code instead of interpreting may help, see #124).

Note how u1 (the point at which we differentiate) is unused in the code of Symbolic (dret corresponds to u99 in HandwrittenVectorized). This result of (bi- and multi-) linearity of conv2dSameS has apparently been exploited by the simplifier of horde-ad. However, the code is still three times slower (for small arrays and the ratio seems to grow with array size (TODO), just as the ratio of runtimes of Handwritten and HandwrittenVectorized (TODO)). This may or may not be fully explained by that fact that Symbolic uses three summing constructs instead of one and creates intermediate arrays of higher ranks than HandwrittenVectorized.

It's quite possible that both of these programs are of a form that leads to bad performance and that both should be transformed. However, it's not clear how and to what form. Rather than wondering about that, one may focus on the more direct task of automatically simplifying Symbolic to HandwrittenVectorized or of making AD generate HandwrittenVectorized instead of Symbolic, using or not using the linearity properties of convolution (which we may even require the user to declare via extra arguments to vjp, if that could help). Surely such improvements, especially the former, would apply to many other cases.

Yet another avenue to speeding up specifically convolution is adding convolution primitives to horde-ad and recovering them after AD (and so vectorization). This is similar to how adding dot product (and maybe matmul) primitives might help and it's similarly future work. I think each of these approaches, if successful, would benefit one case or another of user programs and none of them is likely to completely supersede the others.

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions