Roughlt Secure Algorithm

Category: Crypto

Points: 100

Author: matta

Description:

Crypto is hard? well...

Write-up

간단한 문제였는데, python에서 매우 큰 숫자를 출력시키면 overhead가 굉장히 커서 안된다는 사실을 몰라서 ... 못푼 문제;

그래도 python의 print를 조심해서 써야한다는 것을 배운 좋은 문제다ㅎ...

python 스크립트를 하나 주는데, 열어보면 RSA문제라는 것을 알 수 있다.

p = getPrime(1024)
q = gmpy2.next_prime(p)
e = 65537
N = p * q
phiN = (p - 1) * (q - 1)
d = gmpy2.powmod(e, -1, phiN)

p와 q가 독립적으로 선택되는 것이 아닌 p의 next prime으로 q를 사용하고 있다.

p/q의 값이 1에 근사하면 주어진 RSA 암호화 방식은 Fermat factorization 공격에 취약하기때문에 쉽게 공격할 수 있다.

문제파일을 실행시켜 p와 q를 보면 128bytes중 마지막 1~2바이트만 다르다... 확실히 p/q는 1에 가깝다..

messages = ["Do U know RSA?", "The format of flag is: SCTF{}", flag]

def encrypt(m):
msg = bytes_to_long(m)
ct = gmpy2.powmod(msg, e, N)
ct = long_to_bytes(ct)
return ct.encode("hex")

open("ciphertext.txt", "w").write(", ".join(map(encrypt, messages)))

문제를 보면 알려진 평문 2개와 암호문 3개가 주어진다.

문제가 신기한게 N값을 알려주지않아서 우리가 직접 구해야하는데, 이 N값을 구하는 것은 RSA 알고리즘을 알고 있다면 쉽게 할 수 있다.

암호화의 대상인 평문 세개를 m1, m2, m3라고 하면 암호문 c1, c2는 각각

c1 = m1e mod N, c2 = m2e mod N

m1e - c1 = 0 mod N

m2e - c2 = 0 mod N

이므로, m1e - c1. m2e - c2 N의 배수이다.

그래서 이 두 수의 공약수를 구하면 N의 배수가 된다. 공약수를 구하는 것은 아래와 같이 수행할 수 있다.

from Crypto.Util.number import *
import gmpy2

e = 65537
ct = [bytes_to_long(c.decode("hex")) for c in open("ciphertext.txt").read().split(", ")]
pt = map(bytes_to_long, ["Do U know RSA?", "The format of flag is: SCTF{}"])

k1N = pow(pt[0], e) - ct[0]
k2N = pow(pt[1], e) - ct[1]

N = gmpy2.gcd(k1N, k2N)

assert(gmpy2.powmod(pt[0], e, N) == ct[0])
assert(gmpy2.powmod(pt[1], e, N) == ct[1])
print hex(N)

N을 찾았으니 이제 p, q만 찾으면 문제를 해결할 수 있을 것 같다.

p와 q는 아주 큰 수라서, 하위 1~2 byte 정도 차이는 상대적으로 작은 부분이라고 볼 수 있다. 그러니 p ≈ q라고 쓰자.

N = pq ≈ p2이므로, sqrt(N) ≈ p이라고 쓸 수 있다.(sqrt()는 squre root 함수)

p와 q 중에 하나는 sqrt(N) 보다 크고 하나는 작을 것인데, 우리 입장에서는 p 또는 q 중에하나만 알면 되니 sqrt(N) 부터 시작해서 brute force를 시도하면 금방 N의 약수를 찾을 수 있을 것이다.

N의 약수 하나(편의상 p 라고 하자)를 찾으면 N으로 나누어 q를 계산할 수 있다. p와 q를 알게 되면 문제 코드 처음에 써있던 대로 비밀키 d를 구할 수 있고, flag를 복호화 할 수 있게 된다.


ver1.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from Crypto.Util.number import *
import gmpy2 
 
def xgcd(b, n):
    x0, x1, y0, y1 = 1001
    while n != 0:
        q, b, n = b // n, n, b % n
        x0, x1 = x1, x0 - q * x1
        y0, y1 = y1, y0 - q * y1
    return b, x0, y0
    
def mul_inv(b, n):
    g, x, _ = xgcd(b, n)
    if g == 1:
        return x % n
    
def fermat_factor(n):
    assert n % 2 != 0
    
    a = gmpy2.isqrt(n)
    b2 = gmpy2.square(a) - n
    
    while not gmpy2.is_square(b2):
        a += 1
        b2 = gmpy2.square(a) - n
    p = a + gmpy2.isqrt(b2)
    q = a - gmpy2.isqrt(b2)
    
    return int(p), int(q)
 
 
