File TorchEdgeClassifier.hpp

namespace Acts


This file is foreseen for the Geometry module to replace Extent

class TorchEdgeClassifier : public Acts::EdgeClassificationBase
#include <Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp>

Public Functions

TorchEdgeClassifier(const Config &cfg)
inline Config config() const
virtual std::tuple<std::any, std::any, std::any> operator()(std::any nodes, std::any edges, const Logger &logger) override

Perform edge classification.

  • nodes – Node tensor with shape (n_nodes, n_node_features)

  • edges – Edge-index tensor with shape (2, n_edges)

  • logger – Logger instance


(node_tensor, edge_tensor, score_tensor)

Private Members

Config m_cfg
c10::DeviceType m_deviceType
std::unique_ptr<torch::jit::Module> m_model
struct Config
#include <Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp>

Public Members

float cut = 0.21
std::string modelPath
int nChunks = 1
namespace c10
namespace jit