Skip to content

Commit 2b7a5d5

Browse files
authored
Merge branch 'main' into dev
2 parents cbcf3ba + 5664260 commit 2b7a5d5

1 file changed

Lines changed: 75 additions & 18 deletions

File tree

pcldapy/pclda.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import jpype
33
import json
44
import os
5+
import pickle
56
import sys
67
import warnings
78

@@ -16,6 +17,24 @@ def ensure_java(classpath, java_home="/usr/lib/jvm/java-11-openjdk"):
1617
print(jpype.isJVMStarted())
1718

1819

20+
21+
def _get_java_types():
22+
java_types = {
23+
"Alpha": jpype.JDouble,
24+
"Beta": jpype.JDouble,
25+
"HyperparamOptimInterval": jpype.JInt,
26+
"NoBatches": jpype.JInt,
27+
"NoIters": jpype.JInt,
28+
"NoTopicBatches": jpype.JInt,
29+
"RareThreshold": jpype.JInt,
30+
"StartDiagnostic": jpype.JInt,
31+
"TfIdfThreshold": jpype.JInt,
32+
"TopicInterval": jpype.JInt,
33+
"SavedSamplerDirectory": jpype.JClass("java.lang.String"),
34+
"SavedSamplerDir": jpype.JClass("java.lang.String"),
35+
}
36+
return java_types
37+
1938
def write_config(slc, cfg_fn, jar_dict=None):
2039
"""
2140
Take a config object and a file name, write the config to a json file
@@ -116,8 +135,24 @@ def _custom_serializer(obj):
116135
# print(" ---- OK")
117136
# except:
118137
# [print(e) for e in E]
138+
for k, v in slc_out.items():
139+
if str(type(v)).endswith("Integer'>"):
140+
slc_out[k] = int(v)
141+
elif str(type(v)).endswith("String'>"):
142+
slc_out[k] = str(v)
143+
elif str(type(v)).endswith("Double'>"):
144+
slc_out[k] = float(v)
145+
for k, v in slc_out.items():
146+
try:
147+
pickle.dumps(v)
148+
except:
149+
try:
150+
picke.dump(vars(v)['_jstr'])
151+
except:
152+
print(f"SERIALIZE ERROR: ({k}) {v}, --{str(v)}-- type: {type(v)} ")
119153

120154
jars_out = {}
155+
121156
#with open(cfg_fn, 'r') as oldconfig:
122157
# oslc = json.load(oldconfig)
123158
# if "jars" in oslc:
@@ -163,13 +198,39 @@ def new_simple_lda_config(
163198
topic_interval: how often to print topic info during sampling
164199
tmpdir: temporary directory for intermediate storage of logging data (default "tmp")
165200
topic_priors: text file with 'prior spec' with one topic per line with format: <topic nr(zero idxed)>, <word1>, <word2>, etc
201+
jar: jar to load (by key in the jar dict)
166202
jar_dict: named jar file dict. If you only work with on jar file this should be `{'default': 'path/to/jarfile'}`
167203
cfg_fn: path to config file. if the file exists, it will be updated with provided values, if not the new config will be written to a json file.
168204
169205
Returns
170206
171207
config object
172208
"""
209+
# Make sure you can load JVM any time you load a config.
210+
def ensure_jvm(jar):
211+
if not jpype.isJVMStarted():
212+
jpype.startJVM(classpath=[jar])
213+
214+
215+
if len(jar_dict) == 0:
216+
warnings.warn("You need at least one PCLDA jarfile to work with this library, but you haven't provided one.")
217+
inp = input("Do you want to provide a default PCLDA jar file now? Enter the path from the cwd (or q to exit): ")
218+
if inp == 'q':
219+
print("Ok, exiting")
220+
sys.exit()
221+
else:
222+
jar_dict = {"default": os.path.abspath(inp)}
223+
224+
target_jar = jar_dict[jar]
225+
ensure_jvm(target_jar)
226+
227+
java_types = _get_java_types()
228+
229+
print(jpype.java.lang.System.getProperty("java.class.path"))
230+
231+
# Initialize LoggingUtils
232+
lu = jpype.JClass("cc.mallet.util.LoggingUtils")()
233+
lu.checkAndCreateCurrentLogDir(tmpdir)
173234

