From 08d1179a8cee525b1e1cb09dcae2ddd787b039d4 Mon Sep 17 00:00:00 2001 From: Mariano Riefolo Date: Wed, 20 Mar 2024 23:24:56 +0100 Subject: [PATCH] Added Server classes and modified Database class --- src/main/java/controllers/Database.java | 26 +++- .../java/controllers/MessageForwarder.java | 20 +++ src/main/java/controllers/Server.java | 18 +++ src/main/java/controllers/ServerThread.java | 136 ++++++++++++++++++ 4 files changed, 198 insertions(+), 2 deletions(-) create mode 100644 src/main/java/controllers/MessageForwarder.java create mode 100644 src/main/java/controllers/Server.java create mode 100644 src/main/java/controllers/ServerThread.java diff --git a/src/main/java/controllers/Database.java b/src/main/java/controllers/Database.java index 55e0337..47ebb64 100644 --- a/src/main/java/controllers/Database.java +++ b/src/main/java/controllers/Database.java @@ -9,7 +9,7 @@ import java.math.BigInteger; import java.sql.*; public class Database { - public static boolean register(String username, String password, BigInteger e, BigInteger d, BigInteger n) { + public static boolean registerAccount(String username, String password, BigInteger e, BigInteger d, BigInteger n) { try (Connection connection = getConnection()) { if (connection == null) return false; @@ -38,7 +38,7 @@ public class Database { } } - public static Account login(String username, CharSequence password) { + public static Account getAccount(String username, CharSequence password) { try (Connection connection = getConnection()) { if (connection == null) return null; @@ -70,6 +70,28 @@ public class Database { } } + public static int getIdFromUsername(String username) { + try (Connection connection = getConnection()) { + if (connection == null) return -1; + + try (PreparedStatement statement = connection.prepareStatement(""" + SELECT id + FROM accounts + WHERE username = ? + """)) { + statement.setString(1, username); + ResultSet resultSet = statement.executeQuery(); + if (resultSet.next()) { + return resultSet.getInt(1); + } + return -1; + } + } catch (SQLException e) { + System.err.println("Error while trying to retrieve the id from the username"); + return -1; + } + } + private static Connection getConnection() throws SQLException { return DriverManager.getConnection( System.getenv("db_url"), diff --git a/src/main/java/controllers/MessageForwarder.java b/src/main/java/controllers/MessageForwarder.java new file mode 100644 index 0000000..2e6f781 --- /dev/null +++ b/src/main/java/controllers/MessageForwarder.java @@ -0,0 +1,20 @@ +package controllers; + +import java.util.HashMap; +import java.util.Map; + +public class MessageForwarder { + private static final Map idThread; + + static { + idThread = new HashMap<>(); + } + + public static void addUser(int id, ServerThread thread) { + idThread.put(id, thread); + } + + public static void sendTo(int id, String message, int senderId) { + idThread.get(id).sendMessage(message, senderId); + } +} diff --git a/src/main/java/controllers/Server.java b/src/main/java/controllers/Server.java new file mode 100644 index 0000000..31ed172 --- /dev/null +++ b/src/main/java/controllers/Server.java @@ -0,0 +1,18 @@ +package controllers; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; + +public class Server { + public static void main(String[] args) { + try (ServerSocket serverSocket = new ServerSocket(21324)) { + Socket socket = serverSocket.accept(); + ServerThread serverThread = new ServerThread(socket); + serverThread.start(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } +} diff --git a/src/main/java/controllers/ServerThread.java b/src/main/java/controllers/ServerThread.java new file mode 100644 index 0000000..fb3d85b --- /dev/null +++ b/src/main/java/controllers/ServerThread.java @@ -0,0 +1,136 @@ +package controllers; + +import models.Account; +import models.Rsa; + +import java.io.BufferedReader; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.math.BigInteger; +import java.net.Socket; + +import static controllers.Database.getAccount; +import static controllers.Database.registerAccount; + +public class ServerThread extends Thread { + private final Socket client; + private BufferedReader fromClient; + private DataOutputStream toClient; + private final Rsa rsa; + private BigInteger clientE, clientN; + private int clientId; + + public ServerThread (Socket socket) { + this.client = socket; + rsa = new Rsa(1024); + } + + public void run() { + try { + fromClient = new BufferedReader(new InputStreamReader(client.getInputStream())); + toClient = new DataOutputStream(client.getOutputStream()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + try { + toClient.writeBytes(rsa.getE().toString()); + toClient.writeBytes(rsa.getN().toString()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + try { + clientE = new BigInteger(fromClient.readLine()); + clientN = new BigInteger(fromClient.readLine()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + for (;;) { + String operation; + try { + operation = rsa.decrypt(new BigInteger(fromClient.readLine())); + } catch (IOException e) { + throw new RuntimeException(e); + } + + String username, password; + + try { + username = rsa.decrypt(new BigInteger(fromClient.readLine())); + password = rsa.decrypt(new BigInteger(fromClient.readLine())); + } catch (IOException e) { + throw new RuntimeException(e); + } + + + try { + if ("LOGIN".equals(operation)) { + if (login(username, password)) { + toClient.writeBytes(rsa.encrypt("SUCCESS", clientE, clientN).toString()); + break; + } + } else if ("REGISTER".equals(operation)) { + if (register(username, password)) { + toClient.writeBytes(rsa.encrypt("SUCCESS", clientE, clientN).toString()); + break; + } + } else { + toClient.writeBytes(rsa.encrypt("FAIL", clientE, clientN).toString()); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + MessageForwarder.addUser(clientId, this); + + int recipientId; + String message; + + for (;;) { + try { + recipientId = Integer.parseInt(rsa.decrypt(new BigInteger(fromClient.readLine()))); + message = rsa.decrypt(new BigInteger(fromClient.readLine())); + } catch (IOException e) { + throw new RuntimeException(e); + } + + if (message == null || "DISCONNECT".equals(message)) break; + + MessageForwarder.sendTo(recipientId, message, clientId); + } + } + + private boolean login(String username, String password) { + Account account = getAccount(username, password); + return account != null; + } + + public boolean register(String username, String password) { + BigInteger clientD; + + try { + clientD = new BigInteger(fromClient.readLine()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + if (!registerAccount(username, password, clientE, clientD, clientN)) return false; + + clientId = Database.getIdFromUsername(username); + return clientId != -1; + } + + public void sendMessage(String message, int sender) { + try { + toClient.writeBytes("INCOMING"); + toClient.writeInt(sender); + toClient.writeBytes(message); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +}