mcfx's blog

题解、Writeup、游记和碎碎念

HITCON CTF 2019 Quals Writeup

Contents

Misc

EmojiiVM

It's too long to directly print each character, but we can push 1,1,1,2,...,9,9 into the stack, and write a simple loop to print the table.
Unfortunately my blog doesn't support emojis, so here's no solution.

heXDump

Xxd only overwrites the first bytes of the file, so we can just enumerate each byte.

from pwn import *

r=remote('13.113.205.160',21700)

def get():
    r.recvuntil('0) quit\n')
    r.send('2\n')
    return r.recvuntil('\n')[:-1]

r.recvuntil('0) quit\n')
r.send('1337\n')

fh=get()

known='hitcon{'

while True:
    flag=False
    for j in range(32,127):
        print j
        t=known+chr(j)
        r.recvuntil('0) quit\n')
        r.send('1\n')
        r.recvuntil('format)\n')
        r.send(t.encode('hex')+'\n')
        if get()==fh:
            known=t
            flag=True
            break
    print known
    if not flag: break

Crypto

Very Simple Haskell

If a bit in the flag is 1, the answer will multiply square of a specific prime number.
Then we can use meet-in-middle to find the answer. However, here N is too big, so just factorize the primes is okay.

from gmpy2 import *
from Crypto.Util.number import isPrime

primes=[]
pinv={}
for i in range(2,5000):
    if isPrime(i):
        pinv[i]=len(primes)
        primes.append(i)
print(len(primes))

n=134896036104102133446208954973118530800743044711419303630456535295204304771800100892609593430702833309387082353959992161865438523195671760946142657809228938824313865760630832980160727407084204864544706387890655083179518455155520501821681606874346463698215916627632418223019328444607858743434475109717014763667

base=129105988525739869308153101831605950072860268575706582195774923614094296354415364173823406181109200888049609207238266506466864447780824680862439187440797565555486108716502098901182492654356397840996322893263870349262138909453630565384869193972124927953237311411285678188486737576555535085444384901167109670365

req=84329776255618646348016649734028295037597157542985867506958273359305624184282146866144159754298613694885173220275408231387000884549683819822991588176788392625802461171856762214917805903544785532328453620624644896107723229373581460638987146506975123149045044762903664396325969329482406959546962473688947985096

req=req*invert(base,n)%n

flag='hitcon{'
for i in range(6):
    t=0
    for j in range(8):
        r=28+i*8-j
        if req%primes[r]==0:
            t|=1<<j
    flag+=chr(t)
print flag+'}'

Rᴀɴᴅᴏᴍʟʏ Sᴇʟᴇᴄᴛ A 🐈🐱🐾

In the signing process, whatever we send to the server, the deflated data is ~90 bytes, then after padding they will have common high bits.
Now we have some equations like a3=bN+c+Ka^3=bN+c+K (a<na<n, c<nc<\sqrt{n} and KK is constant). It means (a3K)(a^3-K) is very close to a multiple of N. We can use some linear combinations of these a3Ka^3-K to make a number very close to N.
Assume we have a set of these close-multiple of N, then we can take two numbers from the set, make the difference, then insert back to the set. Each time we do this, and maintain the smallest numbers, finally we can get a number close to N. Also we know some linear equations of (a3K)(a^3-K) that resulted in N, gaussian elimination these equations, we can find b in the original equations. Then N is also found.
After knowing N, we can send any command to the server. But we still need to sign it.
When signing, we need to calculate the cube root of a padded string. We can use Bleichenbacher’s Low-Exponent Attack, but the deflated string is too long, so this will probably fail.
Then I wrote the deflate algorithm myself, and changed the huffman tree to lower the size, and just enumerate nonce and wait until we find one.

Signing script:

import sys
import zlib,hashlib,random
from gmpy2 import *
import numba as nb
from multiprocessing import Pool


@nb.jit(nopython=True)
def get_hdict(codelengths):
    mp={}
    nextcode = 0
    for codelength in range(1, max(codelengths) + 1):
        nextcode <<= 1
        startbit = 1 << codelength
        for (symbol, cl) in enumerate(codelengths):
            if cl != codelength:
                continue
            mp[symbol]=startbit | nextcode
            nextcode += 1
    return mp

@nb.jit(nopython=True)
def upd(a,au,b,c,d):
    if not b in a:
        a[b]=c
        au[b]=d
    else:
        if c<a[b]:
            a[b]=c
            au[b]=d

