Skip to content

Commit 03ed067

Browse files
author
Muller Zhang
committed
fix pytorch cuda mkl openblas new release version code compat
1 parent f7a9f10 commit 03ed067

19 files changed

Lines changed: 308 additions & 303 deletions

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ lazy val commonSettings = Seq(
5656
// This is a hack to avoid depending on the native libs when publishing
5757
// but conveniently have them on the classpath during development.
5858
// There's probably a cleaner way to do this.
59-
tlJdkRelease := Some(11)
59+
tlJdkRelease := Some(21)
6060
) ++ tlReplaceCommandAlias(
6161
"tlReleaseLocal",
6262
List(

project/plugins.sbt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ addSbtPlugin("com.github.sbt" % "sbt-release" % "1.4.0")
88

99
addSbtPlugin("org.bytedeco" % "sbt-javacpp" % "1.17")
1010
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.2")
11-
addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.6.5")
12-
addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.6.5")
11+
addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.7.7")
12+
addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.7.7")
1313
addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.12.2")
1414
addSbtPlugin("org.scalameta" % "sbt-native-image" % "0.3.4")
1515
if (sys.env.isDefinedAt("GITHUB_ACTION")) {

storch_core/src/main/scala/torch/Tensor.scala

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,11 @@
1717
package torch
1818

1919
import torch.numpy.matrix.NDArray
20-
import org.bytedeco.javacpp.{
21-
Pointer,
22-
BoolPointer,
23-
BytePointer,
24-
DoublePointer,
25-
FloatPointer,
26-
IntPointer,
27-
LongPointer,
28-
ShortPointer
29-
}
20+
import org.bytedeco.javacpp.{BoolPointer, BytePointer, DoublePointer, FloatPointer, IntPointer, LongPointer, Pointer, ShortPointer}
3021
import org.bytedeco.pytorch
31-
import org.bytedeco.pytorch.{
32-
Tensor as NativeTensor,
33-
BoolOptional,
34-
TensorOptions,
35-
TensorBase,
36-
DoubleOptional,
37-
EllipsisIndexType,
38-
Generator,
39-
GeneratorOptional,
40-
LongOptional,
41-
NamedTensorMeta,
42-
Node,
43-
Quantizer,
44-
ScalarOptional,
45-
ScalarTypeOptional,
46-
Storage,
47-
SymInt,
48-
SymIntOptional,
49-
TensorArrayRefOptional,
50-
TensorIndex,
51-
TensorIndexArrayRef,
52-
TensorIndexVector,
53-
TensorOptional,
54-
TensorOptionalList,
55-
TensorTensorHook,
56-
TensorVector,
57-
VoidTensorHook
58-
}
22+
import org.bytedeco.pytorch.{BoolOptional, DoubleOptional, EllipsisIndexType, Generator, GeneratorOptional, LongOptional, NamedTensorMeta, Node, Quantizer, ScalarOptional, ScalarTypeOptional, Storage, SymInt, SymIntOptional, TensorArrayRefOptional, TensorBase, TensorIndex, TensorIndexArrayRef, TensorIndexVector, TensorOptional, TensorOptionalList, TensorOptions, TensorTensorHook, TensorVector, VoidTensorHook, Tensor as NativeTensor}
5923
import org.bytedeco.pytorch.global.torch as torchNative
60-
import org.bytedeco.pytorch.global.torch.ScalarType
24+
import org.bytedeco.pytorch.global.torch.{ScalarType,Backend}
6125

6226
import java.nio.{Buffer, ByteBuffer, DoubleBuffer, FloatBuffer, IntBuffer, LongBuffer, ShortBuffer}
6327
import scala.collection.immutable.ArraySeq
@@ -1897,7 +1861,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
18971861

18981862
def put[D2 <: DType](x: Tensor[D2]): Tensor[Promoted[D, D2]] = fromNative(native.put(x.native))
18991863

1900-
def toBackend(b: Int): Tensor[D] = fromNative(native.toBackend(b))
1864+
def toBackend(b: Backend): Tensor[D] = fromNative(native.toBackend(b))
19011865

19021866
def not(un_used: Int*): Tensor[D] = fromNative(native.not())
19031867

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
11
package torch.amp
22

3-
class autocast {}
3+
import org.bytedeco.pytorch.global.torch
4+
import org.bytedeco.pytorch.global.torch.DeviceType
5+
6+
object autocast {
7+
8+
def is_autocast_enabled(device_type: DeviceType): Boolean ={
9+
// torch.set_autocast_enabled()
10+
// torch.is_autocast_cache_enabled()
11+
// torch.is_autocast_eligible()
12+
torch.is_autocast_enabled(device_type)
13+
}
14+
// torch.is_autocast_enabled(DeviceType.CUDA)
15+
16+
17+
}

storch_core/src/main/scala/torch/cuda/TorchCUDAAllocator.scala

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ class TorchCUDAAllocator(val nativeAllocator: CUDAAllocator) extends Closeable {
115115
* @param device
116116
* 设备索引
117117
*/
118-
def ensureExistsAndIncrefPool(
118+
def createOrIncrefPool(
119+
119120
device: Byte,
120121
mempool_id: DeviceAssertionsDataVectorCUDAKernelLaunchInfoVectorPair
121-
): Unit = nativeAllocator.ensureExistsAndIncrefPool(device, mempool_id)
122+
): Unit = nativeAllocator.createOrIncrefPool(device, mempool_id)
122123

123124
def checkPoolLiveAllocations(
124125
device: Byte,
@@ -132,19 +133,35 @@ class TorchCUDAAllocator(val nativeAllocator: CUDAAllocator) extends Closeable {
132133

133134
def isHistoryEnabled = nativeAllocator.isHistoryEnabled
134135

136+
// public native @Cast("bool") boolean isHistoryEnabled();
137+
// public native void recordHistory(
138+
// @Cast("bool") boolean enabled,
139+
// @ByVal @Cast("c10::cuda::CUDACachingAllocator::CreateContextFn*") Pointer context_recorder,
140+
// @Cast("size_t") long alloc_trace_max_entries,
141+
// RecordContext when,
142+
// @Cast("bool") boolean clearHistory);
143+
// public native void recordHistory(
144+
// @Cast("bool") boolean enabled,
145+
// @ByVal @Cast("c10::cuda::CUDACachingAllocator::CreateContextFn*") Pointer context_recorder,
146+
// @Cast("size_t") long alloc_trace_max_entries,
147+
// @Cast("c10::cuda::CUDACachingAllocator::RecordContext") int when,
148+
// @Cast("bool") boolean clearHistory);
149+
135150
def recordHistory(
136151
enabled: Boolean,
137152
context_recorder: Pointer,
138153
alloc_trace_max_entries: Long,
139-
when: torch_cuda.RecordContext
140-
) = nativeAllocator.recordHistory(enabled, context_recorder, alloc_trace_max_entries, when)
154+
when: torch_cuda.RecordContext,
155+
clearHistory: Boolean
156+
) = nativeAllocator.recordHistory(enabled, context_recorder, alloc_trace_max_entries, when,clearHistory)
141157

142158
def recordHistory(
143159
enabled: Boolean,
144160
context_recorder: Pointer,
145161
alloc_trace_max_entries: Long,
146-
when: Int
147-
) = nativeAllocator.recordHistory(enabled, context_recorder, alloc_trace_max_entries, when)
162+
when: Int,
163+
clearHistory: Boolean
164+
) = nativeAllocator.recordHistory(enabled, context_recorder, alloc_trace_max_entries, when,clearHistory)
148165

149166
def recordAnnotation(md: StringPair) = nativeAllocator.recordAnnotation(md)
150167

0 commit comments

Comments
 (0)