Skip to content

Commit

Permalink
Formatting: formatting, modifying CheckConstraint for better readability
Browse files Browse the repository at this point in the history
  • Loading branch information
lykimq committed Jan 30, 2024
1 parent 4b12ba0 commit 40f3d5f
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 36 deletions.
3 changes: 2 additions & 1 deletion alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def run_migrations_online() -> None:
)

with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
context.configure(connection=connection,
target_metadata=target_metadata)

with context.begin_transaction():
context.run_migrations()
Expand Down
12 changes: 7 additions & 5 deletions demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pytezos.crypto.key import blake2b_32
import pytezos


def tezos_hex(s):
return f"0x{bytes(s, 'utf-8').hex()}"

Expand Down Expand Up @@ -205,7 +206,8 @@ def generate_permit(self, key, qty=10):
}])

matcher = MichelsonType.match(
self.nft_contract.entrypoints["transfer"].as_micheline_expr()["args"][0]
self.nft_contract.entrypoints["transfer"].as_micheline_expr()[
"args"][0]
)
micheline_encoded = matcher.from_micheline_value(
transfer.parameters["value"][0]["args"]
Expand Down Expand Up @@ -245,8 +247,8 @@ def generate_permit(self, key, qty=10):
permit_hash = matcher2.from_micheline_value(permit_hashed_args).pack()
permit_signature = key.sign(permit_hash)
permit_op = self.nft_contract.permit([(
key.public_key(),
permit_signature,
transfer_hash
)]).as_transaction()
key.public_key(),
permit_signature,
transfer_hash
)]).as_transaction()
return permit_op
3 changes: 2 additions & 1 deletion src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@


assert TEZOS_RPC is not None, "Please specify a TEZOS_RPC"
assert SECRET_KEY is not None and len(SECRET_KEY) > 0, "Could not read secret key"
assert SECRET_KEY is not None and len(
SECRET_KEY) > 0, "Could not read secret key"


# -- LOGGING --
Expand Down
24 changes: 16 additions & 8 deletions src/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def get_contracts_by_credit(db: Session, credit_id: str):
Return a list of models.Contracts or raise UserNotFound exception
"""
return (
db.query(models.Contract).filter(models.Contract.credit_id == credit_id).all()
db.query(models.Contract).filter(
models.Contract.credit_id == credit_id).all()
)


Expand All @@ -67,7 +68,8 @@ def get_contract_by_address(db: Session, address: str):
"""
try:
return (
db.query(models.Contract).filter(models.Contract.address == address).one()
db.query(models.Contract).filter(
models.Contract.address == address).one()
)
except NoResultFound as e:
raise ContractNotFound() from e
Expand All @@ -77,7 +79,8 @@ def get_contract(db: Session, contract_id: str):
"""
Return a models.Contract or raise ContractNotFound exception
"""
db_contract: Optional[models.Contract] = db.query(models.Contract).get(contract_id)
db_contract: Optional[models.Contract] = db.query(
models.Contract).get(contract_id)
if db_contract is None:
raise ContractNotFound()
return db_contract
Expand Down Expand Up @@ -111,9 +114,11 @@ def create_contract(db: Session, contract: schemas.ContractCreation):
# TODO rewrite this with transaction or something else better
try:
contract = get_contract_by_address(db, contract.address)
raise ContractAlreadyRegistered(f"Contract {contract.address} already added.")
raise ContractAlreadyRegistered(
f"Contract {contract.address} already added.")
except ContractNotFound:
c = {k: v for k, v in contract.model_dump().items() if k not in ["entrypoints"]}
c = {k: v for k, v in contract.model_dump().items() if k not in [
"entrypoints"]}
db_contract = models.Contract(**c)
db.add(db_contract)
db.commit()
Expand Down Expand Up @@ -145,7 +150,8 @@ def get_user_credits(db: Session, user_id: str):
"""
Get credits from a user.
"""
db_credits = db.query(models.Credit).filter(models.Credit.owner_id == user_id).all()
db_credits = db.query(models.Credit).filter(
models.Credit.owner_id == user_id).all()
return db_credits


Expand Down Expand Up @@ -307,7 +313,8 @@ def check_calls_per_month(db, contract_id):
# If max_calls is -1 means condition is disabled (NO LIMIT)
if max_calls == -1: # type: ignore
return True
nb_operations_already_made = get_operations_by_contracts_per_month(db, contract_id)
nb_operations_already_made = get_operations_by_contracts_per_month(
db, contract_id)
return max_calls >= len(nb_operations_already_made)


Expand Down Expand Up @@ -453,5 +460,6 @@ def update_condition(db: Session, condition: models.Condition):

def get_conditions_by_vault(db: Session, vault_id: str):
return (
db.query(models.Condition).filter(models.Condition.vault_id == vault_id).all()
db.query(models.Condition).filter(
models.Condition.vault_id == vault_id).all()
)
26 changes: 15 additions & 11 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ class Operation(Base):
entrypoint_id = Column(UUID(as_uuid=True), ForeignKey("entrypoints.id"))
hash = Column(String)
status = Column(String) # TODO Enum
created_at = Column(DateTime(timezone=True), default=datetime.datetime.utcnow())
created_at = Column(DateTime(timezone=True),
default=datetime.datetime.utcnow())

contract = relationship("Contract", back_populates="operations")
entrypoint = relationship("Entrypoint", back_populates="operations")


# ------- CONDITIONS ------- #

# Condition class with the requirement to limit the number
# of calls for new users (sponsees)

