Program Listing for File converter.h

Return to documentation for file (src/sparsebase/converter/converter.h)

/*******************************************************
 * Copyright (c) 2022 SparCity, Amro Alabsi Aljundi, Taha Atahan Akyildiz, Arda
 *Sener All rights reserved.
 *
 * This file is distributed under MIT license.
 * The complete license agreement can be obtained at:
 * https://sparcityeu.github.io/sparsebase/pages/license.html
 ********************************************************/
#ifndef SPARSEBASE_SPARSEBASE_UTILS_CONVERTER_CONVERTER_H_
#define SPARSEBASE_SPARSEBASE_UTILS_CONVERTER_CONVERTER_H_

#include <algorithm>
#include <cstring>
#include <fstream>
#include <functional>
#include <memory>
#include <optional>
#include <tuple>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <vector>

#include "sparsebase/config.h"
#include "sparsebase/utils/utils.h"

// Forward decleration for the `Convert` functions in
// sparsebase/format/format.h
namespace sparsebase {
namespace format {
class Format;
}
namespace context {
class Context;
}

namespace converter {


typedef std::function<format::Format *(format::Format *, context::Context *)>
    ConversionFunction;

typedef std::function<bool(context::Context *, context::Context *)>
    ConversionCondition;

typedef std::tuple<ConversionFunction, context::Context *, utils::CostType>
    ConversionStep;

typedef std::optional<std::tuple<std::vector<ConversionStep>, utils::CostType>>
    ConversionChain;

typedef std::vector<ConversionChain> ConversionSchema;

typedef std::unordered_map<
    std::type_index,
    std::unordered_map<
        std::type_index,
        std::vector<std::tuple<ConversionCondition, ConversionFunction>>>>
    ConversionMap;
class Converter {
 private:
  ConversionMap copy_conversion_map_;
  ConversionMap move_conversion_map_;

  ConversionMap *get_conversion_map(bool is_move_conversion);

  const ConversionMap *get_conversion_map(bool is_move_conversion) const;


  static std::vector<ConversionStep> ConversionBFS(
      std::type_index from_type, context::Context *from_context,
      std::type_index to_type,
      const std::vector<context::Context *> &to_contexts,
      const ConversionMap *map);

 public:

  void RegisterConversionFunction(std::type_index from_type,
                                  std::type_index to_type,
                                  ConversionFunction conv_func,
                                  ConversionCondition edge_condition,
                                  bool is_move_conversion = false);


  format::Format *Convert(format::Format *source, std::type_index to_type,
                          context::Context *to_context,
                          bool is_move_conversion = false) const;


  std::vector<format::Format *> ConvertCached(
      format::Format *source, std::type_index to_type,
      context::Context *to_context, bool is_move_conversion = false) const;


  format::Format *Convert(format::Format *source, std::type_index to_type,
                          std::vector<context::Context *> to_contexts,
                          bool is_move_conversion = false) const;


  std::vector<format::Format *> ConvertCached(
      format::Format *source, std::type_index to_type,
      std::vector<context::Context *> to_context,
      bool is_move_conversion = false) const;

  template <typename FormatType>
  FormatType *Convert(format::Format *source, context::Context *to_context,
                      bool is_move_conversion = false) const {
    auto *res = this->Convert(source, FormatType::get_id_static(), to_context,
                              is_move_conversion);
    return res->template AsAbsolute<FormatType>();
  }


  template <typename FormatType>
  FormatType *Convert(format::Format *source,
                      std::vector<context::Context *> to_contexts,
                      bool is_move_conversion = false) const {
    auto *res = this->Convert(source, FormatType::get_id_static(), to_contexts,
                              is_move_conversion);
    return res->template AsAbsolute<FormatType>();
  }


  ConversionChain GetConversionChain(
      std::type_index from_type, context::Context *from_context,
      std::type_index to_type,
      const std::vector<context::Context *> &to_contexts,
      bool is_move_conversion = false) const;


  bool CanConvert(std::type_index from_type, context::Context *from_context,
                  std::type_index to_type, context::Context *to_context,
                  bool is_move_conversion = false) const;


  bool CanConvert(std::type_index from_type, context::Context *from_context,
                  std::type_index to_type,
                  const std::vector<context::Context *> &to_contexts,
                  bool is_move_conversion = false);

  void ClearConversionFunctions(std::type_index from_type,
                                std::type_index to_type,
                                bool move_conversion = false);


  void ClearConversionFunctions(bool move_conversion = false);


  static std::vector<std::vector<format::Format *>> ApplyConversionSchema(
      const ConversionSchema &cs,
      const std::vector<format::Format *> &packed_sfs, bool clear_intermediate);


  static std::vector<format::Format *> ApplyConversionChain(
      const ConversionChain &chain, format::Format *, bool clear_intermediate);

  virtual std::type_index get_converter_type() const = 0;
  virtual Converter *Clone() const = 0;
  virtual void Reset() = 0;
  virtual ~Converter();
};

template <class ConverterType>
class ConverterImpl : public Converter {
 public:
  virtual std::type_index get_converter_type() const {
    return typeid(ConverterType);
  }
};

}  // namespace converter

}  // namespace sparsebase
#ifdef USE_CUDA
#include "converter_cuda.cuh"
#ifdef _HEADER_ONLY
#include "converter_cuda.cu"
#endif
#endif

#ifdef _HEADER_ONLY
#include "sparsebase/converter/converter.cc"
#endif

#endif  // SPARSEBASE_SPARSEBASE_UTILS_CONVERTER_CONVERTER_H_