February 3rd, 2025

Decorator JITs: Python as a DSL

Eli Bendersky's article explores Just-In-Time compilation in Python, focusing on JIT decorators in machine learning libraries like JAX and Triton, highlighting various implementation strategies for enhanced performance.

Read original articleLink Icon
FrustrationCuriosityHope
Decorator JITs: Python as a DSL

Eli Bendersky's article discusses the use of Just-In-Time (JIT) compilation in Python, particularly in the context of machine learning libraries like JAX and Triton. The JIT decorator pattern allows Python functions to be transformed into a Domain Specific Language (DSL) that can be compiled and executed more efficiently. The article outlines various implementation strategies for JIT decorators, including AST-based JIT, bytecode-based JIT, and tracing-based JIT. Each method involves translating Python functions into an intermediate representation (IR) before converting them to LLVM IR for execution. The AST-based approach uses Python's Abstract Syntax Tree to create an expression representation, while the bytecode-based method compiles functions to bytecode first. The tracing method, on the other hand, captures the execution of the function to generate the necessary expressions. Bendersky emphasizes the flexibility and power of these techniques, which allow Python to serve as a meta-language for describing computations, ultimately enabling more efficient execution of machine learning algorithms.

- JIT compilation enhances the performance of Python functions, especially in machine learning.

- Different JIT strategies include AST-based, bytecode-based, and tracing-based approaches.

- The transformation of Python functions into an intermediate representation is crucial for JIT execution.

- Libraries like JAX and Triton utilize JIT decorators to optimize performance while maintaining Python's usability.

- Understanding these strategies can help developers leverage Python's capabilities for high-performance computing.

AI: What people are saying
The comments on Eli Bendersky's article about Just-In-Time compilation in Python reveal several key themes and insights regarding JIT and decorators in Python.
  • Many commenters express challenges and frustrations with Python's bytecode and JIT implementations, highlighting issues like instability and data copying overhead.
  • There is a discussion about the potential of using decorators for enhanced performance and the desire for seamless integration between Python and JIT-compiled code.
  • Some commenters draw parallels between Python and Lisp, particularly in terms of metaprogramming capabilities and the use of macros.
  • Several users reminisce about past JIT efforts in Python, such as Psyco, and discuss the evolution of JIT technologies like PyPy.
  • There is a call for improved tooling and frameworks that can facilitate easier JIT compilation and tracing in Python without significant complexity.
Link Icon 12 comments
By @PaulHoule - 3 months
I read On Lisp by Graham recently and first thought "this is the best programming book I read in a while", and then had the urge to make copy editing kind of changes "he didn't define nconc" and then thought "if he was using Clojure he wouldn't be fighting with nconc", and by the end thought "most of the magic is in functions, mostly he gets efficiency out of macros, the one case that really needs macros is the use of continuations" and "I'm disappointed he didn't write any macros that do a real tree transformation"

Then a few weeks later I came to the conclusion that Python is the new Lisp when it comes to metaprogramming. (and async in Python does the same thing that he coded up with continuations.) I think homoiconicity and the parenthesis are a red herring, the real problem is that we're still stuck with parser generators that aren't composable. You really ought to be able to add

   unless(X) { ... }
to Java by adding 1 production to the grammar, a new object for the AST tree, and a transformation for the compiler that rewrites to

   if(!X) { ... }
probably the actual code would be smaller than the POM file if the compiler was built as if extensibility mattered.

Almost all the examples in this book (which claims to be a tutorial for Common Lisp programming)

https://www.amazon.com/Paradigms-Artificial-Intelligence-Pro...

are straightforward to code up in Python. The main retort to this I hear from Common Lisp enthusiasts is that some CL implementations are faster, which is true. Still, most languages today have a big helping of "Lisp, the good parts". Maybe some day the Rustifarians will realize the wide-ranging impacts of garbage collection, not least that you can smack together an unlimited number of frameworks and libraries into one program and never have to think about making the memory allocation and deallocation match up.

By @cchianel - 3 months
I had the misfortune of translating CPython bytecode to Java bytecode, and I do not wish that experience on anyone:

- CPython's bytecode is extremely unstable. Not only do new opcodes are added/removed each release, the meaning of existing opcodes can change. For instance the meaning for the argument to JUMP_IF_FALSE_OR_POP changes depending on the CPython version; in CPython 3.10 and below, it is an absolute address, in CPython 3.11 and above, it a relative address.

- The documentation for the bytecode in dis tends to be outdated or outright wrong. I often had to analyze the generated bytecode to figure out what each opcode means (and then submit a corresponding PR to update said documentation). Moreover, it assumes you know how the inner details of CPython work, from the descriptor protocol to how binary operations are implemented (each of which are about 30 lines functions when written in Python).

- CPython's bytecode is extremely atypical. For instance, a for-loop keeps its iterator on the stack instead of storing it in a synthetic variable. As a result, when an exception occurs inside a for-loop, instead of the stack containing only the exception, it will also contain the for-loop iterator.

