Skip to content

Commit 872953a

Browse files
committed
wip
1 parent c3d29ec commit 872953a

20 files changed

+1641
-72
lines changed

pyjopa/ast.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,13 @@ class Modifier(ASTNode):
8585
annotation: Optional[Annotation]
8686

8787

88-
class TypeDeclaration(ASTNode):
89-
"""Base class for type declarations."""
88+
class ClassBodyDeclaration(ASTNode):
89+
"""Base for class body declarations."""
90+
pass
91+
92+
93+
class TypeDeclaration(ClassBodyDeclaration):
94+
"""Base class for type declarations (can also be nested in classes)."""
9095
pass
9196

9297

pyjopa/classfile.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,15 @@ def _write_constant_value_attribute(cp: ConstantPool, out: bytearray, const: tup
624624
out.extend(struct.pack(">H", idx))
625625

626626

627+
@dataclass
628+
class InnerClassInfo:
629+
"""Represents an entry in the InnerClasses attribute."""
630+
inner_class: str # Internal name like "Outer$Inner"
631+
outer_class: Optional[str] # Internal name like "Outer", None for anonymous/local
632+
inner_name: Optional[str] # Simple name like "Inner", None for anonymous
633+
access_flags: int # Access flags for the inner class
634+
635+
627636
class ClassFile:
628637
"""Represents a Java class file."""
629638

@@ -641,6 +650,7 @@ def __init__(self, name: str, super_class: str = "java/lang/Object",
641650
self.cp = ConstantPool()
642651
self.signature: Optional[str] = None
643652
self.annotations: list[AnnotationInfo] = []
653+
self.inner_classes: list[InnerClassInfo] = []
644654

645655
def add_method(self, method: MethodInfo):
646656
self.methods.append(method)
@@ -738,6 +748,16 @@ def to_bytes(self) -> bytes:
738748
for ann in param_anns:
739749
self.cp.add_utf8(ann.type_descriptor)
740750

751+
# Pre-add "InnerClasses" attribute name and class names
752+
if self.inner_classes:
753+
self.cp.add_utf8("InnerClasses")
754+
for ic in self.inner_classes:
755+
self.cp.add_class(ic.inner_class)
756+
if ic.outer_class:
757+
self.cp.add_class(ic.outer_class)
758+
if ic.inner_name:
759+
self.cp.add_utf8(ic.inner_name)
760+
741761
# Magic number
742762
out.extend(struct.pack(">I", self.MAGIC))
743763

@@ -777,14 +797,41 @@ def to_bytes(self) -> bytes:
777797
attr_count += 1
778798
if self.annotations:
779799
attr_count += 1
800+
if self.inner_classes:
801+
attr_count += 1
780802
out.extend(struct.pack(">H", attr_count))
781803
if self.signature:
782804
_write_signature_attribute(self.cp, out, self.signature)
783805
if self.annotations:
784806
write_annotations_attribute(self.cp, out, "RuntimeVisibleAnnotations", self.annotations)
807+
if self.inner_classes:
808+
self._write_inner_classes_attribute(self.cp, out)
785809

786810
return bytes(out)
787811

812+
def _write_inner_classes_attribute(self, cp: ConstantPool, out: bytearray):
813+
"""Write the InnerClasses attribute."""
814+
attr_name_idx = cp.add_utf8("InnerClasses")
815+
out.extend(struct.pack(">H", attr_name_idx))
816+
817+
# Attribute length: 2 (number_of_classes) + 8 * number_of_classes
818+
attr_len = 2 + 8 * len(self.inner_classes)
819+
out.extend(struct.pack(">I", attr_len))
820+
821+
# Number of classes
822+
out.extend(struct.pack(">H", len(self.inner_classes)))
823+
824+
# Write each inner class entry
825+
for ic in self.inner_classes:
826+
inner_class_idx = cp.add_class(ic.inner_class)
827+
outer_class_idx = cp.add_class(ic.outer_class) if ic.outer_class else 0
828+
inner_name_idx = cp.add_utf8(ic.inner_name) if ic.inner_name else 0
829+
830+
out.extend(struct.pack(">H", inner_class_idx))
831+
out.extend(struct.pack(">H", outer_class_idx))
832+
out.extend(struct.pack(">H", inner_name_idx))
833+
out.extend(struct.pack(">H", ic.access_flags))
834+
788835
def write(self, path: str):
789836
with open(path, "wb") as f:
790837
f.write(self.to_bytes())
@@ -896,6 +943,16 @@ def ldc_string(self, value: str):
896943
self.code.extend(struct.pack(">H", idx))
897944
self._push()
898945

946+
def ldc_class(self, class_name: str):
947+
"""Load a class constant (Class<?> object) onto the stack."""
948+
idx = self.cp.add_class(class_name)
949+
if idx <= 255:
950+
self._emit(Opcode.LDC, idx)
951+
else:
952+
self._emit(Opcode.LDC_W)
953+
self.code.extend(struct.pack(">H", idx))
954+
self._push()
955+
899956
def iload(self, slot: int):
900957
if slot <= 3:
901958
self._emit(Opcode.ILOAD_0 + slot)
@@ -1269,6 +1326,13 @@ def instanceof_(self, class_name: str):
12691326
self._pop()
12701327
self._push()
12711328

1329+
def checkcast(self, class_name: str):
1330+
"""Cast reference to a class type. Pops reference, pushes same reference (typed)."""
1331+
idx = self.cp.add_class(class_name)
1332+
self._emit(Opcode.CHECKCAST)
1333+
self.code.extend(struct.pack(">H", idx))
1334+
# Stack stays the same (pop ref, push ref)
1335+
12721336
def newarray(self, atype: int):
12731337
"""Create new primitive array. atype: T_BOOLEAN=4, T_CHAR=5, T_FLOAT=6, T_DOUBLE=7, T_BYTE=8, T_SHORT=9, T_INT=10, T_LONG=11"""
12741338
self._emit(Opcode.NEWARRAY, atype)

pyjopa/cli.py

Lines changed: 161 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,136 @@ def parse_command(args):
2828
sys.exit(1)
2929

3030

31+
def _get_file_dependencies(source_file: Path, parser) -> tuple[str, set[str]]:
32+
"""Extract package and imported types from a source file.
33+
Returns: (package_name, set_of_imported_types)"""
34+
ast = parser.parse_file(str(source_file))
35+
36+
package = ast.package.name if ast.package else ""
37+
imported_types = set()
38+
39+
# Collect imported types (simple names only, from same or imported packages)
40+
for imp in ast.imports:
41+
if not imp.is_static and not imp.is_wildcard:
42+
# Single-type import: java.util.List -> List
43+
simple_name = imp.name.split(".")[-1]
44+
imported_types.add(simple_name)
45+
46+
return package, imported_types
47+
48+
49+
def _topological_sort(files: list[Path], parser) -> list[Path]:
50+
"""Sort files in dependency order using topological sort.
51+
Files that define types used by other files come first."""
52+
from collections import defaultdict, deque
53+
54+
# Map: package.TypeName -> source file
55+
type_to_file = {}
56+
# Map: source file -> set of types it depends on
57+
file_deps = {}
58+
59+
# First pass: identify what types each file defines
60+
for f in files:
61+
ast = parser.parse_file(str(f))
62+
package = ast.package.name if ast.package else ""
63+
64+
for type_decl in ast.types:
65+
type_name = type_decl.name
66+
full_name = f"{package}.{type_name}" if package else type_name
67+
type_to_file[full_name] = f
68+
69+
# Second pass: identify dependencies
70+
for f in files:
71+
ast = parser.parse_file(str(f))
72+
package = ast.package.name if ast.package else ""
73+
deps = set()
74+
75+
# Add dependencies from imports
76+
for imp in ast.imports:
77+
if not imp.is_static and not imp.is_wildcard:
78+
# Check if this import is for a type in our compilation set
79+
imported_type = imp.name
80+
if imported_type in type_to_file:
81+
deps.add(imported_type)
82+
else:
83+
deps.add(imp.name)
84+
85+
# For same-package dependencies, check if type names appear in source
86+
# This is a heuristic for detecting usage without full semantic analysis
87+
source_text = f.read_text()
88+
for full_type_name, type_file in type_to_file.items():
89+
if type_file != f: # Don't depend on ourselves
90+
simple_name = full_type_name.split(".")[-1]
91+
# Check if type name appears in source (crude but effective)
92+
if package:
93+
# Same package?
94+
type_package = ".".join(full_type_name.split(".")[:-1])
95+
if type_package == package and simple_name in source_text:
96+
deps.add(full_type_name)
97+
98+
# Check if types defined in this file extend/implement types in other files
99+
for type_decl in ast.types:
100+
from pyjopa.ast import ClassDeclaration, InterfaceDeclaration, EnumDeclaration
101+
102+
# Get superclass and interfaces
103+
if isinstance(type_decl, ClassDeclaration) and type_decl.extends:
104+
# extends Type -> might be in same package
105+
super_name = type_decl.extends.name if hasattr(type_decl.extends, 'name') else str(type_decl.extends)
106+
# Try same package first
107+
if package:
108+
candidate = f"{package}.{super_name}"
109+
if candidate in type_to_file:
110+
deps.add(candidate)
111+
else:
112+
if super_name in type_to_file:
113+
deps.add(super_name)
114+
115+
if isinstance(type_decl, (ClassDeclaration, EnumDeclaration)):
116+
for iface in type_decl.implements:
117+
iface_name = iface.name if hasattr(iface, 'name') else str(iface)
118+
if package:
119+
candidate = f"{package}.{iface_name}"
120+
if candidate in type_to_file:
121+
deps.add(candidate)
122+
else:
123+
if iface_name in type_to_file:
124+
deps.add(iface_name)
125+
126+
file_deps[f] = deps
127+
128+
# Build adjacency list: file -> files that depend on it
129+
in_degree = {f: 0 for f in files}
130+
adj = defaultdict(list)
131+
132+
for file, deps in file_deps.items():
133+
for dep in deps:
134+
if dep in type_to_file:
135+
dep_file = type_to_file[dep]
136+
if dep_file != file: # Skip self-dependencies
137+
adj[dep_file].append(file)
138+
in_degree[file] += 1
139+
140+
# Topological sort using Kahn's algorithm
141+
queue = deque([f for f in files if in_degree[f] == 0])
142+
result = []
143+
144+
while queue:
145+
current = queue.popleft()
146+
result.append(current)
147+
148+
for neighbor in adj[current]:
149+
in_degree[neighbor] -= 1
150+
if in_degree[neighbor] == 0:
151+
queue.append(neighbor)
152+
153+
# Check for cycles
154+
if len(result) != len(files):
155+
# Circular dependency detected - return original order
156+
return files
157+
158+
return result
159+
160+
31161
def compile_command(args):
32162
"""Compile Java files to .class bytecode."""
33163
from .parser import Java8Parser
@@ -43,15 +173,36 @@ def compile_command(args):
43173
classpath.add_rt_jar()
44174
except FileNotFoundError:
45175
print("Warning: rt.jar not found, method resolution may be limited", file=sys.stderr)
176+
else:
177+
classpath = ClassPath()
46178

47179
output_dir = Path(args.output) if args.output else Path(".")
180+
181+
# Create output directory first
48182
output_dir.mkdir(parents=True, exist_ok=True)
49183

184+
# Add custom classpath entries
185+
if args.classpath:
186+
import os
187+
for entry in args.classpath.split(os.pathsep):
188+
if entry:
189+
classpath.add_path(entry)
190+
191+
# Add output directory to classpath so previously compiled classes can be found
192+
if classpath:
193+
classpath.add_path(str(output_dir.absolute()))
194+
195+
# Sort files in dependency order if multiple files
196+
file_paths = [Path(f) for f in args.files]
197+
if len(file_paths) > 1:
198+
file_paths = _topological_sort(file_paths, parser)
199+
if args.verbose:
200+
print(f"Compilation order: {[str(f) for f in file_paths]}")
201+
50202
total_classes = 0
51-
for source_file in args.files:
52-
path = Path(source_file)
203+
for path in file_paths:
53204
if not path.exists():
54-
print(f"Error: File not found: {source_file}", file=sys.stderr)
205+
print(f"Error: File not found: {path}", file=sys.stderr)
55206
sys.exit(1)
56207

57208
try:
@@ -61,21 +212,22 @@ def compile_command(args):
61212

62213
for name, bytecode in class_files.items():
63214
class_path = output_dir / f"{name}.class"
215+
class_path.parent.mkdir(parents=True, exist_ok=True)
64216
with open(class_path, "wb") as f:
65217
f.write(bytecode)
66218
if args.verbose:
67219
print(f"Wrote {class_path}")
68220
total_classes += 1
69221

70222
except Exception as e:
71-
print(f"Error compiling {source_file}: {e}", file=sys.stderr)
223+
print(f"Error compiling {path}: {e}", file=sys.stderr)
72224
sys.exit(1)
73225

74226
if classpath:
75227
classpath.close()
76228

77229
if not args.quiet:
78-
print(f"Compiled {len(args.files)} file(s) to {total_classes} class(es)")
230+
print(f"Compiled {len(file_paths)} file(s) to {total_classes} class(es)")
79231

80232

81233
def main():
@@ -113,6 +265,10 @@ def main():
113265
"-o", "--output",
114266
help="Output directory for .class files (default: current directory)",
115267
)
268+
compile_parser.add_argument(
269+
"-cp", "--classpath",
270+
help="Additional classpath entries (colon-separated paths to .jar files or directories)",
271+
)
116272
compile_parser.add_argument(
117273
"--no-rt",
118274
action="store_true",

pyjopa/codegen/boxing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,30 @@ def emit_conversion(self, source_type: JType, target_type: JType, builder: Bytec
8888
return self.emit_boxing(source_type, builder)
8989
elif self.needs_unboxing(source_type, target_type):
9090
return self.emit_unboxing(source_type, builder)
91+
# Box primitive to reference type if needed
92+
elif isinstance(source_type, PrimitiveJType) and isinstance(target_type, ClassJType):
93+
# Box the primitive to its wrapper
94+
boxed_type = self.emit_boxing(source_type, builder)
95+
# Then checkcast to target if needed
96+
if boxed_type.internal_name() != target_type.internal_name():
97+
builder.checkcast(target_type.internal_name())
98+
return target_type
99+
return boxed_type
100+
elif self._needs_checkcast(source_type, target_type):
101+
builder.checkcast(target_type.internal_name())
102+
return target_type
91103
return source_type
92104

105+
def _needs_checkcast(self, source_type: JType, target_type: JType) -> bool:
106+
"""Check if a checkcast is needed from source to target."""
107+
# Skip if same type
108+
if source_type == target_type:
109+
return False
110+
# Skip primitives
111+
if not isinstance(source_type, ClassJType) or not isinstance(target_type, ClassJType):
112+
return False
113+
# Need checkcast if source is Object or a supertype of target
114+
# For simplicity, if they're different reference types, we need checkcast
115+
# (proper implementation would check the class hierarchy)
116+
return True
117+

0 commit comments

Comments
 (0)