Class Acts::EdgeClassificationBase¶
-
class EdgeClassificationBase¶
Subclassed by Acts::OnnxEdgeClassifier, Acts::TorchEdgeClassifier
Public Functions
-
virtual ~EdgeClassificationBase() = default¶
-
virtual std::tuple<std::any, std::any, std::any> operator()(std::any nodes, std::any edges) = 0¶
Perform edge classification.
- Parameters
nodes – Node tensor with shape (n_nodes, n_node_features)
edges – Edge-index tensor with shape (2, n_edges)
- Returns
(node_tensor, edge_tensor, score_tensor)
-
virtual ~EdgeClassificationBase() = default¶