Skip to content

Commit d00c782

Browse files
committed
feat: handle starting jvm + better config read/write logic
1 parent 4b7b710 commit d00c782

1 file changed

Lines changed: 57 additions & 35 deletions

File tree

pcldapy/pclda.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
import warnings
77

88

9+
def ensure_java(classpath, java_home="/usr/lib/jvm/java-11-openjdk"):
10+
print(jpype.isJVMStarted())
11+
print(classpath, os.path.exists(classpath))
12+
if not jpype.isJVMStarted():
13+
# TODO java_home var does nothing
14+
jpype.startJVM(jpype.getDefaultJVMPath(), "-Dfile.encoding=UTF-8", "-Xmx150g", classpath=[os.path.abspath(classpath)])
15+
print(jpype.getDefaultJVMPath())
16+
print(jpype.isJVMStarted())
17+
918

1019
def write_config(slc, cfg_fn, jar_dict=None):
1120
"""
@@ -23,11 +32,19 @@ def _custom_serializer(obj):
2332
try:
2433
return json.JSONEncoder().default(obj)
2534
except TypeError:
35+
2636
try:
37+
# Handle Java strings from JPype
38+
if hasattr(obj, 'getClass') and obj.getClass().getName() == 'java.lang.String':
39+
return str(obj) # or obj.toString()
2740
return vars(obj)['_jstr']
28-
except Exception as e:
41+
except Exception:
42+
return f"Non-serializable: {type(obj).__name__}"
43+
#try:
44+
# return vars(obj)['_jstr']
45+
#except Exception as e:
2946
# print(f"Non-serializable: {obj}, {vars(obj)} {vars(obj.__class__)} :: {e}\n\n")
30-
return f"Non-serializable: {type(obj).__name__}"
47+
# return f"Non-serializable: {type(obj).__name__}"
3148

3249
defaults = {
3350
"VariableSelectionPrior": 0.5,
@@ -101,10 +118,10 @@ def _custom_serializer(obj):
101118
# [print(e) for e in E]
102119

103120
jars_out = {}
104-
with open(cfg_fn, 'r') as oldconfig:
105-
oslc = json.load(oldconfig)
106-
if "jars" in oslc:
107-
jars_out = oslc["jars"]
121+
#with open(cfg_fn, 'r') as oldconfig:
122+
# oslc = json.load(oldconfig)
123+
# if "jars" in oslc:
124+
# jars_out = oslc["jars"]
108125

109126
for k, v in jar_dict.items():
110127
if k is not None:
@@ -116,19 +133,19 @@ def _custom_serializer(obj):
116133

117134

118135
def new_simple_lda_config(
119-
dataset="dataset.txt",
120-
nr_topics=20,
121-
alpha=None,
122-
beta=None,
123-
iterations=2000,
124-
rareword_threshold=10,
125-
optim_interval=-1,
126-
stoplist_fn="stoplist.txt",
127-
topic_interval=10,
128-
tmpdir="/tmp",
129-
topic_priors="priors.txt",
136+
dataset = "dataset.txt",
137+
nr_topics = 20,
138+
alpha = None,
139+
beta = None,
140+
iterations = 2000,
141+
rareword_threshold = 10,
142+
optim_interval = -1,
143+
stoplist_fn = "stoplist.txt",
144+
topic_interval = 10,
145+
tmpdir = "/tmp",
146+
topic_priors = "priors.txt",
147+
jar_key = "default",
130148
jar_dict = {},
131-
cfg_fn = None
132149
):
133150
"""
134151
Create a new LDA config file with default values, unless otherwise specified.
@@ -154,14 +171,33 @@ def new_simple_lda_config(
154171
config object
155172
"""
156173

174+
print(len(jar_dict))
175+
if len(jar_dict) == 0:
176+
warnings.warn("You need at least one PCLDA jarfile to work with this library, but you haven't provided one.")
177+
inp = input("Do you want to provide a default PCLDA jar file now? Enter the path from the cwd (or q to exit): ")
178+
if inp == 'q':
179+
print("Ok, exiting")
180+
sys.exit()
181+
else:
182+
jar_dict = {"default": os.path.abspath(inp)}
183+
157184
if alpha is None:
158185
alpha = 50 / nr_topics
159186

160187
if beta is None:
161188
beta = nr_topics / 5
162-
189+
190+
# make sure jave is running
191+
ensure_java(jar_dict[jar_key])
192+
163193
# Initialize LoggingUtils
164-
lu = jpype.JClass("cc.mallet.util.LoggingUtils")()
194+
try:
195+
lu = jpype.JClass("cc.mallet.util.LoggingUtils")()
196+
print(lu)
197+
except Exception as e:
198+
print(e)
199+
exit()
200+
165201
lu.checkAndCreateCurrentLogDir(tmpdir)
166202

167203
# Initialize SimpleLDAConfiguration
@@ -184,20 +220,6 @@ def new_simple_lda_config(
184220
slc.setDatasetFilename(dataset)
185221
slc.setHyperparamOptimInterval(jpype.JInt(topic_interval))
186222
slc.setNoPreprocess(True)
187-
188-
print(len(jar_dict))
189-
if len(jar_dict) == 0:
190-
warnings.warn("You need at least one PCLDA jarfile to work with this library, but you haven't provided one.")
191-
inp = input("Do you want to provide a default PCLDA jar file now? Enter the path from the cwd (or q to exit): ")
192-
if inp == 'q':
193-
print("Ok, exiting")
194-
sys.exit()
195-
else:
196-
jar_dict = {"default": os.path.abspath(inp)}
197-
198-
if cfg_fn is not None:
199-
write_config(slc, cfg_fn, jar_dict=jar_dict)
200-
201223
return slc
202224

203225

@@ -230,7 +252,7 @@ def load_lda_config(cfg_fn):
230252
j = json.load(inf)
231253

232254
if "Beta" not in j or j["Beta"] is None:
233-
j["Beta"] = j["NoTopics"] / 50
255+
j["pclda_config"]["Beta"] = j["pclda_config"]["NoTopics"] / 50
234256

235257
# Initialize SimpleLDAConfiguration
236258
slc = new_simple_lda_config(jar_dict=j["jars"])

0 commit comments

Comments
 (0)