Skip to content

CH09: whole program compiler driver #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

esc
Copy link
Member

@esc esc commented May 12, 2025

Next attempt at the whole program compiler.

For the program at:

https://gist.github.com/sklam/13c0646a4e6b2d401b731835629b1be4

this currently produces.

 💣 zsh» python wpc.py llm.py
{'softmax': SymbolInfo(name='softmax',
                       ast=<ast.FunctionDef object at 0x104e9ff10>,
                       calls=[<ast.Attribute object at 0x104e9fc50>,
                              <ast.Attribute object at 0x104e9fb10>,
                              <ast.Attribute object at 0x104e9f710>]),
 'scaled_dot_product_attention': SymbolInfo(name='scaled_dot_product_attention',
                                            ast=<ast.FunctionDef object at 0x104e9f310>,
                                            calls=[<ast.Attribute object at 0x104e9ed90>,
                                                   <ast.Attribute object at 0x104e9e550>,
                                                   <ast.Attribute object at 0x104e9dd10>,
                                                   <ast.Attribute object at 0x104e9ce50>,
                                                   <ast.Attribute object at 0x104e9c810>,
                                                   <ast.Name object at 0x104e9c450>,
                                                   <ast.Attribute object at 0x104e9bf50>,
                                                   <ast.Attribute object at 0x104e9bb90>]),
 'MultiHeadAttention.__init__': SymbolInfo(name='MultiHeadAttention.__init__',
                                           ast=<ast.FunctionDef object at 0x104e9b610>,
                                           calls=[]),
 'MultiHeadAttention.split_heads': SymbolInfo(name='MultiHeadAttention.split_heads',
                                              ast=<ast.FunctionDef object at 0x104e9ac10>,
                                              calls=[<ast.Attribute object at 0x104e99f90>,
                                                     <ast.Attribute object at 0x104e99a50>]),
 'MultiHeadAttention.combine_heads': SymbolInfo(name='MultiHeadAttention.combine_heads',
                                                ast=<ast.FunctionDef object at 0x104e99510>,
                                                calls=[<ast.Attribute object at 0x104e98750>]),
 'MultiHeadAttention.forward': SymbolInfo(name='MultiHeadAttention.forward',
                                          ast=<ast.FunctionDef object at 0x104e8fd90>,
                                          calls=[<ast.Attribute object at 0x104e8f790>,
                                                 <ast.Attribute object at 0x104e8f3d0>,
                                                 <ast.Attribute object at 0x104e8f010>,
                                                 <ast.Name object at 0x104e8d890>,
                                                 <ast.Attribute object at 0x104e8d550>]),
 'FeedForwardNetwork.__init__': SymbolInfo(name='FeedForwardNetwork.__init__',
                                           ast=<ast.FunctionDef object at 0x104e8c750>,
                                           calls=[<ast.Attribute object at 0x104e8dd90>,
                                                  <ast.Attribute object at 0x104e8e2d0>]),
 'FeedForwardNetwork.forward': SymbolInfo(name='FeedForwardNetwork.forward',
                                          ast=<ast.FunctionDef object at 0x104e8e690>,
                                          calls=[<ast.Attribute object at 0x104ea9290>,
                                                 <ast.Attribute object at 0x104ea8f10>]),
 'TransformerLayer.__init__': SymbolInfo(name='TransformerLayer.__init__',
                                         ast=<ast.FunctionDef object at 0x104ea8750>,
                                         calls=[<ast.Name object at 0x104ea9910>,
                                                <ast.Name object at 0x104ea9d50>]),
 'TransformerLayer.forward': SymbolInfo(name='TransformerLayer.forward',
                                        ast=<ast.FunctionDef object at 0x104eaa1d0>,
                                        calls=[<ast.Attribute object at 0x104eaa810>,
                                               <ast.Attribute object at 0x104eaad50>])}

@esc
Copy link
Member Author

esc commented May 14, 2025

The next iteration of the code is ready. The code now generates qualified names for the calls and also captures and prints global calls such that they can be used to determine the call graph and compilation order.

 💣 zsh» ipython -i  wpc.py llm.py
