Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
release/2.12
40e21dcd4b92d59842b3e3b7f542f855dedddb91
2 changes: 1 addition & 1 deletion runtime/core/portable_type/c10/c10/util/BFloat16-math.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ template <
typename T,
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
inline T rsqrt(T a) {
return 1.0 / std::sqrt(float(a));
return 1.0f / std::sqrt(float(a));
}
template <
typename T,
Expand Down
2 changes: 1 addition & 1 deletion runtime/core/portable_type/c10/c10/util/complex_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
#if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \
defined(__HIPCC__)
defined(__HIPCC__) || defined(__SYCL_DEVICE_ONLY__)
// For Mac, the new implementation yielded a high relative error. Falling back
// to the old version for now.
// See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
/*
* - Choose either results of conversion of input as a normalized number, or
* as a denormalized number, depending on the input exponent. The variable
* two_w contains input exponent in bits 27-31, therefore if its smaller than
* two_w contains input exponent in bits 27-31, therefore if it's smaller than
* 2**27, the input is either a denormal number, or zero.
* - Combine the result of conversion of exponent and mantissa with the sign
* of the input number.
Expand Down
55 changes: 55 additions & 0 deletions runtime/core/portable_type/c10/torch/headeronly/util/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <complex>

#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>

#if defined(__CUDACC__) || defined(__HIPCC__)
Expand Down Expand Up @@ -588,6 +589,60 @@ struct alignas(4) complex<Half> {
}
};

template <>
struct alignas(4) complex<BFloat16> {
BFloat16 real_;
BFloat16 imag_;

// Constructors
complex() = default;
// BFloat16 constructor is not constexpr so the following constructor can't
// be constexpr
C10_HOST_DEVICE explicit inline complex(
const BFloat16& real,
const BFloat16& imag)
: real_(real), imag_(imag) {}
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
: real_(value.real()), imag_(value.imag()) {}

// Conversion operator
inline C10_HOST_DEVICE operator c10::complex<float>() const {
return {real_, imag_};
}

constexpr C10_HOST_DEVICE BFloat16 real() const {
return real_;
}
constexpr C10_HOST_DEVICE BFloat16 imag() const {
return imag_;
}

C10_HOST_DEVICE complex<BFloat16>& operator+=(
const complex<BFloat16>& other) {
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
return *this;
}

C10_HOST_DEVICE complex<BFloat16>& operator-=(
const complex<BFloat16>& other) {
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
return *this;
}

C10_HOST_DEVICE complex<BFloat16>& operator*=(
const complex<BFloat16>& other) {
auto a = static_cast<float>(real_);
auto b = static_cast<float>(imag_);
auto c = static_cast<float>(other.real());
auto d = static_cast<float>(other.imag());
real_ = a * c - b * d;
imag_ = a * d + b * c;
return *this;
}
};

} // namespace c10

HIDDEN_NAMESPACE_BEGIN(torch, headeronly)
Expand Down
2 changes: 1 addition & 1 deletion torch_pin.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
TORCH_VERSION = "2.12.0"
# NIGHTLY_VERSION = "dev20260318" Temporarily pinning to stable release candidate. Revert https://github.com/pytorch/executorch/pull/18287
NIGHTLY_VERSION = "dev20260614"
Loading