Program Listing for File function_matcher_mixin.h

Return to documentation for file (src/sparsebase/utils/function_matcher_mixin.h)

#ifndef SPARSEBASE_PROJECT_FUNCTIONMATCHERMIXIN_H
#define SPARSEBASE_PROJECT_FUNCTIONMATCHERMIXIN_H

#include <vector>

#include "sparsebase/config.h"
#include "sparsebase/converter/converter.h"
#include "sparsebase/utils/parameterizable.h"
#include "sparsebase/utils/utils.h"

namespace sparsebase::utils {


template <typename ReturnType>
using PreprocessFunction = ReturnType (*)(std::vector<format::Format *> formats,
                                          utils::Parameters *params);


template <typename ReturnType, class PreprocessingImpl = Parameterizable,
          typename Function = PreprocessFunction<ReturnType>,
          typename Key = std::vector<std::type_index>,
          typename KeyHash = TypeIndexVectorHash,
          typename KeyEqualTo = std::equal_to<std::vector<std::type_index>>>
class FunctionMatcherMixin : public PreprocessingImpl {
  typedef std::unordered_map<Key, Function, KeyHash, KeyEqualTo> ConversionMap;

 public:
  std::vector<Key> GetAvailableFormats() {
    std::vector<Key> keys;
    for (auto element : map_to_function_) {
      keys.push_back(element.first);
    }
    return keys;
  }

  bool RegisterFunctionNoOverride(const Key &key_of_function,
                                  const Function &func_ptr);

  void RegisterFunction(const Key &key_of_function, const Function &func_ptr);

  bool UnregisterFunction(const Key &key_of_function);

 protected:
  using PreprocessingImpl::PreprocessingImpl;
  ConversionMap map_to_function_;

  std::tuple<Function, converter::ConversionSchema> GetFunction(
      std::vector<format::Format *> packed_formats, Key key, ConversionMap map,
      std::vector<context::Context *> contexts);

  bool CheckIfKeyMatches(ConversionMap map, Key key,
                         std::vector<format::Format *> packed_formats,
                         std::vector<context::Context *> contexts);
  template <typename Object, typename... Objects>
  std::vector<Object> PackObjects(Object object, Objects... objects);
  template <typename Object>
  std::vector<Object> PackObjects(Object object);

  template <typename F, typename... SF>
  ReturnType Execute(utils::Parameters *params,
                     std::vector<context::Context *> contexts,
                     bool convert_input, F sf, SF... sfs);

