Added missing instructions to get Flame working

This commit is contained in:
2022-05-25 21:04:46 +02:00
parent a0b0ceb6d5
commit 6b00062a11

View File

@@ -26,7 +26,7 @@
#pragma GCC diagnostic ignored "-Wunknown-pragmas"
//region Type classes
class SpirvType {
protected:
llvm::Type *llvm = nullptr;
@@ -221,6 +221,7 @@ protected:
return llvm;
}
};
//endregion
std::string decode_string_arg(const uint32_t* args) {
std::string result;
@@ -238,6 +239,7 @@ std::string decode_string_arg(const uint32_t* args) {
#define OP_BYTES(i) (insn->words + insn->operands[i].offset)
#define OP_WORD(i) (*OP_BYTES(i))
#define OP_VALUE(i) (cur_function->values[OP_WORD(i)])
struct EntryPoint {
std::string name;
@@ -245,7 +247,7 @@ struct EntryPoint {
std::vector<uint32_t> interface_vars;
public:
EntryPoint(const spv_parsed_instruction_t *insn) {
explicit EntryPoint(const spv_parsed_instruction_t *insn) {
func_id = OP_WORD(1);
name = decode_string_arg(OP_BYTES(2));
for (int i = 3; i < insn->num_operands; i++) {
@@ -344,6 +346,7 @@ struct FunctionContext {
);
void gen_prologue(struct CompilerImpl &compiler);
llvm::BasicBlock *gen_bb(struct CompilerImpl& compiler, uint32_t id);
};
@@ -408,7 +411,7 @@ public:
build_global_type();
}
spv_result_t
static spv_result_t
handle_spv_header(spv_endianness_t endianness, uint32_t magic, uint32_t version,
uint32_t generator,
uint32_t id_bound,
@@ -420,7 +423,7 @@ public:
// TODO: allocate structures based on id_bound.
}
std::string format_insn(const spv_parsed_instruction_t* insn) {
static std::string format_insn(const spv_parsed_instruction_t* insn) {
std::stringstream msg_stream;
msg_stream << "Opcode(" << insn_names[insn->opcode] << "/" << insn->opcode << "):";
@@ -443,7 +446,6 @@ public:
uint32_t rty;
bool has_rid = false;
bool has_rty = false;
uint32_t first_operand = insn->num_operands;
auto idata = insn->words;
for (int i = 0; i < insn->num_operands; i++) {
@@ -455,7 +457,6 @@ public:
rty = OP_WORD(i);
has_rty = true;
} else {
first_operand = i;
break;
}
}
@@ -591,31 +592,114 @@ public:
case Op::OpExtInst: {
auto ext_set = ext_insts[OP_WORD(2)];
auto ext_inst = OP_WORD(3);
switch (ext_set) {
case SPV_EXT_INST_TYPE_GLSL_STD_450: {
switch (GLSLstd450(ext_inst)) {
case GLSLstd450Cos: {
auto val = cur_function->values[OP_WORD(4)];
auto ty = val->getType();//->getScalarType();
if (ext_set == SPV_EXT_INST_TYPE_GLSL_STD_450) {
switch (GLSLstd450(ext_inst)) {
put_value(rid, builder->CreateIntrinsic(llvm::Intrinsic::cos, ty, val));
break;
}
// case GLSLstd450Atanh: {
// auto val = cur_function->values[OP_WORD(4)];
// auto ty = val->getType()->getScalarType();
// put_value(rid, builder->CreateIntrinsic(llvm::Intrinsic::))
// }
default: {
put_value(rid, builder->CreateFreeze(llvm::UndefValue::get(types[rty]->get_llvm_type())));
BOOST_LOG_TRIVIAL(warning) << "Unhandled GLSL extinst " << ext_inst;
}
case GLSLstd450FAbs: { // 4
auto res = builder->CreateUnaryIntrinsic(llvm::Intrinsic::fabs, OP_VALUE(4));
put_value(rid, res);
break;
}
case GLSLstd450Floor: { // 8
auto val = OP_VALUE(4);
auto res = builder->CreateIntrinsic(llvm::Intrinsic::floor, val->getType(), val);
put_value(rid, res);
break;
}
case GLSLstd450Sin: { // 13
auto val = OP_VALUE(4);
auto res = builder->CreateIntrinsic(llvm::Intrinsic::sin, val->getType(), val);
put_value(rid, res);
break;
}
case GLSLstd450Cos: { // 14
auto val = cur_function->values[OP_WORD(4)];
auto ty = val->getType();//->getScalarType();
put_value(rid, builder->CreateIntrinsic(llvm::Intrinsic::cos, ty, val));
break;
}
case GLSLstd450Pow: { // 26
auto res = builder->CreateBinaryIntrinsic(llvm::Intrinsic::pow, OP_VALUE(4), OP_VALUE(5));
put_value(rid, res);
break;
}
case GLSLstd450Sqrt: { // 31
put_value(rid, builder->CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, OP_VALUE(4)));
break;
}
case GLSLstd450FMin: { // 37
auto select = builder->CreateFCmpOLT(OP_VALUE(4), OP_VALUE(5));
auto res = builder->CreateSelect(select, OP_VALUE(4), OP_VALUE(5));
put_value(rid, res);
break;
}
case GLSLstd450FMax: { // 37
auto select = builder->CreateFCmpOGT(OP_VALUE(4), OP_VALUE(5));
auto res = builder->CreateSelect(select, OP_VALUE(4), OP_VALUE(5));
put_value(rid, res);
break;
}
case GLSLstd450FMix: { // 46
auto x = OP_VALUE(4);
auto y = OP_VALUE(5);
auto a = OP_VALUE(6);
auto stype = a->getType()->getScalarType();
llvm::Constant *ones = llvm::ConstantFP::get(stype, 1.);
if (x->getType()->isVectorTy()) {
auto ec = llvm::dyn_cast<llvm::VectorType>(a->getType())->getElementCount();
ones = llvm::ConstantVector::getSplat(ec, ones);
}
auto one_min_a = builder->CreateFSub(ones, a);
x = builder->CreateFMul(x, one_min_a);
y = builder->CreateFMul(y, a);
auto res = builder->CreateFAdd(x,y);
put_value(rid, res);
break;
}
case GLSLstd450Length: { //66
auto v = OP_VALUE(4);
llvm::Value* res;
if (v->getType()->isVectorTy()) {
auto sty = v->getType()->getScalarType();
auto vsq = builder->CreateFMul(v, v);
auto vsum = builder->CreateFAddReduce(llvm::ConstantFP::get(sty, 0), vsq);
res = builder->CreateIntrinsic(llvm::Intrinsic::sqrt, sty, vsum);
} else {
res = builder->CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, v);
}
put_value(rid, res);
break;
}
case GLSLstd450Normalize: { //69
auto v = OP_VALUE(4);
llvm::Value* res;
if (v->getType()->isVectorTy()) {
auto sty = v->getType()->getScalarType();
auto vsq = builder->CreateFMul(v, v);
auto vsum = builder->CreateFAddReduce(llvm::ConstantFP::get(sty, 0), vsq);
res = builder->CreateIntrinsic(llvm::Intrinsic::sqrt, sty, vsum);
} else {
res = builder->CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, v);
}
res = builder->CreateFDiv(v, res);
put_value(rid, res);
break;
}
// case GLSLstd450Atanh: {
// auto val = cur_function->values[OP_WORD(4)];
// auto ty = val->getType()->getScalarType();
// put_value(rid, builder->CreateIntrinsic(llvm::Intrinsic::))
// }
default: {
put_value(rid, builder->CreateFreeze(llvm::UndefValue::get(types[rty]->get_llvm_type())));
BOOST_LOG_TRIVIAL(warning) << "Unhandled GLSL extinst " << ext_inst;
}
break;
}
default: {
// ignore
}
break;
} else {
// ignore
}
break;
}
@@ -657,7 +741,6 @@ public:
break;
}
} else if (itype != nullptr) {
uint64_t val = 0;
switch (itype->width) {
case 64:
constants[rid] = std::make_shared<Constant>((uint64_t)OP_BYTES(2)[0]
@@ -670,9 +753,9 @@ public:
break;
}
case Op::OpConstantComposite: {
auto &type = types[rty];
type->build_llvm_type(*ctx);
auto &typ_id = typeid(*type);
auto &type = *types[rty];
type.build_llvm_type(*ctx);
auto &typ_id = typeid(type);
std::vector<llvm::Constant*> elements;
for (int i = 2; i < insn->num_operands; i++) {
elements.push_back(constants[OP_WORD(i)]->get_llvm_const(*ctx));
@@ -680,7 +763,7 @@ public:
llvm::Constant *result = nullptr;
if (typ_id == typeid(SpirvStructType)) {
result = llvm::ConstantStruct::get((llvm::StructType*)type->get_llvm_type(),
result = llvm::ConstantStruct::get((llvm::StructType*)type.get_llvm_type(),
elements);
} else if (typ_id == typeid(SpirvVectorType)) {
result = llvm::ConstantVector::get(
@@ -688,7 +771,7 @@ public:
);
} else if (typ_id == typeid(SpirvArrayType)) {
result = llvm::ConstantArray::get(
(llvm::ArrayType*)type->get_llvm_type(),
(llvm::ArrayType*)type.get_llvm_type(),
elements
);
}
@@ -781,6 +864,12 @@ public:
break;
}
//endregion
//region 3.42.11 Conversion instructions
case Op::OpConvertSToF: {
put_value(rid, builder->CreateSIToFP(OP_VALUE(2), types[rty]->get_llvm_type()));
break;
}
//endregion
// region 3.42.12 Composite instructions
case Op::OpVectorShuffle: {
std::vector<int> mask(OP_BYTES(4), OP_BYTES(insn->num_operands));
@@ -832,12 +921,32 @@ public:
}
//endregion
//region 3.42.13 Arithmetic instructions
case Op::OpIAdd: {
put_value(rid, builder->CreateAdd(OP_VALUE(2), OP_VALUE(3)));
break;
}
case Op::OpFNegate: {
put_value(rid, builder->CreateFNeg(OP_VALUE(2)));
break;
}
case Op::OpFAdd: {
auto value = builder->CreateFAdd(cur_function->values[OP_WORD(2)],
cur_function->values[OP_WORD(3)]);
put_value(rid, value);
break;
}
case Op::OpFSub: {
auto value = builder->CreateFSub(cur_function->values[OP_WORD(2)],
cur_function->values[OP_WORD(3)]);
put_value(rid, value);
break;
}
case Op::OpFMul: {
auto value = builder->CreateFMul(cur_function->values[OP_WORD(2)],
cur_function->values[OP_WORD(3)]);
put_value(rid, value);
break;
}
case Op::OpFDiv: {
auto value =
builder->CreateFDiv(cur_function->values[OP_WORD(2)],
@@ -845,6 +954,11 @@ public:
put_value(rid, value);
break;
}
case Op::OpFMod: {
auto value = builder->CreateFRem(OP_VALUE(2), OP_VALUE(3));
put_value(rid, value);
break;
}
case Op::OpVectorTimesScalar: {
auto vty = (llvm::VectorType*)(types[rty]->get_llvm_type());
auto scalar = builder->CreateVectorSplat(vty->getElementCount(), cur_function->values[OP_WORD(3)]);
@@ -852,14 +966,43 @@ public:
put_value(rid, result);
break;
}
case Op::OpDot: {
auto vtv = builder->CreateFMul(cur_function->values[OP_WORD(2)],
cur_function->values[OP_WORD(3)]);
auto value = builder->CreateFAddReduce(llvm::ConstantFP::get(vtv->getType()->getScalarType(), 0.0),
vtv);
put_value(rid, value);
break;
}
//endregion
//region 3.42.15 Relational and logical instructions
case Op::OpSLessThan: {
put_value(rid, builder->CreateICmpSLT(OP_VALUE(2), OP_VALUE(3)));
break;
}
case Op::OpFOrdGreaterThan: {
put_value(rid, builder->CreateFCmpOGT(OP_VALUE(2), OP_VALUE(3)));
break;
}
case Op::OpFOrdLessThan: {
put_value(rid, builder->CreateFCmpOLT(OP_VALUE(2), OP_VALUE(3)));
break;
}
//endregion
//region 3.42.17 Control Flow instructions
case Op::OpBranch: {
auto target = cur_function->gen_bb(*this, OP_WORD(0));
builder->CreateBr(target);
break;
}
case Op::OpBranchConditional: {
auto ifTrue = cur_function->gen_bb(*this, OP_WORD(1));
auto ifFalse = cur_function->gen_bb(*this, OP_WORD(2));
builder->CreateCondBr(cur_function->values[OP_WORD(0)], ifTrue, ifFalse);
break;
}
case Op::OpLabel: {
llvm::BasicBlock *new_bb = cur_function->basic_blocks[rid];
if (new_bb == nullptr) {
auto label = boost::format("label_%1%") % rid;
new_bb = cur_function->basic_blocks[rid] = llvm::BasicBlock::Create(*ctx, label.str(), cur_function->function);
}
auto new_bb = cur_function->gen_bb(*this, rid);
builder->SetInsertPoint(new_bb);
if (cur_function->in_prologue) {
cur_function->gen_prologue(*this);
@@ -871,6 +1014,10 @@ public:
builder->CreateRetVoid();
break;
}
case Op::OpReturnValue: {
builder->CreateRet(cur_function->values[OP_WORD(0)]);
break;
}
//endregion
default:{
@@ -944,7 +1091,7 @@ Compiler::Compiler(): impl(std::make_unique<CompilerImpl>()) {
bool Compiler::compile(std::vector<uint32_t> &spv_module) {
auto ret = impl->process_module(spv_module);
impl->module->print(llvm::outs(), nullptr, false, true);
// impl->module->print(llvm::outs(), nullptr, false, true);
BOOST_LOG_TRIVIAL(debug) << "Optimizing";
// start generating machine code
@@ -970,8 +1117,7 @@ bool Compiler::compile(std::vector<uint32_t> &spv_module) {
pass_builder.populateModulePassManager(pass);
pass.run(*impl->module);
impl->module->print(llvm::outs(), nullptr, false, true);
// impl->module->print(llvm::outs(), nullptr, false, true);
return ret;
}
@@ -1031,6 +1177,15 @@ void FunctionContext::gen_prologue(struct CompilerImpl &compiler) {
}
}
llvm::BasicBlock *FunctionContext::gen_bb(CompilerImpl &compiler, uint32_t id) {
llvm::BasicBlock *new_bb = basic_blocks[id];
if (new_bb == nullptr) {
auto label = boost::format("label_%1%") % id;
new_bb = basic_blocks[id] = llvm::BasicBlock::Create(*compiler.ctx, label.str(), function);
}
return new_bb;
}
spv_result_t
impl_parse_header(void *user_data, spv_endianness_t endian, uint32_t magic, uint32_t version, uint32_t generator,
uint32_t id_bound, uint32_t reserved) {