Added missing instructions to get Flame working
This commit is contained in:
239
src/compiler.cpp
239
src/compiler.cpp
@@ -26,7 +26,7 @@
|
||||
#pragma GCC diagnostic ignored "-Wunknown-pragmas"
|
||||
|
||||
|
||||
|
||||
//region Type classes
|
||||
class SpirvType {
|
||||
protected:
|
||||
llvm::Type *llvm = nullptr;
|
||||
@@ -221,6 +221,7 @@ protected:
|
||||
return llvm;
|
||||
}
|
||||
};
|
||||
//endregion
|
||||
|
||||
std::string decode_string_arg(const uint32_t* args) {
|
||||
std::string result;
|
||||
@@ -238,6 +239,7 @@ std::string decode_string_arg(const uint32_t* args) {
|
||||
|
||||
#define OP_BYTES(i) (insn->words + insn->operands[i].offset)
|
||||
#define OP_WORD(i) (*OP_BYTES(i))
|
||||
#define OP_VALUE(i) (cur_function->values[OP_WORD(i)])
|
||||
|
||||
struct EntryPoint {
|
||||
std::string name;
|
||||
@@ -245,7 +247,7 @@ struct EntryPoint {
|
||||
std::vector<uint32_t> interface_vars;
|
||||
|
||||
public:
|
||||
EntryPoint(const spv_parsed_instruction_t *insn) {
|
||||
explicit EntryPoint(const spv_parsed_instruction_t *insn) {
|
||||
func_id = OP_WORD(1);
|
||||
name = decode_string_arg(OP_BYTES(2));
|
||||
for (int i = 3; i < insn->num_operands; i++) {
|
||||
@@ -344,6 +346,7 @@ struct FunctionContext {
|
||||
);
|
||||
|
||||
void gen_prologue(struct CompilerImpl &compiler);
|
||||
llvm::BasicBlock *gen_bb(struct CompilerImpl& compiler, uint32_t id);
|
||||
};
|
||||
|
||||
|
||||
@@ -408,7 +411,7 @@ public:
|
||||
build_global_type();
|
||||
}
|
||||
|
||||
spv_result_t
|
||||
static spv_result_t
|
||||
handle_spv_header(spv_endianness_t endianness, uint32_t magic, uint32_t version,
|
||||
uint32_t generator,
|
||||
uint32_t id_bound,
|
||||
@@ -420,7 +423,7 @@ public:
|
||||
// TODO: allocate structures based on id_bound.
|
||||
}
|
||||
|
||||
std::string format_insn(const spv_parsed_instruction_t* insn) {
|
||||
static std::string format_insn(const spv_parsed_instruction_t* insn) {
|
||||
|
||||
std::stringstream msg_stream;
|
||||
msg_stream << "Opcode(" << insn_names[insn->opcode] << "/" << insn->opcode << "):";
|
||||
@@ -443,7 +446,6 @@ public:
|
||||
uint32_t rty;
|
||||
bool has_rid = false;
|
||||
bool has_rty = false;
|
||||
uint32_t first_operand = insn->num_operands;
|
||||
auto idata = insn->words;
|
||||
|
||||
for (int i = 0; i < insn->num_operands; i++) {
|
||||
@@ -455,7 +457,6 @@ public:
|
||||
rty = OP_WORD(i);
|
||||
has_rty = true;
|
||||
} else {
|
||||
first_operand = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -591,31 +592,114 @@ public:
|
||||
case Op::OpExtInst: {
|
||||
auto ext_set = ext_insts[OP_WORD(2)];
|
||||
auto ext_inst = OP_WORD(3);
|
||||
switch (ext_set) {
|
||||
case SPV_EXT_INST_TYPE_GLSL_STD_450: {
|
||||
switch (GLSLstd450(ext_inst)) {
|
||||
case GLSLstd450Cos: {
|
||||
auto val = cur_function->values[OP_WORD(4)];
|
||||
auto ty = val->getType();//->getScalarType();
|
||||
if (ext_set == SPV_EXT_INST_TYPE_GLSL_STD_450) {
|
||||
switch (GLSLstd450(ext_inst)) {
|
||||
|
||||
put_value(rid, builder->CreateIntrinsic(llvm::Intrinsic::cos, ty, val));
|
||||
break;
|
||||
}
|
||||
// case GLSLstd450Atanh: {
|
||||
// auto val = cur_function->values[OP_WORD(4)];
|
||||
// auto ty = val->getType()->getScalarType();
|
||||
// put_value(rid, builder->CreateIntrinsic(llvm::Intrinsic::))
|
||||
// }
|
||||
default: {
|
||||
put_value(rid, builder->CreateFreeze(llvm::UndefValue::get(types[rty]->get_llvm_type())));
|
||||
BOOST_LOG_TRIVIAL(warning) << "Unhandled GLSL extinst " << ext_inst;
|
||||
}
|
||||
case GLSLstd450FAbs: { // 4
|
||||
auto res = builder->CreateUnaryIntrinsic(llvm::Intrinsic::fabs, OP_VALUE(4));
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
case GLSLstd450Floor: { // 8
|
||||
auto val = OP_VALUE(4);
|
||||
auto res = builder->CreateIntrinsic(llvm::Intrinsic::floor, val->getType(), val);
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
case GLSLstd450Sin: { // 13
|
||||
auto val = OP_VALUE(4);
|
||||
auto res = builder->CreateIntrinsic(llvm::Intrinsic::sin, val->getType(), val);
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
case GLSLstd450Cos: { // 14
|
||||
auto val = cur_function->values[OP_WORD(4)];
|
||||
auto ty = val->getType();//->getScalarType();
|
||||
|
||||
put_value(rid, builder->CreateIntrinsic(llvm::Intrinsic::cos, ty, val));
|
||||
break;
|
||||
}
|
||||
case GLSLstd450Pow: { // 26
|
||||
auto res = builder->CreateBinaryIntrinsic(llvm::Intrinsic::pow, OP_VALUE(4), OP_VALUE(5));
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
case GLSLstd450Sqrt: { // 31
|
||||
put_value(rid, builder->CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, OP_VALUE(4)));
|
||||
break;
|
||||
}
|
||||
case GLSLstd450FMin: { // 37
|
||||
auto select = builder->CreateFCmpOLT(OP_VALUE(4), OP_VALUE(5));
|
||||
auto res = builder->CreateSelect(select, OP_VALUE(4), OP_VALUE(5));
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
case GLSLstd450FMax: { // 37
|
||||
auto select = builder->CreateFCmpOGT(OP_VALUE(4), OP_VALUE(5));
|
||||
auto res = builder->CreateSelect(select, OP_VALUE(4), OP_VALUE(5));
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
case GLSLstd450FMix: { // 46
|
||||
auto x = OP_VALUE(4);
|
||||
auto y = OP_VALUE(5);
|
||||
auto a = OP_VALUE(6);
|
||||
auto stype = a->getType()->getScalarType();
|
||||
llvm::Constant *ones = llvm::ConstantFP::get(stype, 1.);
|
||||
|
||||
if (x->getType()->isVectorTy()) {
|
||||
auto ec = llvm::dyn_cast<llvm::VectorType>(a->getType())->getElementCount();
|
||||
ones = llvm::ConstantVector::getSplat(ec, ones);
|
||||
}
|
||||
auto one_min_a = builder->CreateFSub(ones, a);
|
||||
x = builder->CreateFMul(x, one_min_a);
|
||||
y = builder->CreateFMul(y, a);
|
||||
auto res = builder->CreateFAdd(x,y);
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
case GLSLstd450Length: { //66
|
||||
auto v = OP_VALUE(4);
|
||||
llvm::Value* res;
|
||||
if (v->getType()->isVectorTy()) {
|
||||
auto sty = v->getType()->getScalarType();
|
||||
auto vsq = builder->CreateFMul(v, v);
|
||||
auto vsum = builder->CreateFAddReduce(llvm::ConstantFP::get(sty, 0), vsq);
|
||||
res = builder->CreateIntrinsic(llvm::Intrinsic::sqrt, sty, vsum);
|
||||
} else {
|
||||
res = builder->CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, v);
|
||||
}
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
case GLSLstd450Normalize: { //69
|
||||
auto v = OP_VALUE(4);
|
||||
llvm::Value* res;
|
||||
if (v->getType()->isVectorTy()) {
|
||||
auto sty = v->getType()->getScalarType();
|
||||
auto vsq = builder->CreateFMul(v, v);
|
||||
auto vsum = builder->CreateFAddReduce(llvm::ConstantFP::get(sty, 0), vsq);
|
||||
res = builder->CreateIntrinsic(llvm::Intrinsic::sqrt, sty, vsum);
|
||||
} else {
|
||||
res = builder->CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, v);
|
||||
}
|
||||
res = builder->CreateFDiv(v, res);
|
||||
put_value(rid, res);
|
||||
break;
|
||||
}
|
||||
// case GLSLstd450Atanh: {
|
||||
// auto val = cur_function->values[OP_WORD(4)];
|
||||
// auto ty = val->getType()->getScalarType();
|
||||
// put_value(rid, builder->CreateIntrinsic(llvm::Intrinsic::))
|
||||
// }
|
||||
default: {
|
||||
put_value(rid, builder->CreateFreeze(llvm::UndefValue::get(types[rty]->get_llvm_type())));
|
||||
BOOST_LOG_TRIVIAL(warning) << "Unhandled GLSL extinst " << ext_inst;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// ignore
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
// ignore
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -657,7 +741,6 @@ public:
|
||||
break;
|
||||
}
|
||||
} else if (itype != nullptr) {
|
||||
uint64_t val = 0;
|
||||
switch (itype->width) {
|
||||
case 64:
|
||||
constants[rid] = std::make_shared<Constant>((uint64_t)OP_BYTES(2)[0]
|
||||
@@ -670,9 +753,9 @@ public:
|
||||
break;
|
||||
}
|
||||
case Op::OpConstantComposite: {
|
||||
auto &type = types[rty];
|
||||
type->build_llvm_type(*ctx);
|
||||
auto &typ_id = typeid(*type);
|
||||
auto &type = *types[rty];
|
||||
type.build_llvm_type(*ctx);
|
||||
auto &typ_id = typeid(type);
|
||||
std::vector<llvm::Constant*> elements;
|
||||
for (int i = 2; i < insn->num_operands; i++) {
|
||||
elements.push_back(constants[OP_WORD(i)]->get_llvm_const(*ctx));
|
||||
@@ -680,7 +763,7 @@ public:
|
||||
|
||||
llvm::Constant *result = nullptr;
|
||||
if (typ_id == typeid(SpirvStructType)) {
|
||||
result = llvm::ConstantStruct::get((llvm::StructType*)type->get_llvm_type(),
|
||||
result = llvm::ConstantStruct::get((llvm::StructType*)type.get_llvm_type(),
|
||||
elements);
|
||||
} else if (typ_id == typeid(SpirvVectorType)) {
|
||||
result = llvm::ConstantVector::get(
|
||||
@@ -688,7 +771,7 @@ public:
|
||||
);
|
||||
} else if (typ_id == typeid(SpirvArrayType)) {
|
||||
result = llvm::ConstantArray::get(
|
||||
(llvm::ArrayType*)type->get_llvm_type(),
|
||||
(llvm::ArrayType*)type.get_llvm_type(),
|
||||
elements
|
||||
);
|
||||
}
|
||||
@@ -781,6 +864,12 @@ public:
|
||||
break;
|
||||
}
|
||||
//endregion
|
||||
//region 3.42.11 Conversion instructions
|
||||
case Op::OpConvertSToF: {
|
||||
put_value(rid, builder->CreateSIToFP(OP_VALUE(2), types[rty]->get_llvm_type()));
|
||||
break;
|
||||
}
|
||||
//endregion
|
||||
// region 3.42.12 Composite instructions
|
||||
case Op::OpVectorShuffle: {
|
||||
std::vector<int> mask(OP_BYTES(4), OP_BYTES(insn->num_operands));
|
||||
@@ -832,12 +921,32 @@ public:
|
||||
}
|
||||
//endregion
|
||||
//region 3.42.13 Arithmetic instructions
|
||||
case Op::OpIAdd: {
|
||||
put_value(rid, builder->CreateAdd(OP_VALUE(2), OP_VALUE(3)));
|
||||
break;
|
||||
}
|
||||
case Op::OpFNegate: {
|
||||
put_value(rid, builder->CreateFNeg(OP_VALUE(2)));
|
||||
break;
|
||||
}
|
||||
case Op::OpFAdd: {
|
||||
auto value = builder->CreateFAdd(cur_function->values[OP_WORD(2)],
|
||||
cur_function->values[OP_WORD(3)]);
|
||||
put_value(rid, value);
|
||||
break;
|
||||
}
|
||||
case Op::OpFSub: {
|
||||
auto value = builder->CreateFSub(cur_function->values[OP_WORD(2)],
|
||||
cur_function->values[OP_WORD(3)]);
|
||||
put_value(rid, value);
|
||||
break;
|
||||
}
|
||||
case Op::OpFMul: {
|
||||
auto value = builder->CreateFMul(cur_function->values[OP_WORD(2)],
|
||||
cur_function->values[OP_WORD(3)]);
|
||||
put_value(rid, value);
|
||||
break;
|
||||
}
|
||||
case Op::OpFDiv: {
|
||||
auto value =
|
||||
builder->CreateFDiv(cur_function->values[OP_WORD(2)],
|
||||
@@ -845,6 +954,11 @@ public:
|
||||
put_value(rid, value);
|
||||
break;
|
||||
}
|
||||
case Op::OpFMod: {
|
||||
auto value = builder->CreateFRem(OP_VALUE(2), OP_VALUE(3));
|
||||
put_value(rid, value);
|
||||
break;
|
||||
}
|
||||
case Op::OpVectorTimesScalar: {
|
||||
auto vty = (llvm::VectorType*)(types[rty]->get_llvm_type());
|
||||
auto scalar = builder->CreateVectorSplat(vty->getElementCount(), cur_function->values[OP_WORD(3)]);
|
||||
@@ -852,14 +966,43 @@ public:
|
||||
put_value(rid, result);
|
||||
break;
|
||||
}
|
||||
case Op::OpDot: {
|
||||
auto vtv = builder->CreateFMul(cur_function->values[OP_WORD(2)],
|
||||
cur_function->values[OP_WORD(3)]);
|
||||
auto value = builder->CreateFAddReduce(llvm::ConstantFP::get(vtv->getType()->getScalarType(), 0.0),
|
||||
vtv);
|
||||
put_value(rid, value);
|
||||
break;
|
||||
}
|
||||
//endregion
|
||||
//region 3.42.15 Relational and logical instructions
|
||||
case Op::OpSLessThan: {
|
||||
put_value(rid, builder->CreateICmpSLT(OP_VALUE(2), OP_VALUE(3)));
|
||||
break;
|
||||
}
|
||||
case Op::OpFOrdGreaterThan: {
|
||||
put_value(rid, builder->CreateFCmpOGT(OP_VALUE(2), OP_VALUE(3)));
|
||||
break;
|
||||
}
|
||||
case Op::OpFOrdLessThan: {
|
||||
put_value(rid, builder->CreateFCmpOLT(OP_VALUE(2), OP_VALUE(3)));
|
||||
break;
|
||||
}
|
||||
//endregion
|
||||
//region 3.42.17 Control Flow instructions
|
||||
case Op::OpBranch: {
|
||||
auto target = cur_function->gen_bb(*this, OP_WORD(0));
|
||||
builder->CreateBr(target);
|
||||
break;
|
||||
}
|
||||
case Op::OpBranchConditional: {
|
||||
auto ifTrue = cur_function->gen_bb(*this, OP_WORD(1));
|
||||
auto ifFalse = cur_function->gen_bb(*this, OP_WORD(2));
|
||||
builder->CreateCondBr(cur_function->values[OP_WORD(0)], ifTrue, ifFalse);
|
||||
break;
|
||||
}
|
||||
case Op::OpLabel: {
|
||||
llvm::BasicBlock *new_bb = cur_function->basic_blocks[rid];
|
||||
if (new_bb == nullptr) {
|
||||
auto label = boost::format("label_%1%") % rid;
|
||||
new_bb = cur_function->basic_blocks[rid] = llvm::BasicBlock::Create(*ctx, label.str(), cur_function->function);
|
||||
}
|
||||
auto new_bb = cur_function->gen_bb(*this, rid);
|
||||
builder->SetInsertPoint(new_bb);
|
||||
if (cur_function->in_prologue) {
|
||||
cur_function->gen_prologue(*this);
|
||||
@@ -871,6 +1014,10 @@ public:
|
||||
builder->CreateRetVoid();
|
||||
break;
|
||||
}
|
||||
case Op::OpReturnValue: {
|
||||
builder->CreateRet(cur_function->values[OP_WORD(0)]);
|
||||
break;
|
||||
}
|
||||
|
||||
//endregion
|
||||
default:{
|
||||
@@ -944,7 +1091,7 @@ Compiler::Compiler(): impl(std::make_unique<CompilerImpl>()) {
|
||||
bool Compiler::compile(std::vector<uint32_t> &spv_module) {
|
||||
auto ret = impl->process_module(spv_module);
|
||||
|
||||
impl->module->print(llvm::outs(), nullptr, false, true);
|
||||
// impl->module->print(llvm::outs(), nullptr, false, true);
|
||||
|
||||
BOOST_LOG_TRIVIAL(debug) << "Optimizing";
|
||||
// start generating machine code
|
||||
@@ -970,8 +1117,7 @@ bool Compiler::compile(std::vector<uint32_t> &spv_module) {
|
||||
pass_builder.populateModulePassManager(pass);
|
||||
pass.run(*impl->module);
|
||||
|
||||
impl->module->print(llvm::outs(), nullptr, false, true);
|
||||
|
||||
// impl->module->print(llvm::outs(), nullptr, false, true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -1031,6 +1177,15 @@ void FunctionContext::gen_prologue(struct CompilerImpl &compiler) {
|
||||
}
|
||||
}
|
||||
|
||||
llvm::BasicBlock *FunctionContext::gen_bb(CompilerImpl &compiler, uint32_t id) {
|
||||
llvm::BasicBlock *new_bb = basic_blocks[id];
|
||||
if (new_bb == nullptr) {
|
||||
auto label = boost::format("label_%1%") % id;
|
||||
new_bb = basic_blocks[id] = llvm::BasicBlock::Create(*compiler.ctx, label.str(), function);
|
||||
}
|
||||
return new_bb;
|
||||
}
|
||||
|
||||
spv_result_t
|
||||
impl_parse_header(void *user_data, spv_endianness_t endian, uint32_t magic, uint32_t version, uint32_t generator,
|
||||
uint32_t id_bound, uint32_t reserved) {
|
||||
|
||||
Reference in New Issue
Block a user