55 */
66
77#include < shogun/io/OpenMLFlow.h>
8+ #include < shogun/lib/type_case.h>
9+ #include < shogun/util/factory.h>
10+
11+ #include < rapidjson/document.h>
812
913#ifdef HAVE_CURL
1014
@@ -59,7 +63,7 @@ void OpenMLReader::openml_curl_request_helper(const std::string& url)
5963
6064 if (!curl_handle)
6165 {
62- SG_SERROR (" Failed to initialise curl handle." )
66+ SG_SERROR (" Failed to initialise curl handle.\n " )
6367 return ;
6468 }
6569
@@ -82,24 +86,61 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
8286 // TODO: call curl_easy_cleanup(curl_handle) ?
8387 SG_SERROR (" Curl error: %s\n " , curl_easy_strerror (code))
8488 }
85- // else
86- // {
87- // long response_code;
88- // curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE, &response_code);
89- // if (response_code == 200)
90- // return;
91- // else
92- // {
93- // if (response_code == 181)
94- // SG_SERROR("Unknown flow. The flow with the given ID was not
95- // found in the database.") else if (response_code == 180) SG_SERROR("")
96- // SG_SERROR("Server code: %d\n", response_code)
97- // }
98- // }
89+ // else
90+ // {
91+ // long response_code;
92+ // curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE,
93+ // &response_code); if (response_code == 200) return;
94+ // else
95+ // {
96+ // if (response_code == 181)
97+ // SG_SERROR("Unknown flow. The flow with the given ID was not
98+ // found in the database.") else if (response_code == 180)
99+ // SG_SERROR("") SG_SERROR("Server code: %d\n", response_code)
100+ // }
101+ // }
99102}
100103
101- std::shared_ptr<OpenMLFlow>
102- OpenMLFlow::download_flow (const std::string& flow_id, const std::string& api_key)
104+ #endif // HAVE_CURL
105+
106+ static void check_flow_response (rapidjson::Document& doc)
107+ {
108+ if (SG_UNLIKELY (doc.HasMember (" error" )))
109+ {
110+ const Value& root = doc[" error" ];
111+ SG_SERROR (
112+ " Server error %s: %s\n " , root[" code" ].GetString (),
113+ root[" message" ].GetString ())
114+ return ;
115+ }
116+ REQUIRE (doc.HasMember (" flow" ), " Unexpected format of OpenML flow.\n " );
117+ }
118+
119+ static SG_FORCED_INLINE void emplace_string_to_map (
120+ const rapidjson::GenericValue<rapidjson::UTF8<char >>& v,
121+ std::unordered_map<std::string, std::string>& param_dict,
122+ const std::string& name)
123+ {
124+ if (v[name.c_str ()].GetType () == rapidjson::Type::kStringType )
125+ param_dict.emplace (name, v[name.c_str ()].GetString ());
126+ else
127+ param_dict.emplace (name, " " );
128+ }
129+
130+ static SG_FORCED_INLINE void emplace_string_to_map (
131+ const rapidjson::GenericObject<
132+ true , rapidjson::GenericValue<rapidjson::UTF8<char >>>& v,
133+ std::unordered_map<std::string, std::string>& param_dict,
134+ const std::string& name)
135+ {
136+ if (v[name.c_str ()].GetType () == rapidjson::Type::kStringType )
137+ param_dict.emplace (name, v[name.c_str ()].GetString ());
138+ else
139+ param_dict.emplace (name, " " );
140+ }
141+
142+ std::shared_ptr<OpenMLFlow> OpenMLFlow::download_flow (
143+ const std::string& flow_id, const std::string& api_key)
103144{
104145 Document document;
105146 parameters_type params;
@@ -124,7 +165,8 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key
124165
125166 if (root[" parameter" ].IsArray ())
126167 {
127- for (const auto &v : root[" parameter" ].GetArray ()) {
168+ for (const auto & v : root[" parameter" ].GetArray ())
169+ {
128170 emplace_string_to_map (v, param_dict, " data_type" );
129171 emplace_string_to_map (v, param_dict, " default_value" );
130172 emplace_string_to_map (v, param_dict, " description" );
@@ -146,11 +188,22 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key
146188 // handle components, i.e. kernels
147189 if (root.HasMember (" component" ))
148190 {
149- for (const auto & v : root[" component" ].GetArray ())
191+ if (root[" component" ].IsArray ())
192+ {
193+ for (const auto & v : root[" component" ].GetArray ())
194+ {
195+ components.emplace (
196+ v[" identifier" ].GetString (),
197+ OpenMLFlow::download_flow (
198+ v[" flow" ][" id" ].GetString (), api_key));
199+ }
200+ }
201+ else
150202 {
151203 components.emplace (
152- v[" identifier" ].GetString (),
153- OpenMLFlow::download_flow (v[" flow" ][" id" ].GetString (), api_key));
204+ root[" component" ][" identifier" ].GetString (),
205+ OpenMLFlow::download_flow (
206+ root[" component" ][" flow" ][" id" ].GetString (), api_key));
154207 }
155208 }
156209
@@ -162,26 +215,104 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key
162215 if (root.HasMember (" class_name" ))
163216 class_name = root[" class_name" ].GetString ();
164217
165- auto flow = std::make_shared<OpenMLFlow>(name, description, class_name, components, params);
218+ auto flow = std::make_shared<OpenMLFlow>(
219+ name, description, class_name, components, params);
166220
167221 return flow;
168222}
169223
170- void OpenMLFlow::check_flow_response (Document& doc )
224+ void OpenMLFlow::upload_flow ( const std::shared_ptr<OpenMLFlow>& flow )
171225{
172- if (SG_UNLIKELY (doc.HasMember (" error" )))
226+ }
227+
228+ void OpenMLFlow::dump ()
229+ {
230+ }
231+
232+ std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file ()
233+ {
234+ return std::shared_ptr<OpenMLFlow>();
235+ }
236+
237+ std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model (
238+ std::shared_ptr<OpenMLFlow> flow, bool initialize_with_defaults)
239+ {
240+ std::string name;
241+ std::string val_as_string;
242+ std::shared_ptr<CSGObject> obj;
243+ auto params = flow->get_parameters ();
244+ auto components = flow->get_components ();
245+ auto class_name = get_class_info (flow->get_class_name ());
246+ auto module_name = std::get<0 >(class_name);
247+ auto algo_name = std::get<1 >(class_name);
248+ if (module_name == " machine" )
249+ obj = std::shared_ptr<CSGObject>(machine (algo_name));
250+ else if (module_name == " kernel" )
251+ obj = std::shared_ptr<CSGObject>(kernel (algo_name));
252+ else if (module_name == " distance" )
253+ obj = std::shared_ptr<CSGObject>(distance (algo_name));
254+ else
255+ SG_SERROR (" Unsupported factory \" %s\"\n " , module_name.c_str ())
256+ auto obj_param = obj->get_params ();
257+
258+ auto put_lambda = [&obj, &name, &val_as_string](const auto & val) {
259+ // cast value using type from get, i.e. val
260+ auto val_ = char_to_scalar<std::remove_reference_t <decltype (val)>>(
261+ val_as_string.c_str ());
262+ obj->put (name, val_);
263+ };
264+
265+ if (initialize_with_defaults)
173266 {
174- const Value& root = doc[" error" ];
175- SG_SERROR (
176- " Server error %s: %s\n " , root[" code" ].GetString (),
177- root[" message" ].GetString ())
178- return ;
267+ for (const auto & param : params)
268+ {
269+ Any any_val = obj_param.at (param.first )->get_value ();
270+ name = param.first ;
271+ val_as_string = param.second .at (" default_value" );
272+ sg_any_dispatch (any_val, sg_all_typemap, put_lambda);
273+ }
179274 }
180- REQUIRE (doc.HasMember (" flow" ), " Unexpected format of OpenML flow.\n " );
275+
276+ for (const auto & component : components)
277+ {
278+ CSGObject* a =
279+ flow_to_model (component.second , initialize_with_defaults).get ();
280+ // obj->put(component.first, a);
281+ }
282+
283+ return obj;
181284}
182285
183- void OpenMLFlow::upload_flow (const std::shared_ptr<OpenMLFlow>& flow)
286+ std::shared_ptr<OpenMLFlow>
287+ ShogunOpenML::model_to_flow (const std::shared_ptr<CSGObject>& model)
184288{
289+ return std::shared_ptr<OpenMLFlow>();
185290}
186291
187- #endif // HAVE_CURL
292+ std::tuple<std::string, std::string>
293+ ShogunOpenML::get_class_info (const std::string& class_name)
294+ {
295+ std::vector<std::string> class_components;
296+ auto begin = class_name.begin ();
297+ std::tuple<std::string, std::string> result;
298+
299+ for (auto it = class_name.begin (); it != class_name.end (); ++it)
300+ {
301+ if (*it == ' .' )
302+ {
303+ class_components.emplace_back (std::string (begin, it));
304+ begin = it;
305+ }
306+ }
307+ if (class_components.size () != 3 )
308+ SG_SERROR (" Invalid class name format %s\n " , class_name.c_str ())
309+ if (class_components[0 ] == " shogun" )
310+ result = std::make_tuple (class_components[1 ], class_components[2 ]);
311+ else
312+ SG_SERROR (
313+ " The provided flow is not meant for shogun deserialisation! The "
314+ " required library is \" %s\"\n " ,
315+ class_components[0 ].c_str ())
316+
317+ return result;
318+ }
0 commit comments