From ff93430c5df52518b8f404e3c0c17bef071cdbdd Mon Sep 17 00:00:00 2001 From: Carson Fleming Date: Tue, 28 Mar 2023 11:02:03 -0700 Subject: Fix some transition bugs; pty still hangs trying to pull streaming key --- crypto.py | 56 ++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 20 deletions(-) (limited to 'crypto.py') diff --git a/crypto.py b/crypto.py index 7de099e..7994fad 100644 --- a/crypto.py +++ b/crypto.py @@ -7,8 +7,8 @@ class Crypto: def byte_length (i): return -(-i.bit_length()//8) - def i2b (i): - return i.to_bytes(Crypto.byte_length(i), 'big') + def i2b (i, sz=None): + return i.to_bytes(sz if sz else Crypto.byte_length(i), 'big') def b2i (b): return int.from_bytes(b, 'big') @@ -55,32 +55,36 @@ class Crypto: def pad(m, bits): headsz = Crypto.headsize(bits) - chunks, nbytes = [], bits//8-headsz-Crypto.rand_pad-1 + chunks, nbytes = [], -(-bits//8)-headsz-Crypto.rand_pad-1 while len(m) > nbytes: - chunk = Crypto.i2b(nbytes) + m[:nbytes] + chunk = Crypto.i2b(nbytes, headsz) + m[:nbytes] + secrets.token_bytes(Crypto.rand_pad) m = m[nbytes:] chunks.append(Crypto.b2i(chunk)) - chunk = Crypto.i2b(len(m)) + m + secrets.token_bytes(nbytes - len(m) + Crypto.rand_pad) - chunks.append(int.from_bytes(chunk, 'big')) + chunk = Crypto.i2b(len(m), headsz) + m + secrets.token_bytes(nbytes - len(m) + Crypto.rand_pad) + chunks.append(Crypto.b2i(chunk)) return chunks def unpad(p, bits): - m_chunks, headsz = [], bits//8-1, Crypto.headsize(bits) + nbytes = -(-bits//8) - 1 + m_chunks, headsz = [], Crypto.headsize(bits) for p_chunk in p: - chunk = Crypto.i2b(p_chunk) + chunk = Crypto.i2b(p_chunk, nbytes) chunksz = Crypto.b2i(chunk[:headsz]) m_chunks.append(chunk[headsz:headsz+chunksz]) return b''.join(m_chunks) def encrypt(m, e, n, bits): - return [Crypto.i2b(pow(p, e, n)) for p in Crypto.pad(m, bits)] + nbytes = -(-bits//8) + return [Crypto.i2b(pow(p, e, n), nbytes) for p in Crypto.pad(m, bits)] def decrypt(p, d, n, bits): return Crypto.unpad([pow(Crypto.b2i(c), d, n) for c in p], bits) def keygen(bits=2048): - p, q = Crypto.pgen(bits >> 1), Crypto.pgen(bits >> 1) + #p, q = Crypto.pgen(bits >> 1), Crypto.pgen(bits >> 1) + p = 24255437060933278568327701135893465111281929996020213283563016322587538898307222832676648856557095025772204413206980826822975131457046546497523091381599846376591186469347076218208800469787851326371447052505831886762813382071182380079565671280034572474822918684998169291052805586193101619898239325398349038870346792747848088462443207267957917761324031277878527517794286421236832567949642803568424614512153948596887559236723641422745633542467824845355764889656771928727895545492980157734692593529905622895318376798465152794082339714326113478586369772394086511968434509576965140800722212375216417198365996986997884522143 + q = 19768976408982925938974295924028441291658237346176654371253117928175549332214544184956873153703986985390178340362842925095394171723247001148381841608957916206559941167745803865591902614480793421765383057879093169950619962622549993022767924551441062799960937735273562540998394372954466434657886766982923637838307867760458002688852634523140205959841808840422174672818444274740218202285908370264201317843186892579386846520240651724201292430263438230872027782342582077071631797326347450149692183612550505452085046398227369191676244454566325746209776357595496074364519226187538979254031364601406997434831185228055585873117 n, e, d = p*q, Crypto.exp, pow(Crypto.exp, -1, (p-1)*(q-1)) return p, q, n, e, d @@ -142,23 +146,25 @@ class PKSock: assert(type(b) == bytes) if len(b) < 1: return - + if self.streaming and not force_normal: if not self.osk or self.oskp >= len(self.osk): self.push_sk() - while len(b) > len(self.osk) - self.oskp: + while len(b) >= len(self.osk) - self.oskp: b_frag = b[:len(self.osk)-self.oskp] - k = self.sk[self.skp:self.skp+len(b_frag)] + k = self.osk[self.oskp:self.oskp+len(b_frag)] self.oskp += len(b_frag) self.sock.sendall(bytes([b_frag[i] ^ k[i] for i in range(len(b_frag))])) b = b[len(b_frag):] self.push_sk() - k = self.osk[self.oskp:self.oskp+len(b)] - self.oskp += len(b) - self.sock.sendall(bytes([b[i] ^ k[i] for i in range(len(b))])) + if len(b) > 0: + k = self.osk[self.oskp:self.oskp+len(b)] + self.sock.sendall(bytes([b[i] ^ k[i] for i in range(len(b))])) + self.oskp += len(b) + assert(self.oskp < len(self.osk)) else: p = Crypto.encrypt(b, self.rpk['e'], self.rpk['n'], self.bits) - self.raw_send(Crypto.i2b(len(p))) + self.raw_send(Crypto.i2b(len(p), self.headsz)) for chunk in p: self.raw_send(chunk) @@ -180,8 +186,9 @@ class PKSock: def handshake_client (self): rnbytes = Crypto.b2i(self.raw_recv(self.headsz)) - self.raw_send(Crypto.i2b(self.nbytes)) + self.raw_send(Crypto.i2b(self.nbytes, self.headsz)) if self.nbytes != rnbytes: + print('nbytes mismatch: %d vs %d' % (self.nbytes, rnbytes)) return False self.rpk = {'n': Crypto.b2i(self.recv()), 'e': Crypto.exp} @@ -189,17 +196,26 @@ class PKSock: def handshake_server (self, server_pk): self.rpk = server_pk - self.raw_send(Crypto.i2b(self.nbytes)) + self.raw_send(Crypto.i2b(self.nbytes, self.headsz)) rnbytes = Crypto.b2i(self.raw_recv(self.headsz)) if self.nbytes != rnbytes: + print('nbytes mismatch: %d vs %d' % (self.nbytes, rnbytes)) return False - self.send(Crypto.i2b(self.priv['n'])) + self.send(Crypto.i2b(self.priv['n'], self.nbytes)) + return True def push_sk (self): + print('PUSH_SK') self.osk = secrets.token_bytes(self.sksz) self.oskp = 0 self.send(self.osk, force_normal=True) + print('RETURNED') def pull_sk (self): + print('PULL_SK') self.isk = self.recv(force_normal=True) self.iskp = 0 + print('RETURNED') + + def close (self): + self.sock.close() -- cgit v1.2.3