Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tmva][sofie] Fix output type of TopK and clean up code
Browse files Browse the repository at this point in the history
Fix the output type when parsing TopK

Clean up the code in TopK impelmentation and in the generated code.
Fix also the compilation warnings
lmoneta committed Jun 26, 2024
1 parent 77c0ff2 commit 449baf1
Showing 2 changed files with 33 additions and 40 deletions.
56 changes: 27 additions & 29 deletions tmva/sofie/inc/TMVA/ROperator_TopK.hxx
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ private:
int fAttrLargest;
int fAttrSorted;

size_t k;
size_t fK;
std::string fNK;
std::string fNX;
std::string fNVal;
@@ -50,15 +50,9 @@ public:
}

auto shape = input[0]; // Shape format: [ m x n x o x p ... ]
if (fAttrAxis < 0) {
fAttrAxis += shape.size();
}
if (fAttrAxis < 0 || fAttrAxis >= shape.size()) {
throw std::runtime_error("TMVA SOFIE TopK Op axis value is out of bounds");
}

// set the dimension at the specified axis to k
shape[fAttrAxis] = k; // Modified shape: [ m x n x k x p ... ]
// set the dimension at the specified axis to k (fAttrAxis is checked before that is in the correct range
shape[fAttrAxis] = fK; // Modified shape: [ m x n x k x p ... ]
return {shape, shape};
}

@@ -74,19 +68,21 @@ public:
throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor i.e. K is not found in model");
}

fShapeX=model.GetTensorShape(fNX);
auto fShapeK=model.GetTensorShape(fNK);
auto kptr=static_cast<int64_t *>(model.GetInitializedTensorData(fNK).get());
k=*kptr;
fShapeX = model.GetTensorShape(fNX);
auto fShapeK = model.GetTensorShape(fNK);
auto kptr = static_cast<int64_t *>(model.GetInitializedTensorData(fNK).get());
fK = *kptr;
fAttrAxis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
if(fAttrAxis>=fShapeX.size()){
if(static_cast<size_t>(fAttrAxis) >= fShapeX.size()){
throw
std::runtime_error("TMVA::SOFIE ONNX TopK op axis = "+ std::to_string(fAttrAxis) +" value exeeds size of tensor " +fNX+" of size "+fShapeX.size()+" .");
}
if(k>fShapeX[fAttrAxis]){
throw
std::runtime_error("TMVA::SOFIE ONNX TopK op k = "+ std::to_string(k) +" value exeeds value of tensor " +fNX+" of size "+fShapeX.size()+" at axis= "+std::to_string(fAttrAxis)+".");
}
// fK cannot be larger that axis dimension
fK = std::min(fK, fShapeX[fAttrAxis]);
// if(fK>fShapeX[fAttrAxis]){
// throw
// std::runtime_error("TMVA::SOFIE ONNX TopK op k = "+ std::to_string(fK) +" value exeeds value of tensor " +fNX+" of size "+fShapeX.size()+" at axis= "+std::to_string(fAttrAxis)+".");
// }
// fShapeX = model.GetTensorShape(fNX); // [ m x n x o x p ... ]
// if(k[0]>=fShapeX.size()){
// throw
@@ -98,6 +94,7 @@ public:
// size_t axis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
// fShapeY[axis] = k[0]; // [ 2 x m x n x K x p ... ]
fShapeY=ShapeInference({fShapeX,fShapeK})[0];

// for(int i=0;i<fShapeX.size();i++)
// std::cout<<fShapeX[i]<<" ";
// std::cout<<"\ny size -> "<<fShapeY.size()<<std::endl;
@@ -116,12 +113,13 @@ public:
}
std::stringstream out;
size_t size = fShapeX.size();
size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis; // not really needed
out << "\n" << SP << "//------ TopK\n";

size_t length=ConvertShapeToLength(fShapeX);
size_t bound=1;
for(int i=0;i<axis;i++)bound*=fShapeX[i]; //bound decider
for(size_t i = 0; i < axis; i++)
bound *= fShapeX[i]; //bound decider

// first we create boundaries in the input
// [m,n,o,k,p] => boundary's size = length/(m*n*o)
@@ -130,18 +128,18 @@ public:
size_t jump= groupSize/fShapeX[fAttrAxis];
//candidates to check in group
size_t prod =1;

Check warning on line 130 in tmva/sofie/inc/TMVA/ROperator_TopK.hxx

GitHub Actions / alma9-clang clang LLVM_ENABLE_ASSERTIONS=On, CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

variable 'prod' set but not used [-Wunused-but-set-variable]

Check warning on line 130 in tmva/sofie/inc/TMVA/ROperator_TopK.hxx

GitHub Actions / alma9-clang clang LLVM_ENABLE_ASSERTIONS=On, CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

variable 'prod' set but not used [-Wunused-but-set-variable]
for(int i=fAttrAxis+1;i<size;i++){
prod*=fShapeX[i];
for(size_t i = fAttrAxis+1; i < size; i++){
prod *= fShapeX[i];
}
size_t numOfChecksInGrp=groupSize/jump;
size_t numOfCheckersInGrp=groupSize/numOfChecksInGrp;

// for(int i=0;i<length;i++){
// if(i==groupSize)dim=0;
// }
out<<SP<<"size_t itr = 0, p=0;\n";
out<<SP<<"size_t itr = 0, p = 0;\n";
out<<SP<<"std::vector<std::vector<std::pair<float,int>>>groupElements;\n";
out<<SP<<"for(long i=0;i<"<<length<<";i++) {\n";
out<<SP<<"for (size_t i = 0; i < "<<length<<"; i++) {\n";
//main logic
out<<SP<<SP<<"size_t tempitr=0, j=0;\n";
out<<SP<<SP<<"std::vector<std::pair<float,int>>elements;\n";
@@ -158,13 +156,13 @@ public:

out<<SP<<SP<<"itr++;\n";
out<<SP<<SP<<"std::vector<std::pair<float,int>>kelems;\n";
out<<SP<<SP<<"for(int j=0; j < "<<k<<";j++){\n kelems.push_back({elements[j].first,elements[j].second});\n"<<SP<<SP<<"}\n";
out<<SP<<SP<<"for (int j = 0; j < " << fK <<"; j++){\n kelems.push_back({elements[j].first,elements[j].second});\n"<<SP<<SP<<"}\n";
out<<SP<<SP<<"groupElements.push_back(kelems);\n";
out<<SP<<SP<<"if(itr == "<<numOfCheckersInGrp<<"){\n itr = 0;\n i += "<<groupSize-numOfCheckersInGrp/*to compensate the default i++*/<<";\n";
out<<SP<<SP<<SP<<"for(int j=0; j < groupElements[0].size();j++) {\n";
out<<SP<<SP<<SP<<SP<<"for(int k=0; k < groupElements.size();k++) {\n";
out<<SP<<SP<<SP<<SP<<SP<<"tensor_"<<fNVal<<"[p]=(groupElements[k][j].first);\n";
out<<SP<<SP<<SP<<SP<<SP<<"tensor_"<<fNInd<<"[p++]=(groupElements[k][j].second);\n";
out<<SP<<SP<<SP<<"for (size_t j = 0; j < groupElements[0].size(); j++) {\n";
out<<SP<<SP<<SP<<SP<<"for(size_t k = 0; k < groupElements.size(); k++) {\n";
out<<SP<<SP<<SP<<SP<<SP<<"tensor_"<<fNVal<<"[p] = (groupElements[k][j].first);\n";
out<<SP<<SP<<SP<<SP<<SP<<"tensor_"<<fNInd<<"[p++] = (groupElements[k][j].second);\n";
out<<SP<<SP<<SP<<SP<<"}\n";// end for
out<<SP<<SP<<SP<<"}\n";// end for
out<<SP<<SP<<SP<<"groupElements.clear();\n";
17 changes: 6 additions & 11 deletions tmva/sofie_parsers/src/ParseTopK.cxx
Original file line number Diff line number Diff line change
@@ -16,18 +16,12 @@ ParserFuncSignature ParseTopK = [](RModelParser_ONNX &parser, const onnx::NodePr
throw std::runtime_error("TMVA::SOFIE ONNX Parser TopK op has input tensor " + input_name +
" but its type is not yet registered");
}
ETensorType k_type = ETensorType::UNDEFINED;
std::string k_name = nodeproto.input(1);
if (parser.IsRegisteredTensorType(k_name)) {
k_type = parser.GetTensorType(k_name);
} else {
if (!parser.IsRegisteredTensorType(k_name)) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser TopK op has input tensor " + k_name +
" but its type is not yet registered");
}

// std::vector<int> kElem;
// kElem.push_back(std::stoi(k));

std::unique_ptr<ROperator> op;

std::string outputVal_name = nodeproto.output(0);
@@ -47,10 +41,11 @@ ParserFuncSignature ParseTopK = [](RModelParser_ONNX &parser, const onnx::NodePr
}
op.reset(new ROperator_TopK<float>(attr_axis, attr_largest, attr_sorted, k_name, input_name, outputVal_name, outputInd_name));

for(const auto& output_name:{outputVal_name,outputInd_name}){
if (!parser.IsRegisteredTensorType(output_name)) {
parser.RegisterTensorType(output_name, input_type);
}
if (!parser.IsRegisteredTensorType(outputVal_name)) {
parser.RegisterTensorType(outputVal_name, input_type);
}
if (!parser.IsRegisteredTensorType(outputInd_name)) {
parser.RegisterTensorType(outputInd_name, ETensorType::INT64);
}

return op;

0 comments on commit 449baf1

Please sign in to comment.