diff --git a/alembic/env.py b/alembic/env.py index c5a09b9..b79dbe5 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -79,7 +79,8 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) + context.configure(connection=connection, + target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/demo/demo.py b/demo/demo.py index 87dd1e5..87a7af7 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -7,6 +7,7 @@ from pytezos.crypto.key import blake2b_32 import pytezos + def tezos_hex(s): return f"0x{bytes(s, 'utf-8').hex()}" @@ -205,7 +206,8 @@ def generate_permit(self, key, qty=10): }]) matcher = MichelsonType.match( - self.nft_contract.entrypoints["transfer"].as_micheline_expr()["args"][0] + self.nft_contract.entrypoints["transfer"].as_micheline_expr()[ + "args"][0] ) micheline_encoded = matcher.from_micheline_value( transfer.parameters["value"][0]["args"] @@ -245,8 +247,8 @@ def generate_permit(self, key, qty=10): permit_hash = matcher2.from_micheline_value(permit_hashed_args).pack() permit_signature = key.sign(permit_hash) permit_op = self.nft_contract.permit([( - key.public_key(), - permit_signature, - transfer_hash - )]).as_transaction() + key.public_key(), + permit_signature, + transfer_hash + )]).as_transaction() return permit_op diff --git a/src/config.py b/src/config.py index 5cccec1..1d5acbb 100644 --- a/src/config.py +++ b/src/config.py @@ -18,7 +18,8 @@ assert TEZOS_RPC is not None, "Please specify a TEZOS_RPC" -assert SECRET_KEY is not None and len(SECRET_KEY) > 0, "Could not read secret key" +assert SECRET_KEY is not None and len( + SECRET_KEY) > 0, "Could not read secret key" # -- LOGGING -- diff --git a/src/crud.py b/src/crud.py index 3b7d98c..d9d3cda 100644 --- a/src/crud.py +++ b/src/crud.py @@ -57,7 +57,8 @@ def get_contracts_by_credit(db: Session, credit_id: str): Return a list of models.Contracts or raise UserNotFound exception """ return ( - db.query(models.Contract).filter(models.Contract.credit_id == credit_id).all() + db.query(models.Contract).filter( + models.Contract.credit_id == credit_id).all() ) @@ -67,7 +68,8 @@ def get_contract_by_address(db: Session, address: str): """ try: return ( - db.query(models.Contract).filter(models.Contract.address == address).one() + db.query(models.Contract).filter( + models.Contract.address == address).one() ) except NoResultFound as e: raise ContractNotFound() from e @@ -77,7 +79,8 @@ def get_contract(db: Session, contract_id: str): """ Return a models.Contract or raise ContractNotFound exception """ - db_contract: Optional[models.Contract] = db.query(models.Contract).get(contract_id) + db_contract: Optional[models.Contract] = db.query( + models.Contract).get(contract_id) if db_contract is None: raise ContractNotFound() return db_contract @@ -111,9 +114,11 @@ def create_contract(db: Session, contract: schemas.ContractCreation): # TODO rewrite this with transaction or something else better try: contract = get_contract_by_address(db, contract.address) - raise ContractAlreadyRegistered(f"Contract {contract.address} already added.") + raise ContractAlreadyRegistered( + f"Contract {contract.address} already added.") except ContractNotFound: - c = {k: v for k, v in contract.model_dump().items() if k not in ["entrypoints"]} + c = {k: v for k, v in contract.model_dump().items() if k not in [ + "entrypoints"]} db_contract = models.Contract(**c) db.add(db_contract) db.commit() @@ -145,7 +150,8 @@ def get_user_credits(db: Session, user_id: str): """ Get credits from a user. """ - db_credits = db.query(models.Credit).filter(models.Credit.owner_id == user_id).all() + db_credits = db.query(models.Credit).filter( + models.Credit.owner_id == user_id).all() return db_credits @@ -307,7 +313,8 @@ def check_calls_per_month(db, contract_id): # If max_calls is -1 means condition is disabled (NO LIMIT) if max_calls == -1: # type: ignore return True - nb_operations_already_made = get_operations_by_contracts_per_month(db, contract_id) + nb_operations_already_made = get_operations_by_contracts_per_month( + db, contract_id) return max_calls >= len(nb_operations_already_made) @@ -453,5 +460,6 @@ def update_condition(db: Session, condition: models.Condition): def get_conditions_by_vault(db: Session, vault_id: str): return ( - db.query(models.Condition).filter(models.Condition.vault_id == vault_id).all() + db.query(models.Condition).filter( + models.Condition.vault_id == vault_id).all() ) diff --git a/src/models.py b/src/models.py index b0a7220..afa0d55 100644 --- a/src/models.py +++ b/src/models.py @@ -117,14 +117,16 @@ class Operation(Base): entrypoint_id = Column(UUID(as_uuid=True), ForeignKey("entrypoints.id")) hash = Column(String) status = Column(String) # TODO Enum - created_at = Column(DateTime(timezone=True), default=datetime.datetime.utcnow()) + created_at = Column(DateTime(timezone=True), + default=datetime.datetime.utcnow()) contract = relationship("Contract", back_populates="operations") entrypoint = relationship("Entrypoint", back_populates="operations") # ------- CONDITIONS ------- # - +# Condition class with the requirement to limit the number +# of calls for new users (sponsees) class Condition(Base): __tablename__ = "conditions" @@ -141,29 +143,31 @@ class Condition(Base): ) contract_id = Column( UUID(as_uuid=True), - CheckConstraint( - "(type = 'MAX_CALLS_PER_ENTRYPOINT') = (contract_id IS NOT NULL)", - name="contract_id_not_null_constraint", - ), ForeignKey("contracts.id"), nullable=True, ) entrypoint_id = Column( UUID(as_uuid=True), - CheckConstraint( - "(type = 'MAX_CALLS_PER_ENTRYPOINT') = (entrypoint_id IS NOT NULL)", - name="entrypoint_id_not_null_constraint", - ), ForeignKey("entrypoints.id"), nullable=True, ) - vault_id = Column(UUID(as_uuid=True), ForeignKey("credits.id"), nullable=False) + vault_id = Column(UUID(as_uuid=True), ForeignKey( + "credits.id"), nullable=False) max = Column(Integer, nullable=False) current = Column(Integer, nullable=False) created_at = Column( DateTime(timezone=True), default=datetime.datetime.utcnow(), nullable=False ) + # Add a CheckConstraint to ensure consistency between type and associated fields + __table_args__ = ( + CheckConstraint( + "(type = 'MAX_CALLS_PER_ENTRYPOINT' AND entrypoint_id IS NOT NULL) \ + OR (type = 'MAX_CALLS_PER_ENTRYPOINT') AND contract_id IS NOT NULL)", + name="type_to_foreign_key_constraint", + ), + ) + contract = relationship("Contract", back_populates="conditions") entrypoint = relationship("Entrypoint", back_populates="conditions") vault = relationship("Credit", back_populates="conditions") diff --git a/src/routes.py b/src/routes.py index 1522229..4835a08 100644 --- a/src/routes.py +++ b/src/routes.py @@ -75,7 +75,8 @@ async def update_credits( amount = credits.amount is_confirmed = await tezos.confirm_deposit(op_hash, payer_address, amount) if not is_confirmed: - logging.warning(f"Could not find confirmation for {amount} with {op_hash}") + logging.warning( + f"Could not find confirmation for {amount} with {op_hash}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Could not find confirmation for {amount} with {op_hash}", @@ -114,7 +115,8 @@ async def withdraw_credits( detail=f"Credit not found.", ) if credits.amount < withdraw.amount: - logging.warning(f"Not enough funds to withdraw credit ID {withdraw.id}.") + logging.warning( + f"Not enough funds to withdraw credit ID {withdraw.id}.") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Not enough funds to withdraw.", @@ -122,7 +124,8 @@ async def withdraw_credits( expected_counter = credits.owner.withdraw_counter or 0 if expected_counter != withdraw.withdraw_counter: - logging.warning(f"Withdraw counter provided is not the expected counter.") + logging.warning( + f"Withdraw counter provided is not the expected counter.") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Bad withdraw counter." ) @@ -145,7 +148,8 @@ async def withdraw_credits( ) result = await tezos.withdraw(tezos.tezos_manager, owner_address, withdraw.amount) if result["result"] == "ok": - logging.debug(f"Start to confirm withdraw for {result['transaction_hash']}") + logging.debug( + f"Start to confirm withdraw for {result['transaction_hash']}") # Starts a independent loop to check that the operation # has been confirmed asyncio.create_task( @@ -299,7 +303,8 @@ async def post_operation( entrypoint_name = operation["parameters"]["entrypoint"] try: - entrypoint = crud.get_entrypoint(db, str(contract.address), entrypoint_name) + entrypoint = crud.get_entrypoint( + db, str(contract.address), entrypoint_name) if not entrypoint.is_enabled: raise EntrypointDisabled() @@ -339,7 +344,8 @@ async def post_operation( logging.debug(f"Result of operation simulation : {op}") - op_estimated_fees = [(int(x["fee"]), x["destination"]) for x in op.contents] + op_estimated_fees = [(int(x["fee"]), x["destination"]) + for x in op.contents] estimated_fees = tezos.group_fees(op_estimated_fees) logging.debug(f"Estimated fees: {estimated_fees}") @@ -350,7 +356,8 @@ async def post_operation( f"Estimated fees : {estimated_fees[str(contract.address)]} mutez" ) if not crud.check_calls_per_month(db, contract.id): # type: ignore - logging.warning(f"Too many calls made for this contract this month.") + logging.warning( + f"Too many calls made for this contract this month.") raise TooManyCallsForThisMonth() result = await tezos.tezos_manager.queue_operation(call_data.sender_address, op) @@ -358,7 +365,8 @@ async def post_operation( crud.create_operation( db, schemas.CreateOperation( - user_address=call_data.sender_address, contract_id=str(contract.id), entrypoint_id=str(entrypoint.id), hash=result["transaction_hash"], status=result["result"] # type: ignore + # type: ignore + user_address=call_data.sender_address, contract_id=str(contract.id), entrypoint_id=str(entrypoint.id), hash=result["transaction_hash"], status=result["result"] ), ) except MichelsonError as e: diff --git a/src/tezos.py b/src/tezos.py index 6e8f5ea..f102fee 100644 --- a/src/tezos.py +++ b/src/tezos.py @@ -50,7 +50,8 @@ def find_fees(global_tx, payer_key): (int(z["change"]), x["destination"]) for x in op_result for y in ( - x["metadata"].get("operation_result", {}).get("balance_updates", {}), + x["metadata"].get("operation_result", {}).get( + "balance_updates", {}), x["metadata"].get("balance_updates", {}), ) for z in y @@ -232,7 +233,8 @@ async def main_loop(self): log.info(f"{n_ops} operations to process and send") # Post all the correct operations together and get the # result from the RPC to know what the real fees were - posted_tx = ptz.bulk(*acceptable_operations.values()).send() + posted_tx = ptz.bulk( + *acceptable_operations.values()).send() for i, k in enumerate(acceptable_operations): assert self.results[k] != "failing" self.results[k] = {"transaction": posted_tx}