-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[RFC][vLLM IR]: Automatically compile native impl for IR ops #38744
Description
Motivation.
Sometimes we need to call an IR op inside another opaque torch custom op. That means the IR op will be invisible to model-level compilation, and dispatching to the raw native implementation will hurt performance. This problem is not unique to vLLM IR; it happens for CustomOp instances as well, and we currently circumvent it by wrapping forward_native with torch.compile.
Prime examples of this are SiluAndMul and QuantFP8 inside fused_moe. The same mechanism is utilized by the _DecodeConcatQuantFP8 inside the MLA custom op.
Proposed Change.
We wrap the native implementation (or multiple implementations) with a torch.compile decorator. We can do that by setting IrOpImpl.impl_fn = torch.compile(IrOpImpl.impl_fn, ...) (including dynamic shape annotations).
The big question is lifetime: ideally we can set this with the set_priority context and restore it after, but will that persist the compiled code, or will it recompile every time? I guess if torch doesn't cache this, we could cache it manually?
We can optionally guard this with torch.compiler.is_compiling() although I think torch.compile already does that for us?
Alternative: just compile once
I worry this would let state escape arbitrarily, so multiple LLM instances with different configs would affect each other.
Alternative: register a compiled_native implementation
I think this is worse, because the dispatching logic becomes more complex, and we'd have to dispatch differently in the "global" and "wrapped" regions.
Feedback Period.
4/1 - 4/8
CC List.
@zou3519 @tjtanaa @gmagogsfm @angelayi @bringlein @LucasWilkinson @mgoin
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.