Program Listing for File reorder_base.h

Return to documentation for file (src/sparsebase/bases/reorder_base.h)

#include "sparsebase/context/cpu_context.h"
#include "sparsebase/format/format_order_one.h"
#include "sparsebase/format/format_order_two.h"
#include "sparsebase/permute/permute_order_one.h"
#include "sparsebase/permute/permute_order_two.h"
#include "sparsebase/reorder/amd_reorder.h"
#include "sparsebase/reorder/degree_reorder.h"
#include "sparsebase/reorder/generic_reorder.h"
#include "sparsebase/reorder/gray_reorder.h"
#include "sparsebase/reorder/metis_reorder.h"
#include "sparsebase/reorder/rabbit_reorder.h"
#include "sparsebase/reorder/rcm_reorder.h"
#include "sparsebase/reorder/reorder_heatmap.h"
#include "sparsebase/reorder/reorderer.h"

#ifndef SPARSEBASE_PROJECT_REORDER_BASE_H
#define SPARSEBASE_PROJECT_REORDER_BASE_H

namespace sparsebase::bases {


class ReorderBase {
 public:

  template <template <typename, typename, typename> typename Reordering,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static AutoIDType *Reorder(
      typename Reordering<AutoIDType, AutoNNZType, AutoValueType>::ParamsType
          params,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_input) {
    static_assert(
        std::is_base_of_v<reorder::Reorderer<AutoIDType>,
                          Reordering<AutoIDType, AutoNNZType, AutoValueType>>,
        "You must pass a reordering function (with base Reorderer) "
        "to ReorderBase::Reorder");
    static_assert(
        !std::is_same_v<
            reorder::GenericReorder<AutoIDType, AutoNNZType, AutoValueType>,
            Reordering<AutoIDType, AutoNNZType, AutoValueType>>,
        "You must pass a reordering function (with base Reorderer) "
        "to ReorderBase::Reorder");
    Reordering<AutoIDType, AutoNNZType, AutoValueType> reordering(params);
    return reordering.GetReorder(format, contexts, convert_input);
  }

  template <template <typename, typename, typename> typename Reordering,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static std::pair<std::vector<format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                                      AutoValueType> *>,
                   AutoIDType *>
  ReorderCached(
      typename Reordering<AutoIDType, AutoNNZType, AutoValueType>::ParamsType
          params,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts) {
    static_assert(
        std::is_base_of_v<reorder::Reorderer<AutoIDType>,
                          Reordering<AutoIDType, AutoNNZType, AutoValueType>>,
        "You must pass a reordering function (with base Reorderer) "
        "to ReorderBase::Reorder");
    static_assert(
        !std::is_same_v<
            reorder::GenericReorder<AutoIDType, AutoNNZType, AutoValueType>,
            Reordering<AutoIDType, AutoNNZType, AutoValueType>>,
        "You must pass a reordering function (with base Reorderer) "
        "to ReorderBase::Reorder");
    Reordering<AutoIDType, AutoNNZType, AutoValueType> reordering(params);
    auto output = reordering.GetReorderCached(format, contexts, true);
    std::vector<
        format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>
        converted_formats;
    std::transform(
        std::get<0>(output)[0].begin(), std::get<0>(output)[0].end(),
        std::back_inserter(converted_formats),
        [](format::Format *intermediate_format) {
          return static_cast<
              format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>(
              intermediate_format);
        });
    return std::make_pair(converted_formats, std::get<1>(output));
  }


  template <template <typename, typename, typename>
            typename ReturnFormatType = format::FormatOrderTwo,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *Permute2D(
      AutoIDType *ordering,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_input,
      bool convert_output = false) {
    permute::PermuteOrderTwo<AutoIDType, AutoNNZType, AutoValueType> perm(
        ordering, ordering);
    auto out_format = perm.GetPermutation(format, contexts, convert_input);
    ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *output;
    if constexpr (std::is_same_v<
                      ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType>,
                      format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                             AutoValueType>>)
      output = out_format;
    else {
      if (convert_output)
        output = out_format->template Convert<ReturnFormatType>();
      else
        output = out_format->template As<ReturnFormatType>();
    }
    return output;
  }


  template <template <typename, typename, typename>
            typename ReturnFormatType = format::FormatOrderTwo,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static std::pair<std::vector<format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                                      AutoValueType> *>,
                   ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *>
  Permute2DCached(
      AutoIDType *ordering,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_output = false) {
    permute::PermuteOrderTwo<AutoIDType, AutoNNZType, AutoValueType> perm(
        ordering, ordering);
    auto output = perm.GetPermutationCached(format, contexts, true);
    std::vector<
        format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>
        converted_formats;
    std::transform(
        std::get<0>(output)[0].begin(), std::get<0>(output)[0].end(),
        std::back_inserter(converted_formats),
        [](format::Format *intermediate_format) {
          return static_cast<
              format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>(
              intermediate_format);
        });
    ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *output_format;
    if constexpr (std::is_same_v<
                      ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType>,
                      format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                             AutoValueType>>)
      output_format = std::get<1>(output);
    else {
      if (convert_output)
        output_format =
            std::get<1>(output)->template Convert<ReturnFormatType>();
      else
        output_format = std::get<1>(output)->template As<ReturnFormatType>();
    }
    return std::make_pair(converted_formats, output_format);
  }


  template <template <typename, typename, typename>
            typename ReturnFormatType = format::FormatOrderTwo,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static std::pair<std::vector<format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                                      AutoValueType> *>,
                   ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *>
  Permute2DRowColumnWiseCached(
      AutoIDType *row_ordering, AutoIDType *col_ordering,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_output = false) {
    permute::PermuteOrderTwo<AutoIDType, AutoNNZType, AutoValueType> perm(
        row_ordering, col_ordering);
    auto output = perm.GetPermutationCached(format, contexts, true);
    std::vector<
        format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>
        converted_formats;
    std::transform(
        std::get<0>(output)[0].begin(), std::get<0>(output)[0].end(),
        std::back_inserter(converted_formats),
        [](format::Format *intermediate_format) {
          return static_cast<
              format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>(
              intermediate_format);
        });
    ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *output_format;
    if constexpr (std::is_same_v<
                      ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType>,
                      format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                             AutoValueType>>)
      output_format = std::get<1>(output);
    else {
      if (convert_output)
        output_format =
            std::get<1>(output)->template Convert<ReturnFormatType>();
      else
        output_format = std::get<1>(output)->template As<ReturnFormatType>();
    }
    return std::make_pair(converted_formats, output_format);
  }


  template <template <typename, typename, typename>
            typename ReturnFormatType = format::FormatOrderTwo,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *
  Permute2DRowColumnWise(
      AutoIDType *row_ordering, AutoIDType *col_ordering,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_input,
      bool convert_output = false) {
    permute::PermuteOrderTwo<AutoIDType, AutoNNZType, AutoValueType> perm(
        row_ordering, col_ordering);
    auto out_format = perm.GetPermutation(format, contexts, convert_input);
    ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *output;
    if constexpr (std::is_same_v<
                      ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType>,
                      format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                             AutoValueType>>)
      output = out_format;
    else {
      if (convert_output)
        output = out_format->template Convert<ReturnFormatType>();
      else
        output = out_format->template As<ReturnFormatType>();
    }
    return output;
  }


  template <template <typename, typename, typename>
            typename ReturnFormatType = format::FormatOrderTwo,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *
  Permute2DRowWise(
      AutoIDType *ordering,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_input,
      bool convert_output = false) {
    permute::PermuteOrderTwo<AutoIDType, AutoNNZType, AutoValueType> perm(
        ordering, nullptr);
    auto out_format = perm.GetPermutation(format, contexts, convert_input);
    ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *output;
    if constexpr (std::is_same_v<
                      ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType>,
                      format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                             AutoValueType>>)
      output = out_format;
    else {
      if (convert_output)
        output = out_format->template Convert<ReturnFormatType>();
      else
        output = out_format->template As<ReturnFormatType>();
    }
    return output;
  }


