Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
7ccf12c
feat: implement mixed-precision eigensolver for CG and Davidson methods
laoba657 May 23, 2026
5768c4d
feat: implement mixed-precision eigensolver for CG and Davidson methods
laoba657 May 23, 2026
0261f61
fix: resolve merge conflicts with upstream develop
laoba657 May 25, 2026
a3f1eb1
docs: translate Chinese comments/docs to English per reviewer feedback
laoba657 May 25, 2026
82a5942
fix: address Copilot AI review comments
laoba657 May 25, 2026
1c4a3e0
fix: exclude benchmark test from CI to avoid compilation issues
laoba657 May 25, 2026
dae656d
fix: restore CMakeLists.txt to upstream structure, fix merge corruption
laoba657 May 25, 2026
54630f7
debug: remove new test targets to isolate CI failure cause
laoba657 May 25, 2026
ef9456d
fix: remove stale use_paw reference in diag_mixed_precision
laoba657 May 25, 2026
9c0fe7f
fix: guard mixed precision code with ENABLE_MIXED_PRECISION to avoid …
laoba657 May 25, 2026
cf09169
fix: update hsolver_pw_sup.h constructors and remove unlinkable mixed…
laoba657 May 25, 2026
a7e6961
fix: remove junk files and test report, fix benchmark includes per re…
laoba657 May 25, 2026
5ed3b3a
fix: translate remaining Chinese comments to English in precision_str…
laoba657 May 25, 2026
ba17087
refactor: simplify mixed-precision solver per reviewer feedback
laoba657 May 25, 2026
65f7451
feat: wire diago_precision_mode INPUT parameter to HSolverPW
laoba657 May 25, 2026
653596d
fix: restore #ifdef ENABLE_MIXED_PRECISION guards for CI compatibility
laoba657 May 25, 2026
e795975
fix: update DiagoDavid constructor in hsolver_pw_sup.h to match new s…
laoba657 May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
hsolver::DiagoIterAssist<T, Device>::need_subspace,
PARAM.inp.use_k_continuity);

hsolver_pw_obj.set_diago_precision_mode(parse_precision_mode(PARAM.inp.diago_precision_mode));

hsolver_pw_obj.solve(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), *this->stp.template get_psi_t<T, Device>(), this->pelec, this->pelec->ekb.c,
GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat);
}
Expand Down
2 changes: 2 additions & 0 deletions source/source_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ void ESolver_SDFT_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, int istep, i
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
hsolver::DiagoIterAssist<T, Device>::need_subspace);

hsolver_pw_sdft_obj.set_diago_precision_mode(parse_precision_mode(PARAM.inp.diago_precision_mode));

hsolver_pw_sdft_obj.solve(ucell,
static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt),
*this->stp.template get_psi_t<T, Device>(),
Expand Down
176 changes: 176 additions & 0 deletions source/source_hsolver/diago_cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ DiagoCG<T, Device>::DiagoCG(const std::string& basis_type,
pw_diag_thr_ = pw_diag_thr;
pw_diag_nmax_ = pw_diag_nmax;
nproc_in_pool_ = nproc_in_pool;
precision_mode_ = PrecisionMode::kDouble;
this->one_ = new T(static_cast<T>(1.0));
this->zero_ = new T(static_cast<T>(0.0));
this->neg_one_ = new T(static_cast<T>(-1.0));
Expand Down Expand Up @@ -578,6 +579,166 @@ bool DiagoCG<T, Device>::test_exit_cond(const int& ntry, const int& notconv) con
return f1 && (f2 || f3);
}

template <typename T, typename Device>
double DiagoCG<T, Device>::diag_mixed_precision(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int ld_psi,
const int nband,
const int dim,
T* psi_in,
Real* eigenvalue_in,
const std::vector<double>& ethr_band,
const Real* prec)
{
#ifdef ENABLE_MIXED_PRECISION
using MixedT = typename std::conditional<std::is_same<T, double>::value,
float,
std::complex<float>>::type;
using MixedReal = typename GetTypeReal<MixedT>::type;

auto psi = ct::TensorMap(psi_in,
ct::DataTypeToEnum<T>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nband, ld_psi}));
auto psi_temp = psi.slice({0, 0}, {nband, dim});
auto psi_mixed = psi_temp.cast<MixedT>();