174235
print(len(jar_dict))
175236
if len(jar_dict) == 0:
@@ -223,7 +284,7 @@ def new_simple_lda_config(
223284
return slc
224285

225286

226-
def load_lda_config(cfg_fn):
287+
def load_lda_config(cfg_fn, jar='default'):
227288
"""
228289
Load the lda config file from a json file.
229290
@@ -235,18 +296,6 @@ def load_lda_config(cfg_fn):
235296
236297
config object
237298
"""
238-
java_types = {
239-
"Alpha": jpype.JDouble,
240-
"Beta": jpype.JDouble,
241-
"HyperparamOptimInterval": jpype.JInt,
242-
"NoBatches": jpype.JInt,
243-
"NoIters": jpype.JInt,
244-
"NoTopicBatches": jpype.JInt,
245-
"RareThreshold": jpype.JInt,
246-
"StartDiagnostic": jpype.JInt,
247-
"TfIdfThreshold": jpype.JInt,
248-
"TopicInterval": jpype.JInt,
249-
}
250299
# load jaon cfg declaration as a dict
251300
with open(cfg_fn, 'r') as inf:
252301
j = json.load(inf)
@@ -255,7 +304,9 @@ def load_lda_config(cfg_fn):
255304
j["pclda_config"]["Beta"] = j["pclda_config"]["NoTopics"] / 50
256305

257306
# Initialize SimpleLDAConfiguration
258-
slc = new_simple_lda_config(jar_dict=j["jars"])
307+
slc = new_simple_lda_config(jar_dict=j["jars"], jar=jar)
308+
print(type(slc), slc)
309+
java_types = _get_java_types()
259310

260311
# replace slc init values with dict
261312
for k, v in j["pclda_config"].items():
@@ -268,6 +319,7 @@ def load_lda_config(cfg_fn):
268319
method(v)
269320
else:
270321
warnings.warn(f"Unrecognized key in provided config file :: {k} = {v}")
322+
print(type(slc), slc)
271323
return slc, j["jars"]
272324

273325

@@ -315,7 +367,7 @@ def create_lda_dataset(train, test=None, stoplist_fn="stoplist.txt"):
315367
util = jpype.JClass("cc.mallet.util.LDADatasetStringLoadingUtils")()
316368
#pipe = util.buildSerialPipe(stoplist_fn, jpype.JNull("cc.mallet.types.Alphabet"), True)
317369
pipe = util.buildSerialPipe(stoplist_fn, None, True)
318-
print(pipe)
370+
#print(pipe)
319371
# Create InstanceList for the training data
320372
il = jpype.JClass("cc.mallet.types.InstanceList")(pipe)
321373
il.addThruPipe(jpype.JObject(string_iterator, "java.util.Iterator"))
@@ -357,9 +409,14 @@ def sample_pclda(ldaconfig, ds, iterations=2000, sampler_type="cc.mallet.topics.
357409
if testset is not None:
358410
lda.addTestInstances(testset)
359411

360-
# Perform sampling
361-
lda.sample(iterations)
412+
print("**********************", iterations)
362413

414+
# Perform sampling
415+
print("********************", iterations, type(iterations))
416+
try:
417+
lda.sample(iterations)
418+
except Exception as e:
419+
print("EXCEPTION in lda.sample():", e)
363420
# If we need to save the sampler, perform the saving procedure
364421
if save_sampler:
365422
sampler_dir = jpype.JClass("cc.mallet.configuration.LDAConfiguration").STORED_SAMPLER_DIR_DEFAULT
@@ -369,7 +426,7 @@ def sample_pclda(ldaconfig, ds, iterations=2000, sampler_type="cc.mallet.topics.
369426
util.saveSampler(jpype.JObject(lda, "cc.mallet.topics.LDAGibbsSampler"),
370427
jpype.JObject(ldaconfig, "cc.mallet.configuration.LDAConfiguration"),
371428
sampler_folder)
372-
429+
print("*** retunrn fn")
373430
return lda
374431

375432

0 commit comments

Comments
 (0)