Python 3.13.2 | packaged by Anaconda, Inc. | (main, Feb  6 2025, 12:55:35) [Clang 14.0.6 ]
Type 'copyright', 'credits' or 'license' for more information
IPython 9.1.0 -- An enhanced Interactive Python. Type '?' for help.
Tip: You can use `files = !ls *.png`
{'softmax': SymbolInfo(name='softmax',
                       ast=<ast.FunctionDef object at 0x103262410>,
                       calls=[(<ast.Attribute object at 0x103623cd0>, 'np.exp'),
                              (<ast.Attribute object at 0x103623b90>, 'np.max'),
                              (<ast.Attribute object at 0x1036238d0>,
                               'np.sum')]),
 'scaled_dot_product_attention': SymbolInfo(name='scaled_dot_product_attention',
                                            ast=<ast.FunctionDef object at 0x1036234d0>,
                                            calls=[(<ast.Attribute object at 0x103622f10>,
                                                    'query.reshape'),
                                                   (<ast.Attribute object at 0x1036226d0>,
                                                    'key.reshape'),
                                                   (<ast.Attribute object at 0x103621e90>,
                                                    'value.reshape'),
                                                   (<ast.Attribute object at 0x103621250>,
                                                    'np.matmul'),
                                                   (<ast.Attribute object at 0x103620050>,
                                                    'np.sqrt'),
                                                   (<ast.Name object at 0x103620690>,
                                                    'softmax'),
                                                   (<ast.Attribute object at 0x1032e5c10>,
                                                    'np.matmul'),
                                                   (<ast.Attribute object at 0x103610550>,
                                                    'context.reshape')]),
 'MultiHeadAttention.__init__': SymbolInfo(name='MultiHeadAttention.__init__',
                                           ast=<ast.FunctionDef object at 0x103610a90>,
                                           calls=[]),
 'MultiHeadAttention.split_heads': SymbolInfo(name='MultiHeadAttention.split_heads',
                                              ast=<ast.FunctionDef object at 0x103611550>,
                                              calls=[(<ast.Attribute object at 0x1036121d0>,
                                                      'x.reshape'),
                                                     (<ast.Attribute object at 0x103612790>,
                                                      'x.transpose')]),
 'MultiHeadAttention.combine_heads': SymbolInfo(name='MultiHeadAttention.combine_heads',
                                                ast=<ast.FunctionDef object at 0x103612cd0>,
                                                calls=[(<ast.Attribute object at 0x103613950>,
                                                        'x.transpose.reshape')]),
 'MultiHeadAttention.forward': SymbolInfo(name='MultiHeadAttention.forward',
                                          ast=<ast.FunctionDef object at 0x103617e90>,
                                          calls=[(<ast.Attribute object at 0x103617850>,
                                                  'MultiHeadAttention.split_heads'),
                                                 (<ast.Attribute object at 0x103617490>,
                                                  'MultiHeadAttention.split_heads'),
                                                 (<ast.Attribute object at 0x103617050>,
                                                  'MultiHeadAttention.split_heads'),
                                                 (<ast.Name object at 0x103616b50>,
                                                  'scaled_dot_product_attention'),
                                                 (<ast.Attribute object at 0x103616790>,
                                                  'MultiHeadAttention.combine_heads')]),
 'FeedForwardNetwork.__init__': SymbolInfo(name='FeedForwardNetwork.__init__',
                                           ast=<ast.FunctionDef object at 0x103616150>,
                                           calls=[(<ast.Attribute object at 0x103615a50>,
                                                   'np.random.randn'),
                                                  (<ast.Attribute object at 0x103615510>,
                                                   'np.random.randn')]),
 'FeedForwardNetwork.forward': SymbolInfo(name='FeedForwardNetwork.forward',
                                          ast=<ast.FunctionDef object at 0x103615110>,
                                          calls=[(<ast.Attribute object at 0x103614d10>,
                                                  'np.matmul'),
                                                 (<ast.Attribute object at 0x103614990>,
                                                  'np.matmul')]),
 'TransformerLayer.__init__': SymbolInfo(name='TransformerLayer.__init__',
                                         ast=<ast.FunctionDef object at 0x103618910>,
                                         calls=[(<ast.Name object at 0x1036441d0>,
                                                 'MultiHeadAttention.__init__'),
                                                (<ast.Name object at 0x103645050>,
                                                 'FeedForwardNetwork.__init__')]),
 'TransformerLayer.forward': SymbolInfo(name='TransformerLayer.forward',
                                        ast=<ast.FunctionDef object at 0x1036454d0>,
                                        calls=[(<ast.Attribute object at 0x103645b10>,
                                                'TransformerLayer.self_attn.forward'),
                                               (<ast.Attribute object at 0x103646050>,
                                                'TransformerLayer.feed_forward.forward')])}
[(<ast.Attribute object at 0x103646690>, 'np.random.seed'),
 (<ast.Name object at 0x103647050>, 'TransformerLayer.__init__'),
 (<ast.Attribute object at 0x103647510>, 'np.random.randn'),
 (<ast.Attribute object at 0x103647a50>, 'transformer_layer.forward'),
 (<ast.Name object at 0x103647d10>, 'print'),
 (<ast.Name object at 0x1036581d0>, 'print')]

For now I have chosen to include the ast for each ast.Attribute node in a call. Not yet sure if that is useful for the future or not.

Next step will be to use networkx to generate the call graphs.

@esc esc force-pushed the ch09_whole_prog_compiler branch from 9333be3 to bf43da6 Compare May 23, 2025 08:41
@esc
Copy link
Member Author

esc commented May 23, 2025

Hi @sklam @seibert,

this chapter has now been converted to the jupyter format used by the other chapters. Also, we can now connect call graphs across class attributes using the provided type annotations. I have tested both the ipynb and also the html variants.

Also, this includes a modification of the Makefile such that no *.ipynb files are built for ancillary files, i.e. everything that doesn't start with ch or demo: 3037f3d

Technically, this does not yet "work" in the sense that an actual program can be compiled. The completion of that capability is still waiting on:

a) An interface to the new jit compiler that can take an ast. The symbol info objects of this driver include the raw ast, which can be handed off to jit_compile function, for example.

b) A solution as to what to do with the np.* calls. We either need out own implementations, like Numpy re-written in mcl or a way to link into the existing Numpy implementation, for example, by way of PIXIE or similar.

Essentially, in order to make this a "working" compiler, some more blanks must be filled in, but the overall structure to analyse the source code and extract the information required to perform the compile of the functions in a python module is implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant