Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
e9639ff
implement the multinomial GLM (squash commit)
May 10, 2026
fed8db5
undo irrelevant changes
May 10, 2026
fd52669
Merge commit '726f9914af03df2b743f4541a4d0dae243ef267b' into HEAD
yashikno May 10, 2026
02de278
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 10, 2026
8b3204e
fix broken rev test
May 11, 2026
d39065d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 11, 2026
9020b83
remove the non-propto overload from the OpenCL file:
May 11, 2026
fbcefd0
fix failing CI test
May 11, 2026
c9409fa
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 11, 2026
40b3a4d
Merge branch 'develop' into jachym_multinomial2
May 22, 2026
cddc109
Merge branch 'stan-dev:develop' into jachym_multinomial2
jachymb May 26, 2026
48e1049
add mix test + minor code polishing
May 26, 2026
2b29d5a
Merge branch 'jachym_multinomial2' of https://github.com/jachymb/math…
May 26, 2026
6f46041
Fix failing mix test
May 26, 2026
d8a9a94
Merge commit '123176d90e9615fe9ed503518ee9358f914a74d0' into HEAD
yashikno May 26, 2026
c7eb138
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 26, 2026
1d361ae
For opencl multinomial lpmf
SteveBronder Jun 3, 2026
7de0313
Merge commit '58feca35edd163f37d49bd17d7218108cfc2bf9c' into HEAD
yashikno Jun 3, 2026
51b1088
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jun 3, 2026
7214731
fix docs
SteveBronder Jun 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions stan/math/opencl/kernels/multinomial_logit_glm_lpmf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#ifndef STAN_MATH_OPENCL_KERNELS_MULTINOMIAL_LOGIT_GLM_LPMF_HPP
#define STAN_MATH_OPENCL_KERNELS_MULTINOMIAL_LOGIT_GLM_LPMF_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/kernel_cl.hpp>
#include <string>

