/* Copyright (c) 2025, Canaan Bright Sight Co., Ltd
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 1. Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <cstddef>
#include <cstdint>
#include <cstdio>

#include "api/HardwareSPI.h"

#include "drv_fpioa.h"
#include "drv_spi.h"

#if !CONFIG_DISABLE_HAL_LOCKS
#define SPI_PARAM_LOCK() pthread_mutex_lock(&paramLock)
#define SPI_PARAM_UNLOCK() pthread_mutex_unlock(&paramLock)
#else
#define SPI_PARAM_LOCK()
#define SPI_PARAM_UNLOCK()
#endif

namespace arduino {

SPIClass spi0(0);
SPIClass spi1(1);
SPIClass spi2(2);

SPIClass::SPIClass(uint8_t spi_bus)
  : _spi_num(spi_bus), _spi(NULL), _use_hw_ss(false), _sck(-1), _miso(-1), _mosi(-1), _ss(-1), _freq(1000000), _bitOrder(SPI_MSBFIRST), _dataMode(SPI_MODE0) {
#if !CONFIG_DISABLE_HAL_LOCKS
    pthread_mutex_init(&paramLock, NULL);
#endif
}

SPIClass::~SPIClass() {
  end();
#if !CONFIG_DISABLE_HAL_LOCKS
  pthread_mutex_destroy(&paramLock);
#endif
}

bool SPIClass::begin(int8_t sck, int8_t miso, int8_t mosi, int8_t ss) {

#define SPI_SCK_FUNC(id)        \
    ((id) == 0   ? OSPI_CLK     \
     : (id) == 1 ? QSPI0_CLK    \
     : (id) == 2 ? QSPI1_CLK    \
     : FUNC_MAX)

#define SPI_MOSI_FUNC(id)       \
    ((id) == 0   ? OSPI_D0      \
     : (id) == 1 ? QSPI0_D0     \
     : (id) == 2 ? QSPI1_D0     \
     : FUNC_MAX)

#define SPI_MISO_FUNC(id)       \
    ((id) == 0   ? OSPI_D1      \
     : (id) == 1 ? QSPI0_D1     \
     : (id) == 2 ? QSPI1_D1     \
     : FUNC_MAX)

#define SPI_HW_CS(id)           \
    ((id) == 0   ? 14           \
     : (id) == 1 ? 14           \
     : (id) == 2 ? 20           \
     : 14)

    int ret = 0;
    fpioa_err_t err = FPIOA_OK;

    if (_spi_num > 2 || _spi_num < 0) {
        printf("Invalid spi id for SPI%d\n", _spi_num);
        return false;
    }

    if (ss < 0) {
        printf("Invalid ss pin %d for SPI%d, so use hw spi\n", ss, _spi_num);
        _use_hw_ss = true;
    }

    fpioa_func_t f_sck = SPI_SCK_FUNC(_spi_num);
    if (FPIOA_OK != (err = drv_fpioa_validate_pin(sck, f_sck))) {
        printf("Invalid sck pin %d for SPI%d\n", sck, _spi_num);
        return false;
    }

    fpioa_func_t f_mosi = SPI_MOSI_FUNC(_spi_num);
    if (FPIOA_OK != (err = drv_fpioa_validate_pin(mosi, f_mosi))) {
        printf("Invalid mosi pin %d for SPI%d\n", mosi, _spi_num);
        return false;
    }

    fpioa_func_t f_miso = SPI_MISO_FUNC(_spi_num);
    if (FPIOA_OK != (err = drv_fpioa_validate_pin(miso, f_miso))) {
        printf("Invalid miso pin %d for SPI%d\n", miso, _spi_num);
        return false;
    }

    _sck = sck;
    _miso = miso;
    _mosi = mosi;
    _ss = ss;

    ret = drv_spi_inst_create(_spi_num, true, _dataMode, _freq, 8, SPI_HW_CS(_spi_num), SPI_HAL_DATA_LINE_1, &_spi);
    if (ret) {
        printf("%s: Failed to create spi%d instance\n", __func__, _spi_num);
        return false;
    }

    if (!_use_hw_ss) {
        ret = drv_spi_set_cs_mode(_spi, _use_hw_ss);
        if (ret) {
            drv_spi_inst_destroy(&_spi);
            _spi = NULL;
            printf("%s: Fail to set cs mode for SPI%d\n", __func__, _spi_num);
            return false;
        }
    }

    return true;
}

/* TODO: recover fpioa ??? */
void SPIClass::end() {
    if (!_spi) {
        return;
    }
    drv_spi_inst_destroy(&_spi);
    _spi = NULL;
}

