@@ -331,6 +331,8 @@ typedef CUresult CUDAAPI tcuMemcpy2DAsync_v2(const CUDA_MEMCPY2D *pcopy, CUstrea
typedef CUresult CUDAAPI tcuGetErrorName(CUresult error, const char** pstr);
typedef CUresult CUDAAPI tcuGetErrorString(CUresult error, const char** pstr);
typedef CUresult CUDAAPI tcuCtxGetDevice(CUdevice *device);
+typedef CUresult CUDAAPI tcuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev);
+typedef CUresult CUDAAPI tcuDevicePrimaryCtxRelease(CUdevice dev);
typedef CUresult CUDAAPI tcuStreamCreate(CUstream *phStream, unsigned int flags);
typedef CUresult CUDAAPI tcuStreamQuery(CUstream hStream);
@@ -157,6 +157,8 @@ typedef struct CudaFunctions {
tcuGetErrorName *cuGetErrorName;
tcuGetErrorString *cuGetErrorString;
tcuCtxGetDevice *cuCtxGetDevice;
+ tcuDevicePrimaryCtxRetain *cuDevicePrimaryCtxRetain;
+ tcuDevicePrimaryCtxRelease *cuDevicePrimaryCtxRelease;
tcuStreamCreate *cuStreamCreate;
tcuStreamQuery *cuStreamQuery;
@@ -282,6 +284,8 @@ static inline int cuda_load_functions(CudaFunctions **functions, void *logctx)
LOAD_SYMBOL(cuGetErrorName, tcuGetErrorName, "cuGetErrorName");
LOAD_SYMBOL(cuGetErrorString, tcuGetErrorString, "cuGetErrorString");
LOAD_SYMBOL(cuCtxGetDevice, tcuCtxGetDevice, "cuCtxGetDevice");
+ LOAD_SYMBOL(cuDevicePrimaryCtxRetain, tcuDevicePrimaryCtxRetain, "cuDevicePrimaryCtxRetain");
+ LOAD_SYMBOL(cuDevicePrimaryCtxRelease, tcuDevicePrimaryCtxRelease, "cuDevicePrimaryCtxRelease");
LOAD_SYMBOL(cuStreamCreate, tcuStreamCreate, "cuStreamCreate");
LOAD_SYMBOL(cuStreamQuery, tcuStreamQuery, "cuStreamQuery");