Skip to content

Commit

Permalink
Respect log_device_placement when op is executed remotely.
Browse files Browse the repository at this point in the history
For logging copies, we can set the device_policy to DEVICE_PLACEMENT_WARN

PiperOrigin-RevId: 207186848
  • Loading branch information
Akshay Modi authored and tensorflower-gardener committed Aug 2, 2018
1 parent eb31b80 commit 4bc5c6c
Showing 1 changed file with 5 additions and 224 deletions.
229 changes: 5 additions & 224 deletions tensorflow/core/common_runtime/eager/execute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,222 +211,6 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
ndef.DebugString());
}

#ifdef TENSORFLOW_EAGER_USE_XLA
// Synthesizes and returns a wrapper function over `op`, which must be a
// primitive op (e.g. matmul).
//
// The wrapper function conforms to the function signature expected by
// XlaLaunch, with input params ordered by <constants, (variable) args and
// resources>. For example, if the op has input params <Const1, Arg2, Const3,
// Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
// Resource4> as the input params to the synthesized function.
//
// It populates `const_input_types`, `arg_input_types` and
// `op_input_to_func_input` based on the reordering results, that the caller
// can use them to build an XlaLaunch. On error, it returns NULL, and sets
// `status` accordingly.
const FunctionDef* OpToFunction(TFE_Op* op,
std::vector<TF_DataType>* const_input_types,
std::vector<TF_DataType>* arg_input_types,
gtl::FlatMap<int, int>* op_input_to_func_input,
TF_Status* status) {
DCHECK(!op->operation.is_function());

FunctionDef fdef;

// Get the OpDef of the op we are trying to encapsulate.
TFE_Context* ctx = op->operation.ctx;
const OpRegistrationData* op_data;
{
status = ctx->context.FindFunctionOpData(op->operation.Name(), &op_data);
if (!status.ok()) {
return nullptr;
}
}
const OpDef& op_def = op_data->op_def;

OpDef* signature = fdef.mutable_signature();

// Handle constant inputs.
const std::unordered_set<string> const_inputs(
*XlaOpRegistry::CompileTimeConstantInputs(op->operation.Name()));

// First add place holders for the input args, so that we can refer to them
// by position in the next loop. Also tally up the resource inputs.
int num_resource_inputs = 0;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
if (op_def.input_arg(i).type() == DT_RESOURCE) {
++num_resource_inputs;
}
signature->add_input_arg();
}

// Now we map the input params from `op_def` to `signature`, where the param
// ordering for `signature` is: <constants, args, resources>.
int const_index = 0;
int arg_index = const_inputs.size();
int resource_index = op_def.input_arg_size() - num_resource_inputs;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
OpDef::ArgDef* func_input_arg = nullptr;
if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
VLOG(1) << "For const input, mapping op input " << i << " to func input "
<< const_index;
(*op_input_to_func_input)[i] = const_index;
func_input_arg = signature->mutable_input_arg(const_index++);
const_input_types->push_back(
static_cast<TF_DataType>(op->operation.Inputs()[i]->dtype));
} else if (op_input_arg.type() == DT_RESOURCE) {
VLOG(1) << "For resource input, mapping op input " << i
<< " to func input " << resource_index;
(*op_input_to_func_input)[i] = resource_index;
func_input_arg = signature->mutable_input_arg(resource_index++);
} else {
VLOG(1) << "For arg input, mapping op input " << i << " to func input "
<< arg_index;
(*op_input_to_func_input)[i] = arg_index;
func_input_arg = signature->mutable_input_arg(arg_index++);
arg_input_types->push_back(
static_cast<TF_DataType>(op->operation.Inputs()[i]->dtype));
}

func_input_arg->set_name(op_input_arg.name());
func_input_arg->set_type(op->operation.Inputs()[i]->dtype);
}
VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();

// Resources args are at the end of the function input params, and we should
// have iterated over all of them.
DCHECK_EQ(signature->input_arg_size(), resource_index);

// Make the synthesized function's name unique.
signature->set_name(
strings::StrCat(op_def.name(), func_id_generator.fetch_add(1)));

// Add the node def and set its input names to match op_def's names.
const NodeDef& ndef = op->operation.MutableAttrs()->BuildNodeDef();
DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
*fdef.add_node_def() = ndef;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
}
VLOG(1) << "Added NodeDef: " << fdef.DebugString();

