diff --git a/.gitignore b/.gitignore index 046bc62068..42b36eacdc 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ src/generator/generator *.csproj *.ilk *.manifest +*.tmp /build/vs2012 /build/vs2013 /build/gmake diff --git a/src/AST/Class.cs b/src/AST/Class.cs index e61296195f..34f013a854 100644 --- a/src/AST/Class.cs +++ b/src/AST/Class.cs @@ -113,6 +113,7 @@ public Class() IsUnion = false; IsOpaque = false; IsPOD = false; + IsForcedRefType = false; Type = ClassType.RefType; Layout = new ClassLayout(); } @@ -137,6 +138,13 @@ public Class(Class @class) HasNonTrivialCopyConstructor = @class.HasNonTrivialCopyConstructor; HasNonTrivialDestructor = @class.HasNonTrivialDestructor; IsStatic = @class.IsStatic; + IsForcedRefType = @class.IsForcedRefType; + } + + public bool IsForcedRefType + { + get; + set; } public bool HasBase diff --git a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs index 0b52e3b95b..536a88dfa3 100644 --- a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs +++ b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs @@ -2573,10 +2573,10 @@ private ParamMarshal GenerateFunctionParamMarshal(Parameter param, int paramInde var paramType = param.Type; Class @class; - if ( (paramType.GetFinalPointee() ?? paramType).Desugar().TryGetClass(out @class) - && @class.IsRefType) + if ( (paramType.GetFinalPointee() ?? paramType).Desugar().TryGetClass(out @class)) { - WriteLine("{0} = new {1}();", param.Name, paramType); + if(@class.IsRefType || @class.IsValueType) + WriteLine("{0} = new {1}();", param.Name, paramType); } } diff --git a/src/Generator/Passes/CheckMacrosPass.cs b/src/Generator/Passes/CheckMacrosPass.cs index 1946c30fd1..568f62db05 100644 --- a/src/Generator/Passes/CheckMacrosPass.cs +++ b/src/Generator/Passes/CheckMacrosPass.cs @@ -46,6 +46,10 @@ namespace CppSharp.Passes /// Used to flag a method as internal to an assembly. So, it is /// not accessible outside that assembly. /// + /// CS_REF_TYPE (classes and structs) + /// Used to flag a type as ref type. So, it is generated as a C# class + /// even when the CheckValueTypeClassesPass is enabled. + /// /// There isn't a standardized header provided by CppSharp so you will /// have to define these on your own. /// @@ -109,6 +113,8 @@ public override bool VisitClassDecl(Class @class) if (expansions.Any(e => e.Text == Prefix + "_VALUE_TYPE")) @class.Type = ClassType.ValueType; + else if (expansions.Any(e => e.Text == Prefix + "_REF_TYPE")) + @class.IsForcedRefType = true; // If the class is a forward declaration, then we process the macro expansions // of the complete class as if they were specified on the forward declaration. diff --git a/src/Generator/Passes/CheckValueTypeClassesPass.cs b/src/Generator/Passes/CheckValueTypeClassesPass.cs new file mode 100644 index 0000000000..3c696d23a4 --- /dev/null +++ b/src/Generator/Passes/CheckValueTypeClassesPass.cs @@ -0,0 +1,39 @@ +using System.Linq; +using CppSharp.AST; + +namespace CppSharp.Passes +{ + public class CheckValueTypeClassesPass : TranslationUnitPass + { + public CheckValueTypeClassesPass() + { + } + + public override bool VisitClassDecl(Class @class) + { + @class.Type = CheckClassIsStructible(@class, Driver) ? ClassType.ValueType : @class.Type; + return base.VisitClassDecl(@class); + } + + private bool CheckClassIsStructible(Class @class, Driver Driver) + { + if (@class.IsUnion || @class.Namespace.Templates.Any(tmp => tmp.Name.Equals(@class.Name))) + return false; + if (@class.IsInterface || @class.IsStatic || @class.IsAbstract) + return false; + if (@class.Declarations.Any(decl => decl.Access == AccessSpecifier.Protected)) + return false; + if (@class.IsDynamic) + return false; + if (@class.HasBaseClass && @class.BaseClass.IsRefType) + return false; + + var allTrUnits = Driver.ASTContext.TranslationUnits; + if (allTrUnits.Any(trUnit => trUnit.Classes.Any( + cls => cls.Bases.Any(clss => clss.IsClass && clss.Class == @class)))) + return false; + + return @class.IsPOD && !@class.IsForcedRefType; + } + } +} \ No newline at end of file diff --git a/tests/Basic/Basic.cs b/tests/Basic/Basic.cs index 269386a702..27213e98f8 100644 --- a/tests/Basic/Basic.cs +++ b/tests/Basic/Basic.cs @@ -32,6 +32,7 @@ public override void Preprocess(Driver driver, ASTContext ctx) { driver.AddTranslationUnitPass(new GetterSetterToPropertyPass()); driver.AddTranslationUnitPass(new CheckMacroPass()); + driver.AddTranslationUnitPass(new CheckValueTypeClassesPass()); ctx.SetClassAsValueType("Bar"); ctx.SetClassAsValueType("Bar2"); ctx.IgnoreClassWithName("IgnoredType"); diff --git a/tests/Basic/Basic.h b/tests/Basic/Basic.h index c9e8b9b8bf..a0582bd2d0 100644 --- a/tests/Basic/Basic.h +++ b/tests/Basic/Basic.h @@ -799,3 +799,8 @@ class DLL_API ReturnsEmpty public: Empty getEmpty(); }; + +DLL_API class TestIsStructFreeClass { }; +DLL_API struct TestIsStructFreeStruct { }; +DLL_API struct TestIsStructInheritedStruct { }; +DLL_API struct TestIsStructInheritingStruct : public TestIsStructInheritedStruct { }; \ No newline at end of file diff --git a/tests/CLITemp/CLITemp.cs b/tests/CLITemp/CLITemp.cs index 71a42f1758..f92266a95e 100644 --- a/tests/CLITemp/CLITemp.cs +++ b/tests/CLITemp/CLITemp.cs @@ -1,5 +1,6 @@ using CppSharp.AST; using CppSharp.Generators; +using CppSharp.Passes; using CppSharp.Utils; namespace CppSharp.Tests @@ -20,6 +21,7 @@ public override void Setup(Driver driver) public override void Preprocess(Driver driver, ASTContext ctx) { + driver.TranslationUnitPasses.AddPass(new CheckValueTypeClassesPass()); } public static void Main(string[] args) diff --git a/tests/STL/STL.cs b/tests/STL/STL.cs index d87e2d70b8..07858f59ad 100644 --- a/tests/STL/STL.cs +++ b/tests/STL/STL.cs @@ -1,5 +1,6 @@ using CppSharp.AST; using CppSharp.Generators; +using CppSharp.Passes; using CppSharp.Utils; namespace CppSharp.Tests @@ -14,6 +15,8 @@ public STL(GeneratorKind kind) public override void Preprocess(Driver driver, ASTContext ctx) { ctx.SetClassAsValueType("IntWrapperValueType"); + driver.TranslationUnitPasses.AddPass(new CheckMacroPass()); + driver.TranslationUnitPasses.AddPass(new CheckValueTypeClassesPass()); } public static void Main(string[] args) diff --git a/tests/STL/STL.h b/tests/STL/STL.h index 3220da3a75..8a63e515ee 100644 --- a/tests/STL/STL.h +++ b/tests/STL/STL.h @@ -2,7 +2,8 @@ #include #include -struct DLL_API IntWrapper +#define CS_REF_TYPE +struct DLL_API CS_REF_TYPE IntWrapper { int Value; };