summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crypto.py49
-rw-r--r--pkd_stub.py66
2 files changed, 49 insertions, 66 deletions
diff --git a/crypto.py b/crypto.py
index 77d795f..7de099e 100644
--- a/crypto.py
+++ b/crypto.py
@@ -94,20 +94,27 @@ class PKSock:
self.nbytes = -(-bits//8)
self.headsz = Crypto.byte_length(self.nbytes)
self.streaming = False
- self.sk = None
- self.skp = 0
+ # TODO: need to separate isk and osk
+ self.isk = None
+ self.iskp = 0
+ self.osk = None
+ self.oskp = 0
self.sksz = self.nbytes - self.headsz - 1
self.buffer = b''
def start_stream (self):
- self.sk = None
- self.skp = 0
+ self.isk = None
+ self.osk = None
+ self.iskp = 0
+ self.oskp = 0
self.streaming = True
def stop_stream (self, backtrack=0):
assert (backtrack <= self.nbytes)
- self.sk = None
- self.skp = 0
+ self.isk = None
+ self.osk = None
+ self.iskp = 0
+ self.oskp = 0
self.buffer = self.buffer[-backtrack:]
self.streaming = False
@@ -137,17 +144,17 @@ class PKSock:
return
if self.streaming and not force_normal:
- if not self.sk or self.skp >= len(self.sk):
+ if not self.osk or self.oskp >= len(self.osk):
self.push_sk()
- while len(b) > len(self.sk) - self.skp:
- b_frag = b[:len(self.sk)-self.skp]
+ while len(b) > len(self.osk) - self.oskp:
+ b_frag = b[:len(self.osk)-self.oskp]
k = self.sk[self.skp:self.skp+len(b_frag)]
- self.skp += len(b_frag)
+ self.oskp += len(b_frag)
self.sock.sendall(bytes([b_frag[i] ^ k[i] for i in range(len(b_frag))]))
b = b[len(b_frag):]
self.push_sk()
- k = self.sk[self.skp:self.skp+len(b)]
- self.skp += len(b)
+ k = self.osk[self.oskp:self.oskp+len(b)]
+ self.oskp += len(b)
self.sock.sendall(bytes([b[i] ^ k[i] for i in range(len(b))]))
else:
p = Crypto.encrypt(b, self.rpk['e'], self.rpk['n'], self.bits)
@@ -157,13 +164,13 @@ class PKSock:
def recv (self, force_normal=False):
if self.streaming and not force_normal:
- if not self.sk or self.skp >= len(self.sk):
+ if not self.isk or self.iskp >= len(self.isk):
self.pull_sk()
# TODO: this could use some work because we can split opcodes etc
- c = self.sock.recv(len(self.sk) - self.skp)
+ c = self.sock.recv(len(self.isk) - self.iskp)
self.raw_cache(c)
- k = self.sk[self.skp : self.skp+len(c)]
- self.skp += len(c)
+ k = self.isk[self.iskp : self.iskp+len(c)]
+ self.iskp += len(c)
return bytes([c[i] ^ k[i] for i in range(len(c))])
else:
chunks, nchunks = [], Crypto.b2i(self.raw_recv(self.headsz))
@@ -189,10 +196,10 @@ class PKSock:
self.send(Crypto.i2b(self.priv['n']))
def push_sk (self):
- self.sk = secrets.token_bytes(self.sksz)
- self.skp = 0
- self.send(self.sk, force_normal=True)
+ self.osk = secrets.token_bytes(self.sksz)
+ self.oskp = 0
+ self.send(self.osk, force_normal=True)
def pull_sk (self):
- self.sk = self.recv(force_normal=True)
- self.skp = 0
+ self.isk = self.recv(force_normal=True)
+ self.iskp = 0
diff --git a/pkd_stub.py b/pkd_stub.py
index d81899e..cac525c 100644
--- a/pkd_stub.py
+++ b/pkd_stub.py
@@ -32,12 +32,8 @@ def showcrypto():
global privkey
return '[warcrypto] Server public key:\n{"n": %d, "e": %d}' % (privkey['n'], privkey['e'])
-def dispatch_command(sock, command, rpubkey):
- global bits
- send_encrypted(sock, command, rpubkey['e'], rpubkey['n'], bits=bits)
-
def dispatch_ccmd(client, command):
- dispatch_command(client['sock'], command, client['pubkey'])
+ client['sock'].send(command)
def brint(*args, sep=' ', end='\n', prompt=True):
s = '%s%s' % (sep.join(map(lambda s: betterstr(s), args)), end)
@@ -71,7 +67,7 @@ def broadcast_screens(s, skip=set(), sv_prompt=False, ctd_prompt=False):
def blast_command(cmd, orig_screen, targets=set()):
global cmdq
tstr = betterstr(targets)
- if tstr == 'set()':
+ if len(targets) < 1:
tstr = 'all clients'
print('[INFO] Blasting command: %s to %s.' % (betterstr(cmd), tstr))
if type(cmd) != bytes:
@@ -128,19 +124,18 @@ def screens_detach(sel, screen):
brint('[INFO] Screen detaching: %d' % idx)
def screens_pty(sel, screen, client):
- screen['pty'] = client
- client['pty'] = screen
- client['osc'] = OutStreamCipher(client['sock'], client['pubkey'], bits=bits)
- client['isc'] = InStreamCipher(client['sock'], privkey, bits=bits)
-
try:
dispatch_ccmd(client, b'pty')
+ client['sock'].start_stream()
+ client['pty'] = screen
+ screen['pty'] = client
if 'TERM' not in os.environ:
os.environ['TERM'] = 'xterm-256color'
- client['osc'].send(bytes(os.environ['TERM'], 'utf-8'))
+ client['sock'].send(bytes(os.environ['TERM'], 'utf-8'))
except:
tcp_unpty(sel, client, catchup=False)
tcp_disconnect(sel, client)
+ return
try:
screen['sock'].sendall(b'\xc0\xdepty')
@@ -165,7 +160,7 @@ def screens_read(sel, sock, screen):
if screen['pty']:
try:
- screen['pty']['osc'].send(data)
+ screen['pty']['sock'].send(data)
except:
tcp_unpty(sel, client, catchup=False)
tcp_disconnect(sel, client)
@@ -306,11 +301,11 @@ def tcp_dumpq(sel, client):
def tcp_send_npty(sel, client):
try:
- client['osc'].send(b'\xc0\xdenpty')
+ client['sock'].send(b'\xc0\xdenpty')
except:
tcp_disconnect(sel, client)
-def tcp_unpty(sel, client, catchup=True):
+def tcp_unpty(sel, client, catchup=True, backtrack=0):
if type(client['pty']) == dict:
client['pty']['pty'] = False
if client['pty']['alive']:
@@ -320,12 +315,11 @@ def tcp_unpty(sel, client, catchup=True):
screens_detach(sel, client['pty'])
try:
- client['osc'].send(b'\xc0\xdeack')
+ client['sock'].send(b'\xc0\xdeack')
except:
tcp_disconnect(sel, client)
- # this will become stop_stream(backtrack)
- del client['isc']
- del client['osc']
+
+ client['sock'].stop_stream(backtrack)
client['pty'] = False
if catchup:
@@ -336,8 +330,7 @@ def tcp_transport(sel, sock, client):
if not client['alive']:
return
try:
- data = client['isc'].recv() if client['pty'] else\
- recv_encrypted(sock, privkey['d'], privkey['n'], bits=bits)
+ data = client['sock'].recv()
except:
data = False
if not data or data == b'\xde\xad':
@@ -348,7 +341,7 @@ def tcp_transport(sel, sock, client):
elif not client['pty']:
brint('[%d]' % tcp_clients.index(client), data, end='', prompt=False)
elif data[:6] == b'\xc0\xdenpty':
- tcp_unpty(sel, client, catchup=True)
+ tcp_unpty(sel, client, catchup=True, backtrack=len(data[6:]))
print('[INFO] npty acknowledged')
else:
try:
@@ -357,21 +350,6 @@ def tcp_transport(sel, sock, client):
screens_detach(sel, client['pty'])
tcp_send_npty(sel, client)
-def tcp_handshake(sock):
- global privkey, bits, exp
- nbytes, headsz = bits//8, 2
- rnbytes = int.from_bytes(sock.recv(headsz), 'big')
- sock.sendall(nbytes.to_bytes(headsz, 'big'))
-
- if rnbytes != nbytes:
- brint('[ERROR] nbytes mismatch with client: %d vs %d' % (rnbytes, nbytes))
- return False
-
- rpubkey = { 'n': int.from_bytes(recv_encrypted(sock, privkey['d'], privkey['n'], bits=bits),\
- 'big'), 'e': exp }
-
- return rpubkey
-
def tcp_close(sock, client):
try:
dispatch_ccmd(client, b'tunnel')
@@ -379,7 +357,7 @@ def tcp_close(sock, client):
pass
def tcp_accept(sel, sock):
- global tcp_clients
+ global tcp_clients, privkey, bits
try:
cs, ca = sock.accept()
except:
@@ -388,22 +366,20 @@ def tcp_accept(sel, sock):
client = {
'alive': True,
- 'sock': cs,
+ 'sock': PKSock(cs, privkey, bits),
'addr': ca,
'qidx': 0,
'pty': False
}
try:
- rpk = tcp_handshake(cs)
+ success = client['sock'].handshake_client()
except:
- rpk = False
- finally:
- pass
- if not rpk:
+ success = False
+
+ if not success:
brint('[WARNING] TCP handshake failed from', client['addr'])
cs.close()
return
- client['pubkey'] = rpk
tcp_clients.append(client)
sel.register(cs, selectors.EVENT_READ, {'callback': tcp_transport, 'close': tcp_close, 'args': [client]})