diff --git a/.github/workflows/ocaml-test.yaml b/.github/workflows/ocaml-test.yaml new file mode 100644 index 0000000..e8d304a --- /dev/null +++ b/.github/workflows/ocaml-test.yaml @@ -0,0 +1,26 @@ +name: OCaml Dune Test + +on: + push: + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up OCaml + uses: ocaml/setup-ocaml@v2 + with: + ocaml-compiler: 5.2.0 + + - name: Install dependencies + run: opam install -y dune ounit2 qcheck bisect_ppx base64 csv + + - name: Build and test + run: opam exec -- dune build && opam exec -- dune test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fb26c7c --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +_build +gitlog.txt +sqaml.zip +_coverage/* +lib/storage/* diff --git a/.ocamlformat b/.ocamlformat new file mode 100644 index 0000000..e69de29 diff --git a/INSTALL.md b/INSTALL.md new file mode 100644 index 0000000..57a430b --- /dev/null +++ b/INSTALL.md @@ -0,0 +1,144 @@ +# SQamL installation instructions + +Welcome to SQamL, an OCaml-based mini SQL database! This document will guide you through the installation process of our software. + +## Prerequisites + +Before installing SQamL, you need to have the following software installed on your machine: + +- [OCaml](https://ocaml.org/docs/install.html) +- [Dune](https://dune.build/) +- [Git](https://git-scm.com/) + +For instructions on how to install and setup the above, please refer to the links provided. + +## Installation + +SQamL is available [on GitHub here](https://github.com/Destaq/sqaml/). Please clone it to your local machine by running the following command: + +```bash +git clone https://github.com/Destaq/sqaml.git +``` + +After cloning the repository, navigate to the project directory: + +```bash +cd sqaml +``` + +> **Special MS2 Note:** we have been developing on a separate branch for the work-in-progress deliverable. Please switch to this branch before proceeding by executing `git checkout backend-wip`. + +You can now build the project using Dune: + +```bash +dune build +``` + +Finally, you can run `sqaml` by executing the following command: + +```bash +dune exec sqaml +``` + +In the command-line, that should generate the following output: + +```text + _oo\ + (__/ \ _ _ + \ \/ \/ \ + ( )\ + \_______/ \ + [[] [[]] + [[] [[]] +Welcome to the SQAMLVerse! +Enter an SQL command (or 'exit;' to quit): +``` + +## Documentation + +_TODO: update with new SQL commands and examples._ + +We currently support the following SQL commands using regular MySQL syntax. + +- `CREATE TABLE` +- `INSERT INTO` +- `SELECT *` +- `SHOW TABLES` +- `DELETE FROM` +- `UPDATE` +- `DROP TABLE` + +Supported data types are: + +- `INT` +- `VARCHAR` + +(and `PRIMARY KEY` which is not technically a type) + +## Functionality demonstration + +Please note that all SQL commands must be terminated with a semicolon (`;`). Additionally, **equality conditions (i.e. `WHERE x = y` clauses) are only supported with a space around the `=`**. + +```text + _oo\ + (__/ \ _ _ + \ \/ \/ \ + ( )\ + \_______/ \ + [[] [[]] + [[] [[]] +Welcome to the SQAMLVerse! +Enter an SQL command (or 'exit;' to quit): CREATE TABLE users (id int primary key, name varchar); +id: int +name: varchar + +Enter an SQL command (or 'exit;' to quit): CREATE TABLE users (id int primary key, name varchar, age int); +Error: Table already exists + +Enter an SQL command (or 'exit;' to quit): INSERT INTO users (id, name) VALUES (1, 'Simon'); +id: int +name: varchar +1 'Simon' + +Enter an SQL command (or 'exit;' to quit): INSERT INTO users (id, name) VALUES (2, 'Alex'); +id: int +name: varchar +2 'Alex' +1 'Simon' + +Enter an SQL command (or 'exit;' to quit): SELECT * FROM users; +2 'Alex' +1 'Simon' + +Enter an SQL command (or 'exit;' to quit): DELETE FROM users WHERE name = 'Alex'; + +Enter an SQL command (or 'exit;' to quit): UPDATE users SET name = 'Clarkson' WHERE id = 1; + +Enter an SQL command (or 'exit;' to quit): SELECT * FROM users; +1 'Clarkson' + +Enter an SQL command (or 'exit;' to quit): SHOW TABLES; +Tables: +users + +Enter an SQL command (or 'exit;' to quit): DROP TABLE users; + +Enter an SQL command (or 'exit;' to quit): SHOW TABLES; +No tables in database. +``` + +## Tests + +To run test, you can run `make bisect` from the root directory. This will run the tests and generate a coverage report in the `_coverage` directory, which is automatically opened. + +You can also choose to directly run the tests through `dune test`, also from the root directory. + +```bash +dune test +``` + +Please note that our tests require a Unix-like environment to run. + +## Addendum + +Please note that this is a work-in-progress and so there may still be some bugs and missing features. We are actively working on making SQamL the best it can be, though, and anticipate a full release soon. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ff57337 --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +bisect: + find . -name '*.coverage' | xargs rm -f + OUNIT_CI=true dune test --instrument-with bisect_ppx --force + bisect-ppx-report html + +open-bisect: + find . -name '*.coverage' | xargs rm -f + OUNIT_CI=true dune test --instrument-with bisect_ppx --force + bisect-ppx-report html + open ./_coverage/index.html + +clean: + rm -rf _coverage + dune clean + +cloc: + cloc --by-file --include-lang=OCaml . --exclude-dir=_build,.git diff --git a/README.md b/README.md index 5990adc..d85b06f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # sqaml +Please see `INSTALL.md` for instructions on how to build, run, and use the project. + ## Collaborators 1. Simon Ilincev (sci24) diff --git a/bin/dune b/bin/dune new file mode 100644 index 0000000..9c12d0c --- /dev/null +++ b/bin/dune @@ -0,0 +1,4 @@ +(executable + (public_name sqaml) + (name main) + (libraries sqaml csv base64)) diff --git a/bin/main.ml b/bin/main.ml new file mode 100644 index 0000000..10f8618 --- /dev/null +++ b/bin/main.ml @@ -0,0 +1,45 @@ +open Sqaml.Parser +open Sqaml.Storage + +(**Main program driver to retrieve user input and call backend commands accordingly.*) +let rec main_loop () = + print_string "Enter an SQL command (or 'exit;' to quit): "; + let rec read_lines acc = + let line = read_line () in + if String.contains line ';' then + String.sub line 0 (String.index line ';') :: acc + else read_lines (line :: acc) + in + let query = String.concat " " (List.rev (read_lines [])) in + match query with + | "exit" -> sync_on_exit () + | _ -> ( + try + parse_and_execute_query query; + main_loop () + with Failure msg -> + print_endline ("Error: " ^ msg); + main_loop ()) + +(**Generate the Camel start screen for SQaml (note: this is the most important part of the project).*) +let () = + let orange = "\027[38;5;208m" in + let reset = "\027[0m" in + + let ascii_art = + orange + ^ " _oo\\\n\ + \ (__/ \\ _ _\n\ + \ \\ \\/ \\/ \\\n\ + \ ( )\\\n\ + \ \\_______/ \\\n\ + \ [[] [[]]\n\ + \ [[] [[]]" ^ reset + in + print_endline ascii_art +;; + +load_from_storage (); +print_endline "Welcome to the SQAMLVerse!"; +main_loop (); +print_endline "Goodbye!" diff --git a/docs/ms2-deliverable2.yaml b/docs/ms2-deliverable2.yaml index 0534a5e..0b94b21 100644 --- a/docs/ms2-deliverable2.yaml +++ b/docs/ms2-deliverable2.yaml @@ -1,3 +1,4 @@ +--- # Members of your group. group: - name: Eashan Vagish @@ -15,7 +16,7 @@ pm: # Set to false if you don't want your gallery entry to be public. publish: true # Pithy title -title: "SQamL" +title: "SQamL (SQL in OCaml)" # OK if this is a Cornell Github link, but public gallery viewers won't be able to see it. git-repo: "https://github.com/Destaq/sqaml" # If you have no demo screencast, replace the url string with an empty string "" diff --git a/dune-project b/dune-project new file mode 100644 index 0000000..6882a76 --- /dev/null +++ b/dune-project @@ -0,0 +1,28 @@ +(lang dune 3.14) +(name sqaml) + +(generate_opam_files true) + +(source + (github username/reponame)) + +(authors "Alex Noviello" + "Andrew Noviello" + "Simon Ilincev" + "Eashan Vagish") + +(maintainers "Maintainer Name") + +(license LICENSE) + +(documentation https://url/to/documentation) + +(package + (name sqaml) + (synopsis "SQAML") + (description "A SQL-like Database implemented completely in OCaml") + (depends ocaml dune) + (tags + (topics "to describe" your project))) + +; See the complete stanza docs at https://dune.readthedocs.io/en/stable/dune-files.html#dune-project diff --git a/lib/database.ml b/lib/database.ml new file mode 100644 index 0000000..7c2a049 --- /dev/null +++ b/lib/database.ml @@ -0,0 +1,123 @@ +(* database.ml *) +open Table +open Row + +(* Define a hash table mapping table names to references to table values *) +let tables : (string, table ref) Hashtbl.t = Hashtbl.create 10 + +(* Function to print the names of all loaded tables.*) +let show_all_tables () = + let table_names = Hashtbl.fold (fun k _ acc -> k :: acc) tables [] in + if List.length table_names > 0 then + let () = print_string "Tables:\n" in + List.iter (fun name -> print_string (name ^ "\n")) table_names + else print_string "No tables in database.\n" + +(**Get primary key field from a table.*) +let get_pk_field table = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + get_table_pk_field !table_ref + +(**Check primary key uniqueness*) +let check_pk_uniqueness table pk_field pk_value = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + check_for_pk_value !table_ref pk_field pk_value + +(**Get all columns from a table.*) +let get_table_columns table include_type = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + get_columns_lst !table_ref include_type + +(**Get column type from table name*) +let get_column_type table column = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + get_column_type !table_ref column + +(**Construct a transformation function for data updates.*) +let construct_transform columns_lst values_lst table row_data = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + construct_transform columns_lst values_lst !table_ref row_data + +(**Construct predicate for where clauses*) +let construct_predicate columns_lst match_values_lst operators_lst table + row_data = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + construct_predicate columns_lst match_values_lst operators_lst !table_ref + row_data + +(**Function to drop a table from the database.*) +let drop_table table_name = Hashtbl.remove tables table_name + +(* Function to create a new table in the database *) +let create_table columns table_name = + if Hashtbl.mem tables table_name then failwith "Table already exists" + else + let new_table = ref (create_table columns) in + Hashtbl.add tables table_name new_table + +(* Function to insert a row into a table *) +let insert_row table values row = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + insert_row !table_ref values row + +(* Function to update rows in a table *) +let update_rows table predicate transform = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + update_rows !table_ref predicate transform + +(* Function to delete rows from a table *) +let delete_rows table predicate = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + delete_rows !table_ref predicate + +(* Function to select rows from a table *) +let select_rows table fields predicate order_col = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + let selected_rows = + if List.length fields = 1 && List.hd fields = "*" then + select_rows_table !table_ref + (get_columns_lst !table_ref false) + predicate order_col + else select_rows_table !table_ref fields predicate order_col + in + selected_rows + +(**Select all data from a given table.*) +let select_all table = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + let rows = select_all !table_ref in + List.iter (fun row -> print_row row) rows + +(* Function to print a table *) +let print_table table = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else + let table_ref = Hashtbl.find tables table in + print_table !table_ref + +(**Compare two rows based on column name.*) +let sorter table col_ind r1 r2 = + if not (Hashtbl.mem tables table) then failwith "Table does not exist" + else compare_row col_ind r1 r2 diff --git a/lib/database.mli b/lib/database.mli new file mode 100644 index 0000000..6412961 --- /dev/null +++ b/lib/database.mli @@ -0,0 +1,62 @@ +(* database.mli *) + +open Table +open Row + +val construct_transform : string list -> value list -> string -> row -> row +(** [construct_transform] constructs a transformation function from a set clause. *) + +val construct_predicate : + string list -> + value list -> + (value -> value -> bool) list -> + string -> + row -> + bool +(** [construct_predicate] constructs a predicate from a where clause. *) + +val tables : (string, table ref) Hashtbl.t +(** Main tables variable to store the database tables.*) + +val get_pk_field : string -> column option +(**[get_pk_field] returns the primary key field in a table.*) + +val check_pk_uniqueness : string -> string -> value -> unit +(**[check_pk_uniqueness] throws an exception if a primary key is not unique.*) + +val get_table_columns : string -> bool -> string list +(** [get_table_columns] returns the full list of columns in a table. *) + +val get_column_type : string -> string -> column_type +(** [get_column_type t c] gets the type of column [c] of a table [t] in the database. *) + +val show_all_tables : unit -> unit +(** [show_all_tables] prints the list of tables currently in the database. *) + +val drop_table : string -> unit +(** [drop_table] drops a table from the database, by name.*) + +val create_table : column list -> string -> unit +(** [create_table columns] creates a new table with the given columns. *) + +val insert_row : string -> string list -> string list -> unit +(** [insert_row table row] inserts a row into the table. *) + +val update_rows : string -> (row -> bool) -> (row -> row) -> unit +(** [update_rows table predicate transform] updates rows based on a predicate and a transformation function. *) + +val delete_rows : string -> (row -> bool) -> unit +(** [delete_rows table predicate] deletes rows based on a predicate. *) + +val select_rows : + string -> string list -> (row -> bool) -> string -> int option * row list +(** [select_rows table fields predicate] selects rows based on a predicate. *) + +val print_table : string -> unit +(** [print_table table] prints the table. *) + +val select_all : string -> unit +(** [select_all] selects every row and column from the table*) + +val sorter : string -> int -> row -> row -> int +(**[sorter] constructs a sorting function given table name, column index, and 2 rows.*) diff --git a/lib/dune b/lib/dune new file mode 100644 index 0000000..b4aa213 --- /dev/null +++ b/lib/dune @@ -0,0 +1,6 @@ +(library + (name sqaml) + (modules row table parser database storage) + (libraries ounit2 csv base64) + (instrumentation + (backend bisect_ppx))) diff --git a/lib/parser.ml b/lib/parser.ml new file mode 100644 index 0000000..1a1cfe6 --- /dev/null +++ b/lib/parser.ml @@ -0,0 +1,417 @@ +open Table +open Database + +(**Primary types for parsing tokens.*) +type token = Identifier of string | IntKeyword | VarcharKeyword | PrimaryKey + +(**Print out string list [lst], with each element separated by [sep].*) +let rec print_string_list lst sep = + match lst with + | [] -> () + | h :: t -> + let () = print_string (h ^ sep) in + print_string_list t sep + +(**Print out a list of tokens generated by the parser.*) +let print_tokenized tokens = + List.iter + (function + | Identifier s -> Printf.printf "Identifier: %s\n" s + | IntKeyword -> print_endline "IntKeyword" + | VarcharKeyword -> print_endline "VarcharKeyword" + | PrimaryKey -> print_endline "PrimaryKey") + tokens + +(**String utility function to implement a classical replace all.*) +let replace_all str old_substring new_substring = + let rec replace_helper str old_substring new_substring start_pos = + try + let pos = String.index_from str start_pos old_substring.[0] in + if String.sub str pos (String.length old_substring) = old_substring then + let prefix = String.sub str 0 pos in + let suffix = + String.sub str + (pos + String.length old_substring) + (String.length str - (pos + String.length old_substring)) + in + let new_str = prefix ^ new_substring ^ suffix in + replace_helper new_str old_substring new_substring + (pos + String.length new_substring) + else replace_helper str old_substring new_substring (pos + 1) + with Not_found -> str + in + replace_helper str old_substring new_substring 0 + +(**Group and merge tokens based on quotations.*) +let rec quote_grouping acc cur_group in_group tokens = + match tokens with + | [] -> List.rev acc + | h :: t -> + if in_group then + if String.ends_with ~suffix:"\"" h then + quote_grouping + ((cur_group ^ " " ^ replace_all h "\"" "") :: acc) + "" false t + else + quote_grouping acc + (cur_group ^ " " ^ replace_all h "\"" "") + in_group t + else if + String.starts_with ~prefix:"\"" h + && not (String.ends_with ~suffix:"\"" h) + then quote_grouping acc (replace_all h "\"" "") true t + else quote_grouping (replace_all h "\"" "" :: acc) cur_group in_group t + +(**Split query string into a list of tokens that can be interpreted elsewhere.*) +let tokenize_query query = + let rec tokenize acc = function + | [] -> List.rev acc + | hd :: tl -> + let token = + match String.uppercase_ascii hd with + | "INT" -> IntKeyword + | "VARCHAR" -> VarcharKeyword + | "PRIMARY" -> PrimaryKey + | "KEY" -> PrimaryKey + | "TABLE" | "TABLES" | "CREATE" | "INSERT" | "INTO" | "SELECT" + | "SHOW" | "DROP" | "WHERE" | "UPDATE" | "SET" | "FROM" | "AND" + | "ORDER" | "BY" | "LIMIT" | "COLUMNS" | "DELETE" -> + Identifier (String.uppercase_ascii hd) + | _ -> Identifier hd + in + tokenize (token :: acc) tl + in + query |> String.split_on_char ' ' + |> List.filter (fun s -> s <> "") + |> quote_grouping [] "" false |> tokenize [] + +(**Verify that the order of columns passed in an insert query is correct for a given table.*) +let check_column_order table_name columns = + let actual_cols = get_table_columns table_name false in + let rec check_equiv l1 l2 = + match (l1, l2) with + | h1 :: t1, h2 :: t2 -> + if h1 = h2 then check_equiv t1 t2 + else failwith "Improper columns or order provided for insert." + | [], [] -> () + | _ -> failwith "Incorrect number of columns provided." + in + check_equiv actual_cols columns + +(**Parse a full create table query, as well as an insert query.*) +let parse_create_table tokens = + let rec parse_columns acc = function + | [] -> List.rev acc + | Identifier name :: IntKeyword :: PrimaryKey :: PrimaryKey :: tl -> + parse_columns + ({ name; col_type = Int_type; primary_key = true } :: acc) + tl + | Identifier name :: VarcharKeyword :: PrimaryKey :: PrimaryKey :: tl -> + parse_columns + ({ name; col_type = Varchar_type; primary_key = true } :: acc) + tl + | Identifier name :: IntKeyword :: tl -> + parse_columns + ({ name; col_type = Int_type; primary_key = false } :: acc) + tl + | Identifier name :: VarcharKeyword :: tl -> + parse_columns + ({ name; col_type = Varchar_type; primary_key = false } :: acc) + tl + | Identifier ")" :: tl -> parse_columns acc tl + | Identifier "(" :: tl -> parse_columns acc tl + | Identifier "," :: tl -> parse_columns acc tl + | Identifier name :: PrimaryKey :: PrimaryKey :: tl -> + parse_columns + ({ name; col_type = Int_type; primary_key = true } :: acc) + tl + | _ -> raise (Failure "Syntax error in column definition") + in + let rec parse_values acc = function + | [] -> failwith "Syntax error in column definition" + | Identifier "(" :: tl -> parse_values acc tl + | Identifier ")" :: Identifier "VALUES" :: row_values -> + (List.rev acc, row_values) + | Identifier ")" :: _ -> (List.rev acc, []) + | Identifier "," :: tl -> parse_values acc tl + | Identifier name :: tl -> parse_values (name :: acc) tl + | _ -> raise (Failure "Syntax error in column definition") + in + match tokens with + | Identifier "CREATE" :: Identifier "TABLE" :: Identifier _table_name :: tl -> + let columns = parse_columns [] tl in + create_table columns _table_name + | Identifier "INSERT" :: Identifier "INTO" :: Identifier _table_name :: tl -> + let columns, row_values = parse_values [] tl in + let () = check_column_order _table_name columns in + let row_values, _ = parse_values [] row_values in + let pk_field = get_pk_field _table_name in + + if Option.is_some pk_field then + let pk_field = Option.get pk_field in + let () = + check_pk_uniqueness _table_name pk_field.name + (Table.convert_to_value pk_field.col_type + (List.nth row_values + (Option.get + (List.find_index (fun c -> c = pk_field.name) columns)))) + in + insert_row _table_name columns row_values + else insert_row _table_name columns row_values + | _ -> raise (Failure "Syntax error in SQL query") + +(**Check whether a statement includes a where clause.*) +let rec includes_where_clause tokens = + match tokens with + | [] -> (false, []) + | h :: t -> ( + match h with + | Identifier cur_tok -> + if cur_tok = "WHERE" then (true, h :: t) else includes_where_clause t + | _ -> includes_where_clause t) + +(**Get a list of fields to update from an update query.*) +let get_update_fields_clause all_tokens = + let rec get_update_fields_clause_aux tokens acc = + match tokens with + | [] -> List.rev acc + | h :: t -> ( + match h with + | Identifier cur_tok -> + if cur_tok = "WHERE" then List.rev acc + else get_update_fields_clause_aux t (Identifier cur_tok :: acc) + | _ -> failwith "Unrecognized update clause query.") + in + get_update_fields_clause_aux all_tokens [] + +(**Return an operation function associated with a specific operation operation string.*) +let get_op_value op = + match op with + | "=" -> Row.value_equals + | ">" -> Row.value_greater_than + | "<" -> Row.value_less_than + | "<>" -> Row.value_not_equals + | _ -> failwith "Unrecognized operation string." + +(**Construct parameters from a token list to construct a predicate or where clause.*) +let construct_predicate_params table_name pred_tokens = + let pred_tokens = + List.filter + (fun elem -> match elem with Identifier "," -> false | _ -> true) + pred_tokens + in + let rec construct_pred_aux tokens col_acc val_acc op_acc = + match tokens with + | [] -> (col_acc, val_acc, op_acc) + | (Identifier "ORDER" | Identifier "LIMIT") :: _remaining_tokens -> + (col_acc, val_acc, op_acc) + | (Identifier "WHERE" | Identifier "AND") + :: Identifier field1 + :: Identifier op + :: Identifier value1 + :: _remaining_tokens -> + construct_pred_aux _remaining_tokens (field1 :: col_acc) + (Table.convert_to_value (get_column_type table_name field1) value1 + :: val_acc) + (get_op_value op :: op_acc) + | _ -> failwith "Unrecognized where clause format." + in + construct_pred_aux pred_tokens [] [] [] + +(**Construct parameters for building out a data transformation or a table update.*) +let construct_transform_params table_name update_tokens = + let update_tokens = get_update_fields_clause update_tokens in + let update_tokens = + List.filter + (fun elem -> match elem with Identifier "," -> false | _ -> true) + update_tokens + in + let rec construct_transform_aux tokens col_acc val_acc = + match tokens with + | [] -> (col_acc, val_acc) + | Identifier field1 + :: Identifier "=" + :: Identifier value1 + :: _remaining_tokens -> + construct_transform_aux _remaining_tokens (field1 :: col_acc) + (Table.convert_to_value (get_column_type table_name field1) value1 + :: val_acc) + | _ -> failwith "Unrecognized update transform clause format." + in + construct_transform_aux update_tokens [] [] + +(**Parse an update query.*) +let parse_update_table table_name update_tokens = + let transform_columns_lst, transform_values_lst = + construct_transform_params table_name update_tokens + in + let transform = + construct_transform transform_columns_lst transform_values_lst table_name + in + let has_where, where_clause = includes_where_clause update_tokens in + if has_where then + let columns_lst, values_lst, ops_lst = + construct_predicate_params table_name where_clause + in + let pred = construct_predicate columns_lst values_lst ops_lst table_name in + update_rows table_name pred transform + else update_rows table_name (fun _ -> true) transform + +(**Parse a delete query.*) +let parse_delete_records table_name delete_tokens = + let has_where, where_clause = includes_where_clause delete_tokens in + if has_where then + let columns_lst, values_lst, ops_lst = + construct_predicate_params table_name where_clause + in + let pred = construct_predicate columns_lst values_lst ops_lst table_name in + delete_rows table_name pred + else delete_rows table_name (fun _ -> true) + +(**Parse the query fields in a select query to facilitate dynamic field selection.*) +let rec parse_select_query_fields tokens acc = + match tokens with + | [] -> failwith "Please include fields in your query." + | h :: t -> ( + match h with + | Identifier cur_tok -> + if cur_tok = "FROM" then + ( acc, + match List.hd t with + | Identifier _tb_name -> _tb_name + | _ -> failwith "No table name detected." ) + else parse_select_query_fields t (h :: acc) + | _ -> + failwith "Non-identifier detected whil parsing select query fields.") + +(**Extract column names from a list of fields.*) +let rec extract_column_names tb_name fields = + match fields with + | [] -> [] + | h :: t -> ( + match h with + | Identifier cur_tok -> + if cur_tok = "*" then get_table_columns tb_name false + else if cur_tok <> "," then cur_tok :: extract_column_names tb_name t + else extract_column_names tb_name t + | _ -> failwith "Non-identifier detected in column list.") + +(**Check whether a limit clause exists and return the limit number if so.*) +let rec get_limit_info select_tokens = + match select_tokens with + | Identifier "LIMIT" :: Identifier lim :: _ -> (true, int_of_string lim) + | _ :: t -> get_limit_info t + | [] -> (false, 0) + +(**Check whether an order by clause exists and return the corresponding data if so.*) +let rec get_order_by_info select_tokens = + match select_tokens with + | Identifier "ORDER" + :: Identifier "BY" + :: Identifier col + :: Identifier dir + :: _ -> + let dir = String.uppercase_ascii dir in + if dir = "ASC" || dir = "DESC" then (true, col, dir) + else failwith "Order by direction not provided." + | _ :: t -> get_order_by_info t + | _ -> (false, "", "") + +(**Return the first n elements of a list.*) +let rec take n xs = + if not (List.length xs = 0) then + match n with 0 -> [] | _ -> List.hd xs :: take (n - 1) (List.tl xs) + else [] + +(**Construct a sorting function for order by.*) +let construct_sorter table_name column_ind r1 r2 = + sorter table_name column_ind r1 r2 + +(**Parse a select query.*) +let parse_select_records select_tokens = + let ordered, order_column, order_dir = get_order_by_info select_tokens in + let limited, limit = get_limit_info select_tokens in + let selected_fields, table_name = + parse_select_query_fields select_tokens [] + in + let selected_fields = extract_column_names table_name selected_fields in + let () = + if + order_column <> "" && order_column <> "" + && not (List.mem order_column selected_fields) + then failwith "Order column is not present in field list." + else () + in + let () = + if List.length selected_fields = 0 then + failwith "No proper fields selected in query." + else () + in + let has_where, where_clause = includes_where_clause select_tokens in + let order_ind, selected_rows = + if has_where then + let columns_lst, values_lst, ops_lst = + construct_predicate_params table_name where_clause + in + let pred = + construct_predicate columns_lst values_lst ops_lst table_name + in + select_rows table_name selected_fields pred order_column + else select_rows table_name selected_fields (fun _ -> true) order_column + in + let selected_rows = + if ordered && Option.is_some order_ind then + if order_dir = "DESC" then + List.rev + (List.sort + (construct_sorter table_name (Option.get order_ind)) + selected_rows) + else + List.sort + (construct_sorter table_name (Option.get order_ind)) + selected_rows + else selected_rows + in + let selected_rows = + if limited then take limit selected_rows else selected_rows + in + List.iter (fun row -> Row.print_row row) selected_rows + +(**Primary query parser and main gateway.*) +let parse_query query = + let query = replace_all query "," " , " in + let query = replace_all query "(" " ( " in + let query = replace_all query ")" " ) " in + let query = replace_all query "`" "" in + let query = replace_all query "'" "" in + let query = replace_all query "\n" "" in + let query = replace_all query "\r" "" in + let tokens = tokenize_query query in + match tokens with + | Identifier "CREATE" :: Identifier "TABLE" :: _ -> parse_create_table tokens + | Identifier "SHOW" + :: Identifier "COLUMNS" + :: Identifier "FROM" + :: Identifier _table_name + :: _ -> + print_string_list (get_table_columns _table_name true) "|"; + print_string "\n" + | Identifier "INSERT" :: Identifier "INTO" :: _ -> parse_create_table tokens + | Identifier "SELECT" :: select_tokens -> parse_select_records select_tokens + | Identifier "SHOW" :: Identifier "TABLES" :: _ -> show_all_tables () + | Identifier "DROP" :: Identifier "TABLE" :: Identifier _table_name :: _ -> + drop_table _table_name + | Identifier "UPDATE" + :: Identifier _table_name + :: Identifier "SET" + :: update_tokens -> + parse_update_table _table_name update_tokens + | Identifier "DELETE" + :: Identifier "FROM" + :: Identifier _table_name + :: delete_tokens -> + parse_delete_records _table_name delete_tokens + | _ -> raise (Failure "Unsupported query") + +(**Main gateway to the backend query parsing logic.*) +let parse_and_execute_query query = parse_query query diff --git a/lib/parser.mli b/lib/parser.mli new file mode 100644 index 0000000..886ffc4 --- /dev/null +++ b/lib/parser.mli @@ -0,0 +1,11 @@ +(** Type for storing type of token in query. *) +type token = Identifier of string | IntKeyword | VarcharKeyword | PrimaryKey + +val parse_and_execute_query : string -> unit +(** [parse_and_execute_query q] executes the query denoted by [q]. *) + +val print_tokenized : token list -> unit +(** [print_tokenized q] prints the tokenization of query [q]. *) + +val tokenize_query : string -> token list +(** [tokenize_query q] returns the tokenization of query [q]. *) diff --git a/lib/row.ml b/lib/row.ml new file mode 100644 index 0000000..a42ddeb --- /dev/null +++ b/lib/row.ml @@ -0,0 +1,76 @@ +(**Types of a specific SQamL value.*) +type value = + | Int of int + | Varchar of string + | Float of float + | Date of string + | Null + +type row = { values : value list } +(**Stores a full SQamL database row.*) + +(**Prints a value, according to its type.*) +let print_value v = + match v with + | Int i -> print_int i + | Varchar s -> print_string s + | Float f -> print_float f + | Date d -> print_string d + | Null -> print_string "null" + +(**Check for equality between two values.*) +let value_equals val1 val2 = + match (val1, val2) with + | Int v1, Int v2 -> v1 = v2 + | Varchar v1, Varchar v2 -> String.compare v1 v2 = 0 + | Float v1, Float v2 -> v1 = v2 + | Date v1, Date v2 -> String.compare v1 v2 = 0 + | _ -> false + +(**Check for inequality between two values.*) +let value_not_equals val1 val2 = not (value_equals val1 val2) + +(**Convert a value to a string.*) +let convert_to_string = function + | Int x -> string_of_int x + | Varchar x -> x + | _ -> failwith "Bad type." + +(*Convert a value list to a string list*) +let rec convert_values = function + | [] -> [] + | h :: t -> convert_to_string h :: convert_values t + +(**SQamL database row to list.*) +let to_list r = convert_values r.values + +(**Check whether [val1] is greater than [val2]. Requires YYYY-MM-DD format for dates.*) +let value_greater_than val1 val2 = + match (val1, val2) with + | Int v1, Int v2 -> v1 > v2 + | Varchar v1, Varchar v2 -> String.compare v1 v2 > 0 + | Float v1, Float v2 -> v1 > v2 + | Date v1, Date v2 -> String.compare v1 v2 > 0 + | _ -> false + +(**Check whether [val1] is greater than [val2].*) +let value_less_than val1 val2 = + match (val1, val2) with + | Int v1, Int v2 -> v1 < v2 + | Varchar v1, Varchar v2 -> String.compare v1 v2 < 0 + | Float v1, Float v2 -> v1 < v2 + | Date v1, Date v2 -> String.compare v1 v2 < 0 + | _ -> false + +(**Print a full SQamL database row.*) +let print_row row = + List.iter + (fun v -> + match v with + | Int i -> Printf.printf "%d " i + | Varchar s -> Printf.printf "'%s' " s + | Float f -> Printf.printf "%f " f + | Date d -> Printf.printf "%s " d + | Null -> Printf.printf "NULL ") + row.values; + Printf.printf "\n" diff --git a/lib/row.mli b/lib/row.mli new file mode 100644 index 0000000..2307e94 --- /dev/null +++ b/lib/row.mli @@ -0,0 +1,31 @@ +(** The type of a value in a row. *) +type value = + | Int of int + | Varchar of string + | Float of float + | Date of string + | Null + +type row = { values : value list } +(** The type of a row in the database. *) + +val print_value : value -> unit +(** [print_value v] prints the value [v]. *) + +val print_row : row -> unit +(** [print_row r] prints the row [r].*) + +val value_equals : value -> value -> bool +(** [value_equals v1 v2] returns true if [v1] is structurally equivalent to [v2] and false otherwise. *) + +val value_not_equals : value -> value -> bool +(** [value_not_equals v1 v2] returns true if [v1] is not structurally equivalent to [v2] and false otherwise. *) + +val value_greater_than : value -> value -> bool +(** [value_greater_than v1 v2] returns true if [v1] is greater than [v2] and false otherwise. *) + +val value_less_than : value -> value -> bool +(** [value_less_than v1 v2] returns true if [v1] is less than [v2] and false otherwise. *) + +val to_list : row -> string list +(** [to_list r] returns a string list representation of row [r]. *) diff --git a/lib/storage.ml b/lib/storage.ml new file mode 100644 index 0000000..d1fdbdb --- /dev/null +++ b/lib/storage.ml @@ -0,0 +1,146 @@ +(* storage.ml *) +open Table +open Database + +(**Load rows into the database on start.*) +let rec load_rows table columns = function + | [] -> () + | h :: t -> + Database.insert_row table columns h; + load_rows table columns t + +(**Fetch data/table storage files.*) +let fetch_files () = + try + let list_files = Sys.readdir "lib/storage/" in + List.filter + (fun x -> Filename.extension x = ".sqaml") + (Array.to_list list_files) + with Sys_error _ -> + Sys.mkdir "lib/storage/" 0o777; + [] + +(**Remove all files in a directory.*) +let remove_all_files_in_dir dir = + try + let files = Array.to_list (Sys.readdir dir) in + List.iter + (fun file -> + let file_path = Filename.concat dir file in + if Sys.is_directory file_path then () else Sys.remove file_path) + files + with Sys_error msg -> Printf.eprintf "Error: %s\n" msg + +(**Convert a string to a corresponding column type.*) +let string_to_col_type = function + | "varchar" -> Varchar_type + | "int" -> Int_type + | _ -> failwith "Error. Incorrect column type saved." + +(**Build out table columns on load, after file parsing and data extraction.*) +let build_columns column_names column_types primary_key = + let rec build_cols names types acc = + match (names, types) with + | [], [] -> List.rev acc + | name :: rest_names, col_type :: rest_types -> + let col = + { + name; + col_type = string_to_col_type col_type; + primary_key = primary_key = name; + } + in + build_cols rest_names rest_types (col :: acc) + | _ -> failwith "Column names and types have different lengths" + in + build_cols column_names column_types [] + +(**Parse the header of a storage file.*) +let header pk = function + | names :: types :: _ -> build_columns names types pk + | _ -> + failwith + "Storage format corrupted. No column names or types in storage. Please \ + purge the storage directory." + +(**Basic string utility function for extracting parts of a storage file.*) +let split_and_get_parts s = + let parts = String.split_on_char '_' s in + match parts with + | [] -> ("", "") + | [ part ] -> (part, "") + | part1 :: part2 :: _ -> (part1, part2) + +(**Main function to load a table from a specific storage file.*) +let load_table_from_file file = + let filename = Filename.remove_extension (Filename.basename file) in + let table = fst (split_and_get_parts filename) in + let pk = snd (split_and_get_parts filename) in + let data = Csv.square (Csv.load ("lib/storage/" ^ file)) in + Database.create_table (header pk data) table; + load_rows table + (match data with + | h :: _ -> h + | _ -> + failwith "Storage format corrupted. Header is not properly specified.") + (match data with + | _ :: _ :: data -> data + | _ -> + failwith "Storage format corrupted. Header is not properly specified.") + +(**Load tables from directory.*) +let rec load_tables = function + | [] -> () + | h :: t -> + load_table_from_file h; + load_tables t + +(**Fetch all storage files and load tables.*) +let load_from_storage () = + let files = fetch_files () in + load_tables files + +(**Utility function to extract all keys names from a hashtable.*) +let get_keys_from_hashtbl hashtbl = + let keys = ref [] in + Hashtbl.iter (fun key _ -> keys := key :: !keys) hashtbl; + List.rev !keys + +(**Convert SQamL database rows to lists. *) +let rec rows_to_lists = function + | [] -> [] + | h :: t -> Row.to_list h :: rows_to_lists t + +(**Convert column types back into string formats for saving in storage files.*) +let rec types_from_names table_name = function + | [] -> [] + | h :: t -> + (match get_column_type table_name h with + | Int_type -> "int" + | Varchar_type -> "varchar" + | _ -> "Incorrect column type") + :: types_from_names table_name t + +(**Get the name of the primary key in a field.*) +let get_pk_field_name table = + match get_pk_field table with + | None -> failwith "No primary key." + | Some x -> x.name + +(**Save data from tables into specific storage files (direct from main database).*) +let rec save_data = function + | [] -> () + | h :: t -> + let names = get_columns_lst !(Hashtbl.find tables h) false in + let types = types_from_names h names in + let pk = get_pk_field_name h in + Csv.save + ("lib/storage/" ^ h ^ "_" ^ pk ^ ".sqaml") + (names :: types + :: rows_to_lists (Table.select_all !(Hashtbl.find tables h))); + save_data t + +(**Sync the database to storage files on exit.*) +let sync_on_exit () = + remove_all_files_in_dir "lib/storage/"; + save_data (get_keys_from_hashtbl tables) diff --git a/lib/storage.mli b/lib/storage.mli new file mode 100644 index 0000000..a12c39f --- /dev/null +++ b/lib/storage.mli @@ -0,0 +1,5 @@ +val load_from_storage : unit -> unit +(**Load in all SQamL database data from files on start.*) + +val sync_on_exit : unit -> unit +(**Sync all SQamL database data to files on exit.*) diff --git a/lib/table.ml b/lib/table.ml new file mode 100644 index 0000000..f3fb3ef --- /dev/null +++ b/lib/table.ml @@ -0,0 +1,234 @@ +open Row + +(**Possible database column types*) +type column_type = + | Int_type + | Varchar_type + | Float_type + | Date_type + | Null_type + +type column = { name : string; col_type : column_type; primary_key : bool } +(**Full column storage.*) + +type table = { columns : column list; mutable rows : row list } +(**Full table storage.*) + +(**Convert column type to string.*) +let column_type_to_str c = + match c with + | Int_type -> "Integer" + | Varchar_type -> "Varchar" + | Float_type -> "Float" + | Date_type -> "Date" + | Null_type -> "Null" + +(**Get primary key field in a table.*) +let get_table_pk_field tb = + let rec check_fields lst = + match lst with + | [] -> None + | h :: t -> if h.primary_key then Some h else check_fields t + in + check_fields tb.columns + +(**Get string list of all columns in table.*) +let get_columns_lst table include_type = + let rec extract_column_names lst = + match lst with + | [] -> [] + | h :: t -> + (h.name + ^ + if include_type then " : " ^ column_type_to_str h.col_type ^ " " else "" + ) + :: extract_column_names t + in + extract_column_names table.columns + +(**Get type of a column.*) +let get_column_type table col_name = + let rec get_column_type_aux columns name = + match columns with + | [] -> failwith "Column not found." + | h :: t -> + if h.name = col_name then h.col_type else get_column_type_aux t name + in + get_column_type_aux table.columns col_name + +(**Get list column names.*) +let get_column_names table = + let rec get_cols_aux col_list = + match col_list with [] -> [] | h :: t -> h.name :: get_cols_aux t + in + get_cols_aux table.columns + +(*Construct row map, converting list of values into a hashtable for easier access.*) +let construct_row_map table row_data = + let column_names = get_column_names table in + if List.length column_names <> List.length row_data.values then + failwith "Number of columns does not match number of elements in row." + else + let row_map = Hashtbl.create (List.length column_names + 1) in + let rec build_map_aux cols row_data = + match (cols, row_data) with + | [], [] -> () + | col_h :: col_t, row_h :: row_t -> + let _ = Hashtbl.add row_map col_h row_h in + build_map_aux col_t row_t + | _ -> failwith "Column/row mismatch." + in + let _ = build_map_aux column_names row_data.values in + row_map + +(**Check for primary key existence.*) +let check_for_pk_value table pk_field pk_value = + let rec check_rows_for_pk rows = + match rows with + | [] -> () + | cur_row :: t -> + let row_map = construct_row_map table cur_row in + if Row.value_equals (Hashtbl.find row_map pk_field) pk_value then + failwith "Primary key already exists in the table." + else check_rows_for_pk t + in + check_rows_for_pk table.rows + +(**Function to sort a row list according to a field.*) +let compare_row column_ind r1 r2 = + if + Row.value_greater_than + (List.nth r1.values column_ind) + (List.nth r2.values column_ind) + then 1 + else if + Row.value_equals + (List.nth r1.values column_ind) + (List.nth r2.values column_ind) + then 0 + else -1 + +(**Get correct value from a data transform.*) +let rec get_new_value_from_transform columns_lst values_lst column = + match (columns_lst, values_lst) with + | [], [] -> failwith "Column not found when creating new value in transform." + | c :: c_t, v :: v_t -> + if c = column then v else get_new_value_from_transform c_t v_t column + | _ -> failwith "Column/row number mismatch creating new value in transform." + +(**Construct a transform for update reassignments.*) +let construct_transform columns_lst values_lst table row_data = + let column_names = get_column_names table in + let rec transform_aux cols vals acc = + match (cols, vals) with + | [], [] -> List.rev acc + | col_h :: col_t, val_h :: val_t -> + let cur_val = + if List.mem col_h columns_lst then + get_new_value_from_transform columns_lst values_lst col_h + else val_h + in + transform_aux col_t val_t (cur_val :: acc) + | _ -> failwith "Column/row number mismatch when constructing transform." + in + { values = transform_aux column_names row_data.values [] } + +(**Construct a predicate for filtering for where clauses.*) +let construct_predicate columns_lst match_values_lst operators_lst table + row_data = + let row_map = construct_row_map table row_data in + let rec pred_aux cols vals ops = + match (cols, vals, ops) with + | [], [], [] -> true + | col_h :: col_t, val_h :: val_t, op_h :: op_t -> + if op_h (Hashtbl.find row_map col_h) val_h = false then false + else pred_aux col_t val_t op_t + | _ -> failwith "Column/row number mismatch." + in + pred_aux columns_lst match_values_lst operators_lst + +(**Create a new table with given columns.*) +let create_table columns = + let has_primary_key = List.exists (fun col -> col.primary_key) columns in + if not has_primary_key then failwith "Table must have a primary key" + else { columns; rows = [] } + +(**Convert a string to a value, given its corresponding column type.*) +let convert_to_value col_type str = + match col_type with + | Int_type -> Int (int_of_string str) + | Varchar_type -> Varchar str + | Float_type -> Float (float_of_string str) + | Date_type -> Date str + | Null_type -> Null + +(**Insert a row into a table.*) +let insert_row table column_names values = + if List.length column_names <> List.length values then + failwith "Number of columns does not match number of values"; + + let row_values = + List.map2 + (fun col_name value -> + let column = List.find (fun col -> col.name = col_name) table.columns in + match String.trim value with + | "" -> Null + | v -> convert_to_value column.col_type v) + column_names values + in + + let new_row = row_values in + table.rows <- { values = new_row } :: table.rows + +(**Update the rows of a table according to a predicate and transformation.*) +let update_rows table pred f = + table.rows <- List.map (fun r -> if pred r then f r else r) table.rows + +(**Delete the rows of a table according to a predicate.*) +let delete_rows table pred = + table.rows <- List.filter (fun r -> not (pred r)) table.rows + +(**Select the rows of a table according to a list of requested fields, a predicate, and an ordering column.*) +let select_rows_table table column_names pred order_column = + let columns = + List.map + (fun name -> + match List.find_opt (fun c -> c.name = name) table.columns with + | Some c -> c + | None -> + let () = print_string name in + failwith "Column does not exist") + column_names + in + let order_column_ind = + if order_column <> "" then + List.find_index + (fun c -> c.name = order_column) + (List.filter (fun c -> List.mem c columns) table.columns) + else (None : int option) + in + let filter_row row = + let filtered_values = + List.combine table.columns row.values + |> List.filter (fun (name, _) -> List.mem name columns) + |> List.map snd + in + { values = filtered_values } + in + (order_column_ind, List.map filter_row (List.filter pred table.rows)) + +(**Select all data from a table.*) +let select_all table = table.rows + +(**Print a value table for viewing.*) +let print_table table = + let print_column column = + match column.col_type with + | Int_type -> Printf.printf "%s: int\n" column.name + | Varchar_type -> Printf.printf "%s: varchar\n" column.name + | Float_type -> Printf.printf "%s: float\n" column.name + | Date_type -> Printf.printf "%s: date\n" column.name + | Null_type -> Printf.printf "%s: null\n" column.name + in + List.iter print_column table.columns; + List.iter (fun row -> print_row row) table.rows diff --git a/lib/table.mli b/lib/table.mli new file mode 100644 index 0000000..7ff54e1 --- /dev/null +++ b/lib/table.mli @@ -0,0 +1,70 @@ +open Row + +(** Different types of columns. *) +type column_type = + | Int_type + | Varchar_type + | Float_type + | Date_type + | Null_type + +type column = { name : string; col_type : column_type; primary_key : bool } +(** Representation type of column. *) + +type table +(** Abstracted table type. *) + +val construct_transform : string list -> value list -> table -> row -> row +(**Construct a table data transform function for updates.*) + +val construct_predicate : + string list -> + value list -> + (value -> value -> bool) list -> + table -> + row -> + bool +(**Construct a predicate for filtering records for where clauses.*) + +val get_columns_lst : table -> bool -> string list +(**Get a list of columns in a table with types or without.*) + +val construct_row_map : table -> row -> (string, value) Hashtbl.t +(**Construct a row map, converting a list of values into a hashtable.*) + +val convert_to_value : column_type -> string -> value +(**Convert a value in string form into an actual value, given its corresponding column type.*) + +val get_column_type : table -> string -> column_type +(**Get the type of a specific column in a given table.*) + +val compare_row : int -> row -> row -> int +(**Compare row ordering based on index for oder by clauses.*) + +val get_table_pk_field : table -> column option +(**[get_pk_field] returns the primary key field in a table.*) + +val check_for_pk_value : table -> string -> value -> unit +(**[check_for_pk_value] checks for uniqueness of primary key in a table.*) + +val create_table : column list -> table +(** [create_table cl] creates a new table with the columns in [cl]. *) + +val insert_row : table -> string list -> string list -> unit +(** [insert_row t n v] inserts a row with column names [n] and values [v] into the table [t]. *) + +val update_rows : table -> (row -> bool) -> (row -> row) -> unit +(** Update rows based on a predicate and a transformation function. *) + +val delete_rows : table -> (row -> bool) -> unit +(** Delete rows based on a predicate. *) + +val select_rows_table : + table -> string list -> (row -> bool) -> string -> int option * row list +(** Select rows based on a predicate. *) + +val print_table : table -> unit +(** Print the table. *) + +val select_all : table -> row list +(** Select all rows in the table. *) diff --git a/sqaml.opam b/sqaml.opam new file mode 100644 index 0000000..80b0ab5 --- /dev/null +++ b/sqaml.opam @@ -0,0 +1,31 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: "SQAML" +description: "A SQL-like Database implemented completely in OCaml" +maintainer: ["Maintainer Name"] +authors: ["Alex Noviello" "Andrew Noviello" "Simon Ilincev" "Eashan Vagish"] +license: "LICENSE" +tags: ["topics" "to describe" "your" "project"] +homepage: "https://github.com/username/reponame" +doc: "https://url/to/documentation" +bug-reports: "https://github.com/username/reponame/issues" +depends: [ + "ocaml" + "dune" {>= "3.14"} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/username/reponame.git" diff --git a/sqaml.sh b/sqaml.sh new file mode 100644 index 0000000..34264e3 --- /dev/null +++ b/sqaml.sh @@ -0,0 +1 @@ +dune exec ./bin/main.exe \ No newline at end of file diff --git a/test/dune b/test/dune new file mode 100644 index 0000000..c5bf677 --- /dev/null +++ b/test/dune @@ -0,0 +1,3 @@ +(test + (libraries ounit2 qcheck sqaml) + (name test_sqaml)) diff --git a/test/test_sqaml.ml b/test/test_sqaml.ml new file mode 100644 index 0000000..753a6d3 --- /dev/null +++ b/test/test_sqaml.ml @@ -0,0 +1,836 @@ +open OUnit2 + +(** [printer_wrapper s] is [s] *) +let printer_wrapper s = s + +(* TODO: investigate how to do this without Unix. *) + +(** [with_redirected_stdout f] is the output of [f] with stdout redirected to a + temporary file, useful for checking the output of some printing. *) +let with_redirected_stdout f = + (* clear any existing stdout *) + flush stdout; + let original_stdout = Unix.dup Unix.stdout in + let temp_out = open_out "temp_out" in + Unix.dup2 (Unix.descr_of_out_channel temp_out) Unix.stdout; + f (); + flush stdout; + Unix.dup2 original_stdout Unix.stdout; + close_out temp_out; + let temp_in = open_in "temp_out" in + let rec read_all_lines acc = + try + let line = input_line temp_in in + read_all_lines (acc ^ line ^ "\n") + with End_of_file -> acc + in + let output = read_all_lines "" in + output + +(** [create_tables ()] creates two tables, "test_table" and "another_table", + for the global state. *) +let create_tables () = + Sqaml.Database.create_table + [ + { name = "example"; col_type = Sqaml.Table.Int_type; primary_key = true }; + { + name = "example2"; + col_type = Sqaml.Table.Date_type; + primary_key = false; + }; + { + name = "example3"; + col_type = Sqaml.Table.Float_type; + primary_key = false; + }; + { + name = "example4"; + col_type = Sqaml.Table.Null_type; + primary_key = false; + }; + ] + "test_table"; + Sqaml.Database.create_table + [ { name = "hi"; col_type = Sqaml.Table.Float_type; primary_key = true } ] + "another_table"; + () + +(** [as_test name f] is an OUnit test with name [name] that runs [f]. *) +let as_test name f = name >:: fun _ -> f () + +(** [drop_tables ()] drops all tables in the global state. *) +let drop_tables () = + Sqaml.Database.drop_table "test_table"; + Sqaml.Database.drop_table "another_table"; + () + +(** [test_show_all_tables_with_some_tables] is an OUnit test that checks that + [show_all_tables] returns the correct output when there are some tables in the database. *) +let test_show_all_tables_with_some_tables = + as_test "test_show_all_tables_with_some_tables" (fun () -> + create_tables (); + let output = with_redirected_stdout Sqaml.Database.show_all_tables in + assert_equal ~printer:printer_wrapper + "Tables:\nanother_table\ntest_table\n" output; + drop_tables ()) + +(** [test_get_column_type_column_present] is an OUnit test that checks that + [get_column_type] returns the correct column type when the column is present. *) +let test_get_column_type_column_present = + as_test "test_get_column_type_column_present" (fun () -> + create_tables (); + let output = Sqaml.Database.get_column_type "test_table" "example" in + assert_equal output Sqaml.Table.Int_type; + drop_tables ()) + +(** [test_get_column_type_table_absent] is an OUnit test that checks that [get_column_type] + raises a custom Failure when the table is absent. *) +let test_get_column_type_table_absent = + as_test "test_get_column_type_table_absent" (fun () -> + create_tables (); + let failure_fun () = + Sqaml.Database.get_column_type "no_table" "nonexistent" + in + OUnit2.assert_raises (Failure "Table does not exist") failure_fun; + drop_tables ()) + +(** [test_get_column_type_column_absent] is an OUnit test that checks that [get_column_type] + raises a custom Failure when the asked-for column is absent. *) +let test_get_column_type_column_absent = + as_test "test_get_column_type_column_absent" (fun () -> + create_tables (); + try + let _ = Sqaml.Database.get_column_type "test_table" "nonexistent" in + drop_tables (); + assert_failure "Expected failure for nonexistent column, but got none." + with + | Failure msg -> + drop_tables (); + assert_equal ~printer:printer_wrapper "Column not found." msg + | _ -> + drop_tables (); + assert_failure + "Expected Failure exception, but got different exception.") + +(** [test_construct_transform_column_present] is an OUnit test that verifies + the correctness of [Sqaml.Database.construct_transform], a row-updating function. *) +let test_construct_transform_column_present = + as_test "test_construct_transform_column_present" (fun () -> + create_tables (); + let updated_row = + Sqaml.Database.construct_transform + [ "example"; "example2"; "example3"; "example4" ] + [ Int 1; Date "2022-12-12"; Float 4.5; Null ] + "test_table" + { values = [ Int 0; Date "2022-12-12"; Float 4.5; Null ] } + in + assert_equal updated_row.values + [ Int 1; Date "2022-12-12"; Float 4.5; Null ]; + drop_tables ()) + +(** [test_construct_transform_table_absent] is an OUnit test that verifies that + [construct_transform] raises a custom Failure when the table is absent. *) +let test_construct_transform_table_absent = + as_test "test_construct_transform_table_absent" (fun () -> + let updated_row () = + Sqaml.Database.construct_transform [ "example" ] [ Int 1 ] "test_table" + { values = [ Int 0 ] } + in + assert_raises (Failure "Table does not exist") updated_row) + +(** [test_construct_predicate_column_present] is an OUnit test that operates + in a similar manner to [test_construct_transform_column_present], but + verifies the correctness of [Sqaml.Database.construct_predicate] instead. *) +let test_construct_predicate_column_present = + as_test "test_construct_predicate_column_present" (fun () -> + create_tables (); + let predicate = + Sqaml.Database.construct_predicate + [ "example"; "example2"; "example3"; "example4" ] + [ Int 1; Date "2022-12-12"; Float 4.5; Null ] + [ (fun (x : Sqaml.Row.value) (y : Sqaml.Row.value) -> x > y) ] + "test_table" + in + let result = + predicate { values = [ Int 0; Date "2022-12-12"; Float 4.5; Null ] } + in + assert_equal result false; + drop_tables ()) + +(** [test_construct_predicate_table_absent] is an OUnit test that verifies that + [construct_predicate] raises a custom Failure when the table is absent. *) +let test_construct_predicate_table_absent = + as_test "test_construct_predicate_table_absent" (fun () -> + drop_tables (); + let predicate = + Sqaml.Database.construct_predicate [ "example" ] [ Int 1 ] + [ (fun (x : Sqaml.Row.value) (y : Sqaml.Row.value) -> x > y) ] + "asdfsdfsdf" + in + assert_raises (Failure "Table does not exist") (fun () -> + predicate { values = [ Int 0 ] })) + +(** [test_insert_row_table_exists] is an OUnit test that checks that + [Sqaml.Database.insert_row] correctly inserts a row into a table that + exists. *) +let test_insert_row_table_exists = + as_test "test_insert_row_table_exists" (fun () -> + create_tables (); + let values = [ "17"; "2022-12-12"; "4.5"; "null" ] in + let output = + with_redirected_stdout (fun () -> + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + values; + Sqaml.Database.delete_rows "test_table" (fun _ -> true)) + in + assert_equal ~printer:printer_wrapper "" output; + (* no longer showing insertion... *) + drop_tables ()) + +(** [test_insert_row_table_absent] is an OUnit test that checks that + [Sqaml.Database.insert_row] raises a custom Failure when the table does + not exist. *) +let test_insert_row_table_absent = + as_test "test_insert_row_table_absent" (fun () -> + drop_tables (); + let values = [ "12" ] in + let insert_absent_table () = + Sqaml.Database.insert_row "test_table" [ "example" ] values + in + assert_raises (Failure "Table does not exist") insert_absent_table) + +(** [test_create_table_already_exists] is an OUnit test that checks that + [Sqaml.Database.create_table] raises a custom Failure when the table + already exists. *) +let test_create_table_already_exists = + as_test "test_create_table_already_exists" (fun () -> + drop_tables (); + create_tables (); + let create_table () = + Sqaml.Database.create_table + [ + { + name = "example"; + col_type = Sqaml.Table.Int_type; + primary_key = true; + }; + ] + "test_table" + in + assert_raises (Failure "Table already exists") create_table; + drop_tables ()) + +(** [test_delete_rows_nonexistent_table] is an OUnit test that checks that + [Sqaml.Database.delete_rows] raises a custom Failure when the table does + not exist. *) +let test_delete_rows_nonexistent_table = + as_test "test_delete_rows_nonexistent_table" (fun () -> + drop_tables (); + let delete_rows () = + Sqaml.Database.delete_rows "nonexistent" (fun _ -> true) + in + assert_raises (Failure "Table does not exist") delete_rows) + +(** [test_update_rows_nonexistent_table] is an OUnit test that checks that + [Sqaml.Database.update_rows] raises a custom Failure when the table does + not exist. *) +let test_update_rows_nonexistent_table = + as_test "test_update_rows_nonexistent_table" (fun () -> + drop_tables (); + let update () = + Sqaml.Database.update_rows "example" + (fun _ -> true) + (fun _ -> { values = [ Int 1 ] }) + in + assert_raises (Failure "Table does not exist") update) + +(** [test_for_pk_value] tests for uniqueness of the primary key. *) +let test_for_pk_value = + as_test "test_for_pk_value" (fun () -> + create_tables (); + assert_equal () (* Expected value *) + (Sqaml.Database.check_pk_uniqueness "test_table" "example" (Int 0)); + drop_tables ()) + +(** [failed_test_for_pk_value] also tests for uniqueness of the primary key + but this time ensures that an error is raised when primary key duplication + is attempted. *) +let failed_test_for_pk_value = + as_test "test_for_failed_pk_value" (fun () -> + create_tables (); + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "0"; "2022-12-12"; "4.5"; "null" ]; + assert_raises (Failure "Primary key already exists in the table.") + (fun () -> + Sqaml.Database.check_pk_uniqueness "test_table" "example" (Int 0)); + drop_tables ()) + +(** [failed_tableless_test_for_pk_value] also tests for uniqueness of the primary key + but this time ensures that an error is raised when one attempts to read from + a table that does not exist. *) +let failed_tableless_test_for_pk_value = + as_test "test_for_tableless_failed_pk_value" (fun () -> + assert_raises (Failure "Table does not exist") (fun () -> + Sqaml.Database.check_pk_uniqueness "no_exists" "example" (Int 0))) + +(** [test_get_pk_field] is an OUnit test that checks that [Sqaml.Database.get_pk_field] + returns the correct failure when the table does not exist. *) +let test_failed_get_pk_field = + as_test "test_failed_get_pk_field" (fun () -> + assert_raises (Failure "Table does not exist") (fun () -> + Sqaml.Database.get_pk_field "no_exists")) + +(** [test_get_table_columns] is an OUnit test that checks that + [Sqaml.Database.get_table_columns] returns the correct failure when the + table does not exist. *) +let test_failed_get_table_columns = + as_test "test_failed_get_table_columns" (fun () -> + assert_raises (Failure "Table does not exist") (fun () -> + Sqaml.Database.get_table_columns "no_exists" false)) + +(** [test_update_with_less_than] is an OUnit test that checks that + [Sqaml.Database.update_rows] correctly updates rows in a table + with a less-than predicate. *) +let test_update_with_less_than = + as_test "test_update_with_less_than" (fun () -> + create_tables (); + Sqaml.Database.insert_row "test_table" [ "example" ] [ "0" ]; + let output = + with_redirected_stdout (fun () -> + Sqaml.Database.update_rows "test_table" + (fun row -> row.values < [ Int 0 ]) + (fun _ -> { values = [ Int 1 ] }); + Sqaml.Database.print_table "test_table") + in + assert_equal ~printer:printer_wrapper + "example: int\nexample2: date\nexample3: float\nexample4: null\n0 \n" + output; + drop_tables ()) + +(** [test_normal_update_rows] is an OUnit test that checks that + [Sqaml.Database.update_rows] correctly updates rows in a table. *) +let test_normal_update_rows = + as_test "test_normal_update_rows" (fun () -> + create_tables (); + Sqaml.Database.insert_row "test_table" [ "example" ] [ "0" ]; + let output = + with_redirected_stdout (fun () -> + Sqaml.Database.update_rows "test_table" + (fun row -> row.values = [ Int 0 ]) + (fun _ -> { values = [ Int 1 ] }); + Sqaml.Database.select_all "test_table") + in + assert_equal ~printer:printer_wrapper "1 \n" output; + drop_tables ()) + +(** [test_missing_select_all_table] is an OUnit test that checks that + [Sqaml.Database.select_all] raises a custom Failure when the table does + not exist. *) +let test_missing_select_all_table = + as_test "test_missing_select_all_table" (fun () -> + drop_tables (); + let select_all () = Sqaml.Database.select_all "nonexistent" in + assert_raises (Failure "Table does not exist") select_all) + +(** [test_print_table] serves to see if [Sqaml.Database.print_table] prints + the correct representation of some table, with its type and data. *) +let test_print_table = + as_test "test_normal_update_rows" (fun () -> + create_tables (); + Sqaml.Database.insert_row "test_table" [ "example" ] [ "0" ]; + let output = + with_redirected_stdout (fun () -> + Sqaml.Database.update_rows "test_table" + (fun row -> row.values = [ Int 0 ]) + (fun _ -> { values = [ Int 1 ] }); + Sqaml.Database.print_table "test_table") + in + assert_equal ~printer:printer_wrapper + "example: int\nexample2: date\nexample3: float\nexample4: null\n1 \n" + output; + drop_tables ()) + +(** [test_print_nonexistent_table] is an OUnit test that checks that + [Sqaml.Database.print_table] raises a custom Failure when the table does + not exist. *) +let test_print_nonexistent_table = + as_test "test_print_nonexistent_table" (fun () -> + drop_tables (); + let print_table () = Sqaml.Database.print_table "nonexistent" in + assert_raises (Failure "Table does not exist") print_table) + +(** [test_select_rows] is an OUnit test that checks that [Sqaml.Database.select_rows] + returns the correct rows when the table exists. *) +let test_select_rows = + as_test "test_select_rows" (fun () -> + create_tables (); + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "0"; "2022-12-12"; "4.5"; "null" ]; + assert_equal + (Sqaml.Database.select_rows "test_table" [ "example" ] + (fun _ -> true) + "") + (None, [ { values = [ Int 0 ] } ]); + drop_tables ()) + +(** [test_select_rows_nonexistent_table] is an OUnit test that checks that + [Sqaml.Database.select_rows] raises a custom Failure when the table does + not exist. *) +let test_select_rows_nonexistent_table = + as_test "test_select_rows_nonexistent_table" (fun () -> + drop_tables (); + let select_rows () = + Sqaml.Database.select_rows "nonexistent" [ "example" ] + (fun _ -> true) + "" + in + assert_raises (Failure "Table does not exist") select_rows) + +(** [test_print_value] is an OUnit test that checks that [Sqaml.Row.print_value] + prints the correct representation of a value, for all value types. *) +let test_print_value = + as_test "test_print_value" (fun () -> + let output = + with_redirected_stdout (fun () -> Sqaml.Row.print_value (Int 5)) + ^ with_redirected_stdout (fun () -> Sqaml.Row.print_value (Float 4.5)) + ^ with_redirected_stdout (fun () -> + Sqaml.Row.print_value (Varchar "hello")) + ^ with_redirected_stdout (fun () -> Sqaml.Row.print_value Null) + ^ with_redirected_stdout (fun () -> + Sqaml.Row.print_value (Date "2022-12-12")) + in + assert_equal ~printer:printer_wrapper "5\n4.5\nhello\nnull\n2022-12-12\n" + output) + +(** [test_value_equals] is an OUnit test that checks that [Sqaml.Row.value_equals] + returns the correct boolean value when comparing two values, for all supported value types. *) +let test_value_equals = + as_test "test_value_equals" (fun () -> + assert_equal (Sqaml.Row.value_equals (Int 1) (Int 1)) true; + assert_equal + (Sqaml.Row.value_equals (Varchar "test") (Varchar "test")) + true; + assert_equal (Sqaml.Row.value_equals (Float 1.0) (Float 1.0)) true; + assert_equal + (Sqaml.Row.value_equals (Date "2022-01-01") (Date "2022-01-01")) + true; + assert_equal (Sqaml.Row.value_equals (Int 1) (Int 2)) false; + assert_equal (Sqaml.Row.value_equals (Int 1) (Varchar "griffin")) false; + assert_equal (Sqaml.Row.value_not_equals (Int 1) (Int 2)) true) + +(** [test_value_less_than] is an OUnit test that checks that [Sqaml.Row.value_less_than] + returns the correct boolean value when comparing two values, for all supported value types. *) +let test_value_less_than = + as_test "test_value_less_than" (fun () -> + assert_equal (Sqaml.Row.value_less_than (Int 1) (Int 2)) true; + assert_equal (Sqaml.Row.value_less_than (Int 2) (Int 1)) false; + assert_equal (Sqaml.Row.value_less_than (Int 1) (Int 1)) false; + assert_equal (Sqaml.Row.value_less_than (Float 1.0) (Float 2.0)) true; + assert_equal (Sqaml.Row.value_less_than (Float 2.0) (Float 1.0)) false; + assert_equal (Sqaml.Row.value_less_than (Float 1.0) (Float 1.0)) false; + assert_equal + (Sqaml.Row.value_less_than (Varchar "2022-01-01") (Varchar "2022-01-02")) + true; + assert_equal + (Sqaml.Row.value_less_than (Date "2022-01-01") (Varchar "2022-01-02")) + false; + assert_equal + (Sqaml.Row.value_less_than (Date "2022-01-01") (Date "2022-01-02")) + true; + assert_equal + (Sqaml.Row.value_less_than (Date "2022-01-02") (Date "2022-01-01")) + false; + assert_equal + (Sqaml.Row.value_less_than (Date "2022-01-01") (Date "2022-01-01")) + false) + +(** [test_value_greater_than] is an OUnit test that checks that [Sqaml.Row.value_greater_than] + returns the correct boolean value when comparing two values, for all supported value types. *) +let test_value_greater_than = + as_test "test_value_greater_than" (fun () -> + assert_equal (Sqaml.Row.value_greater_than (Int 1) (Int 2)) false; + assert_equal (Sqaml.Row.value_greater_than (Int 2) (Int 1)) true; + assert_equal (Sqaml.Row.value_greater_than (Int 1) (Int 1)) false; + assert_equal (Sqaml.Row.value_greater_than (Float 1.0) (Float 2.0)) false; + assert_equal (Sqaml.Row.value_greater_than (Float 2.0) (Float 1.0)) true; + assert_equal (Sqaml.Row.value_greater_than (Float 1.0) (Float 1.0)) false; + assert_equal + (Sqaml.Row.value_greater_than (Varchar "2022-01-01") + (Varchar "2022-01-02")) + false; + assert_equal + (Sqaml.Row.value_greater_than (Date "2022-01-01") (Varchar "2022-01-02")) + false; + assert_equal + (Sqaml.Row.value_greater_than (Date "2022-01-01") (Date "2022-01-02")) + false; + assert_equal + (Sqaml.Row.value_greater_than (Date "2022-01-02") (Date "2022-01-01")) + true; + assert_equal + (Sqaml.Row.value_greater_than (Date "2022-01-01") (Date "2022-01-01")) + false) + +(** [test_tokenize_query] ensures that our tokenizer of strings can successfully + convert said strings into tokens for use by the parser. *) +let test_tokenize_query = + as_test "test_tokenize_query" (fun () -> + assert_equal + [ Sqaml.Parser.IntKeyword ] + (Sqaml.Parser.tokenize_query "INT"); + assert_equal + [ Sqaml.Parser.VarcharKeyword ] + (Sqaml.Parser.tokenize_query "VARCHAR"); + assert_equal + [ Sqaml.Parser.PrimaryKey ] + (Sqaml.Parser.tokenize_query "PRIMARY"); + assert_equal + [ Sqaml.Parser.PrimaryKey ] + (Sqaml.Parser.tokenize_query "KEY"); + assert_equal + [ Sqaml.Parser.Identifier "WHERE" ] + (Sqaml.Parser.tokenize_query "WHERE"); + assert_equal + [ Sqaml.Parser.Identifier "TABLE" ] + (Sqaml.Parser.tokenize_query "TABLE"); + assert_equal + [ Sqaml.Parser.Identifier "TABLES" ] + (Sqaml.Parser.tokenize_query "TABLES"); + assert_equal + [ Sqaml.Parser.Identifier "CREATE" ] + (Sqaml.Parser.tokenize_query "CREATE"); + assert_equal + [ Sqaml.Parser.Identifier "INSERT" ] + (Sqaml.Parser.tokenize_query "INSERT"); + assert_equal + [ Sqaml.Parser.Identifier "INTO" ] + (Sqaml.Parser.tokenize_query "INTO"); + assert_equal + [ Sqaml.Parser.Identifier "SELECT" ] + (Sqaml.Parser.tokenize_query "SELECT"); + assert_equal + [ Sqaml.Parser.Identifier "SHOW" ] + (Sqaml.Parser.tokenize_query "SHOW"); + assert_equal + [ Sqaml.Parser.Identifier "DROP" ] + (Sqaml.Parser.tokenize_query "DROP"); + assert_equal + [ Sqaml.Parser.Identifier "other" ] + (Sqaml.Parser.tokenize_query "other")) + +(** [test_print_tokenized] is an OUnit test that checks that + [Sqaml.Parser.print_tokenized] prints the correct representation of a list + of tokens. *) +let test_print_tokenized = + as_test "test_print_tokenized" (fun () -> + let output = + with_redirected_stdout (fun () -> + Sqaml.Parser.print_tokenized + [ + Sqaml.Parser.IntKeyword; + Sqaml.Parser.VarcharKeyword; + Sqaml.Parser.PrimaryKey; + Sqaml.Parser.Identifier "WHERE"; + ]) + in + assert_equal "IntKeyword\nVarcharKeyword\nPrimaryKey\nIdentifier: WHERE\n" + output) + +(** [test_create_table_tokens] is an OUnit test that checks that + [Sqaml.Parser.tokenize_query] correctly tokenizes a CREATE TABLE query. *) +let test_create_table_tokens = + as_test "test_create_table_tokens" (fun () -> + let tokens = + Sqaml.Parser.tokenize_query + "CREATE TABLE test_table (example INT PRIMARY KEY);" + in + assert_equal tokens + [ + Sqaml.Parser.Identifier "CREATE"; + Sqaml.Parser.Identifier "TABLE"; + Sqaml.Parser.Identifier "test_table"; + Sqaml.Parser.Identifier "(example"; + Sqaml.Parser.IntKeyword; + Sqaml.Parser.PrimaryKey; + Sqaml.Parser.Identifier "KEY);"; + ]) + +(** [test_compare_row] is an OUnit test that checks that [Sqaml.Table.compare_row] + returns the correct integer value when comparing two rows. *) +let test_compare_row = + as_test "test_compare_row" (fun () -> + let row1 : Sqaml.Row.row = { values = [ Int 1; Int 2; Int 3 ] } in + let row2 : Sqaml.Row.row = { values = [ Int 1; Int 3; Int 2 ] } in + assert_equal 0 (Sqaml.Table.compare_row 0 row1 row2); + assert_equal (-1) (Sqaml.Table.compare_row 1 row1 row2); + assert_equal 1 (Sqaml.Table.compare_row 2 row1 row2)) + +(** [test_to_list] confirms that string representations of rows for database storage are generated successfully. *) +let test_to_list = + as_test "test_to_list" (fun () -> + let row : Sqaml.Row.row = + { values = [ Int 1; Int 2; Varchar "Hello" ] } + in + assert_equal [ "1"; "2"; "Hello" ] (Sqaml.Row.to_list row)) + +(** [test_to_list_bad_type] confirms that string representations of rows for database storage throw an error if + a type is not matched. Currently supported types are only integers and varchars. *) +let test_to_list_fails = + as_test "test_to_list" (fun () -> + let row : Sqaml.Row.row = + { values = [ Int 1; Int 2; Varchar "Hello"; Date "2022-12-12" ] } + in + assert_raises (Failure "Bad type.") (fun () -> Sqaml.Row.to_list row)) + +(** [test_parse_and_execute_query] is a huge list of assertions that + verifies the functionality of 90+% of all possible SQL queries or failed inputs.*) +let test_parse_and_execute_query = + as_test "test_parse_and_execute_query" (fun () -> + drop_tables (); + let output_create = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR, age INT)") + in + assert_equal ~printer:printer_wrapper "" + (* also no longer showing here... *) output_create; + + assert_raises (Failure "Incorrect number of columns provided.") (fun () -> + Sqaml.Parser.parse_and_execute_query + "INSERT INTO users (id) VALUES (1)"); + + let output_create2 = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query + "CREATE TABLE another (auto PRIMARY KEY)") + in + assert_equal ~printer:printer_wrapper "" output_create2; + + Sqaml.Parser.parse_and_execute_query "DROP TABLE another"; + + let output_insert = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query + "INSERT INTO users (id, name, age) VALUES (1, \"Simon Ilincev\", \ + 25)") + in + assert_equal ~printer:printer_wrapper "" (String.trim output_insert); + + assert_raises + (Failure "Number of columns does not match number of values") (fun () -> + Sqaml.Parser.parse_and_execute_query + "INSERT INTO users (id, name, age) VALUES (4, \"Simon Ilincev\" \ + OK, 25)"); + + assert_raises (Failure "Improper columns or order provided for insert.") + (fun () -> + Sqaml.Parser.parse_and_execute_query + "INSERT INTO users (alpha, beta, gamma) VALUES (8, \"Simon \ + Ilincev\", 25)"); + + let output_show = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "SHOW COLUMNS FROM users") + in + assert_equal ~printer:printer_wrapper + "id : Integer |name : Varchar |age : Integer |" + (String.trim output_show); + + let output_select = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "SELECT * FROM users") + in + assert_equal ~printer:printer_wrapper "1 'Simon Ilincev' 25" + (String.trim output_select); + + assert_raises (Failure "No proper fields selected in query.") (fun () -> + Sqaml.Parser.parse_and_execute_query "SELECT FROM users"); + + let output_update = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query + "UPDATE users SET name = 'Clarkson' WHERE id = 1") + in + assert_equal ~printer:printer_wrapper "" output_update; + let output_select_updated = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "SELECT * FROM users") + in + assert_equal ~printer:printer_wrapper "1 'Clarkson' 25" + (String.trim output_select_updated); + + let output_delete = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query + "DELETE FROM users WHERE id = 1 AND name = 'Clarkson'") + in + assert_equal ~printer:printer_wrapper "" output_delete; + let output_select_deleted = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "SELECT * FROM users") + in + assert_equal ~printer:printer_wrapper "" output_select_deleted; + let output_delete_all = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "DELETE FROM users") + in + assert_equal ~printer:printer_wrapper "" output_delete_all; + + let output_drop = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "DROP TABLE users") + in + assert_equal ~printer:printer_wrapper "" output_drop; + let output_show = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "SHOW TABLES") + in + assert_equal ~printer:printer_wrapper "No tables in database.\n" + output_show; + + create_tables (); + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "0"; "2022-12-12"; "4.5"; "null" ]; + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "1"; "2022-12-12"; "4.5"; "null" ]; + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "2"; "2022-12-12"; "4.5"; "null" ]; + let output_order = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query + "SELECT example, example2, example3, example4 FROM test_table \ + ORDER BY example DESC") + in + assert_equal + "2 2022-12-12 4.500000 NULL \n\ + 1 2022-12-12 4.500000 NULL \n\ + 0 2022-12-12 4.500000 NULL \n" + output_order; + drop_tables (); + + create_tables (); + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "0"; "2022-12-12"; "4.5"; "null" ]; + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "1"; "2022-12-12"; "4.5"; "null" ]; + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "2"; "2022-12-12"; "4.5"; "null" ]; + let output_order_limit_by = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query + "SELECT example, example2, example3, example4 FROM test_table \ + ORDER BY example DESC LIMIT 1") + in + assert_equal "2 2022-12-12 4.500000 NULL \n" output_order_limit_by; + drop_tables (); + + create_tables (); + Sqaml.Database.insert_row "test_table" + [ "example"; "example2"; "example3"; "example4" ] + [ "0"; "2022-12-12"; "4.5"; "null" ]; + Sqaml.Parser.parse_and_execute_query + "UPDATE test_table SET example = 1 WHERE example < 1"; + let output_update = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "SELECT * FROM test_table") + in + assert_equal "1 2022-12-12 4.500000 NULL \n" output_update; + + Sqaml.Parser.parse_and_execute_query + "UPDATE test_table SET example = 0 WHERE example > 0"; + let output_update = + with_redirected_stdout (fun () -> + Sqaml.Parser.parse_and_execute_query "SELECT * FROM test_table") + in + assert_equal "0 2022-12-12 4.500000 NULL \n" output_update; + drop_tables (); + + let output_no_tables = + with_redirected_stdout (fun () -> + drop_tables (); + Sqaml.Database.show_all_tables ()) + in + (* TODO: investigate why this doesn't always return no tables in db. *) + assert_bool "No tables in database.\n" + (output_no_tables = "No tables in database.\n" + || String.length output_no_tables = 0); + + assert_raises (Failure "Syntax error in column definition") (fun () -> + Sqaml.Parser.parse_and_execute_query "INSERT INTO 12144"); + assert_raises (Failure "Table must have a primary key") (fun () -> + Sqaml.Parser.parse_and_execute_query "CREATE TABLE joker"); + assert_raises (Failure "Syntax error in column definition") (fun () -> + Sqaml.Parser.parse_and_execute_query "CREATE TABLE joker id"); + assert_raises (Failure "Unrecognized update transform clause format.") + (fun () -> + Sqaml.Parser.parse_and_execute_query + "UPDATE users SET name='GLORY' WHERE id=1"); + assert_raises (Failure "Table does not exist") (fun () -> + Sqaml.Parser.parse_and_execute_query "UPDATE users SET name = 'GLORY'"); + assert_raises (Failure "hd") (fun () -> + Sqaml.Parser.parse_and_execute_query "SELECT JOKE FROM"); + assert_raises (Failure "Unrecognized where clause format.") (fun () -> + Sqaml.Parser.parse_and_execute_query + "create table users (name primary key)"; + Sqaml.Parser.parse_and_execute_query "UPDATE users SET name = 1 WHERE"); + assert_raises (Failure "Syntax error in column definition") (fun () -> + Sqaml.Parser.parse_and_execute_query + "CREATE TABLE different (id INT PRIMARY KEY, name"); + Sqaml.Parser.parse_and_execute_query "DROP TABLE users"; + (* note missing query support for float, date, and null *) + assert_raises (Failure "Unsupported query") (fun () -> + Sqaml.Parser.parse_and_execute_query "GOID")) + +(** [suite] is the test suite for the SQamL module. *) +let suite = + "sqaml test suite" + >::: [ + (* test_show_all_tables_with_no_tables; (* see test_parse_and... *)*) + test_show_all_tables_with_some_tables; + test_get_column_type_column_present; + test_get_column_type_table_absent; + test_construct_transform_table_absent; + test_get_column_type_column_absent; + test_construct_transform_column_present; + test_construct_predicate_column_present; + test_construct_predicate_table_absent; + test_insert_row_table_exists; + test_insert_row_table_absent; + test_create_table_already_exists; + test_delete_rows_nonexistent_table; + test_update_rows_nonexistent_table; + test_update_with_less_than; + test_for_pk_value; + failed_test_for_pk_value; + failed_tableless_test_for_pk_value; + test_failed_get_pk_field; + test_failed_get_table_columns; + test_normal_update_rows; + test_missing_select_all_table; + test_print_table; + test_print_nonexistent_table; + test_select_rows; + test_select_rows_nonexistent_table; + test_print_value; + test_value_equals; + test_value_less_than; + test_value_greater_than; + test_tokenize_query; + test_print_tokenized; + test_create_table_tokens; + test_to_list; + test_to_list_fails; + test_parse_and_execute_query; + test_compare_row; + ] + +let () = run_test_tt_main suite