Skip to content

Commit 52187b5

Browse files
committed
clean code fixup tensor dim and steps as Int type
1 parent 340e009 commit 52187b5

16 files changed

Lines changed: 403 additions & 398 deletions

File tree

build.sbt

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,10 @@ releaseProcess := Seq[ReleaseStep](
105105
//libraryDependencies += "org.bytedeco" % "pytorch-platform-gpu" % "2.7.1-1.5.12"
106106
libraryDependencies += "org.bytedeco" % "cuda" % "12.9-9.10-1.5.12"
107107
libraryDependencies += "org.apache.commons" % "commons-pool2" % "2.12.1"
108+
// https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-databind
109+
libraryDependencies += "com.fasterxml.jackson.core" % "jackson-databind" % "2.20.0"
108110

109-
libraryDependencies += "ai.djl" % "api" % "0.33.0"
111+
//libraryDependencies += "ai.djl" % "api" % "0.33.0"
110112
libraryDependencies += "com.alibaba.fastjson2" % "fastjson2" % "2.0.57"
111113
//libraryDependencies += "org.bytedeco" % "cuda-platform" % "12.9-9.10-1.5.12"
112114
libraryDependencies += "io.github.mullerhai" % "storch-numpy_3" % "0.1.7"
@@ -120,8 +122,8 @@ excludeDependencies += ExclusionRule(organization = "spire", name = "compat_3")
120122
excludeDependencies ++= Seq(
121123
"com.thesamet.scalapb" % "lenses_2.13" ,
122124
"com.thesamet.scalapb" % "scalapb-runtime_2.13",
123-
"io.github.mullerhai" % "storch-safe-tensor_3",
124-
"io.github.mullerhai" % "storch-polar_3",
125+
// "io.github.mullerhai" % "storch-safe-tensor_3",
126+
// "io.github.mullerhai" % "storch-polar_3",
125127
"com.lihaoyi" % "sourcecode_2.13",
126128
"spire" %"compat_2.13",
127129
"spire" %"compat_3"
@@ -149,25 +151,26 @@ lazy val storch_core = project
149151
"com.lihaoyi" %% "os-lib" % "0.9.1",
150152
"com.lihaoyi" %% "sourcecode" % "0.3.0",
151153
"dev.dirs" % "directories" % "26",
152-
"ai.djl" % "api" % "0.33.0",
154+
// "ai.djl" % "api" % "0.33.0",
155+
"com.fasterxml.jackson.core" % "jackson-databind" % "2.20.0",
153156
"com.alibaba.fastjson2" % "fastjson2" % "2.0.57",
154157
"org.apache.commons" % "commons-pool2" % "2.12.1",
155158
"io.github.mullerhai" % "storch-numpy_3" % "0.1.7",
156-
"io.github.mullerhai" % "storch-pandas_3" % "0.1.5",
159+
// "io.github.mullerhai" % "storch-pandas_3" % "0.1.5",
157160
"io.github.mullerhai" % "storch-pickle_3" % "0.1.4",
158161
"io.github.mullerhai" % "storch-tensorboard-proto_3" % "0.1.1",
159162
"io.github.mullerhai" % "storch-plot_3" % "0.0.3",
160163
"io.github.mullerhai" % "storch-scalapy_3" % "0.1.4-1.15.2" exclude("com.lihaoyi","sourcecode_2.13"),
161-
"io.github.mullerhai" % "storch-scikit-learn_3" % "0.1.2-1.15.2" exclude("org.scala-lang.modules","scala-collection-compat_2.13") exclude("org.typelevel","algebra_2.13")exclude("org.typelevel","cats-kernel_2.13"),
164+
// "io.github.mullerhai" % "storch-scikit-learn_3" % "0.1.2-1.15.2" exclude("org.scala-lang.modules","scala-collection-compat_2.13") exclude("org.typelevel","algebra_2.13")exclude("org.typelevel","cats-kernel_2.13"),
162165
"org.scalameta" %% "munit" % "0.7.29" % Test,
163166
"org.scalameta" %% "munit-scalacheck" % "0.7.29" % Test
164167
),
165168
excludeDependencies ++= Seq(
166169
"com.thesamet.scalapb" % "lenses_2.13" ,
167170
"com.lihaoyi" % "sourcecode_2.13",
168171
"com.thesamet.scalapb" % "scalapb-runtime_2.13",
169-
"io.github.mullerhai" % "storch-safe-tensor_3",
170-
"io.github.mullerhai" % "storch-polar_3",
172+
// "io.github.mullerhai" % "storch-safe-tensor_3",
173+
// "io.github.mullerhai" % "storch-polar_3",
171174
"spire" %"compat_2.13",
172175
"spire" %"compat_3"
173176
),

project/build.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
sbt.version=1.9.8
1+
sbt.version=1.11.0

storch_core/src/main/scala/torch/Tensor.scala

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
290290

291291
def scalar_type(un_used: Int*) = native.scalar_type()
292292

293-
def stride(dim: Long) = native.stride(dim)
293+
def stride(dim: Int) = native.stride(dim.toLong)
294294

295295
def size(dim: Long) = native.size(dim)
296296

@@ -1071,15 +1071,15 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
10711071
)
10721072
}
10731073