@nb.jit(nopython=True)
def cal_huffman(s):
    ps=[0]
    for i in s:
        ps.append(ps[-1]+i)
    f=[{(0,2):0}]
    for i in range(19):
        f.append({(100+i,0):0})
    g=[{(0,2):(0,0)}]
    for i in range(19):
        g.append({(100+i,0):(0,0)})
    best=10**10
    bp=0,0,0
    LIM=450 if len(s)>15 else 120
    for i in range(1,16):
        for j in f[i-1]:
            x,y=j
            if y>len(s)-x:
                continue
            v=f[i-1][j]
            for k in range(1+min(y,len(s)-x)):
                if (y-k)*2>len(s)-(x+k): continue
                nv=(ps[x+k]-ps[x])*i+v
                upd(f[i],g[i],(x+k,(y-k)*2),nv,j)
                if x+k==len(s):
                    if nv<best:
                        best=nv
                        bp=i,x+k,y-k
    res=[0]
    t,x,y=bp
    for i in range(t,0,-1):
        ox,oy=g[i][x,y]
        res+=[i]*(x-ox)
        x,y=ox,oy
    return res[:0:-1]

@nb.jit(nopython=True)
def getkl(s):
    s2=[]
    for ii in s:
        if len(s2)==0 or s2[-1][0]!=ii:
            s2.append([ii,1])
        else:
            s2[-1][1]+=1
    re=[(0,0)]
    for i in s2:
        if i[0]==0:
            if i[1]<3:
                for j in range(i[1]):
                    re.append((0,0))
            elif i[1]<11:
                re.append((17,i[1]-3))
            else:
                re.append((18,i[1]-11))
            continue
        if i[1]<=3:
            for j in range(i[1]):
                re.append((i[0],0))
            continue
        if i[1]<=7:
            re.append((i[0],0))
            re.append((16,i[1]-4))
        else:
            re.append((i[0],0))
            re.append((16,3))
            if i[1]-7<=2:
                for j in range(i[1]-7):
                    re.append((i[0],0))
            else:
                re.append((16,i[1]-10))
    return re[1:]

@nb.jit(nopython=True)
def ad(s,l,x):
    if l!=-1:
        for i in range(l):
            s.append(x>>i&1)
    else:
        assert x
        t=0
        while x>>t:
            t+=1
        for i in range(t-2,-1,-1):
            s.append(x>>i&1)


