mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
implement the first _real_ unit-tests, testing the results for correctness instead
of just checking compilation. Fix the many issues discovered by these unit-tests, by the way fixing a performance bug.
This commit is contained in:
@@ -27,5 +27,6 @@ namespace Eigen {
|
||||
#include "Core/Trace.h"
|
||||
#include "Core/Dot.h"
|
||||
#include "Core/Random.h"
|
||||
#include "Core/Fuzzy.h"
|
||||
|
||||
} // namespace Eigen
|
||||
|
||||
@@ -75,9 +75,9 @@ template<typename MatrixType> class Column
|
||||
|
||||
template<typename Scalar, typename Derived>
|
||||
Column<Derived>
|
||||
Object<Scalar, Derived>::col(int i)
|
||||
Object<Scalar, Derived>::col(int i) const
|
||||
{
|
||||
return Column<Derived>(static_cast<Derived*>(this)->ref(), i);
|
||||
return Column<Derived>(static_cast<Derived*>(const_cast<Object*>(this))->ref(), i);
|
||||
}
|
||||
|
||||
#endif // EI_COLUMN_H
|
||||
|
||||
@@ -40,7 +40,7 @@ template<int UnrollCount, int Rows> struct CopyHelperUnroller
|
||||
}
|
||||
};
|
||||
|
||||
template<int Rows> struct CopyHelperUnroller<0, Rows>
|
||||
template<int Rows> struct CopyHelperUnroller<1, Rows>
|
||||
{
|
||||
template <typename Derived1, typename Derived2>
|
||||
static void run(Derived1 &dst, const Derived2 &src)
|
||||
|
||||
@@ -32,7 +32,7 @@ struct DotUnroller
|
||||
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
|
||||
{
|
||||
DotUnroller<Index-1, Size, Derived1, Derived2>::run(v1, v2, dot);
|
||||
dot += v1[Index] * Conj(v2[Index]);
|
||||
dot += v1[Index] * NumTraits<typename Derived1::Scalar>::conj(v2[Index]);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -41,7 +41,7 @@ struct DotUnroller<0, Size, Derived1, Derived2>
|
||||
{
|
||||
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
|
||||
{
|
||||
dot = v1[0] * Conj(v2[0]);
|
||||
dot = v1[0] * NumTraits<typename Derived1::Scalar>::conj(v2[0]);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -67,9 +67,9 @@ Scalar Object<Scalar, Derived>::dot(const OtherDerived& other) const
|
||||
::run(*static_cast<const Derived*>(this), other, res);
|
||||
else
|
||||
{
|
||||
res = (*this)[0] * Conj(other[0]);
|
||||
res = (*this)[0] * NumTraits<Scalar>::conj(other[0]);
|
||||
for(int i = 1; i < size(); i++)
|
||||
res += (*this)[i]* Conj(other[i]);
|
||||
res += (*this)[i]* NumTraits<Scalar>::conj(other[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -78,7 +78,7 @@ template<typename Scalar, typename Derived>
|
||||
typename NumTraits<Scalar>::Real Object<Scalar, Derived>::norm2() const
|
||||
{
|
||||
assert(IsVector);
|
||||
return Real(dot(*this));
|
||||
return NumTraits<Scalar>::real(dot(*this));
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Derived>
|
||||
|
||||
@@ -40,7 +40,8 @@ bool Object<Scalar, Derived>::isApprox(
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < cols(); i++)
|
||||
if(!col(i).isApprox(other.col(i), prec))
|
||||
if((col(i) - other.col(i)).norm2()
|
||||
> std::min(col(i).norm2(), other.col(i).norm2()) * prec * prec)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
@@ -59,7 +60,7 @@ bool Object<Scalar, Derived>::isMuchSmallerThan(
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < cols(); i++)
|
||||
if(!col(i).isMuchSmallerThan(other, prec))
|
||||
if(col(i).norm2() > NumTraits<Scalar>::abs2(other) * prec * prec)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
@@ -79,10 +80,10 @@ bool Object<Scalar, Derived>::isMuchSmallerThan(
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < cols(); i++)
|
||||
if(!col(i).isMuchSmallerThan(other.col(i), prec))
|
||||
if(col(i).norm2() > other.col(i).norm2() * prec * prec)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // EI_FUZZY_H
|
||||
#endif // EI_FUZZY_H
|
||||
|
||||
@@ -45,10 +45,10 @@ template<> struct NumTraits<int>
|
||||
static double sqrt(const int& x) { return std::sqrt(static_cast<double>(x)); }
|
||||
static int abs(const int& x) { return std::abs(x); }
|
||||
static int abs2(const int& x) { return x*x; }
|
||||
static int rand()
|
||||
static int random()
|
||||
{
|
||||
// "rand() % n" is bad, they say, because the low-order bits are not random enough.
|
||||
// However here, 21 is odd, so rand() % 21 uses the high-order bits
|
||||
// "random() % n" is bad, they say, because the low-order bits are not random enough.
|
||||
// However here, 21 is odd, so random() % 21 uses the high-order bits
|
||||
// as well, so there's no problem.
|
||||
return (std::rand() % 21) - 10;
|
||||
}
|
||||
@@ -86,7 +86,7 @@ template<> struct NumTraits<float>
|
||||
static float sqrt(const float& x) { return std::sqrt(x); }
|
||||
static float abs(const float& x) { return std::abs(x); }
|
||||
static float abs2(const float& x) { return x*x; }
|
||||
static float rand()
|
||||
static float random()
|
||||
{
|
||||
return std::rand() / (RAND_MAX/20.0f) - 10.0f;
|
||||
}
|
||||
@@ -120,7 +120,7 @@ template<> struct NumTraits<double>
|
||||
static double sqrt(const double& x) { return std::sqrt(x); }
|
||||
static double abs(const double& x) { return std::abs(x); }
|
||||
static double abs2(const double& x) { return x*x; }
|
||||
static double rand()
|
||||
static double random()
|
||||
{
|
||||
return std::rand() / (RAND_MAX/20.0) - 10.0;
|
||||
}
|
||||
@@ -158,9 +158,9 @@ template<typename _Real> struct NumTraits<std::complex<_Real> >
|
||||
{ return std::abs(static_cast<FloatingPoint>(x)); }
|
||||
static Real abs2(const Complex& x)
|
||||
{ return std::real(x) * std::real(x) + std::imag(x) * std::imag(x); }
|
||||
static Complex rand()
|
||||
static Complex random()
|
||||
{
|
||||
return Complex(NumTraits<Real>::rand(), NumTraits<Real>::rand());
|
||||
return Complex(NumTraits<Real>::random(), NumTraits<Real>::random());
|
||||
}
|
||||
static bool isMuchSmallerThan(const Complex& a, const Complex& b, const Real& prec = precision())
|
||||
{
|
||||
|
||||
@@ -34,6 +34,20 @@ template<typename Scalar, typename Derived> class Object
|
||||
template<typename OtherDerived>
|
||||
void _copy_helper(const Object<Scalar, OtherDerived>& other);
|
||||
|
||||
template<typename OtherDerived>
|
||||
bool _isApprox_helper(
|
||||
const OtherDerived& other,
|
||||
const typename NumTraits<Scalar>::Real& prec = NumTraits<Scalar>::precision()
|
||||
) const;
|
||||
bool _isMuchSmallerThan_helper(
|
||||
const Scalar& other,
|
||||
const typename NumTraits<Scalar>::Real& prec = NumTraits<Scalar>::precision()
|
||||
) const;
|
||||
template<typename OtherDerived>
|
||||
bool _isMuchSmallerThan_helper(
|
||||
const Object<Scalar, OtherDerived>& other,
|
||||
const typename NumTraits<Scalar>::Real& prec = NumTraits<Scalar>::precision()
|
||||
) const;
|
||||
public:
|
||||
static const int SizeAtCompileTime
|
||||
= RowsAtCompileTime == Dynamic || ColsAtCompileTime == Dynamic
|
||||
@@ -81,8 +95,8 @@ template<typename Scalar, typename Derived> class Object
|
||||
return *static_cast<Derived*>(this);
|
||||
}
|
||||
|
||||
Row<Derived> row(int i);
|
||||
Column<Derived> col(int i);
|
||||
Row<Derived> row(int i) const;
|
||||
Column<Derived> col(int i) const;
|
||||
Minor<Derived> minor(int row, int col);
|
||||
Block<Derived> block(int startRow, int endRow, int startCol, int endCol);
|
||||
Transpose<Derived> transpose();
|
||||
@@ -111,7 +125,7 @@ template<typename Scalar, typename Derived> class Object
|
||||
) const;
|
||||
template<typename OtherDerived>
|
||||
bool isMuchSmallerThan(
|
||||
const OtherDerived& other,
|
||||
const Object<Scalar, OtherDerived>& other,
|
||||
const typename NumTraits<Scalar>::Real& prec = NumTraits<Scalar>::precision()
|
||||
) const;
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ template<typename MatrixType> class Random
|
||||
{
|
||||
EI_UNUSED(row);
|
||||
EI_UNUSED(col);
|
||||
return NumTraits<Scalar>::rand();
|
||||
return NumTraits<Scalar>::random();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
@@ -80,9 +80,9 @@ template<typename MatrixType> class Row
|
||||
|
||||
template<typename Scalar, typename Derived>
|
||||
Row<Derived>
|
||||
Object<Scalar, Derived>::row(int i)
|
||||
Object<Scalar, Derived>::row(int i) const
|
||||
{
|
||||
return Row<Derived>(static_cast<Derived*>(this)->ref(), i);
|
||||
return Row<Derived>(static_cast<Derived*>(const_cast<Object*>(this))->ref(), i);
|
||||
}
|
||||
|
||||
#endif // EI_ROW_H
|
||||
|
||||
Reference in New Issue
Block a user