#include "k230_wave.h"
#include <iostream>
#include <cstring>
#include <algorithm>

// Constants
const uint16_t WAVE_FORMAT_PCM = 0x0001;

// Helper functions for reading/writing binary data
namespace {

template<typename T>
T read_le(std::istream& is) {
    T value = 0;
    unsigned char bytes[sizeof(T)];
    is.read(reinterpret_cast<char*>(bytes), sizeof(T));

    for (size_t i = 0; i < sizeof(T); ++i) {
        value |= static_cast<T>(bytes[i]) << (i * 8);
    }

    return value;
}

template<typename T>
void write_le(std::ostream& os, T value) {
    for (size_t i = 0; i < sizeof(T); ++i) {
        os.put(static_cast<char>((value >> (i * 8)) & 0xFF));
    }
}

bool read_chunk_header(std::istream& is, std::string& name, uint32_t& size) {
    char name_buf[4];
    if (!is.read(name_buf, 4) || is.gcount() != 4) {
        return false;
    }
    name = std::string(name_buf, 4);
    size = read_le<uint32_t>(is);
    return true;
}

} // namespace

// Utility function implementation
bool is_valid_wave_file(const std::string& filename) {
    std::ifstream file(filename, std::ios::binary);
    if (!file.is_open()) {
        return false;
    }

    char riff_header[12];
    if (!file.read(riff_header, 12) || file.gcount() != 12) {
        return false;
    }

    // Check RIFF and WAVE identifiers
    return (std::string(riff_header, 4) == "RIFF" &&
            std::string(riff_header + 8, 4) == "WAVE");
}

// WaveRead implementation
WaveRead::WaveRead(const std::string& filename)
    : file_(&file_stream_)
    , i_opened_the_file_(true)
    , nchannels_(0)
    , sampwidth_(0)
    , framerate_(0)
    , nframes_(0)
    , comptype_("NONE")
    , compname_("not compressed")
    , framesize_(0)
    , soundpos_(0)
    , data_start_pos_(0)
    , data_size_(0)
    , data_seek_needed_(true)
    , fmt_chunk_read_(false) {

    file_stream_.open(filename, std::ios::binary);
    if (!file_stream_.is_open()) {
        throw WaveError("Cannot open file: " + filename);
    }
    initfp(file_stream_);
}

WaveRead::WaveRead(std::istream& stream)
    : file_(&stream)
    , i_opened_the_file_(false)
    , nchannels_(0)
    , sampwidth_(0)
    , framerate_(0)
    , nframes_(0)
    , comptype_("NONE")
    , compname_("not compressed")
    , framesize_(0)
    , soundpos_(0)
    , data_start_pos_(0)
    , data_size_(0)
    , data_seek_needed_(true)
    , fmt_chunk_read_(false) {

    initfp(stream);
}

WaveRead::WaveRead(WaveRead&& other) noexcept
    : file_(other.file_)
    , i_opened_the_file_(other.i_opened_the_file_)
    , file_stream_(std::move(other.file_stream_))
    , nchannels_(other.nchannels_)
    , sampwidth_(other.sampwidth_)
    , framerate_(other.framerate_)
    , nframes_(other.nframes_)
    , comptype_(std::move(other.comptype_))
    , compname_(std::move(other.compname_))
    , framesize_(other.framesize_)
    , soundpos_(other.soundpos_)
    , data_start_pos_(other.data_start_pos_)
    , data_size_(other.data_size_)
    , data_seek_needed_(other.data_seek_needed_)
    , fmt_chunk_read_(other.fmt_chunk_read_) {

    other.file_ = nullptr;
    other.i_opened_the_file_ = false;
}

WaveRead& WaveRead::operator=(WaveRead&& other) noexcept {
    if (this != &other) {
        close();
        file_ = other.file_;
        i_opened_the_file_ = other.i_opened_the_file_;
        file_stream_ = std::move(other.file_stream_);
        nchannels_ = other.nchannels_;
        sampwidth_ = other.sampwidth_;
        framerate_ = other.framerate_;
        nframes_ = other.nframes_;
        comptype_ = std::move(other.comptype_);
        compname_ = std::move(other.compname_);
        framesize_ = other.framesize_;
        soundpos_ = other.soundpos_;
        data_start_pos_ = other.data_start_pos_;
        data_size_ = other.data_size_;
        data_seek_needed_ = other.data_seek_needed_;
        fmt_chunk_read_ = other.fmt_chunk_read_;

        other.file_ = nullptr;
        other.i_opened_the_file_ = false;
    }
    return *this;
}

WaveRead::~WaveRead() {
    close();
}

