These are work-in-progress note to help me think about a conceptual part of a project I'm working on. Parts might not be clear and the overall flow might be wonky.
I recently took a course on program analysis and for the final project, my friend and I built a static analysis to infer PyTorch tensor shapes, in the hopes of creating a useful tool for developers to catch shape mismatches early. What we submitted was a good proof of concept, but the analysis is far from ready to take on most PyTorch programs, as we only handled a subset of the language so we could focus on the tensor inference bits.
Looking around the dev tool landscape, I don't see many analyses over languages like Python that made it out of academia. This is a little worrying.
Dynamic languages by definition have less information known at compile-time, which make it harder to write good analyses over them. As I'm looking to build our proof of concept analysis into a useful tool, I want to think more about why languages like Python are so hard to analyze, and what simplifying assumptions can be made that make it tractable.
Our analysis
At a high level, our analysis assumes tensor parameters are annotated with shapes, and then performs a dataflow analysis to infer tensor shapes at each location in a function. Given this function,
def concat(x: T["m a"], y: T["m b"], z: T["m c"]):
w = torch.concat((x, y, z), dim=1)
our analysis infers that w is a tensor with shape $[m, a+b+c]$.
Why are dynamic languages tricky to analyze?
I'm going to focus on several important questions that analysts like to be able to answer about programs. For each, I'm going to talk about how why answering the question well is harder in a dynamic language (Python) versus a language where more is known at compile-time (Rust). Then, I want to explore what tradeoffs we can make by sacrificing some amount of soundness in the hopes of solving the problem to a useful extent.
Here are the questions / problems:
- Typing: What type is this variable? What attributes and methods does it have?
- Determining Intraprocedural Control-Flow Graphs: What statements can execute after this one?
- Heap analysis: What variables can reference this object instance, and what values does it hold?
- to be written
- Iterators: How many times will this loop run, and what values with the loop value take?
- to be written
Typing
- "What type is this object? What attributes and methods does it have?"
Probably the most common type of analyses are ones that attempt to determine the type of each variable, ex. whether it's a primitive type or a more complex struct/class with a set of fields and methods (with their associated types).
Typing is important for downstream analyses that need to ex. match callsites to specific method definitions (method resolution).
In a statically-typed language, each variable has exactly one type, and it's known at compile time, so we get this essentially for free.
How do types work in Python? Each variable refers to an object, which is an instance of a class.
The same variable can refer to objects of different classes over different execution paths:
def foo(a: bool):
x = 5
if a:
x = "no longer an int"
print(type(x)) # might be <class 'int'> or <class 'str'>
And unlike structs in Rust, classes aren't a static container of attributes and methods that is fixed at compile-time:
class Foo: pass
def bar(a: bool):
b = Foo()
if a:
b.woo = lambda x: x + 2
b.woo(12) # this might dispatch to the woo added above, or to
# a .woo added before, or __getattr__ might be invoked,
# potentially raising AttributeError
In Python, the shape of object instances is really created and mutated at runtime. Method and attribute lookup happens at runtime and can raise exceptions.
How do type checkers for dynamic languages work?
Let's focus on mypy, a popular static type checker for Python.
mypy assumes variables have static types after initialization, unless they're explicitly typed otherwise. This means that it will error on our program from before:
def foo(a: bool):
x = 5
if a:
x = "no longer an int" # error: Incompatible types in assignment
print(type(x))
even though this is valid Python. mypy is making the assumption that this type of dynamism is typically an error unless explicitly annotated:
def foo(a: bool):
x: int | str = 5
# unchanged...
# Success: no issues found
def foo(a: bool):
x: Any = 5
# unchanged...
# Success : no issues found
mypy also disallows runtime additions to classes:
class Foo: pass
def bar(a: bool):
b = Foo()
if a:
b.woo = lambda x: x + 2 # error: "Foo" has no attribute "woo"
b.woo(12)
And we can again recover the dynamism with explicit typing1 :
from typing import Callable
class Foo:
woo: Callable[[int], int]
def bar(a: bool):
# unchanged...
# Sucess: no issues found
Essentially mypy accepts a subset of the total space of valid Python programs as a design decision. They deal with the annoying (from an analysis perspective) bits of a dynamic language by leaving them out of the semantics unless explicitly annotated. This obviously makes it much easier! And with the easy escape hatch given by annotations as well as gradual typing, these tools have chosen a decent tradeoff.
Control-Flow Analysis
- "After this statement, what statements could execute next?"
It's often useful to transform functions into an intermediate representation that is represented as a directed graph over blocks of statements that are executed atomically. This is often called a Control-Flow Graph (CFG). Consider the function below:
def foo(a: int):
if a > 5:
print("a > 5")
else:
print("a <= 5)
This function would transform into a representation like this:
block0:
if a > 5 jmp block1 else jmp block2
block1:
print("a > 5")
jmp block3
block2:
print("a <= 5")
jmp block3
block3:
return None
Each block is a series of statements, ending with a terminator, which can be either a conditional/unconditional jump, or a return.
Having a representation like this is nice because it provides a standardized way to reason about control flow, as opposed to handling if statements differently than while loops differently than for loops, etc. We can see that answering the question "What statements can execute after statement X?" is now easy. If the statement of interest is not the last statement in a block, only the following statement in that block can execute next. If the statement is a terminator at the end of a block, any of the first statements of the blocks that may be jumped to could be executed next.
So is converting a function to a CFG harder in Python than Rust?
Ignoring control flow arising from exceptions, constructing a CFG isn't that much trickier for Python programs than Rust programs. However, when doing analysis over the CFG, we might care about determining in what cases a jump condition is truthy vs. falsy.
For the purposes of the tensor shape inference analysis I'm working on, I'm curious if ignoring conditions, and assuming any path through the CFG could be taken is a viable approach.
Because our analysis is a dataflow analysis, it updates a mapping from variables to the possible shapes they could be at each statement, and joins possible shapes when multiple control flow paths flow together. Here is an example for our analysis:
def foo(a: T["b d"], flag: bool):
if flag:
a = torch.flatten(a)
# a has shape [bd]
else:
a = a * 2
# a has shape [b d] (unchanged)
# here a can have shape [bd] or [b d]
We're explicitly ignoring the flag condition and assuming that either branch could execute. Then, when inferring future shapes, we run our inference model on each possible shape a variable could have, accepting the ones where dimension constraints are met. If none meet dimension constraints, we throw an error. You could also imagine throwing an error if any of the possible shapes don't satisfy a dimension constraint, but we believed this would lead to a lot of false positives.
This approach has a potential issue: the possible shapes a variable could get very large, and we might miss true errors. On the other hand, it probably won't admit many false positive errors.
While the true test will be applying our analysis to many real PyTorch programs, my hunch (hope?) is that this won't be too big of a problem in practice, as (1) most branching operations over tensor operations don't result in different tensor shapes, only different tensor values, and (2) our strategy of rejecting inferred shapes that don't meet constraints will keep the set of possible shapes relatively small.
Object Tracking / Heap Analysis
-
"What object instance does this variable reference?" and "What concrete value might this object take?"
-
why do we need this?
- very common pattern to have ex.
nn.Modules call signature be parameterized by attributes
- very common pattern to have ex.
-
why is pointer + constant analysis trickier in python?
-
what simplifying assumptions can we make?
At this point, we're assuming we have an IR for each function—consisting of a directed graph over blocks of statements—as well as inferred or annotated types for every variable (this includes the Any type).
- a few things we might want next
- pointer analysis - which variables can refer to which objects, but not necessarily modeling the concrete states of the objects
- some sort of constant analysis on top of that - modeling the values of fields of objects
Iterators
-
"How many times will this loop run, and what values will the loop value take?"
-
python
-
rust
- because method dispatch is easy, can be easily represented
-
How far can the [[Static Race Detection|RacerD]] approach get you?
- if we're willing to sacrifice soundness, can we make assumptions about what types of programming patterns sane developers use that make analysis more tractable?
our woo annotation tells mypy to trust us that .woo on a Foo instance will resolve correctly, yet unless we monkey-patch Foo at runtime to have a woo callable, we'll still get a runtime Exception when we call bar(a=False)