.. _program_listing_file_src_sparsebase_format_csr.cc: Program Listing for File csr.cc =============================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/sparsebase/format/csr.cc``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "sparsebase/format/csr.h" #include "sparsebase/utils/logger.h" namespace sparsebase::format { template CSR::CSR(CSR &&rhs) : col_(std::move(rhs.col_)), row_ptr_(std::move(rhs.row_ptr_)), vals_(std::move(rhs.vals_)) { this->nnz_ = rhs.get_num_nnz(); this->order_ = 2; this->dimension_ = rhs.dimension_; rhs.col_ = std::unique_ptr>( nullptr, BlankDeleter()); rhs.row_ptr_ = std::unique_ptr>( nullptr, BlankDeleter()); rhs.vals_ = std::unique_ptr>( nullptr, BlankDeleter()); this->context_ = std::unique_ptr( new sparsebase::context::CPUContext); } template CSR &CSR::operator=( const CSR &rhs) { this->nnz_ = rhs.nnz_; this->order_ = 2; this->dimension_ = rhs.dimension_; auto col = new IDType[rhs.get_num_nnz()]; std::copy(rhs.get_col(), rhs.get_col() + rhs.get_num_nnz(), col); auto row_ptr = new NNZType[(rhs.get_dimensions()[0] + 1)]; std::copy(rhs.get_row_ptr(), rhs.get_row_ptr() + (rhs.get_dimensions()[0] + 1), row_ptr); ValueType *vals = nullptr; if constexpr (!std::is_same_v) { if (rhs.get_vals() != nullptr) { vals = new ValueType[rhs.get_num_nnz()]; std::copy(rhs.get_vals(), rhs.get_vals() + rhs.get_num_nnz(), vals); } } this->col_ = std::unique_ptr>( col, Deleter()); this->row_ptr_ = std::unique_ptr>( row_ptr, Deleter()); this->vals_ = std::unique_ptr>( vals, Deleter()); return *this; } template CSR::CSR(const CSR &rhs) : col_(nullptr, BlankDeleter()), row_ptr_(nullptr, BlankDeleter()), vals_(nullptr, BlankDeleter()) { this->nnz_ = rhs.nnz_; this->order_ = 2; this->dimension_ = rhs.dimension_; auto col = new IDType[rhs.get_num_nnz()]; std::copy(rhs.get_col(), rhs.get_col() + rhs.get_num_nnz(), col); auto row_ptr = new NNZType[(rhs.get_dimensions()[0] + 1)]; std::copy(rhs.get_row_ptr(), rhs.get_row_ptr() + (rhs.get_dimensions()[0] + 1), row_ptr); ValueType *vals = nullptr; if constexpr (!std::is_same_v) { if (rhs.get_vals() != nullptr) { vals = new ValueType[rhs.get_num_nnz()]; std::copy(rhs.get_vals(), rhs.get_vals() + rhs.get_num_nnz(), vals); } } this->col_ = std::unique_ptr>( col, Deleter()); this->row_ptr_ = std::unique_ptr>( row_ptr, Deleter()); this->vals_ = std::unique_ptr>( vals, Deleter()); this->context_ = std::unique_ptr( new sparsebase::context::CPUContext); } template CSR::CSR(IDType n, IDType m, NNZType *row_ptr, IDType *col, ValueType *vals, Ownership own, bool ignore_sort) : row_ptr_(row_ptr, BlankDeleter()), col_(col, BlankDeleter()), vals_(vals, BlankDeleter()) { this->order_ = 2; this->dimension_ = {(DimensionType)n, (DimensionType)m}; this->nnz_ = row_ptr[this->dimension_[0]]; if (own == kOwned) { this->row_ptr_ = std::unique_ptr>( row_ptr, Deleter()); this->col_ = std::unique_ptr>( col, Deleter()); this->vals_ = std::unique_ptr>( vals, Deleter()); } this->context_ = std::unique_ptr( new sparsebase::context::CPUContext); ; if (!ignore_sort) { bool not_sorted = false; #pragma omp parallel for default(none) reduction(|| \ : not_sorted) \ shared(col, row_ptr, n) for (IDType i = 0; i < n; i++) { NNZType start = row_ptr[i]; NNZType end = row_ptr[i + 1]; IDType prev_value = 0; for (NNZType j = start; j < end; j++) { if (col[j] < prev_value) { not_sorted = true; break; } prev_value = col[j]; } } if (not_sorted) { utils::Logger logger(typeid(this)); logger.Log("CSR column array must be sorted. Sorting...", utils::LOG_LVL_WARNING); #pragma omp parallel for default(none) shared(row_ptr, col, vals, n) for (IDType i = 0; i < n; i++) { NNZType start = row_ptr[i]; NNZType end = row_ptr[i + 1]; if (end - start <= 1) { continue; } if constexpr (std::is_same_v) { std::vector sort_vec; for (NNZType j = start; j < end; j++) { sort_vec.emplace_back(col[j]); } std::sort(sort_vec.begin(), sort_vec.end(), std::less()); for (NNZType j = start; j < end; j++) { col[j] = sort_vec[j - start]; } } else { std::vector> sort_vec; for (NNZType j = start; j < end; j++) { ValueType val = (vals != nullptr) ? vals[j] : 0; sort_vec.emplace_back(col[j], val); } std::sort(sort_vec.begin(), sort_vec.end(), std::less>()); for (NNZType j = start; j < end; j++) { if (vals != nullptr) { vals[j] = sort_vec[j - start].second; } col[j] = sort_vec[j - start].first; } } } } } } template Format *CSR::Clone() const { return new CSR(*this); } template IDType *CSR::get_col() const { return col_.get(); } template NNZType *CSR::get_row_ptr() const { return row_ptr_.get(); } template ValueType *CSR::get_vals() const { return vals_.get(); } template IDType *CSR::release_col() { auto col = col_.release(); this->col_ = std::unique_ptr>( col, BlankDeleter()); return col; } template NNZType *CSR::release_row_ptr() { auto row_ptr = row_ptr_.release(); this->row_ptr_ = std::unique_ptr>( row_ptr, BlankDeleter()); return row_ptr; } template ValueType *CSR::release_vals() { auto vals = vals_.release(); this->vals_ = std::unique_ptr>( vals, BlankDeleter()); return vals; } template void CSR::set_col(IDType *col, Ownership own) { if (own == kOwned) { this->col_ = std::unique_ptr>( col, Deleter()); } else { this->col_ = std::unique_ptr>( col, BlankDeleter()); } } template void CSR::set_row_ptr(NNZType *row_ptr, Ownership own) { if (own == kOwned) { this->row_ptr_ = std::unique_ptr>( row_ptr, Deleter()); } else { this->row_ptr_ = std::unique_ptr>( row_ptr, BlankDeleter()); } } template void CSR::set_vals(ValueType *vals, Ownership own) { if (own == kOwned) { this->vals_ = std::unique_ptr>( vals, Deleter()); } else { this->vals_ = std::unique_ptr>( vals, BlankDeleter()); } } template bool CSR::RowPtrIsOwned() { return (this->row_ptr_.get_deleter().target_type() != typeid(BlankDeleter)); } template bool CSR::ColIsOwned() { return (this->col_.get_deleter().target_type() != typeid(BlankDeleter)); } template bool CSR::ValsIsOwned() { return (this->vals_.get_deleter().target_type() != typeid(BlankDeleter)); } template CSR::~CSR() {} #ifndef _HEADER_ONLY #include "init/csr.inc" #endif } // namespace sparsebase::format