Skip to content

Commit b985c61

Browse files
committed
Move proxstorm from TypeP to SOL and change constructor to match other algorithms in SOL. Moved dualsvm and logistic examples from example/proxstorm to example/ and use the moved proxstorm.
Added test/sol/test_17.cpp, based on test/algorithm/TypeP/test_08.cpp and adjusted to use proxstorm (with no stochasticity) Todo: Add stochasticity.
1 parent e8a93d1 commit b985c61

File tree

17 files changed

+720
-21
lines changed

17 files changed

+720
-21
lines changed

example/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ ADD_SUBDIRECTORY(dense-hessian)
2222
ADD_SUBDIRECTORY(stream-buffer)
2323
ADD_SUBDIRECTORY(oed)
2424
ADD_SUBDIRECTORY(lincon-test)
25-
ADD_SUBDIRECTORY(proxstorm)
25+
ADD_SUBDIRECTORY(dualsvm)
26+
ADD_SUBDIRECTORY(logistic)
2627

2728
IF(NOT STANDALONE_ROL)
2829

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
#include "ROL_GlobalMPISession.hpp"
1414

1515
#include "ROL_TypeB_Algorithm.hpp"
16-
#include "ROL_TypeP_STORMAlgorithm.hpp"
17-
#include "ROL_TypeP_STORMAlgorithm_Def.hpp"
1816
#include "ROL_TypeP_TrustRegionAlgorithm.hpp"
1917
#include "ROL_TypeP_TrustRegionAlgorithm_Def.hpp"
2018
#include "ROL_ScalarLinearConstraint.hpp"
19+
#include "ROL_STORMAlgorithm.hpp"
20+
#include "ROL_STORMAlgorithm_Def.hpp"
2121

2222
#include <iostream>
2323
#include <random>
@@ -193,8 +193,7 @@ int main(int argc, char *argv[]) {
193193
ROL::Ptr<ROL::PolyhedralProjection<RealT>> proj = ROL::PolyhedralProjectionFactory<RealT>(*x, x->dual(), bnd, con, *lam, *res, *parlist);
194194
ROL::Ptr<DualSVMConstraint<RealT>> nobj = ROL::makePtr<DualSVMConstraint<RealT>>(y, proj);
195195

196-
ROL::Ptr<ROL::TypeP::STORMAlgorithm<RealT>> algo = ROL::makePtr<ROL::TypeP::STORMAlgorithm<RealT>>(*parlist);
197-
ROL::Ptr<ROL::TypeP::TrustRegionAlgorithm<RealT>> algo2 = ROL::makePtr<ROL::TypeP::TrustRegionAlgorithm<RealT>>(*parlist);
196+
ROL::Ptr<ROL::TypeP::TrustRegionAlgorithm<RealT>> algo_nonsmooth_tr = ROL::makePtr<ROL::TypeP::TrustRegionAlgorithm<RealT>>(*parlist);
198197

199198
bool checkDeriv = parlist->sublist("Problem").get("Check Derivatives",false);
200199
if ( checkDeriv ) {
@@ -208,10 +207,22 @@ int main(int argc, char *argv[]) {
208207
obj->checkHessSym(*x,*dx,*hx,true,*outStream);
209208
}
210209

211-
212210
//std::clock_t timer = std::clock();
213-
algo2->run(*xy, *obj, *nobj, *outStream); // nonsmooth tr
214-
algo->run(*x, *obj, *nobj, *outStream); // storm
211+
algo_nonsmooth_tr->run(*xy, *obj, *nobj, *outStream); // nonsmooth tr
212+
213+
// Set up and run proxstorm
214+
ROL::Ptr<ROL::Problem<RealT>> problem = ROL::makePtr<ROL::Problem<RealT>>(
215+
obj, x
216+
);
217+
problem->addProximableObjective(nobj);
218+
problem->finalize(false, true, *outStream);
219+
220+
ROL::Ptr<ROL::STORMAlgorithm<RealT>> algo_storm = ROL::makePtr<ROL::STORMAlgorithm<RealT>>(
221+
problem,
222+
nullptr,
223+
*parlist
224+
);
225+
algo_storm->run(*outStream); // storm
215226

216227
}
217228
catch (std::logic_error& err) {
File renamed without changes.
Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
#include "ROL_GlobalMPISession.hpp"
1414

1515
#include "ROL_TypeB_Algorithm.hpp"
16-
#include "ROL_TypeP_STORMAlgorithm.hpp"
17-
#include "ROL_TypeP_STORMAlgorithm_Def.hpp"
16+
#include "ROL_STORMAlgorithm.hpp"
17+
#include "ROL_STORMAlgorithm_Def.hpp"
1818
#include "ROL_TypeP_TrustRegionAlgorithm.hpp"
1919
#include "ROL_TypeP_TrustRegionAlgorithm_Def.hpp"
2020
#include "ROL_ScalarLinearConstraint.hpp"
@@ -189,8 +189,7 @@ int main(int argc, char *argv[]) {
189189
xy->randomize();
190190
nobj = ROL::makePtr<ROL::l1Objective<RealT>>(wts);
191191

192-
ROL::Ptr<ROL::TypeP::STORMAlgorithm<RealT>> algo = ROL::makePtr<ROL::TypeP::STORMAlgorithm<RealT>>(*parlist);
193-
ROL::Ptr<ROL::TypeP::TrustRegionAlgorithm<RealT>> algo2 = ROL::makePtr<ROL::TypeP::TrustRegionAlgorithm<RealT>>(*parlist);
192+
ROL::Ptr<ROL::TypeP::TrustRegionAlgorithm<RealT>> algo_nonsmooth_tr = ROL::makePtr<ROL::TypeP::TrustRegionAlgorithm<RealT>>(*parlist);
194193
bool checkDeriv = parlist->sublist("Problem").get("Check Derivatives",true);
195194
if ( checkDeriv ) {
196195
ROL::Ptr<ROL::StdVector<RealT>> dx = ROL::makePtr<ROL::StdVector<RealT>>(dim,0);
@@ -204,9 +203,21 @@ int main(int argc, char *argv[]) {
204203
}
205204

206205
//std::clock_t timer = std::clock();
207-
algo2->run(*xy, *obj, *nobj, *outStream); // nonsmooth tr
208-
algo->run(*x, *obj, *nobj, *outStream); // storm
209-
206+
algo_nonsmooth_tr->run(*xy, *obj, *nobj, *outStream); // nonsmooth tr
207+
208+
// Set up and run proxstorm
209+
ROL::Ptr<ROL::Problem<RealT>> problem = ROL::makePtr<ROL::Problem<RealT>>(
210+
obj, x
211+
);
212+
problem->addProximableObjective(nobj);
213+
problem->finalize(false, true, *outStream);
214+
215+
ROL::Ptr<ROL::STORMAlgorithm<RealT>> algo_storm = ROL::makePtr<ROL::STORMAlgorithm<RealT>>(
216+
problem,
217+
nullptr,
218+
*parlist
219+
);
220+
algo_storm->run(*outStream); // storm
210221
}
211222
catch (std::logic_error& err) {
212223
*outStream << err.what() << "\n";
File renamed without changes.

0 commit comments

Comments
 (0)