= 65537
ct = [bytes_to_long(c.decode("hex")) for c in open("ciphertext.txt").read().split(", ")]
pt = map(bytes_to_long, ["Do U know RSA?""The format of flag is: SCTF{}"])
 
k1N = pow(pt[0], e) - ct[0]
k2N = pow(pt[1], e) - ct[1]
 
= gmpy2.gcd(k1N, k2N)
= ct[2]
 
p, q = fermat_factor(n)
phi = (p-1)*(q-1)
 
= mul_inv(e, phi)
plain = pow(c, d, n)
 
print("p :" + str(p))
print("q :" + str(q))
print("p/q :" +str(p/float(q)))
 
print("flag : " + long_to_bytes(plain))
 
cs



ver2.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from Crypto.Util.number import *
import gmpy2
 
= 65537
ct = [bytes_to_long(c.decode("hex")) for c in open("ciphertext.txt").read().split(", ")]
pt = map(bytes_to_long, ["Do U know RSA?""The format of flag is: SCTF{}"])
 
'''
pt[0] ^ e mod N = ct[0]
pt[1] ^ e mod N = ct[1]
pt[0] ^ e - ct[0] = k1 * N
pt[1] ^ e - ct[1] = k2 * N
'''
 
k1N = pow(pt[0], e) - ct[0]
k2N = pow(pt[1], e) - ct[1]
 
= gmpy2.gcd(k1N, k2N)
 
