Added Server classes and modified Database class

This commit is contained in:
Mariano Riefolo 2024-03-20 23:24:56 +01:00
parent 36cd890d0f
commit 08d1179a8c
4 changed files with 198 additions and 2 deletions

View File

@ -9,7 +9,7 @@ import java.math.BigInteger;
import java.sql.*;
public class Database {
public static boolean register(String username, String password, BigInteger e, BigInteger d, BigInteger n) {
public static boolean registerAccount(String username, String password, BigInteger e, BigInteger d, BigInteger n) {
try (Connection connection = getConnection()) {
if (connection == null) return false;
@ -38,7 +38,7 @@ public class Database {
}
}
public static Account login(String username, CharSequence password) {
public static Account getAccount(String username, CharSequence password) {
try (Connection connection = getConnection()) {
if (connection == null) return null;
@ -70,6 +70,28 @@ public class Database {
}
}
public static int getIdFromUsername(String username) {
try (Connection connection = getConnection()) {
if (connection == null) return -1;
try (PreparedStatement statement = connection.prepareStatement("""
SELECT id
FROM accounts
WHERE username = ?
""")) {
statement.setString(1, username);
ResultSet resultSet = statement.executeQuery();
if (resultSet.next()) {
return resultSet.getInt(1);
}
return -1;
}
} catch (SQLException e) {
System.err.println("Error while trying to retrieve the id from the username");
return -1;
}
}
private static Connection getConnection() throws SQLException {
return DriverManager.getConnection(
System.getenv("db_url"),

View File

@ -0,0 +1,20 @@
package controllers;
import java.util.HashMap;
import java.util.Map;
public class MessageForwarder {
private static final Map<Integer, ServerThread> idThread;
static {
idThread = new HashMap<>();
}
public static void addUser(int id, ServerThread thread) {
idThread.put(id, thread);
}
public static void sendTo(int id, String message, int senderId) {
idThread.get(id).sendMessage(message, senderId);
}
}

View File

@ -0,0 +1,18 @@
package controllers;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
public class Server {
public static void main(String[] args) {
try (ServerSocket serverSocket = new ServerSocket(21324)) {
Socket socket = serverSocket.accept();
ServerThread serverThread = new ServerThread(socket);
serverThread.start();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,136 @@
package controllers;
import models.Account;
import models.Rsa;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigInteger;
import java.net.Socket;
import static controllers.Database.getAccount;
import static controllers.Database.registerAccount;
public class ServerThread extends Thread {
private final Socket client;
private BufferedReader fromClient;
private DataOutputStream toClient;
private final Rsa rsa;
private BigInteger clientE, clientN;
private int clientId;
public ServerThread (Socket socket) {
this.client = socket;
rsa = new Rsa(1024);
}
public void run() {
try {
fromClient = new BufferedReader(new InputStreamReader(client.getInputStream()));
toClient = new DataOutputStream(client.getOutputStream());
} catch (IOException e) {
throw new RuntimeException(e);
}
try {
toClient.writeBytes(rsa.getE().toString());
toClient.writeBytes(rsa.getN().toString());
} catch (IOException e) {
throw new RuntimeException(e);
}
try {
clientE = new BigInteger(fromClient.readLine());
clientN = new BigInteger(fromClient.readLine());
} catch (IOException e) {
throw new RuntimeException(e);
}
for (;;) {
String operation;
try {
operation = rsa.decrypt(new BigInteger(fromClient.readLine()));
} catch (IOException e) {
throw new RuntimeException(e);
}
String username, password;
try {
username = rsa.decrypt(new BigInteger(fromClient.readLine()));
password = rsa.decrypt(new BigInteger(fromClient.readLine()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try {
if ("LOGIN".equals(operation)) {
if (login(username, password)) {
toClient.writeBytes(rsa.encrypt("SUCCESS", clientE, clientN).toString());
break;
}
} else if ("REGISTER".equals(operation)) {
if (register(username, password)) {
toClient.writeBytes(rsa.encrypt("SUCCESS", clientE, clientN).toString());
break;
}
} else {
toClient.writeBytes(rsa.encrypt("FAIL", clientE, clientN).toString());
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
MessageForwarder.addUser(clientId, this);
int recipientId;
String message;
for (;;) {
try {
recipientId = Integer.parseInt(rsa.decrypt(new BigInteger(fromClient.readLine())));
message = rsa.decrypt(new BigInteger(fromClient.readLine()));
} catch (IOException e) {
throw new RuntimeException(e);
}
if (message == null || "DISCONNECT".equals(message)) break;
MessageForwarder.sendTo(recipientId, message, clientId);
}
}
private boolean login(String username, String password) {
Account account = getAccount(username, password);
return account != null;
}
public boolean register(String username, String password) {
BigInteger clientD;
try {
clientD = new BigInteger(fromClient.readLine());
} catch (IOException e) {
throw new RuntimeException(e);
}
if (!registerAccount(username, password, clientE, clientD, clientN)) return false;
clientId = Database.getIdFromUsername(username);
return clientId != -1;
}
public void sendMessage(String message, int sender) {
try {
toClient.writeBytes("INCOMING");
toClient.writeInt(sender);
toClient.writeBytes(message);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}