Program Listing for File sparse_file_format.h

Return to documentation for file (src/sparsebase/io/sparse_file_format.h)

/*******************************************************
 * Copyright (c) 2022 SparCity, Amro Alabsi Aljundi, Taha Atahan Akyildiz, Arda
 *Sener All rights reserved.
 *
 * This file is distributed under MIT license.
 * The complete license agreement can be obtained at:
 * https://sparcityeu.github.io/sparsebase/pages/license.html
 ********************************************************/
#ifndef SPARSEBASE_SPARSEBASE_UTILS_IO_SPARSE_FILE_FORMAT_H_
#define SPARSEBASE_SPARSEBASE_UTILS_IO_SPARSE_FILE_FORMAT_H_

#ifdef USE_PIGO
#include "sparsebase/external/pigo/pigo.hpp"
#endif

#include <climits>
#include <iostream>
#include <string>
#include <type_traits>

#include "sparsebase/external/json/json.hpp"

namespace sparsebase {

namespace io {

#ifdef USE_PIGO

class SbffWriteFile {
 private:
  pigo::WFile file;

 public:
  SbffWriteFile(std::string filename, size_t size) : file(filename, size) {}

  void Write(char *data, size_t size) { file.parallel_write(data, size); }
};

class SbffReadOnlyFile {
 private:
  pigo::ROFile file;

 public:
  explicit SbffReadOnlyFile(std::string filename) : file(filename) {}

  void Read(char *buffer, size_t size) { file.parallel_read(buffer, size); }
};

#else

class SbffWriteFile {
 private:
  std::ofstream ofs;

 public:
  SbffWriteFile(std::string filename, size_t size) {
    ofs.open(filename, std::ios::out | std::ios::binary);
  }
  ~SbffWriteFile() { ofs.close(); }
  void Write(char *data, size_t size) { ofs.write(data, size); }
};

class SbffReadOnlyFile {
 private:
  std::ifstream ifs;

 public:
  explicit SbffReadOnlyFile(std::string filename) {
    ifs.open(filename, std::ios::in | std::ios::binary);
  }
  ~SbffReadOnlyFile() { ifs.close(); }
  void Read(char *buffer, size_t size) { ifs.read(buffer, size); }
};

#endif

class SbffArray {
 private:
  std::string name;
  size_t array_size;
  size_t type_size;
  std::string type;
  char *data;
  std::string endian;

  SbffArray() = default;

  friend class SbffObject;

 public:
  template <typename T>
  static SbffArray Create(std::string name, T *arr, size_t size) {
    static_assert(!std::is_same_v<T, void>, "Cannot create a void array");
    SbffArray sbas_arr;
    sbas_arr.name = name;
    sbas_arr.data = (char *)arr;
    sbas_arr.array_size = size;
    sbas_arr.type_size = sizeof(T);
    sbas_arr.endian = GetEndian();

    if constexpr (std::is_floating_point_v<T>) {
      sbas_arr.type = "float";
    } else if constexpr (std::is_signed_v<T>) {
      sbas_arr.type = "signed";
    } else if constexpr (std::is_unsigned_v<T>) {
      sbas_arr.type = "unsigned";
    } else {
      throw sparsebase::utils::WriterException(std::string("Type ") +
                                               typeid(T).name() +
                                               " is not supported by SBFF");
    }

    return sbas_arr;
  }

  static nlohmann::json ReadHeader(SbffReadOnlyFile &file) {
    char header_bytes[1024];
    file.Read((char *)&header_bytes, 1024);
    nlohmann::json header = nlohmann::json::parse(header_bytes);
    return header;
  }

  static SbffArray ReadArray(SbffReadOnlyFile &file, std::string endian) {
    try {
      auto header = ReadHeader(file);

      SbffArray sbas_arr;
      sbas_arr.array_size = header.at("array_size");
      sbas_arr.type_size = header.at("type_size");
      sbas_arr.type = header.at("type");
      sbas_arr.name = header.at("name");
      sbas_arr.endian = endian;

      sbas_arr.data = new char[sbas_arr.array_size * sbas_arr.type_size];
      file.Read((char *)sbas_arr.data,
                sbas_arr.array_size * sbas_arr.type_size);

      return sbas_arr;

    } catch (sparsebase::utils::ReaderException &e) {
      throw sparsebase::utils::ReaderException(e.what());
    } catch (...) {
      throw sparsebase::utils::ReaderException("Unknown SBFF ReadArray Error");
    }
  }

  static std::vector<char> HeaderToBytes(const nlohmann::json &header) {
    std::string header_str = header.dump();
    std::vector<char> header_bytes(header_str.begin(), header_str.end());

    // Headers should have a maximum size of 1024 bytes
    if (header_bytes.size() > 1024) {
      throw sparsebase::utils::WriterException("Header size exceeds 1 KB");
    }

    // Pad the header to exactly 1024 bytes
    while (header_bytes.size() < 1024) {
      header_bytes.push_back(' ');
    }

    return header_bytes;
  }

