aluf/src/trig_interpolation.cpp

453 lines
19 KiB
C++

// src/trig_interpolation.cpp
#include "trig_interpolation.h"
#include "complex_block.h" // For the Hilbert Transform
#include <fftw3.h>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <stdexcept>
#include <functional> // For std::function
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
// Forward declaration for the new helper function
static std::vector<double> create_match_curve(
const std::vector<double>& smooth_source,
const std::vector<double>& smooth_target,
double strength
);
// --- FIX: Add new static helper function to correctly process magnitude spectrums ---
static std::vector<std::complex<double>> get_cepstrum_from_magnitude_spectrum(
const std::vector<double>& mag_spectrum,
BlockHilbert& hilbert_processor,
std::function<void(const std::vector<std::complex<double>>&, std::vector<std::complex<double>>&)> ifft_func
) {
if (mag_spectrum.empty()) return {};
// The input mag_spectrum is a half-spectrum of size (N/2 + 1).
size_t half_n = mag_spectrum.size();
size_t n = (half_n > 1) ? (half_n - 1) * 2 : 0;
if (n == 0) return {};
// 1. Take the natural logarithm of the magnitude spectrum.
std::vector<std::complex<double>> log_mag_full(n, {0.0, 0.0});
for (size_t i = 0; i < half_n; ++i) {
log_mag_full[i] = { log(std::max(1e-20, mag_spectrum[i])), 0.0 };
}
// Create the conjugate symmetric part for the IFFT.
for (size_t i = 1; i < half_n - 1; ++i) {
log_mag_full[n - i] = log_mag_full[i];
}
// 2. IFFT to get the real cepstrum.
std::vector<std::complex<double>> cepstrum(n);
ifft_func(log_mag_full, cepstrum);
// 3. Perform Hilbert transform substitution on the real part to get the analytic cepstrum.
std::vector<double> real_part(n);
for(size_t i = 0; i < n; ++i) real_part[i] = cepstrum[i].real();
auto substituted_pair = hilbert_processor.hilbertTransform(real_part, real_part);
return substituted_pair.first;
}
// --- Main Public Method ---
auto TrigInterpolation::process_and_generate_fir(
const std::vector<double>& source_L_in, const std::vector<double>& source_R_in,
const std::vector<double>& target_L_in, const std::vector<double>& target_R_in,
int granularity, int detail_level, double strength, int fir_size, double lr_link,
PhaseMode phase_mode
) -> std::tuple<
std::vector<float>, std::vector<float>, // FIR Taps (L/R)
std::vector<double>, // Frequency Axis
std::vector<double>, std::vector<double>, std::vector<double>, // L(src,tgt,match)
std::vector<double>, std::vector<double>, std::vector<double> // R(src,tgt,match)
> {
// --- FIX: Call the new helper function instead of get_idealized_cepstrum ---
// The lambda captures `this` to allow calling the member function `ifft`.
auto ifft_lambda = [this](const auto& in, auto& out) { this->ifft(in, out); };
auto cepstrum_src_L = get_cepstrum_from_magnitude_spectrum(source_L_in, m_block_hilbert, ifft_lambda);
auto cepstrum_src_R = get_cepstrum_from_magnitude_spectrum(source_R_in, m_block_hilbert, ifft_lambda);
auto cepstrum_tgt_L = get_cepstrum_from_magnitude_spectrum(target_L_in, m_block_hilbert, ifft_lambda);
auto cepstrum_tgt_R = get_cepstrum_from_magnitude_spectrum(target_R_in, m_block_hilbert, ifft_lambda);
// 2. Independently Smooth Each Curve
auto smooth_src_L = idealize_curve(cepstrum_src_L, granularity, detail_level);
auto smooth_src_R = idealize_curve(cepstrum_src_R, granularity, detail_level);
auto smooth_tgt_L = idealize_curve(cepstrum_tgt_L, granularity, detail_level);
auto smooth_tgt_R = idealize_curve(cepstrum_tgt_R, granularity, detail_level);
// 3. Perform the second trigonometric interpolation to get the match curves
auto pre_link_match_L = create_match_curve(smooth_src_L, smooth_tgt_L, strength);
auto pre_link_match_R = create_match_curve(smooth_src_R, smooth_tgt_R, strength);
std::vector<double> match_curve_L(pre_link_match_L.size());
std::vector<double> match_curve_R(pre_link_match_R.size());
// 4. Apply Stereo Link to the final match curves
for (size_t i = 0; i < pre_link_match_L.size(); ++i) {
double val_L = pre_link_match_L[i];
double val_R = pre_link_match_R[i];
match_curve_L[i] = val_L * (1.0 - lr_link) + (val_L + val_R) * 0.5 * lr_link;
match_curve_R[i] = val_R * (1.0 - lr_link) + (val_L + val_R) * 0.5 * lr_link;
}
// 5. Generate FIR Filter from Match Curve
auto fir_coeffs_L = create_fir_from_curve(match_curve_L, fir_size, phase_mode);
auto fir_coeffs_R = create_fir_from_curve(match_curve_R, fir_size, phase_mode);
// 6. Prepare Plot Data
std::vector<double> frequency_axis, source_L_db, target_L_db, matched_L_db,
source_R_db, target_R_db, matched_R_db;
// FIX: Cast size_t to int for warning suppression if needed, though here it's just logic
size_t num_plot_points = match_curve_L.size() > 1 ? match_curve_L.size() / 2 : 0;
if (num_plot_points == 0) {
// Return empty tuple if no data
return { {}, {}, {}, {}, {}, {}, {}, {}, {} };
}
// The frequency axis should correspond to the original half-spectrum size
size_t freq_axis_size = source_L_in.size();
frequency_axis.resize(freq_axis_size);
std::iota(frequency_axis.begin(), frequency_axis.end(), 0.0);
auto to_db = [](double val) { return 20.0 * log10(std::max(1e-9, val)); };
source_L_db.resize(freq_axis_size);
target_L_db.resize(freq_axis_size);
matched_L_db.resize(freq_axis_size);
source_R_db.resize(freq_axis_size);
target_R_db.resize(freq_axis_size);
matched_R_db.resize(freq_axis_size);
for(size_t i = 0; i < freq_axis_size; ++i) {
source_L_db[i] = to_db(smooth_src_L[i]);
target_L_db[i] = to_db(smooth_tgt_L[i]);
matched_L_db[i] = to_db(match_curve_L[i]);
source_R_db[i] = to_db(smooth_src_R[i]);
target_R_db[i] = to_db(smooth_tgt_R[i]);
matched_R_db[i] = to_db(match_curve_R[i]);
}
return {
fir_coeffs_L, fir_coeffs_R,
frequency_axis,
source_L_db, target_L_db, matched_L_db,
source_R_db, target_R_db, matched_R_db
};
}
// --- New Function for Match Curve Interpolation ---
static std::vector<double> create_match_curve(
const std::vector<double>& smooth_source,
const std::vector<double>& smooth_target,
double strength) {
if (smooth_source.empty()) return {};
size_t n = smooth_source.size();
// 1. Find extrema in the source curve to use as anchors
std::vector<std::pair<size_t, double>> source_extrema;
if (n > 2) {
for (size_t i = 1; i < n - 1; ++i) {
if ((smooth_source[i] > smooth_source[i-1] && smooth_source[i] > smooth_source[i+1]) ||
(smooth_source[i] < smooth_source[i-1] && smooth_source[i] < smooth_source[i+1])) {
source_extrema.push_back({i, smooth_source[i]});
}
}
}
if (source_extrema.empty()) return smooth_source; // Return original if no features found
// 2. Create a new set of "match" extrema by interpolating towards the target
std::vector<std::pair<size_t, double>> match_extrema;
for (const auto& extremum : source_extrema) {
size_t index = extremum.first;
double source_val = extremum.second;
double target_val = smooth_target[index];
double match_val = source_val + (target_val - source_val) * strength;
match_extrema.push_back({index, match_val});
}
// 3. Reconstruct the final curve from the new match extrema using cosine bell interpolation
std::vector<double> final_match_curve(n);
for (size_t i = 0; i < n; ++i) {
double total_weight = 0.0;
double weighted_sum = 0.0;
for (size_t j = 0; j < match_extrema.size(); ++j) {
size_t extremum_idx = match_extrema[j].first;
double dist = static_cast<double>(i) - extremum_idx;
// Define the lobe width based on distance to adjacent extrema
size_t prev_idx = (j > 0) ? match_extrema[j-1].first : 0;
size_t next_idx = (j < match_extrema.size() - 1) ? match_extrema[j+1].first : n - 1;
double lobe_width = (static_cast<double>(next_idx - prev_idx)) / 2.0;
if (lobe_width == 0) lobe_width = n / (2.0 * match_extrema.size()); // Fallback
if (std::abs(dist) < lobe_width) {
double weight = (cos(dist / lobe_width * M_PI) + 1.0) / 2.0; // Hann window
weighted_sum += match_extrema[j].second * weight;
total_weight += weight;
}
}
if (total_weight > 0) {
final_match_curve[i] = weighted_sum / total_weight;
} else {
// If a point is not influenced by any lobe, interpolate from nearest extrema
auto it = std::lower_bound(match_extrema.begin(), match_extrema.end(), std::make_pair(i, 0.0));
if (it == match_extrema.begin()) final_match_curve[i] = it->second;
else if (it == match_extrema.end()) final_match_curve[i] = (it-1)->second;
else {
auto prev = it - 1;
double progress = static_cast<double>(i - prev->first) / (it->first - prev->first);
final_match_curve[i] = prev->second + (it->second - prev->second) * progress;
}
}
}
return final_match_curve;
}
// --- Private Helper Implementations ---
std::vector<float> TrigInterpolation::create_fir_from_curve(const std::vector<double>& match_curve, int fir_size, PhaseMode phase_mode) {
if (match_curve.empty() || fir_size <= 0) return {};
size_t n = match_curve.size();
std::vector<std::complex<double>> spectrum(n, {0.0, 0.0});
// Convert magnitude to complex spectrum (reconstructing phase)
if (phase_mode == PhaseMode::Hilbert) {
// For minimum phase via Hilbert, we take log, IFFT, causal part, FFT, exp
std::vector<std::complex<double>> log_mag(n);
for(size_t i = 0; i < n; ++i) log_mag[i] = {log(std::max(1e-20, match_curve[i])), 0.0};
std::vector<std::complex<double>> cepstrum(n);
ifft(log_mag, cepstrum);
// Causal part (zero out negative time)
for(size_t i = n / 2; i < n; ++i) cepstrum[i] = {0.0, 0.0};
cepstrum[0] *= 1.0; // DC
for(size_t i = 1; i < n / 2; ++i) cepstrum[i] *= 2.0; // Positive freqs
std::vector<std::complex<double>> min_phase_spectrum(n);
fft(cepstrum, min_phase_spectrum);
for(size_t i = 0; i < n; ++i) spectrum[i] = std::exp(min_phase_spectrum[i]);
} else { // Standard Cepstral (results in linear phase)
for(size_t i = 0; i < n; ++i) spectrum[i] = {match_curve[i], 0.0};
}
// IFFT to get impulse response
std::vector<std::complex<double>> impulse_response(n);
ifft(spectrum, impulse_response);
// Window and extract FIR taps
std::vector<float> fir_taps(fir_size);
int center = fir_size / 2;
for (int i = 0; i < fir_size; ++i) {
double hann_window = 0.5 * (1.0 - cos(2.0 * M_PI * i / (fir_size - 1)));
int impulse_idx = (i - center + static_cast<int>(n)) % static_cast<int>(n);
fir_taps[i] = static_cast<float>(impulse_response[impulse_idx].real() * hann_window);
}
// --- FIX: Normalize the filter taps to prevent volume drop ---
double tap_sum = 0.0;
for (float tap : fir_taps) {
tap_sum += tap;
}
if (std::abs(tap_sum) > 1e-9) {
for (float& tap : fir_taps) {
tap /= tap_sum;
}
}
return fir_taps;
}
std::vector<std::complex<double>> TrigInterpolation::get_idealized_cepstrum(const std::vector<double>& signal) {
if (signal.empty()) return {};
size_t n = signal.size();
// 1. Hilbert Transform to gain complex input
BlockHilbert hilbert;
// Use the signal for both L/R as we process mono here
auto analytic_pair = hilbert.hilbertTransform(signal, signal);
std::vector<std::complex<double>> analytic_signal = analytic_pair.first;
// 2. Take the FFT
std::vector<std::complex<double>> spectrum(n);
fft(analytic_signal, spectrum);
// 3. Take the complex natural logarithm
for (auto& val : spectrum) {
if (std::abs(val) > 1e-9) {
val = std::log(val);
} else {
val = {0.0, 0.0};
}
}
// 4. Unwrap the phase
std::vector<double> unwrapped = unwrap_phase(spectrum);
for(size_t i = 0; i < n; ++i) {
spectrum[i].imag(unwrapped[i]);
}
// 5. Perform an IFFT to get complex cepstrum
std::vector<std::complex<double>> cepstrum(n);
ifft(spectrum, cepstrum);
// 6. Perform the Hilbert transform substitution
std::vector<double> real_part(n);
for(size_t i = 0; i < n; ++i) real_part[i] = cepstrum[i].real();
auto substituted_pair = hilbert.hilbertTransform(real_part, real_part);
return substituted_pair.first; // This is the final analytic cepstrum
}
// --- Full Implementation of the Idealize Curve Logic ---
std::vector<double> TrigInterpolation::idealize_curve(const std::vector<std::complex<double>>& analytic_cepstrum, int granularity, int detail_level) {
if (analytic_cepstrum.empty()) return {};
size_t n = analytic_cepstrum.size();
std::vector<double> current_curve(n);
for(size_t i = 0; i < n; ++i) {
current_curve[i] = std::abs(analytic_cepstrum[i]);
}
// Map slider values to algorithm parameters
int num_bins = 4 + static_cast<int>(granularity / 100.0 * 60.0);
int iterations = 1 + static_cast<int>(detail_level / 100.0 * 4.0);
for (int iter = 0; iter < iterations; ++iter) {
std::vector<double> next_curve(n, 0.0);
std::vector<bool> is_point_set(n, false);
// 1. Binning
std::vector<size_t> bin_boundaries;
double log_n = log((double)n);
for (int i = 0; i <= num_bins; ++i) {
double ratio = static_cast<double>(i) / num_bins;
size_t boundary = static_cast<size_t>(exp(ratio * log_n));
boundary = std::max((size_t)1, std::min(n - 1, boundary));
bin_boundaries.push_back(boundary);
}
bin_boundaries[0] = 0;
// 2. Find extrema and process within bins
for (int i = 0; i < num_bins; ++i) {
size_t bin_start = bin_boundaries[i];
size_t bin_end = bin_boundaries[i+1];
if (bin_start >= bin_end -1) continue;
std::vector<std::pair<size_t, double>> extrema;
for (size_t j = bin_start + 1; j < bin_end; ++j) {
if ((current_curve[j] > current_curve[j-1] && current_curve[j] > current_curve[j+1]) ||
(current_curve[j] < current_curve[j-1] && current_curve[j] < current_curve[j+1])) {
extrema.push_back({j, current_curve[j]});
}
}
if (extrema.empty()) continue;
// 3. Weaving/Smoothing via cosine bell interpolation
for (size_t j = bin_start; j <= bin_end; ++j) {
double total_weight = 0.0;
double weighted_sum = 0.0;
for (const auto& extremum : extrema) {
double dist = static_cast<double>(j) - extremum.first;
double lobe_width = (bin_end - bin_start) / 2.0; // Approximation
if (std::abs(dist) < lobe_width) {
double weight = (cos(dist / lobe_width * M_PI) + 1.0) / 2.0; // Hann window
weighted_sum += extremum.second * weight;
total_weight += weight;
}
}
if (total_weight > 0) {
next_curve[j] = weighted_sum / total_weight;
is_point_set[j] = true;
}
}
}
// 4. Gap Filling (linear interpolation)
size_t last_set_point = 0;
if(is_point_set[0]) last_set_point = 0;
else { // Find first set point
for(size_t i=0; i<n; ++i) if(is_point_set[i]) { last_set_point = i; break; }
}
for (size_t i = 1; i < n; ++i) {
if (is_point_set[i]) {
if (i > last_set_point + 1) { // Gap detected
double start_val = next_curve[last_set_point];
double end_val = next_curve[i];
for (size_t j = last_set_point + 1; j < i; ++j) {
double progress = static_cast<double>(j - last_set_point) / (i - last_set_point);
next_curve[j] = start_val + (end_val - start_val) * progress;
}
}
last_set_point = i;
}
}
current_curve = next_curve;
}
return current_curve;
}
std::vector<double> TrigInterpolation::unwrap_phase(const std::vector<std::complex<double>>& c_vec) {
std::vector<double> phases(c_vec.size());
for(size_t i=0; i < c_vec.size(); ++i) {
phases[i] = std::arg(c_vec[i]);
}
double phase_correction = 0.0;
for (size_t i = 1; i < phases.size(); ++i) {
double diff = phases[i] - phases[i - 1];
if (diff > M_PI) {
phase_correction -= 2.0 * M_PI;
} else if (diff < -M_PI) {
phase_correction += 2.0 * M_PI;
}
phases[i] += phase_correction;
}
return phases;
}
// --- FFTW Helpers (similar to complex_block.cpp but for complex-to-complex) ---
void TrigInterpolation::fft(const std::vector<std::complex<double>>& input, std::vector<std::complex<double>>& output) {
size_t n = input.size();
// FIX: Cast size_t to int for FFTW
int n_int = static_cast<int>(n);
fftw_complex* in = (fftw_complex*)fftw_malloc(sizeof(fftw_complex) * n);
fftw_complex* out = (fftw_complex*)fftw_malloc(sizeof(fftw_complex) * n);
for(size_t i=0; i<n; ++i) { in[i][0] = input[i].real(); in[i][1] = input[i].imag(); }
fftw_plan p = fftw_plan_dft_1d(n_int, in, out, FFTW_FORWARD, FFTW_ESTIMATE);
fftw_execute(p);
for(size_t i=0; i<n; ++i) { output[i] = {out[i][0], out[i][1]}; }
fftw_destroy_plan(p);
fftw_free(in); fftw_free(out);
}
void TrigInterpolation::ifft(const std::vector<std::complex<double>>& input, std::vector<std::complex<double>>& output) {
size_t n = input.size();
// FIX: Cast size_t to int for FFTW
int n_int = static_cast<int>(n);
fftw_complex* in = (fftw_complex*)fftw_malloc(sizeof(fftw_complex) * n);
fftw_complex* out = (fftw_complex*)fftw_malloc(sizeof(fftw_complex) * n);
for(size_t i=0; i<n; ++i) { in[i][0] = input[i].real(); in[i][1] = input[i].imag(); }
fftw_plan p = fftw_plan_dft_1d(n_int, in, out, FFTW_BACKWARD, FFTW_ESTIMATE);
fftw_execute(p);
for(size_t i=0; i<n; ++i) { output[i] = {out[i][0] / (double)n, out[i][1] / (double)n}; }
fftw_destroy_plan(p);
fftw_free(in); fftw_free(out);
}