Ever stared at a CUDA Out of Memory error and wondered where it all went wrong? Or perhaps you’ve noticed your GPU utilization sitting at a measly 30% while training your shiny new Transformer? Welcome to MemWall.
MemWall is a comprehensive Python library designed to help ML practitioners plan, profile, and optimize their hardware utilization. We bridge the gap between abstract model architectures and the harsh reality of hardware physics.
At the heart of optimization lies the Roofline Model, an intuitive visual model that connects computational performance, memory bandwidth, and arithmetic intensity.
What is the Ridge Point?
Every piece of hardware (like your A100 or RTX 4090) has two critical ceilings:
The Ridge Point is the exact ratio where these two meet, calculated simply as:
Ridge Point = Peak FLOPs / Peak Bandwidth
This single number is your North Star. By calculating the Arithmetic Intensity (FLOPs / Bytes) of your operation, you can compare it to the Ridge Point:
MemWall makes understanding and visualizing this dynamic effortless.
pip install memwall
To really understand what MemWall can do, let’s walk through the three core examples included in the library. Think of this as your interactive guide to solving ML performance bottlenecks.
examples/basic_estimation.py)The Question: “I want to run Llama-7B with a batch size of 4 and a sequence length of 2048. Will it fit on my GPU?”
Before you even load a single tensor or write a PyTorch script, MemWall can tell you exactly what you’re getting into. The basic_estimation.py script demonstrates how to load a preset estimator for models like Llama-7B.
What it does:
It mathematically estimates the exact VRAM breakdown. It doesn’t just guess a single number; it breaks it down into:
Result: You get a clean readout of exactly how many Gigabytes you need, saving you from trial-and-error OOM crashes.
examples/pytorch_profiling.py)The Question: “My training script is eating 40GB of VRAM, but my model is only 2GB. Where is the memory going?”
Sometimes math isn’t enough—you need to see what PyTorch is actually doing. The pytorch_profiling.py example highlights MemWall’s lightweight PyTorch hooks.
What it does:
By wrapping your model with MemWall’s profile_model function and passing dummy data, the library tracks every single forward pass operation. It outputs:
fc1, relu, fc2) consumed.Result: You instantly spot the specific layers that are hoarding memory, allowing you to selectively apply gradient checkpointing or optimize your architecture.
examples/roofline_analysis.py)The Question: “Why is my matrix multiplication taking so long? Do I need a faster GPU, or just different hardware?”
This script brings the Ridge Point concept to life.
What it does:
A100_80GB_SXM), and MemWall loads its peak constraints.As a bonus, the script automatically generates a beautiful roofline_example.png plot, mapping your operations directly against the theoretical limits of your hardware.
Result: You stop guessing why your code is slow. You get a mathematical proof pointing you toward either memory optimization or algorithmic improvements.
MemWall ships with precise hardware specifications for modern accelerators:
MIT License. Built for the community, by the community.