Skip to content

Commit

Permalink
Add support for CUDA capability 8.6 (halide#6334)
Browse files Browse the repository at this point in the history
* Add support for CUDA capability 8.6

* add assertion to guard LLVM version

* fallback to sm80 if LLVM < 13.0
  • Loading branch information
TH3CHARLie authored Oct 21, 2021
1 parent 27f975f commit ecf69b0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/CodeGen_PTX_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,10 @@ string CodeGen_PTX_Dev::march() const {
}

string CodeGen_PTX_Dev::mcpu() const {
if (target.has_feature(Target::CUDACapability80)) {
if (target.has_feature(Target::CUDACapability86)) {
user_assert(LLVM_VERSION >= 130) << "The linked LLVM version does not support cuda compute capability 8.6\n";
return "sm_86";
} else if (target.has_feature(Target::CUDACapability80)) {
return "sm_80";
} else if (target.has_feature(Target::CUDACapability75)) {
return "sm_75";
Expand All @@ -565,7 +568,9 @@ string CodeGen_PTX_Dev::mcpu() const {
}

string CodeGen_PTX_Dev::mattrs() const {
if (target.has_feature(Target::CUDACapability80)) {
if (target.has_feature(Target::CUDACapability86)) {
return "+ptx71";
} else if (target.has_feature(Target::CUDACapability80)) {
return "+ptx70";
} else if (target.has_feature(Target::CUDACapability70) ||
target.has_feature(Target::CUDACapability75)) {
Expand Down
15 changes: 13 additions & 2 deletions src/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,10 @@ Target::Feature calculate_host_cuda_capability(Target t) {
return Target::CUDACapability70;
} else if (ver < 80) {
return Target::CUDACapability75;
} else {
} else if (ver < 86 || LLVM_VERSION < 130) {
return Target::CUDACapability80;
} else {
return Target::CUDACapability86;
}
}

Expand Down Expand Up @@ -332,6 +334,7 @@ const std::map<std::string, Target::Feature> feature_name_map = {
{"cuda_capability_70", Target::CUDACapability70},
{"cuda_capability_75", Target::CUDACapability75},
{"cuda_capability_80", Target::CUDACapability80},
{"cuda_capability_86", Target::CUDACapability86},
{"opencl", Target::OpenCL},
{"cl_doubles", Target::CLDoubles},
{"cl_half", Target::CLHalf},
Expand Down Expand Up @@ -497,7 +500,8 @@ bool merge_string(Target &t, const std::string &target) {
!t.has_feature(Target::CUDACapability61) &&
!t.has_feature(Target::CUDACapability70) &&
!t.has_feature(Target::CUDACapability75) &&
!t.has_feature(Target::CUDACapability80)) {
!t.has_feature(Target::CUDACapability80) &&
!t.has_feature(Target::CUDACapability86)) {
// Detect host cuda capability
t.set_feature(get_host_cuda_capability(t));
}
Expand Down Expand Up @@ -770,6 +774,9 @@ int Target::get_cuda_capability_lower_bound() const {
if (has_feature(Target::CUDACapability80)) {
return 80;
}
if (has_feature(Target::CUDACapability86)) {
return 86;
}
return 20;
}

Expand Down Expand Up @@ -961,6 +968,7 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result)
CUDACapability70,
CUDACapability75,
CUDACapability80,
CUDACapability86,
HVX_v62,
HVX_v65,
HVX_v66,
Expand Down Expand Up @@ -1074,6 +1082,9 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result)
if (cuda_capability < 80) {
output.features.reset(CUDACapability80);
}
if (cuda_capability < 86) {
output.features.reset(CUDACapability86);
}

// Pick tight lower bound for HVX version. Use fall-through to clear redundant features
int hvx_a = get_hvx_lower_bound(*this);
Expand Down
1 change: 1 addition & 0 deletions src/Target.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct Target {
CUDACapability70 = halide_target_feature_cuda_capability70,
CUDACapability75 = halide_target_feature_cuda_capability75,
CUDACapability80 = halide_target_feature_cuda_capability80,
CUDACapability86 = halide_target_feature_cuda_capability86,
OpenCL = halide_target_feature_opencl,
CLDoubles = halide_target_feature_cl_doubles,
CLHalf = halide_target_feature_cl_half,
Expand Down
1 change: 1 addition & 0 deletions src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@ typedef enum halide_target_feature_t {
halide_target_feature_cuda_capability70, ///< Enable CUDA compute capability 7.0 (Volta)
halide_target_feature_cuda_capability75, ///< Enable CUDA compute capability 7.5 (Turing)
halide_target_feature_cuda_capability80, ///< Enable CUDA compute capability 8.0 (Ampere)
halide_target_feature_cuda_capability86, ///< Enable CUDA compute capability 8.6 (Ampere)

halide_target_feature_opencl, ///< Enable the OpenCL runtime.
halide_target_feature_cl_doubles, ///< Enable double support on OpenCL targets
Expand Down

0 comments on commit ecf69b0

Please sign in to comment.