Skip to content

Commit

Permalink
better fermat, make pylint happier
Browse files Browse the repository at this point in the history
  • Loading branch information
teschlg committed Jun 29, 2024
1 parent d9940e8 commit 7fc9052
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 44 deletions.
7 changes: 2 additions & 5 deletions kryptools/Zmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Zmod:
0 (mod 5)
"""

def __init__(self, n: int, short: bool = False):
def __init__(self, n: int, short: bool = True):
self.n = n
self.short = short

Expand All @@ -40,10 +40,7 @@ class ZmodPoint:
"Represents a point in the ring Zmod."

def __init__(self, x: int, ring: "Zmod"):
if isinstance(x, self.__class__) and x.ring.n == ring.n:
self.x = int(x)
else:
self.x = int(x) % ring.n
self.x = int(x) % ring.n
self.ring = ring

def __repr__(self):
Expand Down
50 changes: 25 additions & 25 deletions kryptools/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,76 +311,76 @@ def order(self) -> int:
break
return order

def dlog(Q, P: "ECPoint") -> int:
def dlog(self, other: "ECPoint") -> int:
"""Compute the discrete log_P(Q) in EC."""
m = P.order()
m = other.order()
mf = factorint(m)
assert m * Q == P.curve(None, None), "DLP not solvable."
assert m * self == other.curve(None, None), "DLP not solvable."
# We first use Pohlig-Hellman to split m into powers of prime factors
mm = []
ll = []
for pj, kj in mf.items():
Pj = (m // pj**kj) * P
Qj = (m // pj**kj) * Q
Pj = (m // pj**kj) * other
Qj = (m // pj**kj) * self
l = Qj.dlog_ph(Pj, pj, kj)
if l is None:
return None
mm += [pj**kj]
ll += [l]
return crt(ll, mm)

def dlog_ph(Q, P: "ECPoint", q: int, k: int) -> int:
def dlog_ph(self, other: "ECPoint", q: int, k: int) -> int:
"""Compute the discrete log_P(Q) in EC if P has order q^k using Pohlig-Hellman reduction."""
if k == 1 or q**k < 10000:
return Q.dlog_switch(P, q**k)
Pj = q**(k - 1) * P
return self.dlog_switch(other, q**k)
Pj = q**(k - 1) * self
P1 = Pj
Qj = q**(k - 1) * Q
Qj = q**(k - 1) * other
xj = Qj.dlog_switch(P1, q)
for j in range(2, k + 1):
Pj = q**(k - j) * P
Qj = q**(k - j) * Q - xj * Pj
Pj = q**(k - j) * self
Qj = q**(k - j) * other - xj * Pj
yj = Qj.dlog_switch(P1, q)
xj = xj + q ** (j - 1) * yj % q**j
return xj

def dlog_switch(Q, P: "ECPoint", m: int) -> int:
def dlog_switch(self, other: "ECPoint", m: int) -> int:
"""Compute the discrete log_P(Q) in EC if P has order m choosing an appropriate method."""
if m < 100:
return Q.dlog_naive(P, m)
return Q.dlog_bsgs(P, m)
return self.dlog_naive(other, m)
return self.dlog_bsgs(other, m)

def dlog_naive(Q, P: "ECPoint", m: int) -> int:
def dlog_naive(self, other: "ECPoint", m: int) -> int:
"""Compute the discrete log_P(Q) in EC using an exhaustive search."""
if not Q.curve == P.curve and not isinstance(Q, P.__class__):
if not self.curve == other.curve and not isinstance(self, other.__class__):
raise ValueError("Points must be on the same curve!")
j = 0
xx, yy = None, None
while xx != Q.x:
while xx != self.x:
j += 1
xx, yy = P.curve.add(xx, yy, P.x, P.y)
xx, yy = self.curve.add(xx, yy, other.x, other.y)
if xx is None:
raise ValueError("DLP not solvabel!")
if yy == Q.y:
if yy == self.y:
return j
return m - j

def dlog_bsgs(Q, P: "ECPoint", m: int) -> int:
def dlog_bsgs(self, other: "ECPoint", m: int) -> int:
"""Compute the discrete log_P(Q) in EC if P has order m using Shanks' baby-step-giant-step algorithm."""
if not Q.curve == P.curve and not isinstance(P, Q.__class__):
if not self.curve == other.curve and not isinstance(other, self.__class__):
raise ValueError("Points must be on the same curve!")
mm = 1 + isqrt(m - 1)
m2 = mm//2 + mm % 1 # we use the group symmetry to halve the number of steps
# initialize baby_steps table
baby_steps = {}
baby_step = P
baby_step = other
for j in range(1,m2+1):
baby_steps[int(baby_step.x)] = j, int(baby_step.y)
baby_step += P
baby_step += other

