Skip to content

Commit 145b591

Browse files
committed
moved json dependency to library
1 parent 7cf1d10 commit 145b591

File tree

3 files changed

+225
-62
lines changed

3 files changed

+225
-62
lines changed

src/shogun/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ SHOGUN_DEPENDENCIES(
412412
CONFIG_FLAG HAVE_XML)
413413
# RapidJSON
414414
include(external/RapidJSON)
415-
SHOGUN_INCLUDE_DIRS(SCOPE PUBLIC ${RAPIDJSON_INCLUDE_DIR})
415+
SHOGUN_INCLUDE_DIRS(SCOPE PRIVATE ${RAPIDJSON_INCLUDE_DIR})
416416

417417
if (NOT WIN32)
418418
# FIXME: HDF5 linking on WIN32 is broken.

src/shogun/io/OpenMLFlow.cpp

Lines changed: 163 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
*/
66

77
#include <shogun/io/OpenMLFlow.h>
8+
#include <shogun/lib/type_case.h>
9+
#include <shogun/util/factory.h>
10+
11+
#include <rapidjson/document.h>
812

913
#ifdef HAVE_CURL
1014

@@ -59,7 +63,7 @@ void OpenMLReader::openml_curl_request_helper(const std::string& url)
5963

6064
if (!curl_handle)
6165
{
62-
SG_SERROR("Failed to initialise curl handle.")
66+
SG_SERROR("Failed to initialise curl handle.\n")
6367
return;
6468
}
6569

@@ -82,24 +86,61 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
8286
// TODO: call curl_easy_cleanup(curl_handle) ?
8387
SG_SERROR("Curl error: %s\n", curl_easy_strerror(code))
8488
}
85-
// else
86-
// {
87-
// long response_code;
88-
// curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE, &response_code);
89-
// if (response_code == 200)
90-
// return;
91-
// else
92-
// {
93-
// if (response_code == 181)
94-
// SG_SERROR("Unknown flow. The flow with the given ID was not
95-
//found in the database.") else if (response_code == 180) SG_SERROR("")
96-
// SG_SERROR("Server code: %d\n", response_code)
97-
// }
98-
// }
89+
// else
90+
// {
91+
// long response_code;
92+
// curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE,
93+
//&response_code); if (response_code == 200) return;
94+
// else
95+
// {
96+
// if (response_code == 181)
97+
// SG_SERROR("Unknown flow. The flow with the given ID was not
98+
// found in the database.") else if (response_code == 180)
99+
// SG_SERROR("") SG_SERROR("Server code: %d\n", response_code)
100+
// }
101+
// }
99102
}
100103

