Program Listing for File csr.cc

Return to documentation for file (src/sparsebase/format/csr.cc)

#include "sparsebase/format/csr.h"

#include "sparsebase/utils/logger.h"
namespace sparsebase::format {
template <typename IDType, typename NNZType, typename ValueType>
CSR<IDType, NNZType, ValueType>::CSR(CSR<IDType, NNZType, ValueType> &&rhs)
    : col_(std::move(rhs.col_)),
      row_ptr_(std::move(rhs.row_ptr_)),
      vals_(std::move(rhs.vals_)) {
  this->nnz_ = rhs.get_num_nnz();
  this->order_ = 2;
  this->dimension_ = rhs.dimension_;
  rhs.col_ = std::unique_ptr<IDType, std::function<void(IDType *)>>(
      nullptr, BlankDeleter<IDType>());
  rhs.row_ptr_ = std::unique_ptr<NNZType, std::function<void(NNZType *)>>(
      nullptr, BlankDeleter<NNZType>());
  rhs.vals_ = std::unique_ptr<ValueType, std::function<void(ValueType *)>>(
      nullptr, BlankDeleter<ValueType>());
  this->context_ = std::unique_ptr<sparsebase::context::Context>(
      new sparsebase::context::CPUContext);
}
template <typename IDType, typename NNZType, typename ValueType>
CSR<IDType, NNZType, ValueType> &CSR<IDType, NNZType, ValueType>::operator=(
    const CSR<IDType, NNZType, ValueType> &rhs) {
  this->nnz_ = rhs.nnz_;
  this->order_ = 2;
  this->dimension_ = rhs.dimension_;
  auto col = new IDType[rhs.get_num_nnz()];
  std::copy(rhs.get_col(), rhs.get_col() + rhs.get_num_nnz(), col);
  auto row_ptr = new NNZType[(rhs.get_dimensions()[0] + 1)];
  std::copy(rhs.get_row_ptr(),
            rhs.get_row_ptr() + (rhs.get_dimensions()[0] + 1), row_ptr);
  ValueType *vals = nullptr;
  if constexpr (!std::is_same_v<ValueType, void>) {
    if (rhs.get_vals() != nullptr) {
      vals = new ValueType[rhs.get_num_nnz()];
      std::copy(rhs.get_vals(), rhs.get_vals() + rhs.get_num_nnz(), vals);
    }
  }
  this->col_ = std::unique_ptr<IDType, std::function<void(IDType *)>>(
      col, Deleter<IDType>());
  this->row_ptr_ = std::unique_ptr<NNZType, std::function<void(NNZType *)>>(
      row_ptr, Deleter<NNZType>());
  this->vals_ = std::unique_ptr<ValueType, std::function<void(ValueType *)>>(
      vals, Deleter<ValueType>());
  return *this;
}
template <typename IDType, typename NNZType, typename ValueType>
CSR<IDType, NNZType, ValueType>::CSR(const CSR<IDType, NNZType, ValueType> &rhs)
    : col_(nullptr, BlankDeleter<IDType>()),
      row_ptr_(nullptr, BlankDeleter<NNZType>()),
      vals_(nullptr, BlankDeleter<ValueType>()) {
  this->nnz_ = rhs.nnz_;
  this->order_ = 2;
  this->dimension_ = rhs.dimension_;
  auto col = new IDType[rhs.get_num_nnz()];
  std::copy(rhs.get_col(), rhs.get_col() + rhs.get_num_nnz(), col);
  auto row_ptr = new NNZType[(rhs.get_dimensions()[0] + 1)];
  std::copy(rhs.get_row_ptr(),
            rhs.get_row_ptr() + (rhs.get_dimensions()[0] + 1), row_ptr);
  ValueType *vals = nullptr;
  if constexpr (!std::is_same_v<ValueType, void>) {
    if (rhs.get_vals() != nullptr) {
      vals = new ValueType[rhs.get_num_nnz()];
      std::copy(rhs.get_vals(), rhs.get_vals() + rhs.get_num_nnz(), vals);
    }
  }
  this->col_ = std::unique_ptr<IDType, std::function<void(IDType *)>>(
      col, Deleter<IDType>());
  this->row_ptr_ = std::unique_ptr<NNZType, std::function<void(NNZType *)>>(
      row_ptr, Deleter<NNZType>());
  this->vals_ = std::unique_ptr<ValueType, std::function<void(ValueType *)>>(
      vals, Deleter<ValueType>());
  this->context_ = std::unique_ptr<sparsebase::context::Context>(
      new sparsebase::context::CPUContext);
}
template <typename IDType, typename NNZType, typename ValueType>
CSR<IDType, NNZType, ValueType>::CSR(IDType n, IDType m, NNZType *row_ptr,
                                     IDType *col, ValueType *vals,
                                     Ownership own, bool ignore_sort)
    : row_ptr_(row_ptr, BlankDeleter<NNZType>()),
      col_(col, BlankDeleter<IDType>()),
      vals_(vals, BlankDeleter<ValueType>()) {
  this->order_ = 2;
  this->dimension_ = {(DimensionType)n, (DimensionType)m};
  this->nnz_ = row_ptr[this->dimension_[0]];
  if (own == kOwned) {
    this->row_ptr_ = std::unique_ptr<NNZType, std::function<void(NNZType *)>>(
        row_ptr, Deleter<NNZType>());
    this->col_ = std::unique_ptr<IDType, std::function<void(IDType *)>>(
        col, Deleter<IDType>());
    this->vals_ = std::unique_ptr<ValueType, std::function<void(ValueType *)>>(
        vals, Deleter<ValueType>());
  }
  this->context_ = std::unique_ptr<sparsebase::context::Context>(
      new sparsebase::context::CPUContext);
  ;

  if (!ignore_sort) {
    bool not_sorted = false;

#pragma omp parallel for default(none) reduction(||            \
                                                 : not_sorted) \
    shared(col, row_ptr, n)
    for (IDType i = 0; i < n; i++) {
      NNZType start = row_ptr[i];
      NNZType end = row_ptr[i + 1];
      IDType prev_value = 0;
      for (NNZType j = start; j < end; j++) {
        if (col[j] < prev_value) {
          not_sorted = true;
          break;
        }
        prev_value = col[j];
      }
    }

    if (not_sorted) {
      utils::Logger logger(typeid(this));
      logger.Log("CSR column array must be sorted. Sorting...",
                 utils::LOG_LVL_WARNING);

#pragma omp parallel for default(none) shared(row_ptr, col, vals, n)
      for (IDType i = 0; i < n; i++) {
        NNZType start = row_ptr[i];
        NNZType end = row_ptr[i + 1];

        if (end - start <= 1) {
          continue;
        }

        if constexpr (std::is_same_v<ValueType, void>) {
          std::vector<IDType> sort_vec;
          for (NNZType j = start; j < end; j++) {
            sort_vec.emplace_back(col[j]);
          }
          std::sort(sort_vec.begin(), sort_vec.end(), std::less<IDType>());
          for (NNZType j = start; j < end; j++) {
            col[j] = sort_vec[j - start];
          }
        } else {
          std::vector<std::pair<IDType, ValueType>> sort_vec;
          for (NNZType j = start; j < end; j++) {
            ValueType val = (vals != nullptr) ? vals[j] : 0;
            sort_vec.emplace_back(col[j], val);
          }
          std::sort(sort_vec.begin(), sort_vec.end(),
                    std::less<std::pair<IDType, ValueType>>());
          for (NNZType j = start; j < end; j++) {
            if (vals != nullptr) {
              vals[j] = sort_vec[j - start].second;
            }
            col[j] = sort_vec[j - start].first;
          }
        }
      }
    }
  }
}