# now take the giant steps
giant_stride = -mm * P
giant_step = Q
giant_stride = -mm * other
giant_step = self
for l in range(mm+1):
if giant_step.x is None:
return l * mm
Expand Down
2 changes: 1 addition & 1 deletion kryptools/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _factor_fermat(n: int, steps: int = 10) -> list:
start = isqrt(n - 1) + 1
step, mod = parameters[n % 24]
start += (mod - start) % step
for a in range(start, max(start + steps * step,(n + 9) // 6) + 1, step):
for a in range(start, min(start + steps * step,(n + 9) // 6) + 1, step):
b = isqrt(a * a - n)
if b * b == a * a - n:
return a - b
Expand Down
1 change: 1 addition & 0 deletions kryptools/factor_ecm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from math import gcd, isqrt, log
from random import randint, seed
from .primes import sieve_eratosthenes
seed(0)

# Crandall and Pomerance: Primes (doi=10.1007/0-387-28979-8)
# Algorithm 7.2.7
Expand Down
2 changes: 1 addition & 1 deletion kryptools/factor_fmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from math import isqrt


def factor_fermat2(n: int) -> list:
def factor_fermat(n: int) -> list:
"""Find factors of n using the method of Fermat."""
factors = []
parameters = {11: (12, 6), 23: (12, 0),
Expand Down
4 changes: 2 additions & 2 deletions kryptools/la.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def eye(self, m: int = None, n: int = None):
"Returns an identity matrix of the same dimension"
def delta(i, j):
if i == j:
return 1
return 0
return one
return zero
if not m and not n:
n, m = self.cols, self.rows
elif not n:
Expand Down
14 changes: 7 additions & 7 deletions kryptools/nt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def lcm(a: int, b: int) -> int:


def egcd(a: int, b: int) -> (int, int, int):
"""Perform the extended Euclidean agorithm. Returns gcd, x, y such that a x + b y = gcd."""
"""Perform the extended Euclidean agorithm. Returns `gcd`, `x`, `y` such that `a x + b y = gcd`."""
r0, r1 = a, b
x0, x1, y0, y1 = 1, 0, 0, 1
while r1 != 0:
Expand All @@ -36,8 +36,8 @@ def egcd(a: int, b: int) -> (int, int, int):

# Chinese remainder theorem

def crt(a: list, m: list) -> int:
"""Solve given linear congruences x[j] % m[j] == a[j] using the Chinese Remainder Theorem."""
def crt(a: list[int], m: list[int]) -> int:
"""Solve given linear congruences x % m[j] == a[j] using the Chinese Remainder Theorem."""
l = len(a)
assert len(m) == l, "The lists of numbers and modules must have equal length."
M = prod(m)
Expand Down Expand Up @@ -179,7 +179,7 @@ def jacobi_symbol(a: int, n: int) -> int:
return 0

def sqrt_mod(a: int, p: int) -> list:
"Compute a square root of a modulo p unsing Cipolla's algorithm."
"Compute a square root of `a` modulo `p` unsing Cipolla's algorithm."
a %= p
if a == 0 or a == 1:
return a
Expand Down Expand Up @@ -211,13 +211,13 @@ def sqrt_mod(a: int, p: int) -> list:


def euler_phi(n: int) -> int:
"""Euler's phi function of n."""
"""Euler's phi function of `n`."""
k = factorint(n)
return prod([(p - 1) * p ** (k[p] - 1) for p in k])


def carmichael_lambda(n: int) -> int:
"""Carmichael's lambda function of n."""
"""Carmichael's lambda function of `n`."""
k = factorint(n)
lam_all = [] # values corresponding to the prime factors
for p in k:
Expand All @@ -233,7 +233,7 @@ def carmichael_lambda(n: int) -> int:
# Order in Z_p^*

def order(a: int, n: int, factor=False) -> int:
"""Compute the order of a in the group Z_n^*."""
"""Compute the order of `a` in the group Z_n^*."""
a %= n
assert a != 0 and gcd(a, n) == 1, f"{a} and {n} are not coprime!"
factors = dict() # We compute euler_phi(n) and its factorization in one pass
Expand Down
18 changes: 15 additions & 3 deletions kryptools/poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,18 @@ def __init__(self, coeff: list, ring = None, modulus: list = None):
if modulus:
self.mod(modulus)

def __call__(self, x):
return sum(c * x**j for j, c in enumerate(self.coeff))

def __getitem__(self, item):
return self.coeff[item]

def __setitem__(self, item, value):
self.coeff[item] = value

def __len__(self):
return len(self.coeff)

def __repr__(self):
def prx(i: int):
if i == 0:
Expand All @@ -37,7 +46,7 @@ def prx(i: int):
return "x^" + str(i)

if len(self.coeff) == 1:
return str(int(self.coeff[0]))
return str(self.coeff[0])
plus = ""
tmp = ""
for i in reversed(range(len(self.coeff))):
Expand Down Expand Up @@ -76,9 +85,11 @@ def __bool__(self):
return bool(self.degree()) or bool(self.coeff[0])

def degree(self):
"Return the degree."
return len(self.coeff) - 1

def map(self, func):
"Apply a given function to all coefficients."
self.coeff = list(map(func, self.coeff))

def __add__(self, other: "Poly") -> "Poly":
Expand Down Expand Up @@ -184,7 +195,7 @@ def __pow__(self, i: int) -> "Poly":
return res

def divmod(self, other: "Poly") -> ("Poly", "Poly"):
"Polynom division with remainder"
"Polynom division with remainder."
if isinstance(other, list):
other = self.__class__(other)
elif not isinstance(other, self.__class__):
Expand Down Expand Up @@ -215,7 +226,7 @@ def divmod(self, other: "Poly") -> ("Poly", "Poly"):
)

def mod(self, other: "Poly") -> None:
"Remainder of polynom division"
"Reduce with respect to a given polynomial."
if isinstance(other, list):
other = self.__class__(other)
elif not isinstance(other, self.__class__):
Expand All @@ -241,6 +252,7 @@ def mod(self, other: "Poly") -> None:
self.coeff.pop(i)

def inv(self, other: "Poly" = None) -> "Poly":
"Inverse modulo a given polynomial."
if not other:
other = self.modulus
if isinstance(other, list):
Expand Down

0 comments on commit 7fc9052

Please sign in to comment.