Mapping.cpp: Domain Mapping and XAD-enabled Multiphysics Solver

mapping.cpp

A C++ MFEM-based module implementing a DomainMapper to map reference coordinates to a deformed physical domain, plus a suite of mapped coefficients, nonlinear integrators, and gravity/multiphysics solvers with an EOS via XAD autodiff. It handles domain decomposition (stellar/core/envelope/vacuum), quadrupole moments, gravitational potential, and includes a test harness and CLI-driven entry point.

//region Includes
#include <memory>
#include <mfem.hpp>
#include <print>
#include <format>
#include <string>
#include <functional>
#include <utility>
#include <vector>
 
#include <cmath>
#include <expected>
 
#include <CLI/CLI.hpp>
#include <XAD/XAD.hpp>
#include <chrono>
//endregion
 
// Estimates by a back of the envelope approximation of what
// potential error will result in a given luminosity error.
// If we let the relative target luminosity error be δL <= 0.01 (1%)
// Which from what I can find seems to be below current observational
// uncertainties then we can approximate the target error on the potential
// as δL/ν where ν is the index on the most temperature sensitive
// reaction in the network. If we take a very conservative approach and
// say that we want to manage Si burning at 0.01 lum error (from the potential)
// then ν~50 so δL/ν ~ 0.01/50 ~ 2e-4. Therefore, so long as the potential
// error is below this it will not contribute more than ~1% error
// to the luminosity. This should be more rigorously validated at a latter
// date but is a useful target to use for now.
constexpr double APPROX_MAX_ACCEPTABLE_POTENTIAL_ERROR_SI_BURNING = 2e-4;
 
 
#define CONCAT_INNER(a, b) a##b
#define CONCAT(a, b) CONCAT_INNER(a, b)
 
constexpr std::string_view ANSI_GREEN = "\033[32m";
constexpr std::string_view ANSI_RED = "\033[31m";
constexpr std::string_view ANSI_YELLOW = "\033[33m";
constexpr std::string_view ANSI_BLUE = "\033[34m";
constexpr std::string_view ANSI_MAGENTA = "\033[35m";
constexpr std::string_view ANSI_CYAN = "\033[36m";
constexpr std::string_view ANSI_RESET = "\033[0m";
 
#define MAKE_UNIQUE_VAR_NAME(prefix) CONCAT(prefix, __COUNTER__)
 
#define START_TIMER(timer_name) \
auto start_##timer_name = std::chrono::steady_clock::now()
 
#define REPORT_TIMER(timer_name) \
std::println("Timer '{}': {}", #timer_name, elapsed_##timer_name)
 
#define END_TIMER(timer_name) \
auto end_##timer_name = std::chrono::steady_clock::now(); \
std::chrono::duration<double, std::milli> elapsed_##timer_name = end_##timer_name - start_##timer_name; \
REPORT_TIMER(timer_name)
 
 
#define RUN_TEST_IMPL(test_name, test_func_w_args, id) \
const int CONCAT(mpi_world_rank_, id) = mfem::Mpi::WorldRank(); \
auto CONCAT(start_var_, id) = std::chrono::steady_clock::now(); \
if (CONCAT(mpi_world_rank_, id) == 0) { \
std::println("{}===== TEST: {} ====={}", ANSI_MAGENTA, test_name, ANSI_RESET); \
} \
test_func_w_args; \
if (CONCAT(mpi_world_rank_, id) == 0) { \
auto CONCAT(end_var_, id) = std::chrono::steady_clock::now(); \
std::chrono::duration<double, std::milli> CONCAT(elapsed_var_, id) = \
CONCAT(end_var_, id) - CONCAT(start_var_, id); \
std::println("{}===== END TEST: {} ({}runtime: {:+0.2f}ms{}) ====={}", ANSI_MAGENTA, \
test_name, \
ANSI_YELLOW, \
CONCAT(elapsed_var_, id).count(), \
ANSI_MAGENTA, \
ANSI_RESET); \
}
 
#define RANK_GUARD(proc) \
    if (mfem::Mpi::WorldRank() == 0) { \
        proc \
    }
 
#define RUN_TEST(test_name, test_func_w_args) \
RUN_TEST_IMPL(test_name, test_func_w_args, __COUNTER__)
 
struct LogPointState {
    static inline thread_local std::chrono::steady_clock::time_point last_time{};
    static inline thread_local bool is_active{false};
};
 
 
#define LOG_PPOINT(prefix) \
RANK_GUARD( \
auto func_name = PLATFORM_STRINGIFY(prefix); \
[&]() { \
auto now = std::chrono::steady_clock::now(); \
if (LogPointState::is_active) { \
auto dur = std::chrono::duration_cast<std::chrono::milliseconds>(now - LogPointState::last_time).count(); \
std::println("Log Point: {} (Line: {}) [+{}ms]", prefix, __LINE__, dur); \
} else { \
std::println("Log Point: {} (Line: {}) [Started]", prefix, __LINE__); \
LogPointState::is_active = true; \
} \
LogPointState::last_time = now; \
}(); \
)
 
#define LOG_POINT LOG_PPOINT(__FUNCTION__)
//region Test Utilities
 
enum class TEST_RESULT_TYPE : uint8_t {
    SUCCESS,
    FAILURE,
    PARTIAL
};
 
std::string fmt_test_msg(const std::string_view test_name, const TEST_RESULT_TYPE type, size_t num_fails, size_t total) {
    std::string_view color;
    switch (type) {
        case TEST_RESULT_TYPE::SUCCESS:
            color = ANSI_GREEN;
            break;
        case TEST_RESULT_TYPE::FAILURE:
            color = ANSI_RED;
            break;
        case TEST_RESULT_TYPE::PARTIAL:
            color = ANSI_YELLOW;
            break;
        default:
            color = ANSI_RESET;
    }
    return std::format("{}[TEST: {}] {}/{}{}", color, test_name, total-num_fails, total, ANSI_RESET);
}
//endregion
 
//region Constants
/********************
 * Constants
 *********************/
constexpr double G = 1.0;
constexpr double MASS = 1.0;
constexpr double RADIUS = 1.0;
 
[[maybe_unused]] constexpr char HOST[10] = "localhost";
[[maybe_unused]] constexpr int PORT = 19916;
 
//endregion
 
//region Concepts and Typedefs
/********************
 * Concepts
 *********************/
template <typename T>
concept is_xad =
       std::is_same_v<T, xad::AReal<long double>>
    || std::is_same_v<T, xad::AReal<double>>
    || std::is_same_v<T, xad::AReal<float>>;
 
template <typename T>
concept is_real = std::is_floating_point_v<T> || is_xad<T>;
 
/********************
 * Type Defs
 *********************/
template <is_real T>
using EOS_P = std::function<T(T rho, T temp)>;
//endregion
 
//region User Argument Structs
/********************
 * User Args
 *********************/
struct potential {
    double rtol;
    double atol;
    int max_iters;
};
 
struct rot {
    bool enabled;
    double omega;
};
 
struct Args {
    std::string mesh_file;
    potential p{};
    rot r{};
    bool verbose{};
    double index{};
    double mass{};
    double c{};
 
    int quad_boost{0};
 
    int max_iters{};
    double tol{};
};
//endregion
 
//region Misc Structs
struct OblatePotential {
    bool use{false};
    double a{1};
    double c{1};
    double rho_0{1};
};
 
struct Bounds {
    double r_star_ref;
    double r_inf_ref;
};
 
enum BoundsError : uint8_t {
    CANNOT_FIND_VACUUM
};
//endregion
 
//region Domain Enums
enum class Domains : uint8_t {
    CORE = 1 << 0,
    ENVELOPE = 1 << 1,
    VACUUM = 1 << 2,
    STELLAR = CORE | ENVELOPE,
    ALL = CORE | ENVELOPE | VACUUM
};
 
inline Domains operator|(Domains lhs, Domains rhs) {
    return static_cast<Domains>(static_cast<uint8_t>(lhs) | static_cast<uint8_t>(rhs));
}
 
inline Domains operator&(Domains lhs, Domains rhs) {
    return static_cast<Domains>(static_cast<uint8_t>(lhs) & static_cast<uint8_t>(rhs));
}
 
enum class Boundaries : uint8_t {
    STELLAR_SURFACE = 1,
    INF_SURFACE = 2
};
 
inline int operator-(Boundaries b, const int a) {
    return static_cast<int>(static_cast<uint8_t>(b) - static_cast<uint8_t>(a));
}
 
//endregion
 
//region Domain Mapper
/********************
 * Mappers
 *********************/
class DomainMapper {
public:
    DomainMapper(
        const double r_star_ref,
        const double r_inf_ref
    ) :
    m_d(nullptr),
    m_r_star_ref(r_star_ref),
    m_r_inf_ref(r_inf_ref) {
        InitAllScratchSpaces();
    }
 
    explicit DomainMapper(
        const mfem::GridFunction &d,
        const double r_star_ref,
        const double r_inf_ref
    ) :
    m_d(&d),
    m_dim(d.FESpace()->GetMesh()->Dimension()),
    m_r_star_ref(r_star_ref),
    m_r_inf_ref(r_inf_ref) {
        InitAllScratchSpaces();
    };
 
 
    [[nodiscard]] bool is_vacuum(const mfem::ElementTransformation &T) const {
        if (T.ElementType == mfem::ElementTransformation::ELEMENT) {
            return T.Attribute == m_vacuum_attr;
        } else if (T.ElementType == mfem::ElementTransformation::BDR_ELEMENT) {
            return T.Attribute == m_vacuum_attr - 1; // TODO: In a more robust code this should really be read from the stroid API to ensure that the vacuum boundary is really 1 - the vacuum material attribute
        }
        return false;
    }
 
    void SetDisplacement(const mfem::GridFunction &d) {
        if (m_dim != d.FESpace()->GetMesh()->Dimension()) {
            const std::string err_msg = std::format("Dimension mismatch: DomainMapper is initialized for dimension {}, but provided displacement field has dimension {}.", m_dim, d.FESpace()->GetMesh()->Dimension());
            throw std::invalid_argument(err_msg);
        }
        m_d = &d;
        InvalidateCache();
    }
 
    [[nodiscard]] bool IsIdentity() const {
        return (m_d == nullptr);
    }
 
    void ResetDisplacement() {
        m_d = nullptr;
        InvalidateCache();
    }
 
    void ComputeJacobian(mfem::ElementTransformation &T, mfem::DenseMatrix &J) const {
        J.SetSize(m_dim, m_dim);
        J = 0.0;
        m_J_D = 0.0;
        if (IsIdentity()) {
            for (int i = 0; i < m_dim; ++i) {
                m_J_D(i, i) = 1.0; // Identity mapping
            }
        } else {
            UpdateElementCache(T);
            m_dshape.SetSize(m_fe->GetDof(), m_dim);
            m_fe->CalcPhysDShape(T, m_dshape);
            mfem::MultAtB(m_dof_mat, m_dshape, m_J_D);
 
            for (int i = 0; i < m_dim; ++i) {
                m_J_D(i, i) += 1.0;
            }
 
        }
 
        if (is_vacuum(T)) {
            T.Transform(T.GetIntPoint(), m_x_ref);
 
            if (IsIdentity()) {
                m_x_disp = m_x_ref;
            } else {
                m_shape.SetSize(m_fe->GetDof());
                m_fe->CalcShape(T.GetIntPoint(), m_shape);
                m_dof_mat.MultTranspose(m_shape, m_d_val);
                add(m_x_ref, m_d_val, m_x_disp);
            }
 
            ComputeKelvinJacobian(m_x_ref, m_x_disp, m_J_D, J);
        } else {
            J = m_J_D;
        }
    }
 
    double ComputeDetJ(mfem::ElementTransformation& T, const mfem::IntegrationPoint& ip) const {
        if (IsIdentity() && !is_vacuum(T)) return 1.0; // If no mapping, the determinant of the Jacobian is 1
        T.SetIntPoint(&ip);
        mfem::DenseMatrix J;
        ComputeJacobian(T, J);
        return J.Det();
    }
 
    void ComputeMappedDiffusionTensor(mfem::ElementTransformation &T, mfem::DenseMatrix &D) const {
        ComputeJacobian(T, m_J_temp);
        const double detJ = m_J_temp.Det();
        mfem::CalcInverse(m_J_temp, m_JInv_temp);
        D.SetSize(m_dim, m_dim);
        mfem::MultABt(m_JInv_temp, m_JInv_temp, D);
        D *= fabs(detJ);
    }
 
    void ComputeInverseJacobian(mfem::ElementTransformation &T, mfem::DenseMatrix &JInv) const {
        ComputeJacobian(T, m_J_temp);
        JInv.SetSize(m_dim, m_dim);
        mfem::CalcInverse(m_J_temp, JInv);
    }
 
    void GetPhysicalPoint(mfem::ElementTransformation& T, const mfem::IntegrationPoint& ip, mfem::Vector& x_phys) const {
        x_phys.SetSize(m_dim);
        T.Transform(ip, m_x_ref);
 
        if (IsIdentity()) {
            x_phys = m_x_ref;
        } else {
            UpdateElementCache(T);
 
            m_shape.SetSize(m_fe->GetDof());
            m_fe->CalcShape(ip, m_shape);
 
            m_dof_mat.MultTranspose(m_shape, m_d_val);
            add(m_x_ref, m_d_val, x_phys);
        }
        if (is_vacuum(T)) {
            ApplyKelvinMapping(m_x_ref, x_phys);
        }
    }
 
    [[nodiscard]] const mfem::GridFunction* GetDisplacement() const { return m_d; }
 
    [[nodiscard]] double GetPhysInfRadius() const {
        return 1.0 / (1.0 - m_xi_clamp);
    }
 
    [[nodiscard]] size_t GetCacheHits() const {
        return m_cache_hits;
    }
 
    [[nodiscard]] size_t GetCacheMisses() const {
        return m_cache_misses;
    }
 
    [[nodiscard]] double GetCacheHitRate() const {
        return (static_cast<double>(m_cache_hits)) / static_cast<double>(m_cache_misses + m_cache_hits);
    }
 
    void ResetCacheStats() const {
        m_cache_hits = 0;
        m_cache_misses = 0;
    }
 
 
private:
    void InitAllScratchSpaces() const {
        m_J_D.SetSize(m_dim, m_dim);
        m_J_temp.SetSize(m_dim, m_dim);
        m_JInv_temp.SetSize(m_dim, m_dim);
        m_x_ref.SetSize(m_dim);
        m_x_disp.SetSize(m_dim);
        m_d_val.SetSize(m_dim);
    }
 
    void ApplyKelvinMapping(const mfem::Vector& x_ref, mfem::Vector& x_phys) const {
        const double r_ref = x_ref.Norml2();
        double xi = (r_ref - m_r_star_ref) / (m_r_inf_ref - m_r_star_ref);
        xi = std::clamp(xi, 0.0, m_xi_clamp);
        const double factor = m_r_star_ref / (r_ref * (1 - xi));
        x_phys *= factor;
    }
 
    void ComputeKelvinJacobian(const mfem::Vector& x_ref, const mfem::Vector &x_disp, const mfem::DenseMatrix &J_D, mfem::DenseMatrix& J) const {
        const double r_ref = x_ref.Norml2();
        const double delta_R = m_r_inf_ref - m_r_star_ref;
 
        double xi = (r_ref - m_r_star_ref) / delta_R;
        xi = std::clamp(xi, 0.0, m_xi_clamp);
 
        const double denom = 1.0 - xi;
 
        const double k = m_r_star_ref / (r_ref * denom);
 
        const double dk_dr = m_r_star_ref * (( 1.0 / (delta_R* r_ref * denom * denom)) - ( 1.0 / (r_ref * r_ref * denom)));
 
        J.SetSize(m_dim, m_dim);
        const double outer_factor = dk_dr / r_ref;
 
        for (int i = 0; i < m_dim; ++i) {
            for (int j = 0; j < m_dim; ++j) {
                J(i, j) = outer_factor * x_disp(i) * x_ref(j) + k * J_D(i, j);
 
            }
        }
    }
 
    void InvalidateCache() const {
        m_cached_elem_id = -1;
    }
 
    void UpdateElementCache(const mfem::ElementTransformation& T) const {
        if (IsIdentity()) return;
 
        if (T.ElementNo != m_cached_elem_id || T.ElementType != m_cached_elem_type) {
            m_cache_misses++;
            m_cached_elem_id = T.ElementNo;
            m_cached_elem_type = T.ElementType;
 
            const mfem::FiniteElementSpace *fes = m_d->FESpace();
            mfem::Array<int> vdofs;
 
            if (T.ElementType == mfem::ElementTransformation::ELEMENT) {
                m_fe = fes->GetFE(m_cached_elem_id);
                fes->GetElementVDofs(m_cached_elem_id, vdofs);
            } else {
                m_fe = fes->GetBE(m_cached_elem_id);
                fes->GetBdrElementVDofs(m_cached_elem_id, vdofs);
            }
 
            m_d->GetSubVector(vdofs, m_elem_dofs);
 
            const int nd = m_fe->GetDof();
            const int vd = fes->GetVDim();
 
            m_dof_mat.UseExternalData(m_elem_dofs.GetData(), nd, vd);
        } else {
            m_cache_hits++;
        }
    }
 
private:
    const mfem::GridFunction *m_d;
    std::unique_ptr<mfem::GridFunction> m_internal_d;
    const int m_dim{3};
    const int m_vacuum_attr{3};
    const double m_r_star_ref{1.0};
    const double m_r_inf_ref{2.0};
    const double m_xi_clamp{0.99999999};
 
    mutable int m_cached_elem_id{-1};
    mutable int m_cached_elem_type{mfem::ElementTransformation::ELEMENT};
    mutable const mfem::FiniteElement* m_fe{nullptr};
 
    mutable mfem::Vector m_elem_dofs;
    mutable mfem::DenseMatrix m_dof_mat;
    mutable mfem::DenseMatrix m_dshape;
    mutable mfem::Vector m_shape;
 
    mutable size_t m_cache_hits{0};
    mutable size_t m_cache_misses{0};
 
    mutable mfem::DenseMatrix m_J_D;
    mutable mfem::DenseMatrix m_J_temp;
    mutable mfem::DenseMatrix m_JInv_temp;
    mutable mfem::Vector m_x_ref;
    mutable mfem::Vector m_x_disp;
    mutable mfem::Vector m_d_val;
};
 
//endregion
 
/********************
 * Cache Types
 *********************/
 
//region State Types
class MappedScalarCoefficient;
 
/********************
 * State Types
 *********************/
struct LORPrecWrapper : public mfem::Solver {
    mfem::Solver& m_amg;
    explicit LORPrecWrapper(mfem::Solver& amg) : mfem::Solver(amg.Height(), amg.Width()), m_amg(amg) {}
 
    void SetOperator(const Operator &op) override {};
 
    void Mult(const mfem::Vector &x, mfem::Vector &y) const override { m_amg.Mult(x, y); }
};
 
struct GravityContext {
    std::unique_ptr<mfem::ParBilinearForm> ho_laplacian;
    std::unique_ptr<mfem::ParBilinearForm> lor_laplacian;
 
    std::unique_ptr<mfem::ParGridFunction> phi;
 
 
 
    mfem::Array<int> ho_ess_tdof_list;
    mfem::Array<int> lor_ess_tdof_list;
 
    mfem::Array<int> stellar_mask;
 
    std::unique_ptr<mfem::MatrixCoefficient> diff_coeff;
 
    std::unique_ptr<mfem::ParLinearForm> b;
    std::unique_ptr<mfem::GridFunctionCoefficient> rho_coeff;
    std::unique_ptr<mfem::ConstantCoefficient> four_pi_G_coeff;
    std::unique_ptr<mfem::ProductCoefficient> rhs_coeff;
    std::unique_ptr<MappedScalarCoefficient> mapped_rhs_coeff;
    std::unique_ptr<mfem::ConstantCoefficient> unit_coeff;
 
    std::unique_ptr<mfem::HypreBoomerAMG> amg_prec;
    std::unique_ptr<LORPrecWrapper> amg_wrapper;
    std::unique_ptr<mfem::GMRESSolver> solver;
 
    mfem::OperatorPtr A_ho;
    mfem::OperatorPtr A_lor;
 
    mfem::Vector B_true;
    mfem::Vector X_true;
 
};
 
struct BoundaryContext {
    mfem::Array<int> inf_bounds;
    mfem::Array<int> stellar_bounds;
};
 
struct FEM {
    std::unique_ptr<mfem::ParMesh> mesh;
    std::unique_ptr<mfem::FiniteElementCollection> H1_fec;
    std::unique_ptr<mfem::ParFiniteElementSpace> H1_fes;
    std::unique_ptr<mfem::ParFiniteElementSpace> Vec_H1_fes;
    std::unique_ptr<mfem::ParLORDiscretization> H1_lor_disc;
    const mfem::ParFiniteElementSpace* H1_lor_fes{nullptr};
    std::unique_ptr<DomainMapper> mapping;
 
    mfem::Array<int> block_true_offsets;
    std::unique_ptr<mfem::ParGridFunction> reference_x;
 
    mfem::Vector com;
    mfem::DenseMatrix Q;
    mfem::Array<int> ess_tdof_x;
 
 
    int int_order{3};
    std::unique_ptr<mfem::IntegrationRule> int_rule;
 
    GravityContext gravity_context;
    BoundaryContext boundary_context;
 
 
    [[nodiscard]] bool okay() const { return (mesh != nullptr) && (H1_fec != nullptr) && (H1_fes != nullptr) && (Vec_H1_fes != nullptr); }
 
    [[nodiscard]] bool has_mapping() const { return mapping != nullptr; }
};
//endregion
 
//region Function Definitions
/********************
 * Core Setup Functions
 *********************/
FEM setup_fem(const std::string& filename, const Args &args);
 
/********************
 * Utility Functions
 *********************/
void view_mesh(const std::string& host, int port, const mfem::Mesh& mesh, const mfem::GridFunction& gf, const std::string& title);
double domain_integrate_grid_function(const FEM& fem, const mfem::GridFunction& gf, Domains domain = Domains::ALL);
mfem::Vector get_com(const FEM& fem, const mfem::GridFunction &rho);
void get_physical_coordinates(const mfem::GridFunction& reference_pos, const mfem::GridFunction& displacement, mfem::GridFunction& physical_pos);
void populate_element_mask(const FEM& fem, Domains domain, mfem::Array<int>& mask);
std::expected<Bounds, BoundsError> DiscoverBounds(const mfem::Mesh *mesh, int vacuum_attr);
double EvalGridFunctionAtPhysicalPoint(const FEM& fem, const mfem::GridFunction& u, const mfem::Vector& x_phys);
 
int get_mesh_order(const mfem::Mesh &mesh);
void conserve_mass(const FEM& fem, mfem::GridFunction& rho, double target_mass);
 
/********************
 * Physics Functions
 *********************/
double centrifugal_potential(const mfem::Vector& phys_x, double omega);
double get_moment_of_inertia(const FEM& fem, const mfem::GridFunction& rho);
double oblate_spheroid_surface_potential(const mfem::Vector& x, double a, double c, double total_mass);
 
const mfem::GridFunction &grav_potential(FEM &fem, const Args &args, const mfem::GridFunction &rho,
                                         bool phi_warm = false);
 
mfem::GridFunction get_potential(FEM &fem, const Args &args, const mfem::GridFunction &rho);
mfem::DenseMatrix compute_quadrupole_moment_tensor(const FEM& fem, const mfem::GridFunction& rho, const mfem::Vector& com);
double l2_multipole_potential(const FEM& fem, double total_mass, const mfem::Vector& phys_x);
//endregion
 
//region Mapping Coefficients
class MappedScalarCoefficient : public mfem::Coefficient {
public:
    enum class EVAL_POINTS : uint8_t {
        PHYSICAL,
        REFERENCE
    };
 
    MappedScalarCoefficient(
        const DomainMapper& map,
        mfem::Coefficient& coeff,
        const EVAL_POINTS eval_point=EVAL_POINTS::PHYSICAL
    ) :
    m_map(map),
    m_coeff(coeff),
    m_eval_point(eval_point) {};
 
    double Eval(mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip) override {
        T.SetIntPoint(&ip);
 
        double detJ = m_map.ComputeDetJ(T, ip);
        double f_val = 0.0;
 
        switch (m_eval_point) {
            case EVAL_POINTS::PHYSICAL: {
                f_val = eval_at_point(m_coeff, T, ip);
                break;
            }
            case EVAL_POINTS::REFERENCE: {
                f_val = m_coeff.Eval(T, ip);
                break;
            }
        }
        return f_val * fabs(detJ);
    }
private:
    static double eval_at_point(mfem::Coefficient& c, mfem::ElementTransformation& T, const mfem::IntegrationPoint& ip) {
        return c.Eval(T, ip);
    }
 
private:
    const DomainMapper& m_map;
    mfem::Coefficient& m_coeff;
    EVAL_POINTS m_eval_point;
 
};
 
class MappedDiffusionCoefficient : public mfem::MatrixCoefficient {
public:
    MappedDiffusionCoefficient(
        const DomainMapper& map,
        mfem::Coefficient& sigma,
        const int dim
    ) :
    mfem::MatrixCoefficient(dim),
    m_map(map),
    m_scalar(&sigma),
    m_tensor(nullptr) {};
 
    MappedDiffusionCoefficient(
        const DomainMapper& map,
        mfem::MatrixCoefficient& sigma
    ) :
    mfem::MatrixCoefficient(sigma.GetHeight()),
    m_map(map),
    m_scalar(nullptr),
    m_tensor(&sigma) {};
 
    void Eval(mfem::DenseMatrix &K, mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip) override {
        const int dim = height;
        T.SetIntPoint(&ip);
 
        mfem::DenseMatrix J(dim, dim), JInv(dim, dim);
        m_map.ComputeJacobian(T, J);
        const double detJ = J.Det();
        mfem::CalcInverse(J, JInv);
 
        if (m_scalar) {
            const double sig_val = m_scalar->Eval(T, ip);
            mfem::MultABt(JInv, JInv, K);
            K *= sig_val * fabs(detJ);
        } else {
            mfem::DenseMatrix sig_mat(dim, dim);
            m_tensor->Eval(sig_mat, T, ip);
 
            mfem::DenseMatrix temp(dim, dim);
            Mult(JInv, sig_mat, temp);
 
            MultABt(temp, JInv, K);
            K *= fabs(detJ);
        }
    }
private:
    const DomainMapper& m_map;
    mfem::Coefficient* m_scalar;
    mfem::MatrixCoefficient* m_tensor;
};
 
class MappedVectorCoefficient : public mfem::VectorCoefficient {
public:
    MappedVectorCoefficient(
        const DomainMapper& map,
        mfem::VectorCoefficient& coeff
    ) :
    mfem::VectorCoefficient(coeff.GetVDim()),
    m_map(map),
    m_coeff(coeff) {};
 
    void Eval(mfem::Vector& V, mfem::ElementTransformation& T, const mfem::IntegrationPoint& ip) override {
        const int dim = vdim;
        T.SetIntPoint(&ip);
 
        mfem::DenseMatrix JInv(dim, dim);
        m_map.ComputeInverseJacobian(T, JInv);
        double detJ = m_map.ComputeDetJ(T, ip);
 
        mfem::Vector C_phys(dim);
        m_coeff.Eval(C_phys, T, ip);
 
        V.SetSize(dim);
        JInv.Mult(C_phys, V);
        V *= fabs(detJ);
    }
private:
    const DomainMapper& m_map;
    mfem::VectorCoefficient& m_coeff;
};
 
class PhysicalPositionFunctionCoefficient : public mfem::Coefficient {
public:
    using Func = std::function<double(const mfem::Vector& x)>;
 
    PhysicalPositionFunctionCoefficient(
        const DomainMapper& map,
        Func f
    ) :
    m_f(std::move(f)),
    m_map(map) {};
 
    double Eval(mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip) override {
        T.SetIntPoint(&ip);
        mfem::Vector x;
        m_map.GetPhysicalPoint(T, ip, x);
        return m_f(x);
    }
private:
    Func m_f;
    const DomainMapper& m_map;
};
//endregion
 
//region Integrators
/********************
 * Integrators
 *********************/
template <is_xad EOS_T>
class FluidIntegrator : public mfem::NonlinearFormIntegrator {
    using Scalar = EOS_T::value_type;
public:
    explicit FluidIntegrator(
        const FEM& fem,
        EOS_P<EOS_T> eos,
        const DomainMapper* map = nullptr
    ) :
    m_fem(fem),
    m_eos(std::move(eos)),
    m_map(map)
    {};
 
    void AssembleElementVector(
        const mfem::FiniteElement &el,
        mfem::ElementTransformation &Tr,
        const mfem::Vector &elfun,
        mfem::Vector &elvect
    ) override {
        const int dof = el.GetDof();
        elvect.SetSize(dof);
        elvect = 0.0;
 
        const mfem::IntegrationRule *ir = m_fem.int_rule.get();
 
        mfem::Vector shape(dof);
        for (int i = 0; i < ir->GetNPoints(); i++) {
            const mfem::IntegrationPoint& ip = ir->IntPoint(i);
            Tr.SetIntPoint(&ip);
            el.CalcShape(ip, shape);
 
            double u = shape * elfun;
            EOS_T rho = u;
            const double val = m_eos(rho, 0.0).value();
            double weight = ip.weight * Tr.Weight() * val;
 
            if (m_map) weight *= fabs(m_map->ComputeDetJ(Tr, ip));
 
 
            elvect.Add(weight, shape);
        }
    }
    void AssembleElementGrad(
        const mfem::FiniteElement &el,
        mfem::ElementTransformation &Tr,
        const mfem::Vector &elfun,
        mfem::DenseMatrix &elmat
    ) override {
        const int dof = el.GetDof();
        elmat.SetSize(dof);
        elmat = 0.0;
 
        const mfem::IntegrationRule *ir = m_fem.int_rule.get();
 
        mfem::Vector shape(dof);
        for (int i = 0; i < ir->GetNPoints(); i++) {
            const mfem::IntegrationPoint& ip = ir->IntPoint(i);
            Tr.SetIntPoint(&ip);
            el.CalcShape(ip, shape);
 
            double u = shape * elfun;
 
            double d_val_d_rho = 0.0;
            if (u > 1e-15) {
                xad::Tape<Scalar> tape;
                EOS_T x_r = u;
                EOS_T x_t = 0.0; // In future this is one area where we introduce a temp dependency
 
                tape.registerInput(x_r);
                EOS_T result = m_eos(x_r, x_t);
                tape.registerOutput(result);
                result.setAdjoint(1.0);
                tape.computeAdjoints();
                d_val_d_rho = x_r.getAdjoint();
            }
 
            double weight = ip.weight * Tr.Weight() * d_val_d_rho;
            if (m_map) weight *= fabs(m_map->ComputeDetJ(Tr, ip));
 
            mfem::AddMult_a_VVt(weight, shape, elmat);
        }
    }
 
    [[nodiscard]] bool has_mapping() const { return m_map != nullptr; }
 
    void set_mapping(const DomainMapper* map) { m_map = map; }
 
    void clear_mapping() { m_map = nullptr; }
private:
    const FEM& m_fem;
    EOS_P<EOS_T> m_eos;
    const DomainMapper* m_map{nullptr};
};
 
//endregion
 
//region Coefficients
/********************
 * Coefficient Defs
 *********************/
template <is_xad EOS_T>
class PressureBoundaryForce : public mfem::VectorCoefficient {
public:
    PressureBoundaryForce(
        const int dim,
        const FEM& fem,
        const mfem::GridFunction& rho,
        const EOS_P<EOS_T>& eos,
        const double P_fit
    ) : VectorCoefficient(dim), m_fem(fem), m_rho(rho), m_eos(eos), m_P_fit(P_fit) {};
 
    void Eval(mfem::Vector &V, mfem::ElementTransformation &T, const mfem::IntegrationPoint &ip) override {
        V.SetSize(vdim);
        V = 0.0;
 
        double rho = m_rho.GetValue(T, ip);
 
        EOS_T x_rho = rho;
        const double P_curr = m_eos(x_rho, 0.0).value();
        const double delta_P = P_curr - m_P_fit;
 
        mfem::Vector phys(vdim);
        T.Transform(ip, phys);
        mfem::Vector normal(vdim);
        mfem::CalcOrtho(T.Jacobian(), normal);
 
        for (int i = 0; i < vdim; ++i) {
            V(i) = delta_P * normal(i);
        }
    }
private:
    const FEM& m_fem;
    const mfem::GridFunction& m_rho;
    const EOS_P<EOS_T>& m_eos;
    double m_P_fit;
};
 
//endregion
 
//region Operators
/********************
 * Operator Defs
 *********************/
template <is_xad EOS_T>
class VectorOperator : public mfem::Operator {
public:
    VectorOperator() = default;
 
    VectorOperator(
        const mfem::Vector& v,
        const bool is_col
    )  :
    Operator(is_col ? v.Size() : 1, is_col ? 1 : v.Size()),
    m_v(v),
    m_is_col(is_col),
    m_is_initialized(true) {}
 
    void SetVector(const mfem::Vector& v, const bool is_col) {
        if (v.Size() != m_v.Size()) {
            m_v.SetSize(v.Size());
        }
        m_v = v;
        m_is_col = is_col;
        height = is_col ? v.Size() : 1;
        width = is_col ? 1 : v.Size();
        m_is_initialized = true;
    }
 
    void Mult(const mfem::Vector &x, mfem::Vector &y) const override {
        if (!m_is_initialized) throw std::runtime_error("VectorOperator Not initialized");
        if (m_is_col) {
            y.SetSize(m_v.Size());
            y = 0.0;
            y.Add(x(0), m_v);
        } else {
            y.SetSize(1);
            y(0) = m_v * x;
        }
    }
private:
    mfem::Vector m_v;
    bool m_is_col{false};
    bool m_is_initialized{false};
};
 
template <is_xad EOS_T>
class PressureDensityCoupling : public mfem::Operator {
public:
    PressureDensityCoupling(
        FEM& fem,
        const mfem::GridFunction& rho,
        const EOS_P<EOS_T>& eos_pressure)
    : Operator(fem.Vec_H1_fes->GetTrueVSize(), fem.H1_fes->GetTrueVSize()),
    m_fem(fem),
    m_rho(rho),
    m_eos(eos_pressure) {
        Assemble();
    }
 
    void Assemble() {
        const int dim = m_fem.mesh->Dimension();
        const int scalar_size = m_fem.H1_fes->GetTrueVSize();
        const int vector_size  = m_fem.Vec_H1_fes->GetTrueVSize();
 
        m_mat = std::make_unique<mfem::SparseMatrix>(vector_size, scalar_size);
 
        for (int be = 0; be < m_fem.mesh->GetNBE(); ++be) {
            auto* ftr = m_fem.mesh->GetBdrFaceTransformations(be);
            if (!ftr) continue;
 
            const int elem = ftr->Elem1No;
 
            const auto& scalar_fe = *m_fem.H1_fes->GetFE(elem);
 
            const int sdof = scalar_fe.GetDof();
 
            mfem::Array<int> scalar_dofs, vector_dofs;
            m_fem.H1_fes->GetElementDofs(elem, scalar_dofs);
            m_fem.Vec_H1_fes->GetElementDofs(elem, vector_dofs);
 
            mfem::DenseMatrix elmat(vector_dofs.Size(), scalar_dofs.Size());
            elmat = 0.0;
 
            const auto& face_ir = mfem::IntRules.Get(ftr->GetGeometryType(), 2 * scalar_fe.GetOrder() + 2);
 
            mfem::Vector shape(sdof);
            mfem::Vector normal(dim);
 
            for (int q = 0; q < face_ir.GetNPoints(); ++q) {
                const auto& face_ip = face_ir.IntPoint(q);
 
                ftr->SetAllIntPoints(&face_ip);
 
                const mfem::IntegrationPoint& vol_ip = ftr->GetElement1IntPoint();
 
                scalar_fe.CalcShape(vol_ip, shape);
 
                mfem::CalcOrtho(ftr->Face->Jacobian(), normal);
 
                mfem::ElementTransformation& vol_trans = ftr->GetElement1Transformation();
                vol_trans.SetIntPoint(&vol_ip);
                double rho_val = m_rho.GetValue(vol_trans, vol_ip);
 
                double dPdrho = 0.0;
                if (rho_val > 1e-15) {
                    using Scalar = EOS_T::value_type;
                    xad::Tape<Scalar> tape;
 
                    EOS_T x_rho = rho_val;
                    tape.registerInput(x_rho);
                    EOS_T P = m_eos(x_rho, EOS_T(0.0));
                    tape.registerOutput(P);
 
                    P.setAdjoint(1.0);
 
                    tape.computeAdjoints();
                    dPdrho = x_rho.getAdjoint();
                }
 
                const double w = face_ip.weight;
 
                for (int k = 0; k < sdof; ++k) {
                    for (int j = 0; j < sdof; ++j) {
                        double base = -dPdrho * shape(j) * shape(k) * w;
                        for (int c = 0; c < dim; ++c) {
                            elmat(k + c * sdof, j) += base * normal(c);
                        }
                    }
                }
            }
            m_mat->AddSubMatrix(vector_dofs, scalar_dofs, elmat);
        }
        m_mat->Finalize();
    }
 
    void Mult(const mfem::Vector &x, mfem::Vector &y) const override {
        mfem::Vector x_loc(m_fem.H1_fes->GetVSize());
        mfem::Vector y_loc(m_fem.Vec_H1_fes->GetVSize());
 
        m_fem.H1_fes->GetProlongationMatrix()->Mult(x, x_loc);
        m_mat->Mult(x_loc, y_loc);
        m_fem.Vec_H1_fes->GetRestrictionMatrix()->Mult(y_loc, y);
    }
 
    [[nodiscard]] mfem::SparseMatrix& SpMat() const { return *m_mat; }
 
private:
    FEM& m_fem;
    const mfem::GridFunction& m_rho;
    const EOS_P<EOS_T>& m_eos;
    std::unique_ptr<mfem::SparseMatrix> m_mat;
};
 
template <is_xad EOS_T>
class MassDisplacementCoupling : public mfem::Operator {
public:
    MassDisplacementCoupling(
        FEM& fem,
        const mfem::GridFunction& rho,
        const bool is_col
    ) :
    Operator(
        is_col ? fem.Vec_H1_fes->GetTrueVSize() : 1,
        is_col ? 1 : fem.Vec_H1_fes->GetTrueVSize()
    ),
    m_fem(fem),
    m_rho(rho),
    m_is_col(is_col){
        m_D.SetSize(m_fem.Vec_H1_fes->GetTrueVSize());
        Assemble();
    }
 
    void Assemble() const {
        const int dim = m_fem.mesh->Dimension();
 
        mfem::Vector D_loc(m_fem.Vec_H1_fes->GetVSize());
        D_loc = 0.0;
        for (int elemID = 0; elemID < m_fem.mesh->GetNE(); ++elemID) {
            auto* trans = m_fem.mesh->GetElementTransformation(elemID);
            const auto& fe = *m_fem.Vec_H1_fes->GetFE(elemID);
            const int dof = fe.GetDof();
 
            mfem::Array<int> vdofs;
            m_fem.Vec_H1_fes->GetElementDofs(elemID, vdofs);
 
            const auto& ir = mfem::IntRules.Get(trans->GetGeometryType(), 2 * fe.GetOrder() + 1);
 
            mfem::DenseMatrix dshape(dof, dim);
            mfem::Vector elvec(dof * dim);
            elvec = 0.0;
 
            for (int q = 0; q < ir.GetNPoints(); ++q) {
                const auto& ip = ir.IntPoint(q);
                trans->SetIntPoint(&ip);
 
                fe.CalcPhysDShape(*trans, dshape);
                double rho_val = m_rho.GetValue(elemID, ip);
                double ref_weight = trans->Weight() * ip.weight;
 
                mfem::DenseMatrix J_map(dim, dim), J_inv(dim, dim);
                m_fem.mapping->ComputeJacobian(*trans, J_map);
                double detJ = J_map.Det();
                mfem::CalcInverse(J_map, J_inv);
 
                for (int k = 0; k < dof; ++k) {
                    for (int c = 0; c < dim; ++c) {
                        double trace_contrib = 0.0;
                        for (int j = 0; j < dim; ++j) {
                            trace_contrib += J_inv(j, c) * dshape(k, j);
                        }
                        elvec(k + c * dof) += rho_val * fabs(detJ) * trace_contrib * ref_weight;
                    }
                }
            }
            D_loc.AddElementVector(vdofs, elvec);
        }
        m_fem.Vec_H1_fes->GetRestrictionMatrix()->Mult(D_loc, m_D); // Project down into the restricted parallel space
    }
 
    [[nodiscard]] mfem::Vector& GetVec() const {
        return m_D;
    }
 
    void Mult(const mfem::Vector &x, mfem::Vector &y) const override {
        if (m_is_col) {
            y.SetSize(m_D.Size());
            y = 0.0;
            y.Add(x(0), m_D);
        } else {
            y.SetSize(1);
            y(0) = m_D * x;
        }
    }
 
 
private:
    const FEM& m_fem;
    const mfem::GridFunction& m_rho;
    const bool m_is_col;
 
    mutable mfem::Vector m_D;
};
 
class ReferenceGravityModel {
public:
    virtual ~ReferenceGravityModel() = default;
    virtual void UpdateState(const FEM& fem, const mfem::GridFunction& current_rho) = 0;
    [[nodiscard]] virtual double EvalDensity (const mfem::Vector& x_phys) const = 0;
    [[nodiscard]] virtual double EvalPotential (const mfem::Vector& x_phys) const = 0;
};
 
class UniformReferenceGravityModel : public ReferenceGravityModel {
public:
    explicit UniformReferenceGravityModel(const double R) : m_R(R), m_rho0(0.0), m_M(0.0) {};
    void UpdateState(const FEM &fem, const mfem::GridFunction &current_rho) override {
        m_M = domain_integrate_grid_function(fem, current_rho, Domains::STELLAR);
        m_rho0 = m_M / ((4.0 / 3.0) * M_PI * pow(m_R, 3));
    }
 
    [[nodiscard]] double EvalDensity (const mfem::Vector& x_phys) const override {
        return (x_phys.Norml2() <= m_R) ? m_rho0 : 0.0;
    }
 
    [[nodiscard]] double EvalPotential (const mfem::Vector& x_phys) const override {
        const double r = x_phys.Norml2();
        if (r <= m_R) {
            return (-G * m_M / (2.0 * std::pow(m_R, 3.0))) * (3.0 * m_R * m_R - r * r);
        }
        return -G * m_M / r;
    }
 
private:
    double m_R;
    double m_rho0;
    double m_M;
};
 
class DynamicPotentialSplitter {
public:
    explicit DynamicPotentialSplitter(
        std::unique_ptr<ReferenceGravityModel> ref_gravity_model
    ) :
    m_grav_model(std::move(ref_gravity_model)) {}
 
    void ComputeTotalPotential(FEM& fem, const Args& args, const mfem::GridFunction& rho, mfem::ParGridFunction& phi_total) {
        m_grav_model->UpdateState(fem, rho);
 
        mfem::GridFunction rho_ref(fem.H1_fes.get());
        PhysicalPositionFunctionCoefficient rho_ref_coeff(
            *fem.mapping,
            [this](const mfem::Vector& x) {return m_grav_model->EvalDensity(x); }
        );
        rho_ref.ProjectCoefficient(rho_ref_coeff);
 
        mfem::GridFunction delta_rho(fem.H1_fes.get());
        delta_rho = rho;
        delta_rho -= rho_ref;
 
        fem.com = get_com(fem, delta_rho);
        fem.Q = compute_quadrupole_moment_tensor(fem, delta_rho, fem.com);
 
        const mfem::GridFunction& delta_phi = grav_potential(fem, args, delta_rho, false);
 
        mfem::GridFunction phi_ref(fem.H1_fes.get());
        PhysicalPositionFunctionCoefficient phi_ref_coeff(
            *fem.mapping,
            [this](const mfem::Vector& x) {return m_grav_model->EvalPotential(x); }
        );
 
        phi_ref.ProjectCoefficient(phi_ref_coeff);
 
        phi_total.SetSpace(fem.H1_fes.get());
        phi_total = delta_phi;
        phi_total += phi_ref;
    }
private:
    std::unique_ptr<ReferenceGravityModel> m_grav_model;
};
 
template <is_xad EOS_T>
class ResidualOperator : public mfem::Operator {
public:
    ResidualOperator(
        FEM& fem,
        const Args& args,
        const EOS_P<EOS_T>& eos_enthalpy,
        const EOS_P<EOS_T>& eos_pressure,
        const double alpha
    ) :
    Operator(fem.block_true_offsets.Last()),
    m_fem(fem),
    m_args(args),
    m_eos_enthalpy(eos_enthalpy),
    m_eos_pressure(eos_pressure),
    m_alpha(std::make_unique<mfem::ConstantCoefficient>(alpha)),
    m_fluid_nlf(m_fem.H1_fes.get()),
    m_reference_stiffness(m_fem.Vec_H1_fes.get())
    {
        auto* fluid_integrator = new FluidIntegrator<EOS_T>(m_fem, m_eos_enthalpy, m_fem.has_mapping() ? m_fem.mapping.get() : nullptr);
        fluid_integrator->SetIntRule(fem.int_rule.get());
 
        populate_element_mask(m_fem, Domains::STELLAR, m_stellar_mask);
        m_bdr_mask.SetSize(m_fem.mesh->bdr_attributes.Max());
        m_bdr_mask = 0;
        m_bdr_mask[0] = 1;
 
        m_fluid_nlf.AddDomainIntegrator(fluid_integrator, m_stellar_mask);
 
        auto* alpha_integrator = new mfem::VectorMassIntegrator(*m_alpha);
        alpha_integrator->SetIntRule(fem.int_rule.get());
 
        auto* diff_integrator = new mfem::VectorDiffusionIntegrator();
        diff_integrator->SetIntRule(fem.int_rule.get());
 
        m_reference_stiffness.AddDomainIntegrator(alpha_integrator, m_stellar_mask);
        m_reference_stiffness.AddDomainIntegrator(diff_integrator, m_stellar_mask);
 
        // Add an identity component to avoid 0 on diagonals in the vacuum
        mfem::Array<int> vacuum_mask;
        populate_element_mask(m_fem, Domains::VACUUM, vacuum_mask);
 
        m_dummy_one = std::make_unique<mfem::ConstantCoefficient>(1.0);
        auto* vacuum_padding = new mfem::VectorMassIntegrator(*m_dummy_one);
        vacuum_padding->SetIntRule(fem.int_rule.get());
        m_reference_stiffness.AddDomainIntegrator(vacuum_padding, vacuum_mask);
 
        m_reference_stiffness.Assemble();
        m_reference_stiffness.Finalize();
 
        // =============================
        // Persistent Memory Allocation
        // =============================
        m_state_rho_gf = std::make_unique<mfem::ParGridFunction>(m_fem.H1_fes.get());
        m_state_d_gf = std::make_unique<mfem::ParGridFunction>(m_fem.Vec_H1_fes.get());
        *m_state_rho_gf = 0.0;
        *m_state_d_gf = 0.0;
 
        m_approx_jacobian = std::make_unique<mfem::BlockOperator>(m_fem.block_true_offsets);
        m_block_prec = std::make_unique<mfem::BlockLowerTriangularPreconditioner>(m_fem.block_true_offsets);
 
        m_ident_rho = std::make_unique<mfem::IdentityOperator>(m_fem.H1_fes->GetTrueVSize());
 
        mfem::ParBilinearForm M_rho(m_fem.H1_fes.get());
        M_rho.AddDomainIntegrator(new mfem::MassIntegrator(), m_stellar_mask);
        M_rho.AddDomainIntegrator(new mfem::MassIntegrator(*m_dummy_one), vacuum_mask);
 
        M_rho.Assemble();
        M_rho.Finalize();
 
        m_M_rho_mat.reset(M_rho.ParallelAssemble());
        m_M_rho_prec = std::make_unique<mfem::HypreSmoother>(*m_M_rho_mat, mfem::HypreSmoother::Jacobi);
        m_block_prec->SetDiagonalBlock(0, m_M_rho_prec.get());
 
 
        m_J_d.reset(m_reference_stiffness.ParallelAssemble());
        m_J_d->EliminateRowsCols(m_fem.ess_tdof_x);
 
        m_amg_d = std::make_unique<mfem::HypreBoomerAMG>(*m_J_d);
        m_amg_d->SetPrintLevel(0);
        m_amg_d->SetSystemsOptions(m_fem.Vec_H1_fes->GetVDim());
        m_block_prec->SetDiagonalBlock(1, m_amg_d.get());
 
        m_approx_jacobian->SetBlock(1, 1, &m_reference_stiffness);
 
        m_C = std::make_unique<PressureDensityCoupling<EOS_T>>(m_fem, *m_state_rho_gf, m_eos_pressure);
        m_D = std::make_unique<MassDisplacementCoupling<EOS_T>>(m_fem, *m_state_rho_gf, false);
 
        m_approx_jacobian->SetBlock(1, 0, m_C.get());
 
        m_B_vec.SetSize(m_fem.H1_fes->GetVSize());
        m_B_vec = 0.0;
 
        m_B_vec_op_col.SetVector(m_B_vec, true);
        m_B_vec_op_row.SetVector(m_B_vec, false);
 
        m_block_prec->SetBlock(1, 0, m_C.get());
 
        double h_sq = std::pow(RADIUS, 2.0) / std::pow(static_cast<double>(m_fem.mesh->GetGlobalNE()), 2.0/3.0);
        double grav_diag_scale = 4.0 * M_PI * h_sq;
        m_grav_coeff = std::make_unique<mfem::ConstantCoefficient>(grav_diag_scale);
    };
 
 
    void Mult(const mfem::Vector &u, mfem::Vector &r) const override {
        mfem::Vector u_rho_true(u.GetData() + m_fem.block_true_offsets[0], m_fem.H1_fes->GetTrueVSize());
        mfem::Vector u_d_true(u.GetData() + m_fem.block_true_offsets[1], m_fem.Vec_H1_fes->GetTrueVSize());
 
        // 2. Update persistent GridFunctions (renamed from m_grad_*)
        m_state_rho_gf->SetFromTrueDofs(u_rho_true);
        m_state_d_gf->SetFromTrueDofs(u_d_true);
 
        // 3. Safely set the mapping to our persistent displacement
        m_fem.mapping->SetDisplacement(*m_state_d_gf);
        m_B_vec = 0.0;
        mfem::ConstantCoefficient one(1.0);
        MappedScalarCoefficient mapped_b_vec(*m_fem.mapping, one);
        mfem::ParLinearForm b_lf(m_fem.H1_fes.get());
 
        auto* lf_integrator = new mfem::DomainLFIntegrator(mapped_b_vec);
        lf_integrator->SetIntRule(m_fem.int_rule.get());
        b_lf.AddDomainIntegrator(lf_integrator, m_stellar_mask);
        b_lf.Assemble();
        b_lf.ParallelAssemble(m_B_vec);
 
 
        m_fem.com = get_com(m_fem, *m_state_rho_gf);
        m_fem.Q = compute_quadrupole_moment_tensor(m_fem, *m_state_rho_gf, m_fem.com);
        auto phi = get_potential(m_fem, m_args, *m_state_rho_gf);
 
        mfem::Vector r_rho_true(r.GetData() + m_fem.block_true_offsets[0], m_fem.H1_fes->GetTrueVSize());
        mfem::Vector r_d_true(r.GetData() + m_fem.block_true_offsets[1], m_fem.Vec_H1_fes->GetTrueVSize());
 
        m_fluid_nlf.Mult(u_rho_true, r_rho_true);
 
        mfem::GridFunctionCoefficient phi_c(&phi);
        MappedScalarCoefficient mapped_phi_c(*m_fem.mapping, phi_c);
        mfem::ParLinearForm phi_lf(m_fem.H1_fes.get());
 
        auto* potential_lf_integrator = new mfem::DomainLFIntegrator(mapped_phi_c);
        potential_lf_integrator -> SetIntRule(m_fem.int_rule.get());
 
        phi_lf.AddDomainIntegrator(potential_lf_integrator, m_stellar_mask);
        phi_lf.Assemble();
 
        mfem::Vector phi_true(m_fem.H1_fes->GetTrueVSize());
        phi_lf.ParallelAssemble(phi_true);
        r_rho_true += phi_true;
 
        m_reference_stiffness.Mult(u_d_true, r_d_true);
 
        PressureBoundaryForce<EOS_T> pbf(
            m_fem.H1_fes->GetMesh()->Dimension(),
            m_fem,
            *m_state_rho_gf,
            m_eos_pressure,
            m_args.c
        );
        mfem::ParLinearForm pbf_lf(m_fem.Vec_H1_fes.get());
 
        auto* vec_bdr_integrator = new mfem::VectorBoundaryLFIntegrator(pbf);
        vec_bdr_integrator->SetIntRule(m_fem.int_rule.get());
 
        pbf_lf.AddBoundaryIntegrator(vec_bdr_integrator, m_bdr_mask);
        pbf_lf.Assemble();
 
        mfem::Vector pbf_true(m_fem.Vec_H1_fes->GetTrueVSize());
        pbf_lf.ParallelAssemble(pbf_true);
 
        r_d_true -= pbf_true;
 
        for (int i = 0; i < m_fem.ess_tdof_x.Size(); ++i) {
            r_d_true(m_fem.ess_tdof_x[i]) = 0.0;
        }
 
        const double current_mass = domain_integrate_grid_function(m_fem, *m_state_rho_gf, Domains::STELLAR);
        const double eps = 1e-7;
        const double penalty_lagrangian_multiplier = ((1.0/eps) * (current_mass - m_args.mass) - m_lambda);
 
        r_rho_true.Add(penalty_lagrangian_multiplier, m_B_vec);
 
        m_D->Assemble();
        r_d_true.Add(penalty_lagrangian_multiplier, m_D->GetVec());
 
        double min_rho = m_state_rho_gf->Min();
        double min_d = m_state_d_gf->Min();
 
        std::println("[DIAGNOSTIC] ρ_min: {:+0.2E}, d_min: {:+0.2E}, com: <{:+0.2E}, {:+0.2E}, {:+0.2E}>, current mass:  {:+0.2E}", min_rho, min_d, m_fem.com(0), m_fem.com(1), m_fem.com(2), current_mass);
    }
 
    [[nodiscard]] double GetMassPenaltyNorm() const {
        return m_B_vec.Norml2();
    }
 
    void SetLambda(const double lambda) const {m_lambda = lambda;};
 
    [[nodiscard]] Operator& GetGradient(const mfem::Vector &u) const override {
        // Note this is legacy code, GetGradient is not being used at the moment
        mfem::Array<int> stellar_mask;
        populate_element_mask(m_fem, Domains::STELLAR, stellar_mask);
 
 
        mfem::Vector u_rho_true(u.GetData() + m_fem.block_true_offsets[0], m_fem.H1_fes->GetTrueVSize());
        mfem::Vector u_d_true(u.GetData() + m_fem.block_true_offsets[1], m_fem.Vec_H1_fes->GetTrueVSize());
 
        m_state_rho_gf->SetFromTrueDofs(u_rho_true);
        m_state_d_gf->SetFromTrueDofs(u_d_true);
 
        m_fem.mapping->SetDisplacement(*m_state_d_gf);
 
        Operator& grad_rho_op = m_fluid_nlf.GetGradient(u_rho_true);
        m_approx_jacobian->SetBlock(0, 0, &grad_rho_op);
 
        mfem::ParNonlinearForm P00_form(m_fem.H1_fes.get());
        auto *enthalpy_integrator = new FluidIntegrator<EOS_T>(m_fem, m_eos_enthalpy, m_fem.has_mapping() ? m_fem.mapping.get() : nullptr);
        enthalpy_integrator->SetIntRule(m_fem.int_rule.get());
        P00_form.AddDomainIntegrator(enthalpy_integrator, m_stellar_mask);
 
        auto* grav_mass_integrator = new mfem::MassIntegrator(*m_grav_coeff);
        grav_mass_integrator->SetIntRule(m_fem.int_rule.get());
        P00_form.AddDomainIntegrator(grav_mass_integrator, m_stellar_mask);
 
 
        m_amg_rho = std::make_unique<mfem::HypreBoomerAMG>(*m_M_rho_mat);
        m_amg_rho->SetPrintLevel(0);
        m_block_prec->SetDiagonalBlock(0, m_amg_rho.get());
 
        m_B_vec = 0.0;
        mfem::ConstantCoefficient one(1.0);
        MappedScalarCoefficient mapped_b_vec(*m_fem.mapping, one);
        mfem::ParLinearForm b_lf(m_fem.H1_fes.get());
 
        auto* lf_integrator = new mfem::DomainLFIntegrator(mapped_b_vec);
        lf_integrator->SetIntRule(m_fem.int_rule.get());
        b_lf.AddDomainIntegrator(lf_integrator, m_stellar_mask);
        b_lf.Assemble();
 
        b_lf.ParallelAssemble(m_B_vec);
        double local_schur = mfem::InnerProduct(m_B_vec, m_B_vec);
        double global_schur = 0.0;
 
        MPI_Allreduce(&local_schur, &global_schur, 1, MPI_DOUBLE, MPI_SUM, m_fem.H1_fes->GetComm());
 
        if (std::abs(global_schur) < 1e-12) global_schur = 1.0;
 
        m_B_vec_op_row.SetVector(m_B_vec, false);
        m_B_vec *= -1.0;
        m_B_vec_op_col.SetVector(m_B_vec, true);
 
        m_C->Assemble();
        m_D->Assemble();
 
        return *m_approx_jacobian;
    }
 
    mfem::Solver& GetBlockPreconditioner() const {
        return *m_block_prec;
    }
private:
    FEM& m_fem;
    const Args& m_args;
    const EOS_P<EOS_T>& m_eos_enthalpy;
    const EOS_P<EOS_T>& m_eos_pressure;
    const std::unique_ptr<mfem::ConstantCoefficient> m_alpha;
 
    mutable mfem::Array<int> m_stellar_mask;
    mutable mfem::Array<int> m_bdr_mask;
 
    mutable mfem::ParNonlinearForm m_fluid_nlf;
    mutable mfem::ParBilinearForm m_reference_stiffness;
    std::unique_ptr<mfem::ConstantCoefficient> m_dummy_one;
 
    // ==================================
    // Persistent Memory Scratch Space
    // ==================================
 
    // State Variables
    mutable std::unique_ptr<mfem::ParGridFunction> m_state_rho_gf;
    mutable std::unique_ptr<mfem::ParGridFunction> m_state_d_gf;
 
    // Top Level Block Structure
    mutable std::unique_ptr<mfem::BlockOperator> m_approx_jacobian = nullptr;
    mutable std::unique_ptr<mfem::BlockLowerTriangularPreconditioner> m_block_prec;
 
    // Constant Preconditioner Blocks
    mutable std::unique_ptr<mfem::HypreParMatrix> m_J_d;
    mutable std::unique_ptr<mfem::HypreBoomerAMG> m_amg_d;
    mutable std::unique_ptr<mfem::HypreBoomerAMG> m_amg_rho;
    mutable std::unique_ptr<mfem::HypreParMatrix> m_M_rho_mat;
    mutable std::unique_ptr<mfem::HypreSmoother> m_M_rho_prec;
    mutable std::unique_ptr<mfem::IdentityOperator> m_ident_rho;
    mutable std::unique_ptr<mfem::IdentityOperator> m_ident_lambda;
 
    // Dynamic Coupling Operators
    mutable std::unique_ptr<PressureDensityCoupling<EOS_T>> m_C;
    mutable std::unique_ptr<MassDisplacementCoupling<EOS_T>> m_D;
 
    // Dynamic Vector Operators
    mutable mfem::Vector m_B_vec;
    mutable VectorOperator<EOS_T> m_B_vec_op_col;
    mutable VectorOperator<EOS_T> m_B_vec_op_row;
 
    // Constant Coefficients
    mutable std::unique_ptr<mfem::ConstantCoefficient> m_grav_coeff;
 
    mutable double m_lambda{0.0};
 
};
//endregion
 
//region Utility Functions
FEM setup_fem(const std::string& filename, const Args &args) {
    FEM fem;
 
    //==================================================================
    // Section 1: Mesh and FE Space Setup
    //==================================================================
    mfem::Mesh serial_mesh(filename, 0, 0);
    fem.mesh = std::make_unique<mfem::ParMesh>(MPI_COMM_WORLD, serial_mesh);
    fem.mesh->EnsureNodes();
    const int geom_order = get_mesh_order(*fem.mesh);
    int dim = fem.mesh->Dimension();
 
    // Assume solution space order == geometrical order
    fem.H1_fec   = std::make_unique<mfem::H1_FECollection>(geom_order, dim);
    fem.H1_fes   = std::make_unique<mfem::ParFiniteElementSpace>(fem.mesh.get(), fem.H1_fec.get());
    fem.Vec_H1_fes = std::make_unique<mfem::ParFiniteElementSpace>(fem.mesh.get(), fem.H1_fec.get(), dim, mfem::Ordering::byNODES);
 
    // LOR discretization for fast preconditioning
    fem.H1_lor_disc = std::make_unique<mfem::ParLORDiscretization>(*fem.H1_fes);
    fem.H1_lor_fes = &fem.H1_lor_disc->GetParFESpace();
 
    //==================================================================
    // Section 2: Domain Mapping
    //==================================================================
    auto [r_star_ref, r_inf_ref] = DiscoverBounds(fem.mesh.get(), 3)
    .or_else([](const BoundsError& err)->std::expected<Bounds, BoundsError> {
        throw std::runtime_error("Unable to determine vacuum domain reference boundary...");
    }).value();
 
    fem.mapping = std::make_unique<DomainMapper>(r_star_ref, r_inf_ref);
 
    //==================================================================
    // Section 3: Multi-physics Block-offsets
    // Layout: [rho (scalar H1) | d (vector H1) | λ (scalar)]
    //==================================================================
    fem.block_true_offsets.SetSize(3);
    fem.block_true_offsets[0] = 0;
    fem.block_true_offsets[1] = fem.H1_fes->GetTrueVSize();
    fem.block_true_offsets[2] = fem.block_true_offsets[1] + fem.Vec_H1_fes->GetTrueVSize();
    // fem.block_true_offsets[3] = fem.block_true_offsets[2] + 1;
 
    //==================================================================
    // Section 4: Multipole BC setup.
    // Assumes:
    //         - COM starts at origin
    //         - Geometry starts out as a spherically symmetric
    //==================================================================
    fem.com.SetSize(dim); fem.com = 0.0;
    fem.Q.SetSize(dim, dim); fem.Q = 0.0;
 
    fem.reference_x = std::make_unique<mfem::ParGridFunction>(fem.Vec_H1_fes.get());
    fem.mesh->GetNodes(*fem.reference_x);
 
    //==================================================================
    // Section 5: Integration Rules
    // Assumes:
    //         - Mesh is composed of hex elements (CUBE)
    //==================================================================
    MFEM_ASSERT(fem.mesh->GetElementGeometry(0) == mfem::Geometry::CUBE, "Currently only hexahedral meshes are supported");
    const int element_order = fem.H1_fes->GetMaxElementOrder();
    fem.int_order = 2 * element_order + geom_order - 2 + args.quad_boost;
 
    fem.int_rule = std::make_unique<mfem::IntegrationRule>(mfem::IntRules.Get(mfem::Geometry::CUBE, fem.int_order));
 
    //==================================================================
    // Section 6: Essential Degrees of Freedom and Boundary Markers
    //==================================================================
    fem.ess_tdof_x.SetSize(0); // No essential boundary conditions for the displacement, the null space here is handled with a weak penalty term
 
    populate_element_mask(fem, Domains::STELLAR, fem.gravity_context.stellar_mask);
 
    const int n_bdr_attrs = fem.mesh->bdr_attributes.Max();
    fem.boundary_context.inf_bounds.SetSize(n_bdr_attrs);
    fem.boundary_context.stellar_bounds.SetSize(n_bdr_attrs);
 
    fem.boundary_context.inf_bounds = 0;
    fem.boundary_context.stellar_bounds = 0;
 
    fem.boundary_context.inf_bounds[Boundaries::INF_SURFACE - 1] = 1; // Vacuum boundary
    fem.boundary_context.stellar_bounds[Boundaries::STELLAR_SURFACE - 1] = 1; // Stellar boundary
 
    fem.H1_fes->GetEssentialTrueDofs(fem.boundary_context.inf_bounds, fem.gravity_context.ho_ess_tdof_list);
    fem.H1_lor_fes->GetEssentialTrueDofs(fem.boundary_context.inf_bounds, fem.gravity_context.lor_ess_tdof_list);
 
    //==================================================================
    // Section 7: Gravity Solution Field
    //==================================================================
    fem.gravity_context.phi = std::make_unique<mfem::ParGridFunction>(fem.H1_fes.get());
    *fem.gravity_context.phi = 0.0;
 
    //==================================================================
    // Section 8: Laplacian Coefficients
    //==================================================================
    fem.gravity_context.unit_coeff = std::make_unique<mfem::ConstantCoefficient>(1.0);
    if (fem.has_mapping()) {
        fem.gravity_context.diff_coeff = std::make_unique<MappedDiffusionCoefficient>(*fem.mapping, *fem.gravity_context.unit_coeff, fem.mesh->Dimension());
    } else {
        throw std::runtime_error("Unable to determine domain reference boundary or mapping...");
    }
 
 
    //==================================================================
    // Section 9: High order Laplacian (Partial Assembly)
    //==================================================================
    fem.gravity_context.ho_laplacian = std::make_unique<mfem::ParBilinearForm>(fem.H1_fes.get());
    {
        auto* ho_di = new mfem::DiffusionIntegrator(*fem.gravity_context.diff_coeff);
        ho_di->SetIntRule(fem.int_rule.get());
        fem.gravity_context.ho_laplacian->AddDomainIntegrator(ho_di);
    }
    fem.gravity_context.ho_laplacian->SetAssemblyLevel(mfem::AssemblyLevel::PARTIAL);
    fem.gravity_context.ho_laplacian->Assemble();
 
    //==================================================================
    // Section 10: Low order Laplacian (preconditioning)
    //==================================================================
    fem.gravity_context.lor_laplacian = std::make_unique<mfem::ParBilinearForm>(&fem.H1_lor_disc->GetParFESpace());
    {
        auto* lo_di = new mfem::DiffusionIntegrator(*fem.gravity_context.diff_coeff);
        fem.gravity_context.lor_laplacian->AddDomainIntegrator(lo_di);
    }
    fem.gravity_context.lor_laplacian->Assemble();
    fem.gravity_context.lor_laplacian->Finalize();
 
    //==================================================================
    // Section 11: Constrained System Operators
    //==================================================================
    fem.gravity_context.ho_laplacian->FormSystemMatrix(
        fem.gravity_context.ho_ess_tdof_list,
        fem.gravity_context.A_ho
    );
 
    fem.gravity_context.lor_laplacian->FormSystemMatrix(
        fem.gravity_context.lor_ess_tdof_list,
        fem.gravity_context.A_lor
    );
 
    //==================================================================
    // Section 12: AMG Preconditioners and CG Solver
    //==================================================================
    fem.gravity_context.amg_prec = std::make_unique<mfem::HypreBoomerAMG>(*fem.gravity_context.A_lor.As<mfem::HypreParMatrix>());
    fem.gravity_context.amg_prec->SetPrintLevel(0);
 
    fem.gravity_context.amg_wrapper = std::make_unique<LORPrecWrapper>(*fem.gravity_context.amg_prec);
 
    fem.gravity_context.solver = std::make_unique<mfem::GMRESSolver>(MPI_COMM_WORLD);
    fem.gravity_context.solver->SetOperator(*fem.gravity_context.A_ho.Ptr());
    fem.gravity_context.solver->SetPreconditioner(*fem.gravity_context.amg_wrapper);
    fem.gravity_context.solver->SetRelTol(args.p.rtol);
    fem.gravity_context.solver->SetAbsTol(args.p.atol);
    fem.gravity_context.solver->SetMaxIter(1000);
    fem.gravity_context.solver->SetKDim(100);
    fem.gravity_context.solver->SetPrintLevel(0);
 
    //==================================================================
    // Section 13: RHS Coefficient Chain
    // Assumes:
    //         - fem.gravity_context.rho_coeff->SetGridFunction(&rho)
    //           must be called before each assembly of the RHS
    //==================================================================
    fem.gravity_context.four_pi_G_coeff = std::make_unique<mfem::ConstantCoefficient>(-4.0 * M_PI * G);
    fem.gravity_context.rho_coeff = std::make_unique<mfem::GridFunctionCoefficient>();
    fem.gravity_context.rhs_coeff = std::make_unique<mfem::ProductCoefficient>(*fem.gravity_context.four_pi_G_coeff, *fem.gravity_context.rho_coeff);
    fem.gravity_context.mapped_rhs_coeff = std::make_unique<MappedScalarCoefficient>(*fem.mapping, *fem.gravity_context.rhs_coeff, MappedScalarCoefficient::EVAL_POINTS::REFERENCE);
 
    //==================================================================
    // Section 14: RHS Linear Form
    //==================================================================
    fem.gravity_context.b = std::make_unique<mfem::ParLinearForm>(fem.H1_fes.get());
 
    {
        auto* grav_rhs_integrator = new mfem::DomainLFIntegrator(*fem.gravity_context.mapped_rhs_coeff);
        grav_rhs_integrator->SetIntRule(fem.int_rule.get());
        fem.gravity_context.b->AddDomainIntegrator(grav_rhs_integrator, fem.gravity_context.stellar_mask);
    }
 
 
    //=========================================================
    // Section 15: Diagnostic Output
    //=========================================================
    std::println(
        "{}Setup (rank: {}): element_order={}, geom_order={}, int_order={}, dim={}, "
        "r_star={:.3f}, r_inf={:.3f}, HO_ndofs={}, LOR_ndofs={}{}",
        ANSI_BLUE,
        mfem::Mpi::WorldRank(),
        element_order, geom_order, fem.int_order, dim,
        r_star_ref, r_inf_ref,
        fem.H1_fes->GlobalTrueVSize(),       // Global DOF count for HO space
        fem.H1_lor_fes->GlobalTrueVSize(),   // Global DOF count for LOR space
        ANSI_RESET);
    return fem;
}
 
void update_stiffness_matrix(FEM& fem) {
    auto& ctx = fem.gravity_context;
 
    const mfem::GridFunction* saved_d = fem.mapping->GetDisplacement();
    fem.mapping->ResetDisplacement();
 
 
    ctx.lor_laplacian = std::make_unique<mfem::ParBilinearForm>(&fem.H1_lor_disc->GetParFESpace());
    ctx.lor_laplacian->AddDomainIntegrator(new mfem::DiffusionIntegrator(*ctx.diff_coeff));
    ctx.lor_laplacian->Assemble();
    ctx.lor_laplacian->Finalize();
 
    ctx.A_lor.Clear();
    ctx.lor_laplacian->FormSystemMatrix(ctx.lor_ess_tdof_list, ctx.A_lor);
 
    ctx.amg_prec = std::make_unique<mfem::HypreBoomerAMG>(*ctx.A_lor.As<mfem::HypreParMatrix>());
    ctx.amg_prec->SetPrintLevel(0);
 
    ctx.amg_wrapper = std::make_unique<LORPrecWrapper>(*ctx.amg_prec);
    ctx.solver->SetPreconditioner(*ctx.amg_wrapper);
 
    if (saved_d) {
        fem.mapping->SetDisplacement(*saved_d);
    }
 
    ctx.ho_laplacian = std::make_unique<mfem::ParBilinearForm>(fem.H1_fes.get());
    auto* ho_di = new mfem::DiffusionIntegrator(*ctx.diff_coeff);
    ho_di->SetIntRule(fem.int_rule.get());
    ctx.ho_laplacian->AddDomainIntegrator(ho_di);
    ctx.ho_laplacian->SetAssemblyLevel(mfem::AssemblyLevel::PARTIAL);
    ctx.ho_laplacian->Assemble();
 
    ctx.A_ho.Clear();
    ctx.ho_laplacian->FormSystemMatrix(ctx.ho_ess_tdof_list, ctx.A_ho);
    ctx.solver->SetOperator(*ctx.A_ho.Ptr());}
 
void view_mesh(const std::string& host, int port, const mfem::Mesh& mesh, const mfem::GridFunction& gf, const std::string& title) {
    mfem::socketstream sol_sock(host.c_str(), port);
    if (!sol_sock.is_open()) return;
    sol_sock << "solution\n" << mesh << gf;
    sol_sock << "window_title '" << title << "\n" << std::flush;
}
 
double domain_integrate_grid_function(const FEM& fem, const mfem::GridFunction& gf, Domains domain) {
    mfem::LinearForm lf(fem.H1_fes.get());
    mfem::GridFunctionCoefficient gf_c(&gf);
    double local_integral;
    mfem::Array<int> elem_markers;
    populate_element_mask(fem, domain, elem_markers);
 
    if (fem.has_mapping()) {
        MappedScalarCoefficient mapped_gf_c(*fem.mapping, gf_c);
 
        // ReSharper disable once CppDFAMemoryLeak // Disabled because MFEM takes ownership so memory is not leaked
        auto* lf_integrator = new mfem::DomainLFIntegrator(mapped_gf_c);
        lf_integrator->SetIntRule(fem.int_rule.get());
        lf.AddDomainIntegrator(lf_integrator, elem_markers);
        lf.Assemble();
 
 
        local_integral = lf.Sum();
    } else {
        lf.AddDomainIntegrator(new mfem::DomainLFIntegrator(gf_c), elem_markers);
        lf.Assemble();
        local_integral = lf.Sum();
    }
 
    double global_integral = 0.0;
    MPI_Allreduce(&local_integral, &global_integral, 1, MPI_DOUBLE, MPI_SUM, fem.H1_fes->GetComm());
    return global_integral;
}
 
mfem::Vector get_com(const FEM& fem, const mfem::GridFunction &rho) {
    const int dim = fem.mesh->Dimension();
    mfem::Vector local_com(dim);
    local_com = 0.0;
    double local_mass = 0.0;
 
    for (int i = 0; i < fem.H1_fes->GetNE(); ++i) {
        if (fem.mesh->GetAttribute(i) == 3) continue;
        mfem::ElementTransformation *trans = fem.H1_fes->GetElementTransformation(i);
        const mfem::IntegrationRule &ir = *fem.int_rule;
 
        for (int j = 0; j < ir.GetNPoints(); ++j) {
            const mfem::IntegrationPoint &ip = ir.IntPoint(j);
            trans->SetIntPoint(&ip);
 
            double weight = trans->Weight() * ip.weight;
            if (fem.has_mapping()) {
                weight *= fem.mapping->ComputeDetJ(*trans, ip);
            }
            double rho_val = rho.GetValue(i, ip);
 
            mfem::Vector phys_point(dim);
            if (fem.has_mapping()) {
                fem.mapping->GetPhysicalPoint(*trans, ip, phys_point);
            } else {
                trans->Transform(ip, phys_point);
            }
 
            const double mass_term = rho_val * weight;
            local_mass += mass_term;
 
            for (int d = 0; d < dim; ++d) {
                local_com(d) += phys_point(d) * mass_term;
            }
        }
    }
 
    double global_mass = 0.0;
    mfem::Vector global_com(dim);
    MPI_Comm comm = fem.H1_fes->GetComm();
 
    MPI_Allreduce(&local_mass, &global_mass, 1, MPI_DOUBLE, MPI_SUM, comm);
 
    MPI_Allreduce(local_com.GetData(), global_com.GetData(), dim, MPI_DOUBLE, MPI_SUM, comm);
 
    if (global_mass > 1e-18) {
        global_com /= global_mass;
    } else {
        global_com = 0.0;
    }
 
    return global_com;
}
 
void get_physical_coordinates(const mfem::GridFunction& reference_pos, const mfem::GridFunction& displacement, mfem::GridFunction& physical_pos) {
    add(reference_pos, displacement, physical_pos);
}
 
void populate_element_mask(const FEM &fem, Domains domain, mfem::Array<int> &mask) {
    mask.SetSize(fem.mesh->attributes.Max());
    mask = 0;
 
    if ((domain & Domains::CORE) == Domains::CORE) {
        mask[0] = 1;
    }
 
    if ((domain & Domains::ENVELOPE) == Domains::ENVELOPE) {
        mask[1] = 1;
    }
 
    if ((domain & Domains::VACUUM) == Domains::VACUUM) {
        mask[2] = 1;
    }
}
 
std::expected<Bounds, BoundsError> DiscoverBounds(const mfem::Mesh *mesh, const int vacuum_attr) {
    double local_min_r = std::numeric_limits<double>::max();
    double local_max_r = -std::numeric_limits<double>::max();
    bool found_vacuum = false;
 
    for (int i = 0; i < mesh->GetNE(); ++i) {
        if (mesh->GetAttribute(i) == vacuum_attr) {
            found_vacuum = true;
            mfem::Array<int> vertices;
            mesh->GetElementVertices(i, vertices);
            for (const int v : vertices) {
                const double* coords = mesh->GetVertex(v);
                double r = std::sqrt(coords[0]*coords[0] + coords[1]*coords[1] + coords[2]*coords[2]);
                local_min_r = std::min(local_min_r, r);
                local_max_r = std::max(local_max_r, r);
            }
        }
    }
 
    double global_min_r, global_max_r;
    int global_found_vacuum;
    int l_found = found_vacuum ? 1 : 0;
 
    MPI_Comm comm = MPI_COMM_WORLD;
    if (const auto* pmesh = dynamic_cast<const mfem::ParMesh*>(mesh)) {
        comm = pmesh->GetComm();
    }
 
    MPI_Allreduce(&local_min_r, &global_min_r, 1, MPI_DOUBLE, MPI_MIN, comm);
    MPI_Allreduce(&local_max_r, &global_max_r, 1, MPI_DOUBLE, MPI_MAX, comm);
    MPI_Allreduce(&l_found, &global_found_vacuum, 1, MPI_INT, MPI_MAX, comm);
 
    if (global_found_vacuum) {
        return Bounds(global_min_r, global_max_r);
    }
    return std::unexpected(CANNOT_FIND_VACUUM);
}
 
double EvalGridFunctionAtPhysicalPoint(const FEM& fem, const mfem::GridFunction& u, const mfem::Vector& x_phys) {
    mfem::Array<int> elem_ids;
    mfem::Array<mfem::IntegrationPoint> ips;
 
    mfem::DenseMatrix P(3, 1);
    P(0, 0) = x_phys(0);
    P(1, 0) = x_phys(1);
    P(2, 0) = x_phys(2);
 
    fem.mesh->FindPoints(P, elem_ids, ips);
 
    double local_val = 0.0;
    bool found_local = (elem_ids.Size() > 0) ? true : false;
 
    if (found_local) {
        local_val = u.GetValue(elem_ids[0], ips[0]);
    }
 
    double global_val = 0.0;
    MPI_Allreduce(&local_val, &global_val, 1, MPI_DOUBLE, MPI_MAX, fem.H1_fes->GetComm());
 
    return global_val;
}
 
int get_mesh_order(const mfem::Mesh &mesh) {
    if (mesh.GetNodes() != nullptr) {
        return mesh.GetNodes() -> FESpace() -> GetMaxElementOrder();
    }
    return 1;
}
 
void conserve_mass(const FEM& fem, mfem::GridFunction& rho, const double target_mass) {
    if (const double current_mass = domain_integrate_grid_function(fem, rho, Domains::STELLAR); current_mass > 1e-15) rho *= (target_mass / current_mass);
}
//endregion
 
//region Physics Functions
double centrifugal_potential(const mfem::Vector& phys_x, const double omega) {
    const double s2 = std::pow(phys_x(0), 2) + std::pow(phys_x(1), 2);
    return -0.5 * s2 * std::pow(omega, 2);
}
 
double get_moment_of_inertia(const FEM& fem, const mfem::GridFunction& rho) {
    auto s2_func = [](const mfem::Vector& x) {
        return std::pow(x(0), 2) + std::pow(x(1), 2);
    };
 
    std::unique_ptr<mfem::Coefficient> s2_coeff;
    if (fem.has_mapping()) {
        s2_coeff = std::make_unique<PhysicalPositionFunctionCoefficient>(*fem.mapping, s2_func);
    } else {
        s2_coeff = std::make_unique<mfem::FunctionCoefficient>(s2_func);
    }
 
    mfem::GridFunctionCoefficient rho_coeff(&rho);
    mfem::ProductCoefficient I_integrand(rho_coeff, *s2_coeff);
 
    mfem::LinearForm I_lf(fem.H1_fes.get());
 
    double I = 0.0;
    // TODO: Need to filter here to just the stellar domain and also update the IntRule
    if (fem.has_mapping()) {
        MappedScalarCoefficient mapped_integrand(*fem.mapping, I_integrand);
        I_lf.AddDomainIntegrator(new mfem::DomainLFIntegrator(mapped_integrand));
        I_lf.Assemble();
        I = I_lf.Sum();
    } else {
        I_lf.AddDomainIntegrator(new mfem::DomainLFIntegrator(I_integrand));
        I_lf.Assemble();
        I = I_lf.Sum();
    }
 
    return I;
}
//endregion
 
//region potentials
mfem::GridFunction& test_analytic_grav_potential(FEM& fem, const Args& args, mfem::Coefficient& exact_rho_coeff, double exact_total_mass) {
    auto& ctx = fem.gravity_context;
    *ctx.phi = 0.0;
 
    auto bdr_func = [&fem, exact_total_mass](const mfem::Vector& x) {
        return l2_multipole_potential(fem, exact_total_mass, x);
    };
 
    PhysicalPositionFunctionCoefficient phi_bdr_coeff(*fem.mapping, bdr_func);
    ctx.phi->ProjectBdrCoefficient(phi_bdr_coeff, fem.boundary_context.inf_bounds);
 
    mfem::ProductCoefficient exact_rhs_coeff(*ctx.four_pi_G_coeff, exact_rho_coeff);
    MappedScalarCoefficient mapped_exact_rhs(*fem.mapping, exact_rhs_coeff, MappedScalarCoefficient::EVAL_POINTS::REFERENCE);
 
    mfem::ParLinearForm exact_b(fem.H1_fes.get());
    auto* grav_rhs_integrator = new mfem::DomainLFIntegrator(mapped_exact_rhs);
    grav_rhs_integrator->SetIntRule(fem.int_rule.get());
    exact_b.AddDomainIntegrator(grav_rhs_integrator, ctx.stellar_mask);
    exact_b.Assemble();
 
    ctx.ho_laplacian->FormLinearSystem(
        ctx.ho_ess_tdof_list, *ctx.phi, exact_b, ctx.A_ho, ctx.X_true, ctx.B_true
    );
 
    ctx.solver->SetOperator(*ctx.A_ho.Ptr());
    ctx.solver->Mult(ctx.B_true, ctx.X_true);
    ctx.ho_laplacian->RecoverFEMSolution(ctx.X_true, exact_b, *ctx.phi);
 
    return *ctx.phi;
}
 
const mfem::GridFunction &grav_potential(FEM &fem, const Args &args, const mfem::GridFunction &rho, const bool phi_warm) {
    auto& ctx = fem.gravity_context;
    if (!phi_warm) {
        *ctx.phi = 0.0;
    }
 
    double total_mass = domain_integrate_grid_function(fem, rho, Domains::STELLAR);
 
 
    auto bdr_func = [&fem, total_mass](const mfem::Vector& x) {
        return l2_multipole_potential(fem, total_mass, x);
    };
 
    PhysicalPositionFunctionCoefficient phi_bdr_coeff(*fem.mapping, bdr_func);
    ctx.phi->ProjectBdrCoefficient(phi_bdr_coeff, fem.boundary_context.inf_bounds);
 
    ctx.rho_coeff->SetGridFunction(&rho);
    ctx.b->Assemble();
 
    ctx.ho_laplacian->FormLinearSystem(
        ctx.ho_ess_tdof_list,
        *ctx.phi,
        *ctx.b,
        ctx.A_ho,
        ctx.X_true,
        ctx.B_true
    );
 
    ctx.solver->SetOperator(*ctx.A_ho.Ptr());
 
    ctx.solver->Mult(ctx.B_true, ctx.X_true);
    ctx.ho_laplacian->RecoverFEMSolution(ctx.X_true, *ctx.b, *ctx.phi);
 
    return *ctx.phi;
}
 
mfem::GridFunction get_potential(FEM &fem, const Args &args, const mfem::GridFunction &rho) {
    auto phi = grav_potential(fem, args, rho);
 
    if (args.r.enabled) {
        auto rot = [&args](const mfem::Vector& x) {
            return centrifugal_potential(x, args.r.omega);
        };
 
        std::unique_ptr<mfem::Coefficient> centrifugal_coeff;
        if (fem.has_mapping()) {
            centrifugal_coeff = std::make_unique<PhysicalPositionFunctionCoefficient>(*fem.mapping, rot);
        } else {
            centrifugal_coeff = std::make_unique<mfem::FunctionCoefficient>(rot);
        }
        mfem::GridFunction centrifugal_gf(fem.H1_fes.get());
        centrifugal_gf.ProjectCoefficient(*centrifugal_coeff);
        phi += centrifugal_gf;
    }
 
    return phi;
}
 
mfem::DenseMatrix compute_quadrupole_moment_tensor(const FEM& fem, const mfem::GridFunction& rho, const mfem::Vector& com) {
    const int dim = fem.mesh->Dimension();
    mfem::DenseMatrix local_Q(dim, dim);
    local_Q = 0.0;
 
    for (int i = 0; i < fem.H1_fes->GetNE(); ++i) {
        if (fem.mesh->GetAttribute(i) == 3) continue;
 
        mfem::ElementTransformation *trans = fem.mesh->GetElementTransformation(i);
        const mfem::IntegrationRule &ir = *fem.int_rule;
 
        for (int j = 0; j < ir.GetNPoints(); ++j) {
            const mfem::IntegrationPoint &ip = ir.IntPoint(j);
            trans->SetIntPoint(&ip);
 
            double weight = trans->Weight() * ip.weight;
 
            if (fem.has_mapping()) {
                weight *= fem.mapping->ComputeDetJ(*trans, ip);
            }
 
            const double rho_val = rho.GetValue(i, ip);
 
            mfem::Vector phys_point(dim);
            if (fem.has_mapping()) {
                fem.mapping->GetPhysicalPoint(*trans, ip, phys_point);
            } else {
                trans->Transform(ip, phys_point);
            }
 
            mfem::Vector x_prime(dim);
            double r_sq = 0.0;
 
            for (int d = 0; d < dim; ++d) {
                x_prime(d) = phys_point(d) - com(d);
                r_sq += x_prime(d) * x_prime(d);
            }
 
            for (int m = 0; m < dim; ++m) {
                for (int n = 0; n < dim; ++n) {
                    const double delta = (m == n) ? 1.0 : 0.0;
                    const double contrib = 3.0 * x_prime(m) * x_prime(n) - delta * r_sq;
                    local_Q(m, n) += rho_val * contrib * weight;
                }
            }
        }
    }
 
    mfem::DenseMatrix global_Q(dim, dim);
    MPI_Allreduce(local_Q.GetData(), global_Q.GetData(), dim * dim, MPI_DOUBLE, MPI_SUM, fem.H1_fes->GetComm());
 
    return global_Q;
}
double l2_multipole_potential(const FEM &fem, const double total_mass, const mfem::Vector &phys_x) {
    const double r = phys_x.Norml2();
    if (r < 1e-12) return 0.0;
 
    const int dim = fem.mesh->Dimension();
 
    mfem::Vector n(phys_x);
    n /= r;
 
    double l2_mult_factor = 0.0;
    for (int i = 0; i < dim; ++i) {
        for (int j = 0; j < dim; ++j) {
            l2_mult_factor += fem.Q(i, j) * n(i) * n(j);
        }
    }
 
    const double l2_contrib = - (G / (2.0 * std::pow(r, 3))) * l2_mult_factor;
 
    const double l0_contrib = -G * total_mass / r;
 
    // l1 contribution is zero for a system centered on its COM
    return l0_contrib + l2_contrib;
}
//endregion
 
//region Tests
void test_mesh_load(const FEM& fem) {
    size_t failed = 0;
    if (not fem.okay()) ++failed;
    const int dim = fem.mesh->Dimension();
    if (dim != 3) ++failed;
 
    const int n_scalar = fem.H1_fes->GetTrueVSize();
    const int n_vector = fem.Vec_H1_fes->GetTrueVSize();
    if (n_vector != dim * n_scalar) ++failed;
 
    if (fem.block_true_offsets[0] != 0) ++failed;
    if (fem.block_true_offsets[1] != n_scalar) ++failed;
    if (fem.block_true_offsets[2] != n_scalar + n_vector) ++failed;
 
    constexpr size_t num_tests = 6;
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    } else if (failed < num_tests) {
        result_type = TEST_RESULT_TYPE::PARTIAL;
    }
 
    RANK_GUARD(
        std::string test_msg = fmt_test_msg("Mesh Load Test", result_type, failed, num_tests);
        std::println("{}", test_msg);
    )
 
    assert(dim == 3);
    assert(n_vector == (n_scalar * dim));
    assert (fem.block_true_offsets[0] == 0);
    assert (fem.block_true_offsets[1] == n_scalar);
    assert (fem.block_true_offsets[2] == n_scalar + n_vector);
}
 
void test_ref_coord_storage(const FEM& fem) {
    size_t failed = 0;
    if (not fem.mapping->IsIdentity()) ++failed;
 
    const size_t num_elemIDs = std::min(30, fem.mesh->GetNE());
    for (int elemID = 0; elemID < num_elemIDs; ++elemID) {
        auto* trans = fem.mesh->GetElementTransformation(elemID);
        const auto& ir = mfem::IntRules.Get(trans->GetGeometryType(), 2);
        const auto& ip = ir.IntPoint(0);
        trans->SetIntPoint(&ip);
 
        mfem::Vector x_ref, x_phys;
        trans->Transform(ip, x_ref);
        fem.mapping->GetPhysicalPoint(*trans, ip, x_phys);
        x_ref -= x_phys;
 
        if (not (x_ref.Norml2() < 1e-12)) ++failed;
    }
 
    const size_t num_tests = num_elemIDs + 1;
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    } else if (failed < num_tests) {
        result_type = TEST_RESULT_TYPE::PARTIAL;
    }
 
    RANK_GUARD(
        std::string test_msg = fmt_test_msg("Mesh Ref Coord", result_type, failed, num_tests);
        std::println("{}", test_msg);
    )
}
 
void test_reference_volume_integral(const FEM& fem) {
    size_t failed = 0;
 
    mfem::GridFunction ones(fem.H1_fes.get());
    ones = 1.0;
 
    double vol = domain_integrate_grid_function(fem, ones, Domains::STELLAR);
    double expected = (4.0/3.0) * M_PI * std::pow(RADIUS, 3.0);
    double rel_err = std::abs(vol - expected) / expected;
 
    if (rel_err > 1e-2) ++failed;
 
    constexpr size_t num_tests = 1;
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    }
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("Reference Volume Integral", result_type, failed, num_tests));
 
        if (result_type == TEST_RESULT_TYPE::FAILURE) {
            std::println("\tFAILURE: Volume: {}, Expected: {}, Error (rel): {}", vol, expected, rel_err);
        }
    )
}
 
void test_spherically_symmetric_com(const FEM& fem) {
    mfem::GridFunction rho(fem.H1_fes.get());
    rho = 1.0;
 
    mfem::Vector com = get_com(fem, rho);
    size_t failed = 0;
 
    const size_t dim = fem.mesh->Dimension();
    const size_t num_tests = dim;
 
    for (int dimID = 0; dimID < num_tests; ++dimID) {
        if (std::abs(com(dimID)) > 1e-12) ++failed;
    }
 
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    } else if (failed < num_tests) {
        result_type = TEST_RESULT_TYPE::PARTIAL;
    }
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("Uniform COM", result_type, failed, num_tests));
 
        if (result_type == TEST_RESULT_TYPE::FAILURE) {
            std::println("\t COM=<{:+0.3E}, {:+0.3E}, {:+0.3E}>", com(0), com(1), com(2));
        }
    )
}
 
void test_com_variance_to_displacement(const FEM& fem) {
    size_t failed = 0;
    mfem::GridFunction linear_displacement(fem.Vec_H1_fes.get());
    linear_displacement = 10.0; // This will linearly displace the domain by 10 unit along all axes
 
    fem.mapping->SetDisplacement(linear_displacement);
 
    mfem::GridFunction rho(fem.H1_fes.get());
    rho = 1.0;
 
    mfem::Vector mapped_com = get_com(fem, rho);
 
    const size_t dim = fem.mesh->Dimension();
    const size_t num_tests = dim;
    for (int dimID = 0; dimID < num_tests; ++dimID) {
        if (10 - std::abs(mapped_com(dimID)) > 1e-12) ++failed;
    }
 
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    } else if (failed < num_tests) {
        result_type = TEST_RESULT_TYPE::PARTIAL;
    }
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("COM variance to displacement", result_type, failed, num_tests));
 
        if (result_type == TEST_RESULT_TYPE::FAILURE) {
            std::println("\tFAILURE COM=<{:+0.2E}, {:+0.2E}, {:+0.2E}>", mapped_com(0), mapped_com(1), mapped_com(2));
        }
    )
 
    fem.mapping->ResetDisplacement();
}
 
void test_volume_invariance_to_displacement(const FEM& fem) {
    size_t failed = 0;
    mfem::GridFunction linear_displacement(fem.Vec_H1_fes.get());
    linear_displacement = 10.0; // This will linearly displace the domain by 10 unit along all axes
 
    fem.mapping->SetDisplacement(linear_displacement);
 
    mfem::GridFunction ones(fem.H1_fes.get());
    ones = 1.0;
    double mapped_vol = domain_integrate_grid_function(fem, ones, Domains::STELLAR);
    double expected = (4.0/3.0) * M_PI * std::pow(RADIUS, 3.0);
    double rel_err = std::abs(mapped_vol - expected) / expected;
 
    if (rel_err > 1e-2) ++failed;
    constexpr size_t num_tests = 1;
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    }
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("Invariance of volume against translation", result_type, failed, num_tests));
 
        if (result_type == TEST_RESULT_TYPE::FAILURE) {
            std::println("\tFAILURE: Volume: {}, Expected: {}", mapped_vol, expected);
        }
    )
    fem.mapping->ResetDisplacement();
}
 
void test_volume_ellipsoid_deformation(const FEM& fem) {
    size_t failed = 0;
    size_t num_tests = 0;
 
    constexpr double a = 2.0; // x-axis
    constexpr double b = 0.5; // y-axis
    constexpr double c = 1.5; // z-axis
    constexpr double expected_vol = (4.0 / 3.0) * M_PI * a * b * c;
 
    mfem::GridFunction ellipsoid_displacement(fem.Vec_H1_fes.get());
    {
        const int dim = fem.mesh->Dimension();
        mfem::VectorFunctionCoefficient disp_coeff(dim, [&](const mfem::Vector& x, mfem::Vector& d) {
            d.SetSize(x.Size());
            d(0) = (a - 1.0) * x(0);
            d(1) = (b - 1.0) * x(1);
            d(2) = (c - 1.0) * x(2);
        });
        ellipsoid_displacement.ProjectCoefficient(disp_coeff);
    }
    fem.mapping->SetDisplacement(ellipsoid_displacement);
 
    {
        ++num_tests;
        mfem::GridFunction ones(fem.H1_fes.get());
        ones = 1.0;
        const double mapped_vol = domain_integrate_grid_function(fem, ones, Domains::STELLAR);
        const double rel_err = std::abs(mapped_vol - expected_vol) / expected_vol;
        if (rel_err > 1e-3) {
            ++failed;
            RANK_GUARD(
                std::println("\tFAILURE Test 1: Mapped volume = {:.6f}, expected = {:.6f}, rel_err = {:.2e}",
                             mapped_vol, expected_vol, rel_err);
            )
        }
    }
 
    {
        ++num_tests;
        const double expected_x2_integral = std::pow(a, 3) * b * c * (4.0 * M_PI / 15.0);
 
        mfem::GridFunction x_ref_sq(fem.H1_fes.get());
        mfem::FunctionCoefficient x_sq_coeff([](const mfem::Vector& x) {
            return x(0) * x(0);
        });
        x_ref_sq.ProjectCoefficient(x_sq_coeff);
 
        mfem::GridFunction x_phys_sq(fem.H1_fes.get());
        PhysicalPositionFunctionCoefficient x_phys_sq_coeff(*fem.mapping,
            [](const mfem::Vector& x_phys) {
                return x_phys(0) * x_phys(0);
            }
        );
        x_phys_sq.ProjectCoefficient(x_phys_sq_coeff);
 
        const double mapped_x2_integral = domain_integrate_grid_function(fem, x_phys_sq, Domains::STELLAR);
        if (const double rel_err = std::abs(mapped_x2_integral - expected_x2_integral) / expected_x2_integral; rel_err > 1e-3) {
            ++failed;
            RANK_GUARD(std::println("\tFAILURE Test 2: integral x_phys^2 = {:.6f}, expected = {:.6f}, rel_err = {:.2e}",
                         mapped_x2_integral, expected_x2_integral, rel_err);)
        }
    }
 
    {
        ++num_tests;
        constexpr double expected_detJ = a * b * c;
        double max_detJ_err = 0.0;
        for (int e = 0; e < std::min(5, fem.mesh->GetNE()); ++e) {
            if (fem.mesh->GetAttribute(e) == 3) { // We want to ignore vacuum elements for this test
                e--;
                continue;
            }
            auto* trans = fem.mesh->GetElementTransformation(e);
            const auto& ir = *fem.int_rule;
            for (int q = 0; q < ir.GetNPoints(); ++q) {
                const auto& ip = ir.IntPoint(q);
                trans->SetIntPoint(&ip);
                const double detJ = fem.mapping->ComputeDetJ(*trans, ip);
                max_detJ_err = std::max(max_detJ_err, std::abs(detJ - expected_detJ));
            }
        }
        if (max_detJ_err > 1e-10) {
            ++failed;
            RANK_GUARD(std::println("\tFAILURE Test 3: max pointwise |det(J) - a*b*c| = {:.2e}", max_detJ_err);)
        }
    }
 
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    } else if (failed < num_tests) {
        result_type = TEST_RESULT_TYPE::PARTIAL;
    }
 
    RANK_GUARD(std::println("{}", fmt_test_msg("Volume under ellipsoidal deformation", result_type, failed, num_tests));)
 
    fem.mapping->ResetDisplacement();
}
 
void test_uniform_potential(FEM& fem, const Args& args) {
    fem.mapping->ResetDisplacement();
    update_stiffness_matrix(fem);
 
    const double analytic_vol = (4.0/3.0) * M_PI * std::pow(RADIUS, 3);
    const double rho0 = MASS / analytic_vol;
 
    mfem::GridFunction rho_uniform(fem.H1_fes.get());
    rho_uniform = rho0;
 
    fem.com = get_com(fem, rho_uniform);
    fem.Q = compute_quadrupole_moment_tensor(fem, rho_uniform, fem.com);
 
    const auto phi = grav_potential(fem, args, rho_uniform);
 
    double local_max_abs_err = 0.0;
    double local_max_rel_err = 0.0;
    constexpr double tol = APPROX_MAX_ACCEPTABLE_POTENTIAL_ERROR_SI_BURNING;
 
    size_t local_failed = 0;
    size_t local_num_tests = 0;
 
    const size_t num_elemIDs = std::min(30, fem.mesh->GetNE());
    for (int elemID = 0; elemID < num_elemIDs; ++elemID) {
        local_num_tests++;
        auto* trans = fem.mesh->GetElementTransformation(elemID);
        const auto& ir = mfem::IntRules.Get(trans->GetGeometryType(), 2);
        const auto& ip = ir.IntPoint(0);
        trans->SetIntPoint(&ip);
 
        mfem::Vector x_phys;
        fem.mapping->GetPhysicalPoint(*trans, ip, x_phys);
 
        const double r = x_phys.Norml2();
        if (r < 1e-9) continue;
 
        const double phi_analytic = (-G * MASS / (2.0 * std::pow(RADIUS, 3.0))) * (3.0*RADIUS*RADIUS - r*r);
        const double phi_fem = phi.GetValue(elemID, ip); // Evaluates local ParGridFunction part
 
        const double abs_err = std::abs(phi_fem  - phi_analytic);
        const double rel_err = abs_err / std::abs(phi_analytic);
 
        local_max_abs_err = std::max(local_max_abs_err, abs_err);
        local_max_rel_err = std::max(local_max_rel_err, rel_err);
 
        if (rel_err > tol) ++local_failed;
    }
 
    double global_max_abs_err = 0.0;
    double global_max_rel_err = 0.0;
    long global_failed = 0;
    long global_num_tests = 0;
 
    MPI_Comm comm = fem.H1_fes->GetComm();
 
    MPI_Allreduce(&local_max_abs_err, &global_max_abs_err, 1, MPI_DOUBLE, MPI_MAX, comm);
    MPI_Allreduce(&local_max_rel_err, &global_max_rel_err, 1, MPI_DOUBLE, MPI_MAX, comm);
 
    long l_failed = static_cast<long>(local_failed);
    long l_tests = static_cast<long>(local_num_tests);
    MPI_Allreduce(&l_failed, &global_failed, 1, MPI_LONG, MPI_SUM, comm);
    MPI_Allreduce(&l_tests, &global_num_tests, 1, MPI_LONG, MPI_SUM, comm);
 
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (global_failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    } else if (global_failed < global_num_tests) {
        result_type = TEST_RESULT_TYPE::PARTIAL;
    }
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("Test Uniform Potential", result_type, global_failed, global_num_tests));
 
        if (result_type == TEST_RESULT_TYPE::FAILURE || result_type == TEST_RESULT_TYPE::PARTIAL) {
            std::println("\tFAILURE: global max abs error: {:+0.2E}, global max rel error: {:+0.2E}",
                         global_max_abs_err, global_max_rel_err);
        }
    )
}
 
void test_ellipsoidal_potential(FEM& fem, const Args& args) {
    constexpr double a = 1.0 * RADIUS;
    constexpr double b = a; // oblate
    constexpr double c = 0.99 * RADIUS;
 
    constexpr double expected_vol = (4.0 / 3.0) * M_PI * a * b * c;
    constexpr double rho0 = MASS / expected_vol;
 
    mfem::GridFunction ellipsoidal_disp(fem.Vec_H1_fes.get());
    mfem::VectorFunctionCoefficient disp_coeff(3, [&](const mfem::Vector& x, mfem::Vector& d) {
        d.SetSize(3);
        d(0) = (a/RADIUS - 1.0) * x(0);
        d(1) = (b/RADIUS - 1.0) * x(1);
        d(2) = (c/RADIUS - 1.0) * x(2);
    });
    ellipsoidal_disp.ProjectCoefficient(disp_coeff);
    fem.mapping->SetDisplacement(ellipsoidal_disp);
    update_stiffness_matrix(fem);
 
    mfem::GridFunction rho(fem.H1_fes.get());
    rho = rho0;
 
    fem.com = get_com(fem, rho);
    fem.Q = compute_quadrupole_moment_tensor(fem, rho, fem.com);
 
    // OblatePotential oblate{.use=true, .a=a, .c=c,.rho_0=rho0};
    const auto phi = grav_potential(fem, args, rho);
 
    constexpr double e_sq = 1.0 - (c * c)/(a*a);
    const double e = std::sqrt(e_sq);
 
    const double I_const = (2.0 * std::sqrt(1.0 - e_sq) / e) * std::asin(e);
    const double A_R = (std::sqrt(1.0-e_sq) / std::pow(e, 3.0)) * std::asin(e) - (1.0 - e_sq)/e_sq;
    const double A_z = (2.0 / e_sq) * (1.0 - (std::sqrt(1.0-e_sq) / e) * std::asin(e));
 
    size_t failed = 0;
    size_t num_tests = 0;
    double max_rel_err = 0.0;
    double total_err = 0.0;
    const size_t check_count = std::min(50, fem.mesh->GetNE());
 
    for (int elemID = 0; elemID < check_count; ++elemID) {
        auto* trans = fem.mesh->GetElementTransformation(elemID);
        const auto& ip = mfem::IntRules.Get(trans->GetGeometryType(), 2).IntPoint(0);
        trans->SetIntPoint(&ip);
 
        mfem::Vector x_phys;
        fem.mapping->GetPhysicalPoint(*trans, ip, x_phys);
 
        const double R2 = x_phys(0)*x_phys(0) + x_phys(1)*x_phys(1);
        const double z2 = x_phys(2)*x_phys(2);
        const double phi_analytic = -M_PI * G * rho0 * (a*a*I_const - A_R * R2 - A_z * z2);
 
        const double phi_fem = phi.GetValue(elemID, ip);
        const double rel_err = std::abs(phi_fem - phi_analytic) / std::abs(phi_analytic);
        max_rel_err = std::max(max_rel_err, rel_err);
        total_err += rel_err;
        num_tests++;
        if (rel_err > APPROX_MAX_ACCEPTABLE_POTENTIAL_ERROR_SI_BURNING) ++failed;
    }
 
    auto result_type = TEST_RESULT_TYPE::FAILURE;
    if (failed == 0) {
        result_type = TEST_RESULT_TYPE::SUCCESS;
    } else if (failed < num_tests) {
        result_type = TEST_RESULT_TYPE::PARTIAL;
    }
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("Test Ellipsoidal Potential", result_type, failed, num_tests));
        if (result_type == TEST_RESULT_TYPE::FAILURE) {
            std::println("\tFAILURE: max rel error: {:+0.2E}, mean rel error: {:+0.2E}", max_rel_err, total_err/static_cast<double>(num_tests));
        }
    )
}
 
void test_ferrers_sphere_potential(FEM& fem, const Args& args) {
    constexpr double R = RADIUS;
    constexpr double rho0 = 1.0;
 
    [[maybe_unused]] const double expected_mass = (8.0 / 15.0) * M_PI * rho0 * std::pow(R, 3.0);
 
    fem.mapping->ResetDisplacement();
    update_stiffness_matrix(fem);
 
    mfem::FunctionCoefficient rho_coeff([&](const mfem::Vector& x) {
        double r = x.Norml2();
        if (r > R) return 0.0;
        return rho0 * (1.0 - std::pow(r / R, 2.0));
    });
 
    mfem::GridFunction rho(fem.H1_fes.get());
    rho.ProjectCoefficient(rho_coeff);
 
    fem.com = get_com(fem, rho);
    fem.Q = compute_quadrupole_moment_tensor(fem, rho, fem.com);
 
    const auto phi = grav_potential(fem, args, rho);
 
    size_t failed = 0;
    size_t num_tests = 0;
    double max_rel_err = 0.0;
 
    // Phi(r) = 4*pi*G*rho0 * (r^2/6 - r^4/(20*R^2)) - pi*G*rho0*R^2
    const size_t check_count = std::min(50, fem.mesh->GetNE());
    for (int elemID = 0; elemID < check_count; ++elemID) {
        auto* trans = fem.mesh->GetElementTransformation(elemID);
        const auto& ip = mfem::IntRules.Get(trans->GetGeometryType(), 2).IntPoint(0);
        trans->SetIntPoint(&ip);
 
        mfem::Vector x_phys;
        fem.mapping->GetPhysicalPoint(*trans, ip, x_phys);
        const double r = x_phys.Norml2();
 
        if (r < R) {
            num_tests++;
            const double term1 = 4.0 * M_PI * G * rho0 * ((r*r / 6.0) - (std::pow(r, 4.0) / (20.0 * R*R)));
            constexpr double term2 = M_PI * G * rho0 * R*R;
            const double phi_analytic = term1 - term2;
 
            const double phi_fem = phi.GetValue(elemID, ip);
            const double rel_err = std::abs(phi_fem - phi_analytic) / std::abs(phi_analytic);
 
            max_rel_err = std::max(max_rel_err, rel_err);
            if (rel_err > APPROX_MAX_ACCEPTABLE_POTENTIAL_ERROR_SI_BURNING) ++failed;
        }
    }
 
    RANK_GUARD(
        auto result_type = (failed == 0) ? TEST_RESULT_TYPE::SUCCESS : TEST_RESULT_TYPE::FAILURE;
        std::println("{}", fmt_test_msg("Test Ferrers Inhomogeneous Potential", result_type, failed, num_tests));
 
        if (result_type == TEST_RESULT_TYPE::FAILURE) {
            std::println("\tFAILURE: max rel error: {:+0.2E}", max_rel_err);
        }
    )
}
 
void test_force_continuity(FEM& fem, const Args& args) {
    constexpr double rho0 = 1.0;
    constexpr double R = RADIUS;
 
    mfem::GridFunction rho(fem.H1_fes.get());
    rho = rho0;
 
    fem.mapping->ResetDisplacement();
    update_stiffness_matrix(fem);
 
    const auto phi = grav_potential(fem, args, rho);
 
    size_t failed = 0;
    size_t num_tests = 0;
    double max_jump = 0.0;
 
    for (int i = 0; i < fem.mesh->GetNE(); i++) {
        mfem::ElementTransformation *T = fem.mesh->GetElementTransformation(i);
        const mfem::IntegrationRule &ir = mfem::IntRules.Get(T->GetGeometryType(), 2);
 
        for (int j = 0; j < ir.GetNPoints(); j++) {
            const mfem::IntegrationPoint &ip = ir.IntPoint(j);
            T->SetIntPoint(&ip);
 
            mfem::Vector x;
            fem.mapping->GetPhysicalPoint(*T, ip, x);
            const double r = x.Norml2();
 
            // Check very close to the surface R
            if (std::abs(r - R) < 0.05) {
                num_tests++;
                mfem::Vector grad_phi;
                phi.GetGradient(*T, grad_phi);
 
                const double g_mag_fem = grad_phi.Norml2();
                const double total_mass = (4.0/3.0) * M_PI * std::pow(R, 3.0) * rho0;
                const double g_mag_analytic = (r <= R) ? (4.0/3.0)*M_PI*G*rho0*r : (G*total_mass)/(r*r);
 
                const double rel_err = std::abs(g_mag_fem - g_mag_analytic) / g_mag_analytic;
                max_jump = std::max(max_jump, rel_err);
                if (rel_err > APPROX_MAX_ACCEPTABLE_POTENTIAL_ERROR_SI_BURNING) ++failed; // Gradients are usually 1 order less accurate than the solution
            }
        }
    }
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("Test Force Continuity",
                     (failed == 0) ? TEST_RESULT_TYPE::SUCCESS : TEST_RESULT_TYPE::FAILURE, failed, num_tests));
    )
}
 
void test_ferrers_ellipsoid_potential(FEM& fem, const Args& args) {
    constexpr double a = 1.1 * RADIUS;
    constexpr double b = 1.0 * RADIUS;
    constexpr double c = 0.9 * RADIUS;
    constexpr double rho0 = 1.0;
 
    mfem::GridFunction ferrers_disp(fem.Vec_H1_fes.get());
    mfem::VectorFunctionCoefficient disp_coeff(3, [&](const mfem::Vector& x, mfem::Vector& d) {
        d.SetSize(3);
        d(0) = (a/RADIUS - 1.0) * x(0);
        d(1) = (b/RADIUS - 1.0) * x(1);
        d(2) = (c/RADIUS - 1.0) * x(2);
    });
    ferrers_disp.ProjectCoefficient(disp_coeff);
    fem.mapping->SetDisplacement(ferrers_disp);
    update_stiffness_matrix(fem);
 
    auto rho_func = [&](const mfem::Vector& x_phys) {
        double m2 = std::pow(x_phys(0)/a, 2) + std::pow(x_phys(1)/b, 2) + std::pow(x_phys(2)/c, 2);
        return (m2 < 1.0) ? rho0 * (1.0 - m2) : 0.0;
    };
    PhysicalPositionFunctionCoefficient rho_coeff(*fem.mapping, rho_func);
 
    mfem::GridFunction rho(fem.H1_fes.get());
    rho.ProjectCoefficient(rho_coeff);
 
    fem.com = get_com(fem, rho);
    fem.Q = compute_quadrupole_moment_tensor(fem, rho, fem.com);
    const auto phi = grav_potential(fem, args, rho);
 
    auto calc_analytic_phi = [&](const mfem::Vector& x) {
        auto integrand = [&](double theta) {
            if (theta == 0.0) return 0.0;
            if (theta >= M_PI / 2.0) return 1.0 / a;
 
            const double tan_t = std::tan(theta);
            const double sec_t = 1.0 / std::cos(theta);
 
            const double u = a * a * tan_t * tan_t;
            const double du_dtheta = 2.0 * a * a * tan_t * sec_t * sec_t;
 
            const double a2_u = a*a + u;
            const double b2_u = b*b + u;
            const double c2_u = c*c + u;
 
            const double delta = std::sqrt(a2_u * b2_u * c2_u);
            const double m_u2 = (x(0)*x(0))/a2_u + (x(1)*x(1))/b2_u + (x(2)*x(2))/c2_u;
 
            const double val = 0.5 * std::pow(1.0 - m_u2, 2) / delta;
 
            return val * du_dtheta;
        };
 
        int n_steps = 1000;
        double dtheta = (M_PI / 2.0) / n_steps;
        double sum = integrand(0.0) + integrand(M_PI / 2.0);
        for (int i = 1; i < n_steps; ++i) {
            const double theta = i * dtheta;
            sum += (i % 2 == 0 ? 2.0 : 4.0) * integrand(theta);
        }
        sum *= dtheta / 3.0;
 
        return -M_PI * G * a * b * c * rho0 * sum;
    };
 
    size_t failed = 0;
    size_t num_tests = 0;
    double max_rel_err = 0.0;
    double total_rel_err = 0.0;
 
    const size_t check_count = std::min(50, fem.mesh->GetNE());
    for (int elemID = 0; elemID < check_count; ++elemID) {
        if (fem.mesh->GetAttribute(elemID) != 1) continue;
 
        auto* trans = fem.mesh->GetElementTransformation(elemID);
        const auto& ip = mfem::IntRules.Get(trans->GetGeometryType(), 2).IntPoint(0);
        trans->SetIntPoint(&ip);
 
        mfem::Vector x_phys;
        fem.mapping->GetPhysicalPoint(*trans, ip, x_phys);
 
        double phi_fem = phi.GetValue(elemID, ip);
        double phi_analytic = calc_analytic_phi(x_phys);
 
        double rel_err = std::abs(phi_fem - phi_analytic) / std::abs(phi_analytic);
        max_rel_err = std::max(max_rel_err, rel_err);
        total_rel_err += rel_err;
        num_tests++;
 
        if (rel_err > APPROX_MAX_ACCEPTABLE_POTENTIAL_ERROR_SI_BURNING) ++failed;
    }
 
    long local_failed = static_cast<long>(failed);
    long local_num_tests = static_cast<long>(num_tests);
    long global_failed = 0;
    long global_num_tests = 0;
 
    MPI_Comm comm = fem.H1_fes->GetComm();
    MPI_Allreduce(&local_failed, &global_failed, 1, MPI_LONG, MPI_SUM, comm);
    MPI_Allreduce(&local_num_tests, &global_num_tests, 1, MPI_LONG, MPI_SUM, comm);
 
 
    RANK_GUARD(
        auto result_type = (failed == 0) ? TEST_RESULT_TYPE::SUCCESS : TEST_RESULT_TYPE::FAILURE;
        std::println("{}", fmt_test_msg("Test Ferrers Ellipsoid (n=1)", result_type, global_failed, global_num_tests));
        if (result_type == TEST_RESULT_TYPE::FAILURE) {
            std::println("\tFAILURE: Max Rel Error: {:+0.2E}, Mean Rel Error: {:+0.2E}", max_rel_err, total_rel_err/static_cast<double>(num_tests));
        }
    )
}
 
void test_mass_conservation_constraint(const FEM& fem, const Args& args) {
    constexpr double target_mass = MASS;
    constexpr double R = RADIUS;
 
    fem.mapping->ResetDisplacement();
 
    mfem::FunctionCoefficient rho_coeff([&](const mfem::Vector& x) {
        double r = x.Norml2();
        return (r < R) ? (1.0 - (r/R)*(r/R)) : 0.0;
    });
 
    mfem::GridFunction rho(fem.H1_fes.get());
    rho.ProjectCoefficient(rho_coeff);
 
    auto enforce_mass = [&](mfem::GridFunction& gf) {
        double current_m = domain_integrate_grid_function(fem, gf, Domains::STELLAR);
        if (current_m > 0.0) {
            gf *= (target_mass / current_m);
        }
    };
 
    size_t failed = 0;
    size_t num_tests = 0;
 
    enforce_mass(rho);
    double mass_undeformed = domain_integrate_grid_function(fem, rho, Domains::STELLAR);
    num_tests++;
    if (std::abs(mass_undeformed - target_mass) / target_mass > 1e-12) {
        failed++;
        RANK_GUARD(
            std::println("\tFAILURE (Undeformed Conservation):");
            std::println("\t\tExpected Mass: {:.6e}", target_mass);
            std::println("\t\tActual Mass:   {:.6e}", mass_undeformed);
            std::println("\t\tRel Error:     {:+0.2E}", std::abs(mass_undeformed - target_mass) / target_mass);
        )
    }
 
    constexpr double a = 1.5 * RADIUS;
    constexpr double b = 0.8 * RADIUS;
    constexpr double c = 0.5 * RADIUS;
 
    mfem::GridFunction disp(fem.Vec_H1_fes.get());
    mfem::VectorFunctionCoefficient disp_coeff(3, [&](const mfem::Vector& x, mfem::Vector& d) {
        d.SetSize(3);
        d(0) = (a/R - 1.0) * x(0);
        d(1) = (b/R - 1.0) * x(1);
        d(2) = (c/R - 1.0) * x(2);
    });
    disp.ProjectCoefficient(disp_coeff);
    fem.mapping->SetDisplacement(disp);
 
    double expected_scaled_mass = target_mass * (a * b * c) / (R * R * R);
    double mass_deformed_unscaled = domain_integrate_grid_function(fem, rho, Domains::STELLAR);
    num_tests++;
    if (std::abs(mass_deformed_unscaled - expected_scaled_mass) / expected_scaled_mass > 1e-10) {
        failed++;
        RANK_GUARD(
            std::println("\tFAILURE (Geometric Scaling):");
            std::println("\t\tDisplacement mapping did not correctly scale the physical mass integral.");
            std::println("\t\tExpected Scaled Mass: {:.6e}", expected_scaled_mass);
            std::println("\t\tActual Scaled Mass:   {:.6e}", mass_deformed_unscaled);
            std::println("\t\tRel Error:            {:+0.2E}", std::abs(mass_deformed_unscaled - expected_scaled_mass) / expected_scaled_mass);
        )
    }
 
    enforce_mass(rho);
    double mass_deformed_conserved = domain_integrate_grid_function(fem, rho, Domains::STELLAR);
    num_tests++;
    if (std::abs(mass_deformed_conserved - target_mass) / target_mass > 1e-12) {
        failed++;
        RANK_GUARD(
            std::println("\tFAILURE (Deformed Conservation):");
            std::println("\t\tFailed to enforce mass constraint on the deformed geometry.");
            std::println("\t\tExpected Mass: {:.6e}", target_mass);
            std::println("\t\tActual Mass:   {:.6e}", mass_deformed_conserved);
            std::println("\t\tRel Error:     {:+0.2E}", std::abs(mass_deformed_conserved - target_mass) / target_mass);
        )
    }
 
    RANK_GUARD(
        auto result_type = (failed == 0) ? TEST_RESULT_TYPE::SUCCESS : (failed < num_tests ? TEST_RESULT_TYPE::PARTIAL : TEST_RESULT_TYPE::FAILURE);
        std::println("{}", fmt_test_msg("Test Mass Conservation Constraint", result_type, failed, num_tests));
    )
}
 
void test_xad_eos_derivative(const FEM& fem, const Args& args) {
    constexpr double K = 2.5;
    constexpr double gamma = 5.0 / 3.0;
 
    auto polytropic_eos = [&](const auto& rho) {
        using std::pow; // Allow ADL to find xad::pow if necessary
        return K * pow(rho, gamma);
    };
 
    fem.mapping->ResetDisplacement();
    mfem::FunctionCoefficient rho_coeff([&](const mfem::Vector& x) {
        double r = x.Norml2();
        return std::max(0.1, 1.0 - (r / RADIUS));
    });
    mfem::GridFunction rho(fem.H1_fes.get());
    rho.ProjectCoefficient(rho_coeff);
 
    size_t failed = 0;
    size_t num_tests = 0;
    double max_rel_err_p = 0.0;
    double max_rel_err_dp = 0.0;
 
    const size_t check_count = std::min(100, fem.mesh->GetNE());
    for (int elemID = 0; elemID < check_count; ++elemID) {
        if (fem.mesh->GetAttribute(elemID) != 1) continue; // Only Core elements
 
        auto* trans = fem.mesh->GetElementTransformation(elemID);
        const auto& ir = *fem.int_rule;
 
        for (int q = 0; q < ir.GetNPoints(); ++q) {
            const auto& ip = ir.IntPoint(q);
            trans->SetIntPoint(&ip);
 
            double rho_val = rho.GetValue(elemID, ip);
 
            double p_analytic = K * std::pow(rho_val, gamma);
            double dp_drho_analytic = K * gamma * std::pow(rho_val, gamma - 1.0);
 
            xad::FReal<double> rho_ad(rho_val, 1.0);
            xad::FReal<double> p_ad = polytropic_eos(rho_ad);
 
            const double p_xad = p_ad.getValue();
            const double dp_drho_xad = p_ad.getDerivative();
 
            double rel_err_p = std::abs(p_xad - p_analytic) / std::abs(p_analytic);
            double rel_err_dp = std::abs(dp_drho_xad - dp_drho_analytic) / std::abs(dp_drho_analytic);
 
            max_rel_err_p = std::max(max_rel_err_p, rel_err_p);
            max_rel_err_dp = std::max(max_rel_err_dp, rel_err_dp);
 
            num_tests++;
            if (rel_err_p > 1e-12 || rel_err_dp > 1e-12) {
                failed++;
            }
        }
    }
 
    long local_failed = static_cast<long>(failed);
    long local_num_tests = static_cast<long>(num_tests);
    long global_failed = 0;
    long global_num_tests = 0;
 
    MPI_Comm comm = fem.H1_fes->GetComm();
    MPI_Allreduce(&local_failed, &global_failed, 1, MPI_LONG, MPI_SUM, comm);
    MPI_Allreduce(&local_num_tests, &global_num_tests, 1, MPI_LONG, MPI_SUM, comm);
 
    RANK_GUARD(
        const auto result_type = (failed == 0) ? TEST_RESULT_TYPE::SUCCESS : TEST_RESULT_TYPE::FAILURE;
        std::println("{}", fmt_test_msg("Test XAD EOS Derivative", result_type, global_failed, global_num_tests));
 
        if (result_type == TEST_RESULT_TYPE::FAILURE) {
            std::println("\tFAILURE: Max Rel Error (P): {:+0.2E}, Max Rel Error (dP/drho): {:+0.2E}",
                         max_rel_err_p, max_rel_err_dp);
        }
    )
}
 
void test_domain_mapper_state_isolation(const FEM& fem) {
    size_t failed = 0;
 
    mfem::GridFunction messy_disp(fem.Vec_H1_fes.get());
    mfem::VectorFunctionCoefficient disp_coeff(3, [](const mfem::Vector& x, mfem::Vector& d) {
        d.SetSize(3);
        d(0) = std::sin(x(0) * M_PI);
        d(1) = std::cos(x(1) * M_PI);
        d(2) = x(0) * x(1);
    });
    messy_disp.ProjectCoefficient(disp_coeff);
    fem.mapping->SetDisplacement(messy_disp);
 
    int elem_A = 0;
    int elem_B = fem.mesh->GetNE() / 2;
 
    auto* trans_A = fem.mesh->GetElementTransformation(elem_A);
    auto* trans_B = fem.mesh->GetElementTransformation(elem_B);
    const auto& ip = mfem::IntRules.Get(trans_A->GetGeometryType(), 2).IntPoint(0);
 
    mfem::DenseMatrix J_A_baseline, J_B_baseline, J_A_test, J_B_test;
 
    trans_A->SetIntPoint(&ip);
    fem.mapping->ComputeJacobian(*trans_A, J_A_baseline);
 
    trans_B->SetIntPoint(&ip);
    fem.mapping->ComputeJacobian(*trans_B, J_B_baseline);
 
    for (int i = 0; i < 5; ++i) {
        trans_A->SetIntPoint(&ip);
        fem.mapping->ComputeJacobian(*trans_A, J_A_test);
 
        trans_B->SetIntPoint(&ip);
        fem.mapping->ComputeJacobian(*trans_B, J_B_test);
 
        J_A_test.Add(-1.0, J_A_baseline);
        if (J_A_test.MaxMaxNorm() > 1e-12) ++failed;
 
        J_B_test.Add(-1.0, J_B_baseline);
        if (J_B_test.MaxMaxNorm() > 1e-12) ++failed;
    }
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("Test DomainMapper State Isolation",
                     (failed == 0) ? TEST_RESULT_TYPE::SUCCESS : TEST_RESULT_TYPE::FAILURE,
                     failed, 10));
    )
 
    fem.mapping->ResetDisplacement();
}
 