As for why I did this, I have Java calling CPython in a hot loop. Although direct FFI is fast, it causes a memory leak due to Java's and Python's Garbage collectors needing to track each other's objects. When using JPype or GraalPy, the overhead of calling Python in a Java hot-loop is massive; I got a 100x speedup from translating the CPython bytecode to Java bytecode with identical behaviour (details can be found in my blog post: https://timefold.ai/blog/java-vs-python-speed).

I strongly recommend using the AST instead (although there are no backward comparability guarantees with AST, it is far less likely to break between versions).

By @6gvONxR4sf7o - 3 months
I've had a lot of fun with tracing decorators in python, but the limitation of data dependent control flow (e.g. an if statement, a for loop) always ends up being more painful that I'd hope. It's a shame since it's such a great pattern otherwise.

Can anyone think of a way to get a nice smooth gradation of tracing based transformations based on effort required or something. I'd love to say, 'okay, in this case i'm willing to put in a bit more effort' and somehow get data dependent if statements working, but not support data dependent loops. All I know of now is either tracing with zero data dependent control flow, or going all the way to writing a python compiler with whatever set of semantics you want to support and full failure on what you don't.

On a different note, some easy decorator DSL based pdb integration would be an incredible enabler for these kinds of things. My coworkers are always trying to write little 'engine' DSLs for one thing or another, and it sucks that whenever you implement your own execution engine, you completely lose all language tooling. As I understand it, in compiler tooling, you always have some burden of shepherding around maps of what part of the source a given thing corresponds to. Ditto for python decorator DSLs, except nobody bothers, meaning you get the equivalent of a 1960's developer experience in that DSL.

By @sega_sai - 3 months
I hope this is the future for Python. Write in pure Python, but if needed, the code can be JIT (or not JIT) compiled into something faster (provided your code does not rely too much on low-level python stuff, such as __ functions).
By @t-vi - 3 months
If you like JIT wrappers and Python interpreters:

In Thunder[1], a PyTorch to Python JIT compiler for optimizing DL models, we are maintaining a bytecode interpreter covering 3.10-3.12 (and 3.13 soon) for our jit. That allows to run Python code while re-directing arbitrary function calls and operations but is quite a bit slower than CPython.

While the bytecode changes (and sometimes it is a back-and-forth for example in the call handling), it seems totally good once you embrace that there will be differences between Python versions.

What has been a large change is the new zero cost (in the happy path) exception handling, but I can totally why Python did that change to that from setting up try-block frames.

I will say that I was happy not to support Python <= 3.9 as changes were a lot more involved there (the bytecode format itself etc.).

Of course, working on this has also means knowing otherwise useless Python trivia afterwards. One of my favorites is how this works:

  l = [1, 2, 3]
  l[-1] += l.pop()
  print(l)
1. https://github.com/Lightning-AI/lightning-thunder/
By @nurettin - 3 months
The problem with these JIT/kernel implementations is that there is an extra step of copying data over to the function. Some implementations like numba work around that by exposing raw pointers to python for numeric arrays, but I don't know of any framework which can do that with python objects.

What we need is a framework which can mirror python code in jit code without effort.

    @jit.share
    class ComplexClass:
        def __init__(self, complex_parameters: SharedParameters):
            self.complex_parameters = complex_parameters
            self.x = 42
            self.y = 42
        
        def function_using_python_specific_libraries_and_stuff(self):
            ...


    @jit
    def make_it_go_fast(obj: ComplexClass):
        ...
        # modify obj, it is mirrored in jit code, 
        # but shares its memory with the python object 
        # so any change also effects the passed python object
        # initially, this can be done with observers, 
        # but then we will need some sort of memory aligned 
        # direct read/write regions for real speed

The complexity arises when function_using_python_specific_libraries_and_stuff uses exported libraries. The JIT code has to detect their calls as wrappers for shared objects and pass them through for seamless integration and only compile the python specific AST.
By @est - 3 months
Aha, anyone remember psyco from the python 2.x era?

https://psyco.sourceforge.net/psycoguide/node8.html

p.s. The psyco guys then went another direction called pypy.

p.p.s. There's also a pypy based decorator but it limits the parameters to basic types. Sadly I forgot the github.

By @hardmath123 - 3 months
Here's another blog post on this theme! https://github.com/kach/art-deco/blob/main/art-deco.ipynb
By @sitkack - 3 months
Great article, https://www.taichi-lang.org/ also uses the JIT decorator technique for their GPU compiled kernels.
By @svilen_dobrev - 3 months
i needed to make the "tracing" part - which i called "explain" - without jits, in 2007-8.. using combination of operator-overloading, variables-"declaring", and bytecode hacks [0].

Applied over set of (constrained) functions, and the result was well-formed trace of which var got what value because of what expression over what values.

So can these ~hacks be avoided now - or not really?

[0] https://github.com/svilendobrev/svd_util/blob/master/tracer....

By @agumonkey - 3 months
less complex libraries do python ast analysis wrapped in decorators to ensure purity of code for instance

it's a fun foot-in-the-door trick to start going into compilation

By @bjourne - 3 months
Great, but afaict, it's not a jit. It is using llvm to aot-compile Python code. Decorators are called when their respective functions are compiled, not when they are called.