aboutsummaryrefslogtreecommitdiff
path: root/libc
diff options
context:
space:
mode:
authorTue Ly <lntue@google.com>2022-01-25 15:11:15 -0500
committerTue Ly <lntue@google.com>2022-01-28 13:39:03 -0500
commitad4ee2d778a8956e578632aeba4e85bc4c8da508 (patch)
treea20f280f86007315d0ea559cfa0970b8b42dc730 /libc
parent[mlir][taco] Accept an integer list for the ordering when defining a tensor f... (diff)
downloadllvm-project-ad4ee2d778a8956e578632aeba4e85bc4c8da508.tar.gz
llvm-project-ad4ee2d778a8956e578632aeba4e85bc4c8da508.tar.bz2
llvm-project-ad4ee2d778a8956e578632aeba4e85bc4c8da508.zip
[libc] Refactor sqrt implementations and add tests for generic sqrt implementations.
Re-apply https://reviews.llvm.org/D118173 with fix for aarch64. Reviewed By: michaelrj Differential Revision: https://reviews.llvm.org/D118433
Diffstat (limited to 'libc')
-rw-r--r--libc/src/__support/FPUtil/CMakeLists.txt11
-rw-r--r--libc/src/__support/FPUtil/Sqrt.h192
-rw-r--r--libc/src/__support/FPUtil/aarch64/sqrt.h38
-rw-r--r--libc/src/__support/FPUtil/generic/CMakeLists.txt6
-rw-r--r--libc/src/__support/FPUtil/generic/sqrt.h214
-rw-r--r--libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h (renamed from libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h)48
-rw-r--r--libc/src/__support/FPUtil/sqrt.h22
-rw-r--r--libc/src/__support/FPUtil/x86_64/sqrt.h44
-rw-r--r--libc/src/math/aarch64/CMakeLists.txt20
-rw-r--r--libc/src/math/generic/CMakeLists.txt12
-rw-r--r--libc/src/math/generic/sqrt.cpp2
-rw-r--r--libc/src/math/generic/sqrtf.cpp2
-rw-r--r--libc/src/math/generic/sqrtl.cpp2
-rw-r--r--libc/src/math/x86_64/CMakeLists.txt30
-rw-r--r--libc/src/math/x86_64/sqrt.cpp20
-rw-r--r--libc/src/math/x86_64/sqrtf.cpp20
-rw-r--r--libc/src/math/x86_64/sqrtl.cpp20
-rw-r--r--libc/test/src/math/CMakeLists.txt77
-rw-r--r--libc/test/src/math/generic_sqrt_test.cpp13
-rw-r--r--libc/test/src/math/generic_sqrtf_test.cpp13
-rw-r--r--libc/test/src/math/generic_sqrtl_test.cpp13
21 files changed, 469 insertions, 350 deletions
diff --git a/libc/src/__support/FPUtil/CMakeLists.txt b/libc/src/__support/FPUtil/CMakeLists.txt
index 6d005a9166c2..d02cd9fcce0e 100644
--- a/libc/src/__support/FPUtil/CMakeLists.txt
+++ b/libc/src/__support/FPUtil/CMakeLists.txt
@@ -22,3 +22,14 @@ add_header_library(
libc.src.__support.common
libc.src.__support.CPP.standalone_cpp
)
+
+add_header_library(
+ sqrt
+ HDRS
+ sqrt.h
+ DEPENDS
+ .fputil
+ libc.src.__support.FPUtil.generic.sqrt
+)
+
+add_subdirectory(generic)
diff --git a/libc/src/__support/FPUtil/Sqrt.h b/libc/src/__support/FPUtil/Sqrt.h
deleted file mode 100644
index 652883ffc96b..000000000000
--- a/libc/src/__support/FPUtil/Sqrt.h
+++ /dev/null
@@ -1,192 +0,0 @@
-//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
-#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
-
-#include "FPBits.h"
-#include "PlatformDefs.h"
-
-#include "src/__support/CPP/TypeTraits.h"
-
-namespace __llvm_libc {
-namespace fputil {
-
-namespace internal {
-
-template <typename T>
-static inline void normalize(int &exponent,
- typename FPBits<T>::UIntType &mantissa);
-
-template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
- // Use binary search to shift the leading 1 bit.
- // With MantissaWidth<float> = 23, it will take
- // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
- // Step 1: 0000 0000 0000 XXXX XXXX XXXX
- // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
- // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
- // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
- // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
- constexpr int NSTEPS = 5; // = ceil(log2(MantissaWidth))
- constexpr uint32_t BOUNDS[NSTEPS] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
- 1 << 23};
- constexpr int SHIFTS[NSTEPS] = {12, 6, 3, 2, 1};
-
- for (int i = 0; i < NSTEPS; ++i) {
- if (mantissa < BOUNDS[i]) {
- exponent -= SHIFTS[i];
- mantissa <<= SHIFTS[i];
- }
- }
-}
-
-template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
- // Use binary search to shift the leading 1 bit similar to float.
- // With MantissaWidth<double> = 52, it will take
- // ceil(log2(52)) = 6 steps checking the mantissa bits.
- constexpr int NSTEPS = 6; // = ceil(log2(MantissaWidth))
- constexpr uint64_t BOUNDS[NSTEPS] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
- 1ULL << 49, 1ULL << 51, 1ULL << 52};
- constexpr int SHIFTS[NSTEPS] = {27, 14, 7, 4, 2, 1};
-
- for (int i = 0; i < NSTEPS; ++i) {
- if (mantissa < BOUNDS[i]) {
- exponent -= SHIFTS[i];
- mantissa <<= SHIFTS[i];
- }
- }
-}
-
-#ifdef LONG_DOUBLE_IS_DOUBLE
-template <>
-inline void normalize<long double>(int &exponent, uint64_t &mantissa) {
- normalize<double>(exponent, mantissa);
-}
-#elif !defined(SPECIAL_X86_LONG_DOUBLE)
-template <>
-inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
- // Use binary search to shift the leading 1 bit similar to float.
- // With MantissaWidth<long double> = 112, it will take
- // ceil(log2(112)) = 7 steps checking the mantissa bits.
- constexpr int NSTEPS = 7; // = ceil(log2(MantissaWidth))
- constexpr __uint128_t BOUNDS[NSTEPS] = {
- __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98,
- __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
- __uint128_t(1) << 112};
- constexpr int SHIFTS[NSTEPS] = {57, 29, 15, 8, 4, 2, 1};
-
- for (int i = 0; i < NSTEPS; ++i) {
- if (mantissa < BOUNDS[i]) {
- exponent -= SHIFTS[i];
- mantissa <<= SHIFTS[i];
- }
- }
-}
-#endif
-
-} // namespace internal
-
-// Correctly rounded IEEE 754 SQRT with round to nearest, ties to even.
-// Shift-and-add algorithm.
-template <typename T,
- cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
-static inline T sqrt(T x) {
- using UIntType = typename FPBits<T>::UIntType;
- constexpr UIntType ONE = UIntType(1) << MantissaWidth<T>::VALUE;
-
- FPBits<T> bits(x);
-
- if (bits.is_inf_or_nan()) {
- if (bits.get_sign() && (bits.get_mantissa() == 0)) {
- // sqrt(-Inf) = NaN
- return FPBits<T>::build_nan(ONE >> 1);
- } else {
- // sqrt(NaN) = NaN
- // sqrt(+Inf) = +Inf
- return x;
- }
- } else if (bits.is_zero()) {
- // sqrt(+0) = +0
- // sqrt(-0) = -0
- return x;
- } else if (bits.get_sign()) {
- // sqrt( negative numbers ) = NaN
- return FPBits<T>::build_nan(ONE >> 1);
- } else {
- int x_exp = bits.get_exponent();
- UIntType x_mant = bits.get_mantissa();
-
- // Step 1a: Normalize denormal input and append hidden bit to the mantissa
- if (bits.get_unbiased_exponent() == 0) {
- ++x_exp; // let x_exp be the correct exponent of ONE bit.
- internal::normalize<T>(x_exp, x_mant);
- } else {
- x_mant |= ONE;
- }
-
- // Step 1b: Make sure the exponent is even.
- if (x_exp & 1) {
- --x_exp;
- x_mant <<= 1;
- }
-
- // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
- // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
- // Notice that the output of sqrt is always in the normal range.
- // To perform shift-and-add algorithm to find y, let denote:
- // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
- // r(n) = 2^n ( x_mant - y(n)^2 ).
- // That leads to the following recurrence formula:
- // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
- // with the initial conditions: y(0) = 1, and r(0) = x - 1.
- // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
- // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
- // 0 otherwise.
- UIntType y = ONE;
- UIntType r = x_mant - ONE;
-
- for (UIntType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
- r <<= 1;
- UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
- if (r >= tmp) {
- r -= tmp;
- y += current_bit;
- }
- }
-
- // We compute one more iteration in order to round correctly.
- bool lsb = y & 1; // Least significant bit
- bool rb = false; // Round bit
- r <<= 2;
- UIntType tmp = (y << 2) + 1;
- if (r >= tmp) {
- r -= tmp;
- rb = true;
- }
-
- // Remove hidden bit and append the exponent field.
- x_exp = ((x_exp >> 1) + FPBits<T>::EXPONENT_BIAS);
-
- y = (y - ONE) | (static_cast<UIntType>(x_exp) << MantissaWidth<T>::VALUE);
- // Round to nearest, ties to even
- if (rb && (lsb || (r != 0))) {
- ++y;
- }
-
- return *reinterpret_cast<T *>(&y);
- }
-}
-
-} // namespace fputil
-} // namespace __llvm_libc
-
-#ifdef SPECIAL_X86_LONG_DOUBLE
-#include "x86_64/SqrtLongDouble.h"
-#endif // SPECIAL_X86_LONG_DOUBLE
-
-#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
diff --git a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h
new file mode 100644
index 000000000000..479ebf76b678
--- /dev/null
+++ b/libc/src/__support/FPUtil/aarch64/sqrt.h
@@ -0,0 +1,38 @@
+//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H
+#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H
+
+#include "src/__support/architectures.h"
+
+#if !defined(LLVM_LIBC_ARCH_AARCH64)
+#error "Invalid include"
+#endif
+
+#include "src/__support/FPUtil/generic/sqrt.h"
+
+namespace __llvm_libc {
+namespace fputil {
+
+template <> inline float sqrt<float>(float x) {
+ float y;
+ __asm__ __volatile__("fsqrt %s0, %s1\n\t" : "=w"(y) : "w"(x));
+ return y;
+}
+
+template <> inline double sqrt<double>(double x) {
+ double y;
+ __asm__ __volatile__("fsqrt %d0, %d1\n\t" : "=w"(y) : "w"(x));
+ return y;
+}
+
+} // namespace fputil
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H
diff --git a/libc/src/__support/FPUtil/generic/CMakeLists.txt b/libc/src/__support/FPUtil/generic/CMakeLists.txt
new file mode 100644
index 000000000000..bf69e7dd961c
--- /dev/null
+++ b/libc/src/__support/FPUtil/generic/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_header_library(
+ sqrt
+ HDRS
+ sqrt.h
+ sqrt_80_bit_long_double.h
+)
diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
new file mode 100644
index 000000000000..1882d3e82ecf
--- /dev/null
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -0,0 +1,214 @@
+//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
+#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
+
+#include "sqrt_80_bit_long_double.h"
+#include "src/__support/CPP/TypeTraits.h"
+#include "src/__support/FPUtil/FEnvImpl.h"
+#include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/FPUtil/PlatformDefs.h"
+
+namespace __llvm_libc {
+namespace fputil {
+
+namespace internal {
+
+template <typename T> struct SpecialLongDouble {
+ static constexpr bool VALUE = false;
+};
+
+#if defined(SPECIAL_X86_LONG_DOUBLE)
+template <> struct SpecialLongDouble<long double> {
+ static constexpr bool VALUE = true;
+};
+#endif // SPECIAL_X86_LONG_DOUBLE
+
+template <typename T>
+static inline void normalize(int &exponent,
+ typename FPBits<T>::UIntType &mantissa);
+
+template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
+ // Use binary search to shift the leading 1 bit.
+ // With MantissaWidth<float> = 23, it will take
+ // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
+ // Step 1: 0000 0000 0000 XXXX XXXX XXXX
+ // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
+ // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
+ // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
+ // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
+ constexpr int NSTEPS = 5; // = ceil(log2(MantissaWidth))
+ constexpr uint32_t BOUNDS[NSTEPS] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
+ 1 << 23};
+ constexpr int SHIFTS[NSTEPS] = {12, 6, 3, 2, 1};
+
+ for (int i = 0; i < NSTEPS; ++i) {
+ if (mantissa < BOUNDS[i]) {
+ exponent -= SHIFTS[i];
+ mantissa <<= SHIFTS[i];
+ }
+ }
+}
+
+template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
+ // Use binary search to shift the leading 1 bit similar to float.
+ // With MantissaWidth<double> = 52, it will take
+ // ceil(log2(52)) = 6 steps checking the mantissa bits.
+ constexpr int NSTEPS = 6; // = ceil(log2(MantissaWidth))
+ constexpr uint64_t BOUNDS[NSTEPS] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
+ 1ULL << 49, 1ULL << 51, 1ULL << 52};
+ constexpr int SHIFTS[NSTEPS] = {27, 14, 7, 4, 2, 1};
+
+ for (int i = 0; i < NSTEPS; ++i) {
+ if (mantissa < BOUNDS[i]) {
+ exponent -= SHIFTS[i];
+ mantissa <<= SHIFTS[i];
+ }
+ }
+}
+
+#ifdef LONG_DOUBLE_IS_DOUBLE
+template <>
+inline void normalize<long double>(int &exponent, uint64_t &mantissa) {
+ normalize<double>(exponent, mantissa);
+}
+#elif !defined(SPECIAL_X86_LONG_DOUBLE)
+template <>
+inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
+ // Use binary search to shift the leading 1 bit similar to float.
+ // With MantissaWidth<long double> = 112, it will take
+ // ceil(log2(112)) = 7 steps checking the mantissa bits.
+ constexpr int NSTEPS = 7; // = ceil(log2(MantissaWidth))
+ constexpr __uint128_t BOUNDS[NSTEPS] = {
+ __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98,
+ __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
+ __uint128_t(1) << 112};
+ constexpr int SHIFTS[NSTEPS] = {57, 29, 15, 8, 4, 2, 1};
+
+ for (int i = 0; i < NSTEPS; ++i) {
+ if (mantissa < BOUNDS[i]) {
+ exponent -= SHIFTS[i];
+ mantissa <<= SHIFTS[i];
+ }
+ }
+}
+#endif
+
+} // namespace internal
+
+// Correctly rounded IEEE 754 SQRT for all rounding modes.
+// Shift-and-add algorithm.
+template <typename T>
+static inline cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, T>
+sqrt(T x) {
+
+ if constexpr (internal::SpecialLongDouble<T>::VALUE) {
+ // Special 80-bit long double.
+ return x86::sqrt(x);
+ } else {
+ // IEEE floating points formats.
+ using UIntType = typename FPBits<T>::UIntType;
+ constexpr UIntType ONE = UIntType(1) << MantissaWidth<T>::VALUE;
+
+ FPBits<T> bits(x);
+
+ if (bits.is_inf_or_nan()) {
+ if (bits.get_sign() && (bits.get_mantissa() == 0)) {
+ // sqrt(-Inf) = NaN
+ return FPBits<T>::build_nan(ONE >> 1);
+ } else {
+ // sqrt(NaN) = NaN
+ // sqrt(+Inf) = +Inf
+ return x;
+ }
+ } else if (bits.is_zero()) {
+ // sqrt(+0) = +0
+ // sqrt(-0) = -0
+ return x;
+ } else if (bits.get_sign()) {
+ // sqrt( negative numbers ) = NaN
+ return FPBits<T>::build_nan(ONE >> 1);
+ } else {
+ int x_exp = bits.get_exponent();
+ UIntType x_mant = bits.get_mantissa();
+
+ // Step 1a: Normalize denormal input and append hidden bit to the mantissa
+ if (bits.get_unbiased_exponent() == 0) {
+ ++x_exp; // let x_exp be the correct exponent of ONE bit.
+ internal::normalize<T>(x_exp, x_mant);
+ } else {
+ x_mant |= ONE;
+ }
+
+ // Step 1b: Make sure the exponent is even.
+ if (x_exp & 1) {
+ --x_exp;
+ x_mant <<= 1;
+ }
+
+ // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
+ // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
+ // Notice that the output of sqrt is always in the normal range.
+ // To perform shift-and-add algorithm to find y, let denote:
+ // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
+ // r(n) = 2^n ( x_mant - y(n)^2 ).
+ // That leads to the following recurrence formula:
+ // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
+ // with the initial conditions: y(0) = 1, and r(0) = x - 1.
+ // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
+ // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
+ // 0 otherwise.
+ UIntType y = ONE;
+ UIntType r = x_mant - ONE;
+
+ for (UIntType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
+ r <<= 1;
+ UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+ if (r >= tmp) {
+ r -= tmp;
+ y += current_bit;
+ }
+ }
+
+ // We compute one more iteration in order to round correctly.
+ bool lsb = y & 1; // Least significant bit
+ bool rb = false; // Round bit
+ r <<= 2;
+ UIntType tmp = (y << 2) + 1;
+ if (r >= tmp) {
+ r -= tmp;
+ rb = true;
+ }
+
+ // Remove hidden bit and append the exponent field.
+ x_exp = ((x_exp >> 1) + FPBits<T>::EXPONENT_BIAS);
+
+ y = (y - ONE) | (static_cast<UIntType>(x_exp) << MantissaWidth<T>::VALUE);
+
+ switch (get_round()) {
+ case FE_TONEAREST:
+ // Round to nearest, ties to even
+ if (rb && (lsb || (r != 0)))
+ ++y;
+ break;
+ case FE_UPWARD:
+ if (rb || (r != 0))
+ ++y;
+ break;
+ }
+
+ return *reinterpret_cast<T *>(&y);
+ }
+ }
+}
+
+} // namespace fputil
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
diff --git a/libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h b/libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h
index 22d2ba2592c8..82a996b21378 100644
--- a/libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h
+++ b/libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h
@@ -6,26 +6,18 @@
//
//===----------------------------------------------------------------------===//
-#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H
-#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H
+#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H
+#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H
-#include "src/__support/architectures.h"
-
-#if !defined(LLVM_LIBC_ARCH_X86)
-#error "Invalid include"
-#endif
-
-#include "src/__support/CPP/TypeTraits.h"
+#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
-#include "src/__support/FPUtil/Sqrt.h"
+#include "src/__support/FPUtil/PlatformDefs.h"
namespace __llvm_libc {
namespace fputil {
+namespace x86 {
-namespace internal {
-
-template <>
-inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
+inline void normalize(int &exponent, __uint128_t &mantissa) {
// Use binary search to shift the leading 1 bit similar to float.
// With MantissaWidth<long double> = 63, it will take
// ceil(log2(63)) = 6 steps checking the mantissa bits.
@@ -43,11 +35,14 @@ inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
}
}
-} // namespace internal
+// if constexpr statement in sqrt.h still requires x86::sqrt to be declared
+// even when it's not used.
+static inline long double sqrt(long double x);
-// Correctly rounded SQRT with round to nearest, ties to even.
+// Correctly rounded SQRT for all rounding modes.
// Shift-and-add algorithm.
-template <> inline long double sqrt<long double, 0>(long double x) {
+#if defined(SPECIAL_X86_LONG_DOUBLE)
+static inline long double sqrt(long double x) {
using UIntType = typename FPBits<long double>::UIntType;
constexpr UIntType ONE = UIntType(1)
<< int(MantissaWidth<long double>::VALUE);
@@ -78,7 +73,7 @@ template <> inline long double sqrt<long double, 0>(long double x) {
if (bits.get_implicit_bit()) {
x_mant |= ONE;
} else if (bits.get_unbiased_exponent() == 0) {
- internal::normalize<long double>(x_exp, x_mant);
+ normalize(x_exp, x_mant);
}
// Step 1b: Make sure the exponent is even.
@@ -126,9 +121,16 @@ template <> inline long double sqrt<long double, 0>(long double x) {
y |= (static_cast<UIntType>(x_exp)
<< (MantissaWidth<long double>::VALUE + 1));
- // Round to nearest, ties to even
- if (rb && (lsb || (r != 0))) {
- ++y;
+ switch (get_round()) {
+ case FE_TONEAREST:
+ // Round to nearest, ties to even
+ if (rb && (lsb || (r != 0)))
+ ++y;
+ break;
+ case FE_UPWARD:
+ if (rb || (r != 0))
+ ++y;
+ break;
}
// Extract output
@@ -140,8 +142,10 @@ template <> inline long double sqrt<long double, 0>(long double x) {
return out;
}
}
+#endif // SPECIAL_X86_LONG_DOUBLE
+} // namespace x86
} // namespace fputil
} // namespace __llvm_libc
-#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H
+#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H
diff --git a/libc/src/__support/FPUtil/sqrt.h b/libc/src/__support/FPUtil/sqrt.h
new file mode 100644
index 000000000000..6e02d9c77e8e
--- /dev/null
+++ b/libc/src/__support/FPUtil/sqrt.h
@@ -0,0 +1,22 @@
+//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
+#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
+
+#include "src/__support/architectures.h"
+
+#if defined(LLVM_LIBC_ARCH_X86_64)
+#include "x86_64/sqrt.h"
+#elif defined(LLVM_LIBC_ARCH_AARCH64)
+#include "aarch64/sqrt.h"
+#else
+#include "generic/sqrt.h"
+
+#endif
+#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
diff --git a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h
new file mode 100644
index 000000000000..8a8f8cf2238d
--- /dev/null
+++ b/libc/src/__support/FPUtil/x86_64/sqrt.h
@@ -0,0 +1,44 @@
+//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H
+#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H
+
+#include "src/__support/architectures.h"
+
+#if !defined(LLVM_LIBC_ARCH_X86)
+#error "Invalid include"
+#endif
+
+#include "src/__support/FPUtil/generic/sqrt.h"
+
+namespace __llvm_libc {
+namespace fputil {
+
+template <> inline float sqrt<float>(float x) {
+ float result;
+ __asm__ __volatile__("sqrtss %x1, %x0" : "=x"(result) : "x"(x));
+ return result;
+}
+
+template <> inline double sqrt<double>(double x) {
+ double result;
+ __asm__ __volatile__("sqrtsd %x1, %x0" : "=x"(result) : "x"(x));
+ return result;
+}
+
+template <> inline long double sqrt<long double>(long double x) {
+ long double result;
+ __asm__ __volatile__("fsqrt" : "=t"(result) : "t"(x));
+ return result;
+}
+
+} // namespace fputil
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H
diff --git a/libc/src/math/aarch64/CMakeLists.txt b/libc/src/math/aarch64/CMakeLists.txt
index 6ce89441857c..bbe927a1c7c8 100644
--- a/libc/src/math/aarch64/CMakeLists.txt
+++ b/libc/src/math/aarch64/CMakeLists.txt
@@ -77,23 +77,3 @@ add_entrypoint_object(
COMPILE_OPTIONS
-O2
)
-
-add_entrypoint_object(
- sqrt
- SRCS
- sqrt.cpp
- HDRS
- ../sqrt.h
- COMPILE_OPTIONS
- -O2
-)
-
-add_entrypoint_object(
- sqrtf
- SRCS
- sqrtf.cpp
- HDRS
- ../sqrtf.h
- COMPILE_OPTIONS
- -O2
-)
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index 88c29eb4ba0b..df2ef34d42ce 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -859,8 +859,10 @@ add_entrypoint_object(
../sqrt.h
DEPENDS
libc.src.__support.FPUtil.fputil
+ libc.src.__support.FPUtil.sqrt
COMPILE_OPTIONS
- -O2
+ -O3
+ -Wno-c++17-extensions
)
add_entrypoint_object(
@@ -871,8 +873,10 @@ add_entrypoint_object(
../sqrtf.h
DEPENDS
libc.src.__support.FPUtil.fputil
+ libc.src.__support.FPUtil.sqrt
COMPILE_OPTIONS
- -O2
+ -O3
+ -Wno-c++17-extensions
)
add_entrypoint_object(
@@ -883,8 +887,10 @@ add_entrypoint_object(
../sqrtl.h
DEPENDS
libc.src.__support.FPUtil.fputil
+ libc.src.__support.FPUtil.sqrt
COMPILE_OPTIONS
- -O2
+ -O3
+ -Wno-c++17-extensions
)
add_entrypoint_object(
diff --git a/libc/src/math/generic/sqrt.cpp b/libc/src/math/generic/sqrt.cpp
index bd43a5c6919a..de21f329e15a 100644
--- a/libc/src/math/generic/sqrt.cpp
+++ b/libc/src/math/generic/sqrt.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "src/math/sqrt.h"
-#include "src/__support/FPUtil/Sqrt.h"
+#include "src/__support/FPUtil/sqrt.h"
#include "src/__support/common.h"
namespace __llvm_libc {
diff --git a/libc/src/math/generic/sqrtf.cpp b/libc/src/math/generic/sqrtf.cpp
index bae39dd4b27e..3ca8d381898b 100644
--- a/libc/src/math/generic/sqrtf.cpp
+++ b/libc/src/math/generic/sqrtf.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "src/math/sqrtf.h"
-#include "src/__support/FPUtil/Sqrt.h"
+#include "src/__support/FPUtil/sqrt.h"
#include "src/__support/common.h"
namespace __llvm_libc {
diff --git a/libc/src/math/generic/sqrtl.cpp b/libc/src/math/generic/sqrtl.cpp
index efbc98eed844..970646a2e4d1 100644
--- a/libc/src/math/generic/sqrtl.cpp
+++ b/libc/src/math/generic/sqrtl.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "src/math/sqrtl.h"
-#include "src/__support/FPUtil/Sqrt.h"
+#include "src/__support/FPUtil/sqrt.h"
#include "src/__support/common.h"
namespace __llvm_libc {
diff --git a/libc/src/math/x86_64/CMakeLists.txt b/libc/src/math/x86_64/CMakeLists.txt
index d2a48231b787..cd129e3eefb7 100644
--- a/libc/src/math/x86_64/CMakeLists.txt
+++ b/libc/src/math/x86_64/CMakeLists.txt
@@ -27,33 +27,3 @@ add_entrypoint_object(
COMPILE_OPTIONS
-O2
)
-
-add_entrypoint_object(
- sqrt
- SRCS
- sqrt.cpp
- HDRS
- ../sqrt.h
- COMPILE_OPTIONS
- -O2
-)
-
-add_entrypoint_object(
- sqrtf
- SRCS
- sqrtf.cpp
- HDRS
- ../sqrtf.h
- COMPILE_OPTIONS
- -O2
-)
-
-add_entrypoint_object(
- sqrtl
- SRCS
- sqrtl.cpp
- HDRS
- ../sqrtl.h
- COMPILE_OPTIONS
- -O2
-)
diff --git a/libc/src/math/x86_64/sqrt.cpp b/libc/src/math/x86_64/sqrt.cpp
deleted file mode 100644
index 5d4e9424e603..000000000000
--- a/libc/src/math/x86_64/sqrt.cpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//===-- Implementation of the sqrt function for x86_64 --------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "src/math/sqrt.h"
-#include "src/__support/common.h"
-
-namespace __llvm_libc {
-
-LLVM_LIBC_FUNCTION(double, sqrt, (double x)) {
- double result;
- __asm__ __volatile__("sqrtsd %x1, %x0" : "=x"(result) : "x"(x));
- return result;
-}
-
-} // namespace __llvm_libc
diff --git a/libc/src/math/x86_64/sqrtf.cpp b/libc/src/math/x86_64/sqrtf.cpp
deleted file mode 100644
index 51d22dff2cbc..000000000000
--- a/libc/src/math/x86_64/sqrtf.cpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//===-- Implementation of the sqrtf function for x86_64 -------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "src/math/sqrtf.h"
-#include "src/__support/common.h"
-
-namespace __llvm_libc {
-
-LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) {
- float result;
- __asm__ __volatile__("sqrtss %x1, %x0" : "=x"(result) : "x"(x));
- return result;
-}
-
-} // namespace __llvm_libc
diff --git a/libc/src/math/x86_64/sqrtl.cpp b/libc/src/math/x86_64/sqrtl.cpp
deleted file mode 100644
index 8b0c39e95fdd..000000000000
--- a/libc/src/math/x86_64/sqrtl.cpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//===-- Implementation of the sqrtl function for x86_64 -------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "src/math/sqrtl.h"
-#include "src/__support/common.h"
-
-namespace __llvm_libc {
-
-LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) {
- long double result;
- __asm__ __volatile__("fsqrt" : "=t"(result) : "t"(x));
- return result;
-}
-
-} // namespace __llvm_libc
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index 73ecef959aba..827dc867ff51 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -983,26 +983,63 @@ add_fp_unittest(
libc.src.__support.FPUtil.fputil
)
-# The quad precision test for sqrt against MPFR currently suffers
-# from insufficient precision in MPFR calculations leading to
-# https://hal.archives-ouvertes.fr/hal-01091186/document. We will
-# renable after fixing the precision issue.
-if(${LIBC_TARGET_ARCHITECTURE_IS_X86})
- add_fp_unittest(
- sqrtl_test
- NEED_MPFR
- SUITE
- libc_math_unittests
- SRCS
- sqrtl_test.cpp
- DEPENDS
- libc.include.math
- libc.src.math.sqrtl
- libc.src.__support.FPUtil.fputil
- )
-else()
- message(STATUS "Skipping sqrtl_test")
-endif()
+add_fp_unittest(
+ sqrtl_test
+ NEED_MPFR
+ SUITE
+ libc_math_unittests
+ SRCS
+ sqrtl_test.cpp
+ DEPENDS
+ libc.include.math
+ libc.src.math.sqrtl
+ libc.src.__support.FPUtil.fputil
+)
+
+add_fp_unittest(
+ generic_sqrtf_test
+ NEED_MPFR
+ SUITE
+ libc_math_unittests
+ SRCS
+ generic_sqrtf_test.cpp
+ DEPENDS
+ libc.src.__support.FPUtil.fputil
+ libc.src.__support.FPUtil.generic.sqrt
+ COMPILE_OPTIONS
+ -O3
+ -Wno-c++17-extensions
+)
+
+add_fp_unittest(
+ generic_sqrt_test
+ NEED_MPFR
+ SUITE
+ libc_math_unittests
+ SRCS
+ generic_sqrt_test.cpp
+ DEPENDS
+ libc.src.__support.FPUtil.fputil
+ libc.src.__support.FPUtil.generic.sqrt
+ COMPILE_OPTIONS
+ -O3
+ -Wno-c++17-extensions
+)
+
+add_fp_unittest(
+ generic_sqrtl_test
+ NEED_MPFR
+ SUITE
+ libc_math_unittests
+ SRCS
+ generic_sqrtl_test.cpp
+ DEPENDS
+ libc.src.__support.FPUtil.fputil
+ libc.src.__support.FPUtil.generic.sqrt
+ COMPILE_OPTIONS
+ -O3
+ -Wno-c++17-extensions
+)
add_fp_unittest(
remquof_test
diff --git a/libc/test/src/math/generic_sqrt_test.cpp b/libc/test/src/math/generic_sqrt_test.cpp
new file mode 100644
index 000000000000..cecfc0ee3de3
--- /dev/null
+++ b/libc/test/src/math/generic_sqrt_test.cpp
@@ -0,0 +1,13 @@
+//===-- Unittests for generic implementation of sqrt ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SqrtTest.h"
+
+#include "src/__support/FPUtil/generic/sqrt.h"
+
+LIST_SQRT_TESTS(double, __llvm_libc::fputil::sqrt<double>)
diff --git a/libc/test/src/math/generic_sqrtf_test.cpp b/libc/test/src/math/generic_sqrtf_test.cpp
new file mode 100644
index 000000000000..64bf92133b98
--- /dev/null
+++ b/libc/test/src/math/generic_sqrtf_test.cpp
@@ -0,0 +1,13 @@
+//===-- Unittests for generic implementation of sqrtf----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SqrtTest.h"
+
+#include "src/__support/FPUtil/generic/sqrt.h"
+
+LIST_SQRT_TESTS(float, __llvm_libc::fputil::sqrt<float>)
diff --git a/libc/test/src/math/generic_sqrtl_test.cpp b/libc/test/src/math/generic_sqrtl_test.cpp
new file mode 100644
index 000000000000..6b68aaed9700
--- /dev/null
+++ b/libc/test/src/math/generic_sqrtl_test.cpp
@@ -0,0 +1,13 @@
+//===-- Unittests for generic implementation of sqrtl----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SqrtTest.h"
+
+#include "src/__support/FPUtil/generic/sqrt.h"
+
+LIST_SQRT_TESTS(long double, __llvm_libc::fputil::sqrt<long double>)