11package com .ibm .wala .cast .python .ml .client ;
22
3+ import static com .ibm .wala .cast .python .ml .client .RaggedConstant .Parameters .RAGGED_RANK ;
34import static com .ibm .wala .cast .python .ml .types .TensorFlowTypes .DType .FLOAT32 ;
45import static com .ibm .wala .cast .python .types .PythonTypes .Root ;
56import static com .ibm .wala .cast .python .types .PythonTypes .list ;
2829import java .util .ArrayList ;
2930import java .util .EnumSet ;
3031import java .util .List ;
31- import java .util .Optional ;
3232import java .util .Set ;
3333import java .util .logging .Logger ;
3434import java .util .stream .StreamSupport ;
@@ -57,6 +57,65 @@ public RaggedConstant(PointsToSetVariable source) {
5757 super (source );
5858 }
5959
60+ private static Set <Integer > getPossibleInnerListLengths (
61+ PropagationCallGraphBuilder builder , OrdinalSet <InstanceKey > pts ) {
62+ Set <Integer > ret = HashSetFactory .make ();
63+ PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
64+
65+ for (InstanceKey ik : pts ) {
66+ AllocationSiteInNode asin = getAllocationSiteInNode (ik );
67+ TypeReference reference = asin .getConcreteType ().getReference ();
68+
69+ // A `list` or `tuple`.
70+ if (reference .equals (list ) || reference .equals (tuple )) {
71+ OrdinalSet <InstanceKey > objectCatalogPointsToSet =
72+ pointerAnalysis .getPointsToSet (
73+ ((AstPointerKeyFactory ) builder .getPointerKeyFactory ())
74+ .getPointerKeyForObjectCatalog (asin ));
75+
76+ assert objectCatalogPointsToSet .iterator ().hasNext ();
77+
78+ InstanceKey catalogIK =
79+ objectCatalogPointsToSet
80+ .iterator ()
81+ .next (); // Just need one element to check inner length.
82+
83+ ConstantKey <?> constantKey = (ConstantKey <?>) catalogIK ;
84+ Object constantKeyValue = constantKey .getValue ();
85+
86+ Integer fieldIndex = (Integer ) constantKeyValue ;
87+
88+ FieldReference subscript =
89+ FieldReference .findOrCreate (Root , findOrCreateAsciiAtom (fieldIndex .toString ()), Root );
90+
91+ IField f = builder .getClassHierarchy ().resolveField (subscript );
92+
93+ PointerKey pointerKeyForInstanceField = builder .getPointerKeyForInstanceField (asin , f );
94+
95+ OrdinalSet <InstanceKey > instanceFieldPointsToSet =
96+ pointerAnalysis .getPointsToSet (pointerKeyForInstanceField );
97+
98+ boolean containsAllListsOrTuples =
99+ StreamSupport .stream (instanceFieldPointsToSet .spliterator (), false )
100+ .allMatch (
101+ ik -> {
102+ AllocationSiteInNode innerAsin = getAllocationSiteInNode (ik );
103+
104+ if (innerAsin == null ) return false ;
105+
106+ TypeReference innerReference = innerAsin .getConcreteType ().getReference ();
107+ return innerReference .equals (list ) || innerReference .equals (tuple );
108+ });
109+
110+ if (!containsAllListsOrTuples ) ret .add (objectCatalogPointsToSet .size ());
111+ else ret .addAll (getPossibleInnerListLengths (builder , instanceFieldPointsToSet ));
112+ } else
113+ throw new IllegalStateException ("Expected a list or tuple, but found: " + reference + "." );
114+ }
115+
116+ return ret ;
117+ }
118+
60119 private static Set <Integer > getPossibleOuterListLengths (
61120 PropagationCallGraphBuilder builder , OrdinalSet <InstanceKey > valuePointsToSet ) {
62121 Set <Integer > ret = HashSetFactory .make ();
@@ -246,31 +305,58 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
246305 int K = maxDepth ;
247306 LOGGER .fine ("Tensor rank: " + K );
248307
249- Optional <Integer > raggedRank = this .getRaggedRankArgumentValue (builder );
250- int R = raggedRank .orElse (K - 1 );
251- LOGGER .fine ("Ragged rank: " + R );
308+ Set <Long > rankArguments = this .getPossibleRaggedRankArguments (builder );
252309
253- // Step 3: Construct shape with rank K and ragged rank R .
310+ if ( rankArguments . isEmpty ()) rankArguments . add ( K - 1L ); // Default ragged rank.
254311
255- // Get the length of the outer list.
256- Set <Integer > possibleOuterListLengths =
257- getPossibleOuterListLengths (builder , valuePointsToSet );
312+ for (Long R : rankArguments ) {
313+ LOGGER .fine ("Ragged rank: " + R );
258314
259- for ( int outerListLength : possibleOuterListLengths ) {
260- List < Dimension <?>> shape = new ArrayList <>();
261- shape . add ( new NumericDim ( outerListLength ));
315+ // Step 3: Construct shape with rank K and ragged rank R.
316+ // The final shape is constructed by concatenating the Ragged Portion and the Uniform
317+ // Portion.
262318
263- // The first R dimensions are ragged.
264- for (int i = 0 ; i < R ; i ++) shape .add (null ); // Unknown size for ragged dimensions.
319+ // Part A: The Ragged Portion (Dimensions 0 to R)
265320
266- /*
267- // The remaining K - R dimensions are dense.
268- for (int i = R; i < K; i++) {
269- shape.add(new NumericDim(-1)); // Unknown size for dense dimensions.
270- }
271- */
321+ // For the ragged dimensions, TensorFlow does not look for a uniform length. It assigns the
322+ // shape based on the row_splits.
272323
273- ret .add (shape );
324+ // Get the length of the outer list.
325+ Set <Integer > possibleOuterListLengths =
326+ getPossibleOuterListLengths (builder , valuePointsToSet );
327+
328+ for (int outerListLength : possibleOuterListLengths ) {
329+ List <Dimension <?>> shape = new ArrayList <>();
330+
331+ // Dim 0 (Batch): Always fixed. It is simply len(input_list).
332+ shape .add (new NumericDim (outerListLength ));
333+
334+ // The first R dimensions are ragged.
335+ // Dim 1 to R: These are assigned None (or ? in older outputs) in the static shape,
336+ // indicating they can vary.
337+ for (Long i = 0L ; i < R ; i ++) shape .add (null ); // Unknown size for ragged dimensions.
338+
339+ // Part B: The Uniform Portion (Dimensions R + 1 to K)
340+ // If R < K - 1 (meaning you requested fewer ragged dimensions than the total depth),
341+ // TensorFlow enforces uniformity on the remaining inner dimensions.
342+
343+ // 1. It checks the length of every sub-list at these levels.
344+ // 2. If any lengths differ, it throws a ValueError.
345+ // 3. If they match, that length becomes the fixed size for that dimension.
346+
347+ if (R < K - 1 ) {
348+ Set <Integer > possibleInnerListLengths =
349+ getPossibleInnerListLengths (builder , valuePointsToSet );
350+
351+ // Determine the uniform lengths for dimensions R + 1 to K - 1.
352+ for (long i = R + 1 ; i < K ; i ++) {
353+ for (int innerListLength : possibleInnerListLengths )
354+ shape .add (new NumericDim (innerListLength ));
355+ }
356+ }
357+
358+ ret .add (shape );
359+ }
274360 }
275361 }
276362
@@ -288,9 +374,46 @@ private static int getMaximumDepthOfInstance(
288374 return 1 + getMaximumDepthOfEmptyList (builder , instance );
289375 }
290376
291- private Optional <Integer > getRaggedRankArgumentValue (PropagationCallGraphBuilder builder ) {
292- // TODO Auto-generated method stub
293- return Optional .empty ();
377+ protected Set <Long > getPossibleRaggedRankArguments (PropagationCallGraphBuilder builder ) {
378+ Set <Long > ret = HashSetFactory .make ();
379+ int valueNumber = this .getRaggedRankArgumentValueNumber (builder );
380+
381+ if (valueNumber >= 0 ) {
382+ PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
383+ PointerKey raggedRankPK =
384+ pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), valueNumber );
385+ OrdinalSet <InstanceKey > raggedRankPointsToSet = pointerAnalysis .getPointsToSet (raggedRankPK );
386+
387+ if (raggedRankPointsToSet == null || raggedRankPointsToSet .isEmpty ())
388+ throw new IllegalArgumentException (
389+ "Empty points-to set for ragged_rank in source: " + this .getSource () + "." );
390+
391+ for (InstanceKey raggedRankIK : raggedRankPointsToSet )
392+ if (raggedRankIK instanceof ConstantKey ) {
393+ ConstantKey <?> constantKey = (ConstantKey <?>) raggedRankIK ;
394+ Object constantKeyValue = constantKey .getValue ();
395+
396+ if (constantKeyValue instanceof Long ) {
397+ Long raggedRankValue = (Long ) constantKeyValue ;
398+ ret .add (raggedRankValue );
399+ } else
400+ throw new IllegalArgumentException (
401+ "Expected an integer for ragged_rank, but found: " + constantKeyValue + "." );
402+ } else
403+ throw new IllegalArgumentException (
404+ "Expected a constant key for ragged_rank, but found: " + raggedRankIK + "." );
405+ }
406+
407+ return ret ;
408+ }
409+
410+ protected int getRaggedRankParameterPosition () {
411+ return RAGGED_RANK .ordinal ();
412+ }
413+
414+ protected int getRaggedRankArgumentValueNumber (PropagationCallGraphBuilder builder ) {
415+ // TODO: Handle keyword arguments.
416+ return this .getArgumentValueNumber (builder , this .getRaggedRankParameterPosition (), true );
294417 }
295418
296419 /**
@@ -323,6 +446,7 @@ protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
323446 return EnumSet .of (FLOAT32 );
324447 }
325448
449+ // Otherwise, there are values available to infer the dtype from.
326450 return super .getDefaultDTypes (builder );
327451 }
328452}
0 commit comments