Skip to content

Commit df8809d

Browse files
committed
started working datasets
1 parent 0500504 commit df8809d

File tree

2 files changed

+183
-60
lines changed

2 files changed

+183
-60
lines changed

src/shogun/io/OpenMLFlow.cpp

Lines changed: 129 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ const char* OpenMLReader::list_dataset_qualities = "/data/qualities/{}";
4949
const char* OpenMLReader::list_dataset_filter = "/data/list/{}";
5050
/* FLOW API */
5151
const char* OpenMLReader::flow_file = "/flow/{}";
52+
/* TASK API */
53+
const char* OpenMLReader::task_file = "/task/{}";
5254

5355
const std::unordered_map<std::string, std::string>
5456
OpenMLReader::m_format_options = {{"xml", xml_server},
@@ -102,10 +104,10 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
102104
#endif // HAVE_CURL
103105

104106
/**
105-
* Checks the returned flow in JSON format
106-
* @param doc the parsed flow
107+
* Checks the returned response from OpenML in JSON format
108+
* @param doc the parsed OpenML JSON format response
107109
*/
108-
static void check_flow_response(Document& doc)
110+
static void check_response(const Document& doc, const std::string& type)
109111
{
110112
if (SG_UNLIKELY(doc.HasMember("error")))
111113
{
@@ -115,7 +117,9 @@ static void check_flow_response(Document& doc)
115117
root["message"].GetString())
116118
return;
117119
}
118-
REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n");
120+
REQUIRE(
121+
doc.HasMember(type.c_str()), "Unexpected format of OpenML %s.\n",
122+
type.c_str());
119123
}
120124