void SPIClass::setHwCs(bool use) {
    fpioa_err_t err = FPIOA_OK;

    if (_ss < 0) {
        return;
    }

    if (use && !_use_hw_ss) {
        if (0 != drv_spi_set_cs_mode(_spi, use)) {
            printf("%s: Fail to set cs mode %d for SPI%d\n", __func__, use, _spi_num);
        } else {
            _use_hw_ss = use;
        }
    } else if (!use && _use_hw_ss) {
        if (FPIOA_OK != (err = drv_fpioa_validate_pin(_ss, GPIO0 + _ss))) {
            printf("Invalid ss pin %d for SPI%d\n", _ss, _spi_num);
        } else {
            if (0 != drv_spi_set_cs_mode(_spi, use)) {
                printf("%s: Fail to set cs mode %d for SPI%d\n", __func__, use, _spi_num);
            } else {
                _use_hw_ss = use;
            }
        }
    }
}

void SPIClass::setSSInvert(bool invert) {

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    if (0 != drv_spi_set_cs_polarity(_spi, invert)) {
        printf("%s: Fail to set cs_polarity %d for SPI%d\n", __func__, invert, _spi_num);
    }
}

void SPIClass::setFrequency(uint32_t freq) {

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    if (_freq != freq) {
        if (0 != drv_spi_set_baudrate(_spi, freq)) {
            printf("%s: Fail to set baudrate %d for SPI%d\n", __func__, freq, _spi_num);
        } else {
            _freq = freq;
        }
    }
}

void SPIClass::setDataMode(uint8_t dataMode) {

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    if (_dataMode != dataMode) {
        if (0 != drv_spi_set_datamode(_spi, dataMode)) {
            printf("%s: Fail to set data mode %d for SPI%d\n",
                   __func__, dataMode, _spi_num);
        } else {
            _dataMode = dataMode;
        }
    }
}

void SPIClass::setBitOrder(uint8_t bitOrder) {

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    if (_bitOrder != bitOrder) {
        _bitOrder = bitOrder;
    }
}

void SPIClass::beginTransaction(SPISettings settings) {

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    SPI_PARAM_LOCK();

    if (_freq != settings._clock) {
        if (0 != drv_spi_set_baudrate(_spi, settings._clock)) {
            printf("%s: Fail to set baudrate %d for SPI%d\n",
                   __func__, settings._clock, _spi_num);
        } else {
            _freq = settings._clock;
        }
    }

    if (_dataMode != settings._dataMode) {
        if (0 != drv_spi_set_datamode(_spi, settings._dataMode)) {
            printf("%s: Fail to set data mode %d for SPI%d\n",
                   __func__, settings._dataMode, _spi_num);
        } else {
            _dataMode = settings._dataMode;
        }
    }

    if (_bitOrder != settings._bitOrder) {
        _bitOrder = settings._bitOrder;
    }
}

void SPIClass::endTransaction() {
    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    SPI_PARAM_UNLOCK();
}

void SPIClass::write(uint8_t data) {
    int ret;

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    ret = drv_spi_write(_spi, &data, sizeof(data), true);
    if (ret < 0) {
        printf("%s: drv_spi_write err: %d\n", __func__, ret);
    }
}