  template <template <typename, typename, typename>
            typename RelativeReturnFormatType = format::FormatOrderTwo,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static std::pair<
      std::vector<
          format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>,
      RelativeReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *>
  Permute2DRowWiseCached(
      AutoIDType *ordering,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_output = false) {
    permute::PermuteOrderTwo<AutoIDType, AutoNNZType, AutoValueType> perm(
        ordering, nullptr);
    auto output = perm.GetPermutationCached(format, contexts, true);
    std::vector<
        format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>
        converted_formats;
    std::transform(
        std::get<0>(output)[0].begin(), std::get<0>(output)[0].end(),
        std::back_inserter(converted_formats),
        [](format::Format *intermediate_format) {
          return static_cast<
              format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>(
              intermediate_format);
        });
    RelativeReturnFormatType<AutoIDType, AutoNNZType, AutoValueType>
        *output_format;
    if constexpr (std::is_same_v<RelativeReturnFormatType<
                                     AutoIDType, AutoNNZType, AutoValueType>,
                                 format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                                        AutoValueType>>)
      output_format = std::get<1>(output);
    else {
      if (convert_output)
        output_format =
            std::get<1>(output)->template Convert<RelativeReturnFormatType>();
      else
        output_format =
            std::get<1>(output)->template As<RelativeReturnFormatType>();
    }
    return std::make_pair(converted_formats, output_format);
  }


  template <template <typename, typename, typename>
            typename ReturnFormatType = format::FormatOrderTwo,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *
  Permute2DColWise(
      AutoIDType *ordering,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_input,
      bool convert_output = false) {
    permute::PermuteOrderTwo<AutoIDType, AutoNNZType, AutoValueType> perm(
        nullptr, ordering);
    auto out_format = perm.GetPermutation(format, contexts, convert_input);
    ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *output;
    if constexpr (std::is_same_v<
                      ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType>,
                      format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                             AutoValueType>>)
      output = out_format;
    else {
      if (convert_output)
        output = out_format->template Convert<ReturnFormatType>();
      else
        output = out_format->template As<ReturnFormatType>();
    }
    return output;
  }


  template <template <typename, typename, typename>
            typename ReturnFormatType = format::FormatOrderTwo,
            typename AutoIDType, typename AutoNNZType, typename AutoValueType>
  static std::pair<std::vector<format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                                      AutoValueType> *>,
                   ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *>
  Permute2DColWiseCached(
      AutoIDType *ordering,
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      std::vector<context::Context *> contexts, bool convert_output = false) {
    permute::PermuteOrderTwo<AutoIDType, AutoNNZType, AutoValueType> perm(
        nullptr, ordering);
    auto output = perm.GetPermutationCached(format, contexts, true);
    std::vector<
        format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>
        converted_formats;
    std::transform(
        std::get<0>(output)[0].begin(), std::get<0>(output)[0].end(),
        std::back_inserter(converted_formats),
        [](format::Format *intermediate_format) {
          return static_cast<
              format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *>(
              intermediate_format);
        });
    ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType> *output_format;
    if constexpr (std::is_same_v<
                      ReturnFormatType<AutoIDType, AutoNNZType, AutoValueType>,
                      format::FormatOrderTwo<AutoIDType, AutoNNZType,
                                             AutoValueType>>)
      output_format = std::get<1>(output);
    else {
      if (convert_output)
        output_format =
            std::get<1>(output)->template Convert<ReturnFormatType>();
      else
        output_format = std::get<1>(output)->template As<ReturnFormatType>();
    }
    return std::make_pair(converted_formats, output_format);
  }