  template <typename F, typename... SF>
  std::tuple<std::vector<std::vector<format::Format *>>, ReturnType>
  CachedExecute(utils::Parameters *params,
                std::vector<context::Context *> contexts, bool convert_input,
                bool clear_intermediate, F format, SF... formats);
};

template <typename ReturnType, class PreprocessingImpl, typename Function,
          typename Key, typename KeyHash, typename KeyEqualTo>
template <typename F, typename... SF>
std::tuple<std::vector<std::vector<format::Format *>>, ReturnType>
FunctionMatcherMixin<ReturnType, PreprocessingImpl, Function, Key, KeyHash,
                     KeyEqualTo>::CachedExecute(utils::Parameters *params,
                                                std::vector<context::Context *>
                                                    contexts,
                                                bool convert_input,
                                                bool clear_intermediate,
                                                F format, SF... formats) {
  ConversionMap map = this->map_to_function_;
  // pack the Formats into a vector
  std::vector<format::Format *> packed_formats =
      PackObjects(format, formats...);
  // pack the types of Formats into a vector
  std::vector<std::type_index> packed_format_types;
  for (auto f : packed_formats) packed_format_types.push_back(f->get_id());
  // get conversion schema
  std::tuple<Function, converter::ConversionSchema> ret =
      GetFunction(packed_formats, packed_format_types, map, contexts);
  Function func = std::get<0>(ret);
  converter::ConversionSchema cs = std::get<1>(ret);
  // carry out conversion
  // ready_formats contains the format to use in preprocessing
  if (!convert_input) {
    for (const auto &conversion_chain : cs) {
      if (conversion_chain)
        throw utils::DirectExecutionNotAvailableException(
            packed_format_types, this->GetAvailableFormats());
    }
  }
  std::vector<std::vector<format::Format *>> all_formats =
      sparsebase::converter::Converter::ApplyConversionSchema(
          cs, packed_formats, clear_intermediate);
  // The formats that will be used in the preprocessing implementation function
  // calls
  std::vector<format::Format *> final_formats;
  std::transform(all_formats.begin(), all_formats.end(),
                 std::back_inserter(final_formats),
                 [](std::vector<format::Format *> conversion_chain) {
                   return conversion_chain.back();
                 });
  // Formats that are used to get to the final formats
  std::vector<std::vector<format::Format *>> intermediate_formats;
  std::transform(all_formats.begin(), all_formats.end(),
                 std::back_inserter(intermediate_formats),
                 [](std::vector<format::Format *> conversion_chain) {
                   if (conversion_chain.size() > 1)
                     return std::vector<format::Format *>(
                         conversion_chain.begin() + 1, conversion_chain.end());
                   return std::vector<format::Format *>();
                 });
  // carry out the correct call
  return std::make_tuple(intermediate_formats, func(final_formats, params));
}

template <typename ReturnType, class PreprocessingImpl, typename Function,
          typename Key, typename KeyHash, typename KeyEqualTo>
template <typename F, typename... SF>
ReturnType FunctionMatcherMixin<
    ReturnType, PreprocessingImpl, Function, Key, KeyHash,
    KeyEqualTo>::Execute(utils::Parameters *params,
                         std::vector<context::Context *> contexts,
                         bool convert_input, F sf, SF... sfs) {
  auto cached_output =
      CachedExecute(params, contexts, convert_input, true, sf, sfs...);
  auto converted_format_chains = std::get<0>(cached_output);
  auto return_object = std::get<1>(cached_output);
  for (const auto &converted_format_chain : converted_format_chains) {
    for (const auto &converted_format : converted_format_chain)
      delete converted_format;
  }
  return return_object;
}

template <typename ReturnType, class PreprocessingImpl, typename Key,
          typename KeyHash, typename KeyEqualTo, typename Function>
template <typename Object>
std::vector<Object>
FunctionMatcherMixin<ReturnType, PreprocessingImpl, Key, KeyHash, KeyEqualTo,
                     Function>::PackObjects(Object object) {
  return {object};
}
template <typename ReturnType, class PreprocessingImpl, typename Key,
          typename KeyHash, typename KeyEqualTo, typename Function>
template <typename Object, typename... Objects>
std::vector<Object>
FunctionMatcherMixin<ReturnType, PreprocessingImpl, Key, KeyHash, KeyEqualTo,
                     Function>::PackObjects(Object object, Objects... objects) {
  std::vector<Object> v = {object};
  std::vector<Object> remainder = PackObjects(objects...);
  for (auto i : remainder) {
    v.push_back(i);
  }
  return v;
}
template <typename ReturnType, class Preprocess, typename Function,
          typename Key, typename KeyHash, typename KeyEqualTo>
bool FunctionMatcherMixin<
    ReturnType, Preprocess, Function, Key, KeyHash,
    KeyEqualTo>::RegisterFunctionNoOverride(const Key &key_of_function,
                                            const Function &func_ptr) {
  if (map_to_function_.find(key_of_function) != map_to_function_.end()) {
    return false;  // function already exists for this Key
  } else {
    map_to_function_[key_of_function] = func_ptr;
    return true;
  }
}

template <typename ReturnType, class Preprocess, typename Function,
          typename Key, typename KeyHash, typename KeyEqualTo>
void FunctionMatcherMixin<
    ReturnType, Preprocess, Function, Key, KeyHash,
    KeyEqualTo>::RegisterFunction(const Key &key_of_function,
                                  const Function &func_ptr) {
  map_to_function_[key_of_function] = func_ptr;
}
template <typename ReturnType, class Preprocess, typename Function,
          typename Key, typename KeyHash, typename KeyEqualTo>
bool FunctionMatcherMixin<ReturnType, Preprocess, Function, Key, KeyHash,
                          KeyEqualTo>::UnregisterFunction(const Key &
                                                              key_of_function) {
  if (map_to_function_.find(key_of_function) == map_to_function_.end()) {
    return false;  // function already exists for this Key
  } else {
    map_to_function_.erase(key_of_function);
    return true;
  }
}
template <typename ReturnType, class PreprocessingImpl, typename Function,
          typename Key, typename KeyHash, typename KeyEqualTo>
bool FunctionMatcherMixin<
    ReturnType, PreprocessingImpl, Function, Key, KeyHash,
    KeyEqualTo>::CheckIfKeyMatches(ConversionMap map, Key key,
                                   std::vector<format::Format *> packed_sfs,
                                   std::vector<context::Context *> contexts) {
  bool match = true;
  if (map.find(key) != map.end()) {
    for (auto sf : packed_sfs) {
      bool found_context = false;
      for (auto context : contexts) {
        if (sf->get_context()->IsEquivalent(context)) {
          found_context = true;
        }
      }
      if (!found_context) match = false;
    }
  } else {
    match = false;
  }
  return match;
}

template <typename ReturnType, class PreprocessingImpl, typename Function,
          typename Key, typename KeyHash, typename KeyEqualTo>
std::tuple<Function, converter::ConversionSchema> FunctionMatcherMixin<
    ReturnType, PreprocessingImpl, Function, Key, KeyHash,
    KeyEqualTo>::GetFunction(std::vector<format::Format *> packed_sfs, Key key,
                             ConversionMap map,
                             std::vector<context::Context *> contexts) {
  converter::ConversionSchema cs;
  Function func = nullptr;
  // When function and conversion costs are added,
  // this 'if' should be removed  -- a conversion might be
  // cheaper than direct call to matching key
  if (CheckIfKeyMatches(map, key, packed_sfs, contexts)) {
    for (auto f : key) {
      cs.push_back({});
    }
    func = map[key];
    return std::make_tuple(func, cs);
  }
  // the keys of all the available functions in preprocessing
  std::vector<Key> all_keys;
  for (const auto &key_func : map) {
    all_keys.push_back(key_func.first);
  }
  // Find all the keys that can potentially run with this input
  std::vector<std::tuple<unsigned int, converter::ConversionSchema, Key>>
      usable_keys;
  for (auto potential_key : all_keys) {
    if (potential_key.size() != key.size()) continue;
    converter::ConversionSchema temp_cs;
    bool is_usable = true;
    int conversion_cost = 0;
    for (int i = 0; i < potential_key.size(); i++) {
      if (key[i] == potential_key[i]) {
        temp_cs.push_back({});
        conversion_cost += 0;  // no conversion cost
      } else {
        auto sc = packed_sfs[i]->get_converter();
        auto conversion_chain = sc->GetConversionChain(
            key[i], packed_sfs[i]->get_context(), potential_key[i], contexts);
        if (conversion_chain) {
          temp_cs.push_back(*conversion_chain);
          conversion_cost += std::get<1>(*conversion_chain);
        } else {
          is_usable = false;
        }
      }
    }
    // At this point, we can add the cost of the function with key
    // "potential_key"
    if (is_usable) {
      int total_cost = conversion_cost;  // add function cost in the future
      usable_keys.push_back(
          std::make_tuple(total_cost, temp_cs, potential_key));
    }
  }
  if (usable_keys.empty()) {
    std::string message;
    message = "Could not find a function that matches the formats: {";
    for (auto f : packed_sfs) {
      message += f->get_name();
      message += " ";
    }
    message += "} using the contexts {";
    for (auto c : contexts) {
      message += c->get_name();
      message += " ";
    }
    message += "}";

    throw sparsebase::utils::FunctionNotFoundException(
        message);  // TODO: add a custom exception type
  }
  std::tuple<Function, converter::ConversionSchema> best_conversion;
  float cost = std::numeric_limits<float>::max();
  for (auto potential_usable_key : usable_keys) {
    if (cost > std::get<0>(potential_usable_key)) {
      cost = std::get<0>(potential_usable_key);
      cs = std::get<1>(potential_usable_key);
      func = map[std::get<2>(potential_usable_key)];
    }
  }
  return std::make_tuple(func, cs);
}
}  // namespace sparsebase::utils
#endif  // SPARSEBASE_PROJECT_FUNCTIONMATCHERMIXIN_H