Skip to content

Commit

Permalink
Update the FAB management tool for next quarter
Browse files Browse the repository at this point in the history
  • Loading branch information
elisescu committed Dec 18, 2024
1 parent 6e3410c commit 68bfdbd
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 35 deletions.
2 changes: 1 addition & 1 deletion fab_management/templates/fab_management.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ <h2 class="text-xl my-2">Private comments </h2>
<div>
<h2 class="text-xl mt-5 mb-2">Questions submission</h2>
<label for="doc_id">Google doc/spreadsheet ID</label>
<input type="text" id="doc_id" name="doc_id" value="16CvRkYMv7tq70SSFcz03cNMS1a5bmcbtmNFnfe-S3Zo">
<input type="text" id="doc_id" name="doc_id" value="18SvH1FjOsY3ZBm9S90Q4kIvYxIsFAUs6MVJaWnCCiSc">
</div>

<div>
Expand Down
223 changes: 189 additions & 34 deletions fab_management/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from questions.models import Question
from posts.models import Post
from posts.tasks import run_post_indexing
from posts.services.common import trigger_update_post_translations

logger = logging.getLogger(__name__)

Expand All @@ -45,33 +46,152 @@ def convert_to_timestamp(date, hour, minute):
datetime.strptime(date, "%m/%d/%y").replace(hour=hour, minute=minute)
)

# Convert the EST datetime to UTC
return aware_date_est.astimezone(utc_tz)


