diff mbox

[FFmpeg-devel,2/2] vf_dnn_processing: add support for more formats gray8 and grayf32

Message ID 1577435660-11904-1-git-send-email-yejun.guo@intel.com
State New
Headers show

Commit Message

Guo, Yejun Dec. 27, 2019, 8:34 a.m. UTC
The following is a python script to halve the value of the gray
image. It demos how to setup and execute dnn model with python+tensorflow.
It also generates .pb file which will be used by ffmpeg.

import tensorflow as tf
import numpy as np
from skimage import color
from skimage import io
in_img = io.imread('input.jpg')
in_img = color.rgb2gray(in_img)
io.imsave('ori_gray.jpg', np.squeeze(in_img))
in_data = np.expand_dims(in_img, axis=0)
in_data = np.expand_dims(in_data, axis=3)
filter_data = np.array([0.5]).reshape(1,1,1,1).astype(np.float32)
filter = tf.Variable(filter_data)
x = tf.placeholder(tf.float32, shape=[1, None, None, 1], name='dnn_in')
y = tf.nn.conv2d(x, filter, strides=[1, 1, 1, 1], padding='VALID', name='dnn_out')
sess=tf.Session()
sess.run(tf.global_variables_initializer())
graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['dnn_out'])
tf.train.write_graph(graph_def, '.', 'halve_gray_float.pb', as_text=False)
print("halve_gray_float.pb generated, please use \
path_to_ffmpeg/tools/python/convert.py to generate halve_gray_float.model\n")
output = sess.run(y, feed_dict={x: in_data})
output = output * 255.0
output = output.astype(np.uint8)
io.imsave("out.jpg", np.squeeze(output))

To do the same thing with ffmpeg:
- generate halve_gray_float.pb with the above script
- generate halve_gray_float.model with tools/python/convert.py
- try with following commands
  ./ffmpeg -i input.jpg -vf format=grayf32,dnn_processing=model=halve_gray_float.model:input=dnn_in:output=dnn_out:dnn_backend=native out.native.png
  ./ffmpeg -i input.jpg -vf format=grayf32,dnn_processing=model=halve_gray_float.pb:input=dnn_in:output=dnn_out:dnn_backend=tensorflow out.tf.png

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
---
 doc/filters.texi                |   6 ++
 libavfilter/vf_dnn_processing.c | 168 ++++++++++++++++++++++++++++++----------
 2 files changed, 132 insertions(+), 42 deletions(-)

Comments

Guo, Yejun Jan. 7, 2020, 6:37 a.m. UTC | #1
> -----Original Message-----
> From: Guo, Yejun
> Sent: Friday, December 27, 2019 4:34 PM
> To: ffmpeg-devel@ffmpeg.org
> Cc: Guo, Yejun <yejun.guo@intel.com>
> Subject: [PATCH 2/2] vf_dnn_processing: add support for more formats gray8
> and grayf32

this patch set asks for review, thanks.

