diff --git a/src/Interpolation.hpp b/src/Interpolation.hpp index 956b478..1b50378 100644 --- a/src/Interpolation.hpp +++ b/src/Interpolation.hpp @@ -44,11 +44,20 @@ inline std::ostream &operator<<(std::ostream &os, const std::vector &v) { return os; } +template class MultiDimVector { + It it; + public: // put the first dimension into an outer vector for processing reasons std::vector> data; - MultiDimVector(size_t n) : data(n){}; + + MultiDimVector(It it) : it(it), data(it.firstIndexBound()) { + auto sizes = it.numValuesPerFirstIndex(); + for (size_t i = 0; i < sizes.size(); ++i) { + data[i].resize(sizes[i]); + } + }; void swap(MultiDimVector &other) { data.swap(other.data); } @@ -57,23 +66,21 @@ class MultiDimVector { data[dim].clear(); } } + + It getJumpIterator() const { return it; } }; template -void multiply_lower_triangular_inplace(It it, std::vector> L, - MultiDimVector &v) { +void multiply_lower_triangular_inplace(std::vector> L, + MultiDimVector &v) { // the multiplication is based on a cyclic permutation of the indices: the last index of v becomes // the first index of w size_t d = L.size(); - auto n = it.indexBounds(); + It it = v.getJumpIterator(); for (int k = d - 1; k >= 0; --k) { - MultiDimVector w(n[k]); - std::vector indexes(n[k], 0); - auto sizes = it.numValuesPerFirstIndex(); - for (size_t idx = 0; idx < n[k]; ++idx) { - w.data[idx].resize(sizes[idx]); - } + MultiDimVector w(it); + std::vector indexes(w.data.size(), 0); auto &Lk = L[k]; it.reset(); @@ -117,20 +124,16 @@ void multiply_lower_triangular_inplace(It it, std::vector -void multiply_upper_triangular_inplace(It it, std::vector> U, - MultiDimVector &v) { +void multiply_upper_triangular_inplace(std::vector> U, + MultiDimVector &v) { // the multiplication is based on a cyclic permutation of the indices: the last index of v becomes // the first index of w size_t d = U.size(); - auto n = it.indexBounds(); + It it = v.getJumpIterator(); for (int k = d - 1; k >= 0; --k) { - MultiDimVector w(n[k]); - std::vector indexes(n[k], 0); - auto sizes = it.numValuesPerFirstIndex(); - for (size_t idx = 0; idx < n[k]; ++idx) { - w.data[idx].resize(sizes[idx]); - } + MultiDimVector w(it); + std::vector indexes(w.data.size(), 0); auto &Uk = U[k]; it.reset(); @@ -264,17 +267,17 @@ class SparseTPOperator { } } - MultiDimVector apply(MultiDimVector input) { + MultiDimVector apply(MultiDimVector input) { prepareApply(); - multiply_upper_triangular_inplace(it, U, input); - multiply_lower_triangular_inplace(it, L, input); + multiply_upper_triangular_inplace(U, input); + multiply_lower_triangular_inplace(L, input); return input; } - MultiDimVector solve(MultiDimVector rhs) { + MultiDimVector solve(MultiDimVector rhs) { prepareSolve(); - multiply_lower_triangular_inplace(it, Linv, rhs); - multiply_upper_triangular_inplace(it, Uinv, rhs); + multiply_lower_triangular_inplace(Linv, rhs); + multiply_upper_triangular_inplace(Uinv, rhs); return rhs; } }; @@ -306,10 +309,11 @@ SparseTPOperator createInterpolationOperator(It it, Phi phi, X x) { } template -MultiDimVector evaluateFunction(It it, Func f, X x) { +MultiDimVector evaluateFunction(It it, Func f, X x) { size_t d = it.dim(); auto n = it.indexBounds(); - MultiDimVector v(n[0]); + MultiDimVector v(it); + std::vector indexes(v.data.size(), 0); it.reset(); std::vector point(d); @@ -324,7 +328,8 @@ MultiDimVector evaluateFunction(It it, Func f, X x) { double function_value = f(point); - v.data[it.firstIndex()].push_back(function_value); + size_t first_index = it.firstIndex(); + v.data[first_index][indexes[first_index]++] = function_value; } it.next(); @@ -334,7 +339,7 @@ MultiDimVector evaluateFunction(It it, Func f, X x) { } template -MultiDimVector interpolate(Func f, It it, Phi phi, X x) { +MultiDimVector interpolate(Func f, It it, Phi phi, X x) { auto rhs = evaluateFunction(it, f, x); auto op = createInterpolationOperator(it, phi, x); return op.solve(rhs); diff --git a/test/common.hpp b/test/common.hpp index a59781b..5adffa4 100644 --- a/test/common.hpp +++ b/test/common.hpp @@ -70,7 +70,8 @@ inline std::ostream &operator<<(std::ostream &os, const std::vector &v) { return os; } -inline std::ostream &operator<<(std::ostream &os, fsi::MultiDimVector const &v) { +template +std::ostream &operator<<(std::ostream &os, fsi::MultiDimVector const &v) { for (size_t i = 0; i < v.data.size(); ++i) { std::cout << v.data[i] << "\n\n"; } diff --git a/test/main.cpp b/test/main.cpp index 5e4b4ae..3458dec 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -26,10 +26,10 @@ limitations under the License. */ void runFunctions() { - constexpr size_t d = 30; - size_t bound = 8; - fsi::TemplateBoundedSumIterator it(bound); - // fsi::BoundedSumIterator it(d, bound); + constexpr size_t d = 8; + size_t bound = 24; + // fsi::TemplateBoundedSumIterator it(bound); + fsi::BoundedSumIterator it(d, bound); std::vector phi(d); std::vector x(d);