void WaveRead::initfp(std::istream& file) {
    // Reset stream state and position
    file.clear();
    file.seekg(0, std::ios::beg);

    // Read RIFF header
    char riff_header[12];
    if (!file.read(riff_header, 12) || file.gcount() != 12) {
        throw WaveError("Cannot read RIFF header");
    }

    // Check RIFF identifier
    if (std::string(riff_header, 4) != "RIFF") {
        throw WaveError("File does not start with RIFF id");
    }

    // Check WAVE identifier
    if (std::string(riff_header + 8, 4) != "WAVE") {
        throw WaveError("Not a WAVE file");
    }

    // Read chunks
    bool data_chunk_found = false;

    while (file.good()) {
        std::string chunk_name;
        uint32_t chunk_size;

        if (!read_chunk_header(file, chunk_name, chunk_size)) {
            break;
        }

        std::streampos chunk_start = file.tellg();

        if (chunk_name == "fmt ") {
            read_fmt_chunk(file, chunk_size);
            fmt_chunk_read_ = true;
        } else if (chunk_name == "data") {
            if (!fmt_chunk_read_) {
                throw WaveError("Data chunk before fmt chunk");
            }
            data_start_pos_ = chunk_start;
            data_size_ = chunk_size;
            nframes_ = chunk_size / framesize_;
            data_seek_needed_ = true;
            data_chunk_found = true;
            // Don't break, continue to read potential other chunks
        }

        // Skip to next chunk (account for pad byte if chunk size is odd)
        std::streampos skip_amount = chunk_size;
        if (chunk_size % 2 != 0) {
            skip_amount += 1;
        }
        file.seekg(chunk_start + static_cast<std::streamoff>(skip_amount));

        // If we've found the data chunk and don't care about other chunks, we can break
        if (data_chunk_found) {
            break;
        }
    }

    if (!fmt_chunk_read_ || !data_chunk_found) {
        throw WaveError("Fmt chunk and/or data chunk missing");
    }

    // Reset to data start for reading
    data_seek_needed_ = true;
    soundpos_ = 0;
}

void WaveRead::read_fmt_chunk(std::istream& chunk, uint32_t chunk_size) {
    if (chunk_size < 16) {
        throw WaveError("Fmt chunk too small");
    }

    uint16_t format_tag = read_le<uint16_t>(chunk);
    nchannels_ = read_le<uint16_t>(chunk);
    framerate_ = read_le<uint32_t>(chunk);
    /* uint32_t avg_bytes_per_sec = */ read_le<uint32_t>(chunk);
    uint16_t block_align = read_le<uint16_t>(chunk);

    if (format_tag == WAVE_FORMAT_PCM) {
        uint16_t bits_per_sample = read_le<uint16_t>(chunk);
        sampwidth_ = (bits_per_sample + 7) / 8;
    } else {
        throw WaveError("Unknown format: " + std::to_string(format_tag));
    }

    framesize_ = nchannels_ * sampwidth_;
    comptype_ = "NONE";
    compname_ = "not compressed";
}

void WaveRead::rewind() {
    data_seek_needed_ = true;
    soundpos_ = 0;
}

void WaveRead::close() {
    if (i_opened_the_file_ && file_stream_.is_open()) {
        file_stream_.close();
    }
    file_ = nullptr;
}

int WaveRead::tell() const {
    return static_cast<int>(soundpos_);
}

int WaveRead::get_channels() const {
    return nchannels_;
}

int WaveRead::get_frames() const {
    return nframes_;
}

int WaveRead::get_sampwidth() const {
    return sampwidth_;
}

int WaveRead::get_framerate() const {
    return framerate_;
}

std::string WaveRead::get_comptype() const {
    return comptype_;
}

std::string WaveRead::get_compname() const {
    return compname_;
}

WaveParams WaveRead::get_params() const {
    return WaveParams{
        get_channels(),
        get_sampwidth(),
        get_framerate(),
        get_frames(),
        get_comptype(),
        get_compname()
    };
}

std::vector<char> WaveRead::read_frames(int nframes) {
    if (nframes <= 0) {
        return {};
    }

    if (data_seek_needed_) {
        file_->seekg(data_start_pos_);
        data_seek_needed_ = false;
    }

    size_t bytes_to_read = nframes * framesize_;

    // Don't read past the end of data
    size_t max_bytes = data_size_ - (soundpos_ * framesize_);
    if (bytes_to_read > max_bytes) {
        bytes_to_read = max_bytes;
        nframes = bytes_to_read / framesize_;
    }

    if (bytes_to_read == 0) {
        return {};
    }

    std::vector<char> data(bytes_to_read);
    file_->read(data.data(), bytes_to_read);

    size_t bytes_read = file_->gcount();
    if (bytes_read < bytes_to_read) {
        data.resize(bytes_read);
    }

    soundpos_ += data.size() / framesize_;
    return data;
}

