Added start of runtime, started work on support functions

This commit is contained in:
2022-05-26 02:43:01 +02:00
parent 6b00062a11
commit c90cc0a9bc
5 changed files with 305 additions and 10 deletions

View File

@@ -21,12 +21,12 @@ add_definitions(${LLVM_DEFINITIONS})
add_executable(bluebell
src/bluebell.cpp src/bluebell.hpp
src/pre-compiler.cpp src/pre-compiler.hpp src/compiler.cpp src/compiler.hpp src/spv_meta.cpp src/spv_meta.hpp)
src/pre-compiler.cpp src/pre-compiler.hpp src/compiler.cpp src/compiler.hpp src/spv_meta.cpp src/spv_meta.hpp src/runtime.cpp src/runtime.hpp)
#target_link_libraries(bluebell PkgConfig::shaderc)
#print(${LLVM_AVAILABLE_LIBS})
llvm_map_components_to_libnames(llvm_libs core native passes nativecodegen all)
llvm_map_components_to_libnames(llvm_libs core native passes nativecodegen OrcJIT)
target_link_libraries(bluebell shaderc_shared Boost::boost Boost::log ${llvm_libs})

View File

@@ -13,6 +13,13 @@
#include <llvm/Support/TargetRegistry.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Scalar/DCE.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Transforms/Utils.h>
#include <llvm/IR/Verifier.h>
#include <boost/log/trivial.hpp>
#include <boost/format.hpp>
#include <spirv-tools/libspirv.hpp>
@@ -360,6 +367,7 @@ struct CompilerImpl {
std::map<uint32_t, std::shared_ptr<SpirvType>> types;
std::map<uint32_t, std::shared_ptr<Constant>> constants;
std::map<std::string, std::shared_ptr<EntryPoint>> entry_points;
std::map<std::pair<uint32_t, spv::Decoration>, uint32_t> decorations;
std::map<uint32_t, Global> globals;
std::map<uint32_t, std::shared_ptr<FunctionContext>> functions;
llvm::StructType* global_type;
@@ -524,7 +532,19 @@ public:
//endregion
//region Annotations
case Op::OpDecorate: {
// TODO: save these for bindings
auto deco = spv::Decoration(OP_WORD(1));
switch(deco) {
case spv::Decoration::Binding:
case spv::Decoration::Location:
case spv::Decoration::DescriptorSet:
case spv::Decoration::BuiltIn:
decorations.emplace(std::make_pair(OP_WORD(0), deco), OP_WORD(2));
break;
default:
// ignore all others
break;
}
break;
}
case Op::OpMemberDecorate: {
@@ -1055,7 +1075,97 @@ public:
}
}
void generate_support() {
generate_support_setup_frame();
generate_support_get_context_size();
generate_support_render_frame();
}
void generate_support_get_context_size() const {
auto ty_fn = llvm::FunctionType::get(llvm::Type::getInt32Ty(*ctx), false);
auto fn = llvm::Function::Create(
ty_fn,
llvm::Function::LinkageTypes::ExternalLinkage,
"get_context_size",
*module);
auto bb = llvm::BasicBlock::Create(*ctx, "entry", fn);
builder->SetInsertPoint(bb);
auto global_null = llvm::ConstantPointerNull::get(global_type->getPointerTo());
auto next_ptr = builder->CreateConstGEP1_32(global_type,
global_null,
1);
auto ptr_as_int = builder->CreatePtrToInt(next_ptr, ty_fn->getReturnType());
builder->CreateRet(ptr_as_int);
}
void generate_support_setup_frame() const {
// generate type for C-based uniform block
auto ty_f = llvm::Type::getFloatTy(*ctx);
auto ty_fx3 = llvm::ArrayType::get(ty_f, 3);
auto ty_fx4 = llvm::ArrayType::get(ty_f, 4);
auto ty_fx3x4 = llvm::ArrayType::get(ty_fx3, 4);
std::vector<llvm::Type*> cuniform_members = {
ty_fx3, // iResolution
ty_f, // iTime
ty_f, // iTimeDelta
ty_f, // iFrame
ty_fx4, // iChannelTime
ty_fx4, // iMouse
ty_fx4, // iDate
ty_f, // iSampleRate
ty_fx3x4, // iChannelResolution
};
auto ty_cuniform = llvm::StructType::create(cuniform_members, "c_uniforms");
std::vector<llvm::Type*> fnArgs = {
global_type->getPointerTo(),
ty_cuniform->getPointerTo(),
};
auto ty_fn = llvm::FunctionType::get(llvm::Type::getVoidTy(*ctx), fnArgs, false);
auto fn = llvm::Function::Create(ty_fn, llvm::Function::ExternalLinkage, "setup_frame", *module);
auto bb = llvm::BasicBlock::Create(*ctx, "entry", fn);
builder->SetInsertPoint(bb);
// figure out which global to access...
int32_t uniform_idx = -1;
uint32_t deco_dset, deco_binding;
for (auto &global: globals) {
auto id = global.first;
if (get_decoration(id, spv::Decoration::DescriptorSet, deco_dset)
&& get_decoration(id, spv::Decoration::Binding, deco_binding)
&& deco_dset == 0 && deco_binding == 0) {
uniform_idx = global.second.element_no;
break;
}
}
BOOST_LOG_TRIVIAL(debug) << "Found uniforms at index " << uniform_idx;
auto uniforms = builder->CreateConstGEP2_32(global_type, fn->getArg(0), 0, uniform_idx);
make_conversion_code(uniforms, fn->getArg(1));
builder->CreateRetVoid();
}
void generate_support_render_frame() {
}
private:
bool get_decoration(uint32_t id, spv::Decoration deco, uint32_t &result) const {
auto found = decorations.find(std::make_pair(id, deco));
if (found != decorations.end()) {
BOOST_LOG_TRIVIAL(debug) << "Found deco " << (uint32_t)deco << " for " << id;
result = found->second;
return true;
}
BOOST_LOG_TRIVIAL(debug) << "No deco " << (uint32_t )deco << " for " << id;
return false;
}
llvm::Value* get_value(uint32_t id, llvm::Type* type) {
auto value_it = cur_function->values.find(id);
if (value_it != cur_function->values.end()) {
@@ -1067,6 +1177,44 @@ private:
return value;
}
}
void make_conversion_code(llvm::Value* dst, llvm::Value* src) const {
// dst and src are both pointers
// src is always an array
auto dst_ty = dst->getType()->getPointerElementType();
auto src_ty = src->getType()->getPointerElementType();
if (dst_ty->isFloatTy()) {
builder->CreateStore(builder->CreateLoad(dst_ty, src), dst);
} else if (llvm::isa<llvm::ArrayType>(dst_ty)) {
auto n_elem = dst_ty->getArrayNumElements();
for (uint64_t i = 0; i < n_elem; i++) {
auto s1 = builder->CreateConstGEP2_32(src_ty, src, 0, i);
auto d1 = builder->CreateConstGEP2_32(dst_ty, dst, 0, i);
make_conversion_code(d1, s1);
}
} else if (llvm::isa<llvm::StructType>(dst_ty)) {
auto n_elem = dst_ty->getStructNumElements();
for (uint64_t i = 0; i < n_elem; i++) {
auto s1 = builder->CreateConstGEP2_32(src_ty, src, 0, i);
auto d1 = builder->CreateConstGEP2_32(dst_ty, dst, 0, i);
make_conversion_code(d1, s1);
}
} else if (auto dst_vty = llvm::dyn_cast<llvm::VectorType>(dst_ty)) {
auto ec = dst_vty->getElementCount().getFixedValue();
llvm::Value* result = llvm::UndefValue::get(dst_vty);
auto float_ty = llvm::Type::getFloatTy(*ctx);
// auto src_v = builder->CreateLoad(src_ty, src);
for (uint64_t i = 0; i < ec; i++) {
auto septr = builder->CreateConstGEP2_32(src_ty, src, 0, i);
auto src_el = builder->CreateLoad(float_ty, septr);
// auto src_el = builder->CreateExtractValue(src_v, i);
result = builder->CreateInsertElement(result, src_el, i);
}
builder->CreateStore(result, dst);
}
}
void put_value(uint32_t id, llvm::Value* value) {
auto sv_it = cur_function->speculative_values.find(id);
if (sv_it != cur_function->speculative_values.end()) {
@@ -1076,6 +1224,7 @@ private:
cur_function->values[id] = value;
}
};
void InitLLVM() {
@@ -1085,11 +1234,15 @@ void InitLLVM() {
}
Compiler::Compiler(): impl(std::make_unique<CompilerImpl>()) {
Compiler::Compiler() {
}
bool Compiler::compile(std::vector<uint32_t> &spv_module) {
llvm::Optional<llvm::orc::ThreadSafeModule> Compiler::compile(std::vector<uint32_t> &spv_module) {
auto impl = std::make_unique<CompilerImpl>();
auto ret = impl->process_module(spv_module);
impl->generate_support();
llvm::verifyModule(*impl->module);
// impl->module->print(llvm::outs(), nullptr, false, true);
@@ -1100,9 +1253,10 @@ bool Compiler::compile(std::vector<uint32_t> &spv_module) {
auto target = llvm::TargetRegistry::lookupTarget(target_triple, error);
if (!target) {
BOOST_LOG_TRIVIAL(error) << "Failed to load target: " << error;
return false;
return {};
}
// dunno if this is actually necessary, but leaving it in for now.
llvm::TargetOptions opt;
auto CPU = llvm::sys::getHostCPUName();
auto reloc_model = llvm::Optional<llvm::Reloc::Model>();
@@ -1114,11 +1268,39 @@ bool Compiler::compile(std::vector<uint32_t> &spv_module) {
llvm::PassManagerBuilder pass_builder;
target_machine->adjustPassManager(pass_builder);
llvm::legacy::PassManager pass;
pass.add(llvm::createArgumentPromotionPass());
pass.add(llvm::createPromoteMemoryToRegisterPass());
pass.add(llvm::createInstructionCombiningPass());
pass.add(llvm::createReassociatePass());
pass.add(llvm::createGVNPass());
pass.add(llvm::createCFGSimplificationPass());
pass.add(llvm::createFunctionInliningPass(3, 0, false));
pass.add(llvm::createAggressiveDCEPass());
pass.add(llvm::createLICMPass());
pass_builder.populateModulePassManager(pass);
// add output to object file
std::error_code err_code;
llvm::raw_fd_ostream objfile("shader.o", err_code);
llvm::raw_fd_ostream asfile("shader.s", err_code);
if (err_code) {
BOOST_LOG_TRIVIAL(error) << "Failed to open object file: " << err_code.message();
return {};
}
// if (target_machine->addPassesToEmitFile(pass, asfile, nullptr, llvm::CGFT_AssemblyFile)) {
// BOOST_LOG_TRIVIAL(error) << "target can't create an assembly file :-(";
// }
if (target_machine->addPassesToEmitFile(pass, objfile, nullptr, llvm::CGFT_ObjectFile)) {
BOOST_LOG_TRIVIAL(error) << "target can't create an object file :-(";
}
pass.run(*impl->module);
objfile.flush();
asfile.flush();
// impl->module->print(llvm::outs(), nullptr, false, true);
return ret;
return {llvm::orc::ThreadSafeModule(std::move(impl->module),
std::move(impl->ctx))};
}
Compiler::~Compiler() = default;
@@ -1143,13 +1325,19 @@ FunctionContext::FunctionContext(struct CompilerImpl &compiler, llvm::FunctionTy
}
}
if (!name.empty()) {
// prefix the name
name = "sh_" + name;
}
for (auto &v: compiler.constants) {
values[v.first] = v.second->get_llvm_const(*compiler.ctx);
}
if (name.empty()) {
std::stringstream name_builder;
name_builder << "shaderfunc_" << id;
name_builder << "sh_func_" << id;
name = name_builder.str();
}

View File

@@ -6,15 +6,19 @@
#include <memory>
#include <vector>
#include "runtime.hpp"
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
struct CompilerImpl;
class Compiler {
std::unique_ptr<CompilerImpl> impl;
public:
Compiler();
bool compile(std::vector<uint32_t> &spv_module);
llvm::Optional<llvm::orc::ThreadSafeModule> compile(std::vector<uint32_t> &spv_module);
virtual ~Compiler();
};
void InitLLVM();

68
src/runtime.cpp Normal file
View File

@@ -0,0 +1,68 @@
//
// Created by thequux on 5/25/22.
//
#include "runtime.hpp"
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Scalar/DCE.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Transforms/Utils.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
struct RuntimeImpl {
std::unique_ptr<llvm::legacy::PassManager> passManager;
std::unique_ptr<llvm::orc::LLJIT> jit;
explicit RuntimeImpl() {
// construct the pass manager
passManager = std::make_unique<llvm::legacy::PassManager>();
passManager->add(llvm::createArgumentPromotionPass());
passManager->add(llvm::createPromoteMemoryToRegisterPass());
passManager->add(llvm::createInstructionCombiningPass());
passManager->add(llvm::createReassociatePass());
passManager->add(llvm::createGVNPass());
passManager->add(llvm::createCFGSimplificationPass());
passManager->add(llvm::createFunctionInliningPass(3, 0, false));
passManager->add(llvm::createAggressiveDCEPass());
passManager->add(llvm::createLICMPass());
// probably good for now, but need to look into further passes
}
llvm::Error setModule(llvm::orc::ThreadSafeModule mod) {
mod.withModuleDo([&] (llvm::Module &mod1){
passManager->run(mod1);
});
// construct the runtime
auto jit_builder = llvm::orc::LLJITBuilder();
jit = cantFail(jit_builder.create(), "Failed to create JIT");
if (auto err = jit->addIRModule(std::move(mod))) {
return err;
}
auto sym = jit->lookup("setup");
return llvm::Error::success();
}
};
Runtime::Runtime(): impl(std::make_unique<RuntimeImpl>()) {
}
llvm::Error Runtime::set_module(llvm::orc::ThreadSafeModule module) {
return impl->setModule(std::move(module));
}
void Runtime::render_frame(const Uniforms &uniforms, float buf[32][32][3]) {
}

35
src/runtime.hpp Normal file
View File

@@ -0,0 +1,35 @@
//
// Created by thequux on 5/25/22.
//
#pragma once
#include "compiler.hpp"
#include <llvm/IR/Module.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <memory>
struct Uniforms {
float iResolution[3];
float iTime;
float iTImeDelta;
float iFrame;
float iChannelTime[4];
float iMouse[4];
float iDate[4];
float iSampleRate;
float iChannelResolution[3][4];
};
class Runtime {
// the compiler will construct this from its internals
friend class Compiler;
std::unique_ptr<struct RuntimeImpl> impl;
protected:
Runtime();
public:
llvm::Error set_module(llvm::orc::ThreadSafeModule module);
void render_frame(const Uniforms &uniforms, float buf[32][32][3]);
};