mirror of
https://gitlab.com/libeigen/eigen.git
synced 2026-04-10 11:34:33 +08:00
Created additional tests for the tensor code.
This commit is contained in:
@@ -28,10 +28,10 @@ static void test_1d()
|
||||
|
||||
float data3[6];
|
||||
TensorMap<Tensor<float, 1>> vec3(data3, 6);
|
||||
vec3 = vec1.cwiseSqrt();
|
||||
vec3 = vec1.sqrt();
|
||||
float data4[6];
|
||||
TensorMap<Tensor<float, 1, RowMajor>> vec4(data4, 6);
|
||||
vec4 = vec2.cwiseSqrt();
|
||||
vec4 = vec2.square();
|
||||
|
||||
VERIFY_IS_APPROX(vec3(0), sqrtf(4.0));
|
||||
VERIFY_IS_APPROX(vec3(1), sqrtf(8.0));
|
||||
@@ -40,12 +40,12 @@ static void test_1d()
|
||||
VERIFY_IS_APPROX(vec3(4), sqrtf(23.0));
|
||||
VERIFY_IS_APPROX(vec3(5), sqrtf(42.0));
|
||||
|
||||
VERIFY_IS_APPROX(vec4(0), sqrtf(0.0));
|
||||
VERIFY_IS_APPROX(vec4(1), sqrtf(1.0));
|
||||
VERIFY_IS_APPROX(vec4(2), sqrtf(2.0));
|
||||
VERIFY_IS_APPROX(vec4(3), sqrtf(3.0));
|
||||
VERIFY_IS_APPROX(vec4(4), sqrtf(4.0));
|
||||
VERIFY_IS_APPROX(vec4(5), sqrtf(5.0));
|
||||
VERIFY_IS_APPROX(vec4(0), 0.0f);
|
||||
VERIFY_IS_APPROX(vec4(1), 1.0f);
|
||||
VERIFY_IS_APPROX(vec4(2), 2.0f * 2.0f);
|
||||
VERIFY_IS_APPROX(vec4(3), 3.0f * 3.0f);
|
||||
VERIFY_IS_APPROX(vec4(4), 4.0f * 4.0f);
|
||||
VERIFY_IS_APPROX(vec4(5), 5.0f * 5.0f);
|
||||
|
||||
vec3 = vec1 + vec2;
|
||||
VERIFY_IS_APPROX(vec3(0), 4.0f + 0.0f);
|
||||
@@ -79,8 +79,8 @@ static void test_2d()
|
||||
|
||||
Tensor<float, 2> mat3(2,3);
|
||||
Tensor<float, 2, RowMajor> mat4(2,3);
|
||||
mat3 = mat1.cwiseAbs();
|
||||
mat4 = mat2.cwiseAbs();
|
||||
mat3 = mat1.abs();
|
||||
mat4 = mat2.abs();
|
||||
|
||||
VERIFY_IS_APPROX(mat3(0,0), 0.0f);
|
||||
VERIFY_IS_APPROX(mat3(0,1), 1.0f);
|
||||
@@ -102,7 +102,7 @@ static void test_3d()
|
||||
Tensor<float, 3> mat1(2,3,7);
|
||||
Tensor<float, 3, RowMajor> mat2(2,3,7);
|
||||
|
||||
float val = 0.0;
|
||||
float val = 1.0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
@@ -118,28 +118,147 @@ static void test_3d()
|
||||
Tensor<float, 3, RowMajor> mat4(2,3,7);
|
||||
mat4 = mat2 * 3.14f;
|
||||
Tensor<float, 3> mat5(2,3,7);
|
||||
mat5 = mat1.cwiseSqrt().cwiseSqrt();
|
||||
mat5 = mat1.inverse().log();
|
||||
Tensor<float, 3, RowMajor> mat6(2,3,7);
|
||||
mat6 = mat2.cwiseSqrt() * 3.14f;
|
||||
mat6 = mat2.pow(0.5f) * 3.14f;
|
||||
Tensor<float, 3> mat7(2,3,7);
|
||||
mat7 = mat1.cwiseMax(mat5 * 2.0f).exp();
|
||||
Tensor<float, 3, RowMajor> mat8(2,3,7);
|
||||
mat8 = (-mat2).exp() * 3.14f;
|
||||
|
||||
val = 0.0;
|
||||
val = 1.0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_APPROX(mat3(i,j,k), val + val);
|
||||
VERIFY_IS_APPROX(mat4(i,j,k), val * 3.14f);
|
||||
VERIFY_IS_APPROX(mat5(i,j,k), sqrtf(sqrtf(val)));
|
||||
VERIFY_IS_APPROX(mat5(i,j,k), logf(1.0f/val));
|
||||
VERIFY_IS_APPROX(mat6(i,j,k), sqrtf(val) * 3.14f);
|
||||
VERIFY_IS_APPROX(mat7(i,j,k), expf((std::max)(val, mat5(i,j,k) * 2.0f)));
|
||||
VERIFY_IS_APPROX(mat8(i,j,k), expf(-val) * 3.14f);
|
||||
val += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void test_constants()
|
||||
{
|
||||
Tensor<float, 3> mat1(2,3,7);
|
||||
Tensor<float, 3> mat2(2,3,7);
|
||||
Tensor<float, 3> mat3(2,3,7);
|
||||
|
||||
float val = 1.0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
mat1(i,j,k) = val;
|
||||
val += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
mat2 = mat1.constant(3.14f);
|
||||
mat3 = mat1.cwiseMax(7.3f).exp();
|
||||
|
||||
val = 1.0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_APPROX(mat2(i,j,k), 3.14f);
|
||||
VERIFY_IS_APPROX(mat3(i,j,k), expf((std::max)(val, 7.3f)));
|
||||
val += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void test_functors()
|
||||
{
|
||||
Tensor<float, 3> mat1(2,3,7);
|
||||
Tensor<float, 3> mat2(2,3,7);
|
||||
Tensor<float, 3> mat3(2,3,7);
|
||||
|
||||
float val = 1.0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
mat1(i,j,k) = val;
|
||||
val += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
mat2 = mat1.inverse().unaryExpr(&asinf);
|
||||
mat3 = mat1.unaryExpr(&tanhf);
|
||||
|
||||
val = 1.0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_APPROX(mat2(i,j,k), asinf(1.0f / mat1(i,j,k)));
|
||||
VERIFY_IS_APPROX(mat3(i,j,k), tanhf(mat1(i,j,k)));
|
||||
val += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void test_type_casting()
|
||||
{
|
||||
Tensor<bool, 3> mat1(2,3,7);
|
||||
Tensor<float, 3> mat2(2,3,7);
|
||||
Tensor<double, 3> mat3(2,3,7);
|
||||
mat1.setRandom();
|
||||
mat2.setRandom();
|
||||
|
||||
mat3 = mat1.template cast<double>();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_APPROX(mat3(i,j,k), mat1(i,j,k) ? 1.0 : 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mat3 = mat2.template cast<double>();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_APPROX(mat3(i,j,k), static_cast<double>(mat2(i,j,k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void test_select()
|
||||
{
|
||||
Tensor<float, 3> selector(2,3,7);
|
||||
Tensor<float, 3> mat1(2,3,7);
|
||||
Tensor<float, 3> mat2(2,3,7);
|
||||
Tensor<float, 3> result(2,3,7);
|
||||
|
||||
selector.setRandom();
|
||||
mat1.setRandom();
|
||||
mat2.setRandom();
|
||||
result = (selector > selector.constant(0.5f)).select(mat1, mat2);
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_APPROX(result(i,j,k), (selector(i,j,k) > 0.5f) ? mat1(i,j,k) : mat2(i,j,k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_cxx11_tensor_expr()
|
||||
{
|
||||
CALL_SUBTEST(test_1d());
|
||||
CALL_SUBTEST(test_2d());
|
||||
CALL_SUBTEST(test_3d());
|
||||
CALL_SUBTEST(test_constants());
|
||||
CALL_SUBTEST(test_functors());
|
||||
CALL_SUBTEST(test_type_casting());
|
||||
CALL_SUBTEST(test_select());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user