1074-
def norm[D1 <: DType, S <: ScalaType](p: S, dim: Long*): Tensor[D1] = {
1074+
def norm[D1 <: DType, S <: ScalaType](p: S, dim: Int*): Tensor[D1] = {
10751075

10761076
val pFloat = p match {
10771077
case m: Float => m
10781078
case m: Double => m.toFloat
10791079
case m: Int => m.toFloat
10801080
case m: Long => m.toFloat
10811081
}
1082-
fromNative(native.norm(ScalarOptional(toScalar(pFloat)), dim*))
1082+
fromNative(native.norm(ScalarOptional(toScalar(pFloat)), dim.map(_.toLong)*))
10831083
}
10841084

10851085
def norm[D1 <: DType, S <: ScalaType](
@@ -1138,8 +1138,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
11381138
* If there are multiple maximal values in a reduced row then the indices of the first maximal
11391139
* value are returned.
11401140
*/
1141-
def max(dim: Long, keepdim: Boolean = false): TensorTuple[D] =
1142-
val nativeTuple = native.max(dim, keepdim)
1141+
def max(dim: Int, keepdim: Boolean = false): TensorTuple[D] =
1142+
val nativeTuple = native.max(dim.toLong, keepdim)
11431143
TensorTuple(values = fromNative(nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1))
11441144

11451145
def maximum[D2 <: DType](other: Tensor[D2]): Tensor[Promoted[D, D2]] =
@@ -1229,7 +1229,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
12291229
def shape: Seq[Int] = size
12301230

12311231
def softmax[Out <: FloatNN | Derive](
1232-
dim: Long,
1232+
dim: Int,
12331233
dtype: Out = derive
12341234
): Tensor[DTypeOrDeriveFromTensor[D, Out]] = F.softmax(input = this, dim = dim, dtype = dtype)
12351235

@@ -1854,23 +1854,23 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
18541854

18551855
def to_sparse: Tensor[D] = fromNative(native.to_sparse)
18561856

1857-
def to_sparse_csr(dense_dim: Long): Tensor[D] = fromNative(
1858-
native.to_sparse_csr(LongOptional(dense_dim))
1857+
def to_sparse_csr(dense_dim: Int): Tensor[D] = fromNative(
1858+
native.to_sparse_csr(LongOptional(dense_dim.toLong))
18591859
)
18601860

1861-
def to_sparse_csc(dense_dim: Long): Tensor[D] = fromNative(
1862-
native.to_sparse_csc(LongOptional(dense_dim))
1861+
def to_sparse_csc(dense_dim: Int): Tensor[D] = fromNative(
1862+
native.to_sparse_csc(LongOptional(dense_dim.toLong))
18631863
)
18641864

1865-
def to_sparse_bsr(blockSize: Seq[Long], dense_dim: Long): Tensor[D] = fromNative(
1866-
native.to_sparse_bsr(blockSize.toArray, LongOptional(dense_dim))
1865+
def to_sparse_bsr(blockSize: Seq[Long], dense_dim: Int): Tensor[D] = fromNative(
1866+
native.to_sparse_bsr(blockSize.toArray, LongOptional(dense_dim.toLong))
18671867
)
18681868

1869-
def to_sparse_bsc(blockSize: Seq[Long], dense_dim: Long): Tensor[D] = fromNative(
1870-
native.to_sparse_bsc(blockSize.toArray, LongOptional(dense_dim))
1869+
def to_sparse_bsc(blockSize: Seq[Long], dense_dim: Int): Tensor[D] = fromNative(
1870+
native.to_sparse_bsc(blockSize.toArray, LongOptional(dense_dim.toLong))
18711871
)
18721872

1873-
def to_sparse_coo(sparse_dim: Long): Tensor[D] = fromNative(native.to_sparse(sparse_dim))
1873+
def to_sparse_coo(sparse_dim: Int): Tensor[D] = fromNative(native.to_sparse(sparse_dim.toLong))
18741874

18751875
def to_dense(un_used: Int*): Tensor[D] = fromNative(native.to_dense)
18761876

@@ -3284,8 +3284,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
32843284
fromNative(native.repeat_interleave(repeats.native, dimOpt, outputSizeOpt))
32853285
}
32863286

3287-
def repeat_interleave[D1 <: DType](repeats: Long, dim: Long): Tensor[D1] =
3288-
fromNative(native.repeat_interleave(repeats, new LongOptional(dim), new LongOptional()))
3287+
def repeat_interleave[D1 <: DType](repeats: Long, dim: Int): Tensor[D1] =
3288+
fromNative(native.repeat_interleave(repeats, new LongOptional(dim.toLong), new LongOptional()))
32893289

32903290
def repeat_interleave[D1 <: DType](
32913291
repeats: Long,
@@ -3744,10 +3744,10 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
37443744

37453745
def std(unbiased: Boolean): Tensor[D] = fromNative(native.std(unbiased))
37463746

3747-
def prod_with_dim(dim: Long, keepdim: Boolean = false, dtype: ScalarTypeOptional): Tensor[D] =
3748-
fromNative(native.prod(dim, keepdim, dtype))
3747+
def prod_with_dim(dim: Int, keepdim: Boolean = false, dtype: ScalarTypeOptional): Tensor[D] =
3748+
fromNative(native.prod(dim.toLong, keepdim, dtype))
37493749

3750-
def prod(dim: Long): Tensor[D] = fromNative(native.prod(dim))
3750+
def prod(dim: Int): Tensor[D] = fromNative(native.prod(dim.toLong))
37513751

37523752
def prod(un_used: Int*): Tensor[D] = fromNative(native.prod())
37533753

@@ -5078,7 +5078,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
50785078
fromNative(native.gather(dim.toLong, index.to(dtype = torch.int64).native, sparse_grad))
50795079
}
50805080

5081-
def gather(dim: Long, index: Tensor[Int64] | Tensor[Int32]): Tensor[D] = {
5081+
def gather(dim: Int, index: Tensor[Int64] | Tensor[Int32]): Tensor[D] = {
50825082
index.dtype match
50835083
case torch.int64 => fromNative(native.gather(dim.toLong, index.native))
50845084
case torch.int32 =>
@@ -5249,12 +5249,12 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
52495249

52505250
def digamma: Tensor[D] = fromNative(native.digamma())
52515251

5252-
def polygamma_(n: Long): this.type = {
5253-
native.polygamma_(n)
5252+
def polygamma_(n: Int): this.type = {
5253+
native.polygamma_(n.toLong)
52545254
this
52555255
}
52565256

5257-
def polygamma(n: Long): Tensor[D] = fromNative(native.polygamma(n))
5257+
def polygamma(n: Int): Tensor[D] = fromNative(native.polygamma(n.toLong))
52585258

52595259
def erfinv_(): this.type = {
52605260
native.erfinv_()
@@ -5455,12 +5455,12 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
54555455
native.remainder(other.native)
54565456
)
54575457

5458-
def renorm_[S <: ScalaType](p: S, dim: Long, maxnorm: S): this.type = {
5459-
native.renorm_(toScalar(p), dim, toScalar(maxnorm))
5458+
def renorm_[S <: ScalaType](p: S, dim: Int, maxnorm: S): this.type = {
5459+
native.renorm_(toScalar(p), dim.toLong, toScalar(maxnorm))
54605460
this
54615461
}
5462-
def renorm_(p: Float, dim: Long, maxnorm: Float): this.type = {
5463-
native.renorm_(toScalar(p), dim, toScalar(maxnorm))
5462+
def renorm_(p: Float, dim: Int, maxnorm: Float): this.type = {
5463+
native.renorm_(toScalar(p), dim.toLong, toScalar(maxnorm))
54645464
this
54655465
}
54665466

@@ -5566,8 +5566,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
55665566
(s, t)
55675567
}
55685568

5569-
def unfold(dimension: Long, size: Long, step: Long): Tensor[D] = fromNative(
5570-
native.unfold(dimension, size, step)
5569+
def unfold(dimension: Int, size: Int, step: Int): Tensor[D] = fromNative(
5570+
native.unfold(dimension.toLong, size.toLong, step.toLong)
55715571
)
55725572
def float_power[D1 <: DType](exponent: Tensor[D1]): Tensor[Promoted[D1, D]] = fromNative(
55735573
native.float_power(exponent.native)
@@ -5581,13 +5581,13 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
55815581
native.pow(exponent.native)
55825582
)
55835583

5584-
def renorm(p: Float, dim: Long, maxnorm: Float): Tensor[D] = fromNative(
5585-
native.renorm(toScalar(p), dim, toScalar(maxnorm))
5584+
def renorm(p: Float, dim: Int, maxnorm: Float): Tensor[D] = fromNative(
5585+
native.renorm(toScalar(p), dim.toLong, toScalar(maxnorm))
55865586
)
55875587

5588-
def renorm[S <: ScalaType](p: S, dim: Long, maxnorm: S): Tensor[Promoted[D, ScalaToDType[S]]] =
5588+
def renorm[S <: ScalaType](p: S, dim: Int, maxnorm: S): Tensor[Promoted[D, ScalaToDType[S]]] =
55895589
fromNative(
5590-
native.renorm(toScalar(p), dim, toScalar(maxnorm))
5590+
native.renorm(toScalar(p), dim.toLong, toScalar(maxnorm))
55915591
)
55925592

55935593
def alias(un_used: Int*): Tensor[D] = fromNative(native.alias())
@@ -5687,9 +5687,9 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
56875687
case _ =>
56885688
val innerSummary = {
56895689
def summarizeSlice(index: Int) = summarize(tensor(index), maxEntries)
5690-
val sliceLen = tensor.size(0).toInt
5690+
val sliceLen = tensor.size(0)
56915691
if sliceLen <= math.max(maxEntries, 6) then
5692-
for (i <- 0 until sliceLen.toInt) yield summarizeSlice(i)
5692+
for (i <- 0 until sliceLen) yield summarizeSlice(i)
56935693
else
56945694
val start = for (i <- 0 until maxEntries / 2) yield summarizeSlice(i)
56955695
val end = for (i <- sliceLen - maxEntries / 2 until sliceLen) yield summarizeSlice(i)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
package torch.amp
22

3+
//https://gist.github.com/dorpxam/67ad2bc222b2cf567d4a6fc298375e13
4+
35
class GradScaler {}

storch_core/src/main/scala/torch/nn/functional/Activations.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private[torch] trait Activations {
4646
*/
4747
def logSoftmax[In <: DType, Out <: FloatNN | Derive](
4848
input: Tensor[In],
49-
dim: Long,
49+
dim: Int,
5050
dtype: Out = derive
5151
): Tensor[DTypeOrDeriveFromTensor[In, Out]] =
5252
val derivedDType = dtype match
@@ -55,11 +55,11 @@ private[torch] trait Activations {
5555
val nativeDType =
5656
if dtype == input.dtype then ScalarTypeOptional()
5757
else ScalarTypeOptional(derivedDType.toScalarType)
58-
fromNative(torchNative.log_softmax(input.native, dim, nativeDType))
58+
fromNative(torchNative.log_softmax(input.native, dim.toLong, nativeDType))
5959

6060
def log_softmax[In <: DType, Out <: FloatNN | Derive](
6161
input: Tensor[In],
62-
dim: Long,
62+
dim: Int,
6363
dtype: Out = derive
6464
) = logSoftmax(input, dim, dtype)
6565

@@ -94,7 +94,7 @@ private[torch] trait Activations {
9494
*/
9595
def softmax[In <: DType, Out <: FloatNN | Derive](
9696
input: Tensor[In],
97-
dim: Long,
97+
dim: Int,
9898
dtype: Out = derive
9999
): Tensor[DTypeOrDeriveFromTensor[In, Out]] =
100100
val derivedDType = dtype match
@@ -103,7 +103,7 @@ private[torch] trait Activations {
103103
val nativeDType =
104104
if dtype == input.dtype then ScalarTypeOptional()
105105
else ScalarTypeOptional(derivedDType.toScalarType)
106-
fromNative(torchNative.softmax(input.native, dim, nativeDType))
106+
fromNative(torchNative.softmax(input.native, dim.toLong, nativeDType))
107107

108108
def relu_[D <: DType](input: Tensor[D]): Tensor[D] = fromNative(torchNative.relu_(input.native))
109109

@@ -160,8 +160,8 @@ private[torch] trait Activations {
160160

161161
def selu_[D <: DType](input: Tensor[D]): Tensor[D] = fromNative(torchNative.selu_(input.native))
162162

163-
def glu[D <: DType](input: Tensor[D], dim: Long): Tensor[D] = fromNative(
164-
torchNative.glu(input.native, dim)
163+
def glu[D <: DType](input: Tensor[D], dim: Int): Tensor[D] = fromNative(
164+
torchNative.glu(input.native, dim.toLong)
165165
)
166166

167167
def gelu[D <: DType](input: Tensor[D], approximate: String = "none"): Tensor[D] = {
@@ -193,8 +193,8 @@ private[torch] trait Activations {
193193
torchNative.softsign(input.native)
194194
)
195195

196-
def softmin[D <: DType](input: Tensor[D], dim: Long, dtype: DType): Tensor[D] = {
197-
val options = SoftminFuncOptions(dim)
196+
def softmin[D <: DType](input: Tensor[D], dim: Int, dtype: DType): Tensor[D] = {
197+
val options = SoftminFuncOptions(dim.toLong)
198198
options.dtype().put(ScalarTypeOptional(dtype.toScalarType))
199199
fromNative(
200200
torchNative.softmin(input.native, options)

0 commit comments

Comments
 (0)