Skip to content

does storch tensor not support tensor[2::4] set value tensor slice batch broadcast opration how to do like python pytorch style ?  #82

@mullerhai

Description

@mullerhai

HI ,
in python code , tensor slice batch broadcast opration is very easy and convient ,like this

// encoding[::, 0 :: 2] = torch.sin(position * div_term)
// encoding[::, 1 :: 2] = torch.cos(position * div_term)

but now in storch ,I can not do that to set batch value ,so how to do like this
do you know how to do ?

if I write this code like this

for (int ti = 0; ti < pTimeSize; ti++) {
    for (int ba = 0; ba < pBatchSize; ba++) {
        for (int ch = 0; ch < pChannelSize; ch++) {
            for (int ro = 0; ro < pRowSize; ro++) {
                for (int co = 0; co < pColumnSize; co++) {
                    int   i = 0, pos = 0;
                    float div_term = 0, pe = 0;
                    i   = co / 2;
                    pos = ch;
                    div_term = pow(10000, 2 * i / (double)pColumnSize);
                    if (co % 2 == 0)
                        pe = sin(pos / div_term);
                    else
                        pe = cos(pos / div_term);
                    (*pPositionalEncoding)[Index5D(peShape, ti, ba, ch, ro, co)] = pe;
                }
            }
        }
    }
}

I think not suit pytorch

or like this

private def initializeEncodings(): Unit = {
for (pos <- 0 until config.maxSeqLen; i <- 0 until config.embeddingDim) {
val angle = pos / math.pow(10000, 2 * (i / 2) / config.embeddingDim)
encodings(pos)(i) = if (i % 2 == 0) sin(angle) else cos(angle)
}
}

thanks


 val arr = Seq(max_len,d_model)
  var encoding = torch.zeros(size = arr.map(_.toInt), dtype = torch.Float32)
  val position = torch.arange(0, max_len, dtype =torch.Float32).unsqueeze(1)
  val div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(Tensor(10000.0)) / d_model))
//  encoding[::, 0 :: 2] = torch.sin(position * div_term)
//  encoding[::, 1 :: 2] = torch.cos(position * div_term)
  val scliceSin = torch.indexing.Slice(Some(0),Some(2),None)
  val scliceCos = torch.indexing.Slice(Some(1),Some(2),None)
  encoding[::,scliceSin] = torch.sin(position * div_term)
  encoding[::, sliceCos] = torch.cos(position * div_term)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions