//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Device/Coord/CoordSystem1D.cpp
//! @brief     Implements CoordSystem1D class and derived classes.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Device/Coord/CoordSystem1D.h"
#include "Base/Axis/MakeScale.h"
#include "Base/Axis/Scale.h"
#include "Base/Const/Units.h"
#include "Base/Util/Assert.h"
#include <cmath>
#include <numbers>
#include <stdexcept>

using std::numbers::pi;

namespace {

double getQ(double wavelength, double angle)
{
    return 4.0 * pi * std::sin(angle) / wavelength;
}

double getInvQ(double wavelength, double q)
{
    double sin_angle = q * wavelength / (4.0 * pi);
    return std::asin(sin_angle);
}

double backTransform(double value, Coords coords, double wavelength)
{
    switch (coords) {
    case Coords::RADIANS:
        return value;
    case Coords::DEGREES:
        return Units::deg2rad(value);
    case Coords::QSPACE:
        return getInvQ(wavelength, value);
    default:
        ASSERT(false);
    }
}

Scale* createAxisFrom(const Scale& axis, Coords coords, const std::string& name, double wavelength)
{
    std::vector<double> ret;
    ret.reserve(axis.size());
    for (const double value : axis.binCenters())
        ret.emplace_back(backTransform(value, coords, wavelength));
    return newListScan(name, ret);
}

} // namespace


//  ************************************************************************************************
//  class CoordSystem1D
//  ************************************************************************************************

CoordSystem1D::CoordSystem1D(const Scale* axis)
    : ICoordSystem({axis})
{
}

double CoordSystem1D::calculateMin(size_t i_axis, Coords units) const
{
    ASSERT(i_axis == 0);
    units = substituteDefaultUnits(units);
    if (units == Coords::NBINS)
        return 0.0;
    auto translator = getTraslatorTo(units);
    return translator(m_axes[0]->binCenter(0));
}

double CoordSystem1D::calculateMax(size_t i_axis, Coords units) const
{
    ASSERT(i_axis == 0);
    units = substituteDefaultUnits(units);
    if (units == Coords::NBINS)
        return static_cast<double>(m_axes[0]->size());
    auto translator = getTraslatorTo(units);
    return translator(m_axes[0]->binCenter(m_axes[0]->size() - 1));
}

Scale* CoordSystem1D::convertedAxis(size_t i_axis, Coords units) const
{
    ASSERT(i_axis == 0);
    units = substituteDefaultUnits(units);
    if (units == Coords::NBINS)
        return newEquiDivision(nameOfAxis(0, units), m_axes[0]->size(), calculateMin(0, units),
                               calculateMax(0, units));

    std::function<double(double)> translator = getTraslatorTo(units);
    auto coords = m_axes[0]->binCenters();
    for (size_t i = 0, size = coords.size(); i < size; ++i)
        coords[i] = translator(coords[i]);
    return newListScan(nameOfAxis(0, units), coords);
}


//  ************************************************************************************************
//  class AngularReflectometryCoords
//  ************************************************************************************************

AngularReflectometryCoords::AngularReflectometryCoords(double wavelength, const Scale& axis,
                                                       Coords axis_units)
    : CoordSystem1D(createAxisFrom(axis, axis_units, nameOfAxis0(axis_units), wavelength))
    , m_wavelength(wavelength)
{
    if (m_axes[0]->min() < 0 || m_axes[0]->max() > (pi / 2))
        throw std::runtime_error("Error in CoordSystem1D: input axis range is out of bounds");
}

AngularReflectometryCoords::AngularReflectometryCoords(const AngularReflectometryCoords& other)
    : CoordSystem1D(other.m_axes[0]->clone())
    , m_wavelength(other.m_wavelength)
{
}

AngularReflectometryCoords::~AngularReflectometryCoords() = default;

AngularReflectometryCoords* AngularReflectometryCoords::clone() const
{
    return new AngularReflectometryCoords(*this);
}

std::vector<Coords> AngularReflectometryCoords::availableUnits() const
{
    return availableUnits0();
}

std::vector<Coords> AngularReflectometryCoords::availableUnits0() // static
{
    return {Coords::NBINS, Coords::RADIANS, Coords::DEGREES, Coords::QSPACE};
}

std::string AngularReflectometryCoords::nameOfAxis(size_t i_axis, const Coords units) const
{
    ASSERT(i_axis == 0);
    return nameOfAxis0(units);
}

std::string AngularReflectometryCoords::nameOfAxis0(const Coords units) // static
{
    switch (units) {
    case Coords::NBINS:
        return "X [nbins]";
    case Coords::RADIANS:
        return "alpha_i [rad]";
    case Coords::QSPACE:
        return "Q [1/nm]";
    case Coords::DEGREES:
    default:
        return "alpha_i [deg]";
    }
}

std::function<double(double)> AngularReflectometryCoords::getTraslatorTo(Coords units) const
{
    switch (units) {
    case Coords::RADIANS:
        return [](double value) { return value; };
    case Coords::DEGREES:
        return [](double value) { return Units::rad2deg(value); };
    case Coords::QSPACE:
        return [wl = m_wavelength](double value) { return getQ(wl, value); };
    default:
        ASSERT(false);
    }
}


//  ************************************************************************************************
//  class WavenumberReflectometryCoords
//  ************************************************************************************************

WavenumberReflectometryCoords::WavenumberReflectometryCoords(const Scale* axis)
    : CoordSystem1D(axis)
{
}

WavenumberReflectometryCoords::WavenumberReflectometryCoords(
    const WavenumberReflectometryCoords& other)
    : CoordSystem1D(other.m_axes[0]->clone())
{
}

WavenumberReflectometryCoords::~WavenumberReflectometryCoords() = default;

WavenumberReflectometryCoords* WavenumberReflectometryCoords::clone() const
{
    return new WavenumberReflectometryCoords(*this);
}

//! Returns the list of all available units
std::vector<Coords> WavenumberReflectometryCoords::availableUnits() const
{
    return {Coords::NBINS, Coords::QSPACE};
}

std::string WavenumberReflectometryCoords::nameOfAxis(size_t i_axis, const Coords units) const
{
    ASSERT(i_axis == 0);
    switch (units) {
    case Coords::NBINS:
        return "X [nbins]";
    case Coords::QSPACE:
    default:
        return "Q [1/nm]";
    }
}

//! Returns translating functional (inv. nm --> desired units)
std::function<double(double)> WavenumberReflectometryCoords::getTraslatorTo(Coords units) const
{
    switch (units) {
    case Coords::QSPACE:
        return [](double value) { return value; };
    default:
        ASSERT(false);
    }
}
