File TorchMetricLearning.hpp

namespace Acts


This file is foreseen for the Geometry module to replace Extent

class TorchMetricLearning : public Acts::GraphConstructionBase
#include <Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp>

Public Functions

TorchMetricLearning(const Config &cfg, std::unique_ptr<const Logger> logger)
inline Config config() const
virtual std::tuple<std::any, std::any> operator()(std::vector<float> &inputValues, std::size_t numNodes) override

Perform the graph construction.

  • inputValues – Flattened input data

  • numNodes – number of nodes. inputValues.size() / numNodes then gives the number of features


(node_tensor, edge_tensore)

Private Functions

inline const auto &logger() const

Private Members

Config m_cfg
c10::DeviceType m_deviceType
std::unique_ptr<const Acts::Logger> m_logger
std::unique_ptr<torch::jit::Module> m_model
struct Config
#include <Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp>

Public Members

int embeddingDim = 8
int knnVal = 500
std::string modelPath
int numFeatures = 3
float rVal = 1.6
bool shuffleDirections = false
namespace c10
namespace torch
namespace jit