summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cupp11.hpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/src/cupp11.hpp b/src/cupp11.hpp
index 71fdc3cd..ec21c5b1 100644
--- a/src/cupp11.hpp
+++ b/src/cupp11.hpp
@@ -411,6 +411,7 @@ public:
}
auto status = nvrtcCompileProgram(*program_, raw_options.size(), raw_options.data());
CLCudaAPINVRTCError::Check(status, "nvrtcCompileProgram");
+ CheckError(cuModuleLoadDataEx(&module_, GetIR().data(), 0, nullptr, nullptr));
}
// Confirms whether a certain status code is an actual compilation error or warning
@@ -440,10 +441,12 @@ public:
return result;
}
- // Accessor to the private data-member
+ // Accessor to the private data-members
+ const CUmodule GetModule() const { return module_; }
const nvrtcProgram& operator()() const { return *program_; }
private:
std::shared_ptr<nvrtcProgram> program_;
+ CUmodule module_;
std::string source_;
bool from_binary_;
};
@@ -665,16 +668,14 @@ class Kernel {
public:
// Constructor based on the regular CUDA data-type: memory management is handled elsewhere
- explicit Kernel(const CUmodule module, const CUfunction kernel):
+ explicit Kernel(const CUfunction kernel):
name_("unknown"),
- module_(module),
kernel_(kernel) {
}
// Regular constructor with memory management
explicit Kernel(const Program &program, const std::string &name): name_(name) {
- CheckError(cuModuleLoadDataEx(&module_, program.GetIR().data(), 0, nullptr, nullptr));
- CheckError(cuModuleGetFunction(&kernel_, module_, name.c_str()));
+ CheckError(cuModuleGetFunction(&kernel_, program.GetModule(), name.c_str()));
}
// Sets a kernel argument at the indicated position. This stores both the value of the argument
@@ -758,7 +759,6 @@ public:
CUfunction operator()() { return kernel_; }
private:
const std::string name_;
- CUmodule module_;
CUfunction kernel_;
std::vector<size_t> arguments_indices_; // Indices of the arguments
std::vector<char> arguments_data_; // The arguments data as raw bytes