namespace stan {
namespace math {
namespace opencl_kernels {

// \cond
static constexpr const char* multinomial_logit_glm_kernel_code = STRINGIFY(
// \endcond
/** \ingroup opencl_kernels
* GPU kernel for the Generalized Linear Model (GLM) with multinomial
* distribution and softmax (logit) link function.
*
* Must be run with at least N_instances threads and local size LOCAL_SIZE_.
* Each thread handles one instance n = gid.
*
* The kernel performs two passes over the K classes for instance n:
* 1. find max(eta[n,:]) for numerical stability,
* 2. accumulate sum_exp, S_n, and logp using shifted eta
* (eta[n,k] - max) to avoid catastrophic cancellation; skips
* skips y_nk=0 terms to implement the 0*log(0)=0 convention;
* if need_delta, stash exp(eta[n,k] - max) into delta_global.
* A final loop normalizes delta (if need_delta) and subtracts
* lgamma(y_nk+1) terms (if need_logp_gamma), reading only y_global
* and delta_global, without re-reading x_beta_global or alpha_global.
*
* @param[out] logp_global partial logp sums, one per work group
* @param[out] delta_global residual matrix N_instances x N_classes
* (col-major)
* @param[in] y_global outcome counts, N_instances x N_classes (col-major)
* @param[in] x_beta_global product x*beta, N_instances x N_classes
* (col-major)
* @param[in] alpha_global intercepts: K values if is_alpha_vector, else
* N_instances x N_classes (col-major)
* @param N_instances number of instances
* @param N_classes number of outcome classes
* @param is_alpha_vector 1 if alpha is shared 1xK row, 0 if NxK
* @param need_delta 1 if delta_global should be computed and written
* @param need_logp_gamma 1 if lgamma terms should be included in logp
*/
__kernel void multinomial_logit_glm(
__global double* logp_global, __global double* delta_global,
const __global int* y_global, const __global double* x_beta_global,
const __global double* alpha_global, const int N_instances,
const int N_classes, const int is_alpha_vector, const int need_delta,
const int need_logp_gamma) {
const int gid = get_global_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const int wg_id = get_group_id(0);

__local double local_storage[LOCAL_SIZE_];

double logp = 0;
if (gid < N_instances) {
// Pass 1: row-wise max of eta for numerical stability.
double eta_max = -DBL_MAX;
for (int k = 0; k < N_classes; k++) {
int nk = k * N_instances + gid;
int alpha_idx = is_alpha_vector * k + !is_alpha_vector * nk;
double eta_k = x_beta_global[nk] + alpha_global[alpha_idx];
eta_max = fmax(eta_k, eta_max);
}

// Pass 2: sum_exp, S_n, logp; if need_delta stash exp_k in
// delta_global.
double sum_exp = 0;
int S_n = 0;
for (int k = 0; k < N_classes; k++) {
int nk = k * N_instances + gid;
int alpha_idx = is_alpha_vector * k + !is_alpha_vector * nk;
double shifted_eta_k
= x_beta_global[nk] + alpha_global[alpha_idx] - eta_max;
double exp_k = exp(shifted_eta_k);
sum_exp += exp_k;
int y_nk = y_global[nk];
S_n += y_nk;
logp += (y_nk != 0) ? y_nk * shifted_eta_k : 0;
if (need_delta) {
delta_global[nk] = exp_k;
}
}
logp -= S_n * log(sum_exp);

if (need_logp_gamma) {
logp += lgamma(S_n + 1.0);
}

// Normalize delta and/or subtract lgamma(y_nk+1) in one pass.
if (need_delta || need_logp_gamma) {
for (int k = 0; k < N_classes; k++) {
int nk = k * N_instances + gid;
int y_nk = y_global[nk];
logp -= need_logp_gamma * lgamma(y_nk + 1.0);
if (need_delta) {
delta_global[nk] = y_nk - S_n * delta_global[nk] / sum_exp;
}
}
}
}

// Work-group reduction of logp.
local_storage[lid] = logp;
barrier(CLK_LOCAL_MEM_FENCE);
for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
step /= REDUCTION_STEP_SIZE) {
if (lid < step) {
for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
local_storage[lid] += local_storage[lid + step * i];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (lid == 0) {
logp_global[wg_id] = local_storage[0];
}
}
// \cond
);
// \endcond

/** \ingroup opencl_kernels
* See the docs for \link kernels/multinomial_logit_glm_lpmf.hpp
* multinomial_logit_glm() \endlink
*/
const kernel_cl<out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, int,
int, int, int, int>
multinomial_logit_glm("multinomial_logit_glm",
{multinomial_logit_glm_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

} // namespace opencl_kernels
} // namespace math
} // namespace stan
#endif
#endif
1 change: 1 addition & 0 deletions stan/math/opencl/prim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
#include <stan/math/opencl/prim/neg_binomial_2_lpmf.hpp>
#include <stan/math/opencl/prim/neg_binomial_2_log_lpmf.hpp>
#include <stan/math/opencl/prim/neg_binomial_2_log_glm_lpmf.hpp>
#include <stan/math/opencl/prim/multinomial_logit_glm_lpmf.hpp>
#include <stan/math/opencl/prim/normal_id_glm_lpdf.hpp>
#include <stan/math/opencl/prim/normal_cdf.hpp>
#include <stan/math/opencl/prim/normal_lccdf.hpp>
Expand Down
177 changes: 177 additions & 0 deletions stan/math/opencl/prim/multinomial_logit_glm_lpmf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#ifndef STAN_MATH_OPENCL_PRIM_MULTINOMIAL_LOGIT_GLM_LPMF_HPP
#define STAN_MATH_OPENCL_PRIM_MULTINOMIAL_LOGIT_GLM_LPMF_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/prim/size.hpp>
#include <stan/math/opencl/rev/operands_and_partials.hpp>
#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/copy.hpp>
#include <stan/math/opencl/prim/multiply.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/opencl/kernels/multinomial_logit_glm_lpmf.hpp>

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/eval.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/Eigen.hpp>

#include <vector>

namespace stan {
namespace math {

/** \ingroup opencl
* Returns the log PMF of the Generalized Linear Model (GLM)
* with multinomial distribution and softmax (logit) link function.
* This is an OpenCL overload of
* `prim/prob/multinomial_logit_glm_lpmf.hpp`.
* Alpha can be either a shared 1×K row vector or an N×K per-instance matrix.
*
* @tparam T_x type of the design matrix (N×M kernel expression)
* @tparam T_alpha type of the intercept (1×K or N×K kernel expression)
* @tparam T_beta type of the weight matrix (M×K kernel expression)
* @param y outcome count vectors: `y[n]` is a length-K vector of non-negative
* integer counts for instance n
* @param x design matrix (N×M) on OpenCL device
* @param alpha intercept: 1×K broadcast row or N×K per-instance matrix
* @param beta weight matrix (M×K) on OpenCL device
* @return log sum of multinomial log PMFs over all N instances
* @throw std::domain_error if any element of x or beta is infinite or NaN,
* or if alpha contains `+inf` or NaN (`-inf` forces the corresponding softmax
* probability to zero and is allowed)
* @throw std::domain_error if any count in y is negative
* @throw std::invalid_argument if container sizes mismatch
*/
template <bool propto = false, typename T_x, typename T_alpha, typename T_beta,
require_all_prim_or_rev_kernel_expression_t<T_x, T_alpha,
T_beta>* = nullptr>
inline return_type_t<T_x, T_alpha, T_beta> multinomial_logit_glm_lpmf(
const std::vector<std::vector<int>>& y, T_x&& x, T_alpha&& alpha,
T_beta&& beta) {
if (size_zero(y)) {
return 0;
}
return multinomial_logit_glm_lpmf<propto>(
matrix_cl<int>(as_array_or_scalar(y)), std::forward<T_x>(x),
std::forward<T_alpha>(alpha), std::forward<T_beta>(beta));
}

/** \ingroup opencl
* Returns the log PMF of the Generalized Linear Model (GLM)
* with multinomial distribution and softmax (logit) link function.
* This is an OpenCL overload of
* `prim/prob/multinomial_logit_glm_lpmf.hpp`.
* Alpha can be either a shared 1×K row vector or an N×K per-instance matrix.
* @tparam T_y expression of the outcome count matrix (N×K kernel expression)
* @tparam T_x expression of the design matrix (N×M kernel expression)
* @tparam T_alpha expression of the intercept (1×K or N×K kernel expression)
* @tparam T_beta expression of the weight matrix (M×K kernel expression)
* @param y outcome count matrix. Each row is of length-K non-negative
* integer counts for instance n
* @param x design matrix (N×M) on OpenCL device
* @param alpha intercept: 1×K broadcast row or N×K per-instance matrix
* @param beta weight matrix (M×K) on OpenCL device
* @return log sum of multinomial log PMFs over all N instances
* @throw std::domain_error if any element of x or beta is infinite or NaN,
* or if alpha contains `+inf` or NaN (`-inf` forces the corresponding softmax
* probability to zero and is allowed)
* @throw std::domain_error if any count in y is negative
* @throw std::invalid_argument if container sizes mismatch
*/
template <bool propto = false, typename T_y, typename T_x, typename T_alpha,
typename T_beta,
require_all_prim_or_rev_kernel_expression_t<T_y, T_x, T_alpha,
T_beta>* = nullptr>
inline return_type_t<T_x, T_alpha, T_beta> multinomial_logit_glm_lpmf(
T_y&& y, T_x&& x, T_alpha&& alpha, T_beta&& beta) {
using T_partials_return = partials_return_t<T_x, T_alpha, T_beta>;
static constexpr const char* function = "multinomial_logit_glm_lpmf";

const int N_instances = x.rows();
const int N_classes = beta.cols();
check_size_match(function, "Rows of", "y", y.rows(), "rows of", "x",
N_instances);
check_size_match(function, "Columns of", "y", y.cols(), "columns of", "beta",
N_classes);
check_size_match(function, "Columns of", "beta", N_classes, "columns of",
"alpha", alpha.cols());
check_size_match(function, "Columns of", "x", x.cols(), "rows of", "beta",
beta.rows());

const int alpha_rows = alpha.rows();
const bool is_alpha_vector = alpha_rows == 1;
if (!is_alpha_vector) {
check_size_match(function, "Rows of", "alpha", alpha_rows, "rows of", "x",
N_instances);
}

if (N_instances == 0) {
return 0;
}

if constexpr (!include_summand<propto, T_x, T_alpha, T_beta>::value) {
return 0;
}

decltype(auto) y_val = eval(value_of(y));
decltype(auto) x_val = eval(value_of(x));
decltype(auto) alpha_val = eval(value_of(alpha));
decltype(auto) beta_val = eval(value_of(beta));

matrix_cl<double> x_beta_cl = x_val * beta_val;

const int local_size
= opencl_kernels::multinomial_logit_glm.get_option("LOCAL_SIZE_");
const int wgs = (N_instances + local_size - 1) / local_size;

constexpr bool need_delta = is_any_autodiff_v<T_x, T_alpha, T_beta>;

matrix_cl<double> logp_cl(wgs, 1);
matrix_cl<double> delta_cl(0, 0);
if constexpr (need_delta) {
delta_cl = matrix_cl<double>(N_instances, N_classes);
}

try {
opencl_kernels::multinomial_logit_glm(
cl::NDRange(local_size * wgs), cl::NDRange(local_size), logp_cl,
delta_cl, y_val, x_beta_cl, alpha_val, N_instances, N_classes,
is_alpha_vector, need_delta, !propto);
} catch (const cl::Error& e) {
check_opencl_error(function, e);
}

T_partials_return logp = sum(from_matrix_cl(logp_cl));

if (!std::isfinite(logp)) {
check_cl(function, "outcome counts", y_val, "nonnegative") = y_val >= 0;
check_cl(function, "Design matrix", x_val, "finite") = isfinite(x_val);
check_cl(function, "Intercept", alpha_val, "finite") = isfinite(alpha_val);
check_cl(function, "Weight matrix", beta_val, "finite")
= isfinite(beta_val);
}

auto ops_partials = make_partials_propagator(std::forward<T_x>(x),
std::forward<T_alpha>(alpha),
std::forward<T_beta>(beta));
if constexpr (need_delta) {
if constexpr (is_autodiff_v<T_x>) {
partials<0>(ops_partials) = delta_cl * transpose(beta_val);
}
if constexpr (is_autodiff_v<T_alpha>) {
partials<1>(ops_partials)
= is_alpha_vector ? colwise_sum(delta_cl) : delta_cl;
}
if constexpr (is_autodiff_v<T_beta>) {
partials<2>(ops_partials) = transpose(x_val) * delta_cl;
}
}
return ops_partials.build(logp);
}

} // namespace math
} // namespace stan

#endif
#endif
1 change: 1 addition & 0 deletions stan/math/prim/prob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
#include <stan/math/prim/prob/multi_student_t_cholesky_rng.hpp>
#include <stan/math/prim/prob/multi_student_t_lpdf.hpp>
#include <stan/math/prim/prob/multi_student_t_rng.hpp>
#include <stan/math/prim/prob/multinomial_logit_glm_lpmf.hpp>
#include <stan/math/prim/prob/multinomial_logit_lpmf.hpp>
#include <stan/math/prim/prob/multinomial_logit_rng.hpp>
#include <stan/math/prim/prob/multinomial_lpmf.hpp>
Expand Down
Loading
Loading