@@ -58,7 +58,6 @@ def fix_f2c_input(f2c_input_path: str) -> None:
58
58
lines = f .readlines ()
59
59
new_lines = []
60
60
lines = char1_args_to_int (lines )
61
-
62
61
for line in lines :
63
62
line = fix_string_args (line )
64
63
@@ -91,6 +90,27 @@ def fix_f2c_input(f2c_input_path: str) -> None:
91
90
92
91
new_lines .append (line )
93
92
93
+ # We assume one function per file, since this seems quite consistently true.
94
+ # Figure out if it's supposed to be recursive. f2c can't handle the
95
+ # recursive keyword so we need to remove it and add a comment so we can tell
96
+ # it was supposed to be recursive. In fix_f2c_output, we'll remove the
97
+ # static keywords from all the variables.
98
+ is_recursive = False
99
+ for idx , line in enumerate (new_lines ):
100
+ if "recursive" in line :
101
+ is_recursive = True
102
+ new_lines [idx ] = new_lines [idx ].replace ("recursive" , "" )
103
+ if line .strip () == "recursive" :
104
+ # If whole line was recursive, then the next line starts with an
105
+ # asterisk to indicate line continuation. Fortran is very
106
+ # persnickity so we have to remove the line continuation. Make
107
+ # sure to replace the * with a space because the number of
108
+ # pre-code characters is significant...
109
+ new_lines [idx + 1 ] = new_lines [idx + 1 ].replace ("*" , " " )
110
+ break
111
+ if is_recursive :
112
+ new_lines .insert (0 , "C .. xxISRECURSIVExx ..\n " )
113
+
94
114
with open (f2c_input_path , "w" ) as f :
95
115
f .writelines (new_lines )
96
116
@@ -194,9 +214,12 @@ def fix_f2c_output(f2c_output_path: str) -> str | None:
194
214
90 and Fortran 95.
195
215
"""
196
216
f2c_output = Path (f2c_output_path )
197
-
198
217
with open (f2c_output ) as f :
199
218
lines = f .readlines ()
219
+
220
+ is_recursive = any ("xxISRECURSIVExx" in line for line in lines )
221
+
222
+ lines = list (regroup_lines (lines ))
200
223
if "id_dist" in f2c_output_path :
201
224
# Fix implicit casts in id_dist.
202
225
lines = fix_inconsistent_decls (lines )
@@ -270,16 +293,45 @@ def fix_line(line: str) -> str:
270
293
if "eupd.c" in str (f2c_output ):
271
294
# put signature on a single line to make replacement more
272
295
# straightforward
273
- regrouped_lines = regroup_lines (lines )
274
296
lines = [
275
- re .sub (r",?\s*ftnlen\s*(howmny_len|bmat_len)" , "" , line )
276
- for line in regrouped_lines
297
+ re .sub (r",?\s*ftnlen\s*(howmny_len|bmat_len)" , "" , line ) for line in lines
277
298
]
278
299
279
300
# Fix signature of c_abs to match the OpenBLAS one
280
301
if "REVCOM.c" in str (f2c_output ):
281
302
lines = [line .replace ("double c_abs(" , "float c_abs(" ) for line in lines ]
282
303
304
+ # Non recursive functions declare all their locals as static, ones marked
305
+ # "recursive" need them to be proper local variables. For recursive
306
+ # functions we'll replace them.
307
+ def fix_static (line : str ) -> str :
308
+ static_prefix = " static"
309
+ if not line .startswith (static_prefix ):
310
+ return line
311
+ line = line .removeprefix (static_prefix ).strip ()
312
+ # If line contains a { or " there's already an initializer and we'll get
313
+ # confused. When there's an initializer there's also only one variable
314
+ # so we don't need to do anything.
315
+ if "{" in line or '"' in line :
316
+ return line + "\n "
317
+ # split off type
318
+ type , rest = line .split (" " , 1 )
319
+ # Since there is no { or " each comma separates a variable name
320
+ names = rest [:- 1 ].split ("," )
321
+ init_names = []
322
+ for name in names :
323
+ if "=" in name :
324
+ # There's already an initializer
325
+ init_names .append (name )
326
+ else :
327
+ # = {0} initializes all types to all 0s.
328
+ init_names .append (name + " = {0}" )
329
+ joined_names = "," .join (init_names )
330
+ return f" { type } { joined_names } ;\n "
331
+
332
+ if is_recursive :
333
+ lines = list (map (fix_static , lines ))
334
+
283
335
with open (f2c_output , "w" ) as f :
284
336
f .writelines (lines )
285
337
@@ -333,15 +385,18 @@ def regroup_lines(lines: Iterable[str]) -> Iterator[str]:
333
385
... static doublereal psum[52];
334
386
... extern /* Subroutine */ int dqelg_(integer *, doublereal *, doublereal *,
335
387
... doublereal *, doublereal *, integer *);
336
- ... '''))))
388
+ ... '''))).strip() )
337
389
/* Subroutine */ int clanhfwrp_(real *ret, char *norm, char *transr, char * uplo, integer *n, complex *a, real *work, ftnlen norm_len, ftnlen transr_len, ftnlen uplo_len){
338
390
static doublereal psum[52];
339
391
extern /* Subroutine */ int dqelg_(integer *, doublereal *, doublereal *, doublereal *, doublereal *, integer *);
340
-
341
392
"""
342
393
line_iter = iter (lines )
343
394
for line in line_iter :
344
- if "/* Subroutine */" not in line :
395
+ if "/* Subroutine */" not in line and "static" not in line :
396
+ yield line
397
+ continue
398
+
399
+ if '"' in line :
345
400
yield line
346
401
continue
347
402
@@ -360,7 +415,7 @@ def regroup_lines(lines: Iterable[str]) -> Iterator[str]:
360
415
if is_definition :
361
416
yield joined_line
362
417
else :
363
- yield from (x + ";" for x in joined_line .split (";" )[:- 1 ])
418
+ yield from (x + ";\n " for x in joined_line .split (";" )[:- 1 ])
364
419
365
420
366
421
def fix_inconsistent_decls (lines : list [str ]) -> list [str ]:
@@ -410,7 +465,6 @@ def fix_inconsistent_decls(lines: list[str]) -> list[str]:
410
465
}
411
466
"""
412
467
func_types = {}
413
- lines = list (regroup_lines (lines ))
414
468
for line in lines :
415
469
if not line .startswith ("/* Subroutine */" ):
416
470
continue
0 commit comments