Skip to content

Commit 28d32a0

Browse files
committed
Fix build issues, update C extension, and enhance tests
1 parent de74d3c commit 28d32a0

File tree

4 files changed

+25
-5
lines changed

4 files changed

+25
-5
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ OptimRL is a **high-performance reinforcement learning library** that introduces
66
## 🏅 Badges
77

88
![PyPI Version](https://img.shields.io/pypi/v/optimrl)
9-
<!-- ![PyPI Downloads](https://img.shields.io/pypi/dm/optimrl) -->
10-
<!-- ![Python Version](https://img.shields.io/pypi/pyversions/optimrl) -->
119
![Python](https://img.shields.io/badge/Python-3.8%2B-blue?logo=python&logoColor=white)
1210
![C](https://img.shields.io/badge/C-99-00599C?logo=c&logoColor=white)
1311
![NumPy](https://img.shields.io/badge/Library-NumPy-013243?logo=numpy&logoColor=white)

optimrl/c_src/grpo.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22
#include <stdio.h>
33
#include <stdlib.h>
44
#include <math.h>
5+
#define PY_SSIZE_T_CLEAN
6+
#include <Python.h>
7+
8+
// Example implementation of the PyInit_libgrpo function
9+
PyMODINIT_FUNC PyInit_libgrpo(void) {
10+
static struct PyModuleDef moduledef = {
11+
PyModuleDef_HEAD_INIT,
12+
"libgrpo", // Module name
13+
"GRPO C extension", // Module docstring
14+
-1, // Size of per-interpreter state of the module
15+
NULL, // Module methods
16+
};
17+
18+
return PyModule_Create(&moduledef);
19+
}
520

621
// Helper function to compute robust statistics
722
void compute_reward_stats(double* rewards, int group_size, double* out_mean, double* out_std) {

optimrl/c_src/libgrpo.dylib

160 Bytes
Binary file not shown.

setup.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
import platform
66
import shutil
77
import versioneer
8+
import sysconfig
9+
10+
# Get the Python include path dynamically
11+
python_include_path = sysconfig.get_path('include')
12+
python_lib_path = sysconfig.get_config_var('LIBDIR')
813

914

1015
class CustomBuildExt(build_ext):
@@ -77,12 +82,14 @@ def finalize_options(self):
7782

7883

7984
# Define the extension module
85+
8086
grpo_module = Extension(
8187
'optimrl.c_src.libgrpo',
8288
sources=['optimrl/c_src/grpo.c'],
83-
include_dirs=['optimrl/c_src'],
89+
include_dirs=['optimrl/c_src',python_include_path],
8490
libraries=['m'] if platform.system() != 'Windows' else [],
85-
extra_compile_args=['-O3', '-fPIC'] if platform.system() != 'Windows' else ['/O2']
91+
extra_compile_args=['-O3', '-fPIC'] if platform.system() != 'Windows' else ['/O2'],
92+
extra_link_args=[] if platform.system() != 'Windows' else ['/EXPORT:PyInit_libgrpo']
8693
)
8794

8895
# Read the README file
@@ -110,7 +117,7 @@ def finalize_options(self):
110117
"torch>=1.8.0"
111118
],
112119
extras_require={
113-
'test': ['pytest>=6.0'],
120+
'test': ['pytest>=6.0', 'flake8', 'isort', 'black'],
114121
'dev': ['pytest>=6.0', 'black', 'isort', 'flake8']
115122
},
116123
python_requires=">=3.8",

0 commit comments

Comments
 (0)