101-
std::shared_ptr<OpenMLFlow>
102-
OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key)
104+
#endif // HAVE_CURL
105+
106+
static void check_flow_response(rapidjson::Document& doc)
107+
{
108+
if (SG_UNLIKELY(doc.HasMember("error")))
109+
{
110+
const Value& root = doc["error"];
111+
SG_SERROR(
112+
"Server error %s: %s\n", root["code"].GetString(),
113+
root["message"].GetString())
114+
return;
115+
}
116+
REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n");
117+
}
118+
119+
static SG_FORCED_INLINE void emplace_string_to_map(
120+
const rapidjson::GenericValue<rapidjson::UTF8<char>>& v,
121+
std::unordered_map<std::string, std::string>& param_dict,
122+
const std::string& name)
123+
{
124+
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
125+
param_dict.emplace(name, v[name.c_str()].GetString());
126+
else
127+
param_dict.emplace(name, "");
128+
}
129+
130+
static SG_FORCED_INLINE void emplace_string_to_map(
131+
const rapidjson::GenericObject<
132+
true, rapidjson::GenericValue<rapidjson::UTF8<char>>>& v,
133+
std::unordered_map<std::string, std::string>& param_dict,
134+
const std::string& name)
135+
{
136+
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
137+
param_dict.emplace(name, v[name.c_str()].GetString());
138+
else
139+
param_dict.emplace(name, "");
140+
}
141+
142+
std::shared_ptr<OpenMLFlow> OpenMLFlow::download_flow(
143+
const std::string& flow_id, const std::string& api_key)
103144
{
104145
Document document;
105146
parameters_type params;
@@ -124,7 +165,8 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key
124165

125166
if (root["parameter"].IsArray())
126167
{
127-
for (const auto &v : root["parameter"].GetArray()) {
168+
for (const auto& v : root["parameter"].GetArray())
169+
{
128170
emplace_string_to_map(v, param_dict, "data_type");
129171
emplace_string_to_map(v, param_dict, "default_value");
130172
emplace_string_to_map(v, param_dict, "description");
@@ -146,11 +188,22 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key
146188
// handle components, i.e. kernels
147189
if (root.HasMember("component"))
148190
{
149-
for (const auto& v : root["component"].GetArray())
191+
if (root["component"].IsArray())
192+
{
193+
for (const auto& v : root["component"].GetArray())
194+
{
195+
components.emplace(
196+
v["identifier"].GetString(),
197+
OpenMLFlow::download_flow(
198+
v["flow"]["id"].GetString(), api_key));
199+
}
200+
}
201+
else
150202
{
151203
components.emplace(
152-
v["identifier"].GetString(),
153-
OpenMLFlow::download_flow(v["flow"]["id"].GetString(), api_key));
204+
root["component"]["identifier"].GetString(),
205+
OpenMLFlow::download_flow(
206+
root["component"]["flow"]["id"].GetString(), api_key));
154207
}
155208
}
156209

@@ -162,26 +215,104 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key
162215
if (root.HasMember("class_name"))
163216
class_name = root["class_name"].GetString();
164217

165-
auto flow = std::make_shared<OpenMLFlow>(name, description, class_name, components, params);
218+
auto flow = std::make_shared<OpenMLFlow>(
219+
name, description, class_name, components, params);
166220

167221
return flow;
168222
}
169223

170-
void OpenMLFlow::check_flow_response(Document& doc)
224+
void OpenMLFlow::upload_flow(const std::shared_ptr<OpenMLFlow>& flow)
171225
{
172-
if (SG_UNLIKELY(doc.HasMember("error")))
226+
}
227+
228+
void OpenMLFlow::dump()
229+
{
230+
}
231+
232+
std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file()
233+
{
234+
return std::shared_ptr<OpenMLFlow>();
235+
}
236+
237+
std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model(
238+
std::shared_ptr<OpenMLFlow> flow, bool initialize_with_defaults)
239+
{
240+
std::string name;
241+
std::string val_as_string;
242+
std::shared_ptr<CSGObject> obj;
243+
auto params = flow->get_parameters();
244+
auto components = flow->get_components();
245+
auto class_name = get_class_info(flow->get_class_name());
246+
auto module_name = std::get<0>(class_name);
247+
auto algo_name = std::get<1>(class_name);
248+
if (module_name == "machine")
249+
obj = std::shared_ptr<CSGObject>(machine(algo_name));
250+
else if (module_name == "kernel")
251+
obj = std::shared_ptr<CSGObject>(kernel(algo_name));
252+
else if (module_name == "distance")
253+
obj = std::shared_ptr<CSGObject>(distance(algo_name));
254+
else
255+
SG_SERROR("Unsupported factory \"%s\"\n", module_name.c_str())
256+
auto obj_param = obj->get_params();
257+
258+
auto put_lambda = [&obj, &name, &val_as_string](const auto& val) {
259+
// cast value using type from get, i.e. val
260+
auto val_ = char_to_scalar<std::remove_reference_t<decltype(val)>>(
261+
val_as_string.c_str());
262+
obj->put(name, val_);
263+
};
264+
265+
if (initialize_with_defaults)
173266
{
174-
const Value& root = doc["error"];
175-
SG_SERROR(
176-
"Server error %s: %s\n", root["code"].GetString(),
177-
root["message"].GetString())
178-
return;
267+
for (const auto& param : params)
268+
{
269+
Any any_val = obj_param.at(param.first)->get_value();
270+
name = param.first;
271+
val_as_string = param.second.at("default_value");
272+
sg_any_dispatch(any_val, sg_all_typemap, put_lambda);
273+
}
179274
}
180-
REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n");
275+
276+
for (const auto& component : components)
277+
{
278+
CSGObject* a =
279+
flow_to_model(component.second, initialize_with_defaults).get();
280+
// obj->put(component.first, a);
281+
}
282+
283+
return obj;
181284
}
182285

183-
void OpenMLFlow::upload_flow(const std::shared_ptr<OpenMLFlow>& flow)
286+
std::shared_ptr<OpenMLFlow>
287+
ShogunOpenML::model_to_flow(const std::shared_ptr<CSGObject>& model)
184288
{
289+
return std::shared_ptr<OpenMLFlow>();
185290
}
186291

187-
#endif // HAVE_CURL
292+
std::tuple<std::string, std::string>
293+
ShogunOpenML::get_class_info(const std::string& class_name)
294+
{
295+
std::vector<std::string> class_components;
296+
auto begin = class_name.begin();
297+
std::tuple<std::string, std::string> result;
298+
299+
for (auto it = class_name.begin(); it != class_name.end(); ++it)
300+
{
301+
if (*it == '.')
302+
{
303+
class_components.emplace_back(std::string(begin, it));
304+
begin = it;
305+
}
306+
}
307+
if (class_components.size() != 3)
308+
SG_SERROR("Invalid class name format %s\n", class_name.c_str())
309+
if (class_components[0] == "shogun")
310+
result = std::make_tuple(class_components[1], class_components[2]);
311+
else
312+
SG_SERROR(
313+
"The provided flow is not meant for shogun deserialisation! The "
314+
"required library is \"%s\"\n",
315+
class_components[0].c_str())
316+
317+
return result;
318+
}

src/shogun/io/OpenMLFlow.h

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <shogun/io/SGIO.h>
1616

1717
#include <curl/curl.h>
18-
#include <rapidjson/document.h>
1918

2019
#include <iostream>
2120
#include <memory>
@@ -150,9 +149,9 @@ namespace shogun
150149

151150
public:
152151
using components_type =
153-
std::unordered_map<std::string, std::shared_ptr<OpenMLFlow>>;
152+
std::unordered_map<std::string, std::shared_ptr<OpenMLFlow>>;
154153
using parameters_type = std::unordered_map<
155-
std::string, std::unordered_map<std::string, std::string>>;
154+
std::string, std::unordered_map<std::string, std::string>>;
156155

157156
OpenMLFlow(
158157
const std::string& name, const std::string& description,
@@ -163,13 +162,16 @@ namespace shogun
163162
{
164163
}
165164

166-
~OpenMLFlow()= default;
167-
168165
static std::shared_ptr<OpenMLFlow>
169166
download_flow(const std::string& flow_id, const std::string& api_key);
170167

168+
static std::shared_ptr<OpenMLFlow>
169+
from_file();
170+
171171
static void upload_flow(const std::shared_ptr<OpenMLFlow>& flow);
172172

173+
void dump();
174+
173175
std::shared_ptr<OpenMLFlow> get_subflow(const std::string& name)
174176
{
175177
auto find_flow = m_components.find(name);
@@ -181,40 +183,70 @@ namespace shogun
181183
return nullptr;
182184
}
183185

186+
#ifndef SWIG
187+
SG_FORCED_INLINE parameters_type get_parameters()
188+
{
189+
return m_parameters;
190+
}
191+
192+
SG_FORCED_INLINE components_type get_components()
193+
{
194+
return m_components;
195+
}
196+
197+
SG_FORCED_INLINE std::string get_class_name()
198+
{
199+
return m_class_name;
200+
}
201+
#endif // SWIG
202+
184203
private:
185204
std::string m_name;
186205
std::string m_description;
187206
std::string m_class_name;
188207
parameters_type m_parameters;
189208
components_type m_components;
209+
};
190210

191-
#ifndef SWIG
192-
static void check_flow_response(rapidjson::Document& doc);
211+
template <typename T>
212+
T char_to_scalar(const char* string_val)
213+
{
214+
SG_SERROR("No registered conversion for type %s\n", demangled_type<T>().c_str())
215+
return 0;
216+
}
193217

194-
static SG_FORCED_INLINE void emplace_string_to_map(
195-
const rapidjson::GenericValue<rapidjson::UTF8<char>>& v,
196-
std::unordered_map<std::string, std::string>& param_dict,
197-
const std::string& name)
198-
{
199-
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
200-
param_dict.emplace(name, v[name.c_str()].GetString());
201-
else
202-
param_dict.emplace(name, "");
203-
}
218+
template <>
219+
float32_t char_to_scalar(const char* string_val)
220+
{
221+
char* end;
222+
return std::strtof(string_val, &end);
223+
}
204224

205-
static SG_FORCED_INLINE void emplace_string_to_map(
206-
const rapidjson::GenericObject<
207-
true, rapidjson::GenericValue<rapidjson::UTF8<char>>>& v,
208-
std::unordered_map<std::string, std::string>& param_dict,
209-
const std::string& name)
210-
{
211-
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
212-
param_dict.emplace(name, v[name.c_str()].GetString());
213-
else
214-
param_dict.emplace(name, "");
215-
}
225+
template <>
226+
float64_t char_to_scalar(const char* string_val)
227+
{
228+
char* end;
229+
return std::strtod(string_val, &end);
230+
}
216231

217-
#endif // SWIG
232+
template <>
233+
floatmax_t char_to_scalar(const char* string_val)
234+
{
235+
char* end;
236+
return std::strtold(string_val, &end);
237+
}
238+
239+
class ShogunOpenML
240+
{
241+
public:
242+
static std::shared_ptr<CSGObject> flow_to_model(
243+
std::shared_ptr<OpenMLFlow> flow, bool initialize_with_defaults);
244+
245+
static std::shared_ptr<OpenMLFlow>
246+
model_to_flow(const std::shared_ptr<CSGObject>& model);
247+
248+
private:
249+
static std::tuple<std::string, std::string> get_class_info(const std::string& class_name);
218250
};
219251
} // namespace shogun
220252
#endif // HAVE_CURL

0 commit comments

Comments
 (0)