-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathSAM.h
59 lines (54 loc) · 1.44 KB
/
SAM.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <vector>
#include <string>
#include <memory>
#include <iostream>
#include "opencv2/core.hpp"
#include "common.h"
#include "onnxruntime_cxx_api.h"
using namespace std;
using namespace cv;
/// <summary>
/// Image Transform
/// Normalization and Resize
/// </summary>
class SAM_EXPORTS Transform
{
public:
Transform(int targetwidth);
cv::Mat TransformImage(cv::Mat& orgimg);
protected:
/// <summary>
/// Get Transformed Image Size
/// </summary>
void GetPreprocessShape(int oldw, int oldh, int long_side_length, int& neww, int& newh);
int m_targetWidth;
};
/// <summary>
/// Semgement Anything C++ Inference
/// </summary>
class SAM_EXPORTS SAM
{
public:
SAM(int targetsize);
/// <summary>
/// Image Encoding
/// </summary>
void ImageEncode(string imgpath);
void Decoder(std::vector<float>promotions,std::vector<float>labels,int promotionCount);
protected:
std::unique_ptr<Ort::Session> m_Encoder;//Image Encoder
std::unique_ptr<Ort::Session> m_Decoder;//Image Decoder
std::unique_ptr<Transform> m_Transform;
std::unique_ptr<Ort::Env>m_env;
std::unique_ptr<Ort::SessionOptions>m_sessionOption;
std::unique_ptr<float> m_ImgEmbedding;
std::vector<int64_t> m_ImgEmbeddingshape;
std::unique_ptr<float>m_resultMask;
int m_orgWid;
int m_orgHei;
int m_targetSize;
/// <summary>
/// Load ONNX PreTrained Models
/// </summary>
void LoadOnnxModel();
};