Skip to content

Commit

Permalink
Resize Bicubic
Browse files Browse the repository at this point in the history
  • Loading branch information
mochen.bmc committed Nov 6, 2023
1 parent 4f83816 commit 28a958a
Show file tree
Hide file tree
Showing 15 changed files with 821 additions and 7 deletions.
2 changes: 1 addition & 1 deletion xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"this flag to false."));
flag_list->push_back(tsl::Flag(
"xla_multiheap_size_constraint_per_heap",
int32_setter_for(
int64_setter_for(
&DebugOptions::set_xla_multiheap_size_constraint_per_heap),
debug_options->xla_multiheap_size_constraint_per_heap(),
"Generates multiple heaps (i.e., temp buffers) with a size "
Expand Down
4 changes: 3 additions & 1 deletion xla/pjrt/gpu/gpu_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ void EnablePeerAccess(absl::Span<se::StreamExecutor* const> executors) {

// Builds a BFCAllocator for all local GPUs.
StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
se::StreamExecutor* executor, double memory_fraction, bool preallocate) {
se::StreamExecutor* executor, double memory_fraction, bool preallocate,
bool garbage_collection) {
bool enable_unified_memory;
Status status = tsl::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY", false,
&enable_unified_memory);
Expand Down Expand Up @@ -111,6 +112,7 @@ StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(

tsl::BFCAllocator::Options opts;
opts.allow_growth = !preallocate;
opts.garbage_collection = garbage_collection;
return std::make_unique<tsl::BFCAllocator>(
std::move(sub_allocator), allocator_memory,
absl::StrCat("GPU_", device_ordinal, "_bfc"), opts);
Expand Down
6 changes: 5 additions & 1 deletion xla/pjrt/gpu/gpu_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ struct GpuAllocatorConfig {
// fragmentation, allowing more of the total memory to be used. If false, the
// allocator will allocate more memory as allocations are requested.
bool preallocate = true;

// activate garbage collection or not
bool garbage_collection = false;
};

std::unique_ptr<tsl::BFCAllocator> GetGpuHostAllocator(
se::StreamExecutor* executor);

// Builds a BFCAllocator for all local GPUs.
StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
se::StreamExecutor* executor, double memory_fraction, bool preallocate);
se::StreamExecutor* executor, double memory_fraction, bool preallocate,
bool garbage_collection);

} // namespace xla

Expand Down
3 changes: 2 additions & 1 deletion xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,8 @@ GetStreamExecutorGpuDeviceAllocator(
auto bfc_allocator,
CreateBFCAllocator(ordinal_and_device.second->executor(),
allocator_config.memory_fraction,
allocator_config.preallocate));
allocator_config.preallocate,
allocator_config.garbage_collection));
allocators_and_streams.emplace_back(
std::move(bfc_allocator),
ordinal_and_device.second->compute_stream());
Expand Down
2 changes: 1 addition & 1 deletion xla/service/buffer_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2019,7 +2019,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
buffers_to_assign_sequentially.size() == global_computations.size();
VLOG(2) << "Running whole module heap simulation: "
<< run_whole_module_heap_simulation;
const int32_t multiheap_size_constraint_per_heap =
const int64_t multiheap_size_constraint_per_heap =
module->config().debug_options().xla_multiheap_size_constraint_per_heap();
VLOG(2) << "Multiheap per heap size limit: "
<< multiheap_size_constraint_per_heap;
Expand Down
2 changes: 1 addition & 1 deletion xla/service/buffer_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ class BufferAssignment {
color_alignment_(std::move(color_alignment)),
alias_analysis_(std::move(alias_analysis)),
hlo_live_range_(std::move(hlo_live_range)) {
int32_t raw_value = module->config()
int64_t raw_value = module->config()
.debug_options()
.xla_multiheap_size_constraint_per_heap();
// -1 means no constraint.
Expand Down
80 changes: 80 additions & 0 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ cc_library(
":stream_synchronization",
":support",
":topk",
":resize_bicubic",
":tracing",
"//xla:statusor",
"//xla:xla_proto_cc",
Expand Down Expand Up @@ -417,6 +418,85 @@ cc_library(
],
)

cc_library(
name = "resize_bicubic_kernel",
srcs = if_cuda_is_configured(
[
"resize_bicubic_kernel.cc",
],
),
hdrs = if_cuda_is_configured(["resize_bicubic_kernel.h"]),
compatible_with = [],
deps = [
":resize_bicubic_kernel_cuda",
# "//xla:shape_util",
"//xla:xla_proto_cc",
"//xla:xla_data_proto_cc",
"//xla/runtime:memref_view",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream_executor_headers", # build_cleaner: keep
"//xla/stream_executor/gpu:gpu_stream_header",
"//xla/stream_executor/gpu:gpu_types_header",
"@com_google_absl//absl/numeric:bits",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_config_cuda//cuda:cuda_headers",
],
)

cuda_library(
name = "resize_bicubic_kernel_cuda",
srcs = if_cuda_is_configured(
[
"resize_bicubic_kernel.cu.cc",
],
),
hdrs = if_cuda_is_configured(["resize_bicubic_kernel_common.h"]),
compatible_with = [],
deps = [
"@eigen_archive//:eigen3",
"@local_config_cuda//cuda:cuda_headers",
"@com_google_absl//absl/types:span",
],
)


cc_library(
name = "resize_bicubic",
srcs = if_cuda_is_configured(
["resize_bicubic.cc"],
),
hdrs = ["resize_bicubic.h"],
deps = if_cuda_is_configured([":resize_bicubic_kernel"]) + [
":support",
"//xla:executable_run_options",
# "//xla:shape_util",
"//xla:status",
"//xla:statusor",
# "//xla:types",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
# "//xla/mlir/runtime/transforms:custom_call_encoding",
"//xla/runtime:custom_call",
"//xla/runtime:custom_call_registry",
"//xla/runtime:executable",
"//xla/runtime:state",
# "//xla/runtime/ffi:ffi_api",
# "//xla/runtime/ffi:ffi_c_api_hdrs",
"//xla/service:executable",
"//xla/service:hlo_pass",
"//xla/service:tuple_util",
"//xla/stream_executor/gpu:gpu_stream_header",
"//xla/stream_executor/gpu:gpu_types_header",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "gemm",
srcs = ["gemm.cc"],
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ limitations under the License.
#include "xla/service/gpu/runtime/stream_synchronization.h"
#include "xla/service/gpu/runtime/support.h"
#include "xla/service/gpu/runtime/topk.h"
#include "xla/service/gpu/runtime/resize_bicubic.h"
#include "xla/service/gpu/runtime/tracing.h"
#include "xla/service/gpu/thunk.h"
#include "xla/service/service_executable_run_options.h"
Expand Down Expand Up @@ -87,6 +88,7 @@ void RegisterXlaGpuRuntimeCustomCalls(DirectCustomCallRegistry& registry) {
RegisterMemsetCustomCalls(registry);
RegisterSendRecvCustomCalls(registry);
RegisterTopkCustomCall(registry);
RegisterResizeBicubicCustomCall(registry);

#if GOOGLE_CUDA || TF_HIPBLASLT
RegisterMatmulCustomCalls(registry);
Expand Down
88 changes: 88 additions & 0 deletions xla/service/gpu/runtime/resize_bicubic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/runtime/resize_bicubic.h"

#include <stdint.h>

#include <cstddef>

#include "absl/status/status.h"
#include "absl/types/span.h"
#include "xla/runtime/custom_call.h"
#include "xla/runtime/executable.h"
#include "xla/service/gpu/runtime/resize_bicubic_kernel.h"
// #include "xla/runtime/custom_call_registry.h"

#include "xla/service/gpu/runtime/support.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/xla_data.pb.h"

namespace xla::gpu {
using ::xla::runtime::CustomCall;
using ::xla::runtime::StridedMemrefView;

static absl::Status ResizeBicubicImpl(
const ServiceExecutableRunOptions* run_options, StridedMemrefView input,
StridedMemrefView output, bool align_corners) {
float scales_h =
static_cast<float>(output.sizes[2]) / static_cast<float>(input.sizes[2]);
float scales_w =
static_cast<float>(output.sizes[3]) / static_cast<float>(input.sizes[3]);
se::StreamExecutor* executor = run_options->stream()->parent();
return RunResizeBicubicImpl(
se::gpu::AsGpuStreamValue(run_options->stream()),
executor->GetDeviceDescription().threads_per_block_limit(), input, output,
align_corners, scales_h, scales_w);
}

static absl::Status ResizeBicubicGradImpl(
const ServiceExecutableRunOptions* run_options,
StridedMemrefView grad_output, StridedMemrefView grad_input,
bool align_corners) {
float scales_h = static_cast<float>(grad_output.sizes[2]) /
static_cast<float>(grad_input.sizes[2]);
float scales_w = static_cast<float>(grad_output.sizes[3]) /
static_cast<float>(grad_input.sizes[3]);
se::StreamExecutor* executor = run_options->stream()->parent();
return RunResizeBicubicGradImpl(
se::gpu::AsGpuStreamValue(run_options->stream()),
executor->GetDeviceDescription().threads_per_block_limit(), grad_input,
grad_output, align_corners, scales_h, scales_w);
}

XLA_RUNTIME_DEFINE_CUSTOM_CALL(
ResizeBicubic, FunctionWrapper<ResizeBicubicImpl>(), checks,
CustomCall::Bind("__gpu$ResizeBicubic")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<StridedMemrefView>() // input
.Arg<StridedMemrefView>() // output
.Attr<bool>("align_corners"));

XLA_RUNTIME_DEFINE_CUSTOM_CALL(
ResizeBicubicGrad, FunctionWrapper<ResizeBicubicGradImpl>(), checks,
CustomCall::Bind("__gpu$ResizeBicubicGrad")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<StridedMemrefView>() // grad_output
.Arg<StridedMemrefView>() // grad_input
.Attr<bool>("align_corners"));

void RegisterResizeBicubicCustomCall(
runtime::DirectCustomCallRegistry& registry) {
registry.Register("__gpu$ResizeBicubic", ResizeBicubic);
registry.Register("__gpu$ResizeBicubicGrad", ResizeBicubicGrad);
}

} // namespace xla::gpu
28 changes: 28 additions & 0 deletions xla/service/gpu/runtime/resize_bicubic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_RUNTIME_RESIZE_BICUBIC_H_
#define XLA_SERVICE_GPU_RUNTIME_RESIZE_BICUBIC_H_

#include "xla/runtime/custom_call_registry.h"

namespace xla::gpu {

// Registers XLA Gpu runtime TopK custom calls.
void RegisterResizeBicubicCustomCall(runtime::DirectCustomCallRegistry& registry);

} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_RUNTIME_RESIZE_BICUBIC_H_
Loading

0 comments on commit 28a958a

Please sign in to comment.