-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathbulkelab.lean
165 lines (156 loc) · 6.95 KB
/
bulkelab.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import Lean.Meta
import LeanCodePrompts
import LeanCodePrompts.BatchTranslate
import LeanAide.Config
import Cli
open Lean Cli LeanAide.Meta LeanAide Translator Translate
set_option maxHeartbeats 10000000
set_option maxRecDepth 1000
set_option compiler.extract_closed false
unsafe def runBulkElab (p : Parsed) : IO UInt32 := do
initSearchPath (← Lean.findSysroot) initFiles
let input_file :=
p.positionalArg? "input" |>.map (fun s => s.as! String)
|>.getD "thm"
let numSim := p.flag? "prompts" |>.map (fun s => s.as! Nat)
|>.getD 10
let numConcise := p.flag? "concise_descriptions" |>.map
(fun s => s.as! Nat) |>.getD 2
let numDesc := p.flag? "descriptions" |>.map
(fun s => s.as! Nat) |>.getD 2
let pb₁ :=
PromptExampleBuilder.embedBuilder numSim numConcise numDesc |>.simplify
let numLeanSeach := p.flag? "leansearch_prompts" |>.map
(fun s => s.as! Nat) |>.getD 0
let numMoogle := p.flag? "moogle_prompts" |>.map
(fun s => s.as! Nat) |>.getD 2
let pb₂ := PromptExampleBuilder.searchBuilder numLeanSeach numMoogle |>.simplify
let includeFixed := p.hasFlag "include_fixed"
let pb :=
if includeFixed then pb₁ ++ pb₂ ++ PromptExampleBuilder.proofNetPromptBuilder
else pb₁ ++ pb₂
let queryNum := p.flag? "responses" |>.map (fun s => s.as! Nat)
|>.getD 5
let temp10 := p.flag? "temperature" |>.map (fun s => s.as! Nat)
|>.getD 8
let temp : JsonNumber := ⟨temp10, 1⟩
let model := p.flag? "model" |>.map (fun s => s.as! String)
|>.getD "gpt-4o"
let delay := p.flag? "delay" |>.map (fun s => s.as! Nat)
|>.getD 20
let repeats := p.flag? "repeats" |>.map (fun s => s.as! Nat)
|>.getD 0
let maxTokens := p.flag? "max_tokens" |>.map (fun s => s.as! Nat)
|>.getD 1600
let azure := p.hasFlag "azure"
let tag := p.hasFlag "tag"
let roundtrip := p.hasFlag "roundtrip"
let url? := p.flag? "url" |>.map (fun s => s.as! String)
let sysLess := p.hasFlag "no_sysprompt"
let chatServer :=
if azure then ChatServer.azure (model := model) else
match url? with
| some url => ChatServer.generic model url !sysLess
| none => ChatServer.openAI model
let chatParams : ChatParams :=
let params: ChatParams :=
{temp := temp, n := queryNum, maxTokens := maxTokens}
params.withoutStop (p.hasFlag "no_stop")
let queryData? : Option (Std.HashMap String Json) ←
p.flag? "query_data" |>.map (fun s => s.as! String) |>.mapM
fun filename => do
let lines ← IO.FS.lines filename
let mut qdMap := Std.HashMap.empty
for l in lines do
let json? := Json.parse l
match json? with
| Except.ok json =>
let doc := (json.getObjValAs? String "docString" |>.toOption.orElse
(fun _ => json.getObjValAs? String "doc_string" |>.toOption)
).get!
let out := json.getObjValAs? Json "choices" |>.toOption.get!
qdMap := qdMap.insert doc out
IO.println doc
| Except.error e => do
throw <| IO.userError s!"Error parsing query data file: {e}"
pure qdMap
let gitHash ← gitHash
let dir :=
if tag then System.mkFilePath <| ["results", model, gitHash]
else System.mkFilePath <| ["results", model]
if !(← dir.pathExists) then
IO.FS.createDirAll dir
let outFile :=
if tag then
System.mkFilePath <|
p.flag? "output" |>.map (fun s => [s.as! String]) |>.getD
["results", model, gitHash, s!"{input_file}-elab-{pb.signature}-{chatParams.n}-{chatParams.temp.mantissa}.json"]
else
System.mkFilePath <|
p.flag? "output" |>.map (fun s => [s.as! String]) |>.getD
["results", model, s!"{input_file}-elab-{pb.signature}-{chatParams.n}-{chatParams.temp.mantissa}.json"]
let env ←
importModules #[{module := `Mathlib},
{module:= `LeanAide.TheoremElab},
{module:= `LeanCodePrompts.Translate},
{module := `Mathlib}] {}
withUnpickle (← picklePath "docString")
<|fun (docStringData : EmbedData) => do
withUnpickle (← picklePath "description")
<|fun (descData : EmbedData) => do
withUnpickle (← picklePath "concise-description")
<|fun (concDescData : EmbedData) => do
IO.eprintln "Loading hashmap"
let dataMap :
EmbedMap := Std.HashMap.ofList [("docString", docStringData), ("description", descData), ("concise-description", concDescData)]
IO.eprintln "Loaded hashmap"
let translator : Translator := {pb := pb, server := chatServer, params := chatParams, roundTrip := roundtrip}
let core :=
translator.checkTranslatedThmsM input_file delay repeats
queryData? tag |>.runWithEmbeddings dataMap
let io? :=
core.run' {fileName := "", fileMap := {source:= "", positions := #[]}, maxHeartbeats := 100000000000, maxRecDepth := 1000000}
{env := env}
match ← io?.toIO' with
| Except.ok js =>
IO.println "Success"
IO.FS.writeFile outFile js.pretty
-- IO.println js.pretty
IO.println s!"Written to file {outFile}"
| Except.error e =>
do
IO.println "Ran with error"
let msg ← e.toMessageData.toString
IO.println msg
return 0
unsafe def bulkElab : Cmd := `[Cli|
bulkelab VIA runBulkElab;
"Elaborate a set of inputs and report whether successful and the result if successful."
FLAGS:
include_fixed; "Include the 'Lean Chat' fixed prompts."
o, output : String; "Output file (default `results/{type}-elab-{numSim}-{queryNum}-{temp10}.json`)."
roundtrip; "Translate back to natural language and compare."
prompt_examples : String; "Example prompts in Json"
p, prompts : Nat; "Number of example prompts (default 10)."
concise_descriptions : Nat; "Number of example concise descriptions (default 2)."
descriptions : Nat; "Number of example descriptions (default 2)."
leansearch_prompts: Nat; "Number of examples from LeanSearch"
moogle_prompts: Nat; "Number of examples from Moogle"
r, responses : Nat; "Number of responses to ask for (default 5)."
t, temperature : Nat; "Scaled temperature `t*10` for temperature `t` (default 8)."
roundtrip; "Translate back to natural language and compare."
m, model : String ; "Model to be used (default `gpt-4o`)"
d, delay : Nat; "Delay between requests in seconds (default 20)."
query_data : String; "Query data jsonlines file if cached queries are to be used; should have the result of the 'choices' field of the output and a 'docString' field for the query."
repeats : Nat; "Number of times to repeat the request (default 0)."
azure; "Use Azure instead of OpenAI."
url : String; "URL to query (for a local server)."
tag; "Include the git hash in the results filepath"
no_stop; "Don't use `:=` as a stop token."
max_tokens : Nat; "Maximum tokens to use in the translation."
no_sysprompt; "The model has no system prompt (not relevant for GPT models)."
ARGS:
input : String; "The input file in the `data` folder."
]
unsafe def main (args: List String) : IO UInt32 :=
bulkElab.validate args