121125
/**
@@ -142,8 +146,7 @@ static SG_FORCED_INLINE void emplace_string_to_map(
142146
* @param name the name of the key
143147
*/
144148
static SG_FORCED_INLINE void emplace_string_to_map(
145-
const GenericObject<
146-
true, GenericValue<UTF8<char>>>& v,
149+
const GenericObject<true, GenericValue<UTF8<char>>>& v,
147150
std::unordered_map<std::string, std::string>& param_dict,
148151
const std::string& name)
149152
{
@@ -167,7 +170,7 @@ std::shared_ptr<OpenMLFlow> OpenMLFlow::download_flow(
167170
auto reader = OpenMLReader(api_key);
168171
auto return_string = reader.get("flow_file", "json", flow_id);
169172
document.Parse(return_string.c_str());
170-
check_flow_response(document);
173+
check_response(document, "flow");
171174

172175
// store root for convenience. We know it exists from previous check.
173176
const Value& root = document["flow"];
@@ -248,10 +251,63 @@ std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file()
248251
return std::shared_ptr<OpenMLFlow>();
249252
}
250253

254+
std::shared_ptr<OpenMLTask>
255+
OpenMLTask::get_dataset(const std::string& task_id, const std::string& api_key)
256+
{
257+
Document document;
258+
std::string task_name;
259+
std::string task_type;
260+
std::string task_type_id;
261+
std::pair<std::shared_ptr<OpenMLData>, std::shared_ptr<OpenMLSplit>>
262+
task_descriptor;
263+
264+
auto reader = OpenMLReader(api_key);
265+
auto return_string = reader.get("task_file", "json", task_id);
266+
267+
document.Parse(return_string.c_str());
268+
check_response(document, "task");
269+
270+
const Value& root = document["flow"];
271+
272+
REQUIRE(
273+
task_id == root["task_id"].GetString(),
274+
"Expected downloaded task to have the same id as the requested task "
275+
"id.\n")
276+
277+
task_name = root["task_name"].GetString();
278+
task_type = root["task_type"].GetString();
279+
task_type_id = root["task_type_id"].GetString();
280+
281+
// expect two elements in input array: dataset and split
282+
const Value& json_input = root["input"];
283+
284+
REQUIRE(
285+
json_input.IsArray(), "Currently the dataset reader can only handle "
286+
"inputs with a dataset and split field")
287+
288+
auto input_array = json_input.GetArray();
289+
REQUIRE(
290+
input_array.Size() == 2,
291+
"Currently the dataset reader can only handle inputs with a dataset "
292+
"and split field. Found %d elements.",
293+
input_array.Size())
294+
295+
// handle dataset
296+
auto json_dataset = input_array[0].GetObject();
297+
298+
auto result = std::make_shared<OpenMLTask>(
299+
task_id, task_name, task_type, task_type_id, task_descriptor);
300+
301+
return result;
302+
}
303+
251304
/**
252305
* Class using the Any visitor pattern to convert
253306
* a string to a C++ type that can be used as a parameter
254-
* in a Shogun model.
307+
* in a Shogun model. If the string value is not "null" it will
308+
* be put in its casted type in the given model with the provided parameter
309+
* name. If the value is null nothing happens, i.e. no error is thrown
310+
* and no value is put in model.
255311
*/
256312
class StringToShogun : public AnyVisitor
257313
{
@@ -266,18 +322,20 @@ class StringToShogun : public AnyVisitor
266322

267323
void on(bool* v) final
268324
{
325+
SG_SDEBUG(
326+
"bool: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
269327
if (!is_null())
270328
{
271-
SG_SDEBUG("bool: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
272329
bool result = strcmp(m_string_val.c_str(), "true") == 0;
273330
m_model->put(m_parameter, result);
274331
}
275332
}
276333
void on(int32_t* v) final
277334
{
335+
SG_SDEBUG(
336+
"int32: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
278337
if (!is_null())
279338
{
280-
SG_SDEBUG("int32: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
281339
try
282340
{
283341
int32_t result = std::stoi(m_string_val);
@@ -299,84 +357,98 @@ class StringToShogun : public AnyVisitor
299357
}
300358
void on(int64_t* v) final
301359
{
360+
SG_SDEBUG(
361+
"int64: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
302362
if (!is_null())
303363
{
304-
SG_SDEBUG("int64: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
364+
305365
int64_t result = std::stol(m_string_val);
306366
m_model->put(m_parameter, result);
307367
}
308368
}
309369
void on(float* v) final
310370
{
371+
SG_SDEBUG(
372+
"float: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
311373
if (!is_null())
312374
{
313-
SG_SDEBUG("float: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
314-
char* end;
315-
float32_t result = std::strtof(m_string_val.c_str(), &end);
375+
float32_t result = std::stof(m_string_val);
316376
m_model->put(m_parameter, result);
317377
}
318378
}
319379
void on(double* v) final
320380
{
381+
SG_SDEBUG(
382+
"double: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
321383
if (!is_null())
322384
{
323-
SG_SDEBUG("double: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
324-
char* end;
325-
float64_t result = std::strtod(m_string_val.c_str(), &end);
385+
float64_t result = std::stod(m_string_val);
326386
m_model->put(m_parameter, result);
327387
}
328388
}
329389
void on(long double* v)
330390
{
391+
SG_SDEBUG(
392+
"long double: %s=%s\n", m_parameter.c_str(),
393+
m_string_val.c_str())
331394
if (!is_null())
332395
{
333-
SG_SDEBUG("long double: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
334-
char* end;
335-
floatmax_t result = std::strtold(m_string_val.c_str(), &end);
396+
floatmax_t result = std::stold(m_string_val);
336397
m_model->put(m_parameter, result);
337398
}
338399
}
339400
void on(CSGObject** v) final
340401
{
341-
SG_SDEBUG("CSGObject: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
402+
SG_SDEBUG(
403+
"CSGObject: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
342404
}
343405
void on(SGVector<int>* v) final
344406
{
345-
SG_SDEBUG("SGVector<int>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
407+
SG_SDEBUG(
408+
"SGVector<int>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
346409
}
347410
void on(SGVector<float>* v) final
348411
{
349-
SG_SDEBUG("SGVector<float>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
412+
SG_SDEBUG(
413+
"SGVector<float>: %s=%s\n", m_parameter.c_str(),
414+
m_string_val.c_str())
350415
}
351416
void on(SGVector<double>* v) final
352417
{
353-
SG_SDEBUG("SGVector<double>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
418+
SG_SDEBUG(
419+
"SGVector<double>: %s=%s\n", m_parameter.c_str(),
420+
m_string_val.c_str())
354421
}
355422
void on(SGMatrix<int>* mat) final
356423
{
357-
SG_SDEBUG("SGMatrix<int>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
424+
SG_SDEBUG(
425+
"SGMatrix<int>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
358426
}
359427
void on(SGMatrix<float>* mat) final
360428
{
361-
SG_SDEBUG("SGMatrix<float>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
362-
}
363-
void on(SGMatrix<double>* mat) final
364-
{
365-
SG_SDEBUG("SGMatrix<double>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
429+
SG_SDEBUG(
430+
"SGMatrix<float>: %s=%s\n", m_parameter.c_str(),
431+
m_string_val.c_str())
366432
}
367-
368-
bool is_null()
433+
void on(SGMatrix<double>* mat) final{SG_SDEBUG(
434+
"SGMatrix<double>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())}
435+
436+
/**
437+
* In OpenML "null" is an empty parameter value field.
438+
* @return whether the field is "null"
439+
*/
440+
SG_FORCED_INLINE bool is_null()
369441
{
370442
bool result = strcmp(m_string_val.c_str(), "null") == 0;
371443
return result;
372444
}
373445

374-
void set_parameter_name(const std::string& name)
446+
SG_FORCED_INLINE void set_parameter_name(const std::string& name)
375447
{
376448
m_parameter = name;
377449
}
378450

379-
void set_string_value(const std::string& value)
451+
SG_FORCED_INLINE void set_string_value(const std::string& value)
380452
{
381453
m_string_val = value;
382454
}
@@ -396,17 +468,16 @@ class StringToShogun : public AnyVisitor
396468
std::shared_ptr<CSGObject> instantiate_model_from_factory(
397469
const std::string& factory_name, const std::string& algo_name)
398470
{
399-
std::shared_ptr<CSGObject> obj;
400471
if (factory_name == "machine")
401-
obj = std::shared_ptr<CSGObject>(machine(algo_name));
402-
else if (factory_name == "kernel")
403-
obj = std::shared_ptr<CSGObject>(kernel(algo_name));
404-
else if (factory_name == "distance")
405-
obj = std::shared_ptr<CSGObject>(distance(algo_name));
406-
else
407-
SG_SERROR("Unsupported factory \"%s\".\n", factory_name.c_str())
472+
return std::shared_ptr<CSGObject>(machine(algo_name));
473+
if (factory_name == "kernel")
474+
return std::shared_ptr<CSGObject>(kernel(algo_name));
475+
if (factory_name == "distance")
476+
return std::shared_ptr<CSGObject>(distance(algo_name));
408477

409-
return obj;
478+
SG_SERROR("Unsupported factory \"%s\".\n", factory_name.c_str())
479+
480+
return nullptr;
410481
}
411482

412483
/**
@@ -426,19 +497,21 @@ void cast_and_put(
426497
// temporary fix until shared_ptr PR merged
427498
auto* tmp_clone = dynamic_cast<CMachine*>(casted_obj->clone());
428499
obj->put(parameter_name, tmp_clone);
500+
return;
429501
}
430-
else if (auto casted_obj = std::dynamic_pointer_cast<CKernel>(nested_obj))
502+
if (auto casted_obj = std::dynamic_pointer_cast<CKernel>(nested_obj))
431503
{
432504
auto* tmp_clone = dynamic_cast<CKernel*>(casted_obj->clone());
433505
obj->put(parameter_name, tmp_clone);
506+
return;
434507
}
435-
else if (auto casted_obj = std::dynamic_pointer_cast<CDistance>(nested_obj))
508+
if (auto casted_obj = std::dynamic_pointer_cast<CDistance>(nested_obj))
436509
{
437510
auto* tmp_clone = dynamic_cast<CDistance*>(casted_obj->clone());
438511
obj->put(parameter_name, tmp_clone);
512+
return;
439513
}
440-
else
441-
SG_SERROR("Could not cast SGObject.\n")
514+
SG_SERROR("Could not cast SGObject.\n")
442515
}
443516

444517
std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model(
@@ -447,8 +520,8 @@ std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model(
447520
auto params = flow->get_parameters();
448521
auto components = flow->get_components();
449522
auto class_name = get_class_info(flow->get_class_name());
450-
auto module_name = std::get<0>(class_name);
451-
auto algo_name = std::get<1>(class_name);
523+
auto module_name = class_name.first;
524+
auto algo_name = class_name.second;
452525

453526
auto obj = instantiate_model_from_factory(module_name, algo_name);
454527
auto obj_param = obj->get_params();
@@ -486,12 +559,12 @@ ShogunOpenML::model_to_flow(const std::shared_ptr<CSGObject>& model)
486559
return std::shared_ptr<OpenMLFlow>();
487560
}
488561

489-
std::tuple<std::string, std::string>
562+
std::pair<std::string, std::string>
490563
ShogunOpenML::get_class_info(const std::string& class_name)
491564
{
492565
std::vector<std::string> class_components;
493566
auto begin = class_name.begin();
494-
std::tuple<std::string, std::string> result;
567+
std::pair<std::string, std::string> result;
495568

496569
for (auto it = class_name.begin(); it != class_name.end(); ++it)
497570
{
@@ -503,15 +576,16 @@ ShogunOpenML::get_class_info(const std::string& class_name)
503576
if (std::next(it) == class_name.end())
504577
class_components.emplace_back(std::string(begin, std::next(it)));
505578
}
506-
if (class_components[0] == "shogun")
507-
result = std::make_tuple(class_components[1], class_components[2]);
579+
580+
if (class_components[0] == "shogun" && class_components.size() == 3)
581+
result = std::make_pair(class_components[1], class_components[2]);
582+
else if (class_components[0] == "shogun" && class_components.size() != 3)
583+
SG_SERROR("Invalid class name format %s.\n", class_name.c_str())
508584
else
509585
SG_SERROR(
510586
"The provided flow is not meant for shogun deserialisation! The "
511587
"required library is \"%s\".\n",
512588
class_components[0].c_str())
513-
if (class_components.size() != 3)
514-
SG_SERROR("Invalid class name format %s.\n", class_name.c_str())
515589

516590
return result;
517591
}

0 commit comments

Comments
 (0)