Extended subgroups - use 128bit masks (#1215)

* Extended subgroups - use 128bit masks

* Refactoring to avoid kernels code duplication

* unification kernel names as test_ prefix +subgroups function name
* use string literals that improve readability
* use kernel templates that limit code duplication
* WorkGroupParams allows define default kernel - kernel template for multiple functions
* WorkGroupParams allows define  kernel for specific one subgroup function

Co-authored-by: Stuart Brady <stuart.brady@arm.com>
This commit is contained in:
Grzegorz Wawiorko
2021-10-01 12:28:37 +02:00
committed by GitHub
parent 903f1bf65d
commit 92844bead1
12 changed files with 592 additions and 1054 deletions

View File

@@ -24,31 +24,172 @@
#include <limits>
#include <vector>
#include <type_traits>
#include <bitset>
#include <regex>
#include <map>
#define NR_OF_ACTIVE_WORK_ITEMS 4
extern MTdata gMTdata;
typedef std::bitset<128> bs128;
extern cl_half_rounding_mode g_rounding_mode;
struct WorkGroupParams
{
WorkGroupParams(size_t gws, size_t lws,
const std::vector<uint32_t> &all_wim = {})
bool use_mask = false)
: global_workgroup_size(gws), local_workgroup_size(lws),
all_work_item_masks(all_wim)
use_masks(use_mask)
{
subgroup_size = 0;
work_items_mask = 0;
use_core_subgroups = true;
dynsc = 0;
load_masks();
}
size_t global_workgroup_size;
size_t local_workgroup_size;
size_t subgroup_size;
uint32_t work_items_mask;
bs128 work_items_mask;
int dynsc;
bool use_core_subgroups;
std::vector<uint32_t> all_work_item_masks;
std::vector<bs128> all_work_item_masks;
bool use_masks;
void save_kernel_source(const std::string &source, std::string name = "")
{
if (name == "")
{
name = "default";
}
if (kernel_function_name.find(name) != kernel_function_name.end())
{
log_info("Kernel definition duplication. Source will be "
"overwritten for function name %s",
name.c_str());
}
kernel_function_name[name] = source;
};
// return specific defined kernel or default.
std::string get_kernel_source(std::string name)
{
if (kernel_function_name.find(name) == kernel_function_name.end())
{
return kernel_function_name["default"];
}
return kernel_function_name[name];
}
private:
std::map<std::string, std::string> kernel_function_name;
void load_masks()
{
if (use_masks)
{
// 1 in string will be set 1, 0 will be set 0
bs128 mask_0xf0f0f0f0("11110000111100001111000011110000"
"11110000111100001111000011110000"
"11110000111100001111000011110000"
"11110000111100001111000011110000",
128, '0', '1');
all_work_item_masks.push_back(mask_0xf0f0f0f0);
// 1 in string will be set 0, 0 will be set 1
bs128 mask_0x0f0f0f0f("11110000111100001111000011110000"
"11110000111100001111000011110000"
"11110000111100001111000011110000"
"11110000111100001111000011110000",
128, '1', '0');
all_work_item_masks.push_back(mask_0x0f0f0f0f);
bs128 mask_0x5555aaaa("10101010101010101010101010101010"
"10101010101010101010101010101010"
"10101010101010101010101010101010"
"10101010101010101010101010101010",
128, '0', '1');
all_work_item_masks.push_back(mask_0x5555aaaa);
bs128 mask_0xaaaa5555("10101010101010101010101010101010"
"10101010101010101010101010101010"
"10101010101010101010101010101010"
"10101010101010101010101010101010",
128, '1', '0');
all_work_item_masks.push_back(mask_0xaaaa5555);
// 0x0f0ff0f0
bs128 mask_0x0f0ff0f0("00001111000011111111000011110000"
"00001111000011111111000011110000"
"00001111000011111111000011110000"
"00001111000011111111000011110000",
128, '0', '1');
all_work_item_masks.push_back(mask_0x0f0ff0f0);
// 0xff0000ff
bs128 mask_0xff0000ff("11111111000000000000000011111111"
"11111111000000000000000011111111"
"11111111000000000000000011111111"
"11111111000000000000000011111111",
128, '0', '1');
all_work_item_masks.push_back(mask_0xff0000ff);
// 0xff00ff00
bs128 mask_0xff00ff00("11111111000000001111111100000000"
"11111111000000001111111100000000"
"11111111000000001111111100000000"
"11111111000000001111111100000000",
128, '0', '1');
all_work_item_masks.push_back(mask_0xff00ff00);
// 0x00ffff00
bs128 mask_0x00ffff00("00000000111111111111111100000000"
"00000000111111111111111100000000"
"00000000111111111111111100000000"
"00000000111111111111111100000000",
128, '0', '1');
all_work_item_masks.push_back(mask_0x00ffff00);
// 0x80 1 workitem highest id for 8 subgroup size
bs128 mask_0x80808080("10000000100000001000000010000000"
"10000000100000001000000010000000"
"10000000100000001000000010000000"
"10000000100000001000000010000000",
128, '0', '1');
all_work_item_masks.push_back(mask_0x80808080);
// 0x8000 1 workitem highest id for 16 subgroup size
bs128 mask_0x80008000("10000000000000001000000000000000"
"10000000000000001000000000000000"
"10000000000000001000000000000000"
"10000000000000001000000000000000",
128, '0', '1');
all_work_item_masks.push_back(mask_0x80008000);
// 0x80000000 1 workitem highest id for 32 subgroup size
bs128 mask_0x80000000("10000000000000000000000000000000"
"10000000000000000000000000000000"
"10000000000000000000000000000000"
"10000000000000000000000000000000",
128, '0', '1');
all_work_item_masks.push_back(mask_0x80000000);
// 0x80000000 00000000 1 workitem highest id for 64 subgroup size
// 0x80000000 1 workitem highest id for 32 subgroup size
bs128 mask_0x8000000000000000("10000000000000000000000000000000"
"00000000000000000000000000000000"
"10000000000000000000000000000000"
"00000000000000000000000000000000",
128, '0', '1');
all_work_item_masks.push_back(mask_0x8000000000000000);
// 0x80000000 00000000 00000000 00000000 1 workitem highest id for
// 128 subgroup size
bs128 mask_0x80000000000000000000000000000000(
"10000000000000000000000000000000"
"00000000000000000000000000000000"
"00000000000000000000000000000000"
"00000000000000000000000000000000",
128, '0', '1');
all_work_item_masks.push_back(
mask_0x80000000000000000000000000000000);
bs128 mask_0xffffffff("11111111111111111111111111111111"
"11111111111111111111111111111111"
"11111111111111111111111111111111"
"11111111111111111111111111111111",
128, '0', '1');
all_work_item_masks.push_back(mask_0xffffffff);
}
}
};
enum class SubgroupsBroadcastOp
@@ -1267,11 +1408,23 @@ template <typename Ty, typename Fns, size_t TSIZE = 0> struct test
std::vector<Ty> mapout;
mapout.resize(local);
std::stringstream kernel_sstr;
if (test_params.work_items_mask != 0)
if (test_params.use_masks)
{
kernel_sstr << "#define WORK_ITEMS_MASK ";
kernel_sstr << "0x" << std::hex << test_params.work_items_mask
<< "\n";
// Prapare uint4 type to store bitmask on kernel OpenCL C side
// To keep order the first characet in string is the lowest bit
// there was a need to give such offset to bitset constructor
// (first highest offset = 96)
std::bitset<32> bits_1_32(test_params.work_items_mask.to_string(),
96, 32);
std::bitset<32> bits_33_64(test_params.work_items_mask.to_string(),
64, 32);
std::bitset<32> bits_65_96(test_params.work_items_mask.to_string(),
32, 32);
std::bitset<32> bits_97_128(test_params.work_items_mask.to_string(),
0, 32);
kernel_sstr << "global uint4 work_item_mask_vector = (uint4)(0b"
<< bits_1_32 << ",0b" << bits_33_64 << ",0b"
<< bits_65_96 << ",0b" << bits_97_128 << ");\n";
}
@@ -1452,18 +1605,24 @@ struct RunTestForType
num_elements_(num_elements), test_params_(test_params)
{}
template <typename T, typename U>
int run_impl(const char *kernel_name, const char *source)
int run_impl(const std::string &function_name)
{
int error = TEST_PASS;
std::string source =
std::regex_replace(test_params_.get_kernel_source(function_name),
std::regex("\\%s"), function_name);
std::string kernel_name = "test_" + function_name;
if (test_params_.all_work_item_masks.size() > 0)
{
error = test<T, U>::mrun(device_, context_, queue_, num_elements_,
kernel_name, source, test_params_);
kernel_name.c_str(), source.c_str(),
test_params_);
}
else
{
error = test<T, U>::run(device_, context_, queue_, num_elements_,
kernel_name, source, test_params_);
kernel_name.c_str(), source.c_str(),
test_params_);
}
return error;