Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DF] Check column types in GetColumnReadersImpl() #17221

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tree/dataframe/inc/ROOT/RNTupleDS.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class RNTupleDS final : public ROOT::RDF::RDataSource {
std::unordered_map<ROOT::Experimental::DescriptorId_t, std::string> fFieldId2QualifiedName;
std::vector<std::string> fColumnNames;
std::vector<std::string> fColumnTypes;
/// Applies TClassEdit::GetNormalizedName to fColumnTypes
std::vector<std::string> fNormalizedColumnTypes;
/// List of column readers returned by GetColumnReaders() organized by slot. Used to reconnect readers
/// to new page sources when the files in the chain change.
std::vector<std::vector<Internal::RNTupleColumnReader *>> fActiveColumnReaders;
Expand Down
39 changes: 34 additions & 5 deletions tree/dataframe/src/RNTupleDS.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@ void RNTupleDS::AddField(const RNTupleDescriptor &desc, std::string_view colName
auto cardinalityField = std::make_unique<ROOT::Experimental::Internal::RRDFCardinalityField>();
cardinalityField->SetOnDiskId(fieldId);
fColumnNames.emplace_back("R_rdf_sizeof_" + std::string(colName));
fColumnTypes.emplace_back(cardinalityField->GetTypeName());
const auto typeName = cardinalityField->GetTypeName();
fColumnTypes.emplace_back(typeName);
std::string normalized;
TClassEdit::GetNormalizedName(normalized, typeName);
fNormalizedColumnTypes.emplace_back(normalized);
Comment on lines +280 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and further down: I wonder if instead of adding the normalized type name everywhere where we append to fColumnTypes, perhaps we should loop once over all the column types when they are fully build, i.e. at the end of the constructor.

fProtoFields.emplace_back(std::move(cardinalityField));

for (const auto &f : desc.GetFieldIterable(fieldDesc.GetId())) {
Expand Down Expand Up @@ -359,13 +363,21 @@ void RNTupleDS::AddField(const RNTupleDescriptor &desc, std::string_view colName

if (cardinalityField) {
fColumnNames.emplace_back("R_rdf_sizeof_" + std::string(colName));
fColumnTypes.emplace_back(cardinalityField->GetTypeName());
const auto typeName = cardinalityField->GetTypeName();
fColumnTypes.emplace_back(typeName);
std::string normalized;
TClassEdit::GetNormalizedName(normalized, typeName);
fNormalizedColumnTypes.emplace_back(normalized);
fProtoFields.emplace_back(std::move(cardinalityField));
}

fieldInfos.emplace_back(fieldId, nRepetitions);
fColumnNames.emplace_back(colName);
fColumnTypes.emplace_back(valueField->GetTypeName());
const auto typeName = valueField->GetTypeName();
fColumnTypes.emplace_back(typeName);
std::string normalized;
TClassEdit::GetNormalizedName(normalized, typeName);
fNormalizedColumnTypes.emplace_back(normalized);
fProtoFields.emplace_back(std::move(valueField));
}

Expand Down Expand Up @@ -431,13 +443,30 @@ RDF::RDataSource::Record_t RNTupleDS::GetColumnReadersImpl(std::string_view /* n
}

std::unique_ptr<ROOT::Detail::RDF::RColumnReaderBase>
RNTupleDS::GetColumnReaders(unsigned int slot, std::string_view name, const std::type_info & /*tid*/)
RNTupleDS::GetColumnReaders(unsigned int slot, std::string_view name, const std::type_info & tid)
{
// At this point we can assume that `name` will be found in fColumnNames
// TODO(jblomer): check incoming type
const auto index = std::distance(fColumnNames.begin(), std::find(fColumnNames.begin(), fColumnNames.end(), name));
auto field = fProtoFields[index].get();

std::string demangled = ROOT::Internal::RDF::DemangleTypeIdName(tid);
std::string normalized;
TClassEdit::GetNormalizedName(normalized, demangled.c_str());
if (normalized != fNormalizedColumnTypes[index]) {
std::string err = "The type of column \"";
err += name;
err += "\" is ";
err += fColumnTypes[index];
if (fColumnTypes[index] != fNormalizedColumnTypes[index])
err += " (= " + fNormalizedColumnTypes[index] + ")";
err += " but ";
err += demangled;
if (demangled != normalized)
err += " (= " + normalized + ")";
err += " has been selected";
throw std::runtime_error(err);
}

// Map the field's and subfields' IDs to qualified names so that we can later connect the fields to
// other page sources from the chain
fFieldId2QualifiedName[field->GetOnDiskId()] = fPrincipalDescriptor->GetQualifiedFieldName(field->GetOnDiskId());
Expand Down
19 changes: 18 additions & 1 deletion tree/dataframe/test/datasource_ntuple.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class RNTupleDSTest : public ::testing::Test {

void SetUp() override {
auto model = RNTupleModel::Create();
*model->MakeField<std::uint32_t>("nevent") = 1;
*model->MakeField<float>("pt") = 42;
*model->MakeField<float>("energy") = 7;
*model->MakeField<std::string>("tag") = "xyz";
Expand Down Expand Up @@ -81,8 +82,9 @@ TEST_F(RNTupleDSTest, ColTypeNames)
RNTupleDS ds(fNtplName, fFileName);

auto colNames = ds.GetColumnNames();
ASSERT_EQ(15, colNames.size());
ASSERT_EQ(16, colNames.size());

EXPECT_TRUE(ds.HasColumn("nevent"));
EXPECT_TRUE(ds.HasColumn("pt"));
EXPECT_TRUE(ds.HasColumn("energy"));
EXPECT_TRUE(ds.HasColumn("rvec"));
Expand Down Expand Up @@ -132,6 +134,21 @@ TEST_F(RNTupleDSTest, CardinalityColumn)
EXPECT_EQ(3, *max_rvec2);
}

// TODO(jblomer): this test will change once collections are read as RVecs in RNTupleDS
TEST_F(RNTupleDSTest, ReadRVec)
{
auto df = ROOT::RDF::Experimental::FromRNTuple(fNtplName, fFileName);

// Allow use of float and Float_t interchangibly
EXPECT_DOUBLE_EQ(3.0, *df.Sum<std::vector<Float_t>>("jets"));
// Allow use of std int types and ROOT int types interchangibly
EXPECT_EQ(1U, df.Take<std::uint32_t>("nevent").GetValue()[0]);
EXPECT_EQ(1U, df.Take<UInt_t>("nevent").GetValue()[0]);
// jets is currently exposed as std::vector<float> and thus not usable as ROOT::RVec<float>
EXPECT_ANY_THROW(df.Sum<ROOT::RVec<float>>("jets"));
// EXPECT_THROW(df.Sum<ROOT::RVec<float>>("jets"), std::runtime_error); // This does not work directly, maybe due to jitting ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should understand this before merging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the test is still failing even after that change, so the attempt was not useful

}

static void ReadTest(const std::string &name, const std::string &fname)
{
auto df = ROOT::RDF::Experimental::FromRNTuple(name, fname);
Expand Down
12 changes: 6 additions & 6 deletions tutorials/io/ntuple/ntpl011_global_temperatures.C
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ void Analyze()
auto max_value = df.Max("AverageTemperature");

// Functions to filter by each season from date formatted "1944-12-01."
auto fnWinter = [](int month) { return month == 12 || month == 1 || month == 2; };
auto fnSpring = [](int month) { return month == 3 || month == 4 || month == 5; };
auto fnSummer = [](int month) { return month == 6 || month == 7 || month == 8; };
auto fnFall = [](int month) { return month == 9 || month == 10 || month == 11; };
auto fnWinter = [](std::uint32_t month) { return month == 12 || month == 1 || month == 2; };
auto fnSpring = [](std::uint32_t month) { return month == 3 || month == 4 || month == 5; };
auto fnSummer = [](std::uint32_t month) { return month == 6 || month == 7 || month == 8; };
auto fnFall = [](std::uint32_t month) { return month == 9 || month == 10 || month == 11; };

// Create a RDataFrame per season.
auto dfWinter = df.Filter(fnWinter, {"Month"});
Expand All @@ -164,8 +164,8 @@ void Analyze()
auto fallCount = dfFall.Count();

// Functions to filter for the time period between 2003-2013, and 1993-2002.
auto fn1993_to_2002 = [](int year) { return year >= 1993 && year <= 2002; };
auto fn2003_to_2013 = [](int year) { return year >= 2003 && year <= 2013; };
auto fn1993_to_2002 = [](std::uint32_t year) { return year >= 1993 && year <= 2002; };
auto fn2003_to_2013 = [](std::uint32_t year) { return year >= 2003 && year <= 2013; };

// Create a RDataFrame for decades 1993_to_2002 & 2003_to_2013.
auto df1993_to_2002 = df.Filter(fn1993_to_2002, {"Year"});
Expand Down
Loading