@nb.jit(nopython=True)
def compress(s):
    t={-1:-1}
    t.pop(-1)
    for i in s:
        t[i]=0
    for i in s:
        t[i]+=1
    for i in range(10):
        t[i+48]+=100
    t[34]+=40
    t2=[]
    for i in t:
        t2.append((t[i],i))
    t=t2
    t.sort()
    t=t[::-1]
    tx=[]
    for i2 in t: tx.append(i2[0])
    tx.append(1)
    tp=[4]*11+[5]*8+[6]*4
    tpx=[0 for i in range(257)]
    for i in range(len(t)):
        tpx[t[i][1]]=tp[i]
    tpx[256]=tp[-1]
    tpc=get_hdict(tpx)
    tl=getkl(tpx+[0])
    tlu={}
    for ii in tl:
        tlu[ii[0]]=0
    for ii in tl:
        tlu[ii[0]]+=1
    tlux=[]
    for i in tlu:
        tlux.append((tlu[i],i))
    tlux.sort()
    tlux=tlux[::-1]
    tlk_=[]
    for ii in tlux:
        tlk_.append(ii[0])
    tlkl=cal_huffman(tlk_)
    tlklu=[0]*19
    for i in range(len(tlux)):
        tlklu[tlux[i][1]]=tlkl[i]

    res=[0]
    res.pop()
    res+=[1] #final
    ad(res,2,2) #type
    ad(res,5,0) #hlit
    ad(res,5,0) #hdist
    tlklc=get_hdict(tlklu)

    rest=[0]
    rest.pop()
    ad(rest,3,tlklu[16])
    ad(rest,3,tlklu[17])
    ad(rest,3,tlklu[18])
    ad(rest,3,tlklu[0])
    tlklu[16]=0
    tlklu[17]=0
    tlklu[18]=0
    tlklu[0]=0
    for i in range(100):
        su=0
        for j in tlklu:
            su+=j
        if su==0:break
        j = (8 + i // 2) if (i % 2 == 0) else (7 - i // 2)
        ad(rest,3,tlklu[j])
        tlklu[j]=0
    ad(res,4,(len(rest)-12)/3) #hclen
    res+=rest

    for ii in tl:
        ad(res,-1,tlklc[ii[0]])
        if ii[0]==16:
            ad(res,2,ii[1])
        elif ii[0]==17:
            ad(res,3,ii[1])
        elif ii[0]==18:
            ad(res,7,ii[1])

    for i in s:
        ad(res,-1,tpc[i])
    ad(res,-1,tpc[256])
    while len(res)%8:
        res+=[0]
    fin=[]
    for i in range(0,len(res),8):
        tuu=0
        for j in range(8):
            tuu+=res[i+j]<<j
        fin.append(tuu)
    return fin

def real_comp(s):
    fin=''.join(map(chr,compress(list(map(ord,list(s))))))
    ss=zlib.compress(s)
    return '789c'.decode('hex')+fin+ss[-4:]

def sign(m_):
    m,id=m_
    h=int(hashlib.sha256(m).hexdigest(),16)
    ni=0
    hh=str(h)
    mi=10**15
    posi=0
    while True:
        a=ni%(len(hh)-7);b=ni/(len(hh)-7)
        t='{"hash":%d,"nonce":"%s"}'%(h,''.join([random.choice(list('0123456789'))for _ in range(8)]))
        x=real_comp(t)
        if len(x)<=83:posi+=1
        u='\xCA\xFE\x12\x04'+x+'\0'*(251-len(x))
        v=int(u.encode('hex'),16)
        g=iroot(v,3)
        g=g[0] if g[1] else g[0]+1
        v=g**3
        rv=('%x'%v).decode('hex')
        ut=int(rv[4:4+len(x)].encode('hex'),16)^int(x.encode('hex'),16)
        if ut<mi:
            mi=ut
            print 'cur best(%d):'%id,mi
        if rv[4:4+len(x)]==x:
            open('result.txt','ab').write(m+' '+rv+'\n')
            print rv
            return rv
        ni+=1
        if ni%1000==0:
            print ni,posi

if __name__ == '__main__':
    pool = Pool(processes=4)
    pool.map(sign,[('meow*',_) for _ in range(4)])

Attack script:

from pwn import *
from gmpy2 import gcd,iroot
import zlib,hashlib

r=remote('54.92.6.97', 3239)

def geto():
    r.recvuntil('meow?\n')
    r.send('meow~\n')
    r.recvuntil('meow~\n')
    r.send('a\n')
    return int(r.recvuntil('\n'))

diff=0xcafe1204<<2008

s=[]
for i in range(12):
    print 'retrieve:',i
    s.append(geto())

s=list(map(lambda x:x**3-diff,s))
s.sort()
os=s[0].bit_length()
print(os)
sold=s
st=[]
for i in range(len(s)):
    st.append((s[i],[1 if _==i else 0 for _ in range(len(s))]))
s=st

K=1000

cnt=0

def sub(a,b):
    t=[]
    for i in range(len(a)):
        t.append(a[i]-b[i])
    return t

def mul(a,b):
    t=[]
    for i in range(len(a)):
        t.append(a[i]*b)
    return t

def subx(a,b):
    return (a[0]-b[0],sub(a[1],b[1]))

for i in range(5000):
    t=[]
    if len(s)<K*.9:
        for j in range(len(s)):
            t.append(s[j])
            for k in range(j+1,min(len(s),j+4)):
                t.append(subx(s[k],s[j]))
    else:
        for j in range(len(s)-1):
            t.append(subx(s[j+1],s[j]))
        for j in range(min(len(s),1000)):
            t.append(s[j])
    t.sort()
    t2=[]
    for j in t:
        if j[0]<(1<<1900):continue
        if len(t2)==0 or j[0]>t2[-1][0]+(1<<1900) or (t2[-1][0].bit_length()==2048 and j[0].bit_length()==2048 and t2[-1]!=j):
            t2.append(j)
    t=t2
    t=t[:K]
    tb=t[0][0].bit_length()
    if i%10==0: print i+1,os-tb,'%.5f'%((os-tb)/(i+1.))
    s=t
    if t[0][0].bit_length()==2048:
        cnt+=1
        if cnt>=15:
            break

full=(1<<2048)-1
mask=full
for i in range(len(s)-1):
    if s[i+1][0].bit_length()!=2048: continue
    t=full^s[i][0]^s[i+1][0]
    if (t>>2045)!=7: continue
    mask&=t
mask^=full
t=0
while (1<<t)<mask:
    t+=1
mask=full^(1<<t)-1
print(t)
req=s[0][0]&mask
sn=[]
for i in s:
    if (i[0]&mask)==req:
        sn.append(i[1])

s=[]
for i in range(len(sn)-1):
    s.append(sub(sn[i],sn[i+1]))

while len(s[0])>2:
    sn=[]
    st=[]
    for j in s:
        if j[-1]:
            st.append(j)
        else:
            sn.append(j[:-1])
    for j in range(len(st)-1):
        a=st[j]
        b=st[j+1]
        at=mul(a,b[-1])
        bt=mul(b,a[-1])
        t=sub(at,bt)
        assert t[-1]==0
        sn.append(t[:-1])
    s=sn
for i in s:
    if i[0]:
        a,b=i
g=gcd(a,b)
a/=g;b/=g
tn=sold[0]/abs(b)
tn*=req/tn if req%tn<tn/2 else req/tn+1
n=tn
print 'found n:',n

def encrypt(s):
    return '%0512x'%pow(int(('\xCA\xFE\x12\x04'+'\0'*(251-len(s))+s).encode('hex'),16),3,n)

keys={}
keys['meow*']='%0512x'%4643124907324364176541919631092611537462168169113368694062754706716345623066850854804612072567302665612238408034253132873785414030172160601569222092854027911483401638265250179534654654959278137810816171464

r.recvuntil('meow?\n')
r.send('meow!\n')
r.recvuntil('meow meow~\n')
e=encrypt('meow*')
r.send(e+'\n')
r.recvuntil('meow meow meow?\n')
r.send(keys['meow*']+'\n')
r.interactive()

Reverse

EmojiVM

It requires some token to print flag. The token is encrypted.

s=[142, 99, 205, 18, 75, 88, 21, 23, 81, 34, 217, 4, 81, 44, 25, 21, 134, 44, 209, 76, 132, 46, 32, 6, 0]
s2=[]
for i in range(25):
    if i%4==0:
        s2.append(s[i]-30)
    elif i%4==1:
        s2.append((s[i]^7)+8)
    elif i%4==2:
        s2.append(((s[i]+4)^68)-44)
    else:
        s2.append(s[i]^4^101)
print(''.join(map(chr,s2[:-1])))

Use the script above to find the token. Then just paste the token to get flag.

Suicune

It requires a key in 00xffff to encrypt a given message of length N.
The encryption is in 16 rounds, in each rounds, the key is uses to shuffle a permutation of 0
255, then the first N bytes of the permutation will be processed, then xor to the original message.
The process is to regard it as a permutation, and next it K times, where K is related to key.

def fcount(s):
    res=0
    for i in range(len(s)):
        c=0
        for j in range(i+1,len(s)):
            c+=s[j]>s[i]
        res+=c*fac[len(s)-i-1]
    return res

def fkth(n,k):
    s=[0 for i in range(n)]
    for i in range(n-1,-1,-1):
        s[i]=n-i-1-(k%fac[n-i]//fac[n-i-1])
        for j in range(i+1,n):
            if s[j]>=s[i]:
                s[j]+=1
    return s

def fwalk(s,k):
    t=fcount(s)
    t=max(0,t-k)
    st=fkth(len(s),t)
    s.sort()
    res=[]
    for i in st:
        res.append(s[i])
    return res

fac=[1]
for i in range(1,100):
    fac.append(fac[-1]*i)

def ROR4(x,y):
    x%=(1<<32)
    assert y>=0 and y<32
    return ((x>>y) | (x<<(32-y))) % (1<<32)

def keystr(s):
    r='%016x'%s
    u=''
    for i in range(16,0,-2):
        u+=r[i-2:i]
    return u

L=49
enc='04dd5a70faea88b76e4733d0fa346b086e2c0efd7d2815e3b6ca118ab945719970642b2929b18a71b28d87855796e344d8'

for key in range(1<<16):
    raw=[0]*L
    okey=key
    key=(6364136223846793005*(key%0x10000)+6364136223846793006)%(1<<64)

    for T in range(16):
        Carr=[_ for _ in range(256)]

        for v41 in range(255,0,-1):
            v154=v41+1
            v51=key
            if ((1<<32)-v154)%v154:
                v155=key
                while True:
                    v155 = (1 + 6364136223846793005 * v155) % (1<<64)
                    v156 = ROR4((v51 ^ (v51 >> 18)) >> 27, v51 >> 59)
                    v51 = v155
                    if v156 < (1<<32) - (((1<<32)-v154) % v154):
                        break
            else:
                v155 = (1 + 6364136223846793005 * v51) % (1<<64)
                v157 = (v51 ^ (v51 >> 18)) >> 27
                v51 >>= 59
                v156 = ROR4(v157, v51)
            key = v155
            v158 = v156 % v154
            Carr[v41],Carr[v158]=Carr[v158],Carr[v41]
        dest = 0
        for i in range(0,64,32):
            dest += ROR4((key ^ (key >> 18)) >> 27, key >> 59) << i
            key = (1 + 6364136223846793005 * key) % (1<<64)

        Darr = Carr[:L]

        Darr=fwalk(Darr,dest)

        tmp=[]
        for i in range(L-1,-1,-1):
            tmp.append(raw[i]^Darr[i])
        raw=tmp

    res=''
    ok=True
    for i in range(0,L*2,2):
        t=int(enc[i:i+2],16)^raw[i//2]
        res+=chr(t)
        ok=ok and t>=32 and t<=127
    if ok:
        print(res)
    if okey%100==0:
        print(okey)

日期: 2019-10-15

标签: CTF Writeup

这是一篇旧文,原始文章及评论可在 https://oldblog.mcfx.us/archives/268/ 查看。