diff --git a/flag/flag.c b/flag/flag.c index 41ff6cd..cd16497 100644 --- a/flag/flag.c +++ b/flag/flag.c @@ -1,7 +1,7 @@ +#include "flag.h" #include "../bstr/bstr.h" #include "../vec/vector.h" #include -#include "flag.h" #include #include #include @@ -109,9 +109,7 @@ static bool arg_is_flag(bstr *arg) { return true; } -bool flag_parsed(flag_Parser *p) { - return p->parsed; -} +bool flag_parsed(flag_Parser *p) { return p->parsed; } void flag_parse(flag_Parser *p) { for (int curr_arg_idx = 0; curr_arg_idx < p->argc; curr_arg_idx++) { @@ -121,39 +119,70 @@ void flag_parse(flag_Parser *p) { continue; } + bstr flag_name, flag_value; + size_t sep_idx = 0; + if (bstr_index(&sep_idx, arg, bstr_new("=")) || + bstr_index(&sep_idx, arg, bstr_new(":"))) { + flag_value = bstr_new(arg.cstr + sep_idx + 1); + flag_name = arg; + flag_name.len -= flag_value.len + 1; + } else { + flag_name = arg; + flag_value = bstr_new(""); + } + struct flag_flag *curr_flag = NULL; for (size_t j = 0; j < vec_len(p->flags); j++) { vec_get(p->flags, j, &curr_flag); - if (bstr_equal(arg, curr_flag->name)) { - if (++curr_arg_idx >= p->argc) { - break; - } + if (bstr_equal(flag_name, curr_flag->name)) { + char *flag_val = NULL; + + if (flag_value.len == 0 && curr_flag->type != FLAG_BOOL) { + if (++curr_arg_idx >= p->argc) { + break; + } - char *flag_val = p->argv[curr_arg_idx]; + flag_val = p->argv[curr_arg_idx]; + } else { + flag_val = flag_value.cstr; + } switch (curr_flag->type) { case FLAG_BOOL: - curr_flag->bool_flag = true; + if (flag_value.len == 0) { + curr_flag->bool_flag = !curr_flag->bool_flag; + } else if (bstr_equal(flag_value, bstr_new("true"))) { + curr_flag->bool_flag = true; + } else if (bstr_equal(flag_value, bstr_new("false"))) { + curr_flag->bool_flag = false; + } break; + case FLAG_STR: curr_flag->str_flag = flag_val; break; + case FLAG_LONG: curr_flag->long_flag = strtol(flag_val, NULL, 10); break; + case FLAG_ULONG: curr_flag->ulong_flag = strtoul(flag_val, NULL, 10); break; + case FLAG_LONGLONG: curr_flag->longlong_flag = strtoll(flag_val, NULL, 10); break; + case FLAG_ULONGLONG: curr_flag->ulonglong_flag = strtoull(flag_val, NULL, 10); break; + case FLAG_DOUBLE: curr_flag->double_flag = strtod(flag_val, NULL); break; + case FLAG_FLOAT: curr_flag->float_flag = strtof(flag_val, NULL); break; diff --git a/flag/maintest.c b/flag/maintest.c index 7f4d75a..a8c6368 100644 --- a/flag/maintest.c +++ b/flag/maintest.c @@ -4,7 +4,7 @@ #include int main(void) { - char *argv[] = {"--funnyno", "69420", "-pi", "3.14159265359", "--hello", "world"}; + char *argv[] = {"--test=77", "--funnyno", "69420", "--is-true", "-pi", "3.14159265359", "--hello:world"}; int argc = sizeof(argv) / sizeof(argv[0]); flag_Parser p = flag_new(argc, argv); @@ -12,7 +12,9 @@ int main(void) { long *funny_num = flag_long(&p, "funnyno", 0, "the funny number"); char **hello = flag_str(&p, "hello", "test", "says hello"); double *pi = flag_double(&p, "pi", 0, "the number for pi"); - assert(flag_nflags(&p) == 3); + long *test = flag_long(&p, "test", 99, "get number 99 or something else"); + bool *is_true = flag_bool(&p, "is-true", false, "returns true when used"); + assert(flag_nflags(&p) == 5); assert(flag_parsed(&p) == false); flag_parse(&p); @@ -21,6 +23,8 @@ int main(void) { assert(*funny_num == 69420); assert(strcmp(*hello, "world") == 0); assert(*pi == 3.14159265359); + assert(*test == 77); + assert(*is_true); flag_free(&p); puts("test passed.");