// WaveWrite implementation
WaveWrite::WaveWrite(const std::string& filename)
    : file_(&file_stream_)
    , i_opened_the_file_(true)
    , nchannels_(0)
    , sampwidth_(0)
    , framerate_(0)
    , nframes_(0)
    , comptype_("NONE")
    , compname_("not compressed")
    , nframes_written_(0)
    , data_written_(0)
    , data_length_(0)
    , header_written_(false)
    , form_length_pos_(0)
    , data_length_pos_(0) {

    file_stream_.open(filename, std::ios::binary);
    if (!file_stream_.is_open()) {
        throw WaveError("Cannot open file for writing: " + filename);
    }
}

WaveWrite::WaveWrite(std::ostream& stream)
    : file_(&stream)
    , i_opened_the_file_(false)
    , nchannels_(0)
    , sampwidth_(0)
    , framerate_(0)
    , nframes_(0)
    , comptype_("NONE")
    , compname_("not compressed")
    , nframes_written_(0)
    , data_written_(0)
    , data_length_(0)
    , header_written_(false)
    , form_length_pos_(0)
    , data_length_pos_(0) {
}

WaveWrite::WaveWrite(WaveWrite&& other) noexcept
    : file_(other.file_)
    , i_opened_the_file_(other.i_opened_the_file_)
    , file_stream_(std::move(other.file_stream_))
    , nchannels_(other.nchannels_)
    , sampwidth_(other.sampwidth_)
    , framerate_(other.framerate_)
    , nframes_(other.nframes_)
    , comptype_(std::move(other.comptype_))
    , compname_(std::move(other.compname_))
    , nframes_written_(other.nframes_written_)
    , data_written_(other.data_written_)
    , data_length_(other.data_length_)
    , header_written_(other.header_written_)
    , form_length_pos_(other.form_length_pos_)
    , data_length_pos_(other.data_length_pos_) {

    other.file_ = nullptr;
    other.i_opened_the_file_ = false;
}

WaveWrite& WaveWrite::operator=(WaveWrite&& other) noexcept {
    if (this != &other) {
        close();
        file_ = other.file_;
        i_opened_the_file_ = other.i_opened_the_file_;
        file_stream_ = std::move(other.file_stream_);
        nchannels_ = other.nchannels_;
        sampwidth_ = other.sampwidth_;
        framerate_ = other.framerate_;
        nframes_ = other.nframes_;
        comptype_ = std::move(other.comptype_);
        compname_ = std::move(other.compname_);
        nframes_written_ = other.nframes_written_;
        data_written_ = other.data_written_;
        data_length_ = other.data_length_;
        header_written_ = other.header_written_;
        form_length_pos_ = other.form_length_pos_;
        data_length_pos_ = other.data_length_pos_;

        other.file_ = nullptr;
        other.i_opened_the_file_ = false;
    }
    return *this;
}

WaveWrite::~WaveWrite() {
    close();
}

void WaveWrite::set_channels(int nchannels) {
    if (data_written_ > 0) {
        throw WaveError("Cannot change parameters after starting to write");
    }
    if (nchannels < 1) {
        throw WaveError("Bad number of channels");
    }
    nchannels_ = nchannels;
}

int WaveWrite::get_channels() const {
    if (nchannels_ == 0) {
        throw WaveError("Number of channels not set");
    }
    return nchannels_;
}

void WaveWrite::set_sampwidth(int sampwidth) {
    if (data_written_ > 0) {
        throw WaveError("Cannot change parameters after starting to write");
    }
    if (sampwidth < 1 || sampwidth > 4) {
        throw WaveError("Bad sample width");
    }
    sampwidth_ = sampwidth;
}

int WaveWrite::get_sampwidth() const {
    if (sampwidth_ == 0) {
        throw WaveError("Sample width not set");
    }
    return sampwidth_;
}

void WaveWrite::set_framerate(int framerate) {
    if (data_written_ > 0) {
        throw WaveError("Cannot change parameters after starting to write");
    }
    if (framerate <= 0) {
        throw WaveError("Bad frame rate");
    }
    framerate_ = framerate;
}

int WaveWrite::get_framerate() const {
    if (framerate_ == 0) {
        throw WaveError("Frame rate not set");
    }
    return framerate_;
}

void WaveWrite::set_frames(int nframes) {
    if (data_written_ > 0) {
        throw WaveError("Cannot change parameters after starting to write");
    }
    nframes_ = nframes;
}

int WaveWrite::get_frames() const {
    return nframes_written_;
}

void WaveWrite::set_comptype(const std::string& comptype, const std::string& compname) {
    if (data_written_ > 0) {
        throw WaveError("Cannot change parameters after starting to write");
    }
    if (comptype != "NONE") {
        throw WaveError("Unsupported compression type");
    }
    comptype_ = comptype;
    compname_ = compname;
}

std::string WaveWrite::get_comptype() const {
    return comptype_;
}

std::string WaveWrite::get_compname() const {
    return compname_;
}

