File tree 2 files changed +16
-0
lines changed
2 files changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -66,6 +66,13 @@ def test_flatten(self):
66
66
a = a .flatten (0 , 1 )
67
67
self .assertEqual (tuple (a .shape ), (6 , 4 ))
68
68
69
+ def test_copy_ (self ):
70
+ with self .env :
71
+ a = torch .zeros ((2 , 3 ), device = "cpu" )
72
+ b = torch .ones ((2 , 3 ))
73
+ b .copy_ (a )
74
+ self .assertEqual (a , b .cpu ())
75
+
69
76
def test_rnn (self ):
70
77
model = SeqModel ()
71
78
x = torch .randn ((2 , 100 , 20 ))
Original file line number Diff line number Diff line change @@ -95,6 +95,15 @@ def shape(self):
95
95
@property
96
96
def ndim (self ):
97
97
return len (self ._elem .shape )
98
+
99
+ @property
100
+ def data (self ):
101
+ return self
102
+
103
+ def copy_ (self , other ):
104
+ if other .device .type == "cpu" :
105
+ other = other .to (self .device )
106
+ return torch .ops .aten .copy_ (self , other )
98
107
99
108
def flatten (self , start_dim = 0 , end_dim = - 1 ):
100
109
if end_dim == - 1 :
You can’t perform that action at this time.
0 commit comments