btw, I'll add the fate test after this patch set is reviewed.
Pedro Arthur Jan. 7, 2020, 1:57 p.m. UTC | #2
Em sex., 27 de dez. de 2019 às 05:42, Guo, Yejun <yejun.guo@intel.com> escreveu:
>
> The following is a python script to halve the value of the gray
> image. It demos how to setup and execute dnn model with python+tensorflow.
> It also generates .pb file which will be used by ffmpeg.
>
> import tensorflow as tf
> import numpy as np
> from skimage import color
> from skimage import io
> in_img = io.imread('input.jpg')
> in_img = color.rgb2gray(in_img)
> io.imsave('ori_gray.jpg', np.squeeze(in_img))
> in_data = np.expand_dims(in_img, axis=0)
> in_data = np.expand_dims(in_data, axis=3)
> filter_data = np.array([0.5]).reshape(1,1,1,1).astype(np.float32)
> filter = tf.Variable(filter_data)
> x = tf.placeholder(tf.float32, shape=[1, None, None, 1], name='dnn_in')
> y = tf.nn.conv2d(x, filter, strides=[1, 1, 1, 1], padding='VALID', name='dnn_out')
> sess=tf.Session()
> sess.run(tf.global_variables_initializer())
> graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['dnn_out'])
> tf.train.write_graph(graph_def, '.', 'halve_gray_float.pb', as_text=False)
> print("halve_gray_float.pb generated, please use \
> path_to_ffmpeg/tools/python/convert.py to generate halve_gray_float.model\n")
> output = sess.run(y, feed_dict={x: in_data})
> output = output * 255.0
> output = output.astype(np.uint8)
> io.imsave("out.jpg", np.squeeze(output))
>
> To do the same thing with ffmpeg:
> - generate halve_gray_float.pb with the above script
> - generate halve_gray_float.model with tools/python/convert.py
> - try with following commands
>   ./ffmpeg -i input.jpg -vf format=grayf32,dnn_processing=model=halve_gray_float.model:input=dnn_in:output=dnn_out:dnn_backend=native out.native.png
>   ./ffmpeg -i input.jpg -vf format=grayf32,dnn_processing=model=halve_gray_float.pb:input=dnn_in:output=dnn_out:dnn_backend=tensorflow out.tf.png
>
> Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
> ---
>  doc/filters.texi                |   6 ++
>  libavfilter/vf_dnn_processing.c | 168 ++++++++++++++++++++++++++++++----------
>  2 files changed, 132 insertions(+), 42 deletions(-)
>
> diff --git a/doc/filters.texi b/doc/filters.texi
> index f467378..57a129d 100644
> --- a/doc/filters.texi
> +++ b/doc/filters.texi
> @@ -9075,6 +9075,12 @@ Halve the red channle of the frame with format rgb24:
>  ffmpeg -i input.jpg -vf format=rgb24,dnn_processing=model=halve_first_channel.model:input=dnn_in:output=dnn_out:dnn_backend=native out.native.png
>  @end example
>
> +@item
> +Halve the pixel value of the frame with format gray32f:
> +@example
> +ffmpeg -i input.jpg -vf format=grayf32,dnn_processing=model=halve_gray_float.model:input=dnn_in:output=dnn_out:dnn_backend=native -y out.native.png
> +@end example
> +
>  @end itemize
>
>  @section drawbox
> diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
> index 4a6b900..13273f2 100644
> --- a/libavfilter/vf_dnn_processing.c
> +++ b/libavfilter/vf_dnn_processing.c
> @@ -104,12 +104,20 @@ static int query_formats(AVFilterContext *context)
>  {
>      static const enum AVPixelFormat pix_fmts[] = {
>          AV_PIX_FMT_RGB24, AV_PIX_FMT_BGR24,
> +        AV_PIX_FMT_GRAY8, AV_PIX_FMT_GRAYF32,
>          AV_PIX_FMT_NONE
>      };
>      AVFilterFormats *fmts_list = ff_make_format_list(pix_fmts);
>      return ff_set_common_formats(context, fmts_list);
>  }
>
> +#define LOG_FORMAT_CHANNEL_MISMATCH()                       \
> +    av_log(ctx, AV_LOG_ERROR,                               \
> +           "the frame's format %s does not match "          \
> +           "the model input channel %d\n",                  \
> +           av_get_pix_fmt_name(fmt),                        \
> +           model_input->channels);
> +
>  static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink)
>  {
>      AVFilterContext *ctx   = inlink->dst;
> @@ -131,17 +139,34 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
>      case AV_PIX_FMT_RGB24:
>      case AV_PIX_FMT_BGR24:
>          if (model_input->channels != 3) {
> -            av_log(ctx, AV_LOG_ERROR, "the frame's input format %s does not match "
> -                                       "the model input channels %d\n",
> -                                       av_get_pix_fmt_name(fmt),
> -                                       model_input->channels);
> +            LOG_FORMAT_CHANNEL_MISMATCH();
>              return AVERROR(EIO);
>          }
>          if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) {
>              av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
>              return AVERROR(EIO);
>          }
> -        break;
> +        return 0;
> +    case AV_PIX_FMT_GRAY8:
> +        if (model_input->channels != 1) {
> +            LOG_FORMAT_CHANNEL_MISMATCH();
> +            return AVERROR(EIO);
> +        }
> +        if (model_input->dt != DNN_UINT8) {
> +            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n");
> +            return AVERROR(EIO);
> +        }
> +        return 0;
> +    case AV_PIX_FMT_GRAYF32:
> +        if (model_input->channels != 1) {
> +            LOG_FORMAT_CHANNEL_MISMATCH();
> +            return AVERROR(EIO);
> +        }
> +        if (model_input->dt != DNN_FLOAT) {
> +            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n");
> +            return AVERROR(EIO);
> +        }
> +        return 0;
>      default:
>          av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt));
>          return AVERROR(EIO);
> @@ -206,28 +231,58 @@ static int config_output(AVFilterLink *outlink)
>
>  static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
>  {
> -    // extend this function to support more formats
> -    av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
> -
> -    if (dnn_input->dt == DNN_FLOAT) {
> -        float *dnn_input_data = dnn_input->data;
> -        for (int i = 0; i < frame->height; i++) {
> -            for(int j = 0; j < frame->width * 3; j++) {
> -                int k = i * frame->linesize[0] + j;
> -                int t = i * frame->width * 3 + j;
> -                dnn_input_data[t] = frame->data[0][k] / 255.0f;
> +    switch (frame->format) {
> +    case AV_PIX_FMT_RGB24:
> +    case AV_PIX_FMT_BGR24:
> +        if (dnn_input->dt == DNN_FLOAT) {
> +            float *dnn_input_data = dnn_input->data;
> +            for (int i = 0; i < frame->height; i++) {
> +                for(int j = 0; j < frame->width * 3; j++) {
> +                    int k = i * frame->linesize[0] + j;
> +                    int t = i * frame->width * 3 + j;
> +                    dnn_input_data[t] = frame->data[0][k] / 255.0f;
> +                }
> +            }
> +        } else {
> +            uint8_t *dnn_input_data = dnn_input->data;
> +            av_assert0(dnn_input->dt == DNN_UINT8);
> +            for (int i = 0; i < frame->height; i++) {
> +                for(int j = 0; j < frame->width * 3; j++) {
> +                    int k = i * frame->linesize[0] + j;
> +                    int t = i * frame->width * 3 + j;
> +                    dnn_input_data[t] = frame->data[0][k];
> +                }
>              }
>          }
> -    } else {
> -        uint8_t *dnn_input_data = dnn_input->data;
> -        av_assert0(dnn_input->dt == DNN_UINT8);
> -        for (int i = 0; i < frame->height; i++) {
> -            for(int j = 0; j < frame->width * 3; j++) {
> -                int k = i * frame->linesize[0] + j;
> -                int t = i * frame->width * 3 + j;
> -                dnn_input_data[t] = frame->data[0][k];
> +        return 0;
> +    case AV_PIX_FMT_GRAY8:
> +        {
> +            uint8_t *dnn_input_data = dnn_input->data;
> +            av_assert0(dnn_input->dt == DNN_UINT8);
> +            for (int i = 0; i < frame->height; i++) {
> +                for(int j = 0; j < frame->width; j++) {
> +                    int k = i * frame->linesize[0] + j;
> +                    int t = i * frame->width + j;
> +                    dnn_input_data[t] = frame->data[0][k];
> +                }
>              }
>          }
> +        return 0;
> +    case AV_PIX_FMT_GRAYF32:
> +        {
> +            float *dnn_input_data = dnn_input->data;
> +            av_assert0(dnn_input->dt == DNN_FLOAT);
> +            for (int i = 0; i < frame->height; i++) {
> +                for(int j = 0; j < frame->width; j++) {
> +                    int k = i * frame->linesize[0] + j * sizeof(float);
> +                    int t = i * frame->width + j;
> +                    dnn_input_data[t] = *(float*)(frame->data[0] + k);
> +                }
> +            }
> +        }
> +        return 0;
> +    default:
> +        return AVERROR(EIO);
>      }
>
>      return 0;
> @@ -235,28 +290,58 @@ static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
>
>  static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output)
>  {
> -    // extend this function to support more formats
> -    av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
> -
> -    if (dnn_output->dt == DNN_FLOAT) {
> -        float *dnn_output_data = dnn_output->data;
> -        for (int i = 0; i < frame->height; i++) {
> -            for(int j = 0; j < frame->width * 3; j++) {
> -                int k = i * frame->linesize[0] + j;
> -                int t = i * frame->width * 3 + j;
> -                frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
> +    switch (frame->format) {
> +    case AV_PIX_FMT_RGB24:
> +    case AV_PIX_FMT_BGR24:
> +        if (dnn_output->dt == DNN_FLOAT) {
> +            float *dnn_output_data = dnn_output->data;
> +            for (int i = 0; i < frame->height; i++) {
> +                for(int j = 0; j < frame->width * 3; j++) {
> +                    int k = i * frame->linesize[0] + j;
> +                    int t = i * frame->width * 3 + j;
> +                    frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
> +                }
> +            }
> +        } else {
> +            uint8_t *dnn_output_data = dnn_output->data;
> +            av_assert0(dnn_output->dt == DNN_UINT8);
> +            for (int i = 0; i < frame->height; i++) {
> +                for(int j = 0; j < frame->width * 3; j++) {
> +                    int k = i * frame->linesize[0] + j;
> +                    int t = i * frame->width * 3 + j;
> +                    frame->data[0][k] = dnn_output_data[t];
> +                }
> +            }
> +        }
> +        return 0;
> +    case AV_PIX_FMT_GRAY8:
> +        {
> +            uint8_t *dnn_output_data = dnn_output->data;
> +            av_assert0(dnn_output->dt == DNN_UINT8);
> +            for (int i = 0; i < frame->height; i++) {
> +                for(int j = 0; j < frame->width; j++) {
> +                    int k = i * frame->linesize[0] + j;
> +                    int t = i * frame->width + j;
> +                    frame->data[0][k] = dnn_output_data[t];
> +                }
>              }
>          }
> -    } else {
> -        uint8_t *dnn_output_data = dnn_output->data;
> -        av_assert0(dnn_output->dt == DNN_UINT8);
> -        for (int i = 0; i < frame->height; i++) {
> -            for(int j = 0; j < frame->width * 3; j++) {
> -                int k = i * frame->linesize[0] + j;
> -                int t = i * frame->width * 3 + j;
> -                frame->data[0][k] = dnn_output_data[t];
> +        return 0;
> +    case AV_PIX_FMT_GRAYF32:
> +        {
> +            float *dnn_output_data = dnn_output->data;
> +            av_assert0(dnn_output->dt == DNN_FLOAT);
> +            for (int i = 0; i < frame->height; i++) {
> +                for(int j = 0; j < frame->width; j++) {
> +                    int k = i * frame->linesize[0] + j * sizeof(float);
> +                    int t = i * frame->width + j;
> +                    *(float*)(frame->data[0] + k) = dnn_output_data[t];
> +                }
>              }
>          }
> +        return 0;
> +    default:
> +        return AVERROR(EIO);
>      }
>
>      return 0;
> @@ -278,7 +363,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
>          av_frame_free(&in);
>          return AVERROR(EIO);
>      }
> -    av_assert0(ctx->output.channels == 3);
>
>      out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
>      if (!out) {
> --
> 2.7.4
>
LGTM,
pushed thanks.

> _______________________________________________
> ffmpeg-devel mailing list
> ffmpeg-devel@ffmpeg.org
> https://ffmpeg.org/mailman/listinfo/ffmpeg-devel
>
> To unsubscribe, visit link above, or email
> ffmpeg-devel-request@ffmpeg.org with subject "unsubscribe".
diff mbox

Patch

diff --git a/doc/filters.texi b/doc/filters.texi
index f467378..57a129d 100644
--- a/doc/filters.texi
+++ b/doc/filters.texi
@@ -9075,6 +9075,12 @@  Halve the red channle of the frame with format rgb24:
 ffmpeg -i input.jpg -vf format=rgb24,dnn_processing=model=halve_first_channel.model:input=dnn_in:output=dnn_out:dnn_backend=native out.native.png
 @end example
 
+@item
+Halve the pixel value of the frame with format gray32f:
+@example
+ffmpeg -i input.jpg -vf format=grayf32,dnn_processing=model=halve_gray_float.model:input=dnn_in:output=dnn_out:dnn_backend=native -y out.native.png
+@end example
+
 @end itemize
 
 @section drawbox
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index 4a6b900..13273f2 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -104,12 +104,20 @@  static int query_formats(AVFilterContext *context)
 {
     static const enum AVPixelFormat pix_fmts[] = {
         AV_PIX_FMT_RGB24, AV_PIX_FMT_BGR24,
+        AV_PIX_FMT_GRAY8, AV_PIX_FMT_GRAYF32,
         AV_PIX_FMT_NONE
     };
     AVFilterFormats *fmts_list = ff_make_format_list(pix_fmts);
     return ff_set_common_formats(context, fmts_list);
 }
 
+#define LOG_FORMAT_CHANNEL_MISMATCH()                       \
+    av_log(ctx, AV_LOG_ERROR,                               \
+           "the frame's format %s does not match "          \
+           "the model input channel %d\n",                  \
+           av_get_pix_fmt_name(fmt),                        \
+           model_input->channels);
+
 static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink)
 {
     AVFilterContext *ctx   = inlink->dst;
@@ -131,17 +139,34 @@  static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
     case AV_PIX_FMT_RGB24:
     case AV_PIX_FMT_BGR24:
         if (model_input->channels != 3) {
-            av_log(ctx, AV_LOG_ERROR, "the frame's input format %s does not match "
-                                       "the model input channels %d\n",
-                                       av_get_pix_fmt_name(fmt),
-                                       model_input->channels);
+            LOG_FORMAT_CHANNEL_MISMATCH();
             return AVERROR(EIO);
         }
         if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) {
             av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
             return AVERROR(EIO);
         }
-        break;
+        return 0;
+    case AV_PIX_FMT_GRAY8:
+        if (model_input->channels != 1) {
+            LOG_FORMAT_CHANNEL_MISMATCH();
+            return AVERROR(EIO);
+        }
+        if (model_input->dt != DNN_UINT8) {
+            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n");
+            return AVERROR(EIO);
+        }
+        return 0;
+    case AV_PIX_FMT_GRAYF32:
+        if (model_input->channels != 1) {
+            LOG_FORMAT_CHANNEL_MISMATCH();
+            return AVERROR(EIO);
+        }
+        if (model_input->dt != DNN_FLOAT) {
+            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n");
+            return AVERROR(EIO);
+        }
+        return 0;
     default:
         av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt));
         return AVERROR(EIO);
@@ -206,28 +231,58 @@  static int config_output(AVFilterLink *outlink)
 
 static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
 {
-    // extend this function to support more formats
-    av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
-
-    if (dnn_input->dt == DNN_FLOAT) {
-        float *dnn_input_data = dnn_input->data;
-        for (int i = 0; i < frame->height; i++) {
-            for(int j = 0; j < frame->width * 3; j++) {
-                int k = i * frame->linesize[0] + j;
-                int t = i * frame->width * 3 + j;
-                dnn_input_data[t] = frame->data[0][k] / 255.0f;
+    switch (frame->format) {
+    case AV_PIX_FMT_RGB24:
+    case AV_PIX_FMT_BGR24:
+        if (dnn_input->dt == DNN_FLOAT) {
+            float *dnn_input_data = dnn_input->data;
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width * 3; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width * 3 + j;
+                    dnn_input_data[t] = frame->data[0][k] / 255.0f;
+                }
+            }
+        } else {
+            uint8_t *dnn_input_data = dnn_input->data;
+            av_assert0(dnn_input->dt == DNN_UINT8);
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width * 3; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width * 3 + j;
+                    dnn_input_data[t] = frame->data[0][k];
+                }
             }
         }
-    } else {
-        uint8_t *dnn_input_data = dnn_input->data;
-        av_assert0(dnn_input->dt == DNN_UINT8);
-        for (int i = 0; i < frame->height; i++) {
-            for(int j = 0; j < frame->width * 3; j++) {
-                int k = i * frame->linesize[0] + j;
-                int t = i * frame->width * 3 + j;
-                dnn_input_data[t] = frame->data[0][k];
+        return 0;
+    case AV_PIX_FMT_GRAY8:
+        {
+            uint8_t *dnn_input_data = dnn_input->data;
+            av_assert0(dnn_input->dt == DNN_UINT8);
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width + j;
+                    dnn_input_data[t] = frame->data[0][k];
+                }
             }
         }
+        return 0;
+    case AV_PIX_FMT_GRAYF32:
+        {
+            float *dnn_input_data = dnn_input->data;
+            av_assert0(dnn_input->dt == DNN_FLOAT);
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width; j++) {
+                    int k = i * frame->linesize[0] + j * sizeof(float);
+                    int t = i * frame->width + j;
+                    dnn_input_data[t] = *(float*)(frame->data[0] + k);
+                }
+            }
+        }
+        return 0;
+    default:
+        return AVERROR(EIO);
     }
 
     return 0;
@@ -235,28 +290,58 @@  static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
 
 static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output)
 {
-    // extend this function to support more formats
-    av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
-
-    if (dnn_output->dt == DNN_FLOAT) {
-        float *dnn_output_data = dnn_output->data;
-        for (int i = 0; i < frame->height; i++) {
-            for(int j = 0; j < frame->width * 3; j++) {
-                int k = i * frame->linesize[0] + j;
-                int t = i * frame->width * 3 + j;
-                frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
+    switch (frame->format) {
+    case AV_PIX_FMT_RGB24:
+    case AV_PIX_FMT_BGR24:
+        if (dnn_output->dt == DNN_FLOAT) {
+            float *dnn_output_data = dnn_output->data;
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width * 3; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width * 3 + j;
+                    frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
+                }
+            }
+        } else {
+            uint8_t *dnn_output_data = dnn_output->data;
+            av_assert0(dnn_output->dt == DNN_UINT8);
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width * 3; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width * 3 + j;
+                    frame->data[0][k] = dnn_output_data[t];
+                }
+            }
+        }
+        return 0;
+    case AV_PIX_FMT_GRAY8:
+        {
+            uint8_t *dnn_output_data = dnn_output->data;
+            av_assert0(dnn_output->dt == DNN_UINT8);
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width + j;
+                    frame->data[0][k] = dnn_output_data[t];
+                }
             }
         }
-    } else {
-        uint8_t *dnn_output_data = dnn_output->data;
-        av_assert0(dnn_output->dt == DNN_UINT8);
-        for (int i = 0; i < frame->height; i++) {
-            for(int j = 0; j < frame->width * 3; j++) {
-                int k = i * frame->linesize[0] + j;
-                int t = i * frame->width * 3 + j;
-                frame->data[0][k] = dnn_output_data[t];
+        return 0;
+    case AV_PIX_FMT_GRAYF32:
+        {
+            float *dnn_output_data = dnn_output->data;
+            av_assert0(dnn_output->dt == DNN_FLOAT);
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width; j++) {
+                    int k = i * frame->linesize[0] + j * sizeof(float);
+                    int t = i * frame->width + j;
+                    *(float*)(frame->data[0] + k) = dnn_output_data[t];
+                }
             }
         }
+        return 0;
+    default:
+        return AVERROR(EIO);
     }
 
     return 0;
@@ -278,7 +363,6 @@  static int filter_frame(AVFilterLink *inlink, AVFrame *in)
         av_frame_free(&in);
         return AVERROR(EIO);
     }
-    av_assert0(ctx->output.channels == 3);
 
     out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     if (!out) {