question_columns = [
# touples with: column header, question field name, type/conversion function
("title", "title", "string"),
("description", "description", "string"),
("resolution_criteria", "resolution_criteria", "string"),
("fine_print", "fine_print", "string"),
("publish_time", "open_time", lambda val: convert_to_timestamp(val, 10, 30)),
def get_number_or_none(row_values, field_name):
value_str = row_values["range_max"]
try:
return float(value_str)
except ValueError:
return None


def get_string_value(row_value, field_name):
return row_value[field_name]


def get_timestamp(row_values, field_name, default_time="00:00:00"):
value_str = row_values.get(field_name, "")
if not value_str:
raise ValueError(f"Missing date in field '{field_name}'")

if len(value_str.split()) == 1:
value_str += f" {default_time}"

formats = ["%m/%d/%Y %H:%M:%S", "%m/%d/%Y %H:%M", "%m/%d/%Y"]
naive_datetime = None

for fmt in formats:
try:
naive_datetime = datetime.strptime(value_str, fmt)
break
except ValueError:
continue

if not naive_datetime:
raise ValueError(f"Invalid date format for field '{field_name}': {value_str}")

timezone = pytz.timezone("UTC")
aware_datetime = timezone.localize(naive_datetime)
return aware_datetime


def get_type(row_values):
value_str = row_values["type"]
if value_str not in ["binary", "numeric", "multiple_choice"]:
raise ValueError("Unknown value for the question type")
return value_str


def get_question_weight(row_values):
value_str = row_values["question_weight"]
if not value_str:
return 1.0
value = float(value_str)
if value <= 1 and value > 0:
return value
raise ValueError("Invalid value for the question_weight field")


def get_range_max(row_values):
value_str = row_values["range_max"]
value = float(value_str)
return value


def get_range_min(row_values):
value_str = row_values["range_min"]
value = float(value_str)
return value


def get_zero_point(row_values):
value_str = row_values["zero_point"]
if not value_str:
return None
value = float(value_str)
return value


def get_open_lower_bound(row_values):
return row_values["open_lower_bound"].lower() == "true"


def get_open_upper_bound(row_values):
return row_values["open_lower_bound"].lower() == "true"


def get_group_variable(row_values):
return row_values["group_variable"]


def get_options(row_values):
return row_values["options"].split("|")


def get_author(row_values):
username = row_values["author"]
try:
return User.objects.get(username=username)
except User.DoesNotExist:
raise ValueError(f"Author username is not valud '{username}'")


common_fields = [
("title", lambda row_values: get_string_value(row_values, "title")),
("type", get_type),
(
"resolution_criteria",
lambda row_values: get_string_value(row_values, "resolution_criteria"),
),
("fine_print", lambda row_values: get_string_value(row_values, "fine_print")),
("description", lambda row_values: get_string_value(row_values, "description")),
("question_weight", get_question_weight),
("open_time", lambda row_values: get_timestamp(row_values, "open_time")),
(
"close_time",
"scheduled_close_time",
lambda val: convert_to_timestamp(val, 10, 30),
lambda row_values: get_timestamp(row_values, "scheduled_close_time"),
),
(
"resolve_time",
"scheduled_resolve_time",
lambda val: convert_to_timestamp(val, 10, 30),
lambda row_values: get_timestamp(row_values, "scheduled_resolve_time"),
),
]

# post_columsn = [
# ("publish_time", "published_at", lambda val: convert_to_timestamp(val, 10, 30)),
# ("author", "author", lambda username: User.objects.get(username=username)),
# ]
numeric_q_fields = [
("range_min", get_range_min),
("range_max", get_range_max),
("zero_point", get_zero_point),
("open_lower_bound", get_open_lower_bound),
("open_upper_bound", get_open_upper_bound),
]

multiple_choice_q_fields = [
(
"group_variable",
lambda row_values: get_string_value(row_values, "group_variable"),
),
("options", get_options),
]

other_fields = [
("parent_url", lambda row_values: get_string_value(row_values, "parent_url")),
("author", lambda row_values: get_string_value(row_values, "author")),
]


all_fields = common_fields + numeric_q_fields + multiple_choice_q_fields + other_fields


def submit_questions(
Expand All @@ -80,59 +200,93 @@ def submit_questions(
worksheet = get_worksheet(doc_id, sheet_name)
tournament = Project.objects.get(pk=tournament_id)
messages: list[tuple[str, str]] = []
field_name = None

rollback = dry_run

def log_error(msg: str):
return messages.append(("error", msg))

def log_info(msg: str):
return messages.append(("info", msg))

try:
start_str, end_str = rows_range.split(":")
start, end = int(start_str), int(end_str)
if start > end:
raise ValueError()
# [int(l) for l in rows_range.split(":")]
except Exception:
log_error(
f"Invalid value for rows range: {rows_range}. Should be in the format start:end"
)
return messages

log_info(
f"Importing these questions to tournament [{tournament.id}/{tournament.name}], from sheet [{worksheet.title}]-> {rows_range}: "
)
top_columns = worksheet.get("A1:T1")[0]
top_columns = worksheet.get("A2:T2")[0]

def get_val(row, col_name):
def get_raw_val(row, col_name):
col_idx = top_columns.index(col_name)
# row = worksheet.get(f"A{row_idx}:T{row_idx}", maintain_size=True)[0]
val = row[col_idx]
return val

col_name = None
def get_values_dict(row):
values_dict = {}
for field_name, _ in all_fields:
idx = top_columns.index(field_name)
values_dict[field_name] = row[idx]
return values_dict

with transaction.atomic():
for row, row_idx in rows_iterator(worksheet, rows_range):
question_data = {"type": "binary"}
row_values = get_values_dict(row)
question_data = {}
try:
parent_url = get_val(row, "parent_url")
parent_url = get_raw_val(row, "parent_url")
parent_post = get_parent_post(parent_url) if parent_url else None

col_name = None
for col_name, question_field, format_fn in question_columns:
val = get_val(row, col_name)
type = get_type(row_values)
question_fields = common_fields
author = get_author(row_values)

if type == "numeric":
question_fields += numeric_q_fields

if type == "multiple_choice":
question_fields += multiple_choice_q_fields

for field_name, field_get_value_fn in question_fields:
val = get_raw_val(row, field_name)

if val == ".p":
if parent_post is None:
log_error(
f"Error on row {row_idx}, col '{col_name}': question has no parent, but '.p' was used for field"
f"Error on row {row_idx}, col '{field_name}': question has no parent, but '.p' was used for field"
)
continue
val = getattr(parent_post.question, question_field)
elif callable(format_fn):
val = format_fn(val)
question_data[question_field] = val
val = getattr(parent_post.question, field_name)
elif callable(field_get_value_fn):
val = field_get_value_fn(row_values)

question_data[field_name] = val
except Exception as e:
log_error(
f"Error on row {row_idx}{', col ' + col_name if col_name else ''}: {str(e)}"
f"Error on row {row_idx}{', col ' + field_name if field_name else ''}: {str(e)}"
)
continue
rollback = True
break

question = Question(**question_data)

question.cp_reveal_time = question.scheduled_close_time
question.include_bots_in_aggregates = True
question.save()
post = Post(
title=question.title,
author=User.objects.get(username=get_val(row, "author")),
author=author,
published_at=question.open_time,
open_time=question.open_time,
created_at=question.created_at,
Expand All @@ -144,10 +298,11 @@ def get_val(row, col_name):
)
post.save()
run_post_indexing.send(post.id)
trigger_update_post_translations(post)
log_info(f" - added question [{question.title}] to {tournament.name}")
if dry_run:
if rollback:
transaction.set_rollback(True)
log_info("****UNDO all actions, dry-run mode was ON****")
log_info("****UNDO all actions, nothing was saved to the DB****")
return messages


Expand Down

0 comments on commit 68bfdbd

Please sign in to comment.