void WaveWrite::set_params(const WaveParams& params) {
    if (data_written_ > 0) {
        throw WaveError("Cannot change parameters after starting to write");
    }
    set_channels(params.nchannels);
    set_sampwidth(params.sampwidth);
    set_framerate(params.framerate);
    set_frames(params.nframes);
    set_comptype(params.comptype, params.compname);
}

WaveParams WaveWrite::get_params() const {
    if (nchannels_ == 0 || sampwidth_ == 0 || framerate_ == 0) {
        throw WaveError("Not all parameters set");
    }
    return WaveParams{
        nchannels_,
        sampwidth_,
        framerate_,
        nframes_,
        comptype_,
        compname_
    };
}

int WaveWrite::tell() const {
    return nframes_written_;
}

void WaveWrite::write_frames_raw(const std::vector<char>& data) {
    ensure_header_written(data.size());

    size_t framesize = nchannels_ * sampwidth_;
    int nframes = data.size() / framesize;

    file_->write(data.data(), data.size());
    data_written_ += data.size();
    nframes_written_ += nframes;
}

void WaveWrite::write_frames(const std::vector<char>& data) {
    write_frames_raw(data);
    if (data_length_ != data_written_) {
        patch_header();
    }
}

void WaveWrite::close() {
    if (file_ && file_->good()) {
        try {
            ensure_header_written(0);
            if (data_length_ != data_written_) {
                patch_header();
            }
            file_->flush();
        } catch (...) {
            // Ignore errors during close
        }
    }
    if (i_opened_the_file_ && file_stream_.is_open()) {
        file_stream_.close();
    }
    file_ = nullptr;
}

void WaveWrite::ensure_header_written(size_t datasize) {
    if (!header_written_) {
        if (nchannels_ == 0) {
            throw WaveError("Number of channels not specified");
        }
        if (sampwidth_ == 0) {
            throw WaveError("Sample width not specified");
        }
        if (framerate_ == 0) {
            throw WaveError("Sampling rate not specified");
        }
        write_header(datasize);
    }
}

void WaveWrite::write_header(size_t initlength) {
    if (header_written_) return;

    file_->write("RIFF", 4);

    if (nframes_ == 0) {
        nframes_ = initlength / (nchannels_ * sampwidth_);
    }
    data_length_ = nframes_ * nchannels_ * sampwidth_;

    // Save position for later patching
    form_length_pos_ = file_->tellp();

    // Write file size (will be updated later)
    write_le<uint32_t>(*file_, 36 + data_length_);

    file_->write("WAVE", 4);

    // Write fmt chunk
    file_->write("fmt ", 4);
    write_le<uint32_t>(*file_, 16); // fmt chunk size
    write_le<uint16_t>(*file_, WAVE_FORMAT_PCM);
    write_le<uint16_t>(*file_, nchannels_);
    write_le<uint32_t>(*file_, framerate_);
    write_le<uint32_t>(*file_, nchannels_ * framerate_ * sampwidth_); // avg bytes per sec
    write_le<uint16_t>(*file_, nchannels_ * sampwidth_); // block align
    write_le<uint16_t>(*file_, sampwidth_ * 8); // bits per sample

    // Write data chunk header
    file_->write("data", 4);
    data_length_pos_ = file_->tellp();
    write_le<uint32_t>(*file_, data_length_);

    header_written_ = true;
}

void WaveWrite::patch_header() {
    if (!header_written_) return;
    if (data_written_ == data_length_) return;

    auto curpos = file_->tellp();

    // Update RIFF chunk size
    file_->seekp(form_length_pos_);
    write_le<uint32_t>(*file_, 36 + data_written_);

    // Update data chunk size
    file_->seekp(data_length_pos_);
    write_le<uint32_t>(*file_, data_written_);

    file_->seekp(curpos);
    data_length_ = data_written_;
}

// Open functions implementation
std::unique_ptr<WaveRead> open_read(const std::string& filename) {
    return std::make_unique<WaveRead>(filename);
}

std::unique_ptr<WaveRead> open_read(std::istream& stream) {
    return std::make_unique<WaveRead>(stream);
}

std::unique_ptr<WaveWrite> open_write(const std::string& filename) {
    return std::make_unique<WaveWrite>(filename);
}

std::unique_ptr<WaveWrite> open_write(std::ostream& stream) {
    return std::make_unique<WaveWrite>(stream);
}

std::unique_ptr<WaveFile> open(const std::string& filename, const std::string& mode) {
    if (mode == "r" || mode == "rb") {
        return std::make_unique<WaveReadFile>(filename);
    } else if (mode == "w" || mode == "wb") {
        return std::make_unique<WaveWriteFile>(filename);
    } else {
        throw WaveError("Mode must be 'r', 'rb', 'w', or 'wb'");
    }
}