-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LinalgExt] Generalize attribute setting for attention decomposition #18780
Conversation
compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
Show resolved
Hide resolved
auto qkAttrs = (*this)->getAttrOfType<DictionaryAttr>("qk_attrs"); | ||
auto pvAttrs = (*this)->getAttrOfType<DictionaryAttr>("pv_attrs"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to declare constant variables for qk/pv_attrs
because they are also used in KernelConfig.cpp
. Also, we can get rid of magic numbers/strings. We can declare the constant variables in MarkerUtils.h. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can do that. MarkerUtils is part of Codegen/, while LinalgExt is part of Dialects/. I can make it part of the attention op definition as well I guess. But it doesnt really make sense to be part of op definition as well? It's codegen specific.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I missed that! Can we add some getters to extraClassDeclaration
? The concern is that the magic strings are floating in the codebase, which looks brittle to me. It looks better if we can centralize to somewhere, e.g., we can do something like:
// They share the same entry which could be "config".
void setConfig(DictionaryAttr attr) { ... }
DictionaryAttr getConfig() { return (*this)->getAttrOfType<DictionaryAttr>("config"); }
bool hasConfig(StringRef str) {
auto config = getConfig();
if (!config) return false;
return config.getNamed(str);
}
// Marker-like things
static StringRef getQKAttrStr() { return "qk_attrs"; }
static StringRef getPVAttrStr() { return "pv_attrs"; }
iree/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Lines 530 to 561 in 0c2c627
let extraClassDeclaration = [{ | |
// Method to implement for specifying output range for | |
// DestinationStyleOpInterface | |
MutableOperandRange getDpsInitsMutable(); | |
SmallVector<AffineMap> getIndexingMapsArray(); | |
AffineMap getQueryMap() { | |
return cast<AffineMap>(getIndexingMapsArray()[0]); | |
} | |
AffineMap getKeyMap() { | |
return cast<AffineMap>(getIndexingMapsArray()[1]); | |
} | |
AffineMap getValueMap() { | |
return cast<AffineMap>(getIndexingMapsArray()[2]); | |
} | |
AffineMap getScaleMap() { | |
return cast<AffineMap>(getIndexingMapsArray()[3]); | |
} | |
std::optional<AffineMap> getMaskMap() { | |
if (getMask()) { | |
return cast<AffineMap>(getIndexingMapsArray()[4]); | |
} | |
return std::nullopt; | |
} | |
AffineMap getOutputMap() { | |
return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs()]); | |
} | |
int64_t getIterationDomainRank() { | |
return getQueryMap().getNumDims(); | |
} | |
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a decompositon_config attribute to AttentionOp, which acts as the config attribute you asked for above. I also added getQKAttrStr/getPVAttrStr to extraClassDeclarations.
4ef4853
to
9152ae3
Compare
9152ae3
to
e7c3c79
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one final question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks!
This PR teaches attention decomposition to set attributes for attention matmuls by passing attribute dictionaries to iree_linalg_ext.online_attention operation. This allows us to further control codegen of matmuls (generally the root operations) after decomposition (for example, setting lowering config on the decompose matmuls).