Skip to content

Commit

Permalink
Make gpu_executor.h only used by RocmExecutor and CudaExecutor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695872718
  • Loading branch information
klucke authored and Google-ML-Automation committed Nov 14, 2024
1 parent f25e9a0 commit 6924030
Show file tree
Hide file tree
Showing 11 changed files with 21 additions and 24 deletions.
2 changes: 1 addition & 1 deletion xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1343,9 +1343,9 @@ cc_library(
"//xla/stream_executor:launch_dim",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:typed_kernel_factory",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:scoped_gpu_graph_exec",
"//xla/stream_executor/gpu:scoped_update_mode",
Expand Down
4 changes: 2 additions & 2 deletions xla/stream_executor/cuda/cuda_command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_status.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/scoped_update_mode.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/semantic_version.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/typed_kernel_factory.h" // IWYU pragma: keep
#include "tsl/platform/casts.h"
#include "tsl/platform/env.h"
Expand Down Expand Up @@ -135,7 +135,7 @@ absl::Status GraphInstantiate(CUgraphExec* exec, CUgraph graph) {
} // namespace

absl::StatusOr<std::unique_ptr<CudaCommandBuffer>> CudaCommandBuffer::Create(
Mode mode, GpuExecutor* parent, CudaContext* cuda_context) {
Mode mode, StreamExecutor* parent, CudaContext* cuda_context) {
TF_ASSIGN_OR_RETURN(CUgraph graph, CreateGraph());
return std::unique_ptr<CudaCommandBuffer>(
new CudaCommandBuffer(mode, parent, cuda_context, graph,
Expand Down
11 changes: 6 additions & 5 deletions xla/stream_executor/cuda/cuda_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_context.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/scoped_gpu_graph_exec.h"
#include "xla/stream_executor/gpu/scoped_update_mode.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"

namespace stream_executor::gpu {

Expand All @@ -46,13 +46,14 @@ class CudaCommandBuffer final : public GpuCommandBuffer {
public:
// Creates a new CUDA command buffer and the underlying CUDA graph.
static absl::StatusOr<std::unique_ptr<CudaCommandBuffer>> Create(
Mode mode, GpuExecutor* parent, CudaContext* cuda_context);
Mode mode, StreamExecutor* parent, CudaContext* cuda_context);

~CudaCommandBuffer() override;

private:
CudaCommandBuffer(Mode mode, GpuExecutor* parent, CudaContext* cuda_context,
CUgraph graph, bool is_owned_graph)
CudaCommandBuffer(Mode mode, StreamExecutor* parent,
CudaContext* cuda_context, CUgraph graph,
bool is_owned_graph)
: GpuCommandBuffer(mode, parent),
parent_(parent),
cuda_context_(cuda_context),
Expand Down Expand Up @@ -184,7 +185,7 @@ class CudaCommandBuffer final : public GpuCommandBuffer {
SetWhileConditionKernel set_while_condition_kernel_;
NoOpKernel noop_kernel_;

GpuExecutor* parent_;
StreamExecutor* parent_;

CudaContext* cuda_context_;

Expand Down
4 changes: 1 addition & 3 deletions xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ gpu_only_cc_library(
],
deps = [
":gpu_executor_header",
":gpu_types_header",
":scoped_update_mode",
"//xla/stream_executor:bit_pattern",
"//xla/stream_executor:command_buffer",
Expand All @@ -176,8 +175,8 @@ gpu_only_cc_library(
"//xla/stream_executor:launch_dim",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/cuda:cuda_platform_id",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log",
Expand All @@ -187,7 +186,6 @@ gpu_only_cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@local_config_cuda//cuda:cuda_headers",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/gpu/gpu_command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ static std::atomic<int64_t> alive_execs(0);
// GpuCommandBuffer implementation
//===----------------------------------------------------------------------===//

GpuCommandBuffer::GpuCommandBuffer(Mode mode, GpuExecutor* parent)
GpuCommandBuffer::GpuCommandBuffer(Mode mode, StreamExecutor* parent)
: mode_(mode), parent_(parent) {
execution_scopes_.try_emplace(kDefaulExecutionScope);
}
Expand Down
6 changes: 3 additions & 3 deletions xla/stream_executor/gpu/gpu_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ limitations under the License.
#include "xla/stream_executor/bit_pattern.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/scoped_update_mode.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/stream_executor.h"

namespace stream_executor::gpu {

Expand Down Expand Up @@ -96,7 +96,7 @@ class GpuCommandBuffer : public CommandBuffer {
size_t nodes_offset = 0;
};

GpuCommandBuffer(Mode mode, GpuExecutor* parent);
GpuCommandBuffer(Mode mode, StreamExecutor* parent);

absl::Status Barrier(ExecutionScopeId execution_scope_id) override;

Expand Down Expand Up @@ -319,7 +319,7 @@ class GpuCommandBuffer : public CommandBuffer {
Mode mode_;
State state_ = State::kCreate;

GpuExecutor* parent_; // not owned, must outlive *this
StreamExecutor* parent_; // not owned, must outlive *this

private:
// ExecutionScope holds the state of an underlying CUDA graph (nodes an
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1145,8 +1145,8 @@ cc_library(
"//xla/stream_executor:device_memory",
"//xla/stream_executor:kernel",
"//xla/stream_executor:launch_dim",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:scoped_gpu_graph_exec",
"//xla/stream_executor/gpu:scoped_update_mode",
Expand Down
4 changes: 2 additions & 2 deletions xla/stream_executor/rocm/rocm_command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ limitations under the License.
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/scoped_update_mode.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/rocm/rocm_driver_wrapper.h"
#include "xla/stream_executor/rocm/rocm_kernel.h"
#include "xla/stream_executor/rocm/rocm_status.h"
#include "xla/stream_executor/stream_executor.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -94,7 +94,7 @@ GraphNodeHandle FromHipGraphHandle(hipGraphNode_t handle) {
} // namespace

absl::StatusOr<std::unique_ptr<RocmCommandBuffer>> RocmCommandBuffer::Create(
Mode mode, GpuExecutor* parent) {
Mode mode, StreamExecutor* parent) {
TF_ASSIGN_OR_RETURN(hipGraph_t graph, CreateGraph());
return std::unique_ptr<RocmCommandBuffer>(
new RocmCommandBuffer(mode, parent, graph,
Expand Down
8 changes: 4 additions & 4 deletions xla/stream_executor/rocm/rocm_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ limitations under the License.
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/scoped_gpu_graph_exec.h"
#include "xla/stream_executor/gpu/scoped_update_mode.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/stream_executor.h"

namespace stream_executor::gpu {

Expand All @@ -43,12 +43,12 @@ class RocmCommandBuffer : public GpuCommandBuffer {
public:
// Creates a new ROCm command buffer and the underlying HIP graph.
static absl::StatusOr<std::unique_ptr<RocmCommandBuffer>> Create(
Mode mode, GpuExecutor* parent);
Mode mode, StreamExecutor* parent);

~RocmCommandBuffer() override;

private:
RocmCommandBuffer(Mode mode, GpuExecutor* parent, hipGraph_t graph,
RocmCommandBuffer(Mode mode, StreamExecutor* parent, hipGraph_t graph,
bool is_owned_graph)
: GpuCommandBuffer(mode, parent),
parent_(parent),
Expand Down Expand Up @@ -145,7 +145,7 @@ class RocmCommandBuffer : public GpuCommandBuffer {
absl::StatusOr<std::vector<GraphNodeHandle>> GetNodeDependencies(
GraphNodeHandle node) override;

GpuExecutor* parent_;
StreamExecutor* parent_;

static_assert(std::is_pointer_v<hipGraph_t>, "hipGraph_t must be a pointer");
static_assert(std::is_pointer_v<hipGraphExec_t>,
Expand Down
1 change: 0 additions & 1 deletion xla/stream_executor/sycl/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ cc_library(
"//xla/stream_executor:executor_cache",
"//xla/stream_executor/platform:initialize",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/stream_executor/gpu:gpu_executor_header",
"@tsl//tsl/platform:errors",
]),
alwayslink = True, # Registers itself with the PlatformManager.
Expand Down
1 change: 0 additions & 1 deletion xla/stream_executor/sycl/sycl_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/platform_manager.h"
Expand Down

0 comments on commit 6924030

Please sign in to comment.