ct::Tensor prec_mixed;
if (prec != nullptr)
{
auto prec_map = ct::TensorMap(const_cast<Real*>(prec),
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({dim}));
prec_mixed = prec_map.template cast<MixedReal>().template to_device<ct_Device>();
}

std::vector<MixedReal> eigen_mixed(nband, static_cast<MixedReal>(0.0));

auto hpsi_func_mixed = [hpsi_func](MixedT* psi_in_mixed,
MixedT* hpsi_out_mixed,
const int ld_psi_mixed,
const int nvec) {
auto psi_in_map = ct::TensorMap(psi_in_mixed,
ct::DataTypeToEnum<MixedT>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nvec, ld_psi_mixed}));
auto psi_in_double = psi_in_map.cast<T>();
auto hpsi_double = ct::Tensor(ct::DataTypeToEnum<T>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nvec, ld_psi_mixed}));
hpsi_func(psi_in_double.template data<T>(), hpsi_double.template data<T>(), ld_psi_mixed, nvec);
auto hpsi_mixed_out = hpsi_double.cast<MixedT>();
ct::TensorMap hpsi_out_tensor(hpsi_out_mixed,
ct::DataTypeToEnum<MixedT>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nvec, ld_psi_mixed}));
hpsi_out_tensor.CopyFrom(hpsi_mixed_out);
};

auto spsi_func_mixed = [spsi_func](MixedT* psi_in_mixed,
MixedT* spsi_out_mixed,
const int ld_psi_mixed,
const int nvec) {
auto psi_in_map = ct::TensorMap(psi_in_mixed,
ct::DataTypeToEnum<MixedT>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nvec, ld_psi_mixed}));
auto psi_in_double = psi_in_map.cast<T>();
auto spsi_double = ct::Tensor(ct::DataTypeToEnum<T>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nvec, ld_psi_mixed}));
spsi_func(psi_in_double.template data<T>(), spsi_double.template data<T>(), ld_psi_mixed, nvec);
auto spsi_mixed_out = spsi_double.cast<MixedT>();
ct::TensorMap spsi_out_tensor(spsi_out_mixed,
ct::DataTypeToEnum<MixedT>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nvec, ld_psi_mixed}));
spsi_out_tensor.CopyFrom(spsi_mixed_out);
};

auto double_subspace = subspace_func_;
auto subspace_func_mixed = [double_subspace](MixedT* psi_in_mixed,
MixedT* psi_out_mixed,
const int ld_psi_mixed,
const int nband_mixed,
const bool S_orth) {
if (!double_subspace)
{
return;
}
auto psi_in_map = ct::TensorMap(psi_in_mixed,
ct::DataTypeToEnum<MixedT>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nband_mixed, ld_psi_mixed}));
auto psi_in_double = psi_in_map.cast<T>();
auto psi_out_double = ct::Tensor(ct::DataTypeToEnum<T>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nband_mixed, ld_psi_mixed}));
double_subspace(psi_in_double.template data<T>(), psi_out_double.template data<T>(), ld_psi_mixed, nband_mixed, S_orth);
auto psi_out_mixed_tensor = psi_out_double.cast<MixedT>();
ct::TensorMap psi_out_tensor(psi_out_mixed,
ct::DataTypeToEnum<MixedT>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nband_mixed, ld_psi_mixed}));
psi_out_tensor.CopyFrom(psi_out_mixed_tensor);
};

hsolver::DiagoCG<MixedT, Device> mixed_solver(
basis_type_,
calculation_,
need_subspace_,
subspace_func_mixed,
pw_diag_thr_,
pw_diag_nmax_,
nproc_in_pool_);
mixed_solver.set_precision_mode(hsolver::PrecisionMode::kFloat);

double float_avg_iter = mixed_solver.diag(hpsi_func_mixed,
spsi_func_mixed,
ld_psi,
nband,
dim,
psi_mixed.template data<MixedT>(),
eigen_mixed.data(),
ethr_band,
prec != nullptr ? prec_mixed.template data<MixedReal>() : nullptr);

Comment on lines +687 to +706
auto psi_refined = psi_mixed.template cast<T>();
psi_temp.CopyFrom(psi_refined);

ct::Tensor eigen = ct::TensorMap(eigenvalue_in,
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({nband}));

