iOS相比Android移植TensorFlow没那么方便,要用C++来编写,接下来讲一下iOS调用TensorFlow的过程。
- 引入依赖
在Podfile中加入pod 'TensorFlow-experimental',再在terminal中cd进项目目录输入pod install即可安装依赖。
- 复制PB文件
快速开发的话直接把PB文件放在data文件夹里就行,如果正式上线的时候觉得PB文件一起打包较大的话可以放在服务器,打开APP的时候提示下载再复制进去就好。
- 引入头文件、命名空间
#import <opencv2/imgcodecs/ios.h>#include "tensorflow/cc/ops/const_op.h"#include "tensorflow/core/framework/op_kernel.h"#include "tensorflow/core/public/session.h"#include <tensorflow/core/kernels/reshape_op.h>#include <tensorflow/core/kernels/argmax_op.h>using namespace tensorflow;using namespace tensorflow::core;
- 处理数据
图像处理相比于Android的bitmap操作还是较为麻烦,iOS需要用到opencv,所以还需要引入opencv的依赖,通过cv的UIImageToMat方法吧UIImage转成cv::Mat再进行矩阵操作(类似:灰度处理、归一化、平展)
UIImage *image = [UIImage imageNamed:@"OOLU8095571.jpg"]; self.preImageView.contentMode = UIViewContentModeRedraw; UIImageToMat(image,cvMatImage); cvMatImage.convertTo(cvMatImage, CV_32F, 1.0/255., 0);//归一化 cv::Mat reshapeMat= cvMatImage.reshape(0,1);//reshape NSString* inference_result = RunModel(reshapeMat); self.urlContentTextView.text = inference_result;
RunModel(reshapeMat)就是把处理过的数据传递给TensorFlow去运算了。
- 定义常量
这里跟Android差不多,定义一些必要的常量,输入输出节点,输出输出节点数据,图像尺寸、通道等
std::string input_layer = "inputs/X"; std::string output_layer = "output/predict"; tensorflow::Tensor x( tensorflow::DT_FLOAT, tensorflow::TensorShape({wanted_height*wanted_width})); std::vector<tensorflow::Tensor> outputs; const int wanted_width = 256; const int wanted_height = 64; const int wanted_channels = 1;
- 创建session
这里跟Android不同,需要手动创建session
tensorflow::SessionOptions options; tensorflow::Session* session_pointer = nullptr; tensorflow::Status session_status = tensorflow::NewSession(options, &session_pointer); if (!session_status.ok()) { std::string status_string = session_status.ToString(); return [NSString stringWithFormat: @"Session create failed - %s", status_string.c_str()]; } std::unique_ptr<tensorflow::Session> session(session_pointer);
- 载入graph
tensorflow::GraphDef tensorflow_graph; NSString* network_path = FilePathForResourceName(@"rounded_graph", @"pb"); PortableReadFileToProto([network_path UTF8String], &tensorflow_graph); tensorflow::Status s = session->Create(tensorflow_graph); if (!s.ok()) { LOG(ERROR) << "Could not create TensorFlow Graph: " << s; return @""; }
其中FilePathForResourceName是返回graph的地址
NSString* FilePathForResourceName(NSString* name, NSString* extension) { NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; if (file_path == NULL) { LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] << "' in bundle."; } return file_path;}
PortableReadFileToProto是把graph读到内存中并赋值给tensorflow_graph,并使用session->Create(tensorflow_graph)把graph载入到session中。
- 输入数据的类型转换
输入到TensorFlow的数据不能是mat类型的所以进行mat转vector操作
vector<float> Vmat; Vmat.assign ( ( float* )ImageMat.datastart, ( float* )ImageMat.dataend ); auto dst = x.flat<float>().data(); auto img = Vmat; std::copy_n(img.begin(), wanted_width*wanted_height, dst);
- run session
tensorflow::Status run_status = session->Run({{input_layer, x}}, {output_layer}, {}, &outputs); if (!run_status.ok()) { LOG(ERROR) << "Running model failed: " << run_status; tensorflow::LogAllRegisteredKernels(); result = @"Error running model"; return result; }
- 数据变换
使用operator方法获取到tensor中的每一个元素值,重新赋值给array。
auto outputMatrix = outputs[0].flat<int64>(); array<long,11> outputArray; for(int i=0;i<11;i++){ outputArray[i]=outputMatrix.operator()(i); }NSString *predictionstr = vec2text(outputArray);
获取完之后需要对数据进行处理,比如我们做的vector转text。
NSString* vec2text(array<long,11> outputArray) { std::stringstream ss; ss.precision(12); ss <<"Prediction:"; for(int i=0;i<11;i++){ long char_idx=outputArray[i]; long char_code = 0; if (char_idx<10){ char_code = char_idx + int('0'); } else if (char_idx<36){ char_code = char_idx-10 + int('A'); } else if (char_idx<62){ char_code = char_idx + int('a'); } ss << char(char_code); } tensorflow::string predictions = ss.str(); NSString* result = [NSString stringWithFormat: @"%s", predictions.c_str()]; return result;}
iOS调用TensorFlow的基础运用就这样,高级用法可以使用MemoryMappedModel,这种方法会比较节省内存,更加优雅。
原著是一个有趣的人,若有侵权,请通知删除
还没有人抢沙发呢~