Skip to content

Commit

Permalink
Src: limit the number of calls per new users
Browse files Browse the repository at this point in the history
  • Loading branch information
lykimq committed Jan 30, 2024
1 parent 40f3d5f commit d706308
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 32 deletions.
37 changes: 22 additions & 15 deletions src/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,22 +321,24 @@ def check_calls_per_month(db, contract_id):
def create_max_calls_per_sponsee_condition(
db: Session, condition: schemas.CreateMaxCallsPerSponseeCondition
):
# If a condition still exists, do not create a new one
"""
Function to limit the number of calls for new users
rather than for a specific sponsee.
"""
existing_condition = (
db.query(models.Condition)
.filter(models.Condition.sponsee_address == condition.sponsee_address)
.filter(models.Condition.vault_id == condition.vault_id)
.filter(models.Condition.current < models.Condition.max)
.one_or_none()
)
if existing_condition is not None:
raise ConditionAlreadyExists(
"A condition with maximum calls per sponsee already exists and the maximum is not reached. Cannot create a new one."
"A condition with maximum calls per sponsee already exists for this vault. Cannot create a new one."
)
# Create a new condition for the max calls per sponsee
db_condition = models.Condition(
**{
"type": schemas.ConditionType.MAX_CALLS_PER_SPONSEE,
"sponsee_address": condition.sponsee_address,
"vault_id": condition.vault_id,
"max": condition.max,
"current": 0,
Expand All @@ -346,7 +348,6 @@ def create_max_calls_per_sponsee_condition(
db.commit()
db.refresh(db_condition)
return schemas.MaxCallsPerSponseeCondition(
sponsee_address=db_condition.sponsee_address,
vault_id=db_condition.vault_id,
max=db_condition.max,
current=db_condition.current,
Expand Down Expand Up @@ -396,12 +397,15 @@ def create_max_calls_per_entrypoint_condition(
id=db_condition.id,
)

# The functions retrieves the condition that sets the maximum limit
# on the number of calls for any address or new user associated with the
# specified vault_id


def check_max_calls_per_sponsee(db: Session, sponsee_address: str, vault_id: UUID4):
def check_max_calls_per_sponsee(db: Session, vault_id: UUID4):
return (
db.query(models.Condition)
.filter(models.Condition.type == schemas.ConditionType.MAX_CALLS_PER_SPONSEE)
.filter(models.Condition.sponsee_address == sponsee_address)
.filter(models.Condition.vault_id == vault_id)
.one_or_none()
)
Expand All @@ -419,23 +423,26 @@ def check_max_calls_per_entrypoint(
.one_or_none()
)

# New condition for limitting calls for new users


def check_conditions(db: Session, datas: schemas.CheckConditions):
print(datas)
sponsee_condition = check_max_calls_per_sponsee(
db, datas.sponsee_address, datas.vault_id
)

max_calls_condition = check_max_calls_per_sponsee
(db, datas.vault_id)

entrypoint_condition = check_max_calls_per_entrypoint(
db, datas.contract_id, datas.entrypoint_id, datas.vault_id
)

# No condition registered
if sponsee_condition is None and entrypoint_condition is None:
if max_calls_condition is None and entrypoint_condition is None:
return True
# One of condition is excedeed
if (
sponsee_condition is not None
and (sponsee_condition.current >= sponsee_condition.max)
max_calls_condition is not None
and (max_calls_condition.current_calls_per_user >= max_calls_condition.max_calls_per_user)
) or (
entrypoint_condition is not None
and (entrypoint_condition.current >= entrypoint_condition.max)
Expand All @@ -445,8 +452,8 @@ def check_conditions(db: Session, datas: schemas.CheckConditions):
# Update conditions
# TODO - Rewrite with list

if sponsee_condition:
update_condition(db, sponsee_condition)
if max_calls_condition:
update_condition(db, max_calls_condition) # type: ignore
if entrypoint_condition:
update_condition(db, entrypoint_condition)
return True
Expand Down
14 changes: 4 additions & 10 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,6 @@ class Condition(Base):

id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
type = Column(Enum(ConditionType))
sponsee_address = Column(
String,
CheckConstraint(
"(type = 'MAX_CALLS_PER_SPONSEE') = (sponsee_address IS NOT NULL)",
name="sponsee_address_not_null_constraint",
),
nullable=True,
)
contract_id = Column(
UUID(as_uuid=True),
ForeignKey("contracts.id"),
Expand All @@ -153,8 +145,10 @@ class Condition(Base):
)
vault_id = Column(UUID(as_uuid=True), ForeignKey(
"credits.id"), nullable=False)
max = Column(Integer, nullable=False)
current = Column(Integer, nullable=False)
# New field to store the maximum allowed calls per user
max_calls_per_user = Column(Integer, nullable=False)
# New field to track current calls per user
current_calls_per_user = Column(Integer, nullable=False, default=0)
created_at = Column(
DateTime(timezone=True), default=datetime.datetime.utcnow(), nullable=False
)
Expand Down
21 changes: 17 additions & 4 deletions src/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,11 @@ async def get_entrypoint(
# Operations
@router.post("/operation")
async def post_operation(
# call_data: contains information about the operations to be performed
call_data: schemas.UnsignedCall, db: Session = Depends(database.get_db)
):
# It first checks if the list of operations in call_data is empty.
# If it is, raises exception
if len(call_data.operations) == 0:
logging.warning(f"Operations list is empty")
raise HTTPException(
Expand All @@ -292,6 +295,8 @@ async def post_operation(
detail=f"Target {contract_address} is not allowed",
)
try:
# try to retrieve the contract associated with the destination address
# of each operation from the database
contract = crud.get_contract_by_address(db, contract_address)
except ContractNotFound:
logging.warning(f"{contract_address} is not found")
Expand All @@ -303,15 +308,19 @@ async def post_operation(
entrypoint_name = operation["parameters"]["entrypoint"]

try:
# check if the specified entrypoint for each operation exists
# and is enabled in the contract
entrypoint = crud.get_entrypoint(
db, str(contract.address), entrypoint_name)
if not entrypoint.is_enabled:
raise EntrypointDisabled()

# Check if certain conditions are met for each operation,
# if not, raises corresponding exceptions
if not crud.check_conditions(
db,
schemas.CheckConditions(
sponsee_address=call_data.sender_address,
sponsee_address=None, # Remove sponsee address
contract_id=contract.id,
entrypoint_id=entrypoint.id,
vault_id=contract.credit_id,
Expand All @@ -338,7 +347,7 @@ async def post_operation(
)

try:
# Simulate the operation alone without sending it
# Simulate the transaction to estimate fees without actually sending it
# TODO: log the result
op = tezos.simulate_transaction(call_data.operations)

Expand All @@ -350,16 +359,20 @@ async def post_operation(

logging.debug(f"Estimated fees: {estimated_fees}")

# Check if there are enough funds to pay the estimated fees.
if not tezos.check_credits(db, estimated_fees):
logging.warning(f"Not enough funds to pay estimated fees.")
raise NotEnoughFunds(
f"Estimated fees : {estimated_fees[str(contract.address)]} mutez"
)
# Check if there have been too many calls made for the contract in the current month
if not crud.check_calls_per_month(db, contract.id): # type: ignore
logging.warning(
f"Too many calls made for this contract this month.")
raise TooManyCallsForThisMonth()

# If everything succeeds, it queues the operation for execution, records the
# operation in the database, and returns the result
result = await tezos.tezos_manager.queue_operation(call_data.sender_address, op)

crud.create_operation(
Expand All @@ -369,6 +382,7 @@ async def post_operation(
user_address=call_data.sender_address, contract_id=str(contract.id), entrypoint_id=str(entrypoint.id), hash=result["transaction_hash"], status=result["result"]
),
)
# Handle exceptions
except MichelsonError as e:
print("Received failing operation, discarding")
logging.error(f"Invalid operation {e}")
Expand Down Expand Up @@ -453,12 +467,11 @@ async def create_condition(
)
elif (
body.type == ConditionType.MAX_CALLS_PER_SPONSEE
and body.sponsee_address is not None
):
# Adjusted to create a condition for maximum calls per user
return crud.create_max_calls_per_sponsee_condition(
db,
schemas.CreateMaxCallsPerSponseeCondition(
sponsee_address=body.sponsee_address,
vault_id=body.vault_id,
max=body.max,
),
Expand Down
7 changes: 4 additions & 3 deletions src/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,12 @@ class CreateMaxCallsPerEntrypointCondition(BaseModel):


class CreateMaxCallsPerSponseeCondition(BaseModel):
sponsee_address: str
vault_id: UUID4
max: int


class CheckConditions(BaseModel):
sponsee_address: str
sponsee_address: Optional[str]
contract_id: UUID4
entrypoint_id: UUID4
vault_id: UUID4
Expand All @@ -179,6 +178,8 @@ class MaxCallsPerEntrypointCondition(ConditionBase):
contract_id: UUID4
entrypoint_id: UUID4

# Remove sponsee_address, because it does not needed


class MaxCallsPerSponseeCondition(ConditionBase):
sponsee_address: str
pass

0 comments on commit d706308

Please sign in to comment.