summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test_vec.cpp23
-rw-r--r--vec.h55
2 files changed, 68 insertions, 10 deletions
diff --git a/test_vec.cpp b/test_vec.cpp
index 143d5b3..9739160 100644
--- a/test_vec.cpp
+++ b/test_vec.cpp
@@ -77,18 +77,18 @@ TEST(Vec, Mod2) {
}
TEST(Vec, IsZero) {
- vec3i a{1,0,0}, b{0,-1,0}, c{1,2,3};
+ vec3i a{1, 0, 0}, b{0, -1, 0}, c{1, 2, 3};
ASSERT_FALSE(a.is_zero());
ASSERT_FALSE(b.is_zero());
ASSERT_FALSE(c.is_zero());
- vec3d d{0.1,0,0}, e{0,-0.1,0},f{0.1,0.1,0.1};
+ vec3d d{0.1, 0, 0}, e{0, -0.1, 0}, f{0.1, 0.1, 0.1};
ASSERT_FALSE(d.is_zero());
ASSERT_FALSE(e.is_zero());
ASSERT_FALSE(f.is_zero());
- vec3i g{0,0,0};
- vec3d h{0,0,0}, i{1e-10,0,0}, j{1e-10,1e-10,1e-10};
+ vec3i g{0, 0, 0};
+ vec3d h{0, 0, 0}, i{1e-10, 0, 0}, j{1e-10, 1e-10, 1e-10};
ASSERT_TRUE(g.is_zero());
ASSERT_TRUE(h.is_zero());
ASSERT_TRUE(i.is_zero());
@@ -96,6 +96,19 @@ TEST(Vec, IsZero) {
}
TEST(Vec, Reflect) {
- vec3d n{1,0,0}, u{-1,1.1,0}, v{1,1.1,0};
+ vec3d n{1, 0, 0}, u{-1, 1.1, 0}, v{1, 1.1, 0};
ASSERT_EQ(v, n.reflect(u));
+}
+
+TEST(Vec, Refract) {
+ vec3d n{1, 0, 0}, u{-1, 0, -1}, v{-sqrt(14), 0, -sqrt(2)};
+ ASSERT_EQ(u.unit_vec(), n.refract<true>(u.unit_vec(), 1).unit_vec());
+ ASSERT_EQ(u.unit_vec(), n.refract<false>(u.unit_vec(), 1).unit_vec());
+ ASSERT_EQ(v.unit_vec(), n.refract<true>(u.unit_vec(), 0.5).unit_vec());
+ ASSERT_EQ(v.unit_vec(), n.refract<false>(u.unit_vec(), 0.5).unit_vec());
+}
+
+TEST(Vec, Refract_TIR) {
+ vec3d n{1, 0, 0}, u{-1, 0, -sqrt(3)}, v{1, 0, -sqrt(3)};
+ ASSERT_EQ(v.unit_vec(), n.refract<true>(u.unit_vec(), 2));
} \ No newline at end of file
diff --git a/vec.h b/vec.h
index 78580a8..b233220 100644
--- a/vec.h
+++ b/vec.h
@@ -9,6 +9,7 @@
#include <random>
#include <ostream>
#include <cassert>
+#include <algorithm>
static inline bool eq(int a, int b) {
return a == b;
@@ -102,11 +103,46 @@ struct vec3 {
return *this * (1.0 / norm());
}
+ // If all the three float-point scalars are finite real number.
+ bool valid() const {
+ return std::isfinite(x) && std::isfinite(y) && std::isfinite(z);
+ }
+
// Get the reflected vector. Current vector is the normal vector (length should be 1), v is the incoming vector.
vec3 reflect(const vec3 &v) const {
assert(fabs(mod2() - 1.0) < 1e-8);
return v - (2.0 * dot(v)) * (*this);
}
+
+ // Get the refracted vector. Current vector is the normal vector (length should be 1),
+ // r1 is the incoming vector, ri_inv is the relative refraction index n2/n1,
+ // where n2 is the destination media's refraction index, and n1 is the source media's refraction index.
+ // TIR (Total Internal Reflection) is optionally enabled by macro TIR_OR and TIR_OFF.
+ // If TIR happens, the ray will be reflected.
+ template<bool Enable_TIR>
+ vec3 refract(const vec3 &r1, double ri_inv) const {
+ assert(fabs(mod2() - 1.0) < 1e-12);
+ assert(fabs(r1.mod2() - 1.0) < 1e-12);
+ assert(ri_inv > 0);
+ assert(dot(r1) < 0); // normal vector must be on the same side
+ const auto &n = *this; // normal vector
+ const auto m_cos1 = std::max(dot(r1), (T) (-1)); // cos(a1), a1 is the incoming angle
+// assert(m_cos1 <= 0); // incoming angle must smaller than 90deg
+ const auto c = ri_inv; // c = nx * r`x + ny * r`y
+ auto d = 1 - c * c * (1 - m_cos1 * m_cos1);
+ if (d < 0) {
+ // TODO test TIR
+ if (Enable_TIR) {
+ // ri_inv < sin(a1), cannot refract, must reflect (Total Internal Reflection)
+ return reflect(r1);
+ } else {
+ d = -d; // abs, just make the sqrt has a real solution
+ }
+ }
+ const auto n2 = (r1 - dot(r1) * n) * c - sqrt(d) * n;
+ assert(n2.valid());
+ return n2;
+ }
};
// print to ostream
@@ -115,33 +151,42 @@ inline std::ostream &operator<<(std::ostream &out, const vec3<T> &vec) {
return out << "vec3[x=" << vec.x << ", y=" << vec.y << ", z=" << vec.z << ']';
}
-// product vec3 by a scalar
+// product vec3 by a scalar, with fp assertions
template<
typename T,
typename S,
typename = typename std::enable_if<std::is_arithmetic<S>::value, S>::type
>
inline vec3<T> operator*(const vec3<T> &vec, const S &b) {
+ if (std::is_floating_point<S>::value) {
+ assert(std::isfinite(b));;
+ }
return vec3<T>{.x=(T) (vec.x * b), .y=(T) (vec.y * b), .z=(T) (vec.z * b)};
}
-// product vec3 by a scalar
+// product vec3 by a scalar, with fp assertions
template<
typename T,
- typename S,
- typename = typename std::enable_if<std::is_arithmetic<S>::value, S>::type
+ typename S
>
inline vec3<T> operator*(const S &b, const vec3<T> &vec) {
+ if (std::is_floating_point<S>::value) {
+ assert(std::isfinite(b));
+ }
return vec3<T>{.x=(T) (vec.x * b), .y=(T) (vec.y * b), .z=(T) (vec.z * b)};
}
-// product vec3 by the inversion of a scalar (div by a scalar)
+// product vec3 by the inversion of a scalar (div by a scalar), with fp assertions
template<
typename T,
typename S,
typename = typename std::enable_if<std::is_arithmetic<S>::value, S>::type
>
inline vec3<T> operator/(const vec3<T> &vec, const S &b) {
+ if (std::is_floating_point<S>::value) {
+ assert(std::isfinite(b));
+ assert(b != 0);
+ }
return vec3<T>{.x=(T) (vec.x / b), .y=(T) (vec.y / b), .z=(T) (vec.z / b)};
}