|
14 | 14 |
|
15 | 15 |
|
16 | 16 | import numpy as np |
| 17 | +import pandas as pd |
17 | 18 | import pytest |
18 | | -import tensorflow as tf |
19 | 19 |
|
20 | 20 | from autokeras import test_utils |
21 | 21 | from autokeras.adapters import input_adapters |
|
25 | 25 | def test_image_input_adapter_transform_to_dataset(): |
26 | 26 | x = test_utils.generate_data() |
27 | 27 | adapter = input_adapters.ImageAdapter() |
28 | | - assert isinstance(adapter.adapt(x, batch_size=32), tf.data.Dataset) |
| 28 | + assert isinstance(adapter.adapt(x), np.ndarray) |
29 | 29 |
|
30 | 30 |
|
31 | 31 | def test_image_input_unsupported_type(): |
32 | 32 | x = "unknown" |
33 | 33 | adapter = input_adapters.ImageAdapter() |
34 | 34 | with pytest.raises(TypeError) as info: |
35 | | - x = adapter.adapt(x, batch_size=32) |
| 35 | + x = adapter.adapt(x) |
36 | 36 | assert "Expect the data to ImageInput to be numpy" in str(info.value) |
37 | 37 |
|
38 | 38 |
|
39 | 39 | def test_image_input_numerical(): |
40 | 40 | x = np.array([[["unknown"]]]) |
41 | 41 | adapter = input_adapters.ImageAdapter() |
42 | 42 | with pytest.raises(TypeError) as info: |
43 | | - x = adapter.adapt(x, batch_size=32) |
| 43 | + x = adapter.adapt(x) |
44 | 44 | assert "Expect the data to ImageInput to be numerical" in str(info.value) |
45 | 45 |
|
46 | 46 |
|
47 | 47 | def test_input_type_error(): |
48 | 48 | x = "unknown" |
49 | 49 | adapter = input_adapters.InputAdapter() |
50 | 50 | with pytest.raises(TypeError) as info: |
51 | | - x = adapter.adapt(x, batch_size=32) |
| 51 | + x = adapter.adapt(x) |
52 | 52 | assert "Expect the data to Input to be numpy" in str(info.value) |
53 | 53 |
|
54 | 54 |
|
55 | 55 | def test_input_numerical(): |
56 | 56 | x = np.array([[["unknown"]]]) |
57 | 57 | adapter = input_adapters.InputAdapter() |
58 | 58 | with pytest.raises(TypeError) as info: |
59 | | - x = adapter.adapt(x, batch_size=32) |
| 59 | + x = adapter.adapt(x) |
60 | 60 | assert "Expect the data to Input to be numerical" in str(info.value) |
61 | 61 |
|
62 | 62 |
|
63 | | -def test_text_adapt_unbatched_dataset(): |
64 | | - x = tf.data.Dataset.from_tensor_slices(np.array(["a b c", "b b c"])) |
65 | | - adapter = input_adapters.TextAdapter() |
66 | | - x = adapter.adapt(x, batch_size=32) |
67 | | - |
68 | | - assert data_utils.dataset_shape(x).as_list() == [None] |
69 | | - assert isinstance(x, tf.data.Dataset) |
70 | | - |
71 | | - |
72 | | -def test_text_adapt_batched_dataset(): |
73 | | - x = tf.data.Dataset.from_tensor_slices(np.array(["a b c", "b b c"])).batch( |
74 | | - 32 |
75 | | - ) |
76 | | - adapter = input_adapters.TextAdapter() |
77 | | - x = adapter.adapt(x, batch_size=32) |
78 | | - |
79 | | - assert data_utils.dataset_shape(x).as_list() == [None] |
80 | | - assert isinstance(x, tf.data.Dataset) |
81 | | - |
82 | | - |
83 | 63 | def test_text_adapt_np(): |
84 | 64 | x = np.array(["a b c", "b b c"]) |
85 | 65 | adapter = input_adapters.TextAdapter() |
86 | | - x = adapter.adapt(x, batch_size=32) |
| 66 | + x = adapter.adapt(x) |
87 | 67 |
|
88 | | - assert data_utils.dataset_shape(x).as_list() == [None] |
89 | | - assert isinstance(x, tf.data.Dataset) |
| 68 | + assert data_utils.dataset_shape(x) == [2] |
| 69 | + assert isinstance(x, np.ndarray) |
90 | 70 |
|
91 | 71 |
|
92 | 72 | def test_text_input_type_error(): |
93 | 73 | x = "unknown" |
94 | 74 | adapter = input_adapters.TextAdapter() |
95 | 75 | with pytest.raises(TypeError) as info: |
96 | | - x = adapter.adapt(x, batch_size=32) |
| 76 | + x = adapter.adapt(x) |
97 | 77 | assert "Expect the data to TextInput to be numpy" in str(info.value) |
| 78 | + |
| 79 | + |
| 80 | +def test_structured_data_input_unsupported_type_error(): |
| 81 | + with pytest.raises(TypeError) as info: |
| 82 | + adapter = input_adapters.StructuredDataAdapter() |
| 83 | + adapter.adapt("unknown") |
| 84 | + |
| 85 | + assert "Unsupported type" in str(info.value) |
| 86 | + |
| 87 | + |
| 88 | +def test_structured_data_input_transform_to_dataset(): |
| 89 | + x = pd.read_csv(test_utils.TRAIN_CSV_PATH).to_numpy().astype(str) |
| 90 | + adapter = input_adapters.StructuredDataAdapter() |
| 91 | + |
| 92 | + x = adapter.adapt(x) |
| 93 | + |
| 94 | + assert isinstance(x, np.ndarray) |
0 commit comments