diff --git a/.gitignore b/.gitignore index aca9d81..5a91053 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +__pycache__ .vscode smallchat.dSYM smallchat-server diff --git a/Makefile b/Makefile index fe764e7..611a75e 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,10 @@ smallchat-server: smallchat-server.c chatlib.c smallchat-client: smallchat-client.c chatlib.c $(CC) smallchat-client.c chatlib.c -o smallchat-client $(CFLAGS) +test: smallchat-server smallchat-client + python3 -m unittest process.py -v + python3 -m unittest tests.py -v + clean: rm -f smallchat-server rm -f smallchat-client diff --git a/README.md b/README.md index 2c66bfc..795a87d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # Smallchat +Inspired by Salvatore Sanfilippo's series on creating a chat +Implemented in python + + TLDR: This is just a programming example for a few friends of mine. It somehow turned into a set of programming videos, continuing one project I started some time ago: Writing System Software videos series. 1. [First episode](https://www.youtube.com/watch?v=eT02gzeLmF0), how the basic server works. diff --git a/process.py b/process.py new file mode 100644 index 0000000..253418e --- /dev/null +++ b/process.py @@ -0,0 +1,116 @@ +import subprocess +import sys +import unittest + + +class Process: + def __init__(self, args): + self.proc = subprocess.Popen( + args, stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + self.stop() + self.proc.stdin.close() + self.proc.stdout.close() + + def read(self): + return self.proc.stdout.readline().strip() + + def stop(self): + self.terminate() + self.wait() + + def terminate(self): + self.proc.terminate() + + def wait(self): + self.proc.wait() + + def write(self, msg): + # print(f"Process.write msg: {msg}", file=sys.stderr) + self.proc.stdin.write(msg + b"\n") + self.proc.stdin.flush() + + +class TestProcess(unittest.TestCase): + def setUp(self): + self.p = Process([sys.executable, __file__]) + + def tearDown(self): + self.p.close() + + def test_stdout(self): + line = self.p.read() + self.assertEqual(line, b"started") + self.p.stop() + line = self.p.read() + self.assertFalse(line) + + def test_stdin(self): + line = self.p.read() + self.assertEqual(line, b"started") + self.p.write(b"test-request") + line = self.p.read() + self.assertEqual(line, b"test-request") + self.p.stop() + line = self.p.read() + self.assertFalse(line) + + def test_cycle(self): + line = self.p.read() + self.assertEqual(line, b"started") + self.p.write(b"test-request-1") + line = self.p.read() + self.assertEqual(line, b"test-request-1") + self.p.write(b"test-request-2") + line = self.p.read() + self.assertEqual(line, b"test-request-2") + self.p.write(b"/exit") + self.p.wait() + line = self.p.read() + self.assertFalse(line) + + + def test_long(self): + line = self.p.read() + self.assertEqual(line, b"started") + request = b"0123456789" * 10000 + self.p.write(request) + line = self.p.read() + self.assertEqual(line, request) + self.p.stop() + line = self.p.read() + self.assertFalse(line) + + +class TestProcessContext(unittest.TestCase): + def test_stdout(self): + with Process([sys.executable, __file__]) as p: + line = p.read() + self.assertEqual(line, b"started") + p.stop() + line = p.read() + self.assertFalse(line) + + +def test_main(): + sys.stdout.write("started\n") + sys.stdout.flush() + while True: + # print("read", file=sys.stderr) + request = sys.stdin.readline().strip() + # print("request", request, file=sys.stderr) + if request == '/exit': + break + sys.stdout.write(request + "\n") + sys.stdout.flush() + + +if __name__ == '__main__': + test_main(*sys.argv[1:]) diff --git a/smallchat.py b/smallchat.py new file mode 100644 index 0000000..225f7f6 --- /dev/null +++ b/smallchat.py @@ -0,0 +1,187 @@ +import select +import socket +import sys + +WELCOME = b"Welcome to Simple Chat! Use /nick to set your nick." +PREFIX = b"/nick " + + +class Client: + def __init__(self, conn, protocol_cls, notify_receive, notify_close): + self.conn = conn + self.notify_receive = notify_receive + self.notify_close = notify_close + self.fd = conn.fileno() + self.protocol = protocol_cls(self.notify_receive) + self.out_buffer = bytearray() + + def raw_receive(self): + data = self.conn.recv(1024) + if not data: + self.notify_close(self) + self.conn.close() + else: + self.protocol.decode(data) + + def raw_send(self): + if self.out_buffer: + sent = self.conn.send(self.out_buffer) + self.out_buffer = self.out_buffer[sent:] + + def send(self, msg): + # print(f"send msg: {msg}", file=sys.stderr) + self.out_buffer += self.protocol.encode(msg) + + +class Clients: + def __init__(self, inputs, outputs): + self.inputs = inputs + self.outputs = outputs + self.clients = {} + + def add(self, client): + self.clients[client.fd] = client + self.inputs.append(client.conn) + self.outputs.append(client.conn) + + def delete(self, client): + self.clients.pop(client.fd) + self.inputs.remove(client.conn) + self.outputs.remove(client.conn) + + def get(self, conn): + return self.clients[conn.fileno()] + + +class ChatClient(Client): + def __init__(self, conn, protocol_cls, publish, notify_close): + super().__init__(conn, protocol_cls, self._received, notify_close) + self.publish = publish + self.nick = f"user:{conn.fileno()}" + + def _received(self, msg): + # print(f"received msg: {msg}", file=sys.stderr) + if msg.startswith(PREFIX): + self.nick = msg[len(PREFIX):].decode() + else: + self.publish(self, msg) + + +class ChatClients(Clients): + def add(self, client): + super().add(client) + # print(f"Connected client fd={client.fd}, nick={client.nick}") + client.send(WELCOME) + + def delete(self, client): + super().delete(client) + # print(f"Disconnected client fd={client.fd}, nick={client.nick}") + + def publish(self, sender, msg): + response = sender.nick.encode() + b"> " + msg + for client in self.clients.values(): + if client != sender: + client.send(response) + + +class Protocol: + END = b"\n" + + def __init__(self, notify, end=None): + self.notify = notify + self.end = end or self.END + self.buff = bytearray() + + @classmethod + def encode(cls, msg): + assert not cls.END in msg + return msg + cls.END + + def decode(self, data): + for car in data: + if car == ord(self.END): + self.notify(self.buff) + self.buff.clear() + else: + self.buff.append(car) + + +class Stream: + def __init__(self, stdin, stdout): + self.stdin = stdin + self.stdout = stdout + self.closed = False + self.send = None + + def close(self): + self.closed = True + + def raw_receive(self): + msg = self.stdin.readline().rstrip() + self.send(msg.encode()) + + def receive(self, msg): + self.stdout.write(msg.decode() + "\n") + self.stdout.flush() + + def raw_send(self): + pass # noop + + +def _main_client(address): + stream = Stream(sys.stdin, sys.stdout) + + with socket.socket() as conn: + conn.connect(address) + client = Client(conn, Protocol, notify_receive=stream.receive, notify_close=stream.close) + stream.protocol = Protocol(client.send, "\n") + stream.send = client.send + inputs = [conn, sys.stdin] + outputs = [conn, sys.stdout] + clients = {conn: client, sys.stdin: stream, sys.stdout: stream} + while not stream.closed: + inputready, outputready, exceptready = select.select(inputs, outputs, []) + for s in inputready: + clients.get(s).raw_receive() + for s in outputready: + if s.fileno() <= 0: + # sockets already closed during reception/recv + continue + clients.get(s).raw_send() + + +def _main_server(address): + with socket.socket() as sl: + sl.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sl.bind(address) + sl.listen() + sys.stdout.write(f"Server started address={address}\n") + sys.stdout.flush() + inputs = [sl] + outputs = [] + clients = ChatClients(inputs, outputs) + while True: + inputready, outputready, exceptready = select.select(inputs, outputs, []) + for s in inputready: + if s == sl: + conn, addr = sl.accept() + client = ChatClient(conn, Protocol, clients.publish, clients.delete) + clients.add(client) + else: + clients.get(s).raw_receive() + for s in outputready: + if s.fileno() <= 0: + # sockets already closed during reception/recv + continue + clients.get(s).raw_send() + + +def main(role, host, port): + fun = globals()["_main_" + role] + address = (host, int(port)) + fun(address) + + +if __name__ == '__main__': + main(*sys.argv[1:]) + diff --git a/tests.py b/tests.py new file mode 100644 index 0000000..44a1428 --- /dev/null +++ b/tests.py @@ -0,0 +1,151 @@ +from sys import executable +from time import sleep +from unittest import TestCase, skip + +from process import Process +from smallchat import WELCOME + +HOST = "localhost" +PORT = "7711" + + +class TestServer: + def _wait_client_receive(self): + pass + + def _wait_start_server(self): + line = self.server.read() + assert line.startswith(b"Server started") + + def setUp(self): + self.server = Process(self.SERVER) + self._wait_start_server() + + def tearDown(self): + self.server.close() + + +class TestIntegration(TestServer): + BIG_MESSAGE_BODY = b"Hi, it's " + b"me" * 10000 + b"." + CLIENT = ["nc", HOST, PORT] + CONSECUTIVE_MESSAGES_COUNT = 100 + CONTEMPORARY_CLIENTS_COUNT = 50 + + def test_minimal(self): + c_first = Process(self.CLIENT) + c_second = Process(self.CLIENT) + l = c_first.read() + self.assertEqual(l, WELCOME) + l = c_second.read() + self.assertEqual(l, WELCOME) + c_first.write(b"/nick test-me") + self._wait_client_receive() + c_first.write(b"Hi!") + l_second = c_second.read() + self.assertEqual(l_second, b"test-me> Hi!") + c_first.close() + c_second.close() + + def test_disconnected(self): + c_first = Process(self.CLIENT) + c_second = Process(self.CLIENT) + c_third = Process(self.CLIENT) + self.assertEqual(c_first.read(), WELCOME) + self.assertEqual(c_second.read(), WELCOME) + self.assertEqual(c_third.read(), WELCOME) + c_third.close() + c_first.write(b"/nick test-me") + c_first.write(b"/nick test-me") + self._wait_client_receive() + c_first.write(b"Hi!") + l_second = c_second.read() + self.assertEqual(l_second, b"test-me> Hi!") + c_first.close() + c_second.close() + + def test_very_long_message(self): + c_first = Process(self.CLIENT) + c_second = Process(self.CLIENT) + l = c_first.read() + self.assertEqual(l, WELCOME) + l = c_second.read() + self.assertEqual(l, WELCOME) + c_first.write(b"/nick test-me") + self._wait_client_receive() + msg = self.BIG_MESSAGE_BODY + c_first.write(msg) + l_second = c_second.read() + self.assertEqual(l_second, b"test-me> " + msg) + c_first.close() + c_second.close() + + def test_many_consecutive_messages(self): + c_first = Process(self.CLIENT) + c_second = Process(self.CLIENT) + self.assertEqual(c_first.read(), WELCOME) + self.assertEqual(c_second.read(), WELCOME) + c_first.write(b"/nick test-me") + self._wait_client_receive() + msg = self.BIG_MESSAGE_BODY + for idx in range(self.CONSECUTIVE_MESSAGES_COUNT): + c_first.write(msg) + for idx in range(self.CONSECUTIVE_MESSAGES_COUNT): + l_second = c_second.read() + self.assertEqual(l_second, b"test-me> " + msg) + c_first.close() + c_second.close() + + def test_many_clients(self): + clients = [Process(self.CLIENT) for idx in range(self.CONTEMPORARY_CLIENTS_COUNT)] + for client in clients: + self.assertEqual(client.read(), WELCOME) + c_first, c_second, *c_others = clients + c_first.write(b"/nick test-me-1") + c_first.write(b"Hi, I'm the first!") + c_second.write(b"/nick test-me-2") + c_second.write(b"Hi, it's me, I'm the second!") + self.assertEqual(c_first.read(), b"test-me-2> Hi, it's me, I'm the second!") + self.assertEqual(c_second.read(), b"test-me-1> Hi, I'm the first!") + for c_other in c_others: + msgs = set([c_other.read(), c_other.read()]) + self.assertIn(b"test-me-1> Hi, I'm the first!", msgs) + self.assertIn(b"test-me-2> Hi, it's me, I'm the second!", msgs) + for client in clients: + client.close() + + + +class TestIntegrationPy(TestIntegration, TestCase): + SERVER = [executable, "smallchat.py", "server", HOST, PORT] + + +class TestIntegrationClientPy(TestIntegration, TestCase): + CLIENT = [executable, "smallchat.py", "client", HOST, PORT] + CONTEMPORARY_CLIENTS_COUNT = 2 + SERVER = [executable, "smallchat.py", "server", HOST, PORT] + + @skip("TODO") + def test_many_clients(self): + super().test_many_clients() + + +class TestIntegrationC(TestIntegration, TestCase): + SERVER = ["./smallchat-server"] + + def _wait_client_receive(self): + sleep(.1) + + def _wait_start_server(self): + sleep(.01) + + @skip("TODO") + def test_very_long_message(self): + super().test_very_long_message() + + @skip("TODO") + def test_many_consecutive_messages(self): + super().test_many_consecutive_messages() + + @skip("TODO") + def test_many_clients(self): + super().test_many_clients()