Skip to content

Commit 54b0ee2

Browse files
micklatlamblin
authored andcommitted
Check that mlp with index targets and vector targets behave the same.
Compare 'misclass' and 'nll' monitoring channels.
1 parent 15eefc5 commit 54b0ee2

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

pylearn2/models/tests/test_mlp.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,46 @@ def test_softmax_binary_targets():
400400
np.testing.assert_allclose(cost_bin(X_data, y_bin_data),
401401
cost_vec(X_data, y_vec_data))
402402

403+
def test_softmax_bin_targets_channels(seed=0):
404+
"""
405+
Constructs softmax layers with binary target and with vector targets
406+
to check that they give the same 'misclass' channel value.
407+
"""
408+
np.random.seed(seed)
409+
num_classes = 2
410+
batch_size = 5
411+
mlp_bin = MLP(
412+
layers=[Softmax(num_classes, 's1', irange=0.1,
413+
binary_target_dim=1)],
414+
nvis=100
415+
)
416+
mlp_vec = MLP(
417+
layers=[Softmax(num_classes, 's1', irange=0.1)],
418+
nvis=100
419+
)
420+
421+
X = mlp_bin.get_input_space().make_theano_batch()
422+
y_bin = mlp_bin.get_target_space().make_theano_batch()
423+
y_vec = mlp_vec.get_target_space().make_theano_batch()
424+
425+
X_data = np.random.random(size=(batch_size, 100))
426+
X_data = X_data.astype(theano.config.floatX)
427+
y_bin_data = np.random.randint(low=0, high=num_classes,
428+
size=(batch_size, 1))
429+
y_vec_data = np.zeros((batch_size, num_classes), dtype=theano.config.floatX)
430+
y_vec_data[np.arange(batch_size),y_bin_data.flatten()] = 1
431+
432+
def channel_value(channel_name, model, y, y_data):
433+
chans = model.get_monitoring_channels((X,y))
434+
f_channel = theano.function([X,y], chans['s1_'+channel_name])
435+
return f_channel(X_data, y_data)
436+
437+
for channel_name in ['misclass', 'nll']:
438+
vec_val = channel_value(channel_name, mlp_vec, y_vec, y_vec_data)
439+
bin_val = channel_value(channel_name, mlp_bin, y_bin, y_bin_data)
440+
print channel_name, vec_val, bin_val
441+
np.testing.assert_allclose(vec_val, bin_val)
442+
403443
def test_set_get_weights_Softmax():
404444
"""
405445
Tests setting and getting weights for Softmax layer.

0 commit comments

Comments
 (0)