Program Listing for File permute_order_two.cc

Return to documentation for file (src/sparsebase/permute/permute_order_two.cc)

#include "sparsebase/permute/permute_order_two.h"

#include "sparsebase/format/array.h"
#include "sparsebase/format/csr.h"

namespace sparsebase::permute {
template <typename IDType, typename NNZType, typename ValueType>
PermuteOrderTwo<IDType, NNZType, ValueType>::PermuteOrderTwo(
    IDType *row_order, IDType *col_order) {
  this->RegisterFunction(
      {format::CSR<IDType, NNZType, ValueType>::get_id_static()},
      PermuteOrderTwoCSR);
  this->params_ =
      std::make_unique<PermuteOrderTwoParams<IDType>>(row_order, col_order);
}
template <typename IDType, typename NNZType, typename ValueType>
PermuteOrderTwo<IDType, NNZType, ValueType>::PermuteOrderTwo(
    PermuteOrderTwoParams<IDType> params) {
  PermuteOrderTwo(params.row_order, params.col_order);
}
template <typename IDType, typename NNZType, typename ValueType>
format::FormatOrderTwo<IDType, NNZType, ValueType>
    *PermuteOrderTwo<IDType, NNZType, ValueType>::PermuteOrderTwoCSR(
        std::vector<format::Format *> formats, utils::Parameters *params) {
  auto *sp = formats[0]->AsAbsolute<format::CSR<IDType, NNZType, ValueType>>();
  auto row_order =
      static_cast<PermuteOrderTwoParams<IDType> *>(params)->row_order;
  auto col_order =
      static_cast<PermuteOrderTwoParams<IDType> *>(params)->col_order;
  std::vector<format::DimensionType> dimensions = sp->get_dimensions();
  IDType n = dimensions[0];
  IDType m = dimensions[1];
  NNZType nnz = sp->get_num_nnz();
  NNZType *xadj = sp->get_row_ptr();
  IDType *adj = sp->get_col();
  ValueType *vals = sp->get_vals();
  NNZType *nxadj = new NNZType[n + 1]();
  IDType *nadj = new IDType[nnz]();
  ValueType *nvals = nullptr;
  if constexpr (!std::is_same_v<void, ValueType>) {
    if (sp->get_vals() != nullptr) nvals = new ValueType[nnz]();
  }
  std::function<IDType(IDType)> get_i_row_order;
  std::function<IDType(IDType)> get_col_order;
  IDType *inverse_row_order;
  if (row_order != nullptr) {
    inverse_row_order = new IDType[n]();
    for (IDType i = 0; i < n; i++) inverse_row_order[row_order[i]] = i;
    get_i_row_order = [&inverse_row_order](IDType i) -> IDType {
      return inverse_row_order[i];
    };
  } else {
    get_i_row_order = [&inverse_row_order](IDType i) -> IDType { return i; };
  }
  if (col_order != nullptr) {
    get_col_order = [&col_order](IDType i) -> IDType { return col_order[i]; };
  } else {
    get_col_order = [](IDType i) -> IDType { return i; };
  }
  // IDType *inverse_col_order = new IDType[n]();
  // for (IDType i = 0; i < n; i++)
  //  inverse_col_order[col_order[i]] = i;
  NNZType c = 0;
  for (IDType i = 0; i < n; i++) {
    IDType u = get_i_row_order(i);
    nxadj[i + 1] = nxadj[i] + (xadj[u + 1] - xadj[u]);
    for (NNZType v = xadj[u]; v < xadj[u + 1]; v++) {
      nadj[c] = get_col_order(adj[v]);
      if constexpr (!std::is_same_v<void, ValueType>) {
        if (sp->get_vals() != nullptr) nvals[c] = vals[v];
      }
      c++;
    }
  }
  if (row_order == nullptr) delete[] inverse_row_order;
  format::CSR<IDType, NNZType, ValueType> *csr =
      new format::CSR(n, m, nxadj, nadj, nvals);
  return csr;
}

#if !defined(_HEADER_ONLY)
#include "init/permute_order_two.inc"
#endif
}  // namespace sparsebase::permute