@@ -30,6 +30,7 @@ impl<R: Records, S> DatasetBase<R, S> {
3030 targets,
3131 weights : Array1 :: zeros ( 0 ) ,
3232 feature_names : Vec :: new ( ) ,
33+ target_names : Vec :: new ( ) ,
3334 }
3435 }
3536
@@ -60,14 +61,8 @@ impl<R: Records, S> DatasetBase<R, S> {
6061 /// A feature name gives a human-readable string describing the purpose of a single feature.
6162 /// This allow the reader to understand its purpose while analysing results, for example
6263 /// correlation analysis or feature importance.
63- pub fn feature_names ( & self ) -> Vec < String > {
64- if !self . feature_names . is_empty ( ) {
65- self . feature_names . clone ( )
66- } else {
67- ( 0 ..self . records . nfeatures ( ) )
68- . map ( |idx| format ! ( "feature-{idx}" ) )
69- . collect ( )
70- }
64+ pub fn feature_names ( & self ) -> & [ String ] {
65+ & self . feature_names
7166 }
7267
7368 /// Return records of a dataset
@@ -81,13 +76,14 @@ impl<R: Records, S> DatasetBase<R, S> {
8176 /// Updates the records of a dataset
8277 ///
8378 /// This function overwrites the records in a dataset. It also invalidates the weights and
84- /// feature names.
79+ /// feature/target names.
8580 pub fn with_records < T : Records > ( self , records : T ) -> DatasetBase < T , S > {
8681 DatasetBase {
8782 records,
8883 targets : self . targets ,
8984 weights : Array1 :: zeros ( 0 ) ,
9085 feature_names : Vec :: new ( ) ,
86+ target_names : Vec :: new ( ) ,
9187 }
9288 }
9389
@@ -100,6 +96,7 @@ impl<R: Records, S> DatasetBase<R, S> {
10096 targets,
10197 weights : self . weights ,
10298 feature_names : self . feature_names ,
99+ target_names : self . target_names ,
103100 }
104101 }
105102
@@ -111,11 +108,14 @@ impl<R: Records, S> DatasetBase<R, S> {
111108 }
112109
113110 /// Updates the feature names of a dataset
111+ ///
112+ /// **Panics** when given names not empty and length does not equal to the number of features
114113 pub fn with_feature_names < I : Into < String > > ( mut self , names : Vec < I > ) -> DatasetBase < R , S > {
115- let feature_names = names. into_iter ( ) . map ( |x| x. into ( ) ) . collect ( ) ;
116-
117- self . feature_names = feature_names;
118-
114+ assert ! (
115+ names. is_empty( ) || names. len( ) == self . nfeatures( ) ,
116+ "Wrong number of feature names"
117+ ) ;
118+ self . feature_names = names. into_iter ( ) . map ( |x| x. into ( ) ) . collect ( ) ;
119119 self
120120 }
121121}
@@ -131,6 +131,18 @@ impl<X, Y> Dataset<X, Y> {
131131}
132132
133133impl < L , R : Records , T : AsTargets < Elem = L > > DatasetBase < R , T > {
134+ /// Updates the target names of a dataset
135+ ///
136+ /// **Panics** when given names not empty and length does not equal to the number of targets
137+ pub fn with_target_names < I : Into < String > > ( mut self , names : Vec < I > ) -> DatasetBase < R , T > {
138+ assert ! (
139+ names. is_empty( ) || names. len( ) == self . ntargets( ) ,
140+ "Wrong number of target names"
141+ ) ;
142+ self . target_names = names. into_iter ( ) . map ( |x| x. into ( ) ) . collect ( ) ;
143+ self
144+ }
145+
134146 /// Map targets with a function `f`
135147 ///
136148 /// # Example
@@ -153,6 +165,7 @@ impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
153165 targets,
154166 weights,
155167 feature_names,
168+ target_names,
156169 ..
157170 } = self ;
158171
@@ -163,9 +176,17 @@ impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
163176 targets : targets. map ( fnc) ,
164177 weights,
165178 feature_names,
179+ target_names,
166180 }
167181 }
168182
183+ /// Returns target names
184+ ///
185+ /// A target name gives a human-readable string describing the purpose of a single target.
186+ pub fn target_names ( & self ) -> & [ String ] {
187+ & self . target_names
188+ }
189+
169190 /// Return the number of targets in the dataset
170191 ///
171192 /// # Example
@@ -217,6 +238,7 @@ impl<'a, F: 'a, L: 'a, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
217238where
218239 D : Data < Elem = F > ,
219240 T : AsTargets < Elem = L > + FromTargetArray < ' a > ,
241+ T :: View : AsTargets < Elem = L > ,
220242{
221243 /// Creates a view of a dataset
222244 pub fn view ( & ' a self ) -> DatasetBase < ArrayView2 < ' a , F > , T :: View > {
@@ -226,6 +248,7 @@ where
226248 DatasetBase :: new ( records, targets)
227249 . with_feature_names ( self . feature_names . clone ( ) )
228250 . with_weights ( self . weights . clone ( ) )
251+ . with_target_names ( self . target_names . clone ( ) )
229252 }
230253
231254 /// Iterate over features
@@ -268,6 +291,7 @@ impl<L, R: Records, T: AsTargetsMut<Elem = L>> AsTargetsMut for DatasetBase<R, T
268291impl < ' a , L : ' a , F , T > DatasetBase < ArrayView2 < ' a , F > , T >
269292where
270293 T : AsTargets < Elem = L > + FromTargetArray < ' a > ,
294+ T :: View : AsTargets < Elem = L > ,
271295{
272296 /// Split dataset into two disjoint chunks
273297 ///
@@ -299,11 +323,13 @@ where
299323 } ;
300324 let dataset1 = DatasetBase :: new ( records_first, targets_first)
301325 . with_weights ( first_weights)
302- . with_feature_names ( self . feature_names . clone ( ) ) ;
326+ . with_feature_names ( self . feature_names . clone ( ) )
327+ . with_target_names ( self . target_names . clone ( ) ) ;
303328
304329 let dataset2 = DatasetBase :: new ( records_second, targets_second)
305330 . with_weights ( second_weights)
306- . with_feature_names ( self . feature_names . clone ( ) ) ;
331+ . with_feature_names ( self . feature_names . clone ( ) )
332+ . with_target_names ( self . target_names . clone ( ) ) ;
307333
308334 ( dataset1, dataset2)
309335 }
@@ -349,7 +375,8 @@ where
349375 label,
350376 DatasetBase :: new ( self . records ( ) . view ( ) , targets)
351377 . with_feature_names ( self . feature_names . clone ( ) )
352- . with_weights ( self . weights . clone ( ) ) ,
378+ . with_weights ( self . weights . clone ( ) )
379+ . with_target_names ( self . target_names . clone ( ) ) ,
353380 )
354381 } )
355382 . collect ( ) )
@@ -405,6 +432,7 @@ impl<F, D: Data<Elem = F>, I: Dimension> From<ArrayBase<D, I>>
405432 targets : empty_targets,
406433 weights : Array1 :: zeros ( 0 ) ,
407434 feature_names : Vec :: new ( ) ,
435+ target_names : Vec :: new ( ) ,
408436 }
409437 }
410438}
@@ -421,6 +449,7 @@ where
421449 targets : rec_tar. 1 ,
422450 weights : Array1 :: zeros ( 0 ) ,
423451 feature_names : Vec :: new ( ) ,
452+ target_names : Vec :: new ( ) ,
424453 }
425454 }
426455}
@@ -957,7 +986,8 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
957986 let n1 = ( self . nsamples ( ) as f32 * ratio) . ceil ( ) as usize ;
958987 let n2 = self . nsamples ( ) - n1;
959988
960- let feature_names = self . feature_names ( ) ;
989+ let feature_names = self . feature_names ( ) . to_vec ( ) ;
990+ let target_names = self . target_names ( ) . to_vec ( ) ;
961991
962992 // split records into two disjoint arrays
963993 let mut array_buf = self . records . into_raw_vec ( ) ;
@@ -990,10 +1020,12 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
9901020 // create new datasets with attached weights
9911021 let dataset1 = Dataset :: new ( first, first_targets)
9921022 . with_weights ( self . weights )
993- . with_feature_names ( feature_names. clone ( ) ) ;
1023+ . with_feature_names ( feature_names. clone ( ) )
1024+ . with_target_names ( target_names. clone ( ) ) ;
9941025 let dataset2 = Dataset :: new ( second, second_targets)
9951026 . with_weights ( second_weights)
996- . with_feature_names ( feature_names) ;
1027+ . with_feature_names ( feature_names. clone ( ) )
1028+ . with_target_names ( target_names. clone ( ) ) ;
9971029
9981030 ( dataset1, dataset2)
9991031 }
0 commit comments