void test_hydrostatic_zero_residual(FEM& fem, const Args& i_args) {
    fem.mapping->ResetCacheStats();
    Args args = i_args;
 
    constexpr double R = RADIUS;
    constexpr double k = M_PI / R;
    constexpr double rho_c = 1.0;
    constexpr double K_poly = (2.0 * G * R * R) / M_PI;
 
    constexpr double exact_lambda = -(4.0 * G * R * R * rho_c) / M_PI; // This is -4/pi
 
 
    auto eos_pressure = [&](const auto& rho_val, const auto& temp_val) {
        using std::pow;
        return K_poly * pow(rho_val, 2.0);
    };
 
    auto eos_enthalpy = [&](const auto& rho_val, const auto& temp_val) {
        return 2.0 * K_poly * rho_val;
    };
 
    fem.mapping->ResetDisplacement();
    update_stiffness_matrix(fem);
 
    mfem::FunctionCoefficient rho_coeff([&](const mfem::Vector& x) -> double {
        double r = x.Norml2();
        if (r < 1e-8) return rho_c;
        if (r >= R) return 0.0;
        return rho_c * std::sin(k * r) / (k * r);
    });
 
 
    mfem::BlockVector state(fem.block_true_offsets);
    state = 0.0;
 
    mfem::ParGridFunction rho(fem.H1_fes.get());
    rho.ProjectCoefficient(rho_coeff);
    mfem::Vector& state_rho = state.GetBlock(0);
    rho.GetTrueDofs(state_rho);
 
 
    mfem::BlockVector residual(fem.block_true_offsets);
    residual = 0.0;
 
    double integrated_mass = domain_integrate_grid_function(fem, rho, Domains::STELLAR);
    args.mass = integrated_mass;
 
    ResidualOperator<xad::AReal<double>> coupled_operator(fem, args, eos_enthalpy, eos_pressure, 1.0);
 
    coupled_operator.SetLambda(exact_lambda);
    coupled_operator.Mult(state, residual);
 
    mfem::Vector& r_rho = residual.GetBlock(0);
    mfem::Vector& r_d = residual.GetBlock(1);
 
    double force_scale = K_poly * rho_c * rho_c * R * R;
 
    double norm_rho = r_rho.Norml2() / force_scale;
    double norm_d = r_d.Norml2();
 
    size_t failed = 0;
    size_t num_tests = 3;
 
    if (norm_rho > 1e-4) ++failed;
    if (norm_d > 1e-4) ++failed;
 
    RANK_GUARD(
        auto result_type = (failed == 0) ? TEST_RESULT_TYPE::SUCCESS :
                           (failed < num_tests ? TEST_RESULT_TYPE::PARTIAL : TEST_RESULT_TYPE::FAILURE);
 
        std::println("{}", fmt_test_msg("Test Hydrostatic Zero Residual (n=1)", result_type, failed, num_tests));
 
        if (result_type != TEST_RESULT_TYPE::SUCCESS) {
            std::println("\tFAILURE: The Residual Operator did not achieve equilibrium.");
            std::println("\t\tDensity Rel Residual (H + Phi - Lambda): {:+0.2E}", norm_rho);
            std::println("\t\tDisplacement Abs Residual (Stiffness - PBF): {:+0.2E}", norm_d);
        }
    )
 
    mfem::BlockVector perturbed_state(state);
    mfem::Vector& perturbed_rho = perturbed_state.GetBlock(0);
    mfem::Vector& perturbed_d = perturbed_state.GetBlock(1);
    perturbed_rho *= 1.05;
 
    mfem::ParGridFunction p_d_gf(fem.Vec_H1_fes.get());
    mfem::VectorFunctionCoefficient pd_coeff(3, [&](const mfem::Vector& x, mfem::Vector& v) {
        v.SetSize(3);
        v(0) = 0.01 + x(0);
        v(1) = 0.01 + x(1);
        v(2) = 0.01 + x(2);
    });
    p_d_gf.ProjectCoefficient(pd_coeff);
    p_d_gf.GetTrueDofs(perturbed_d);
 
    mfem::BlockVector perturbed_residual(fem.block_true_offsets);
    perturbed_residual = 0.0;
    coupled_operator.Mult(perturbed_state, perturbed_residual);
 
    mfem::Vector& p_r_rho = perturbed_residual.GetBlock(0);
    mfem::Vector& p_r_d = perturbed_residual.GetBlock(1);
 
    double p_norm_rho = p_r_rho.Norml2() / force_scale;
    double p_norm_d = p_r_d.Norml2();
 
    size_t perturb_failed = 0;
    size_t perturb_tests = 3;
 
    if (p_norm_rho <= norm_rho) ++perturb_failed;
 
    if (p_norm_d <= norm_d) ++perturb_failed;
 
 
    RANK_GUARD(
        auto perturb_result = (perturb_failed == 0) ? TEST_RESULT_TYPE::SUCCESS :
                              (perturb_failed < perturb_tests ? TEST_RESULT_TYPE::PARTIAL : TEST_RESULT_TYPE::FAILURE);
 
        std::println("{}", fmt_test_msg("Test Hydrostatic Perturbation Spike", perturb_result, perturb_failed, perturb_tests));
 
        if (perturb_result != TEST_RESULT_TYPE::SUCCESS) {
            std::println("\tFAILURE: Perturbing the state did not worsen the residual.");
            std::println("\t\tDensity Res jump:   {:.2e} -> {:.2e}", norm_rho, p_norm_rho);
            std::println("\t\tDisp Res jump:      {:.2e} -> {:.2e}", norm_d, p_norm_d);
        };
    )
}
 
void test_lbfgs_convergence(FEM& fem, const Args& i_args) {
    Args args = i_args;
    constexpr double R = RADIUS;
    constexpr double rho_c = 1.0;
    constexpr double K_poly = (2.0 * G * R * R) / M_PI;
    args.mass = (4.0 * R * R * R * rho_c) / M_PI;
 
    auto eos_pressure = [&](const auto& r, const auto& t) { return K_poly * pow(r, 2.0); };
    auto eos_enthalpy = [&](const auto& r, const auto& t) { return 2.0 * K_poly * r; };
 
    mfem::BlockVector u(fem.block_true_offsets);
    u = 0.0;
    mfem::FunctionCoefficient rho_coeff([&](const mfem::Vector& x) {
        double r = x.Norml2();
        return (r < R) ? 1.05 * rho_c * std::sin((M_PI/R) * r) / ((M_PI/R) * r) : 0.0;
    });
    mfem::ParGridFunction rho_init(fem.H1_fes.get());
    rho_init.ProjectCoefficient(rho_coeff);
    rho_init.GetTrueDofs(u.GetBlock(0));
 
    ResidualOperator<xad::AReal<double>> coupled_op(fem, args, eos_enthalpy, eos_pressure, 1.0);
 
    mfem::Vector res_vec(u.Size());
    coupled_op.Mult(u, res_vec);
    double initial_norm = res_vec.Norml2();
 
    mfem::LBFGSSolver lbfgs(fem.H1_fes->GetComm());
    lbfgs.SetOperator(coupled_op);
    lbfgs.SetAbsTol(1e-8);
    lbfgs.SetRelTol(1e-6);
    lbfgs.SetMaxIter(200);
    lbfgs.SetPrintLevel(1);
 
    mfem::Vector zero_rhs(u.Size());
    zero_rhs = 0.0;
 
    auto start_time = std::chrono::steady_clock::now();
 
    lbfgs.Mult(zero_rhs, u);
 
    auto end_time = std::chrono::steady_clock::now();
    std::chrono::duration<double, std::milli> elapsed = end_time - start_time;
 
    // 5. Evaluate Final State
    coupled_op.Mult(u, res_vec);
    double final_norm = res_vec.Norml2();
 
    mfem::ParGridFunction rho(fem.H1_fes.get());
    rho.SetFromTrueDofs(u.GetBlock(0));
    double final_mass = domain_integrate_grid_function(fem, rho, Domains::STELLAR);
 
    bool success = (final_norm < initial_norm * 0.1) && (std::abs(final_mass - args.mass) / args.mass < 1e-6);
 
    RANK_GUARD(
        std::println("{}", fmt_test_msg("Test L-BFGS Convergence",
                     success ? TEST_RESULT_TYPE::SUCCESS : TEST_RESULT_TYPE::FAILURE,
                     success ? 0 : 1, 1));
        if (!success) {
            std::println("Failure Report:");
            std::println("\tReduction: {:.6e} -> {:.6e} ({:.2f}%)",
                         initial_norm, final_norm, (final_norm/initial_norm)*100.0);
            std::println("\tRuntime: {:.2f} ms", elapsed.count());
            std::println("\tFinal Mass: {:.6e} ((target mass: {:+0.2E}) Rel Error: {:+0.2E}%)", final_mass, args.mass, std::abs(final_mass - args.mass) / args.mass * 100.0 / args.mass);
        }
    )
}
 
void test_rotational_residual_injection(FEM& fem, const Args& i_args) {
    Args args = i_args;
    constexpr double R = RADIUS;
    constexpr double rho_c = 1.0;
    constexpr double K_poly = (2.0 * G * R * R) / M_PI;
    constexpr double exact_lambda = -(4.0 * G * R * R * rho_c) / M_PI;
 
    args.mass = (4.0 * R * R * R * rho_c) / M_PI;
    args.r.enabled = true;
    args.r.omega = 0.20 * std::sqrt(G * args.mass / std::pow(R, 3.0));
 
    auto eos_pressure = [&](const auto& r, const auto& t) { return K_poly * pow(r, 2.0); };
    auto eos_enthalpy = [&](const auto& r, const auto& t) { return 2.0 * K_poly * r; };
 
    fem.mapping->ResetDisplacement();
    update_stiffness_matrix(fem);
 
    mfem::FunctionCoefficient rho_coeff([&](const mfem::Vector& x) {
        double r = x.Norml2();
        if (r < 1e-8) return rho_c;
        if (r >= R) return 0.0;
        return rho_c * std::sin((M_PI/R) * r) / ((M_PI/R) * r);
    });
 
    mfem::BlockVector state(fem.block_true_offsets);
    state = 0.0;
    mfem::ParGridFunction rho(fem.H1_fes.get());
    rho.ProjectCoefficient(rho_coeff);
    rho.GetTrueDofs(state.GetBlock(0));
 
    double integrated_mass = domain_integrate_grid_function(fem, rho, Domains::STELLAR);
    args.mass = integrated_mass;
    ResidualOperator<xad::AReal<double>> coupled_operator(fem, args, eos_enthalpy, eos_pressure, 1.0);
    coupled_operator.SetLambda(exact_lambda);
 
    mfem::BlockVector residual(fem.block_true_offsets);
    residual = 0.0;
    coupled_operator.Mult(state, residual);
 
    mfem::Vector& r_rho = residual.GetBlock(0);
 
    mfem::Array<int> stellar_mask;
    populate_element_mask(fem, Domains::STELLAR, stellar_mask);
 
    mfem::ParLinearForm expected_lf(fem.H1_fes.get());
    mfem::FunctionCoefficient rot_coeff([&](const mfem::Vector& x) {
        return centrifugal_potential(x, args.r.omega);
    });
    MappedScalarCoefficient mapped_rot(*fem.mapping, rot_coeff);
    auto* lf_integrator = new mfem::DomainLFIntegrator(mapped_rot);
    lf_integrator->SetIntRule(fem.int_rule.get());
    expected_lf.AddDomainIntegrator(lf_integrator, stellar_mask);
    expected_lf.Assemble();
 
    mfem::Vector expected_true(fem.H1_fes->GetTrueVSize());
    expected_lf.ParallelAssemble(expected_true);
 
    mfem::Vector diff(r_rho);
    diff -= expected_true;
 
    double norm_diff = diff.Norml2();
    double norm_expected = expected_true.Norml2();
    double rel_err = norm_diff / norm_expected;
 
    size_t failed = (rel_err > 1e-2) ? 1 : 0;
 
    RANK_GUARD(
        auto result_type = (failed == 0) ? TEST_RESULT_TYPE::SUCCESS : TEST_RESULT_TYPE::FAILURE;
        std::println("{}", fmt_test_msg("Test Rotational Residual Injection", result_type, failed, 1));
        if (failed != 0) {
            std::println("\tFAILURE: Rotation residual did not match expected weak form.");
            std::println("\t||r_rho - expected|| / ||expected|| = {:+0.2E}", rel_err);
            std::println("\t||r_rho|| = {:+0.2E}", r_rho.Norml2());
            std::println("\t||expected|| = {:+0.2E}", norm_expected);
        }
    )
}
 
void test_simple_rotation(FEM& fem, const Args& i_args) {
    Args args = i_args;
    constexpr double R = RADIUS;
    constexpr double rho_c = 1.0;
    constexpr double K_poly = (2.0 * G * R * R) / M_PI;
    args.mass = (4.0 * R * R * R * rho_c) / M_PI;
    args.r.enabled = true;
    args.r.omega = 0.18 * std::sqrt(G * args.mass / std::pow(R, 3.0)); // Approx 20% critical
 
    fem.mapping->ResetDisplacement();
 
    auto eos_pressure = [&](const auto& r, const auto& t) { return K_poly * pow(r, 2.0); };
    auto eos_enthalpy = [&](const auto& r, const auto& t) { return 2.0 * K_poly * r; };
 
    mfem::BlockVector u(fem.block_true_offsets);
    u = 0.0;
    mfem::FunctionCoefficient rho_coeff([&](const mfem::Vector& x) {
        double r = x.Norml2();
        return (r < R) ? rho_c * std::sin((M_PI/R) * r) / ((M_PI/R) * r) : 0.0;
    });
    mfem::ParGridFunction rho_init(fem.H1_fes.get());
    rho_init.ProjectCoefficient(rho_coeff);
    rho_init.GetTrueDofs(u.GetBlock(0));
 
    double integrated_mass = domain_integrate_grid_function(fem, rho_init, Domains::STELLAR);
    args.mass = integrated_mass;
 
    ResidualOperator<xad::AReal<double>> coupled_op(fem, args, eos_enthalpy, eos_pressure, 1.0);
 
    mfem::Vector res_vec(u.Size());
    coupled_op.Mult(u, res_vec);
    double initial_norm = res_vec.Norml2();
 
    mfem::LBFGSSolver lbfgs(fem.H1_fes->GetComm());
    lbfgs.SetOperator(coupled_op);
    lbfgs.SetAbsTol(1e-7);
    lbfgs.SetRelTol(1e-6);
    lbfgs.SetMaxIter(200);
    lbfgs.SetPrintLevel(1);
 
    mfem::Vector zero_rhs(u.Size());
    zero_rhs = 0.0;
 
    lbfgs.Mult(zero_rhs, u);
 
    mfem::ParGridFunction rho(fem.H1_fes.get());
    rho.SetFromTrueDofs(u.GetBlock(0));
    double final_mass = domain_integrate_grid_function(fem, rho, Domains::STELLAR);
 
 
    mfem::Vector x1(3), x2(3);
    x1 = 0.0;
    x2 = 0.0;
 
 
    x1(0) = 0.5 * R;
    x2(2) = 0.5 * R;
 
    double equatorial_density = EvalGridFunctionAtPhysicalPoint(fem, rho, x1);
    double polar_density = EvalGridFunctionAtPhysicalPoint(fem, rho, x2);
 
    std::println("Equatorial Density: {:.6e}, Polar Density: {:.6e}, Ratio: {:.4f}", equatorial_density, polar_density, equatorial_density/polar_density);
}
 
//endregion
 
int main(int argc, char** argv) {
    mfem::Mpi::Init(argc, argv);
 
    std::string device_config = "cpu";
    mfem::Device device(device_config);
 
    const int myid = mfem::Mpi::WorldRank();
    const int num_procs = mfem::Mpi::WorldSize();
 
    if (myid == 0) {
        std::println("Starting MFEM run on {} processors.", num_procs);
    }
 
    Args args;
 
    CLI::App app{"Mapped Coordinates XAD-Enabled Non-Linear Solver"};
    app.add_option("-m,--mesh", args.mesh_file)->required()->check(CLI::ExistingFile);
    app.add_option("--max-iters", args.max_iters)->default_val(20);
    app.add_option("--index", args.index)->default_val(1);
    app.add_option("--mass", args.mass)->default_val(MASS);
    app.add_option("--surface-pressure-scalar", args.c)->default_val(1e-4);
    app.add_option("-q,--quad-boost", args.quad_boost)->default_val(0);
 
    args.r.enabled = false;
    args.p.rtol = 1e-12;
    args.p.atol = 1e-11;
    args.p.max_iters = 1000;
 
    CLI11_PARSE(app, argc, argv);
    FEM fem = setup_fem(args.mesh_file, args);
    // RUN_TEST("Mesh Loading", test_mesh_load(fem));
    // RUN_TEST("Test Reference Coordinates", test_ref_coord_storage(fem));
    // RUN_TEST("Test Reference Volume Integral", test_reference_volume_integral(fem));
    // RUN_TEST("Test Spherically Symmetric Center of Mass", test_spherically_symmetric_com(fem));
    //
    // RUN_TEST("Test COM variance to displacement", test_com_variance_to_displacement(fem));
    // RUN_TEST("Test Volume Invariance to Displacement", test_volume_invariance_to_displacement(fem))
    // RUN_TEST("Test Volume of Ellipsoid Deformation", test_volume_ellipsoid_deformation(fem));
    //
    // RUN_TEST("Test Uniform Potential", test_uniform_potential(fem, args));
    // RUN_TEST("Test Ellipsoidal Potential", test_ellipsoidal_potential(fem, args));
    // RUN_TEST("Test Ferrers Sphere Potential", test_ferrers_sphere_potential(fem, args));
    // RUN_TEST("Test Ferrers Ellipsoid Potential", test_ferrers_ellipsoid_potential(fem, args));
    //
    // RUN_TEST("Test Mass Conservation Constraint", test_mass_conservation_constraint(fem, args));
    // RUN_TEST("Test XAD EOS Derivative", test_xad_eos_derivative(fem, args));
    // RUN_TEST("Test Force Continuity", test_force_continuity(fem, args));
    //
    // RUN_TEST("Test Domain Mapper State Isolation", test_domain_mapper_state_isolation(fem));
    //
    // RUN_TEST("Test Hydrostatic Zero Residuals", test_hydrostatic_zero_residual(fem, args));
    // RUN_TEST("TEST L-BFGS Convergence", test_lbfgs_convergence(fem, args));
 
    RUN_TEST("Test Rotating Residual", test_rotational_residual_injection(fem, args));
    RUN_TEST("Test Rigid Rotation", test_simple_rotation(fem, args));
}