diff --git a/Haruspex.java b/Haruspex.java index 2fc03dc..25a9da6 100644 --- a/Haruspex.java +++ b/Haruspex.java @@ -41,6 +41,10 @@ import java.util.ArrayList; import java.io.FileWriter; import java.io.PrintWriter; +import java.io.File; +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; import ghidra.app.script.GhidraScript; import ghidra.program.model.symbol.*; @@ -48,92 +52,181 @@ import ghidra.app.decompiler.DecompInterface; import ghidra.app.decompiler.DecompileOptions; import ghidra.app.decompiler.DecompileResults; +import ghidra.util.exception.CancelledException; -public class Haruspex extends GhidraScript -{ - List functions; - DecompileOptions options; - DecompInterface decomp; - String outputPath = "/tmp/haruspex.out"; - static int TIMEOUT = 60; - - @Override - public void run() throws Exception - { - printf("\nHaruspex.java - Extract Ghidra decompiler's pseudo-code\n"); - printf("Copyright (c) 2022 Marco Ivaldi \n\n"); - - // ask for output directory path - try { - outputPath = askString("Output directory path", "Enter the path of the output directory:"); - } catch (Exception e) { - printf("Output directory not supplied, using default \"%s\".\n", outputPath); - } - - // get all functions - functions = new ArrayList<>(); - getAllFunctions(); - - // extract pseudo-code of all functions (using default options) - decomp = new DecompInterface(); - options = new DecompileOptions(); - decomp.setOptions(options); - decomp.toggleCCode(true); - decomp.toggleSyntaxTree(true); - decomp.setSimplificationStyle("decompile"); - if (!decomp.openProgram(currentProgram)) { - printf("Could not initialize the decompiler, exiting.\n\n"); - return; - } - printf("Extracting pseudo-code from %d functions...\n\n", functions.size()); - functions.forEach(f -> extractPseudoCode(f)); - } - - // collect all Function objects into a global ArrayList - public void getAllFunctions() - { - SymbolTable st = currentProgram.getSymbolTable(); - SymbolIterator si = st.getSymbolIterator(); - - while (si.hasNext()) { - Symbol s = si.next(); - if ( (s.getSymbolType() == SymbolType.FUNCTION) && (!s.isExternal()) ) { - Function fun = getFunctionAt(s.getAddress()); - if (!fun.isThunk()) { - functions.add(fun); - } - } - } - } - - // extract the pseudo-code of a function - // @param func target function - public void extractPseudoCode(Function func) - { - DecompileResults res = decomp.decompileFunction(func, TIMEOUT, monitor); - if(res.getDecompiledFunction() != null){ - saveToFile(outputPath, func.getName() + "@" + func.getEntryPoint() + ".c", res.getDecompiledFunction().getC()); - } - else{ - printf("Can't decompile %s\n\n", func.getName()); - } - } - - // save results to file - // @param path name of the output directory - // @param name name of the output file - // @param output content to save to file - public void saveToFile(String path, String name, String output) - { - try { - FileWriter fw = new FileWriter(path + "/" + name); - PrintWriter pw = new PrintWriter(fw); - pw.write(output); - pw.close(); - - } catch (Exception e) { - printf("Cannot write to output file \"%s\".\n\n", path + "/" + name); - return; - } - } -} +public class Haruspex extends GhidraScript { + List < Function > functions; + DecompileOptions options; + DecompInterface decomp; + String outputPath; + String functionsFilePath; + static int TIMEOUT = 60; + static int BATCHSIZE = 25; // Specify your batch size + + @Override + public void run() throws Exception { + printf("\nHaruspex.java - Extract Ghidra decompiler's pseudo-code\n"); + printf("Copyright (c) 2022 Marco Ivaldi \n\n"); + + String[] inputPaths = new String[2]; + try { + inputPaths[0] = askString("Output directory path", "Enter the path of the output directory:"); + + // Allow empty path + try { + inputPaths[1] = askString("Function Names File Path", "Enter the path of the function names file:"); + } catch (CancelledException e) { + inputPaths[1] = null; + } + + } catch (CancelledException e) { + // Handle cancellation + printf("Script execution canceled by the user.\n"); + return; + } + + outputPath = inputPaths[0]; + functionsFilePath = inputPaths[1]; + + List < String > functionNames; + functions = new ArrayList < > (); + + // Check if function file path is provided + if (functionsFilePath != null && !functionsFilePath.isEmpty()) { + try { + functionNames = readFunctionNames(functionsFilePath); + } catch (Exception e) { + printf("Error reading function names file: %s\n", e.getMessage()); + return; + } + getFunctionsByName(functionNames); + } else { + // If no function file path is provided, extract all functions + getAllFunctions(); + } + + // Create the output directory if it doesn't exist + File outputDir = new File(outputPath); + if (!outputDir.exists()) { + try { + outputDir.mkdirs(); + } catch (SecurityException e) { + printf("Error creating output directory: %s\n", e.getMessage()); + return; + } + } + + printf("Reading function names from: \"%s\".\n", functionsFilePath); + printf("Output directory: \"%s\".\n", outputPath); + + // extract pseudo-code of all functions (using default options) + extractPseudoCode(); + } + + // Reads function names from a file and returns them as a list + // @param filePath The path to the file containing function names + private List < String > readFunctionNames(String filePath) { + List < String > names = new ArrayList < > (); + try (BufferedReader reader = new BufferedReader(new FileReader(new File(filePath)))) { + String line; + while ((line = reader.readLine()) != null) { + names.add(line.trim()); + } + } catch (IOException e) { + printf("Error reading file: %s\n", e.getMessage()); + } + return names; + } + + // Collect Function objects into a global ArrayList based on names + private void getFunctionsByName(List < String > functionNames) { + SymbolTable st = currentProgram.getSymbolTable(); + for (String functionName: functionNames) { + + // Get all symbols with function name + SymbolIterator si = st.getSymbols(functionName); + + // Iterate over symbols + while (si.hasNext()) { + Symbol s = si.next(); + // Check if the symbol is a function + if (s.getSymbolType() == SymbolType.FUNCTION) { + Function fun = getFunctionAt(s.getAddress()); + if (fun != null && !fun.isThunk()) { + functions.add(fun); + } + } + } + } + } + + // Collects all Function objects into a global ArrayList + public void getAllFunctions() { + SymbolTable st = currentProgram.getSymbolTable(); + SymbolIterator si = st.getSymbolIterator(); + + while (si.hasNext()) { + Symbol s = si.next(); + if ((s.getSymbolType() == SymbolType.FUNCTION) && (!s.isExternal())) { + Function fun = getFunctionAt(s.getAddress()); + if (!fun.isThunk()) { + functions.add(fun); + } + } + } + } + + // Initalizes the decompiler and extracts pseudo-code from a batch of functions + private void extractPseudoCode() { + decomp = new DecompInterface(); + options = new DecompileOptions(); + decomp.setOptions(options); + decomp.toggleCCode(true); + decomp.toggleSyntaxTree(true); + decomp.setSimplificationStyle("decompile"); + + if (!decomp.openProgram(currentProgram)) { + printf("Could not initialize the decompiler, exiting.\n\n"); + return; + } + + printf("Extracting pseudo-code from %d functions...\n\n", functions.size()); + + // batch processing for better extraction performance + for (int i = 0; i < functions.size(); i += BATCHSIZE) { + List < Function > batch = functions.subList(i, Math.min(i + BATCHSIZE, functions.size())); + batchExtractPseudoCode(batch); + } + } + + // extract the pseudo-code of a function + // @param func target function batch + public void batchExtractPseudoCode(List < Function > batch) { + for (Function func: batch) { + DecompileResults res = decomp.decompileFunction(func, TIMEOUT, monitor); + if (res.getDecompiledFunction() != null) { + saveToFile(outputPath, func.getName() + "@" + func.getEntryPoint() + ".c", res.getDecompiledFunction().getC()); + } else { + printf("Can't decompile %s\n\n", func.getName()); + } + } + } + + // save results to file + // @param path name of the output directory + // @param name name of the output file + // @param output content to save to file + public void saveToFile(String path, String name, String output) { + String fullPath = path + File.separator + name; + try { + FileWriter fw = new FileWriter(fullPath); + PrintWriter pw = new PrintWriter(fw); + pw.write(output); + pw.close(); + + } catch (IOException e) { + printf("Error writing to output file \"%s\". %s\n\n", fullPath, e.getMessage()); + return; + } + } +} \ No newline at end of file