@@ -49,6 +49,8 @@ const char* OpenMLReader::list_dataset_qualities = "/data/qualities/{}";
4949const char * OpenMLReader::list_dataset_filter = " /data/list/{}" ;
5050/* FLOW API */
5151const char * OpenMLReader::flow_file = " /flow/{}" ;
52+ /* TASK API */
53+ const char * OpenMLReader::task_file = " /task/{}" ;
5254
5355const std::unordered_map<std::string, std::string>
5456 OpenMLReader::m_format_options = {{" xml" , xml_server},
@@ -102,10 +104,10 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
102104#endif // HAVE_CURL
103105
104106/* *
105- * Checks the returned flow in JSON format
106- * @param doc the parsed flow
107+ * Checks the returned response from OpenML in JSON format
108+ * @param doc the parsed OpenML JSON format response
107109 */
108- static void check_flow_response ( Document& doc)
110+ static void check_response ( const Document& doc, const std::string& type )
109111{
110112 if (SG_UNLIKELY (doc.HasMember (" error" )))
111113 {
@@ -115,7 +117,9 @@ static void check_flow_response(Document& doc)
115117 root[" message" ].GetString ())
116118 return ;
117119 }
118- REQUIRE (doc.HasMember (" flow" ), " Unexpected format of OpenML flow.\n " );
120+ REQUIRE (
121+ doc.HasMember (type.c_str ()), " Unexpected format of OpenML %s.\n " ,
122+ type.c_str ());
119123}
120124
121125/* *
@@ -142,8 +146,7 @@ static SG_FORCED_INLINE void emplace_string_to_map(
142146 * @param name the name of the key
143147 */
144148static SG_FORCED_INLINE void emplace_string_to_map (
145- const GenericObject<
146- true , GenericValue<UTF8<char >>>& v,
149+ const GenericObject<true , GenericValue<UTF8<char >>>& v,
147150 std::unordered_map<std::string, std::string>& param_dict,
148151 const std::string& name)
149152{
@@ -167,7 +170,7 @@ std::shared_ptr<OpenMLFlow> OpenMLFlow::download_flow(
167170 auto reader = OpenMLReader (api_key);
168171 auto return_string = reader.get (" flow_file" , " json" , flow_id);
169172 document.Parse (return_string.c_str ());
170- check_flow_response (document);
173+ check_response (document, " flow " );
171174
172175 // store root for convenience. We know it exists from previous check.
173176 const Value& root = document[" flow" ];
@@ -248,10 +251,63 @@ std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file()
248251 return std::shared_ptr<OpenMLFlow>();
249252}
250253
254+ std::shared_ptr<OpenMLTask>
255+ OpenMLTask::get_dataset (const std::string& task_id, const std::string& api_key)
256+ {
257+ Document document;
258+ std::string task_name;
259+ std::string task_type;
260+ std::string task_type_id;
261+ std::pair<std::shared_ptr<OpenMLData>, std::shared_ptr<OpenMLSplit>>
262+ task_descriptor;
263+
264+ auto reader = OpenMLReader (api_key);
265+ auto return_string = reader.get (" task_file" , " json" , task_id);
266+
267+ document.Parse (return_string.c_str ());
268+ check_response (document, " task" );
269+
270+ const Value& root = document[" flow" ];
271+
272+ REQUIRE (
273+ task_id == root[" task_id" ].GetString (),
274+ " Expected downloaded task to have the same id as the requested task "
275+ " id.\n " )
276+
277+ task_name = root[" task_name" ].GetString ();
278+ task_type = root[" task_type" ].GetString ();
279+ task_type_id = root[" task_type_id" ].GetString ();
280+
281+ // expect two elements in input array: dataset and split
282+ const Value& json_input = root[" input" ];
283+
284+ REQUIRE (
285+ json_input.IsArray (), " Currently the dataset reader can only handle "
286+ " inputs with a dataset and split field" )
287+
288+ auto input_array = json_input.GetArray ();
289+ REQUIRE (
290+ input_array.Size () == 2 ,
291+ " Currently the dataset reader can only handle inputs with a dataset "
292+ " and split field. Found %d elements." ,
293+ input_array.Size ())
294+
295+ // handle dataset
296+ auto json_dataset = input_array[0 ].GetObject ();
297+
298+ auto result = std::make_shared<OpenMLTask>(
299+ task_id, task_name, task_type, task_type_id, task_descriptor);
300+
301+ return result;
302+ }
303+
251304/* *
252305 * Class using the Any visitor pattern to convert
253306 * a string to a C++ type that can be used as a parameter
254- * in a Shogun model.
307+ * in a Shogun model. If the string value is not "null" it will
308+ * be put in its casted type in the given model with the provided parameter
309+ * name. If the value is null nothing happens, i.e. no error is thrown
310+ * and no value is put in model.
255311 */
256312class StringToShogun : public AnyVisitor
257313{
@@ -266,18 +322,20 @@ class StringToShogun : public AnyVisitor
266322
267323 void on (bool * v) final
268324 {
325+ SG_SDEBUG (
326+ " bool: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
269327 if (!is_null ())
270328 {
271- SG_SDEBUG (" bool: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
272329 bool result = strcmp (m_string_val.c_str (), " true" ) == 0 ;
273330 m_model->put (m_parameter, result);
274331 }
275332 }
276333 void on (int32_t * v) final
277334 {
335+ SG_SDEBUG (
336+ " int32: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
278337 if (!is_null ())
279338 {
280- SG_SDEBUG (" int32: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
281339 try
282340 {
283341 int32_t result = std::stoi (m_string_val);
@@ -299,84 +357,98 @@ class StringToShogun : public AnyVisitor
299357 }
300358 void on (int64_t * v) final
301359 {
360+ SG_SDEBUG (
361+ " int64: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
302362 if (!is_null ())
303363 {
304- SG_SDEBUG ( " int64: %s=%s \n " , m_parameter. c_str (), m_string_val. c_str ())
364+
305365 int64_t result = std::stol (m_string_val);
306366 m_model->put (m_parameter, result);
307367 }
308368 }
309369 void on (float * v) final
310370 {
371+ SG_SDEBUG (
372+ " float: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
311373 if (!is_null ())
312374 {
313- SG_SDEBUG (" float: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
314- char * end;
315- float32_t result = std::strtof (m_string_val.c_str (), &end);
375+ float32_t result = std::stof (m_string_val);
316376 m_model->put (m_parameter, result);
317377 }
318378 }
319379 void on (double * v) final
320380 {
381+ SG_SDEBUG (
382+ " double: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
321383 if (!is_null ())
322384 {
323- SG_SDEBUG (" double: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
324- char * end;
325- float64_t result = std::strtod (m_string_val.c_str (), &end);
385+ float64_t result = std::stod (m_string_val);
326386 m_model->put (m_parameter, result);
327387 }
328388 }
329389 void on (long double * v)
330390 {
391+ SG_SDEBUG (
392+ " long double: %s=%s\n " , m_parameter.c_str (),
393+ m_string_val.c_str ())
331394 if (!is_null ())
332395 {
333- SG_SDEBUG (" long double: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
334- char * end;
335- floatmax_t result = std::strtold (m_string_val.c_str (), &end);
396+ floatmax_t result = std::stold (m_string_val);
336397 m_model->put (m_parameter, result);
337398 }
338399 }
339400 void on (CSGObject** v) final
340401 {
341- SG_SDEBUG (" CSGObject: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
402+ SG_SDEBUG (
403+ " CSGObject: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
342404 }
343405 void on (SGVector<int >* v) final
344406 {
345- SG_SDEBUG (" SGVector<int>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
407+ SG_SDEBUG (
408+ " SGVector<int>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
346409 }
347410 void on (SGVector<float >* v) final
348411 {
349- SG_SDEBUG (" SGVector<float>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
412+ SG_SDEBUG (
413+ " SGVector<float>: %s=%s\n " , m_parameter.c_str (),
414+ m_string_val.c_str ())
350415 }
351416 void on (SGVector<double >* v) final
352417 {
353- SG_SDEBUG (" SGVector<double>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
418+ SG_SDEBUG (
419+ " SGVector<double>: %s=%s\n " , m_parameter.c_str (),
420+ m_string_val.c_str ())
354421 }
355422 void on (SGMatrix<int >* mat) final
356423 {
357- SG_SDEBUG (" SGMatrix<int>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
424+ SG_SDEBUG (
425+ " SGMatrix<int>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
358426 }
359427 void on (SGMatrix<float >* mat) final
360428 {
361- SG_SDEBUG (" SGMatrix<float>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
362- }
363- void on (SGMatrix<double >* mat) final
364- {
365- SG_SDEBUG (" SGMatrix<double>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
429+ SG_SDEBUG (
430+ " SGMatrix<float>: %s=%s\n " , m_parameter.c_str (),
431+ m_string_val.c_str ())
366432 }
367-
368- bool is_null ()
433+ void on (SGMatrix<double >* mat) final {SG_SDEBUG (
434+ " SGMatrix<double>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())}
435+
436+ /* *
437+ * In OpenML "null" is an empty parameter value field.
438+ * @return whether the field is "null"
439+ */
440+ SG_FORCED_INLINE bool is_null ()
369441 {
370442 bool result = strcmp (m_string_val.c_str (), " null" ) == 0 ;
371443 return result;
372444 }
373445
374- void set_parameter_name (const std::string& name)
446+ SG_FORCED_INLINE void set_parameter_name (const std::string& name)
375447 {
376448 m_parameter = name;
377449 }
378450
379- void set_string_value (const std::string& value)
451+ SG_FORCED_INLINE void set_string_value (const std::string& value)
380452 {
381453 m_string_val = value;
382454 }
@@ -396,17 +468,16 @@ class StringToShogun : public AnyVisitor
396468std::shared_ptr<CSGObject> instantiate_model_from_factory (
397469 const std::string& factory_name, const std::string& algo_name)
398470{
399- std::shared_ptr<CSGObject> obj;
400471 if (factory_name == " machine" )
401- obj = std::shared_ptr<CSGObject>(machine (algo_name));
402- else if (factory_name == " kernel" )
403- obj = std::shared_ptr<CSGObject>(kernel (algo_name));
404- else if (factory_name == " distance" )
405- obj = std::shared_ptr<CSGObject>(distance (algo_name));
406- else
407- SG_SERROR (" Unsupported factory \" %s\" .\n " , factory_name.c_str ())
472+ return std::shared_ptr<CSGObject>(machine (algo_name));
473+ if (factory_name == " kernel" )
474+ return std::shared_ptr<CSGObject>(kernel (algo_name));
475+ if (factory_name == " distance" )
476+ return std::shared_ptr<CSGObject>(distance (algo_name));
408477
409- return obj;
478+ SG_SERROR (" Unsupported factory \" %s\" .\n " , factory_name.c_str ())
479+
480+ return nullptr ;
410481}
411482
412483/* *
@@ -426,19 +497,21 @@ void cast_and_put(
426497 // temporary fix until shared_ptr PR merged
427498 auto * tmp_clone = dynamic_cast <CMachine*>(casted_obj->clone ());
428499 obj->put (parameter_name, tmp_clone);
500+ return ;
429501 }
430- else if (auto casted_obj = std::dynamic_pointer_cast<CKernel>(nested_obj))
502+ if (auto casted_obj = std::dynamic_pointer_cast<CKernel>(nested_obj))
431503 {
432504 auto * tmp_clone = dynamic_cast <CKernel*>(casted_obj->clone ());
433505 obj->put (parameter_name, tmp_clone);
506+ return ;
434507 }
435- else if (auto casted_obj = std::dynamic_pointer_cast<CDistance>(nested_obj))
508+ if (auto casted_obj = std::dynamic_pointer_cast<CDistance>(nested_obj))
436509 {
437510 auto * tmp_clone = dynamic_cast <CDistance*>(casted_obj->clone ());
438511 obj->put (parameter_name, tmp_clone);
512+ return ;
439513 }
440- else
441- SG_SERROR (" Could not cast SGObject.\n " )
514+ SG_SERROR (" Could not cast SGObject.\n " )
442515}
443516
444517std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model (
@@ -447,8 +520,8 @@ std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model(
447520 auto params = flow->get_parameters ();
448521 auto components = flow->get_components ();
449522 auto class_name = get_class_info (flow->get_class_name ());
450- auto module_name = std::get< 0 >( class_name) ;
451- auto algo_name = std::get< 1 >( class_name) ;
523+ auto module_name = class_name. first ;
524+ auto algo_name = class_name. second ;
452525
453526 auto obj = instantiate_model_from_factory (module_name, algo_name);
454527 auto obj_param = obj->get_params ();
@@ -486,12 +559,12 @@ ShogunOpenML::model_to_flow(const std::shared_ptr<CSGObject>& model)
486559 return std::shared_ptr<OpenMLFlow>();
487560}
488561
489- std::tuple <std::string, std::string>
562+ std::pair <std::string, std::string>
490563ShogunOpenML::get_class_info (const std::string& class_name)
491564{
492565 std::vector<std::string> class_components;
493566 auto begin = class_name.begin ();
494- std::tuple <std::string, std::string> result;
567+ std::pair <std::string, std::string> result;
495568
496569 for (auto it = class_name.begin (); it != class_name.end (); ++it)
497570 {
@@ -503,15 +576,16 @@ ShogunOpenML::get_class_info(const std::string& class_name)
503576 if (std::next (it) == class_name.end ())
504577 class_components.emplace_back (std::string (begin, std::next (it)));
505578 }
506- if (class_components[0 ] == " shogun" )
507- result = std::make_tuple (class_components[1 ], class_components[2 ]);
579+
580+ if (class_components[0 ] == " shogun" && class_components.size () == 3 )
581+ result = std::make_pair (class_components[1 ], class_components[2 ]);
582+ else if (class_components[0 ] == " shogun" && class_components.size () != 3 )
583+ SG_SERROR (" Invalid class name format %s.\n " , class_name.c_str ())
508584 else
509585 SG_SERROR (
510586 " The provided flow is not meant for shogun deserialisation! The "
511587 " required library is \" %s\" .\n " ,
512588 class_components[0 ].c_str ())
513- if (class_components.size () != 3 )
514- SG_SERROR (" Invalid class name format %s.\n " , class_name.c_str ())
515589
516590 return result;
517591}
0 commit comments