-
Notifications
You must be signed in to change notification settings - Fork 25
/
extractOpenAIEmbeddings.m
71 lines (61 loc) · 2.59 KB
/
extractOpenAIEmbeddings.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
function [emb, response] = extractOpenAIEmbeddings(text, nvp)
% EXTRACTOPENAIEMBEDDINGS Generate text embeddings using the OpenAI API
%
% emb = EXTRACTOPENAIEMBEDDINGS(text) generates an embedding of the input
% TEXT using the OpenAI API.
%
% emb = EXTRACTOPENAIEMBEDDINGS(text,Name=Value) specifies optional
% specifies additional options using one or more name-value pairs:
%
% 'ModelName' - The ID of the model to use.
%
% 'APIKey' - OpenAI API token. It can also be specified by
% setting the environment variable OPENAI_API_KEY
%
% 'TimeOut' - Connection Timeout in seconds (default: 10 secs)
%
% 'Dimensions' - Number of dimensions the resulting output
% embeddings should have.
%
% [emb, response] = EXTRACTOPENAIEMBEDDINGS(...) also returns the full
% response from the OpenAI API call.
%
% Copyright 2023-2024 The MathWorks, Inc.
arguments
text (1,:) {mustBeNonzeroLengthText}
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ...
"text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002"
nvp.TimeOut (1,1) {mustBeNumeric,mustBeReal,mustBePositive} = 10
nvp.Dimensions (1,1) {mustBeNumeric,mustBeInteger,mustBePositive}
nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar}
end
END_POINT = "https://api.openai.com/v1/embeddings";
key = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY");
parameters = struct("input",text,"model",nvp.ModelName);
if isfield(nvp, "Dimensions")
if nvp.ModelName=="text-embedding-ada-002"
error("llms:invalidOptionForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Dimensions", nvp.ModelName));
end
mustBeCorrectDimensions(nvp.Dimensions,nvp.ModelName);
parameters.dimensions = nvp.Dimensions;
end
response = llms.internal.sendRequestWrapper(parameters,key, END_POINT, nvp.TimeOut);
if isfield(response.Body.Data, "data")
emb = [response.Body.Data.data.embedding];
emb = emb';
else
emb = [];
end
end
function mustBeCorrectDimensions(dimensions,modelName)
model2dim = ....
dictionary(["text-embedding-3-large", "text-embedding-3-small"], ...
[3072,1536]);
mustBeNumeric(dimensions);
if dimensions>model2dim(modelName)
error("llms:dimensionsMustBeSmallerThan", ...
llms.utils.errorMessageCatalog.getMessage("llms:dimensionsMustBeSmallerThan", ...
string(model2dim(modelName))));
end
end