ct::Tensor prec_tensor;
if (prec != nullptr)
{
prec_tensor = ct::TensorMap(const_cast<Real*>(prec),
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({dim}))
.template to_device<ct_Device>();
}

avg_iter_ += float_avg_iter;
this->diag_once(prec_tensor, psi_temp, eigen, ethr_band);

if (this->notconv_ > std::max(5, this->n_band_ / 4))
{
std::cout << "\n notconv = " << this->notconv_;
std::cout << "\n DiagoCG::diag_mixed_precision', too many bands are not converged! \n";
}

psi.zero();
psi.sync(psi_temp);
return avg_iter_;
#else
return 0.0;
#endif
}

template <typename T, typename Device>
double DiagoCG<T, Device>::diag(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
Expand All @@ -593,6 +754,21 @@ double DiagoCG<T, Device>::diag(const HPsiFunc& hpsi_func,
REQUIRES_OK(static_cast<int>(ethr_band.size()) >= nband,
"DiagoCG::diag: ethr_band size must be >= nband");

if (precision_mode_ == PrecisionMode::kMixed)
{
#ifdef ENABLE_MIXED_PRECISION
return diag_mixed_precision(hpsi_func,
spsi_func,
ld_psi,
nband,
dim,
psi_in,
eigenvalue_in,
ethr_band,
prec);
#endif
}

auto psi = ct::TensorMap(psi_in,
ct::DataTypeToEnum<T>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
Expand Down
50 changes: 50 additions & 0 deletions source/source_hsolver/diago_cg.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,42 @@
#include <ATen/core/tensor.h>
#include <ATen/core/tensor_types.h>

#include <string>

namespace hsolver {

/**
* @brief Precision mode for diagonalization solvers.
*/
enum class PrecisionMode
{
kDouble = 0, ///< Pure double precision (default)
kFloat = 1, ///< Pure single precision
kMixed = 2 ///< Mixed precision (float iteration + double refinement)
};

} // namespace hsolver

inline hsolver::PrecisionMode parse_precision_mode(const std::string& mode_str)
{
if (mode_str == "float" || mode_str == "single")
return hsolver::PrecisionMode::kFloat;
if (mode_str == "mixed" || mode_str == "auto")
return hsolver::PrecisionMode::kMixed;
return hsolver::PrecisionMode::kDouble;
}

inline std::string precision_mode_to_string(hsolver::PrecisionMode mode)
{
switch (mode)
{
case hsolver::PrecisionMode::kFloat: return "float";
case hsolver::PrecisionMode::kMixed: return "mixed";
case hsolver::PrecisionMode::kDouble:
default: return "double";
}
}

namespace hsolver {

template <typename T, typename Device = base_device::DEVICE_CPU>
Expand All @@ -25,6 +61,7 @@ class DiagoCG final
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
using SPsiFunc = std::function<void(T*, T*, const int, const int)>;
using SubspaceFunc = std::function<void(T*, T*, const int, const int, const bool)>;

// Constructor need:
// 1. temporary mock of Hamiltonian "Hamilt_PW"
// 2. precondition pointer should point to place of precondition array.
Expand All @@ -38,6 +75,8 @@ class DiagoCG final
const int& pw_diag_nmax,
const int& nproc_in_pool);

void set_precision_mode(const PrecisionMode mode) { precision_mode_ = mode; }

~DiagoCG();

// virtual void init(){};
Expand Down Expand Up @@ -80,6 +119,7 @@ class DiagoCG final
std::string calculation_ = {};

bool need_subspace_ = false;
PrecisionMode precision_mode_ = PrecisionMode::kDouble;
/// A function object that performs the hPsi calculation.
HPsiFunc hpsi_func_ = nullptr;
/// A function object that performs the sPsi calculation.
Expand Down Expand Up @@ -133,6 +173,16 @@ class DiagoCG final
ct::Tensor& eigen,
const std::vector<double>& ethr_band);

double diag_mixed_precision(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int ld_psi,
const int nband,
const int dim,
T* psi_in,
Real* eigenvalue_in,
const std::vector<double>& ethr_band,
const Real* prec);

bool test_exit_cond(const int& ntry, const int& notconv) const;

using dot_real_op = ModuleBase::dot_real_op<T, Device>;
Expand Down
Loading
Loading