Class Acts::EdgeClassificationBase

class EdgeClassificationBase

Subclassed by Acts::OnnxEdgeClassifier, Acts::TorchEdgeClassifier

Public Functions

virtual ~EdgeClassificationBase() = default
virtual std::tuple<std::any, std::any, std::any> operator()(std::any nodes, std::any edges) = 0

Perform edge classification.

Parameters
  • nodes – Node tensor with shape (n_nodes, n_node_features)

  • edges – Edge-index tensor with shape (2, n_edges)

Returns

(node_tensor, edge_tensor, score_tensor)