1515using namespace shogun ;
1616using namespace rapidjson ;
1717
18+ /* *
19+ * The writer callback function used to write the packets to a C++ string.
20+ * @param data the data received in CURL request
21+ * @param size always 1
22+ * @param nmemb the size of data
23+ * @param buffer_in the buffer to write to
24+ * @return the size of buffer that was written
25+ */
1826size_t writer (char * data, size_t size, size_t nmemb, std::string* buffer_in)
1927{
2028 // adapted from https://stackoverflow.com/a/5780603
@@ -30,13 +38,16 @@ size_t writer(char* data, size_t size, size_t nmemb, std::string* buffer_in)
3038 return 0 ;
3139}
3240
41+ /* OpenML server format */
3342const char * OpenMLReader::xml_server = " https://www.openml.org/api/v1/xml" ;
3443const char * OpenMLReader::json_server = " https://www.openml.org/api/v1/json" ;
44+ /* DATA API */
3545const char * OpenMLReader::dataset_description = " /data/{}" ;
3646const char * OpenMLReader::list_data_qualities = " /data/qualities/list" ;
3747const char * OpenMLReader::data_features = " /data/features/{}" ;
3848const char * OpenMLReader::list_dataset_qualities = " /data/qualities/{}" ;
3949const char * OpenMLReader::list_dataset_filter = " /data/list/{}" ;
50+ /* FLOW API */
4051const char * OpenMLReader::flow_file = " /flow/{}" ;
4152
4253const std::unordered_map<std::string, std::string>
@@ -84,25 +95,16 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
8495 if (code != CURLE_OK)
8596 {
8697 // TODO: call curl_easy_cleanup(curl_handle) ?
87- SG_SERROR (" Curl error: %s\n " , curl_easy_strerror (code))
98+ SG_SERROR (" Connection error: %s. \n " , curl_easy_strerror (code))
8899 }
89- // else
90- // {
91- // long response_code;
92- // curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE,
93- // &response_code); if (response_code == 200) return;
94- // else
95- // {
96- // if (response_code == 181)
97- // SG_SERROR("Unknown flow. The flow with the given ID was not
98- // found in the database.") else if (response_code == 180)
99- // SG_SERROR("") SG_SERROR("Server code: %d\n", response_code)
100- // }
101- // }
102100}
103101
104102#endif // HAVE_CURL
105103
104+ /* *
105+ * Checks the returned flow in JSON format
106+ * @param doc the parsed flow
107+ */
106108static void check_flow_response (rapidjson::Document& doc)
107109{
108110 if (SG_UNLIKELY (doc.HasMember (" error" )))
@@ -116,24 +118,36 @@ static void check_flow_response(rapidjson::Document& doc)
116118 REQUIRE (doc.HasMember (" flow" ), " Unexpected format of OpenML flow.\n " );
117119}
118120
121+ /* *
122+ * Helper function to add JSON objects as string in map
123+ * @param v a RapidJSON GenericValue, i.e. string
124+ * @param param_dict the map to write to
125+ * @param name the name of the key
126+ */
119127static SG_FORCED_INLINE void emplace_string_to_map (
120- const rapidjson:: GenericValue<rapidjson:: UTF8<char >>& v,
128+ const GenericValue<UTF8<char >>& v,
121129 std::unordered_map<std::string, std::string>& param_dict,
122130 const std::string& name)
123131{
124- if (v[name.c_str ()].GetType () == rapidjson:: Type::kStringType )
132+ if (v[name.c_str ()].GetType () == Type::kStringType )
125133 param_dict.emplace (name, v[name.c_str ()].GetString ());
126134 else
127135 param_dict.emplace (name, " " );
128136}
129137
138+ /* *
139+ * Helper function to add JSON objects as string in map
140+ * @param v a RapidJSON GenericObject, i.e. array
141+ * @param param_dict the map to write to
142+ * @param name the name of the key
143+ */
130144static SG_FORCED_INLINE void emplace_string_to_map (
131- const rapidjson:: GenericObject<
132- true , rapidjson:: GenericValue<rapidjson:: UTF8<char >>>& v,
145+ const GenericObject<
146+ true , GenericValue<UTF8<char >>>& v,
133147 std::unordered_map<std::string, std::string>& param_dict,
134148 const std::string& name)
135149{
136- if (v[name.c_str ()].GetType () == rapidjson:: Type::kStringType )
150+ if (v[name.c_str ()].GetType () == Type::kStringType )
137151 param_dict.emplace (name, v[name.c_str ()].GetString ());
138152 else
139153 param_dict.emplace (name, " " );
@@ -234,52 +248,235 @@ std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file()
234248 return std::shared_ptr<OpenMLFlow>();
235249}
236250
251+ /* *
252+ * Class using the Any visitor pattern to convert
253+ * a string to a C++ type that can be used as a parameter
254+ * in a Shogun model.
255+ */
256+ class StringToShogun : public AnyVisitor
257+ {
258+ public:
259+ explicit StringToShogun (std::shared_ptr<CSGObject> model)
260+ : m_model(model), m_parameter(" " ), m_string_val(" " ){};
261+
262+ StringToShogun (
263+ std::shared_ptr<CSGObject> model, const std::string& parameter,
264+ const std::string& string_val)
265+ : m_model(model), m_parameter(parameter), m_string_val(string_val){};
266+
267+ void on (bool * v) final
268+ {
269+ if (!is_null ())
270+ {
271+ SG_SDEBUG (" bool: %s=%s\n " , m_parameter, m_string_val)
272+ bool result = strcmp (m_string_val.c_str (), " true" ) == 0 ;
273+ m_model->put (m_parameter, result);
274+ }
275+ }
276+ void on (int32_t * v) final
277+ {
278+ if (!is_null ())
279+ {
280+ SG_SDEBUG (" int32: %s=%s\n " , m_parameter, m_string_val)
281+ try
282+ {
283+ int32_t result = std::stoi (m_string_val);
284+ m_model->put (m_parameter, result);
285+ }
286+ catch (const std::invalid_argument&)
287+ {
288+ // it's an option, i.e. internally represented
289+ // as an enum but in swig exposed as a string
290+ m_string_val.erase (
291+ std::remove_if (
292+ m_string_val.begin (), m_string_val.end (),
293+ // remove quotes
294+ [](const auto & val) { return val == ' \" ' ; }),
295+ m_string_val.end ());
296+ m_model->put (m_parameter, m_string_val);
297+ }
298+ }
299+ }
300+ void on (int64_t * v) final
301+ {
302+ if (!is_null ())
303+ {
304+ SG_SDEBUG (" int64: %s=%s\n " , m_parameter, m_string_val)
305+ int64_t result = std::stol (m_string_val);
306+ m_model->put (m_parameter, result);
307+ }
308+ }
309+ void on (float * v) final
310+ {
311+ if (!is_null ())
312+ {
313+ SG_SDEBUG (" float: %s=%s\n " , m_parameter, m_string_val)
314+ char * end;
315+ float32_t result = std::strtof (m_string_val.c_str (), &end);
316+ m_model->put (m_parameter, result);
317+ }
318+ }
319+ void on (double * v) final
320+ {
321+ if (!is_null ())
322+ {
323+ SG_SDEBUG (" double: %s=%s\n " , m_parameter, m_string_val)
324+ char * end;
325+ float64_t result = std::strtod (m_string_val.c_str (), &end);
326+ m_model->put (m_parameter, result);
327+ }
328+ }
329+ void on (long double * v)
330+ {
331+ if (!is_null ())
332+ {
333+ SG_SDEBUG (" long double: %s=%s\n " , m_parameter, m_string_val)
334+ char * end;
335+ floatmax_t result = std::strtold (m_string_val.c_str (), &end);
336+ m_model->put (m_parameter, result);
337+ }
338+ }
339+ void on (CSGObject** v) final
340+ {
341+ SG_SDEBUG (" CSGObject: %s=%s\n " , m_parameter, m_string_val)
342+ }
343+ void on (SGVector<int >* v) final
344+ {
345+ SG_SDEBUG (" SGVector<int>: %s=%s\n " , m_parameter, m_string_val)
346+ }
347+ void on (SGVector<float >* v) final
348+ {
349+ SG_SDEBUG (" SGVector<float>: %s=%s\n " , m_parameter, m_string_val)
350+ }
351+ void on (SGVector<double >* v) final
352+ {
353+ SG_SDEBUG (" SGVector<double>: %s=%s\n " , m_parameter, m_string_val)
354+ }
355+ void on (SGMatrix<int >* mat) final
356+ {
357+ SG_SDEBUG (" SGMatrix<int>: %s=%s\n " , m_parameter, m_string_val)
358+ }
359+ void on (SGMatrix<float >* mat) final
360+ {
361+ SG_SDEBUG (" SGMatrix<float>: %s=%s\n " , m_parameter, m_string_val)
362+ }
363+ void on (SGMatrix<double >* mat) final
364+ {
365+ SG_SDEBUG (" SGMatrix<double>: %s=%s\n " , m_parameter, m_string_val)
366+ }
367+
368+ bool is_null ()
369+ {
370+ bool result = strcmp (m_string_val.c_str (), " null" ) == 0 ;
371+ return result;
372+ }
373+
374+ void set_parameter_name (const std::string& name)
375+ {
376+ m_parameter = name;
377+ }
378+
379+ void set_string_value (const std::string& value)
380+ {
381+ m_string_val = value;
382+ }
383+
384+ private:
385+ std::shared_ptr<CSGObject> m_model;
386+ std::string m_parameter;
387+ std::string m_string_val;
388+ };
389+
390+ /* *
391+ * Instantiates a CSGObject using a factory
392+ * @param factory_name the name of the factory
393+ * @param algo_name the name of algorithm passed to factory
394+ * @return the instantiated object using a factory
395+ */
396+ std::shared_ptr<CSGObject> instantiate_model_from_factory (
397+ const std::string& factory_name, const std::string& algo_name)
398+ {
399+ std::shared_ptr<CSGObject> obj;
400+ if (factory_name == " machine" )
401+ obj = std::shared_ptr<CSGObject>(machine (algo_name));
402+ else if (factory_name == " kernel" )
403+ obj = std::shared_ptr<CSGObject>(kernel (algo_name));
404+ else if (factory_name == " distance" )
405+ obj = std::shared_ptr<CSGObject>(distance (algo_name));
406+ else
407+ SG_SERROR (" Unsupported factory \" %s\" .\n " , factory_name.c_str ())
408+
409+ return obj;
410+ }
411+
412+ /* *
413+ * Downcasts a CSGObject and puts it in the map of obj.
414+ * @param obj the main object
415+ * @param nested_obj the object to be casted and put in the obj map.
416+ * @param parameter_name the name of nested_obj
417+ */
418+ void cast_and_put (
419+ const std::shared_ptr<CSGObject>& obj,
420+ const std::shared_ptr<CSGObject>& nested_obj,
421+ const std::string& parameter_name)
422+ {
423+ if (auto casted_obj = std::dynamic_pointer_cast<CMachine>(nested_obj))
424+ {
425+ // TODO: remove clone
426+ // temporary fix until shared_ptr PR merged
427+ auto * tmp_clone = dynamic_cast <CMachine*>(casted_obj->clone ());
428+ obj->put (parameter_name, tmp_clone);
429+ }
430+ else if (auto casted_obj = std::dynamic_pointer_cast<CKernel>(nested_obj))
431+ {
432+ auto * tmp_clone = dynamic_cast <CKernel*>(casted_obj->clone ());
433+ obj->put (parameter_name, tmp_clone);
434+ }
435+ else if (auto casted_obj = std::dynamic_pointer_cast<CDistance>(nested_obj))
436+ {
437+ auto * tmp_clone = dynamic_cast <CDistance*>(casted_obj->clone ());
438+ obj->put (parameter_name, tmp_clone);
439+ }
440+ else
441+ SG_SERROR (" Could not cast SGObject.\n " )
442+ }
443+
237444std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model (
238445 std::shared_ptr<OpenMLFlow> flow, bool initialize_with_defaults)
239446{
240- std::string name;
241- std::string val_as_string;
242- std::shared_ptr<CSGObject> obj;
243447 auto params = flow->get_parameters ();
244448 auto components = flow->get_components ();
245449 auto class_name = get_class_info (flow->get_class_name ());
246450 auto module_name = std::get<0 >(class_name);
247451 auto algo_name = std::get<1 >(class_name);
248- if (module_name == " machine" )
249- obj = std::shared_ptr<CSGObject>(machine (algo_name));
250- else if (module_name == " kernel" )
251- obj = std::shared_ptr<CSGObject>(kernel (algo_name));
252- else if (module_name == " distance" )
253- obj = std::shared_ptr<CSGObject>(distance (algo_name));
254- else
255- SG_SERROR (" Unsupported factory \" %s\"\n " , module_name.c_str ())
452+
453+ auto obj = instantiate_model_from_factory (module_name, algo_name);
256454 auto obj_param = obj->get_params ();
257455
258- auto put_lambda = [&obj, &name, &val_as_string](const auto & val) {
259- // cast value using type from get, i.e. val
260- auto val_ = char_to_scalar<std::remove_reference_t <decltype (val)>>(
261- val_as_string.c_str ());
262- obj->put (name, val_);
263- };
456+ std::unique_ptr<StringToShogun> visitor (new StringToShogun (obj));
264457
265458 if (initialize_with_defaults)
266459 {
267460 for (const auto & param : params)
268461 {
269462 Any any_val = obj_param.at (param.first )->get_value ();
270- name = param.first ;
271- val_as_string = param.second .at (" default_value" );
272- sg_any_dispatch (any_val, sg_all_typemap, put_lambda);
463+ std::string name = param.first ;
464+ std::string val_as_string = param.second .at (" default_value" );
465+ visitor->set_parameter_name (name);
466+ visitor->set_string_value (val_as_string);
467+ any_val.visit (visitor.get ());
273468 }
274469 }
275470
276471 for (const auto & component : components)
277472 {
278- CSGObject* a =
279- flow_to_model (component.second , initialize_with_defaults). get () ;
280- // obj->put( component.first, a );
473+ std::shared_ptr< CSGObject> nested_obj =
474+ flow_to_model (component.second , initialize_with_defaults);
475+ cast_and_put (obj, nested_obj, component.first );
281476 }
282477
478+ SG_SDEBUG (" Final object: %s.\n " , obj->to_string ());
479+
283480 return obj;
284481}
285482
@@ -306,15 +503,15 @@ ShogunOpenML::get_class_info(const std::string& class_name)
306503 if (std::next (it) == class_name.end ())
307504 class_components.emplace_back (std::string (begin, std::next (it)));
308505 }
309- if (class_components.size () != 3 )
310- SG_SERROR (" Invalid class name format %s\n " , class_name.c_str ())
311506 if (class_components[0 ] == " shogun" )
312507 result = std::make_tuple (class_components[1 ], class_components[2 ]);
313508 else
314509 SG_SERROR (
315510 " The provided flow is not meant for shogun deserialisation! The "
316- " required library is \" %s\"\n " ,
511+ " required library is \" %s\" . \n " ,
317512 class_components[0 ].c_str ())
513+ if (class_components.size () != 3 )
514+ SG_SERROR (" Invalid class name format %s.\n " , class_name.c_str ())
318515
319516 return result;
320517}
0 commit comments