Skip to content

Commit

Permalink
Merge pull request #840 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.20.19 release
  • Loading branch information
ablaom authored Sep 15, 2022
2 parents 2023856 + d2af6bd commit 4cb559c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 19 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.20.18"
version = "0.20.19"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -46,4 +46,4 @@ ScientificTypes = "3"
StatisticalTraits = "3.2"
StatsBase = "0.32, 0.33"
Tables = "0.2, 1.0"
julia = "1.6"
julia = "1.6"
39 changes: 22 additions & 17 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@
## TODO: need to add checks on the arguments of
## predict(::Machine, ) and transform(::Machine, )

_err_rows_not_allowed() =
throw(ArgumentError("Calling `transform(mach, rows=...)` or "*
"`predict(mach, rows=...)` when "*
"`mach.model isa Static` is not allowed, as no data "*
"is bound to `mach` in this case. Specify a explicit "*
"data or node, as in `transform(mach, X)`, or "*
"`transform(mach, X1, X2, ...)`. "))
_err_serialized(operation) =
throw(ArgumentError("Calling $operation on a "*
"deserialized machine with no data "*
"bound to it. "))
const ERR_ROWS_NOT_ALLOWED = ArgumentError(
"Calling `transform(mach, rows=...)` or "*
"`predict(mach, rows=...)` when "*
"`mach.model isa Static` is not allowed, as no data "*
"is bound to `mach` in this case. Specify a explicit "*
"data or node, as in `transform(mach, X)`, or "*
"`transform(mach, X1, X2, ...)`. "
)

err_serialized(operation) = ArgumentError(
"Calling $operation on a "*
"deserialized machine with no data "*
"bound to it. "
)

warn_serializable_mach(operation) = "The operation $operation has been called on a "*
"deserialised machine mach whose learned parameters "*
Expand All @@ -42,9 +45,11 @@ warn_serializable_mach(operation) = "The operation $operation has been called on
# `ret` in the ordinary case that the operation does not include an "report" component ;
# otherwise update `mach.report` with that component and return the non-report part of
# `ret`:
named_tuple(t::Nothing) = NamedTuple()
named_tuple(t) = t
function get!(ret, operation, mach)
if operation in reporting_operations(mach.model)
report = last(ret)
report = named_tuple(last(ret))
if isnothing(mach.report) || isempty(mach.report)
mach.report = report
else
Expand All @@ -66,13 +71,12 @@ for operation in OPERATIONS
ex = quote
function $(operation)(mach::Machine{<:Model,false}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && _err_serialized($operation)
ret = ($operation)(mach, mach.args[1](rows=rows))
return get!(ret, $quoted_operation, mach)
isempty(mach.args) && throw(err_serialized($operation))
return ($operation)(mach, mach.args[1](rows=rows))
end
function $(operation)(mach::Machine{<:Model,true}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && _err_serialized($operation)
isempty(mach.args) && throw(err_serialized($operation))
model = mach.model
ret = ($operation)(
model,
Expand All @@ -83,7 +87,8 @@ for operation in OPERATIONS
end

# special case of Static models (no training arguments):
$operation(mach::Machine{<:Static}; rows=:) = _err_rows_not_allowed()
$operation(mach::Machine{<:Static,true}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED)
$operation(mach::Machine{<:Static,false}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED)
end
eval(ex)

Expand Down
1 change: 1 addition & 0 deletions test/composition/models/static_transformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ MLJBase.predict(transf::YourTransformer, verbosity, X) =
X = (x1=rand(3), x2=[1, 2, 3]);
mach = machine(YourTransformer(:x2))
@test transform(mach, X) == [1, 2, 3]
@test_throws MLJBase.ERR_ROWS_NOT_ALLOWED transform(mach, rows=:)
@test predict(mach, X) == [3, 2, 1]
@test report(mach).nrows == 3
transform(mach, (x2=["a", "b"],))
Expand Down
12 changes: 12 additions & 0 deletions test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,18 @@ end
rm(filename)
end

struct ReportingDynamic <: Unsupervised end
MLJBase.fit(::ReportingDynamic, _, X) = nothing, 16, NamedTuple()
MLJBase.transform(::ReportingDynamic,_, X) = (X, (news=42,))
MLJBase.reporting_operations(::Type{<:ReportingDynamic}) = (:transform, )

@testset "corner case for operation applied to a reporting machinw" begin
model = ReportingDynamic()
mach = fit!(machine(model, [1,2,3]), verbosity=0)
@test transform(mach, rows=:) == [1, 2, 3]
@test transform(mach, rows=1:2) == [1, 2]
end

end # module

true

0 comments on commit 4cb559c

Please sign in to comment.