22import jpype
33import json
44import os
5+ import pickle
56import sys
67import warnings
78
@@ -16,6 +17,24 @@ def ensure_java(classpath, java_home="/usr/lib/jvm/java-11-openjdk"):
1617 print (jpype .isJVMStarted ())
1718
1819
20+
21+ def _get_java_types ():
22+ java_types = {
23+ "Alpha" : jpype .JDouble ,
24+ "Beta" : jpype .JDouble ,
25+ "HyperparamOptimInterval" : jpype .JInt ,
26+ "NoBatches" : jpype .JInt ,
27+ "NoIters" : jpype .JInt ,
28+ "NoTopicBatches" : jpype .JInt ,
29+ "RareThreshold" : jpype .JInt ,
30+ "StartDiagnostic" : jpype .JInt ,
31+ "TfIdfThreshold" : jpype .JInt ,
32+ "TopicInterval" : jpype .JInt ,
33+ "SavedSamplerDirectory" : jpype .JClass ("java.lang.String" ),
34+ "SavedSamplerDir" : jpype .JClass ("java.lang.String" ),
35+ }
36+ return java_types
37+
1938def write_config (slc , cfg_fn , jar_dict = None ):
2039 """
2140 Take a config object and a file name, write the config to a json file
@@ -116,8 +135,24 @@ def _custom_serializer(obj):
116135 # print(" ---- OK")
117136 # except:
118137 # [print(e) for e in E]
138+ for k , v in slc_out .items ():
139+ if str (type (v )).endswith ("Integer'>" ):
140+ slc_out [k ] = int (v )
141+ elif str (type (v )).endswith ("String'>" ):
142+ slc_out [k ] = str (v )
143+ elif str (type (v )).endswith ("Double'>" ):
144+ slc_out [k ] = float (v )
145+ for k , v in slc_out .items ():
146+ try :
147+ pickle .dumps (v )
148+ except :
149+ try :
150+ picke .dump (vars (v )['_jstr' ])
151+ except :
152+ print (f"SERIALIZE ERROR: ({ k } ) { v } , --{ str (v )} -- type: { type (v )} " )
119153
120154 jars_out = {}
155+
121156 #with open(cfg_fn, 'r') as oldconfig:
122157 # oslc = json.load(oldconfig)
123158 # if "jars" in oslc:
@@ -163,13 +198,39 @@ def new_simple_lda_config(
163198 topic_interval: how often to print topic info during sampling
164199 tmpdir: temporary directory for intermediate storage of logging data (default "tmp")
165200 topic_priors: text file with 'prior spec' with one topic per line with format: <topic nr(zero idxed)>, <word1>, <word2>, etc
201+ jar: jar to load (by key in the jar dict)
166202 jar_dict: named jar file dict. If you only work with on jar file this should be `{'default': 'path/to/jarfile'}`
167203 cfg_fn: path to config file. if the file exists, it will be updated with provided values, if not the new config will be written to a json file.
168204
169205 Returns
170206
171207 config object
172208 """
209+ # Make sure you can load JVM any time you load a config.
210+ def ensure_jvm (jar ):
211+ if not jpype .isJVMStarted ():
212+ jpype .startJVM (classpath = [jar ])
213+
214+
215+ if len (jar_dict ) == 0 :
216+ warnings .warn ("You need at least one PCLDA jarfile to work with this library, but you haven't provided one." )
217+ inp = input ("Do you want to provide a default PCLDA jar file now? Enter the path from the cwd (or q to exit): " )
218+ if inp == 'q' :
219+ print ("Ok, exiting" )
220+ sys .exit ()
221+ else :
222+ jar_dict = {"default" : os .path .abspath (inp )}
223+
224+ target_jar = jar_dict [jar ]
225+ ensure_jvm (target_jar )
226+
227+ java_types = _get_java_types ()
228+
229+ print (jpype .java .lang .System .getProperty ("java.class.path" ))
230+
231+ # Initialize LoggingUtils
232+ lu = jpype .JClass ("cc.mallet.util.LoggingUtils" )()
233+ lu .checkAndCreateCurrentLogDir (tmpdir )
173234
174235 print (len (jar_dict ))
175236 if len (jar_dict ) == 0 :
@@ -223,7 +284,7 @@ def new_simple_lda_config(
223284 return slc
224285
225286
226- def load_lda_config (cfg_fn ):
287+ def load_lda_config (cfg_fn , jar = 'default' ):
227288 """
228289 Load the lda config file from a json file.
229290
@@ -235,18 +296,6 @@ def load_lda_config(cfg_fn):
235296
236297 config object
237298 """
238- java_types = {
239- "Alpha" : jpype .JDouble ,
240- "Beta" : jpype .JDouble ,
241- "HyperparamOptimInterval" : jpype .JInt ,
242- "NoBatches" : jpype .JInt ,
243- "NoIters" : jpype .JInt ,
244- "NoTopicBatches" : jpype .JInt ,
245- "RareThreshold" : jpype .JInt ,
246- "StartDiagnostic" : jpype .JInt ,
247- "TfIdfThreshold" : jpype .JInt ,
248- "TopicInterval" : jpype .JInt ,
249- }
250299 # load jaon cfg declaration as a dict
251300 with open (cfg_fn , 'r' ) as inf :
252301 j = json .load (inf )
@@ -255,7 +304,9 @@ def load_lda_config(cfg_fn):
255304 j ["pclda_config" ]["Beta" ] = j ["pclda_config" ]["NoTopics" ] / 50
256305
257306 # Initialize SimpleLDAConfiguration
258- slc = new_simple_lda_config (jar_dict = j ["jars" ])
307+ slc = new_simple_lda_config (jar_dict = j ["jars" ], jar = jar )
308+ print (type (slc ), slc )
309+ java_types = _get_java_types ()
259310
260311 # replace slc init values with dict
261312 for k , v in j ["pclda_config" ].items ():
@@ -268,6 +319,7 @@ def load_lda_config(cfg_fn):
268319 method (v )
269320 else :
270321 warnings .warn (f"Unrecognized key in provided config file :: { k } = { v } " )
322+ print (type (slc ), slc )
271323 return slc , j ["jars" ]
272324
273325
@@ -315,7 +367,7 @@ def create_lda_dataset(train, test=None, stoplist_fn="stoplist.txt"):
315367 util = jpype .JClass ("cc.mallet.util.LDADatasetStringLoadingUtils" )()
316368 #pipe = util.buildSerialPipe(stoplist_fn, jpype.JNull("cc.mallet.types.Alphabet"), True)
317369 pipe = util .buildSerialPipe (stoplist_fn , None , True )
318- print (pipe )
370+ # print(pipe)
319371 # Create InstanceList for the training data
320372 il = jpype .JClass ("cc.mallet.types.InstanceList" )(pipe )
321373 il .addThruPipe (jpype .JObject (string_iterator , "java.util.Iterator" ))
@@ -357,9 +409,14 @@ def sample_pclda(ldaconfig, ds, iterations=2000, sampler_type="cc.mallet.topics.
357409 if testset is not None :
358410 lda .addTestInstances (testset )
359411
360- # Perform sampling
361- lda .sample (iterations )
412+ print ("**********************" , iterations )
362413
414+ # Perform sampling
415+ print ("********************" , iterations , type (iterations ))
416+ try :
417+ lda .sample (iterations )
418+ except Exception as e :
419+ print ("EXCEPTION in lda.sample():" , e )
363420 # If we need to save the sampler, perform the saving procedure
364421 if save_sampler :
365422 sampler_dir = jpype .JClass ("cc.mallet.configuration.LDAConfiguration" ).STORED_SAMPLER_DIR_DEFAULT
@@ -369,7 +426,7 @@ def sample_pclda(ldaconfig, ds, iterations=2000, sampler_type="cc.mallet.topics.
369426 util .saveSampler (jpype .JObject (lda , "cc.mallet.topics.LDAGibbsSampler" ),
370427 jpype .JObject (ldaconfig , "cc.mallet.configuration.LDAConfiguration" ),
371428 sampler_folder )
372-
429+ print ( "*** retunrn fn" )
373430 return lda
374431
375432
0 commit comments