  void WriteArray(SbffWriteFile &file) {
    nlohmann::json header;
    header["name"] = name;
    header["type"] = type;
    header["type_size"] = type_size;
    header["array_size"] = array_size;

    file.Write((char *)HeaderToBytes(header).data(), 1024);
    file.Write((char *)data, array_size * type_size);
  }

  // This will fail if sizeof(int) == 1
  // which might be the case on some embedded systems
  static std::string GetEndian() {
    const int value{0x01};
    const void *address = static_cast<const void *>(&value);
    const auto *least_significant_address =
        static_cast<const unsigned char *>(address);
    return (*least_significant_address == 0x01) ? "little" : "big";
  }

  template <typename T>
  static T SwapEndian(T u) {
    static_assert(CHAR_BIT == 8, "CHAR_BIT != 8");

    union {
      T u;
      unsigned char u8[sizeof(T)];
    } source, dest;

    source.u = u;

    for (size_t k = 0; k < sizeof(T); k++)
      dest.u8[k] = source.u8[sizeof(T) - k - 1];

    return dest.u;
  }
};

struct SbffObject {
 private:
  std::string name;
  std::unordered_map<std::string, SbffArray> arrays;
  std::vector<int> dimensions;
  size_t total_size = 1024;

 public:
  explicit SbffObject(std::string name) : name(name) {}

  void AddDimensions(const std::vector<format::DimensionType> &dims) {
    dimensions.insert(dimensions.end(), dims.begin(), dims.end());
  }

  template <typename T>
  void AddArray(std::string array_name, T *arr, size_t size) {
    auto sbas_arr = SbffArray::Create(array_name, arr, size);
    arrays.emplace(array_name, sbas_arr);
    total_size += 1024 + sizeof(T) * size;
  }

  void AddArray(SbffArray sbas_arr) {
    arrays.emplace(sbas_arr.name, sbas_arr);
    total_size += 1024 + sbas_arr.array_size * sbas_arr.type_size;
  }

  template <typename T>
  size_t GetArray(std::string array_name, T *&ptr) {
    try {
      SbffArray &arr = arrays.at(array_name);

      if (arr.type == "float" && !std::is_floating_point_v<T>) {
        throw sparsebase::utils::ReaderException(
            "Type mismatch, array type is float");
      }

      if (arr.type == "signed" && !std::is_signed_v<T>) {
        throw sparsebase::utils::ReaderException(
            "Type mismatch, array type is signed");
      }

      if (arr.type == "unsigned" && !std::is_unsigned_v<T>) {
        throw sparsebase::utils::ReaderException(
            "Type mismatch, array type is unsigned");
      }

      if (arr.type_size != sizeof(T)) {
        throw sparsebase::utils::ReaderException(
            std::string("Type mismatch, array type has size ") +
            std::to_string(arr.type_size));
      }

      ptr = (T *)arr.data;

      if (arr.endian != SbffArray::GetEndian()) {
#pragma omp parallel for shared(ptr, arr) default(none)
        for (size_t i = 0; i < arr.array_size; i++) {
          ptr[i] = SbffArray::SwapEndian(ptr[i]);
        }
      }

      return arr.array_size;
    } catch (sparsebase::utils::ReaderException &e) {
      throw sparsebase::utils::ReaderException(e.what());
    } catch (...) {
      throw sparsebase::utils::ReaderException("Unknown SBFF ReadArray Error");
    }
  }

  void WriteObject(std::string filename) {
    SbffWriteFile file(filename, total_size);
    WriteObject(file);
  }

  void WriteObject(SbffWriteFile &file) {
    nlohmann::json header;
    header["name"] = name;
    header["array_count"] = arrays.size();
    header["dimensions"] = dimensions;
    header["endian"] = SbffArray::GetEndian();

    file.Write((char *)SbffArray::HeaderToBytes(header).data(), 1024);

    for (auto arr : arrays) {
      arr.second.WriteArray(file);
    }
  }

  static SbffObject ReadObject(SbffReadOnlyFile &file) {
    try {
      auto header = SbffArray::ReadHeader(file);

      SbffObject obj("temp");
      obj.name = header.at("name");
      std::string endian = header.at("endian");
      size_t array_count = header.at("array_count");
      auto dims = header.at("dimensions");
      obj.dimensions.insert(obj.dimensions.end(), dims.begin(), dims.end());

      for (size_t i = 0; i < array_count; i++)
        obj.AddArray(SbffArray::ReadArray(file, endian));

      return obj;

    } catch (sparsebase::utils::ReaderException &e) {
      throw sparsebase::utils::ReaderException(e.what());
    } catch (...) {
      throw sparsebase::utils::ReaderException("Unknown SBFF ReadArray Error");
    }
  }

  static SbffObject ReadObject(std::string filename) {
    SbffReadOnlyFile file(filename);
    return ReadObject(file);
  }

  std::string get_name() { return name; }

  size_t get_array_count() { return arrays.size(); }

  std::vector<int> get_dimensions() { return dimensions; }
};

}  // namespace io

}  // namespace sparsebase

#endif  // SPARSEBASE_SPARSEBASE_UTILS_IO_SPARSE_FILE_FORMAT_H_