class Condition(Base):
__tablename__ = "conditions"
Expand All @@ -141,29 +143,31 @@ class Condition(Base):
)
contract_id = Column(
UUID(as_uuid=True),
CheckConstraint(
"(type = 'MAX_CALLS_PER_ENTRYPOINT') = (contract_id IS NOT NULL)",
name="contract_id_not_null_constraint",
),
ForeignKey("contracts.id"),
nullable=True,
)
entrypoint_id = Column(
UUID(as_uuid=True),
CheckConstraint(
"(type = 'MAX_CALLS_PER_ENTRYPOINT') = (entrypoint_id IS NOT NULL)",
name="entrypoint_id_not_null_constraint",
),
ForeignKey("entrypoints.id"),
nullable=True,
)
vault_id = Column(UUID(as_uuid=True), ForeignKey("credits.id"), nullable=False)
vault_id = Column(UUID(as_uuid=True), ForeignKey(
"credits.id"), nullable=False)
max = Column(Integer, nullable=False)
current = Column(Integer, nullable=False)
created_at = Column(
DateTime(timezone=True), default=datetime.datetime.utcnow(), nullable=False
)

# Add a CheckConstraint to ensure consistency between type and associated fields
__table_args__ = (
CheckConstraint(
"(type = 'MAX_CALLS_PER_ENTRYPOINT' AND entrypoint_id IS NOT NULL) \
OR (type = 'MAX_CALLS_PER_ENTRYPOINT') AND contract_id IS NOT NULL)",
name="type_to_foreign_key_constraint",
),
)

contract = relationship("Contract", back_populates="conditions")
entrypoint = relationship("Entrypoint", back_populates="conditions")
vault = relationship("Credit", back_populates="conditions")
24 changes: 16 additions & 8 deletions src/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ async def update_credits(
amount = credits.amount
is_confirmed = await tezos.confirm_deposit(op_hash, payer_address, amount)
if not is_confirmed:
logging.warning(f"Could not find confirmation for {amount} with {op_hash}")
logging.warning(
f"Could not find confirmation for {amount} with {op_hash}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Could not find confirmation for {amount} with {op_hash}",
Expand Down Expand Up @@ -114,15 +115,17 @@ async def withdraw_credits(
detail=f"Credit not found.",
)
if credits.amount < withdraw.amount:
logging.warning(f"Not enough funds to withdraw credit ID {withdraw.id}.")
logging.warning(
f"Not enough funds to withdraw credit ID {withdraw.id}.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Not enough funds to withdraw.",
)

expected_counter = credits.owner.withdraw_counter or 0
if expected_counter != withdraw.withdraw_counter:
logging.warning(f"Withdraw counter provided is not the expected counter.")
logging.warning(
f"Withdraw counter provided is not the expected counter.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Bad withdraw counter."
)
Expand All @@ -145,7 +148,8 @@ async def withdraw_credits(
)
result = await tezos.withdraw(tezos.tezos_manager, owner_address, withdraw.amount)
if result["result"] == "ok":
logging.debug(f"Start to confirm withdraw for {result['transaction_hash']}")
logging.debug(
f"Start to confirm withdraw for {result['transaction_hash']}")
# Starts a independent loop to check that the operation
# has been confirmed
asyncio.create_task(
Expand Down Expand Up @@ -299,7 +303,8 @@ async def post_operation(
entrypoint_name = operation["parameters"]["entrypoint"]

try:
entrypoint = crud.get_entrypoint(db, str(contract.address), entrypoint_name)
entrypoint = crud.get_entrypoint(
db, str(contract.address), entrypoint_name)
if not entrypoint.is_enabled:
raise EntrypointDisabled()

Expand Down Expand Up @@ -339,7 +344,8 @@ async def post_operation(

logging.debug(f"Result of operation simulation : {op}")

op_estimated_fees = [(int(x["fee"]), x["destination"]) for x in op.contents]
op_estimated_fees = [(int(x["fee"]), x["destination"])
for x in op.contents]
estimated_fees = tezos.group_fees(op_estimated_fees)

logging.debug(f"Estimated fees: {estimated_fees}")
Expand All @@ -350,15 +356,17 @@ async def post_operation(
f"Estimated fees : {estimated_fees[str(contract.address)]} mutez"
)
if not crud.check_calls_per_month(db, contract.id): # type: ignore
logging.warning(f"Too many calls made for this contract this month.")
logging.warning(
f"Too many calls made for this contract this month.")
raise TooManyCallsForThisMonth()

result = await tezos.tezos_manager.queue_operation(call_data.sender_address, op)

crud.create_operation(
db,
schemas.CreateOperation(
user_address=call_data.sender_address, contract_id=str(contract.id), entrypoint_id=str(entrypoint.id), hash=result["transaction_hash"], status=result["result"] # type: ignore
# type: ignore
user_address=call_data.sender_address, contract_id=str(contract.id), entrypoint_id=str(entrypoint.id), hash=result["transaction_hash"], status=result["result"]
),
)
except MichelsonError as e:
Expand Down
6 changes: 4 additions & 2 deletions src/tezos.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def find_fees(global_tx, payer_key):
(int(z["change"]), x["destination"])
for x in op_result
for y in (
x["metadata"].get("operation_result", {}).get("balance_updates", {}),
x["metadata"].get("operation_result", {}).get(
"balance_updates", {}),
x["metadata"].get("balance_updates", {}),
)
for z in y
Expand Down Expand Up @@ -232,7 +233,8 @@ async def main_loop(self):
log.info(f"{n_ops} operations to process and send")
# Post all the correct operations together and get the
# result from the RPC to know what the real fees were
posted_tx = ptz.bulk(*acceptable_operations.values()).send()
posted_tx = ptz.bulk(
*acceptable_operations.values()).send()
for i, k in enumerate(acceptable_operations):
assert self.results[k] != "failing"
self.results[k] = {"transaction": posted_tx}
Expand Down

0 comments on commit 40f3d5f

Please sign in to comment.