-
Notifications
You must be signed in to change notification settings - Fork 12.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Single commit implementing the enzyme/autodiff frontend
Co-authored-by: Lorenz Schmidt <bytesnake@mailbox.org>
- Loading branch information
Showing
23 changed files
with
1,875 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute, | ||
//! we create an `AutoDiffItem` which contains the source and target function names. The source | ||
//! is the function to which the autodiff attribute is applied, and the target is the function | ||
//! getting generated by us (with a name given by the user as the first autodiff arg). | ||
|
||
use std::fmt::{self, Display, Formatter}; | ||
use std::str::FromStr; | ||
|
||
use crate::expand::typetree::TypeTree; | ||
use crate::expand::{Decodable, Encodable, HashStable_Generic}; | ||
use crate::ptr::P; | ||
use crate::{Ty, TyKind}; | ||
|
||
/// Forward and Reverse Mode are well known names for automatic differentiation implementations. | ||
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants | ||
/// are a hack to support higher order derivatives. We need to compute first order derivatives | ||
/// before we compute second order derivatives, otherwise we would differentiate our placeholder | ||
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations, | ||
/// as it's already done in the C++ and Julia frontend of Enzyme. (FIXME) remove *First variants. | ||
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and | ||
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online. | ||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub enum DiffMode { | ||
/// No autodiff is applied (used during error handling). | ||
Error, | ||
/// The primal function which we will differentiate. | ||
Source, | ||
/// The target function, to be created using forward mode AD. | ||
Forward, | ||
/// The target function, to be created using reverse mode AD. | ||
Reverse, | ||
/// The target function, to be created using forward mode AD. | ||
/// This target function will also be used as a source for higher order derivatives, | ||
/// so compute it before all Forward/Reverse targets and optimize it through llvm. | ||
ForwardFirst, | ||
/// The target function, to be created using reverse mode AD. | ||
/// This target function will also be used as a source for higher order derivatives, | ||
/// so compute it before all Forward/Reverse targets and optimize it through llvm. | ||
ReverseFirst, | ||
} | ||
|
||
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity. | ||
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode | ||
/// we add to the previous shadow value. To not surprise users, we picked different names. | ||
/// Dual numbers is also a quite well known name for forward mode AD types. | ||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub enum DiffActivity { | ||
/// Implicit or Explicit () return type, so a special case of Const. | ||
None, | ||
/// Don't compute derivatives with respect to this input/output. | ||
Const, | ||
/// Reverse Mode, Compute derivatives for this scalar input/output. | ||
Active, | ||
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute | ||
/// the original return value. | ||
ActiveOnly, | ||
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument | ||
/// with it. | ||
Dual, | ||
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument | ||
/// with it. Drop the code which updates the original input/output for maximum performance. | ||
DualOnly, | ||
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. | ||
Duplicated, | ||
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. | ||
/// Drop the code which updates the original input for maximum performance. | ||
DuplicatedOnly, | ||
/// All Integers must be Const, but these are used to mark the integer which represents the | ||
/// length of a slice/vec. This is used for safety checks on slices. | ||
FakeActivitySize, | ||
} | ||
/// We generate one of these structs for each `#[autodiff(...)]` attribute. | ||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct AutoDiffItem { | ||
/// The name of the function getting differentiated | ||
pub source: String, | ||
/// The name of the function being generated | ||
pub target: String, | ||
pub attrs: AutoDiffAttrs, | ||
/// Describe the memory layout of input types | ||
pub inputs: Vec<TypeTree>, | ||
/// Describe the memory layout of the output type | ||
pub output: TypeTree, | ||
} | ||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct AutoDiffAttrs { | ||
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and | ||
/// e.g. in the [JAX | ||
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions). | ||
pub mode: DiffMode, | ||
pub ret_activity: DiffActivity, | ||
pub input_activity: Vec<DiffActivity>, | ||
} | ||
|
||
impl DiffMode { | ||
pub fn is_rev(&self) -> bool { | ||
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst) | ||
} | ||
pub fn is_fwd(&self) -> bool { | ||
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst) | ||
} | ||
} | ||
|
||
impl Display for DiffMode { | ||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { | ||
match self { | ||
DiffMode::Error => write!(f, "Error"), | ||
DiffMode::Source => write!(f, "Source"), | ||
DiffMode::Forward => write!(f, "Forward"), | ||
DiffMode::Reverse => write!(f, "Reverse"), | ||
DiffMode::ForwardFirst => write!(f, "ForwardFirst"), | ||
DiffMode::ReverseFirst => write!(f, "ReverseFirst"), | ||
} | ||
} | ||
} | ||
|
||
/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...). | ||
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...). | ||
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output. | ||
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg, | ||
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong. | ||
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { | ||
if activity == DiffActivity::None { | ||
// Only valid if primal returns (), but we can't check that here. | ||
return true; | ||
} | ||
match mode { | ||
DiffMode::Error => false, | ||
DiffMode::Source => false, | ||
DiffMode::Forward | DiffMode::ForwardFirst => { | ||
activity == DiffActivity::Dual | ||
|| activity == DiffActivity::DualOnly | ||
|| activity == DiffActivity::Const | ||
} | ||
DiffMode::Reverse | DiffMode::ReverseFirst => { | ||
activity == DiffActivity::Const | ||
|| activity == DiffActivity::Active | ||
|| activity == DiffActivity::ActiveOnly | ||
} | ||
} | ||
} | ||
|
||
/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value | ||
/// for the given argument, but we generally can't know the size of such a type. | ||
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated, | ||
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value | ||
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent | ||
/// users here from marking scalars as Duplicated, due to type aliases. | ||
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool { | ||
use DiffActivity::*; | ||
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it. | ||
if matches!(activity, Const) { | ||
return true; | ||
} | ||
if matches!(activity, Dual | DualOnly) { | ||
return true; | ||
} | ||
// FIXME(ZuseZ4) We should make this more robust to also | ||
// handle type aliases. Once that is done, we can be more restrictive here. | ||
if matches!(activity, Active | ActiveOnly) { | ||
return true; | ||
} | ||
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..)) | ||
&& matches!(activity, Duplicated | DuplicatedOnly) | ||
} | ||
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { | ||
use DiffActivity::*; | ||
return match mode { | ||
DiffMode::Error => false, | ||
DiffMode::Source => false, | ||
DiffMode::Forward | DiffMode::ForwardFirst => { | ||
matches!(activity, Dual | DualOnly | Const) | ||
} | ||
DiffMode::Reverse | DiffMode::ReverseFirst => { | ||
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const) | ||
} | ||
}; | ||
} | ||
|
||
impl Display for DiffActivity { | ||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | ||
match self { | ||
DiffActivity::None => write!(f, "None"), | ||
DiffActivity::Const => write!(f, "Const"), | ||
DiffActivity::Active => write!(f, "Active"), | ||
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"), | ||
DiffActivity::Dual => write!(f, "Dual"), | ||
DiffActivity::DualOnly => write!(f, "DualOnly"), | ||
DiffActivity::Duplicated => write!(f, "Duplicated"), | ||
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"), | ||
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"), | ||
} | ||
} | ||
} | ||
|
||
impl FromStr for DiffMode { | ||
type Err = (); | ||
|
||
fn from_str(s: &str) -> Result<DiffMode, ()> { | ||
match s { | ||
"Error" => Ok(DiffMode::Error), | ||
"Source" => Ok(DiffMode::Source), | ||
"Forward" => Ok(DiffMode::Forward), | ||
"Reverse" => Ok(DiffMode::Reverse), | ||
"ForwardFirst" => Ok(DiffMode::ForwardFirst), | ||
"ReverseFirst" => Ok(DiffMode::ReverseFirst), | ||
_ => Err(()), | ||
} | ||
} | ||
} | ||
impl FromStr for DiffActivity { | ||
type Err = (); | ||
|
||
fn from_str(s: &str) -> Result<DiffActivity, ()> { | ||
match s { | ||
"None" => Ok(DiffActivity::None), | ||
"Active" => Ok(DiffActivity::Active), | ||
"ActiveOnly" => Ok(DiffActivity::ActiveOnly), | ||
"Const" => Ok(DiffActivity::Const), | ||
"Dual" => Ok(DiffActivity::Dual), | ||
"DualOnly" => Ok(DiffActivity::DualOnly), | ||
"Duplicated" => Ok(DiffActivity::Duplicated), | ||
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly), | ||
_ => Err(()), | ||
} | ||
} | ||
} | ||
|
||
impl AutoDiffAttrs { | ||
pub fn has_ret_activity(&self) -> bool { | ||
self.ret_activity != DiffActivity::None | ||
} | ||
pub fn has_active_only_ret(&self) -> bool { | ||
self.ret_activity == DiffActivity::ActiveOnly | ||
} | ||
|
||
pub fn error() -> Self { | ||
AutoDiffAttrs { | ||
mode: DiffMode::Error, | ||
ret_activity: DiffActivity::None, | ||
input_activity: Vec::new(), | ||
} | ||
} | ||
pub fn source() -> Self { | ||
AutoDiffAttrs { | ||
mode: DiffMode::Source, | ||
ret_activity: DiffActivity::None, | ||
input_activity: Vec::new(), | ||
} | ||
} | ||
|
||
pub fn is_active(&self) -> bool { | ||
self.mode != DiffMode::Error | ||
} | ||
|
||
pub fn is_source(&self) -> bool { | ||
self.mode == DiffMode::Source | ||
} | ||
pub fn apply_autodiff(&self) -> bool { | ||
!matches!(self.mode, DiffMode::Error | DiffMode::Source) | ||
} | ||
|
||
pub fn into_item( | ||
self, | ||
source: String, | ||
target: String, | ||
inputs: Vec<TypeTree>, | ||
output: TypeTree, | ||
) -> AutoDiffItem { | ||
AutoDiffItem { source, target, inputs, output, attrs: self } | ||
} | ||
} | ||
|
||
impl fmt::Display for AutoDiffItem { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
write!(f, "Differentiating {} -> {}", self.source, self.target)?; | ||
write!(f, " with attributes: {:?}", self.attrs)?; | ||
write!(f, " with inputs: {:?}", self.inputs)?; | ||
write!(f, " with output: {:?}", self.output) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use std::fmt; | ||
|
||
use crate::expand::{Decodable, Encodable, HashStable_Generic}; | ||
|
||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub enum Kind { | ||
Anything, | ||
Integer, | ||
Pointer, | ||
Half, | ||
Float, | ||
Double, | ||
Unknown, | ||
} | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct TypeTree(pub Vec<Type>); | ||
|
||
impl TypeTree { | ||
pub fn new() -> Self { | ||
Self(Vec::new()) | ||
} | ||
pub fn all_ints() -> Self { | ||
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }]) | ||
} | ||
pub fn int(size: usize) -> Self { | ||
let mut ints = Vec::with_capacity(size); | ||
for i in 0..size { | ||
ints.push(Type { | ||
offset: i as isize, | ||
size: 1, | ||
kind: Kind::Integer, | ||
child: TypeTree::new(), | ||
}); | ||
} | ||
Self(ints) | ||
} | ||
} | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct FncTree { | ||
pub args: Vec<TypeTree>, | ||
pub ret: TypeTree, | ||
} | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct Type { | ||
pub offset: isize, | ||
pub size: usize, | ||
pub kind: Kind, | ||
pub child: TypeTree, | ||
} | ||
|
||
impl Type { | ||
pub fn add_offset(self, add: isize) -> Self { | ||
let offset = match self.offset { | ||
-1 => add, | ||
x => add + x, | ||
}; | ||
|
||
Self { size: self.size, kind: self.kind, child: self.child, offset } | ||
} | ||
} | ||
|
||
impl fmt::Display for Type { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
<Self as fmt::Debug>::fmt(self, f) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.