template <typename IDType, typename NNZType, typename ValueType>
Format *CSR<IDType, NNZType, ValueType>::Clone() const {
  return new CSR(*this);
}
template <typename IDType, typename NNZType, typename ValueType>
IDType *CSR<IDType, NNZType, ValueType>::get_col() const {
  return col_.get();
}
template <typename IDType, typename NNZType, typename ValueType>
NNZType *CSR<IDType, NNZType, ValueType>::get_row_ptr() const {
  return row_ptr_.get();
}
template <typename IDType, typename NNZType, typename ValueType>
ValueType *CSR<IDType, NNZType, ValueType>::get_vals() const {
  return vals_.get();
}
template <typename IDType, typename NNZType, typename ValueType>
IDType *CSR<IDType, NNZType, ValueType>::release_col() {
  auto col = col_.release();
  this->col_ = std::unique_ptr<IDType, std::function<void(IDType *)>>(
      col, BlankDeleter<IDType>());
  return col;
}
template <typename IDType, typename NNZType, typename ValueType>
NNZType *CSR<IDType, NNZType, ValueType>::release_row_ptr() {
  auto row_ptr = row_ptr_.release();
  this->row_ptr_ = std::unique_ptr<NNZType, std::function<void(NNZType *)>>(
      row_ptr, BlankDeleter<NNZType>());
  return row_ptr;
}
template <typename IDType, typename NNZType, typename ValueType>
ValueType *CSR<IDType, NNZType, ValueType>::release_vals() {
  auto vals = vals_.release();
  this->vals_ = std::unique_ptr<ValueType, std::function<void(ValueType *)>>(
      vals, BlankDeleter<ValueType>());
  return vals;
}

template <typename IDType, typename NNZType, typename ValueType>
void CSR<IDType, NNZType, ValueType>::set_col(IDType *col, Ownership own) {
  if (own == kOwned) {
    this->col_ = std::unique_ptr<IDType, std::function<void(IDType *)>>(
        col, Deleter<IDType>());
  } else {
    this->col_ = std::unique_ptr<IDType, std::function<void(IDType *)>>(
        col, BlankDeleter<IDType>());
  }
}

template <typename IDType, typename NNZType, typename ValueType>
void CSR<IDType, NNZType, ValueType>::set_row_ptr(NNZType *row_ptr,
                                                  Ownership own) {
  if (own == kOwned) {
    this->row_ptr_ = std::unique_ptr<NNZType, std::function<void(NNZType *)>>(
        row_ptr, Deleter<NNZType>());
  } else {
    this->row_ptr_ = std::unique_ptr<NNZType, std::function<void(NNZType *)>>(
        row_ptr, BlankDeleter<NNZType>());
  }
}

template <typename IDType, typename NNZType, typename ValueType>
void CSR<IDType, NNZType, ValueType>::set_vals(ValueType *vals, Ownership own) {
  if (own == kOwned) {
    this->vals_ = std::unique_ptr<ValueType, std::function<void(ValueType *)>>(
        vals, Deleter<ValueType>());
  } else {
    this->vals_ = std::unique_ptr<ValueType, std::function<void(ValueType *)>>(
        vals, BlankDeleter<ValueType>());
  }
}

template <typename IDType, typename NNZType, typename ValueType>
bool CSR<IDType, NNZType, ValueType>::RowPtrIsOwned() {
  return (this->row_ptr_.get_deleter().target_type() !=
          typeid(BlankDeleter<NNZType>));
}

template <typename IDType, typename NNZType, typename ValueType>
bool CSR<IDType, NNZType, ValueType>::ColIsOwned() {
  return (this->col_.get_deleter().target_type() !=
          typeid(BlankDeleter<IDType>));
}

template <typename IDType, typename NNZType, typename ValueType>
bool CSR<IDType, NNZType, ValueType>::ValsIsOwned() {
  return (this->vals_.get_deleter().target_type() !=
          typeid(BlankDeleter<ValueType>));
}
template <typename IDType, typename NNZType, typename ValueType>
CSR<IDType, NNZType, ValueType>::~CSR() {}

#ifndef _HEADER_ONLY
#include "init/csr.inc"
#endif
}  // namespace sparsebase::format