diff --git a/src/dual.jl b/src/dual.jl index ad0af94..d71fa34 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -170,17 +170,17 @@ end Base.convert(::Type{Dual}, z::Dual) = z Base.convert(::Type{Dual}, x::Number) = Dual(x) -Base.:(==)(z::Dual, w::Dual) = value(z) == value(w) -Base.:(==)(z::Dual, x::Number) = value(z) == x -Base.:(==)(x::Number, z::Dual) = value(z) == x +Base.:(==)(z::Dual, w::Dual) = value(z) == value(w) && epsilon(z) == epsilon(w) +Base.:(==)(z::Dual, x::Number) = value(z) == x && iszero(epsilon(z)) +Base.:(==)(x::Number, z::Dual) = z == x Base.isequal(z::Dual, w::Dual) = isequal(value(z),value(w)) && isequal(epsilon(z), epsilon(w)) Base.isequal(z::Dual, x::Number) = isequal(value(z), x) && isequal(epsilon(z), zero(x)) Base.isequal(x::Number, z::Dual) = isequal(z, x) -Base.isless(z::Dual{<:Real},w::Dual{<:Real}) = value(z) < value(w) -Base.isless(z::Real,w::Dual{<:Real}) = z < value(w) -Base.isless(z::Dual{<:Real},w::Real) = value(z) < w +Base.isless(z::Dual{<:Real},w::Dual{<:Real}) = isless(value(z), value(w)) || (isequal(value(z), value(w)) && isless(epsilon(z), epsilon(w))) +Base.isless(z::Real,w::Dual{<:Real}) = isless(z, value(w)) || (isequal(z, value(w)) && isless(zero(epsilon(w)), epsilon(w))) +Base.isless(z::Dual{<:Real},w::Real) = isless(value(z), w) || (isequal(value(z), w) && isless(epsilon(z), zero(epsilon(z)))) Base.hash(z::Dual) = (x = hash(value(z)); epsilon(z)==0 ? x : bitmix(x,hash(epsilon(z)))) diff --git a/test/automatic_differentiation_test.jl b/test/automatic_differentiation_test.jl index 2af2b2a..2b43cb2 100644 --- a/test/automatic_differentiation_test.jl +++ b/test/automatic_differentiation_test.jl @@ -43,20 +43,26 @@ powwrap(z, n, epspart=0) = Dual(z, epspart)^n @test powwrap(1, -1) == powwrap(1.0, -1) # special case is handled @test powwrap(1, -2) == powwrap(1.0, -2) # special case is handled @test powwrap(1, -123) == powwrap(1.0, -123) # special case is handled -@test powwrap(1, 0) == Dual(1, 1) -@test powwrap(123, 0) == Dual(1, 1) +@test powwrap(1, 0) == Dual(1, 0) +@test powwrap(1, 0) != Dual(1, 1) +@test powwrap(123, 0) == Dual(1, 0) +@test powwrap(123, 0) != Dual(1, 1) for i ∈ -3:3 - @test powwrap(1, i) == Dual(1, i) + @test powwrap(1, i) == Dual(1, 0) + @test i == 0 || (powwrap(1, i) != Dual(1, i)) end # this no longer throws 1/0 DomainError -@test powwrap(0, Dual(0, 1)) == Dual(1, 0) +@test powwrap(0, Dual(0, 1)) == Dual(1, -Inf) +@test powwrap(0, Dual(0, 1)) != Dual(1, 0) # this never did DomainError because it starts off with a float -@test 0.0^Dual(0, 1) == Dual(1.0, NaN) +@test 0.0^Dual(0, 1) == Dual(1.0, -Inf) +@test 0.0^Dual(0, 1) != Dual(1.0, NaN) # and Dual^Dual uses a log and is now type stable # because the log promotes ints to floats for all values @test typeof(value(powwrap(0, Dual(0, 1)))) == Float64 -@test Dual(0, 1)^Dual(0, 1) == Dual(1, 0) +@test Dual(0, 1)^Dual(0, 1) == Dual(1, -Inf) +@test Dual(0, 1)^Dual(0, 1) != Dual(1, 0) y = Dual(2.0, 1)^UInt64(0) @test !isnan(epsilon(y))