mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex<T>`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations.
This commit is contained in:
committed by
Antonio Sánchez
parent
145e51516f
commit
9cb8771e9c
@@ -10,6 +10,7 @@
|
||||
#define EIGEN_NO_STATIC_ASSERT
|
||||
|
||||
#include "main.h"
|
||||
#include "random_without_cast_overflow.h"
|
||||
|
||||
template<typename MatrixType> void basicStuff(const MatrixType& m)
|
||||
{
|
||||
@@ -90,7 +91,7 @@ template<typename MatrixType> void basicStuff(const MatrixType& m)
|
||||
Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> cv(rows);
|
||||
rv = square.row(r);
|
||||
cv = square.col(r);
|
||||
|
||||
|
||||
VERIFY_IS_APPROX(rv, cv.transpose());
|
||||
|
||||
if(cols!=1 && rows!=1 && MatrixType::SizeAtCompileTime!=Dynamic)
|
||||
@@ -120,28 +121,28 @@ template<typename MatrixType> void basicStuff(const MatrixType& m)
|
||||
m1 = m2;
|
||||
VERIFY(m1==m2);
|
||||
VERIFY(!(m1!=m2));
|
||||
|
||||
|
||||
// check automatic transposition
|
||||
sm2.setZero();
|
||||
for(Index i=0;i<rows;++i)
|
||||
sm2.col(i) = sm1.row(i);
|
||||
VERIFY_IS_APPROX(sm2,sm1.transpose());
|
||||
|
||||
|
||||
sm2.setZero();
|
||||
for(Index i=0;i<rows;++i)
|
||||
sm2.col(i).noalias() = sm1.row(i);
|
||||
VERIFY_IS_APPROX(sm2,sm1.transpose());
|
||||
|
||||
|
||||
sm2.setZero();
|
||||
for(Index i=0;i<rows;++i)
|
||||
sm2.col(i).noalias() += sm1.row(i);
|
||||
VERIFY_IS_APPROX(sm2,sm1.transpose());
|
||||
|
||||
|
||||
sm2.setZero();
|
||||
for(Index i=0;i<rows;++i)
|
||||
sm2.col(i).noalias() -= sm1.row(i);
|
||||
VERIFY_IS_APPROX(sm2,-sm1.transpose());
|
||||
|
||||
|
||||
// check ternary usage
|
||||
{
|
||||
bool b = internal::random<int>(0,10)>5;
|
||||
@@ -194,14 +195,72 @@ template<typename MatrixType> void basicStuffComplex(const MatrixType& m)
|
||||
VERIFY(!static_cast<const MatrixType&>(cm).imag().isZero());
|
||||
}
|
||||
|
||||
template<int>
|
||||
void casting()
|
||||
template<typename SrcScalar, typename TgtScalar>
|
||||
void casting_test()
|
||||
{
|
||||
Matrix4f m = Matrix4f::Random(), m2;
|
||||
Matrix4d n = m.cast<double>();
|
||||
VERIFY(m.isApprox(n.cast<float>()));
|
||||
m2 = m.cast<float>(); // check the specialization when NewType == Type
|
||||
VERIFY(m.isApprox(m2));
|
||||
Matrix<SrcScalar,4,4> m;
|
||||
for (int i=0; i<m.rows(); ++i) {
|
||||
for (int j=0; j<m.cols(); ++j) {
|
||||
m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value();
|
||||
}
|
||||
}
|
||||
Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
|
||||
for (int i=0; i<m.rows(); ++i) {
|
||||
for (int j=0; j<m.cols(); ++j) {
|
||||
VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SrcScalar, typename EnableIf = void>
|
||||
struct casting_test_runner {
|
||||
static void run() {
|
||||
casting_test<SrcScalar, bool>();
|
||||
casting_test<SrcScalar, int8_t>();
|
||||
casting_test<SrcScalar, uint8_t>();
|
||||
casting_test<SrcScalar, int16_t>();
|
||||
casting_test<SrcScalar, uint16_t>();
|
||||
casting_test<SrcScalar, int32_t>();
|
||||
casting_test<SrcScalar, uint32_t>();
|
||||
casting_test<SrcScalar, int64_t>();
|
||||
casting_test<SrcScalar, uint64_t>();
|
||||
casting_test<SrcScalar, half>();
|
||||
casting_test<SrcScalar, bfloat16>();
|
||||
casting_test<SrcScalar, float>();
|
||||
casting_test<SrcScalar, double>();
|
||||
casting_test<SrcScalar, std::complex<float>>();
|
||||
casting_test<SrcScalar, std::complex<double>>();
|
||||
}
|
||||
};
|
||||
|
||||
template<typename SrcScalar>
|
||||
struct casting_test_runner<SrcScalar, typename internal::enable_if<(NumTraits<SrcScalar>::IsComplex)>::type>
|
||||
{
|
||||
static void run() {
|
||||
// Only a few casts from std::complex<T> are defined.
|
||||
casting_test<SrcScalar, half>();
|
||||
casting_test<SrcScalar, bfloat16>();
|
||||
casting_test<SrcScalar, std::complex<float>>();
|
||||
casting_test<SrcScalar, std::complex<double>>();
|
||||
}
|
||||
};
|
||||
|
||||
void casting_all() {
|
||||
casting_test_runner<bool>::run();
|
||||
casting_test_runner<int8_t>::run();
|
||||
casting_test_runner<uint8_t>::run();
|
||||
casting_test_runner<int16_t>::run();
|
||||
casting_test_runner<uint16_t>::run();
|
||||
casting_test_runner<int32_t>::run();
|
||||
casting_test_runner<uint32_t>::run();
|
||||
casting_test_runner<int64_t>::run();
|
||||
casting_test_runner<uint64_t>::run();
|
||||
casting_test_runner<half>::run();
|
||||
casting_test_runner<bfloat16>::run();
|
||||
casting_test_runner<float>::run();
|
||||
casting_test_runner<double>::run();
|
||||
casting_test_runner<std::complex<float>>::run();
|
||||
casting_test_runner<std::complex<double>>::run();
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
@@ -210,12 +269,12 @@ void fixedSizeMatrixConstruction()
|
||||
Scalar raw[4];
|
||||
for(int k=0; k<4; ++k)
|
||||
raw[k] = internal::random<Scalar>();
|
||||
|
||||
|
||||
{
|
||||
Matrix<Scalar,4,1> m(raw);
|
||||
Array<Scalar,4,1> a(raw);
|
||||
for(int k=0; k<4; ++k) VERIFY(m(k) == raw[k]);
|
||||
for(int k=0; k<4; ++k) VERIFY(a(k) == raw[k]);
|
||||
for(int k=0; k<4; ++k) VERIFY(a(k) == raw[k]);
|
||||
VERIFY_IS_EQUAL(m,(Matrix<Scalar,4,1>(raw[0],raw[1],raw[2],raw[3])));
|
||||
VERIFY((a==(Array<Scalar,4,1>(raw[0],raw[1],raw[2],raw[3]))).all());
|
||||
}
|
||||
@@ -277,6 +336,7 @@ EIGEN_DECLARE_TEST(basicstuff)
|
||||
CALL_SUBTEST_5( basicStuff(MatrixXcd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_6( basicStuff(Matrix<float, 100, 100>()) );
|
||||
CALL_SUBTEST_7( basicStuff(Matrix<long double,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE),internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_8( casting_all() );
|
||||
|
||||
CALL_SUBTEST_3( basicStuffComplex(MatrixXcf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_5( basicStuffComplex(MatrixXcd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
@@ -288,6 +348,4 @@ EIGEN_DECLARE_TEST(basicstuff)
|
||||
CALL_SUBTEST_1(fixedSizeMatrixConstruction<int>());
|
||||
CALL_SUBTEST_1(fixedSizeMatrixConstruction<long int>());
|
||||
CALL_SUBTEST_1(fixedSizeMatrixConstruction<std::ptrdiff_t>());
|
||||
|
||||
CALL_SUBTEST_2(casting<0>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user