summaryrefslogtreecommitdiff
path: root/vec.h
diff options
context:
space:
mode:
Diffstat (limited to 'vec.h')
-rw-r--r--vec.h55
1 files changed, 50 insertions, 5 deletions
diff --git a/vec.h b/vec.h
index 78580a8..b233220 100644
--- a/vec.h
+++ b/vec.h
@@ -9,6 +9,7 @@
#include <random>
#include <ostream>
#include <cassert>
+#include <algorithm>
static inline bool eq(int a, int b) {
return a == b;
@@ -102,11 +103,46 @@ struct vec3 {
return *this * (1.0 / norm());
}
+ // If all the three float-point scalars are finite real number.
+ bool valid() const {
+ return std::isfinite(x) && std::isfinite(y) && std::isfinite(z);
+ }
+
// Get the reflected vector. Current vector is the normal vector (length should be 1), v is the incoming vector.
vec3 reflect(const vec3 &v) const {
assert(fabs(mod2() - 1.0) < 1e-8);
return v - (2.0 * dot(v)) * (*this);
}
+
+ // Get the refracted vector. Current vector is the normal vector (length should be 1),
+ // r1 is the incoming vector, ri_inv is the relative refraction index n2/n1,
+ // where n2 is the destination media's refraction index, and n1 is the source media's refraction index.
+ // TIR (Total Internal Reflection) is optionally enabled by macro TIR_OR and TIR_OFF.
+ // If TIR happens, the ray will be reflected.
+ template<bool Enable_TIR>
+ vec3 refract(const vec3 &r1, double ri_inv) const {
+ assert(fabs(mod2() - 1.0) < 1e-12);
+ assert(fabs(r1.mod2() - 1.0) < 1e-12);
+ assert(ri_inv > 0);
+ assert(dot(r1) < 0); // normal vector must be on the same side
+ const auto &n = *this; // normal vector
+ const auto m_cos1 = std::max(dot(r1), (T) (-1)); // cos(a1), a1 is the incoming angle
+// assert(m_cos1 <= 0); // incoming angle must smaller than 90deg
+ const auto c = ri_inv; // c = nx * r`x + ny * r`y
+ auto d = 1 - c * c * (1 - m_cos1 * m_cos1);
+ if (d < 0) {
+ // TODO test TIR
+ if (Enable_TIR) {
+ // ri_inv < sin(a1), cannot refract, must reflect (Total Internal Reflection)
+ return reflect(r1);
+ } else {
+ d = -d; // abs, just make the sqrt has a real solution
+ }
+ }
+ const auto n2 = (r1 - dot(r1) * n) * c - sqrt(d) * n;
+ assert(n2.valid());
+ return n2;
+ }
};
// print to ostream
@@ -115,33 +151,42 @@ inline std::ostream &operator<<(std::ostream &out, const vec3<T> &vec) {
return out << "vec3[x=" << vec.x << ", y=" << vec.y << ", z=" << vec.z << ']';
}
-// product vec3 by a scalar
+// product vec3 by a scalar, with fp assertions
template<
typename T,
typename S,
typename = typename std::enable_if<std::is_arithmetic<S>::value, S>::type
>
inline vec3<T> operator*(const vec3<T> &vec, const S &b) {
+ if (std::is_floating_point<S>::value) {
+ assert(std::isfinite(b));;
+ }
return vec3<T>{.x=(T) (vec.x * b), .y=(T) (vec.y * b), .z=(T) (vec.z * b)};
}
-// product vec3 by a scalar
+// product vec3 by a scalar, with fp assertions
template<
typename T,
- typename S,
- typename = typename std::enable_if<std::is_arithmetic<S>::value, S>::type
+ typename S
>
inline vec3<T> operator*(const S &b, const vec3<T> &vec) {
+ if (std::is_floating_point<S>::value) {
+ assert(std::isfinite(b));
+ }
return vec3<T>{.x=(T) (vec.x * b), .y=(T) (vec.y * b), .z=(T) (vec.z * b)};
}
-// product vec3 by the inversion of a scalar (div by a scalar)
+// product vec3 by the inversion of a scalar (div by a scalar), with fp assertions
template<
typename T,
typename S,
typename = typename std::enable_if<std::is_arithmetic<S>::value, S>::type
>
inline vec3<T> operator/(const vec3<T> &vec, const S &b) {
+ if (std::is_floating_point<S>::value) {
+ assert(std::isfinite(b));
+ assert(b != 0);
+ }
return vec3<T>{.x=(T) (vec.x / b), .y=(T) (vec.y / b), .z=(T) (vec.z / b)};
}