From ebb9ae6db26b3c9423ebdb1712b66a0104505db5 Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Fri, 27 Mar 2026 14:53:46 +0100 Subject: [PATCH] Narrow Zero arithmetic methods to reduce invalidations The methods +(x::Any, ::Zero), *(::Zero, ::Any), etc. used untyped arguments, causing method invalidations by superseding the fundamental +(x, y) and *(x, y) fallbacks in Base. Narrow these to Number, AbstractArray, and AbstractMutable since these cover the types that participate in MutableArithmetics rewrites. Downstream packages with custom types can define their own +(::MyType, ::Zero) methods (as MultivariatePolynomials already does). --- src/dispatch.jl | 10 ++++++++++ src/rewrite.jl | 25 ++++++++++++++++++------- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index f8b70c48..d74aa559 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -13,6 +13,16 @@ abstract type AbstractMutable end +# Zero arithmetic methods for AbstractMutable types. +# The main Zero arithmetic is defined in rewrite.jl with Number/AbstractArray; +# these methods extend it to AbstractMutable. +Base.:*(z::Zero, ::AbstractMutable) = z +Base.:*(::AbstractMutable, z::Zero) = z +Base.:+(::Zero, x::AbstractMutable) = copy_if_mutable(x) +Base.:+(x::AbstractMutable, ::Zero) = copy_if_mutable(x) +Base.:-(::Zero, x::AbstractMutable) = operate(-, x) +Base.:-(x::AbstractMutable, ::Zero) = copy_if_mutable(x) + function Base.sum( a::AbstractArray{T}; dims = :, diff --git a/src/rewrite.jl b/src/rewrite.jl index 485cd89d..f37ca580 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -58,20 +58,31 @@ broadcast!!(::Union{typeof(add_mul),typeof(+)}, ::Zero, x) = copy_if_mutable(x) broadcast!!(::typeof(add_mul), ::Zero, x, y) = x * y # Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)` -Base.:*(z::Zero, ::Any) = z -Base.:*(::Any, z::Zero) = z +# These methods are narrowed to `Number` and `AbstractArray` to avoid invalidating +# the very broad `+(x, y)`, `*(x, y)` fallbacks in Base, which causes thousands of +# method invalidations across the ecosystem. Downstream packages that define custom +# types participating in MutableArithmetics rewrites should define their own +# `+(::MyType, ::Zero)` etc. methods. +Base.:*(z::Zero, ::Number) = z +Base.:*(::Number, z::Zero) = z +Base.:*(z::Zero, ::AbstractArray) = z +Base.:*(::AbstractArray, z::Zero) = z Base.:*(z::Zero, ::Zero) = z -Base.:+(::Zero, x::Any) = x -Base.:+(x::Any, ::Zero) = x +Base.:+(::Zero, x::Number) = x +Base.:+(x::Number, ::Zero) = x +Base.:+(::Zero, x::AbstractArray) = x +Base.:+(x::AbstractArray, ::Zero) = x Base.:+(z::Zero, ::Zero) = z -Base.:-(::Zero, x::Any) = -x -Base.:-(x::Any, ::Zero) = x +Base.:-(::Zero, x::Number) = -x +Base.:-(x::Number, ::Zero) = x +Base.:-(::Zero, x::AbstractArray) = -x +Base.:-(x::AbstractArray, ::Zero) = x Base.:-(z::Zero, ::Zero) = z Base.:-(z::Zero) = z Base.:+(z::Zero) = z Base.:*(z::Zero) = z -function Base.:/(z::Zero, x::Any) +function Base.:/(z::Zero, x::Number) if iszero(x) throw(DivideError()) else