Skip to content

Load target address earlier for tail call interpreter #129976

Open
@aconz2

Description

@aconz2

My working branch is here main...aconz2:cpython:aconz2/early-tail-call-load

I saw the recent merge of the tail call interpreter (#128718), very nice! I have played with this style of interpreter before and one thing that comes up is when to calculate the target address. As it is, the current interpreter does it in DISPATCH() by doing

DEF_TARGET(foo) {
    // ...
    TAIL return INSTRUCTION_TABLE[opcode](ARGS);
}

this results in assembly like:

0000000000289580 <_TAIL_CALL_GET_LEN>:
  289580: 50                           	push	rax
  289581: 89 fb                        	mov	ebx, edi
  289583: 4d 89 7c 24 38               	mov	qword ptr [r12 + 0x38], r15
  289588: 49 83 c7 02                  	add	r15, 0x2
  28958c: 49 8b 7d f8                  	mov	rdi, qword ptr [r13 - 0x8]
  289590: 4d 89 6c 24 40               	mov	qword ptr [r12 + 0x40], r13
  289595: e8 f6 fc ea ff               	call	0x139290 <PyObject_Size>
  28959a: 4d 8b 6c 24 40               	mov	r13, qword ptr [r12 + 0x40]
  28959f: 49 c7 44 24 40 00 00 00 00   	mov	qword ptr [r12 + 0x40], 0x0
  2895a8: 48 85 c0                     	test	rax, rax
  2895ab: 78 2b                        	js	0x2895d8 <_TAIL_CALL_GET_LEN+0x58>
  2895ad: 48 89 c7                     	mov	rdi, rax
  2895b0: e8 eb 54 ec ff               	call	0x14eaa0 <PyLong_FromSsize_t>
  2895b5: 48 85 c0                     	test	rax, rax
  2895b8: 74 1e                        	je	0x2895d8 <_TAIL_CALL_GET_LEN+0x58>
  2895ba: 49 89 45 00                  	mov	qword ptr [r13], rax
  2895be: 49 83 c5 08                  	add	r13, 0x8
  2895c2: 41 0f b7 3f                  	movzx	edi, word ptr [r15]  #<-- Load next_instr
  2895c6: 40 0f b6 c7                  	movzx	eax, dil             #<-- grab opcode
  2895ca: c1 ef 08                     	shr	edi, 0x8
  2895cd: 48 8d 0d 7c 50 1f 00         	lea	rcx, [rip + 0x1f507c]   # 0x47e650 <INSTRUCTION_TABLE>
  2895d4: 5a                           	pop	rdx
  2895d5: ff 24 c1                     	jmp	qword ptr [rcx + 8*rax] #<-- jmp with addr calculation
  2895d8: 89 df                        	mov	edi, ebx
  2895da: 58                           	pop	rax
  2895db: e9 30 dc ff ff               	jmp	0x287210 <_TAIL_CALL_error>

where we jmp to a computed adress which is dependent on the lea and a memory load a few instructions prior.

Another method looks like

DEF_TARGET(foo) {
  // ...
  tail_funcptr next_f = INSTRUCTION_TABLE[next_opcode];
  // ...
  TAIL return next_f(ARGS);
}

where we try to get the compiler to compute the target earlier and then have a jmp reg. We have to pay special attention to places where next_instr is modified and reload the pointer (though hopefully the optimizer will just wait to do the calculation until the latest place).

In this early branch, I was able to get this working enough to see what asm it would generate. For _TAIL_CALL_GET_LEN, the sequence now looks like

00000000002896b0 <_TAIL_CALL_GET_LEN>:
  2896b0: 55                           	push	rbp
  2896b1: 89 fb                        	mov	ebx, edi
  2896b3: 4d 89 7c 24 38               	mov	qword ptr [r12 + 0x38], r15
  2896b8: 41 0f b6 47 02               	movzx	eax, byte ptr [r15 + 0x2]  #<-- Load next instr opcode
  2896bd: 49 83 c7 02                  	add	r15, 0x2
  2896c1: 48 8d 0d 88 5f 1f 00         	lea	rcx, [rip + 0x1f5f88]   # 0x47f650 <INSTRUCTION_TABLE>
  2896c8: 48 8b 2c c1                  	mov	rbp, qword ptr [rcx + 8*rax]  #<-- load next target addr
  2896cc: 49 8b 7d f8                  	mov	rdi, qword ptr [r13 - 0x8]
  2896d0: 4d 89 6c 24 40               	mov	qword ptr [r12 + 0x40], r13
  2896d5: e8 b6 fb ea ff               	call	0x139290 <PyObject_Size>
  2896da: 4d 8b 6c 24 40               	mov	r13, qword ptr [r12 + 0x40]
  2896df: 49 c7 44 24 40 00 00 00 00   	mov	qword ptr [r12 + 0x40], 0x0
  2896e8: 48 85 c0                     	test	rax, rax
  2896eb: 78 20                        	js	0x28970d <_TAIL_CALL_GET_LEN+0x5d>
  2896ed: 48 89 c7                     	mov	rdi, rax
  2896f0: e8 ab 53 ec ff               	call	0x14eaa0 <PyLong_FromSsize_t>
  2896f5: 48 85 c0                     	test	rax, rax
  2896f8: 74 13                        	je	0x28970d <_TAIL_CALL_GET_LEN+0x5d>
  2896fa: 49 89 45 00                  	mov	qword ptr [r13], rax
  2896fe: 49 83 c5 08                  	add	r13, 0x8
  289702: 41 0f b6 7f 01               	movzx	edi, byte ptr [r15 + 0x1]
  289707: 48 89 e8                     	mov	rax, rbp                  #<-- register rename
  28970a: 5d                           	pop	rbp
  28970b: ff e0                        	jmp	rax                       #<-- jmp to target addr
  28970d: 89 df                        	mov	edi, ebx
  28970f: 5d                           	pop	rbp
  289710: e9 fb da ff ff               	jmp	0x287210 <_TAIL_CALL_error>
  289715: 66 66 2e 0f 1f 84 00 00 00 00 00     	nop	word ptr cs:[rax + rax]
  2896c1: 48 8d 0d 88 5f 1f 00         	lea	rcx, [rip + 0x1f5f88]   # 0x47f650 <INSTRUCTION_TABLE>
  2896c8: 48 8b 2c c1                  	mov	rbp, qword ptr [rcx + 8*rax]

Specifically in this case, both PyObject_Size and PyLong_FromSsize_t don't touch rbp so there isn't any additional register pressure. But I haven't looked extensively so may not be universally true.

My theory is that this could be better for the CPU because in this example once it gets back from PyLong_FromSsize_t, the jump target is already in a register and could maybe prefetch better.

Have not benchmarked anything yet.

Looking at another example _TAIL_CALL_BINARY_OP_SUBSCR_GETITEM, this does a LOAD_IP() towards the end so we have to reload our target address. It does seem like the optimizer is smart enough to avoid double loading, but this just ends up with an almost identical ending:

# main
  28eba5: 41 0f b7 3f                   movzx   edi, word ptr [r15]
  28eba9: 40 0f b6 cf                   movzx   ecx, dil
  28ebad: c1 ef 08                      shr     edi, 0x8
  28ebb0: 48 8d 15 99 fa 1e 00          lea     rdx, [rip + 0x1efa99]   # 0x47e650 <INSTRUCTION_TABLE>
  28ebb7: 49 89 c4                      mov     r12, rax
  28ebba: ff 24 ca                      jmp     qword ptr [rdx + 8*rcx]

# this branch
  28f185: 41 0f b6 0f                   movzx   ecx, byte ptr [r15]
  28f189: 48 8d 15 c0 04 1f 00          lea     rdx, [rip + 0x1f04c0]   # 0x47f650 <INSTRUCTION_TABLE>
  28f190: 41 0f b6 7f 01                movzx   edi, byte ptr [r15 + 0x1]
  28f195: 49 89 c4                      mov     r12, rax
  28f198: ff 24 ca                      jmp     qword ptr [rdx + 8*rcx]

I did this a bit half-hazardly through a combination of modifying macros and manual changes to anything that assigns to next_instr and a few special cases like exit_unwind that didn't fit. Could clean up with some direction.

One super naive metric is

# should be a tab after jmp
llvm-objdump --x86-asm-syntax=intel -D python | grep 'jmp   r' | wc -l

which is 916 for this modification and 731 originally, so 185 more places where we jmp to a register instead of a computed address.

Metadata

Metadata

Assignees

No one assigned

    Labels

    interpreter-core(Objects, Python, Grammar, and Parser dirs)performancePerformance or resource usage

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions