1212use Rubix \ML \Classifiers \KNearestNeighbors ;
1313use Rubix \ML \Classifiers \MultilayerPerceptron ;
1414use Rubix \ML \Datasets \Labeled ;
15+ use Rubix \ML \DataType ;
1516use Rubix \ML \Estimator ;
1617use Rubix \ML \NeuralNet \Layers \Dense ;
1718use Rubix \ML \NeuralNet \Layers \PReLU ;
@@ -131,14 +132,7 @@ public function handle()
131132 private function getEstimator (string $ modelPath , Estimator $ baseEstimator ): Estimator
132133 {
133134 $ estimator = new PersistentModel (
134- new Pipeline (
135- [
136- new MissingDataImputer (),
137- new OneHotEncoder (),
138- new ZScaleStandardizer (),
139- ],
140- $ baseEstimator
141- ),
135+ new Pipeline ($ this ->getTransformers ($ baseEstimator ), $ baseEstimator ),
142136 new Filesystem ($ modelPath )
143137 );
144138
@@ -151,12 +145,6 @@ private function getEstimator(string $modelPath, Estimator $baseEstimator): Esti
151145
152146 private function getDefaultBaseEstimator (bool $ continuous ): Estimator
153147 {
154- // $layers = [
155- // new Dense(100),
156- // new Dense(100),
157- // new Dense(100),
158- // ];
159-
160148 $ baseEstimator = new KDNeighbors ();
161149
162150 if ($ continuous ) {
@@ -165,4 +153,20 @@ private function getDefaultBaseEstimator(bool $continuous): Estimator
165153
166154 return $ baseEstimator ;
167155 }
156+
157+ private function getTransformers (Estimator $ estimator ): array
158+ {
159+ $ dataTypes = $ estimator ->compatibility ();
160+
161+ $ transformers = [];
162+ $ transformers [] = new MissingDataImputer ();
163+
164+ if (!in_array (DataType::categorical (), $ dataTypes ) && in_array (DataType::continuous (), $ dataTypes )) {
165+ $ transformers [] = new OneHotEncoder ();
166+ }
167+
168+ $ transformers [] = new ZScaleStandardizer ();
169+
170+ return $ transformers ;
171+ }
168172}
0 commit comments