@@ -250,6 +250,10 @@ static int th_start_inference(void *args)
av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
return DNN_GENERIC_ERROR;
}
+ // Transfer tensor to the same device as model
+ c10::Device device = (*th_model->jit_model->parameters().begin()).device();
+ if (infer_request->input_tensor->device() != device)
+ *infer_request->input_tensor = infer_request->input_tensor->to(device);
inputs.push_back(*infer_request->input_tensor);
*infer_request->output = th_model->jit_model->forward(inputs).toTensor();
@@ -285,6 +289,9 @@ static void infer_completion_callback(void *args) {
switch (th_model->model.func_type) {
case DFT_PROCESS_FRAME:
if (task->do_ioproc) {
+ // Post process can only deal with CPU memory.
+ if (output->device() != torch::kCPU)
+ *output = output->to(torch::kCPU);
outputs.scale = 255;
outputs.data = output->data_ptr();
if (th_model->model.frame_post_proc != NULL) {
@@ -424,7 +431,13 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
th_model->ctx = ctx;
c10::Device device = c10::Device(device_name);
- if (!device.is_cpu()) {
+ if (device.is_xpu()) {
+ if (!at::hasXPU()) {
+ av_log(ctx, AV_LOG_ERROR, "No XPU device found\n");
+ goto fail;
+ }
+ at::detail::getXPUHooks().initXPU();
+ } else if (!device.is_cpu()) {
av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name);
goto fail;
}
@@ -432,6 +445,7 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
try {
th_model->jit_model = new torch::jit::Module;
(*th_model->jit_model) = torch::jit::load(ctx->model_filename);
+ th_model->jit_model->to(device);
} catch (const c10::Error& e) {
av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
goto fail;