// Fix the output names and set output types.
for (int i = 0; i < op_def.output_arg_size(); ++i) {
OpDef::ArgDef* arg = signature->add_output_arg();
const OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
const string& out_tensor_name =
strings::StrCat(ndef.name(), ":", op_def_arg.name(), ":", 0);
arg->set_name(op_def_arg.name());
(*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
const string& type_attr = op_def_arg.type_attr();
if (!type_attr.empty()) {
auto i = ndef.attr().find(type_attr);
if (i == ndef.attr().end()) {
status = errors::InvalidArgument(
strings::StrCat("Could not find attr ", type_attr, " in NodeDef ",
ndef.DebugString()));
return nullptr;
}
arg->set_type(i->second.type());
}
}
VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();

status = ctx->context.AddFunctionDef(fdef);
if (!status.ok()) return nullptr;
const auto ret = ctx->context.FindFunctionDef(signature->name());
DCHECK(ret != nullptr);
return ret;
}

// Builds an XlaLaunch as a wrapper over 'op', so that 'op' can be executed
// via XLA.
std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
VLOG(1) << "Creating XlaLaunch for TFE_Op " << op->operation.Name();
auto launch_op = std::unique_ptr<TFE_Op>(
TFE_NewOp(op->operation.ctx, "XlaLaunch", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
if (op->operation.device) {
TFE_OpSetDevice(launch_op.get(), op->operation.device->name().c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}

const FunctionDef* fdef;
{ fdef = op->operation.ctx->FindFunctionDef(op->operation.Name()); }
std::vector<TF_DataType> const_input_types;
std::vector<TF_DataType> arg_input_types;
gtl::FlatMap<int, int> op_input_to_func_input;
if (fdef == nullptr) {
// See if this is a primitive op, and if so create a function for it, so
// that XlaLaunch can access it.
fdef = OpToFunction(op, &const_input_types, &arg_input_types,
&op_input_to_func_input, status);
if (!status.ok()) return nullptr;
} else {
// TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work
// for functions, so we need to find another way to handle constant
// inputs.
for (int i = const_input_types.size();
i < fdef->signature().input_arg_size(); ++i) {
VLOG(1) << "Adding Targs from input arg " << i;
const OpDef::ArgDef& arg = fdef->signature().input_arg(i);
arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
}
}
DCHECK(fdef != nullptr);

// Copy inputs and their devices.
// Since input param reordering may have occurred between `op` and
// `launch_op` via `op_input_to_func_input`, adjust the actual inputs
// accordingly.
*launch_op->operation.MutableInputs() = op->operation.Inputs();
for (TensorHandle* h : launch_op->operation.Inputs()) {
h->Ref();
}
if (!op_input_to_func_input.empty()) {
DCHECK_EQ(op->operation.Inputs().size(), op_input_to_func_input.size());
for (int i = 0; i < op_input_to_func_input.size(); ++i) {
VLOG(1) << "mapping op input " << i << " to func input "
<< op_input_to_func_input[i];

(*launch_op->operation.MuableInputs())[op_input_to_func_input[i]] =
op->operation.Inputs()[i];
}
}
launch_op->operation.MutableAttrs()->NumInputs(op->operation.Inputs().size());

TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
const_input_types.size());

// Set Targs and Nresources attrs.
TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
arg_input_types.size());
const int num_resource_inputs = fdef->signature().input_arg_size() -
const_input_types.size() -
arg_input_types.size();
TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);

// Set Tresults attr.
std::vector<TF_DataType> tresults;
for (const OpDef::ArgDef& arg : fdef->signature().output_arg()) {
tresults.push_back(static_cast<TF_DataType>(arg.type()));
}
TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
tresults.size());

// Set function attr.
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(fdef->signature().name());
launch_op->attrs.Set("function", attr_value);

return launch_op;
}
#endif // TENSORFLOW_EAGER_USE_XLA

Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
const auto& node_def = op->MutableAttrs()->BuildNodeDef();
const OpDef* op_def = nullptr;
Expand Down Expand Up @@ -467,14 +251,6 @@ Status EagerLocalExecute(EagerOperation* op,
EagerContext* ctx = op->EagerContext();
auto status = ctx->GetStatus();
if (!status.ok()) return status;
#ifdef TENSORFLOW_EAGER_USE_XLA
std::unique_ptr<TFE_Op> xla_launch_op;
if (op->UseXla() && op->Name() != "XlaLaunch") {
xla_launch_op = BuildXlaLaunch(op, status);
if (!status.ok()) return status;
op = xla_launch_op.get();
}
#endif // TENSORFLOW_EAGER_USE_XLA
// Ensure all resource-touching ops run in the device the resource is,
// regardless of anything else that has been specified. This is identical to
// the graph mode behavior.
Expand Down Expand Up @@ -827,6 +603,11 @@ Status EagerExecute(EagerOperation* op,
return EagerLocalExecute(op, retvals, num_retvals);
}

if (op->EagerContext()->LogDevicePlacement()) {
LOG(INFO) << "Executing op " << op->Name() << " in device "
<< op->Device()->name();
}

return EagerRemoteExecute(op, retvals->data(), num_retvals);
}

Expand Down

0 comments on commit 4bc5c6c

Please sign in to comment.