Skip to content

Commit 0b2c363

Browse files
authored
Bug fixes for python notebooks (#1955)
* bug fixes * update * use torch backend for all notebooks * test coverage
1 parent 3880104 commit 0b2c363

17 files changed

+185
-121
lines changed

autokeras/engine/analyser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def update(self, data):
4141
data: np.ndarray. The entire dataset.
4242
"""
4343
if self.dtype is None:
44-
if np.issubdtype(data.dtype, np.str_):
44+
if np.issubdtype(data.dtype, np.str_) or np.issubdtype(
45+
data.dtype, np.bytes_
46+
):
4547
self.dtype = "string"
4648
else:
4749
self.dtype = str(data.dtype)

autokeras/engine/analyser_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2020 The AutoKeras Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
18+
from autokeras.engine.analyser import Analyser
19+
20+
21+
def test_analyser_update_unicode_string_dtype():
22+
analyser = Analyser()
23+
data = np.array(["hello", "world"], dtype="U10")
24+
25+
analyser.update(data)
26+
27+
assert analyser.dtype == "string"
28+
assert analyser.shape == [2]
29+
assert analyser.batch_size == 2
30+
assert analyser.num_samples == 2
31+
32+
33+
def test_analyser_update_byte_string_dtype():
34+
analyser = Analyser()
35+
data = np.array([b"hello", b"world"], dtype="S10")
36+
37+
analyser.update(data)
38+
39+
assert analyser.dtype == "string"
40+
assert analyser.shape == [2]
41+
assert analyser.batch_size == 2
42+
assert analyser.num_samples == 2
43+
44+
45+
def test_analyser_update_numeric_dtype():
46+
analyser = Analyser()
47+
data = np.array([1, 2, 3], dtype=np.int32)
48+
49+
analyser.update(data)
50+
51+
assert analyser.dtype == "int32"
52+
assert analyser.shape == [3]
53+
assert analyser.batch_size == 3
54+
assert analyser.num_samples == 3
55+
56+
57+
def test_analyser_update_float_dtype():
58+
analyser = Analyser()
59+
data = np.array([1.0, 2.0, 3.0], dtype=np.float64)
60+
61+
analyser.update(data)
62+
63+
assert analyser.dtype == "float64"
64+
assert analyser.shape == [3]
65+
assert analyser.batch_size == 3
66+
assert analyser.num_samples == 3
67+
68+
69+
def test_analyser_finalize_not_implemented():
70+
analyser = Analyser()
71+
72+
with pytest.raises(NotImplementedError):
73+
analyser.finalize()

autokeras/graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@ def _compile_keras_model(self, hp, model):
296296
elif optimizer_name == "sgd":
297297
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
298298
elif optimizer_name == "adam_weight_decay":
299-
steps_per_epoch = int(self.num_samples / self.batch_size)
299+
steps_per_epoch = max(
300+
1, int(self.num_samples / (self.batch_size or 32))
301+
)
300302
num_train_steps = steps_per_epoch * self.epochs
301303

302304
lr_schedule = keras.optimizers.schedules.PolynomialDecay(
@@ -335,6 +337,7 @@ def set_fit_args(self, validation_split, epochs=None):
335337
# Epochs not specified by the user
336338
if self.epochs is None:
337339
self.epochs = 1
340+
validation_split = validation_split or 0
338341
# num_samples from analysers are before split
339342
self.num_samples = self.inputs[0].num_samples * (1 - validation_split)
340343

autokeras/graph_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ def test_adamw_optimizer():
8989
hp.Choice("optimizer", ["adam", "sgd", "adam_weight_decay"], default="adam")
9090
hp.values["optimizer"] = "adam_weight_decay"
9191
graph = graph_module.Graph(inputs=input_node, outputs=output_node)
92-
graph.num_samples = 10000
92+
graph.inputs[0].num_samples = 100
9393
graph.inputs[0].batch_size = 32
9494
graph.epochs = 10
95+
graph.set_fit_args(0, epochs=10)
9596
model = graph.build(hp)
9697
assert model.input_shape == (None, 30)
9798
assert model.output_shape == (None, 1)
@@ -168,3 +169,16 @@ def test_graph_can_init_with_one_missing_output():
168169
ak.ClassificationHead()(output_node)
169170

170171
graph_module.Graph(input_node, output_node)
172+
173+
174+
def test_set_fit_args_with_none_validation_split():
175+
input_node = ak.Input(shape=(30,))
176+
output_node = input_node
177+
output_node = ak.DenseBlock()(output_node)
178+
output_node = ak.RegressionHead(shape=(1,))(output_node)
179+
180+
graph = graph_module.Graph(inputs=input_node, outputs=output_node)
181+
graph.inputs[0].num_samples = 100
182+
graph.inputs[0].batch_size = 32
183+
graph.set_fit_args(None, epochs=1)
184+
assert graph.num_samples == 100 # Should handle None as 0

autokeras/nodes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,12 @@ def get_block(self):
145145

146146
def get_hyper_preprocessors(self):
147147
return [
148+
hyper_preprocessors.DefaultHyperPreprocessor(
149+
preprocessors.CastToString()
150+
),
148151
hyper_preprocessors.DefaultHyperPreprocessor(
149152
preprocessors.TextTokenizer()
150-
)
153+
),
151154
]
152155

153156

autokeras/preprocessors/common.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import Counter
16+
1517
import keras
1618
import numpy as np
1719

@@ -39,28 +41,35 @@ class CastToString(preprocessor.Preprocessor):
3941
"""Cast the dataset shape to string."""
4042

4143
def transform(self, dataset):
42-
return dataset.astype("str")
44+
if np.issubdtype(dataset.dtype, np.bytes_):
45+
return np.array(
46+
[x.decode("utf-8", errors="ignore") for x in dataset]
47+
)
48+
else:
49+
return dataset.astype("str")
4350

4451

4552
@keras.utils.register_keras_serializable(package="autokeras")
4653
class TextTokenizer(preprocessor.Preprocessor):
4754
"""Simple text tokenizer that converts strings to integer sequences."""
4855

49-
def __init__(self, max_len=100, vocab=None, **kwargs):
56+
def __init__(self, max_len=100, vocab=None, max_vocab=500, **kwargs):
5057
super().__init__(**kwargs)
5158
self.max_len = max_len
5259
self.vocab = vocab
60+
self.max_vocab = max_vocab
5361

5462
def fit(self, dataset):
5563
# Build vocab from unique words in the dataset
56-
unique_words = set()
64+
unique_words = []
5765
for text in dataset:
5866
words = text.split()
59-
unique_words.update(words)
60-
# Sort for consistency
61-
sorted_words = sorted(unique_words)
67+
unique_words.extend(words)
68+
word_counts = Counter(unique_words)
69+
sorted_words = sorted(word_counts, key=word_counts.get, reverse=True)
6270
self.vocab = {
63-
word: idx + 1 for idx, word in enumerate(sorted_words)
71+
word: idx + 1
72+
for idx, word in enumerate(sorted_words[: self.max_vocab])
6473
} # Start from 1, 0 for padding
6574

6675
def transform(self, dataset):
@@ -80,6 +89,7 @@ def get_config(self):
8089
{
8190
"max_len": self.max_len,
8291
"vocab": self.vocab,
92+
"max_vocab": self.max_vocab,
8393
}
8494
)
8595
return config

autokeras/preprocessors/common_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
16+
1517
from autokeras import test_utils
1618
from autokeras.preprocessors import common
1719

@@ -21,3 +23,39 @@ def test_cast_to_int32_return_int32():
2123
x = x.astype("uint8")
2224
x = common.CastToInt32().transform(x)
2325
assert x.dtype == "int32"
26+
27+
28+
def test_cast_to_string_with_bytes():
29+
x = np.array([b"hello", b"world"])
30+
result = common.CastToString().transform(x)
31+
assert result.dtype.kind in ["U", "S"] # Unicode or byte string
32+
assert result[0] == "hello"
33+
assert result[1] == "world"
34+
35+
36+
def test_cast_to_string_with_strings():
37+
x = np.array(["hello", "world"])
38+
result = common.CastToString().transform(x)
39+
assert result.dtype.kind in ["U", "S"]
40+
assert result[0] == "hello"
41+
assert result[1] == "world"
42+
43+
44+
def test_text_tokenizer_vocab_limit():
45+
x = np.array(["word1 word2 word3", "word1 word4 word5"])
46+
tokenizer = common.TextTokenizer(max_vocab=2)
47+
tokenizer.fit(x)
48+
assert len(tokenizer.vocab) <= 3 # 2 words + 1 for unknown (0 is padding)
49+
# word1 should be most frequent
50+
assert "word1" in tokenizer.vocab
51+
assert tokenizer.vocab["word1"] == 1
52+
53+
54+
def test_text_tokenizer_transform():
55+
x = np.array(["hello world", "hello"])
56+
tokenizer = common.TextTokenizer(max_vocab=10)
57+
tokenizer.fit(x)
58+
result = tokenizer.transform(x)
59+
assert result.shape == (2, 100) # max_len=100
60+
assert result.dtype == np.int32
61+
assert result[0][0] == tokenizer.vocab.get("hello", 0)

docs/py/customized.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""shell
2+
export KERAS_BACKEND="torch"
23
pip install autokeras
34
"""
45

docs/py/export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""shell
2+
export KERAS_BACKEND="torch"
23
pip install autokeras
34
"""
45

docs/py/image_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""shell
2+
export KERAS_BACKEND="torch"
23
pip install autokeras
34
"""
45

0 commit comments

Comments
 (0)