for i in range(2100):
    if N % i == 0 \
    and gmpy2.powmod(pt[0], e, N // i) == ct[0] \
    and gmpy2.powmod(pt[1], e, N // i) == ct[1]:
        print "reduced:", i
        N //= i
 
assert(gmpy2.powmod(pt[0], e, N) == ct[0])
assert(gmpy2.powmod(pt[1], e, N) == ct[1])
 
= gmpy2.isqrt(N)
 
while True:
    q, r = gmpy2.t_divmod(N, p)
    if (r == 0):
        break
    p += 1
 
phiN = (p - 1* (q - 1)
= gmpy2.powmod(e, -1, phiN)
 
flag = long_to_bytes(gmpy2.powmod(ct[2], d, N))
print flag
 
cs


SSTF OpenCTF에 나왔던 LSB Oracle 문제.

익스코드

ver1.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import decimal
from pwn import *
from Crypto.Util.number import long_to_bytes
 
conn = remote("certainparts.sstf.site"12345)
 
= 125560377595624869696322630015882810288481844943661657098054331638111628210374072413173109448166779450170382439644694193024664710888205521543909902890609386144989825610230813823505041837614119263169879059174244133031673428386864683895197298571580602944596661705915969281038713739672744280271836542084404846309
= 65537
enc = 118647304114971068925032683768641917858857901141412816512618918698813600434373504387159513526475575029510748903851883089892473885809651616328335069672142793018627751022608056294812828777586302633454779414048616431461446403924692378415772959218687182740655957635595144364189435787610949804907648326594886012169

= N.bit_length()
decimal.getcontext().prec = k
lower = decimal.Decimal(0)
upper = decimal.Decimal(N)
 
 
p2 = pow(2, e, N)
lower = decimal.Decimal(0)
upper = decimal.Decimal(N)
= p2
 
for i in xrange(k):
    mid = (lower + upper) / 2
    conn.readuntil('Ciphertext: ')
    conn.sendline(hex(enc * p % N)[2:].strip("L"))
    conn.recvuntil("LSB: ")
    cur = int(conn.readline().strip())
    if cur == 0:
        upper = mid
    else:
        lower = mid
    p = p * p2 % N
    print(int(upper))
print long_to_bytes(int(upper))
conn.interactive()
cs


ver2.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from Time import *
from Base import *
from Factorization import *
from Prime import *
 
SECRET = None
PUBLIC = None
COUNTER = 0
 
def GetInstance(l, flag=None):
    p1 = RandomPrime(l//2); p2 = RandomPrime(l//2)
    n = p1*p2
    q1 = p2*InverseMod(p2, p1)
    q2 = p1*InverseMod(p1, p2)
    m = (p1-1)*(p2-1)
    e = 0x10001
    d = InverseMod(e, m)
    l = l//8
    if flag == None:
        r = RandomInteger(0, n)
    else:
        r = (int.from_bytes(flag[:l].encode("utf-8"), "big"))%n
    y = pow(r, e, n)
    global SECRET, PUBLIC
    SECRET = [p1, p2, q1, q2, d, r]
    PUBLIC = [n, e, l, y]
    return
 
def LSBOracle(y):
    global SECRET, PUBLIC, COUNTER
    COUNTER += 1
    #x = pow(y, SECRET[2], PUBLIC[0])
    d1 = SECRET[4]%(SECRET[0]-1); d2 = SECRET[4]%(SECRET[1]-1)
    y1 = y%SECRET[0]; y2 = y%SECRET[1]
    x1 = pow(y1, d1, SECRET[0])
    x2 = pow(y2, d2, SECRET[1])
    x = (x1*SECRET[2+ x2*SECRET[3])%PUBLIC[0]
    return x%2
 
 
def ChosenCiphertext(n, e, y, r=None):
    if r == None:
        r = RandomInteger(0, n)
    else:
        r = r%n
    z = (pow(r, e, n)*y)%n
    return (r, z)
 
def DivideByTwo(a, y):
    #assert DecryptionOracle_LSB(u) == 0
    c = InverseMod(2, n)
    z = (y*pow(c, e, n))%n
    b = (a*c)%n if a%2 == 1 else a//2
    return (b, z)
 
def Attack(n, e, y):
    (a, u) = ChosenCiphertext(n, e, y)
    (b, v) = ChosenCiphertext(n, e, y)
    while u != 1 and v!= 1:
        if u == v or u == 0 or v == 0:
            print("Attack Failed!: the GCD is not equal to 1")
            return None
        lsb_u = LSBOracle(u)
        lsb_v = LSBOracle(v)
        while lsb_u*lsb_v == 0:
            if lsb_u == 0:
                (a, u) = DivideByTwo(a, u)
                lsb_u = LSBOracle(u)
            if lsb_v == 0:
                (b, v) = DivideByTwo(b, v)
                lsb_v = LSBOracle(v)
        (c, u) = ChosenCiphertext(n, e, y, a+b)
        if LSBOracle(u) != 0:
            print("Attack Failed: the sum is greater than modulus")
            return None
        (c, u) = DivideByTwo(c, u) 
        (d, v) = ChosenCiphertext(n, e, y, a-b)
        if LSBOracle(v) == 1:
            (d, v) = ChosenCiphertext(n, e, y, b-a)
        (d, v) = DivideByTwo(d, v)
        a = c; b = d
    return InverseMod(a, n) if u == 1 else InverseMod(b, n)
 
cs


'Write-up > Crypto' 카테고리의 다른 글

[2019 SSTF OpenCTF] BlackHackerService  (0) 2019.08.28
[2019 SSTF OpenCTF] Roughlt Secure Algorithm  (0) 2019.08.28
[RedpwnCTF] Binary (RSA LSB Oracle Attack)  (0) 2019.08.17
[PlaidCTF 2019] R u SAd?  (0) 2019.04.18
[VolgaCTF 2019] Blind  (0) 2019.04.03

010000100110100101101110011000010111001001111001

Binary

Written by: Tux


0100100100100000011001100110111101110101011011100110010000100000011101000110100001101001011100110010000001110111011001010110100101110010011001000010000001110011011001010111001001110110011010010110001101100101001011100010111000101110

I found this weird service...


nc chall2.2019.redpwn.net 5001


Hint: 010010010111001100100000011010010111010000100000011001010111011001100101011011100010000001101111011100100010000001101111011001000110010000111111

Is it even or odd?

  


아는 분이 RSA 문제 소개시켜주어서... 잠깐 풀어보았는데

롸업은 나중에 쓰고 일단 익스코드 나중에 쓰일거같아서 저장하려고 ㅎㅎ...

설명은 나중에 올려야지


RSA LSB Oracle Attack 기법을 사용해서 풀 수 있다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import decimal
from pwn import *
from Crypto.Util.number import long_to_bytes
 
conn = remote("chall2.2019.redpwn.net"5001)
 
def decode_binary(ut):
    msg = conn.recvuntil(ut)[:-1]
    msg = int(msg,2)
    msg = long_to_bytes(msg)
    result = conn.recvline()
    return msg, result
 
print(decode_binary("\n")[0])
print(decode_binary("\n")[0])
 
msg, result = decode_binary(":")
N, e = result.strip()[1:-1].split(",")
N=int(N,2)
e=int(e,2)
 
print(msg + " : " + str(N) + ", " + str(e))
 
conn.recvline()
 
msg, result = decode_binary(":")
enc = int(result,2)
print(msg + " : " + str(enc))
 
 
= N.bit_length()
decimal.getcontext().prec = k
lower = decimal.Decimal(0)
upper = decimal.Decimal(N)
 
 
p2 = pow(2, e, N)
lower = decimal.Decimal(0)
upper = decimal.Decimal(N)
= p2
 
for i in xrange(k):
    mid = (lower + upper) / 2
    conn.readuntil('> ')
    conn.sendline(bin(enc * p % N)[2:])
    cur = int(conn.readline().strip())
    if cur == 0:
        upper = mid
    else:
        lower = mid
    p = p * p2 % N
    print(int(upper))
print long_to_bytes(int(upper))
conn.interactive()
cs


+ Recent posts