Idea of the challenge was to use Meet-in-the-middle attack.

Basically instead of trying all combinations of 48-bit keys, we generate all possible outputs from encryption function with 24-bit key and plaintext 16 bit plaintext and also generate all possible outputs from decryption function with 24-bit key and ciphertext \x04\x67\xa5\x2a\xfa\x8f\x15\xcf\xb8\xf0\xea\x40\x36\x5a\x66\x92.

We then try to find equal output from the collection of the encrypted output and from the decrypted output, representating the “ct1”-variable of the given code.

This reduces the needed calculations from 2^48 to 2^25, which should be doable in reasonable time on normal home-computer.

The decryption version of the fun-function is not provided, so we need to create one.

The function consist of rounds of simple xor, substitution ans transposition actions, so decryption function is easy to create by reversing the order of actions and slight modifications to the transposition and substitution actions.

The provided code was not optimal so I needed to make few changes to the code. Runtime of the final brute-forcing code was ~1h on my computer, which was not too bad.

    from hashlib import md5
    from binascii import unhexlify 
    import time

    BLOCK_LENGTH = 16
    KEY_LENGTH = 3
    ROUND_COUNT = 16

    sbox = [210, 213, 115, 178, 122, 4, 94, 164, 199, 230, 237, 248, 54,
            217, 156, 202, 212, 177, 132, 36, 245, 31, 163, 49, 68, 107,
            91, 251, 134, 242, 59, 46, 37, 124, 185, 25, 41, 184, 221,
            63, 10, 42, 28, 104, 56, 155, 43, 250, 161, 22, 92, 81,
            201, 229, 183, 214, 208, 66, 128, 162, 172, 147, 1, 74, 15,
            151, 227, 247, 114, 47, 53, 203, 170, 228, 226, 239, 44, 119,
            123, 67, 11, 175, 240, 13, 52, 255, 143, 88, 219, 188, 99,
            82, 158, 14, 241, 78, 33, 108, 198, 85, 72, 192, 236, 129,
            131, 220, 96, 71, 98, 75, 127, 3, 120, 243, 109, 23, 48,
            97, 234, 187, 244, 12, 139, 18, 101, 126, 38, 216, 90, 125,
            106, 24, 235, 207, 186, 190, 84, 171, 113, 232, 2, 105, 200,
            70, 137, 152, 165, 19, 166, 154, 112, 142, 180, 167, 57, 153,
            174, 8, 146, 194, 26, 150, 206, 141, 39, 60, 102, 9, 65,
            176, 79, 61, 62, 110, 111, 30, 218, 197, 140, 168, 196, 83,
            223, 144, 55, 58, 157, 173, 133, 191, 145, 27, 103, 40, 246,
            169, 73, 179, 160, 253, 225, 51, 32, 224, 29, 34, 77, 117,
            100, 233, 181, 76, 21, 5, 149, 204, 182, 138, 211, 16, 231,
            0, 238, 254, 252, 6, 195, 89, 69, 136, 87, 209, 118, 222,
            20, 249, 64, 130, 35, 86, 116, 193, 7, 121, 135, 189, 215,
            50, 148, 159, 93, 80, 45, 17, 205, 95]
    sbox2 = [sbox.index(i) for i in range(256)]
    p = [3, 9, 0, 1, 8, 7, 15, 2, 5, 6, 13, 10, 4, 12, 11, 14]

            
    def xor(a, b):
        for k in range(BLOCK_LENGTH):
            a[k]^=b[k]
        return a


    def fun(key, pt):
        ct = bytearray(pt)
        for _ in range(ROUND_COUNT):
            ct = xor(ct, key)
            for i in range(BLOCK_LENGTH):
                ct[i] = sbox[ct[i]]
            nct = bytearray(BLOCK_LENGTH)
            for i in range(BLOCK_LENGTH):
                nct[i] = ct[p[i]]
            ct = nct
        return ct
        
    def rev_fun(key, pt):
        ct = bytearray(pt)
        for _ in range(ROUND_COUNT):
        
            nct = bytearray(BLOCK_LENGTH)
            for i in range(BLOCK_LENGTH):
                nct[p[i]] = ct[i]
            ct=nct
            for i in range(BLOCK_LENGTH):
                ct[i] = sbox2[ct[i]]
            
            ct = xor(ct, key)
            
        return ct
        
    set1 = set()
    set2 = set()
    s1 = {}
    s2 = {}
    key1 = bytearray(3)
    u = unhexlify(b'0467a52afa8f15cfb8f0ea40365a6692')

    print("Phase1")
    for i in range(256):
        t=time.time()
        key1[0]=i
        for j in range(256):
            key1[1]=j
            for k in range(256):
                key1[2]=k
                key = md5(key1).digest()
                dec = fun(key,b"16 bit plaintext")
                cry = rev_fun(key,u)
                set1.add(bytes(dec))
                set2.add(bytes(cry))
                s1[bytes(dec)]=bytes(key1)
                s2[bytes(cry)]=bytes(key1)
        print(i,time.time()-t)
        
    print("Phase2")
    for k in set1:
        if(k in set2):
            print(s1[k],s2[k])

Author: jusbsec