Skip to content

Commit 8b0467f

Browse files
committed
GROOVY-11335: STC: loop item type from UnionTypeClassNode
1 parent e4c7e2d commit 8b0467f

File tree

3 files changed

+90
-83
lines changed

3 files changed

+90
-83
lines changed

src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java

+32-29
Original file line numberDiff line numberDiff line change
@@ -1979,33 +1979,34 @@ public void visitForLoop(final ForStatement forLoop) {
19791979
* @see #inferComponentType
19801980
*/
19811981
public static ClassNode inferLoopElementType(final ClassNode collectionType) {
1982-
ClassNode componentType = collectionType.getComponentType();
1983-
if (componentType == null) {
1984-
if (isOrImplements(collectionType, ITERABLE_TYPE)) {
1985-
ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE);
1986-
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
1987-
1988-
} else if (isOrImplements(collectionType, MAP_TYPE)) { // GROOVY-6240
1989-
ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE);
1990-
componentType = makeClassSafe0(MAP_ENTRY_TYPE, col.getGenericsTypes());
1991-
1992-
} else if (isOrImplements(collectionType, STREAM_TYPE)) { // GROOVY-10476
1993-
ClassNode col = GenericsUtils.parameterizeType(collectionType, STREAM_TYPE);
1994-
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
1995-
1996-
} else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123
1997-
ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE);
1998-
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
1999-
2000-
} else if (isOrImplements(collectionType, Iterator_TYPE)) { // GROOVY-10712
2001-
ClassNode col = GenericsUtils.parameterizeType(collectionType, Iterator_TYPE);
2002-
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
2003-
2004-
} else if (isStringType(collectionType)) {
2005-
componentType = STRING_TYPE;
2006-
} else {
2007-
componentType = OBJECT_TYPE;
2008-
}
1982+
ClassNode componentType;
1983+
if (collectionType.isArray()) { // GROOVY-11335
1984+
componentType = collectionType.getComponentType();
1985+
1986+
} else if (isOrImplements(collectionType, ITERABLE_TYPE)) {
1987+
ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE);
1988+
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
1989+
1990+
} else if (isOrImplements(collectionType, MAP_TYPE)) { // GROOVY-6240
1991+
ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE);
1992+
componentType = makeClassSafe0(MAP_ENTRY_TYPE, col.getGenericsTypes());
1993+
1994+
} else if (isOrImplements(collectionType, STREAM_TYPE)) { // GROOVY-10476
1995+
ClassNode col = GenericsUtils.parameterizeType(collectionType, STREAM_TYPE);
1996+
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
1997+
1998+
} else if (isOrImplements(collectionType, Iterator_TYPE)) { // GROOVY-10712
1999+
ClassNode col = GenericsUtils.parameterizeType(collectionType, Iterator_TYPE);
2000+
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
2001+
2002+
} else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123
2003+
ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE);
2004+
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
2005+
2006+
} else if (isStringType(collectionType)) {
2007+
componentType = STRING_TYPE;
2008+
} else {
2009+
componentType = OBJECT_TYPE;
20092010
}
20102011
return componentType;
20112012
}
@@ -4716,8 +4717,10 @@ protected static ClassNode getGroupOperationResultType(final ClassNode a, final
47164717
}
47174718

47184719
protected ClassNode inferComponentType(final ClassNode receiverType, final ClassNode subscriptType) {
4719-
ClassNode componentType = receiverType.getComponentType();
4720-
if (componentType == null) {
4720+
ClassNode componentType = null;
4721+
if (receiverType.isArray()) { // GROOVY-11335
4722+
componentType = receiverType.getComponentType();
4723+
} else {
47214724
MethodCallExpression mce;
47224725
if (subscriptType != null) { // GROOVY-5521: check for a suitable "getAt(T)" method
47234726
mce = callX(varX("#", receiverType), "getAt", varX("selector", subscriptType));

src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java

+44-54
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.codehaus.groovy.transform.ASTTransformation;
3636

3737
import java.util.Arrays;
38-
import java.util.Collections;
3938
import java.util.HashSet;
4039
import java.util.Iterator;
4140
import java.util.LinkedHashSet;
@@ -172,59 +171,51 @@ public void addTransform(final Class<? extends ASTTransformation> transform, fin
172171
throw new UnsupportedOperationException();
173172
}
174173

175-
@Override
176-
public boolean declaresInterface(final ClassNode classNode) {
177-
for (ClassNode delegate : delegates) {
178-
if (delegate.declaresInterface(classNode)) return true;
179-
}
180-
return false;
181-
}
182-
183174
@Override
184175
public List<MethodNode> getAbstractMethods() {
185-
List<MethodNode> allMethods = new LinkedList<MethodNode>();
176+
List<MethodNode> answer = new LinkedList<>();
186177
for (ClassNode delegate : delegates) {
187-
allMethods.addAll(delegate.getAbstractMethods());
178+
answer.addAll(delegate.getAbstractMethods());
188179
}
189-
return allMethods;
180+
return answer;
190181
}
191182

192183
@Override
193184
public List<MethodNode> getAllDeclaredMethods() {
194-
List<MethodNode> allMethods = new LinkedList<MethodNode>();
185+
List<MethodNode> answer = new LinkedList<>();
195186
for (ClassNode delegate : delegates) {
196-
allMethods.addAll(delegate.getAllDeclaredMethods());
187+
answer.addAll(delegate.getAllDeclaredMethods());
197188
}
198-
return allMethods;
189+
return answer;
199190
}
200191

201192
@Override
202193
public Set<ClassNode> getAllInterfaces() {
203-
Set<ClassNode> allMethods = new HashSet<ClassNode>();
194+
Set<ClassNode> answer = new HashSet<>();
204195
for (ClassNode delegate : delegates) {
205-
allMethods.addAll(delegate.getAllInterfaces());
196+
answer.addAll(delegate.getAllInterfaces());
206197
}
207-
return allMethods;
198+
return answer;
208199
}
209200

210201
@Override
211202
public List<AnnotationNode> getAnnotations() {
212-
List<AnnotationNode> nodes = new LinkedList<AnnotationNode>();
203+
List<AnnotationNode> answer = new LinkedList<>();
213204
for (ClassNode delegate : delegates) {
214205
List<AnnotationNode> annotations = delegate.getAnnotations();
215-
if (annotations != null) nodes.addAll(annotations);
206+
if (annotations != null) answer.addAll(annotations);
216207
}
217-
return nodes;
208+
return answer;
218209
}
219210

220211
@Override
221212
public List<AnnotationNode> getAnnotations(final ClassNode type) {
222-
List<AnnotationNode> nodes = new LinkedList<AnnotationNode>();
213+
List<AnnotationNode> answer = new LinkedList<>();
223214
for (ClassNode delegate : delegates) {
224215
List<AnnotationNode> annotations = delegate.getAnnotations(type);
225-
if (annotations != null) nodes.addAll(annotations);
216+
if (annotations != null) answer.addAll(annotations);
226217
}
227-
return nodes;
218+
return answer;
228219
}
229220

230221
@Override
@@ -234,11 +225,11 @@ public ClassNode getComponentType() {
234225

235226
@Override
236227
public List<ConstructorNode> getDeclaredConstructors() {
237-
List<ConstructorNode> nodes = new LinkedList<ConstructorNode>();
228+
List<ConstructorNode> answer = new LinkedList<>();
238229
for (ClassNode delegate : delegates) {
239-
nodes.addAll(delegate.getDeclaredConstructors());
230+
answer.addAll(delegate.getDeclaredConstructors());
240231
}
241-
return nodes;
232+
return answer;
242233
}
243234

244235
@Override
@@ -261,12 +252,12 @@ public MethodNode getDeclaredMethod(final String name, final Parameter[] paramet
261252

262253
@Override
263254
public List<MethodNode> getDeclaredMethods(final String name) {
264-
List<MethodNode> nodes = new LinkedList<MethodNode>();
255+
List<MethodNode> answer = new LinkedList<>();
265256
for (ClassNode delegate : delegates) {
266257
List<MethodNode> methods = delegate.getDeclaredMethods(name);
267-
if (methods != null) nodes.addAll(methods);
258+
if (methods != null) answer.addAll(methods);
268259
}
269-
return nodes;
260+
return answer;
270261
}
271262

272263
@Override
@@ -290,12 +281,12 @@ public FieldNode getField(final String name) {
290281

291282
@Override
292283
public List<FieldNode> getFields() {
293-
List<FieldNode> nodes = new LinkedList<FieldNode>();
284+
List<FieldNode> answer = new LinkedList<>();
294285
for (ClassNode delegate : delegates) {
295286
List<FieldNode> fields = delegate.getFields();
296-
if (fields != null) nodes.addAll(fields);
287+
if (fields != null) answer.addAll(fields);
297288
}
298-
return nodes;
289+
return answer;
299290
}
300291

301292
@Override
@@ -305,22 +296,25 @@ public Iterator<InnerClassNode> getInnerClasses() {
305296

306297
@Override
307298
public ClassNode[] getInterfaces() {
308-
Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
299+
Set<ClassNode> answer = new LinkedHashSet<>();
309300
for (ClassNode delegate : delegates) {
310-
ClassNode[] interfaces = delegate.getInterfaces();
311-
if (interfaces != null) Collections.addAll(nodes, interfaces);
301+
if (delegate.isInterface()) {
302+
answer.remove(delegate); answer.add(delegate);
303+
} else {
304+
answer.addAll(Arrays.asList(delegate.getInterfaces()));
305+
}
312306
}
313-
return nodes.toArray(ClassNode.EMPTY_ARRAY);
307+
return answer.toArray(ClassNode.EMPTY_ARRAY);
314308
}
315309

316310
@Override
317311
public List<MethodNode> getMethods() {
318-
List<MethodNode> nodes = new LinkedList<MethodNode>();
312+
List<MethodNode> answer = new LinkedList<>();
319313
for (ClassNode delegate : delegates) {
320314
List<MethodNode> methods = delegate.getMethods();
321-
if (methods != null) nodes.addAll(methods);
315+
if (methods != null) answer.addAll(methods);
322316
}
323-
return nodes;
317+
return answer;
324318
}
325319

326320
@Override
@@ -334,12 +328,12 @@ public ClassNode getPlainNodeReference(final boolean skipPrimitives) {
334328

335329
@Override
336330
public List<PropertyNode> getProperties() {
337-
List<PropertyNode> nodes = new LinkedList<PropertyNode>();
331+
List<PropertyNode> answer = new LinkedList<>();
338332
for (ClassNode delegate : delegates) {
339333
List<PropertyNode> properties = delegate.getProperties();
340-
if (properties != null) nodes.addAll(properties);
334+
if (properties != null) answer.addAll(properties);
341335
}
342-
return nodes;
336+
return answer;
343337
}
344338

345339
@Override
@@ -349,22 +343,18 @@ public Class getTypeClass() {
349343

350344
@Override
351345
public ClassNode[] getUnresolvedInterfaces() {
352-
Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
353-
for (ClassNode delegate : delegates) {
354-
ClassNode[] interfaces = delegate.getUnresolvedInterfaces();
355-
if (interfaces != null) Collections.addAll(nodes, interfaces);
356-
}
357-
return nodes.toArray(ClassNode.EMPTY_ARRAY);
346+
return getUnresolvedInterfaces(false);
358347
}
359348

360349
@Override
361350
public ClassNode[] getUnresolvedInterfaces(final boolean useRedirect) {
362-
Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
363-
for (ClassNode delegate : delegates) {
364-
ClassNode[] interfaces = delegate.getUnresolvedInterfaces(useRedirect);
365-
if (interfaces != null) Collections.addAll(nodes, interfaces);
351+
ClassNode[] interfaces = getInterfaces();
352+
if (useRedirect) {
353+
for (int i = 0; i < interfaces.length; ++i) {
354+
interfaces[i] = interfaces[i].redirect();
355+
}
366356
}
367-
return nodes.toArray(ClassNode.EMPTY_ARRAY);
357+
return interfaces;
368358
}
369359

370360
@Override

src/test/groovy/transform/stc/LoopsSTCTest.groovy

+14
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,20 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
247247
'''
248248
}
249249

250+
// GROOVY-11335
251+
void testForInLoopOnCollection() {
252+
assertScript '''
253+
def whatever(Collection<String> coll) {
254+
if (coll instanceof Serializable) {
255+
for (item in coll) {
256+
return item.toLowerCase()
257+
}
258+
}
259+
}
260+
assert whatever(['Works']) == 'works'
261+
'''
262+
}
263+
250264
// GROOVY-6123
251265
void testForInLoopOnEnumeration() {
252266
assertScript '''

0 commit comments

Comments
 (0)