Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added PyTorch Geometric functionality. #26

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Added PyTorch Geometric functionality.
  • Loading branch information
stefanvanberkum committed Aug 21, 2023
commit 20b602b006dbb9a7d8f561fef9bdbd2d631a2e19
2 changes: 2 additions & 0 deletions bindings/pyroot/pythonizations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -65,6 +65,8 @@ if(tmva)
ROOT/_pythonization/_tmva/_rtensor.py
ROOT/_pythonization/_tmva/_tree_inference.py
ROOT/_pythonization/_tmva/_utils.py)
list(APPEND PYROOT_EXTRA_PY3_SOURCE
ROOT/_pythonization/_tmva/_torchgnn.py)
endif()

if(PYTHON_VERSION_STRING_Development_Main VERSION_GREATER_EQUAL 3.8 AND dataframe)
Original file line number Diff line number Diff line change
@@ -22,6 +22,8 @@

from ._rbdt import Compute, pythonize_rbdt

from ._torchgnn import RModel_TorchGNN

if sys.version_info >= (3, 8):
from ._batchgenerator import (
CreateNumPyGenerators,
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Helper functions for Python TorchGNN.
Author: Stefan van Berkum
"""

from .. import pythonization
from cppyy.gbl.std import vector, map


class RModel_TorchGNN():
def ExtractParameters(self, model):
"""Extract the parameters from a PyTorch model.
In order for this to work, the parameterized module names in ROOT should
be the same as those in the PyTorch state dictionary, which is named
after the class attributes.
For example:
Torch: self.linear_1 = torch.nn.Linear(5, 20)
ROOT: model.AddModule(ROOT.TMVA.Experimental.SOFIE.RModule_Linear('X',
5, 20), 'linear_1')
:param model: The PyTorch model.
"""

# Transform Python dictionary to C++ map and load parameters.
m = map[str, vector[float]]()
for key, value in model.state_dict().items():
m[key] = value.cpu().numpy().flatten().tolist()
self.LoadParameters(m)


@pythonization("RModel_TorchGNN", ns="TMVA::Experimental::SOFIE")
def pythonize_torchgnn_extractparameters(klass):
setattr(klass, "ExtractParameters", RModel_TorchGNN.ExtractParameters)
15 changes: 15 additions & 0 deletions tmva/sofie/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -44,9 +44,24 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
TMVA/ROperator_Erf.hxx
TMVA/SOFIE_common.hxx
TMVA/SOFIEHelpers.hxx

TMVA/TorchGNN/modules/RModule_Add.hxx
TMVA/TorchGNN/modules/RModule_Cat.hxx
TMVA/TorchGNN/modules/RModule_GCNConv.hxx
TMVA/TorchGNN/modules/RModule_GlobalMeanPool.hxx
TMVA/TorchGNN/modules/RModule_Input.hxx
TMVA/TorchGNN/modules/RModule_Linear.hxx
TMVA/TorchGNN/modules/RModule_ReLU.hxx
TMVA/TorchGNN/modules/RModule_Reshape.hxx
TMVA/TorchGNN/modules/RModule_Softmax.hxx
TMVA/TorchGNN/modules/RModule.hxx

TMVA/TorchGNN/RModel_TorchGNN.hxx
SOURCES
src/RModel.cxx
src/SOFIE_common.cxx

src/TorchGNN/RModel_TorchGNN.cxx
DEPENDENCIES
TMVA
)
2 changes: 2 additions & 0 deletions tmva/sofie/inc/LinkDef.h
Original file line number Diff line number Diff line change
@@ -14,5 +14,7 @@
#pragma link C++ struct TMVA::Experimental::SOFIE::TensorInfo+;
#pragma link C++ struct TMVA::Experimental::SOFIE::InputTensorInfo+;
#pragma link C++ struct TMVA::Experimental::SOFIE::Dim+;
#pragma link C++ class TMVA::Experimental::SOFIE::RModule+;
#pragma link C++ class TMVA::Experimental::SOFIE::RModel_TorchGNN+;

#endif
203 changes: 203 additions & 0 deletions tmva/sofie/inc/TMVA/TorchGNN/RModel_TorchGNN.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// @(#)root/tmva/sofie:$Id$
// Author: Stefan van Berkum

/**
* Header file for PyTorch Geometric models.
*
* Models are created by the user and parameters can then be loaded into each layer.
*
* IMPORTANT: Changes to the format (e.g., namespaces) may affect the emit
* defined in RModel_TorchGNN.cxx (save).
*/

#ifndef TMVA_SOFIE_RMODEL_TORCHGNN_H_
#define TMVA_SOFIE_RMODEL_TORCHGNN_H_

#include "TMVA/TorchGNN/modules/RModule.hxx"
#include "TMVA/TorchGNN/modules/RModule_Input.hxx"
#include <stdexcept>
#include <iostream>

namespace TMVA {
namespace Experimental {
namespace SOFIE {

class RModel_TorchGNN {
public:
/** Model constructor without inputs. */
RModel_TorchGNN() {}

/**
* Model constructor with manual input names.
*
* @param input_names Vector of input names.
* @param input_shapes Vector of input shapes. Each element may contain
* at most one wildcard (-1).
*/
RModel_TorchGNN(std::vector<std::string> input_names, std::vector<std::vector<int>> input_shapes) {
fInputs = input_names;
fShapes = input_shapes;

// Generate input layers.
for (std::size_t i = 0; i < input_names.size(); i++) {
// Check shape.
if (std::any_of(input_shapes[i].begin(), input_shapes[i].end(), [](int j){return j == 0;})) {
throw std::invalid_argument("Invalid input shape for input " + input_names[i] + ". Dimension cannot be zero.");
}
if (std::any_of(input_shapes[i].begin(), input_shapes[i].end(), [](int j){return j < -1;})) {
throw std::invalid_argument("Invalid input shape for input " + input_names[i] + ". Shape cannot have negative entries (except for the wildcard dimension).");
}
if (std::count(input_shapes[i].begin(), input_shapes[i].end(), -1) > 1) {
throw std::invalid_argument("Invalid input shape for input " + input_names[i] + ". Shape may have at most one wildcard.");
}
AddModule(RModule_Input(input_shapes[i]), input_names[i]);
}
}

/**
* Add a module to the module list.
*
* @param module Module to add.
* @param name Module name. Defaults to the module type with a count
* value (e.g., GCNConv_1).
*/
template<typename T>
void AddModule(T module, std::string name="") {
std::string new_name = (name == "") ? std::string(module.GetOperation()) : name;
if (fModuleCounts[new_name] > 0) {
// Module exists, so add discriminator and increment count.
new_name += "_" + std::to_string(fModuleCounts[new_name]);
fModuleCounts[new_name]++;

if (name != "") {
// Issue warning.
std::cout << "WARNING: Module with duplicate name \"" << name << "\" renamed to \"" << new_name << "\"." << std::endl;
}
} else {
// First module of its kind.
fModuleCounts[new_name] = 1;
}
module.SetName(new_name);

// Initialize the module.
module.Initialize(fModules, fModuleMap);

// Add module to the module list.
fModules.push_back(std::make_shared<T>(module));
fModuleMap[std::string(module.GetName())] = fModuleCount;
fModuleCount++;
}

/**
* Run the forward function.
*
* @param args Any number of input arguments.
* @returns The output of the last layer.
*/
template<class... Types>
std::vector<float> Forward(Types... args) {
auto input = std::make_tuple(args...);

// Instantiate input layers.
int k = 0;
std::apply(
[&](auto&... in) {
((std::dynamic_pointer_cast<RModule_Input>(fModules[k++]) -> SetParams(in)), ...);
}, input);

// Loop through and execute modules.
for (std::shared_ptr<RModule> module: fModules) {
module -> Execute();
}

// Return output of the last layer.
const std::vector<float>& out_const = fModules.back() -> GetOutput();
std::vector<float> out = out_const;
return out;
}

/**
* Load parameters from PyTorch state dictionary for all modules.
*
* @param state_dict The state dictionary.
*/
void LoadParameters(std::map<std::string, std::vector<float>> state_dict) {
for (std::shared_ptr<RModule> module: fModules) {
module -> LoadParameters(state_dict);
}
}

/**
* Load saved parameters for all modules.
*/
void LoadParameters() {
for (std::shared_ptr<RModule> module: fModules) {
module -> LoadParameters();
}
}

/**
* Save the model as standalone inference code.
*
* @param path Path to save location.
* @param name Model name.
* @param overwrite True if any existing directory should be
* overwritten. Defaults to false.
*/
void Save(std::string path, std::string name, bool overwrite=false);
private:
/**
* Get a timestamp.
*
* @returns The timestamp in string format.
*/
static std::string GetTimestamp() {
time_t rawtime;
struct tm * timeinfo;
char timestamp [80];
time(&rawtime);
timeinfo = localtime(&rawtime);
strftime(timestamp, 80, "Timestamp: %d-%m-%Y %T.", timeinfo);
return timestamp;
}

/**
* Write the methods to create a self-contained package.
*
* @param dir Directory to save to.
* @param name Model name.
* @param timestamp Timestamp.
*/
void WriteMethods(std::string dir, std::string name, std::string timestamp);

/**
* Write the model to a file.
*
* @param dir Directory to save to.
* @param name Model name.
* @param timestamp Timestamp.
*/
void WriteModel(std::string dir, std::string name, std::string timestamp);

/**
* Write the CMakeLists file.
*
* @param dir Directory to save to.
* @param name Model name.
* @param timestamp Timestamp.
*/
void WriteCMakeLists(std::string dir, std::string name, std::string timestamp);

std::vector<std::string> fInputs; // Vector of input names.
std::vector<std::vector<int>> fShapes; // Vector of input shapes.
std::map<std::string, int> fModuleCounts; // Map from module name to number of occurrences.
std::vector<std::shared_ptr<RModule>> fModules; // Vector containing the modules.
std::map<std::string, int> fModuleMap; // Map from module name to module index (in modules).
int fModuleCount = 0; // Number of modules.
};

} // SOFIE.
} // Experimental.
} // TMVA.

#endif // TMVA_SOFIE_RMODEL_TORCHGNN_H_
Loading