Write-up/Crypto

[2019 SSTF OpenCTF] Roughlt Secure Algorithm

MyriaBreak 2019. 8. 28. 00:38

Roughlt Secure Algorithm

Category: Crypto

Points: 100

Author: matta

Description:

Crypto is hard? well...

Write-up

간단한 문제였는데, python에서 매우 큰 숫자를 출력시키면 overhead가 굉장히 커서 안된다는 사실을 몰라서 ... 못푼 문제;

그래도 python의 print를 조심해서 써야한다는 것을 배운 좋은 문제다ㅎ...

python 스크립트를 하나 주는데, 열어보면 RSA문제라는 것을 알 수 있다.

p = getPrime(1024)
q = gmpy2.next_prime(p)
e = 65537
N = p * q
phiN = (p - 1) * (q - 1)
d = gmpy2.powmod(e, -1, phiN)

p와 q가 독립적으로 선택되는 것이 아닌 p의 next prime으로 q를 사용하고 있다.

p/q의 값이 1에 근사하면 주어진 RSA 암호화 방식은 Fermat factorization 공격에 취약하기때문에 쉽게 공격할 수 있다.

문제파일을 실행시켜 p와 q를 보면 128bytes중 마지막 1~2바이트만 다르다... 확실히 p/q는 1에 가깝다..

messages = ["Do U know RSA?", "The format of flag is: SCTF{}", flag]

def encrypt(m):
msg = bytes_to_long(m)
ct = gmpy2.powmod(msg, e, N)
ct = long_to_bytes(ct)
return ct.encode("hex")

open("ciphertext.txt", "w").write(", ".join(map(encrypt, messages)))

문제를 보면 알려진 평문 2개와 암호문 3개가 주어진다.

문제가 신기한게 N값을 알려주지않아서 우리가 직접 구해야하는데, 이 N값을 구하는 것은 RSA 알고리즘을 알고 있다면 쉽게 할 수 있다.

암호화의 대상인 평문 세개를 m1, m2, m3라고 하면 암호문 c1, c2는 각각

c1 = m1e mod N, c2 = m2e mod N

m1e - c1 = 0 mod N

m2e - c2 = 0 mod N

이므로, m1e - c1. m2e - c2 N의 배수이다.

그래서 이 두 수의 공약수를 구하면 N의 배수가 된다. 공약수를 구하는 것은 아래와 같이 수행할 수 있다.

from Crypto.Util.number import *
import gmpy2

e = 65537
ct = [bytes_to_long(c.decode("hex")) for c in open("ciphertext.txt").read().split(", ")]
pt = map(bytes_to_long, ["Do U know RSA?", "The format of flag is: SCTF{}"])

k1N = pow(pt[0], e) - ct[0]
k2N = pow(pt[1], e) - ct[1]

N = gmpy2.gcd(k1N, k2N)

assert(gmpy2.powmod(pt[0], e, N) == ct[0])
assert(gmpy2.powmod(pt[1], e, N) == ct[1])
print hex(N)

N을 찾았으니 이제 p, q만 찾으면 문제를 해결할 수 있을 것 같다.

p와 q는 아주 큰 수라서, 하위 1~2 byte 정도 차이는 상대적으로 작은 부분이라고 볼 수 있다. 그러니 p ≈ q라고 쓰자.

N = pq ≈ p2이므로, sqrt(N) ≈ p이라고 쓸 수 있다.(sqrt()는 squre root 함수)

p와 q 중에 하나는 sqrt(N) 보다 크고 하나는 작을 것인데, 우리 입장에서는 p 또는 q 중에하나만 알면 되니 sqrt(N) 부터 시작해서 brute force를 시도하면 금방 N의 약수를 찾을 수 있을 것이다.

N의 약수 하나(편의상 p 라고 하자)를 찾으면 N으로 나누어 q를 계산할 수 있다. p와 q를 알게 되면 문제 코드 처음에 써있던 대로 비밀키 d를 구할 수 있고, flag를 복호화 할 수 있게 된다.


ver1.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from Crypto.Util.number import *
import gmpy2 
 
def xgcd(b, n):
    x0, x1, y0, y1 = 1001
    while n != 0:
        q, b, n = b // n, n, b % n
        x0, x1 = x1, x0 - q * x1
        y0, y1 = y1, y0 - q * y1
    return b, x0, y0
    
def mul_inv(b, n):
    g, x, _ = xgcd(b, n)
    if g == 1:
        return x % n
    
def fermat_factor(n):
    assert n % 2 != 0
    
    a = gmpy2.isqrt(n)
    b2 = gmpy2.square(a) - n
    
    while not gmpy2.is_square(b2):
        a += 1
        b2 = gmpy2.square(a) - n
    p = a + gmpy2.isqrt(b2)
    q = a - gmpy2.isqrt(b2)
    
    return int(p), int(q)
 
 
= 65537
ct = [bytes_to_long(c.decode("hex")) for c in open("ciphertext.txt").read().split(", ")]
pt = map(bytes_to_long, ["Do U know RSA?""The format of flag is: SCTF{}"])
 
k1N = pow(pt[0], e) - ct[0]
k2N = pow(pt[1], e) - ct[1]
 
= gmpy2.gcd(k1N, k2N)
= ct[2]
 
p, q = fermat_factor(n)
phi = (p-1)*(q-1)
 
= mul_inv(e, phi)
plain = pow(c, d, n)
 
print("p :" + str(p))
print("q :" + str(q))
print("p/q :" +str(p/float(q)))
 
print("flag : " + long_to_bytes(plain))
 
cs



ver2.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from Crypto.Util.number import *
import gmpy2
 
= 65537
ct = [bytes_to_long(c.decode("hex")) for c in open("ciphertext.txt").read().split(", ")]
pt = map(bytes_to_long, ["Do U know RSA?""The format of flag is: SCTF{}"])
 
'''
pt[0] ^ e mod N = ct[0]
pt[1] ^ e mod N = ct[1]
pt[0] ^ e - ct[0] = k1 * N
pt[1] ^ e - ct[1] = k2 * N
'''
 
k1N = pow(pt[0], e) - ct[0]
k2N = pow(pt[1], e) - ct[1]
 
= gmpy2.gcd(k1N, k2N)
 
for i in range(2100):
    if N % i == 0 \
    and gmpy2.powmod(pt[0], e, N // i) == ct[0] \
    and gmpy2.powmod(pt[1], e, N // i) == ct[1]:
        print "reduced:", i
        N //= i
 
assert(gmpy2.powmod(pt[0], e, N) == ct[0])
assert(gmpy2.powmod(pt[1], e, N) == ct[1])
 
= gmpy2.isqrt(N)
 
while True:
    q, r = gmpy2.t_divmod(N, p)
    if (r == 0):
        break
    p += 1
 
phiN = (p - 1* (q - 1)
= gmpy2.powmod(e, -1, phiN)
 
flag = long_to_bytes(gmpy2.powmod(ct[2], d, N))
print flag
 
cs