  template <template <typename>
            typename ReturnFormatType = format::FormatOrderOne,
            typename AutoIDType, typename AutoValueType>
  static ReturnFormatType<AutoValueType> *Permute1D(
      AutoIDType *ordering, format::FormatOrderOne<AutoValueType> *format,
      std::vector<context::Context *> context, bool convert_inputs,
      bool convert_output = false) {
    permute::PermuteOrderOne<AutoIDType, AutoValueType> perm(ordering);
    auto out_format = perm.GetPermutation(format, context, convert_inputs);
    ReturnFormatType<AutoValueType> *output;
    if constexpr (std::is_same_v<ReturnFormatType<AutoValueType>,
                                 format::FormatOrderOne<AutoValueType>>)
      output = out_format;
    else {
      if (convert_output)
        output = out_format->template Convert<ReturnFormatType>();
      else
        output = out_format->template As<ReturnFormatType>();
    }
    return output;
  }


  template <template <typename>
            typename ReturnFormatType = format::FormatOrderOne,
            typename AutoIDType, typename AutoValueType>
  static std::pair<std::vector<format::FormatOrderOne<AutoValueType> *>,
                   ReturnFormatType<AutoValueType> *>
  Permute1DCached(AutoIDType *ordering,
                  format::FormatOrderOne<AutoValueType> *format,
                  std::vector<context::Context *> context,
                  bool convert_output = false) {
    permute::PermuteOrderOne<AutoIDType, AutoValueType> perm(ordering);
    auto output = perm.GetPermutationCached(format, context, true);
    std::vector<format::FormatOrderOne<AutoValueType> *> converted_formats;
    std::transform(
        std::get<0>(output)[0].begin(), std::get<0>(output)[0].end(),
        std::back_inserter(converted_formats),
        [](format::Format *intermediate_format) {
          return static_cast<format::FormatOrderOne<AutoValueType> *>(
              intermediate_format);
        });
    ReturnFormatType<AutoValueType> *output_format;
    if constexpr (std::is_same_v<ReturnFormatType<AutoValueType>,
                                 format::FormatOrderOne<AutoValueType>>)
      output_format = std::get<1>(output);
    else {
      if (convert_output)
        output_format =
            std::get<1>(output)->template Convert<ReturnFormatType>();
      else
        output_format = std::get<1>(output)->template As<ReturnFormatType>();
    }
    return std::make_pair(converted_formats, output_format);
  }


  template <typename AutoIDType, typename AutoNumType>
  static AutoIDType *InversePermutation(AutoIDType *perm, AutoNumType length) {
    static_assert(std::is_integral_v<AutoNumType>,
                  "Length of the permutation array must be an integer");
    auto inv_perm = new AutoIDType[length];
    for (AutoIDType i = 0; i < length; i++) {
      inv_perm[perm[i]] = i;
    }
    return inv_perm;
  }

  template <typename FloatType, typename AutoIDType, typename AutoNNZType,
            typename AutoValueType>
  static sparsebase::format::Array<FloatType> *Heatmap(
      format::FormatOrderTwo<AutoIDType, AutoNNZType, AutoValueType> *format,
      format::FormatOrderOne<AutoIDType> *permutation_r,
      format::FormatOrderOne<AutoIDType> *permutation_c, int num_parts,
      std::vector<context::Context *> contexts, bool convert_input) {
    reorder::ReorderHeatmap<AutoIDType, AutoNNZType, AutoValueType, FloatType>
        heatmapper(num_parts);
    format::FormatOrderOne<FloatType> *arr = heatmapper.Get(
        format, permutation_r, permutation_c, contexts, convert_input);
    return arr->template Convert<sparsebase::format::Array>();
  }
};
}  // namespace sparsebase::bases

#endif  // SPARSEBASE_PROJECT_REORDER_BASE_H