Skip to content

Commit

Permalink
Check if the handle is nullptr, and fail early instead of segfaulting.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 207176253
  • Loading branch information
Akshay Modi authored and tensorflower-gardener committed Aug 2, 2018
1 parent f8afea0 commit daaaab2
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tensorflow/c/eager/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,26 +348,46 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
}

int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
int result;
status->status = h->handle->NumDims(&result);
return result;
}

int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result);
return result;
}

const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = nullptr;
status->status = h->handle->OpDevice(&d);
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}

TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
tensorflow::Device* d = nullptr;
tensorflow::Device* op_device = nullptr;
Expand Down
36 changes: 36 additions & 0 deletions tensorflow/c/eager/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,42 @@ void SetAndGetOpDevices(bool async) {
TF_DeleteStatus(status);
}

TEST(CAPI, TensorHandleNullptr) {
TFE_TensorHandle* h = nullptr;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);

TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(t, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));

TF_SetStatus(status.get(), TF_OK, "");

const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_name, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));

TF_SetStatus(status.get(), TF_OK, "");

int num_dims = TFE_TensorHandleNumDims(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(num_dims, -1);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));

TF_SetStatus(status.get(), TF_OK, "");

int dim = TFE_TensorHandleDim(h, 0, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(dim, -1);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
}

void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
Expand Down

0 comments on commit daaaab2

Please sign in to comment.