Decorator JITs: Python as a DSL
Eli Bendersky's article explores Just-In-Time compilation in Python, focusing on JIT decorators in machine learning libraries like JAX and Triton, highlighting various implementation strategies for enhanced performance.
Read original articleEli Bendersky's article discusses the use of Just-In-Time (JIT) compilation in Python, particularly in the context of machine learning libraries like JAX and Triton. The JIT decorator pattern allows Python functions to be transformed into a Domain Specific Language (DSL) that can be compiled and executed more efficiently. The article outlines various implementation strategies for JIT decorators, including AST-based JIT, bytecode-based JIT, and tracing-based JIT. Each method involves translating Python functions into an intermediate representation (IR) before converting them to LLVM IR for execution. The AST-based approach uses Python's Abstract Syntax Tree to create an expression representation, while the bytecode-based method compiles functions to bytecode first. The tracing method, on the other hand, captures the execution of the function to generate the necessary expressions. Bendersky emphasizes the flexibility and power of these techniques, which allow Python to serve as a meta-language for describing computations, ultimately enabling more efficient execution of machine learning algorithms.
- JIT compilation enhances the performance of Python functions, especially in machine learning.
- Different JIT strategies include AST-based, bytecode-based, and tracing-based approaches.
- The transformation of Python functions into an intermediate representation is crucial for JIT execution.
- Libraries like JAX and Triton utilize JIT decorators to optimize performance while maintaining Python's usability.
- Understanding these strategies can help developers leverage Python's capabilities for high-performance computing.
Related
Mining JIT traces for missing optimizations with Z3
Using Z3, PyPy's JIT traces are analyzed to pinpoint inefficient integer operations for further optimization. By translating operations into Z3 formulas, redundancies are identified to enhance PyPy's JIT compiler efficiently.
Python extensions should be lazy
Python's `ast.parse` function is slow due to memory management issues. A Rust extension improved AST processing speed by 16x, suggesting lazy loading strategies for better performance in Python extensions.
A DSL for peephole transformation rules of integer operations in the PyPy JIT
A new domain-specific language for optimizing integer operations in the PyPy JIT compiler has been developed, featuring declarative transformation rules verified for correctness to enhance efficiency and monitor effectiveness.
Musings on Tracing in PyPy
The blog post examines tracing Just-In-Time compilers, particularly in PyPy, noting their performance inconsistencies and challenges. It concludes that tracing is a pragmatic choice for handling Python's complexity.
Musings on Tracing in PyPy
The blog post analyzes tracing Just-In-Time (JIT) compilers, particularly in PyPy, discussing their benefits, challenges, and contrasting them with method-based JITs while expressing optimism for Python's performance improvements.
- Many commenters express challenges and frustrations with Python's bytecode and JIT implementations, highlighting issues like instability and data copying overhead.
- There is a discussion about the potential of using decorators for enhanced performance and the desire for seamless integration between Python and JIT-compiled code.
- Some commenters draw parallels between Python and Lisp, particularly in terms of metaprogramming capabilities and the use of macros.
- Several users reminisce about past JIT efforts in Python, such as Psyco, and discuss the evolution of JIT technologies like PyPy.
- There is a call for improved tooling and frameworks that can facilitate easier JIT compilation and tracing in Python without significant complexity.
Then a few weeks later I came to the conclusion that Python is the new Lisp when it comes to metaprogramming. (and async in Python does the same thing that he coded up with continuations.) I think homoiconicity and the parenthesis are a red herring, the real problem is that we're still stuck with parser generators that aren't composable. You really ought to be able to add
unless(X) { ... }
to Java by adding 1 production to the grammar, a new object for the AST tree, and a transformation for the compiler that rewrites to if(!X) { ... }
probably the actual code would be smaller than the POM file if the compiler was built as if extensibility mattered.Almost all the examples in this book (which claims to be a tutorial for Common Lisp programming)
https://www.amazon.com/Paradigms-Artificial-Intelligence-Pro...
are straightforward to code up in Python. The main retort to this I hear from Common Lisp enthusiasts is that some CL implementations are faster, which is true. Still, most languages today have a big helping of "Lisp, the good parts". Maybe some day the Rustifarians will realize the wide-ranging impacts of garbage collection, not least that you can smack together an unlimited number of frameworks and libraries into one program and never have to think about making the memory allocation and deallocation match up.
- CPython's bytecode is extremely unstable. Not only do new opcodes are added/removed each release, the meaning of existing opcodes can change. For instance the meaning for the argument to JUMP_IF_FALSE_OR_POP changes depending on the CPython version; in CPython 3.10 and below, it is an absolute address, in CPython 3.11 and above, it a relative address.
- The documentation for the bytecode in dis tends to be outdated or outright wrong. I often had to analyze the generated bytecode to figure out what each opcode means (and then submit a corresponding PR to update said documentation). Moreover, it assumes you know how the inner details of CPython work, from the descriptor protocol to how binary operations are implemented (each of which are about 30 lines functions when written in Python).
- CPython's bytecode is extremely atypical. For instance, a for-loop keeps its iterator on the stack instead of storing it in a synthetic variable. As a result, when an exception occurs inside a for-loop, instead of the stack containing only the exception, it will also contain the for-loop iterator.
As for why I did this, I have Java calling CPython in a hot loop. Although direct FFI is fast, it causes a memory leak due to Java's and Python's Garbage collectors needing to track each other's objects. When using JPype or GraalPy, the overhead of calling Python in a Java hot-loop is massive; I got a 100x speedup from translating the CPython bytecode to Java bytecode with identical behaviour (details can be found in my blog post: https://timefold.ai/blog/java-vs-python-speed).
I strongly recommend using the AST instead (although there are no backward comparability guarantees with AST, it is far less likely to break between versions).
Can anyone think of a way to get a nice smooth gradation of tracing based transformations based on effort required or something. I'd love to say, 'okay, in this case i'm willing to put in a bit more effort' and somehow get data dependent if statements working, but not support data dependent loops. All I know of now is either tracing with zero data dependent control flow, or going all the way to writing a python compiler with whatever set of semantics you want to support and full failure on what you don't.
On a different note, some easy decorator DSL based pdb integration would be an incredible enabler for these kinds of things. My coworkers are always trying to write little 'engine' DSLs for one thing or another, and it sucks that whenever you implement your own execution engine, you completely lose all language tooling. As I understand it, in compiler tooling, you always have some burden of shepherding around maps of what part of the source a given thing corresponds to. Ditto for python decorator DSLs, except nobody bothers, meaning you get the equivalent of a 1960's developer experience in that DSL.
In Thunder[1], a PyTorch to Python JIT compiler for optimizing DL models, we are maintaining a bytecode interpreter covering 3.10-3.12 (and 3.13 soon) for our jit. That allows to run Python code while re-directing arbitrary function calls and operations but is quite a bit slower than CPython.
While the bytecode changes (and sometimes it is a back-and-forth for example in the call handling), it seems totally good once you embrace that there will be differences between Python versions.
What has been a large change is the new zero cost (in the happy path) exception handling, but I can totally why Python did that change to that from setting up try-block frames.
I will say that I was happy not to support Python <= 3.9 as changes were a lot more involved there (the bytecode format itself etc.).
Of course, working on this has also means knowing otherwise useless Python trivia afterwards. One of my favorites is how this works:
l = [1, 2, 3]
l[-1] += l.pop()
print(l)
1. https://github.com/Lightning-AI/lightning-thunder/What we need is a framework which can mirror python code in jit code without effort.
@jit.share
class ComplexClass:
def __init__(self, complex_parameters: SharedParameters):
self.complex_parameters = complex_parameters
self.x = 42
self.y = 42
def function_using_python_specific_libraries_and_stuff(self):
...
@jit
def make_it_go_fast(obj: ComplexClass):
...
# modify obj, it is mirrored in jit code,
# but shares its memory with the python object
# so any change also effects the passed python object
# initially, this can be done with observers,
# but then we will need some sort of memory aligned
# direct read/write regions for real speed
The complexity arises when function_using_python_specific_libraries_and_stuff uses exported libraries. The JIT code has to detect their calls as wrappers for shared objects and pass them through for seamless integration and only compile the python specific AST.https://psyco.sourceforge.net/psycoguide/node8.html
p.s. The psyco guys then went another direction called pypy.
p.p.s. There's also a pypy based decorator but it limits the parameters to basic types. Sadly I forgot the github.
Applied over set of (constrained) functions, and the result was well-formed trace of which var got what value because of what expression over what values.
So can these ~hacks be avoided now - or not really?
[0] https://github.com/svilendobrev/svd_util/blob/master/tracer....
it's a fun foot-in-the-door trick to start going into compilation
Related
Mining JIT traces for missing optimizations with Z3
Using Z3, PyPy's JIT traces are analyzed to pinpoint inefficient integer operations for further optimization. By translating operations into Z3 formulas, redundancies are identified to enhance PyPy's JIT compiler efficiently.
Python extensions should be lazy
Python's `ast.parse` function is slow due to memory management issues. A Rust extension improved AST processing speed by 16x, suggesting lazy loading strategies for better performance in Python extensions.
A DSL for peephole transformation rules of integer operations in the PyPy JIT
A new domain-specific language for optimizing integer operations in the PyPy JIT compiler has been developed, featuring declarative transformation rules verified for correctness to enhance efficiency and monitor effectiveness.
Musings on Tracing in PyPy
The blog post examines tracing Just-In-Time compilers, particularly in PyPy, noting their performance inconsistencies and challenges. It concludes that tracing is a pragmatic choice for handling Python's complexity.
Musings on Tracing in PyPy
The blog post analyzes tracing Just-In-Time (JIT) compilers, particularly in PyPy, discussing their benefits, challenges, and contrasting them with method-based JITs while expressing optimism for Python's performance improvements.