uint8_t SPIClass::transfer(uint8_t data) {
    int ret;
    uint8_t rx_data;

    if (!_spi) {
        printf("have no spi instance\n");
        return 0;
    }

    ret = drv_spi_transfer(_spi, &data, &rx_data, sizeof(data), true);
    if (ret < 0) {
        printf("%s: drv_spi_transfer err: %d\n", __func__, ret);
        return ret;
    } else {
        return rx_data;
    }
}

void SPIClass::write16(uint16_t data) {
    int ret;

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    if (_bitOrder == SPI_LSBFIRST) {
        MSB_16_SET(data, data);
    }

    ret = drv_spi_write(_spi, &data, sizeof(data), true);
    if (ret < 0) {
        printf("%s: drv_spi_write err: %d\n", __func__, ret);
    }
}

uint16_t SPIClass::transfer16(uint16_t data) {
    int ret;
    uint16_t rx_data;

    if (!_spi) {
        printf("have no spi instance\n");
        return 0;
    }

    if (_bitOrder == SPI_LSBFIRST) {
        MSB_16_SET(data, data);
    }

    ret = drv_spi_transfer(_spi, &data, &rx_data, sizeof(data), true);
    if (ret < 0) {
        printf("%s: drv_spi_transfer err: %d\n", __func__, ret);
        return ret;
    } else {
        if (_bitOrder == SPI_LSBFIRST) {
            MSB_16_SET(rx_data, rx_data);
        }
        return rx_data;
    }
}

void SPIClass::write32(uint32_t data) {
    int ret;

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    if (_bitOrder == SPI_LSBFIRST) {
        MSB_32_SET(data, data);
    }

    ret = drv_spi_write(_spi, &data, sizeof(data), true);
    if (ret < 0) {
        printf("%s: drv_spi_write err: %d\n", __func__, ret);
    }
}

uint32_t SPIClass::transfer32(uint32_t data) {
    int ret;
    uint32_t rx_data;

    if (!_spi) {
        printf("have no spi instance\n");
        return 0;
    }

    if (_bitOrder == SPI_LSBFIRST) {
        MSB_32_SET(data, data);
    }

    ret = drv_spi_transfer(_spi, &data, &rx_data, sizeof(data), true);
    if (ret < 0) {
        printf("%s: drv_spi_transfer err: %d\n", __func__, ret);
        return ret;
    } else {
        if (_bitOrder == SPI_LSBFIRST) {
            MSB_32_SET(rx_data, rx_data);
        }
        return rx_data;
    }
}

void SPIClass::transferBits(uint32_t data, uint32_t *out, uint8_t bits) {
    if (!_spi) {
        printf("have no spi instance\n");
        return ;
    }

    int ret;
    uint32_t bytes, mask;

    if (bits > 32) {
        bits = 32;
    }
    bytes = (bits + 7) / 8;
    mask = (((uint64_t)1 << bits) - 1) & 0xFFFFFFFF;
    data = data & mask;
    if (_bitOrder == SPI_LSBFIRST) {
        if (bytes == 2) {
            MSB_16_SET(data, data);
        } else if (bytes == 3) {
            MSB_24_SET(data, data);
        } else {
            MSB_32_SET(data, data);
        }
    }

    if (0 != drv_spi_set_data_bits(_spi, bits)) {
        printf("%s: Fail to set data bits %d for SPI%d\n",
               __func__, bits, _spi_num);
    }

    ret = drv_spi_transfer(_spi, &data, out, 1, true);
    if (ret < 0) {
        printf("%s: drv_spi_transfer err: %d\n", __func__, ret);
        goto out;
    }

    if (out) {
        if (_bitOrder == SPI_LSBFIRST) {
            if (bytes == 2) {
                MSB_16_SET(*out, *out);
            } else if (bytes == 3) {
                MSB_24_SET(*out, *out);
            } else {
                MSB_32_SET(*out, *out);
            }
        }
    }
out:
    if (0 != drv_spi_set_data_bits(_spi, 8)) {
        printf("%s: Fail to set data bits 8 for SPI%d\n", __func__, _spi_num);
    }
}

