Skip to content

Commit ff77fd3

Browse files
authored
[Bindings] pass shamrock sys.argv to python (Shamrock-code#1553)
1 parent d0dcf81 commit ff77fd3

5 files changed

Lines changed: 23 additions & 7 deletions

File tree

src/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,13 @@ int main(int argc, char *argv[]) {
190190
"cannot run ipython mode with > 1 processes");
191191
}
192192

193-
shambindings::start_ipython(true);
193+
shambindings::start_ipython(true, argc, argv);
194194

195195
} else if (opts::has_option("--rscript")) {
196196
__shamrock_stack_entry();
197197
std::string fname = std::string(opts::get_option("--rscript"));
198198

199-
shambindings::run_py_file(fname, shamcomm::world_rank() == 0);
199+
shambindings::run_py_file(fname, shamcomm::world_rank() == 0, argc, argv);
200200

201201
} else {
202202
if (shamcomm::world_rank() == 0) {

src/shambindings/include/shambindings/start_python.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ namespace shambindings {
2727
*/
2828
void setpypath(std::string path);
2929

30+
/**
31+
* @brief set the value of sys.argv
32+
*
33+
* This function will throw if bindings were not initialized in embed mode
34+
*/
35+
void set_sys_argv(int argc, char *argv[]);
36+
3037
/**
3138
* @brief set the value of sys.path before init from the supplied binary
3239
*
@@ -42,7 +49,7 @@ namespace shambindings {
4249
* @warning This function shall not be called if more than one processes are running
4350
* @param do_print print log at python startup
4451
*/
45-
void start_ipython(bool do_print);
52+
void start_ipython(bool do_print, int argc, char *argv[]);
4653

4754
/**
4855
* @brief run python runscript
@@ -52,7 +59,7 @@ namespace shambindings {
5259
* @param do_print print log at python startup
5360
* @param file_path path to the runscript
5461
*/
55-
void run_py_file(std::string file_path, bool do_print);
62+
void run_py_file(std::string file_path, bool do_print, int argc, char *argv[]);
5663

5764
/**
5865
* @brief Modify Python sys.path to point to one detected during cmake invocation

src/shambindings/src/run_ipython.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@
2222
###
2323
"""
2424

25-
start_ipython(config=c)
25+
start_ipython(argv=[], config=c)

src/shambindings/src/start_python.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "shambindings/pybindings.hpp"
2121
#include "shambindings/start_python.hpp"
2222
#include <pybind11/embed.h>
23+
#include <pybind11/stl.h>
2324
#include <cstdlib>
2425
#include <optional>
2526
#include <string>
@@ -109,10 +110,16 @@ namespace shambindings {
109110
py::exec(modify_path);
110111
}
111112

112-
void start_ipython(bool do_print) {
113+
void set_sys_argv(int argc, char *argv[]) {
114+
std::vector<std::string> cpp_argv(argv, argv + argc);
115+
py::module_::import("sys").attr("argv") = py::cast(cpp_argv);
116+
}
117+
118+
void start_ipython(bool do_print, int argc, char *argv[]) {
113119

114120
py::scoped_interpreter guard{};
115121
modify_py_sys_path(do_print);
122+
set_sys_argv(argc, argv);
116123

117124
if (do_print) {
118125
shambase::println("--------------------------------------------");
@@ -127,9 +134,10 @@ namespace shambindings {
127134
}
128135
}
129136

130-
void run_py_file(std::string file_path, bool do_print) {
137+
void run_py_file(std::string file_path, bool do_print, int argc, char *argv[]) {
131138
py::scoped_interpreter guard{};
132139
modify_py_sys_path(do_print);
140+
set_sys_argv(argc, argv);
133141

134142
if (do_print) {
135143
shambase::println("-----------------------------------");

src/shamtest/shamtest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ namespace shamtest {
501501

502502
ON_RANK_0(shamcomm::logs::print_faint_row());
503503
shambindings::modify_py_sys_path(shamcomm::world_rank() == 0);
504+
shambindings::set_sys_argv(argc, argv);
504505
ON_RANK_0(shamcomm::logs::print_faint_row());
505506

506507
// import shamrock in pybind

0 commit comments

Comments
 (0)