Skip to content

common/layersのSoftmaxWithLossの動作について #25

@tetsuomiyoshi

Description

@tetsuomiyoshi

common/layersのSoftmaxWithLossの動作についてですが、現状のものでは、

forward(x, t)の引数のx, tが例えば、
x = np.array([[1.0, 1.5, 2.0], [1.2, 1.5, 1.7]])
t = np.array([[0, 0, 1], [0, 1, 0]])
または、
x = np.array([[1.0, 1.5, 2.0], [1.2, 1.5, 1.7]])
t = np.array([2, 1])
などのバッチ形式の時は問題ないのですが、

x = np.array([1.0, 1.5, 2.0])
t = np.array([0, 0, 1])
または、
x = np.array([1.0, 1.5, 2.0])
t = np.array(2)
などのベクトル形式の場合、エラーになってしまいます。

一応、クラスSoftmaxWithLossを下記のように変更し、

66 class SoftmaxWithLoss:
67 def init(self):
68 self.params, self.grads = [], []
69 self.y = None # softmaxの出力
70 self.t = None # 教師ラベル
71
72 def forward(self, x, t):
73 self.t = t
74 self.y = softmax(x)
75
76 # 教師ラベルがone-hotベクトルの場合、正解のインデックスに変換
77 if self.y.ndim == 1: # add
78 self.t = self.t.reshape(1, self.t.size) # add
79 self.y = self.y.reshape(1, self.y.size) # add
80
81 if self.t.size == self.y.size:
82 self.t = self.t.argmax(axis=1)
83
84 loss = cross_entropy_error(self.y, self.t)
85 return loss
86
87 def backward(self, dout=1):
88 #batch_size = self.t.shape[0 # delete]
89 batch_size = self.y.shape[0] # modify
90 #print('here1')
91 #print(batch_size)
92 dx = self.y.copy()
93 dx[np.arange(batch_size), self.t] -= 1
94 dx *= dout
95 dx = dx / batch_size
96
97 return dx
98

common.functionsの関数cross_entropy_error()を下記の様に変更すると、

25 def cross_entropy_error(y, t):
26 #print('here2')
27 #if y.ndim == 1: # delete
28 # t = t.reshape(1, t.size) # delete
29 # y = y.reshape(1, y.size) # delete
30
31 # 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換
32 #if t.size == y.size: # delet
33 # t = t.argmax(axis=1) # delet
34
35 batch_size = y.shape[0]
36
37 return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size
38

ベクトル形式でも問題なく動作してくれる様なのですが、現状のものは引数のxとtに関して
これ以外の入力形式を想定して書かれているのでしょうか?その場合、どういった入力形式を
想定しているのか教えて下さい。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions