June 22nd, 2024

Shape Rotation 101: An Intro to Einsum and Jax Transformers

Einsum notation simplifies tensor operations in libraries like NumPy, PyTorch, and Jax. Jax Transformers showcase efficient tensor operations in deep learning tasks, emphasizing speed and memory benefits for research and production environments.

Read original articleLink Icon
Shape Rotation 101: An Intro to Einsum and Jax Transformers

This blog post introduces readers to the concepts of Einsum notation and Jax Transformers. Einsum is a tensor manipulation API used in libraries like NumPy, PyTorch, and Jax, leveraging Einstein summation notation to simplify linear algebraic operations on multi-dimensional arrays. It offers speed and memory efficiency benefits but can be initially tricky to grasp due to its notation. The post explains how Einsum works with examples, showcasing its advantages over traditional array functions.

The second part of the post delves into a simple Jax Transformer implementation, focusing on a decoder model. The code demonstrates how to perform attention mechanisms and feedforward neural network operations using Einsum notation within the Jax framework. The transformer implementation showcases the usage of Einsum for efficient tensor operations in deep learning tasks. Jax is highlighted as a tool that combines the syntax of NumPy with functional programming concepts for faster computation, making it suitable for research and production environments.

Related

Implementing General Relativity: What's inside a black hole?

Implementing General Relativity: What's inside a black hole?

Implementing general relativity for black hole exploration involves coordinate systems, upgrading metrics, calculating tetrads, and parallel transport. Tetrads transform vectors between flat and curved spacetime, crucial for understanding paths.

We no longer use LangChain for building our AI agents

We no longer use LangChain for building our AI agents

Octomind switched from LangChain due to its inflexibility and excessive abstractions, opting for modular building blocks instead. This change simplified their codebase, increased productivity, and emphasized the importance of well-designed abstractions in AI development.

Exposition of Front End Build Systems

Exposition of Front End Build Systems

Frontend build systems are crucial in web development, involving transpilation, bundling, and minification steps. Tools like Babel and Webpack optimize code for performance and developer experience. Various bundlers like Webpack, Rollup, Parcel, esbuild, and Turbopack are compared for features and performance.

Writing an IR from Scratch and survive to write a post

Writing an IR from Scratch and survive to write a post

Eduardo Blázquez developed an Intermediate Representation (IR) for the Kunai Static Analyzer during his PhD, aiming to enhance Dalvik bytecode analysis. The project, shared on GitHub and published in SoftwareX, transitioned to Shuriken. Blázquez drew inspiration from Triton and LLVM, exploring various IR structures like ASTs and CFGs. MjolnIR, Kunai's IR, utilized a Medium Level IL design with control-flow graphs representing methods. Blázquez's approach involved studying compiler design resources.

Homegrown Rendering with Rust

Homegrown Rendering with Rust

Embark Studios develops a creative platform for user-generated content, emphasizing gameplay over graphics. They leverage Rust for 3D rendering, introducing the experimental "kajiya" renderer for learning purposes. The team aims to simplify rendering for user-generated content, utilizing Vulkan API and Rust's versatility for GPU programming. They seek to enhance Rust's ecosystem for GPU programming.

Link Icon 6 comments
By @dima55 - 5 months
An important note about numpy broadcasting: numpy broadcasts from the back, so your life improves dramatically when you reference indices from the back as well: use axis references < 0. So if you want to reference a row: refer to axis=-1. This will ALWAYS refer to the row (first broadcasting dimension), whether you have a 1D vector or 2D matrix or any N-D array. Numpy is deeply unfriendly if you don't do this. To smooth out this an similar issues, there's the numpysane library. But simply using negative axis references goes a long way.
By @nlprtag - 5 months
I find NumPy way too complex for the relatively simple operations used in machine learning. The amount of implicit rules like broadcasting, silently truncating int64 => double, einsum complexities etc. is just mind boggling.

The result is a couple of dense lines but one cannot just read them without going into a deep analysis for each line.

It is a pity that this has been accepted as the standard for machine learning. Worse, now every package has its own variant of NumPy (e.g. "import jax.numpy as jnp" in the article), which is incompatible with the standard one:

https://jax.readthedocs.io/en/latest/jax.numpy.html

I really would like a simpler array library that does stricter type checking, supports saner type specifications for composite types, does not broadcast automatically (except perhaps for matrix * scalar) and does one operation at a time. Casting should be explicit as well.

Bonus points if it isn't tied and inextricably linked to Python.

By @cl3misch - 5 months
If you're into shape rotations with numpy arrays check out einopt.

Also consider using "None" instead of "np.newaxis". To newcomers it's not as self-explanatory but it results in more readable code imho.

By @ishan0102 - 5 months
so good
By @tanvach - 5 months
Don’t know if the author will see this: in the table at the end of the article, there is an error where the text description of dot product and matrix multiplication are swapped.

Otherwise - great article! Didn’t know this exists in numpy. A really neat way to express matrix operations.

By @earhart - 5 months
I still wish Tile had caught on; einsum is really nice, but sometimes I want a dilated convolution, or a maxpool.

(OTOH, I’m not an einsum expert; please feel free to delight me by pointing out how it’s possible to do these sorts of things :-)