/**
 * @param data uint8_t *
 * @param size uint32_t
 */
void SPIClass::writeBytes(const uint8_t *data, uint32_t size) {
    int ret;

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    ret = drv_spi_write(_spi, data, size, true);
    if (ret < 0) {
        printf("%s: drv_spi_write err: %d\n", __func__, ret);
    }
}

void SPIClass::transfer(void *data, uint32_t size) {
    transferBytes((const uint8_t *)data, (uint8_t *)data, size);
}

/**
 * @param data void *
 * @param size uint32_t
 */
void SPIClass::writePixels(const void *data, uint32_t size) {
    int ret;
    bool cs_change = false;
    size_t longs = size >> 2;

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    if (size & 3) {
        longs++;
    }
    uint32_t *_data = (uint32_t *)data;
    size_t c_len = 0, c_longs = 0, l_bytes = 0;

    while (size) {
        c_len = (size > 64) ? 64 : size;
        c_longs = (longs > 16) ? 16 : longs;
        l_bytes = (c_len & 3);

        for (size_t i = 0; i < c_longs; i++) {
            if (_bitOrder == SPI_LSBFIRST) {
                if (l_bytes && i == (c_longs - 1)) {
                    if (l_bytes == 2) {
                        MSB_16_SET(_data[i], _data[i]);
                    } else {
                        _data[i] = _data[i] & 0xFF;
                    }
                } else {
                    MSB_PIX_SET(_data[i], _data[i]);
                }
            }
        }

        if ((size - c_len) <= 0) {
            cs_change = true;
        }

        ret = drv_spi_write(_spi, _data, c_len , cs_change);
        if (ret < 0) {
            printf("drv_spi_write err:%d\n", ret);
        }

        _data += c_longs;
        longs -= c_longs;
        size -= c_len;
    }
}

/**
 * @param data uint8_t * data buffer. can be NULL for Read Only operation
 * @param out  uint8_t * output buffer. can be NULL for Write Only operation
 * @param size uint32_t
 */
void SPIClass::transferBytes(const uint8_t *data, uint8_t *out, uint32_t size) {
    int ret;

    if (!_spi) {
        printf("have no spi instance\n");
        return;
    }

    ret = drv_spi_transfer(_spi, data, out, size, true);
    if (ret < 0) {
        printf("%s: drv_spi_transfer err: %d\n", __func__, ret);
    }
}

/**
 * @param data uint8_t *
 * @param size uint8_t  max for size is 64Byte
 * @param repeat uint32_t
 */
void SPIClass::writePattern(const uint8_t *data, uint8_t size, uint32_t repeat) {
    if (size > 64) {
        return;  //max Hardware FIFO
    }

    uint32_t byte = (size * repeat);
    uint8_t r = (64 / size);
    const uint8_t max_bytes_FIFO = r * size;  // Max number of whole patterns (in bytes) that can fit into the hardware FIFO

    while (byte) {
        if (byte > max_bytes_FIFO) {
            writePattern_(data, size, r);
            byte -= max_bytes_FIFO;
        } else {
            writePattern_(data, size, (byte / size));
            byte = 0;
        }
    }
}

void SPIClass::writePattern_(const uint8_t *data, uint8_t size, uint8_t repeat) {
    uint8_t bytes = (size * repeat);
    uint8_t buffer[64];
    uint8_t *bufferPtr = &buffer[0];
    const uint8_t *dataPtr;
    uint8_t dataSize = bytes;
    for (uint8_t i = 0; i < repeat; i++) {
        dataSize = size;
        dataPtr = data;
        while (dataSize--) {
            *bufferPtr = *dataPtr;
            dataPtr++;
            bufferPtr++;
        }
    }

    writeBytes(&buffer[0], bytes);
}

} // namespace arduino
