diff --git a/src/main/java/controllers/Database.java b/src/main/java/controllers/Database.java index 5d79f36..2ecc0a6 100644 --- a/src/main/java/controllers/Database.java +++ b/src/main/java/controllers/Database.java @@ -3,16 +3,15 @@ package controllers; import models.Account; import models.Conversation; import models.Message; -import org.mariadb.jdbc.MariaDbBlob; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; -import views.Client; import java.math.BigInteger; import java.sql.*; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.List; + import models.PublicKey; public class Database { @@ -24,14 +23,14 @@ public class Database { String hashedPass = encoder.encode(password); try (PreparedStatement statement = connection.prepareStatement(""" - INSERT INTO accounts(username, password, e, d, n) - VALUES (?, ?, ?, ?, ?) - """)) { + INSERT INTO accounts(username, password, e, d, n) + VALUES (?, ?, ?, ?, ?) + """)) { statement.setString(1, username); statement.setString(2, hashedPass); - statement.setBlob(3, new MariaDbBlob(e.toString().getBytes())); - statement.setBytes(4, d.getBytes()); - statement.setBlob(5, new MariaDbBlob(n.toString().getBytes())); + statement.setString(3, e.toString()); + statement.setString(4, d); + statement.setString(5, n.toString()); int rowsInserted = statement.executeUpdate(); return rowsInserted == 1; @@ -50,17 +49,17 @@ public class Database { if (connection == null) return null; try (PreparedStatement statement = connection.prepareStatement(""" - SELECT id, password, e, d, n - FROM accounts - WHERE username = ? - """)) { + SELECT id, password, e, d, n + FROM accounts + WHERE username = ? + """)) { statement.setString(1, username); ResultSet resultSet = statement.executeQuery(); if (resultSet.next()) { int id = resultSet.getInt("id"); String hash_password = resultSet.getString("password"); BigInteger e = new BigInteger(resultSet.getString("e")); - byte[] d = convertToByteArray(resultSet.getString("d")); + String d = resultSet.getString("d"); BigInteger n = new BigInteger(resultSet.getString("n")); PasswordEncoder encoder = new BCryptPasswordEncoder(); if (encoder.matches(password, hash_password)) @@ -77,19 +76,15 @@ public class Database { } } - public static byte[] convertToByteArray(String input) { - return Client.convertToByteArray(input); - } - public static int getIdFromUsername(String username) { try (Connection connection = getConnection()) { if (connection == null) return -1; try (PreparedStatement statement = connection.prepareStatement(""" - SELECT id - FROM accounts - WHERE username = ? - """)) { + SELECT id + FROM accounts + WHERE username = ? + """)) { statement.setString(1, username); ResultSet resultSet = statement.executeQuery(); if (resultSet.next()) { @@ -103,40 +98,37 @@ public class Database { } } - public static boolean addMessage(int conversationId, String message) { + public static void addMessage(int conversationId, String message) { try (Connection connection = getConnection()) { - if (connection == null) return false; + if (connection == null) return; try (PreparedStatement statement = connection.prepareStatement(""" - INSERT INTO messages(message, conversation_id, sending_date, sending_time) - VALUES (?, ?, ?, ?) - """)) { + INSERT INTO messages(message, conversation_id, sending_date, sending_time) + VALUES (?, ?, ?, ?) + """)) { statement.setString(1, message); statement.setInt(2, conversationId); statement.setDate(3, Date.valueOf(LocalDateTime.now().toLocalDate())); statement.setTime(4, Time.valueOf(LocalDateTime.now().toLocalTime())); - int rowsInserted = statement.executeUpdate(); - return rowsInserted == 1; + statement.executeUpdate(); } catch (SQLException ex) { System.err.println("Error while trying to add a message: " + ex); - return false; } } catch (SQLException ex) { System.err.println("Error while trying to open a connection: " + ex); - return false; } } - + public static PublicKey getPublicKey(int accountId) { try (Connection connection = getConnection()) { if (connection == null) return null; try (PreparedStatement statement = connection.prepareStatement(""" - SELECT e, n - FROM accounts - WHERE id = ? - """)) { + SELECT e, n + FROM accounts + WHERE id = ? + """)) { statement.setInt(1, accountId); ResultSet resultSet = statement.executeQuery(); if (resultSet.next()) { @@ -160,10 +152,10 @@ public class Database { if (connection == null) return null; try (PreparedStatement statement = connection.prepareStatement(""" - SELECT id, sender, recipient - FROM conversations - WHERE sender = ? OR recipient = ? - """)) { + SELECT id, sender, recipient + FROM conversations + WHERE sender = ? OR recipient = ? + """)) { statement.setInt(1, sender_id); statement.setInt(2, sender_id); ResultSet resultSet = statement.executeQuery(); @@ -185,42 +177,35 @@ public class Database { } } - public static boolean addConversation(int sender, int receiver) { + public static void addConversation(int sender, int receiver) { try (Connection connection = getConnection()) { - if (connection == null) return false; + if (connection == null) return; try (PreparedStatement statement = connection.prepareStatement(""" - INSERT INTO conversations(sender, recipient) - VALUES (?, ?) - """)) { - System.err.println("a"); + INSERT INTO conversations(sender, recipient) + VALUES (?, ?) + """)) { statement.setInt(1, sender); - System.err.println("a"); statement.setInt(2, receiver); - System.err.println("a"); - int rowsInserted = statement.executeUpdate(); - System.err.println("a"); - return rowsInserted == 1; + statement.executeUpdate(); } catch (SQLException ex) { System.err.println("Error while trying to add a conversation: " + ex); - return false; } } catch (SQLException ex) { System.err.println("Error while trying to open a connection: " + ex); - return false; } } - + public static int getConversationId(int sender, int receiver) { try (Connection connection = getConnection()) { if (connection == null) return -1; try (PreparedStatement statement = connection.prepareStatement(""" - SELECT id - FROM conversations - WHERE sender = ? AND recipient = ? - """)) { + SELECT id + FROM conversations + WHERE sender = ? AND recipient = ? + """)) { statement.setInt(1, sender); statement.setInt(2, receiver); ResultSet resultSet = statement.executeQuery(); @@ -243,10 +228,10 @@ public class Database { if (connection == null) return null; try (PreparedStatement statement = connection.prepareStatement(""" - SELECT message, sending_date, sending_time - FROM messages - WHERE conversation_id = ? - """)) { + SELECT message, sending_date, sending_time + FROM messages + WHERE conversation_id = ? + """)) { statement.setInt(1, conversation_id); ResultSet resultSet = statement.executeQuery(); List messages = new ArrayList<>(); @@ -269,8 +254,8 @@ public class Database { private static Connection getConnection() throws SQLException { return DriverManager.getConnection( - "jdbc:mariadb://riefolo.me:3306/chat_rsa", - "proj_sistemi", "$o8a5#diTGg8*Agk" + System.getenv("db_url"), + System.getenv("db_user"), System.getenv("db_pass") ); } } diff --git a/src/main/java/controllers/ServerThread.java b/src/main/java/controllers/ServerThread.java index c6297f6..ffe15db 100644 --- a/src/main/java/controllers/ServerThread.java +++ b/src/main/java/controllers/ServerThread.java @@ -2,14 +2,13 @@ package controllers; import models.Account; import models.Rsa; -import views.Client; import java.io.*; import java.math.BigInteger; import java.net.Socket; -import java.util.Arrays; import static controllers.Database.*; + import models.PublicKey; public class ServerThread extends Thread { @@ -20,7 +19,7 @@ public class ServerThread extends Thread { private BigInteger clientE, clientN; private int clientId; - public ServerThread (Socket socket) { + public ServerThread(Socket socket) { this.client = socket; rsa = new Rsa(1024); } @@ -51,7 +50,7 @@ public class ServerThread extends Thread { throw new RuntimeException(e); } - for (;;) { + for (; ; ) { String operation; sendEncrypted("Quale operazione vuoi effettuare? (REGISTER|LOGIN): "); try { @@ -76,18 +75,21 @@ public class ServerThread extends Thread { if (login(username, password)) { Account account = getAccount(username, password); if (account == null) { + System.err.println("fail"); sendEncrypted("FAIL"); break; } sendEncrypted("SUCCESS"); + clientId = Database.getIdFromUsername(username); sendEncrypted(String.valueOf(account.n())); sendEncrypted(String.valueOf(account.e())); - sendEncrypted(Arrays.toString(account.d())); + sendEncrypted(account.d()); break; } } else if ("REGISTER".equals(operation)) { if (register(username, password)) { sendEncrypted("SUCCESS"); + clientId = Database.getIdFromUsername(username); break; } } else { @@ -97,12 +99,12 @@ public class ServerThread extends Thread { MessageForwarder.addUser(clientId, this); - String message,dUsername; + String message, dUsername; sendEncrypted("Inserisci l'username del destinatario: "); try { - dUsername = rsa.decrypt(fromClient.readLine()); + dUsername = rsa.decrypt(fromClient.readLine()); } catch (IOException e) { throw new RuntimeException(e); } @@ -110,10 +112,11 @@ public class ServerThread extends Thread { Database.addConversation(clientId, recipientId); PublicKey pk = Database.getPublicKey(recipientId); + assert pk != null; sendEncrypted(pk.e().toString()); sendEncrypted(pk.n().toString()); - for (;;) { + for (; ; ) { try { sendEncrypted("Inserisci messaggio: "); message = fromClient.readLine(); @@ -123,10 +126,10 @@ public class ServerThread extends Thread { if (message == null || "DISCONNECT".equals(message)) break; - if ("CAMBIA_DESTINATARIO".equals(message)){ + if ("CAMBIA_DESTINATARIO".equals(message)) { try { sendEncrypted("Inserisci l'username del nuovo destinatario: "); - dUsername = rsa.decrypt(fromClient.readLine()); + dUsername = rsa.decrypt(fromClient.readLine()); } catch (IOException e) { throw new RuntimeException(e); } @@ -134,7 +137,7 @@ public class ServerThread extends Thread { Database.addConversation(clientId, recipientId); continue; } - + int convId = Database.getConversationId(clientId, recipientId); try { @@ -166,25 +169,20 @@ public class ServerThread extends Thread { } public boolean register(String username, String password) { - byte[] clientD; + String clientD; try { - String line = rsa.decrypt(fromClient.readLine()); - clientD = convertToByteArray(line); + clientD = rsa.decrypt(fromClient.readLine()); } catch (IOException e) { throw new RuntimeException(e); } - if (!registerAccount(username, password, clientE, Arrays.toString(clientD), clientN)) return false; + if (!registerAccount(username, password, clientE, clientD, clientN)) return false; clientId = Database.getIdFromUsername(username); return clientId != -1; } - public static byte[] convertToByteArray(String input) { - return Client.convertToByteArray(input); - } - public void sendMessage(String message, int sender) throws IOException { send("INCOMING"); diff --git a/src/main/java/models/Account.java b/src/main/java/models/Account.java index d748126..43ca4bd 100644 --- a/src/main/java/models/Account.java +++ b/src/main/java/models/Account.java @@ -2,5 +2,5 @@ package models; import java.math.BigInteger; -public record Account(int id, String username, BigInteger e, byte[] d, BigInteger n) { +public record Account(int id, String username, BigInteger e, String d, BigInteger n) { } diff --git a/src/main/java/models/Aes.java b/src/main/java/models/Aes.java new file mode 100644 index 0000000..5da0801 --- /dev/null +++ b/src/main/java/models/Aes.java @@ -0,0 +1,77 @@ +package models; + +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.PBEKeySpec; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.security.spec.KeySpec; +import java.util.Base64; + +public class Aes { + private static final int KEY_LENGTH = 256; + private static final int ITERATION_COUNT = 65536; + + public static String encrypt(String strToEncrypt, String secretKey, String salt) { + + try { + + SecureRandom secureRandom = new SecureRandom(); + byte[] iv = new byte[16]; + secureRandom.nextBytes(iv); + IvParameterSpec ivspec = new IvParameterSpec(iv); + + SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256"); + KeySpec spec = new PBEKeySpec(secretKey.toCharArray(), salt.getBytes(), ITERATION_COUNT, KEY_LENGTH); + SecretKey tmp = factory.generateSecret(spec); + SecretKeySpec secretKeySpec = new SecretKeySpec(tmp.getEncoded(), "AES"); + + Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); + cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivspec); + + byte[] cipherText = cipher.doFinal(strToEncrypt.getBytes(StandardCharsets.UTF_8)); + byte[] encryptedData = new byte[iv.length + cipherText.length]; + System.arraycopy(iv, 0, encryptedData, 0, iv.length); + System.arraycopy(cipherText, 0, encryptedData, iv.length, cipherText.length); + + return Base64.getEncoder().encodeToString(encryptedData); + } catch (Exception e) { + // Handle the exception properly + e.printStackTrace(); + return null; + } + } + + public static String decrypt(String strToDecrypt, String secretKey, String salt) { + + try { + + byte[] encryptedData = Base64.getDecoder().decode(strToDecrypt); + byte[] iv = new byte[16]; + System.arraycopy(encryptedData, 0, iv, 0, iv.length); + IvParameterSpec ivspec = new IvParameterSpec(iv); + + SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256"); + KeySpec spec = new PBEKeySpec(secretKey.toCharArray(), salt.getBytes(), ITERATION_COUNT, KEY_LENGTH); + SecretKey tmp = factory.generateSecret(spec); + SecretKeySpec secretKeySpec = new SecretKeySpec(tmp.getEncoded(), "AES"); + + Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); + cipher.init(Cipher.DECRYPT_MODE, secretKeySpec, ivspec); + + byte[] cipherText = new byte[encryptedData.length - 16]; + System.arraycopy(encryptedData, 16, cipherText, 0, cipherText.length); + + byte[] decryptedText = cipher.doFinal(cipherText); + return new String(decryptedText, StandardCharsets.UTF_8); + } catch (Exception e) { + // Handle the exception properly + e.printStackTrace(); + return null; + } + } + +} diff --git a/src/main/java/models/PublicKey.java b/src/main/java/models/PublicKey.java index af7c1c8..2ad4711 100644 --- a/src/main/java/models/PublicKey.java +++ b/src/main/java/models/PublicKey.java @@ -1,14 +1,6 @@ -/* - * Click nbfs://nbhost/SystemFileSystem/Templates/Licenses/license-default.txt to change this license - * Click nbfs://nbhost/SystemFileSystem/Templates/Classes/Class.java to edit this template - */ package models; import java.math.BigInteger; -/** - * - * @author mariano.riefolo - */ public record PublicKey(BigInteger e, BigInteger n) { } diff --git a/src/main/java/views/Client.java b/src/main/java/views/Client.java index ad385c5..dc74d5a 100644 --- a/src/main/java/views/Client.java +++ b/src/main/java/views/Client.java @@ -1,21 +1,20 @@ package views; import models.*; -import org.springframework.security.crypto.encrypt.AesBytesEncryptor; import java.io.*; import java.math.BigInteger; import java.net.Socket; import java.security.NoSuchAlgorithmException; -import java.util.Arrays; +import java.util.Objects; -public class Client{ +public class Client { private static Rsa rsa; private static BufferedReader in; private static BufferedWriter out; - private static BigInteger se,sn; + private static BigInteger se, sn; private static final String salt = "50726f6653616e736f6e6e65"; public static void connect() throws Exception { @@ -46,29 +45,14 @@ public class Client{ if ("FAIL".equals(read)) return false; BigInteger n = new BigInteger(rsa.decrypt(in.readLine())); BigInteger e = new BigInteger(rsa.decrypt(in.readLine())); - BigInteger d = new BigInteger(new AesBytesEncryptor(password, salt).decrypt(convertToByteArray(rsa.decrypt(in.readLine())))); + BigInteger d = new BigInteger(Objects.requireNonNull(Aes.decrypt(rsa.decrypt(in.readLine()), password, salt))); rsa = new Rsa(e, d, n); return true; } - public static byte[] convertToByteArray(String input) { - String cleanInput = input.replaceAll("\\[|]|\\s", ""); - - String[] numbersAsString = cleanInput.split(","); - - byte[] byteArray = new byte[numbersAsString.length]; - - for (int i = 0; i < numbersAsString.length; i++) { - byteArray[i] = Byte.parseByte(numbersAsString[i].trim()); - } - - return byteArray; - } - protected static boolean register(String password) throws NoSuchAlgorithmException, IOException { BigInteger d = rsa.getD(); - //Invio Username,password e Chiave privata AESata - sendEncrypted(Arrays.toString(new AesBytesEncryptor(password, salt).encrypt(d.toByteArray()))); + sendEncrypted(Aes.encrypt(d.toString(), password, salt)); return !"FAIL".equals(in.readLine()); } @@ -84,7 +68,7 @@ public class Client{ public static void main(String[] args) throws Exception { connect(); - ClientSendThread cst = new ClientSendThread(in,out,se,sn,rsa); + ClientSendThread cst = new ClientSendThread(in, out, se, sn, rsa); cst.start(); } } diff --git a/src/main/java/views/ClientReceiveThread.java b/src/main/java/views/ClientReceiveThread.java index 66d275a..f8b8018 100644 --- a/src/main/java/views/ClientReceiveThread.java +++ b/src/main/java/views/ClientReceiveThread.java @@ -5,15 +5,17 @@ import models.Rsa; import java.io.BufferedReader; import java.io.IOException; -public class ClientReceiveThread extends Thread{ +public class ClientReceiveThread extends Thread { private final BufferedReader in; private final Rsa rsa; - public ClientReceiveThread(BufferedReader in, Rsa rsa){ - this.in=in; - this.rsa=rsa; + + public ClientReceiveThread(BufferedReader in, Rsa rsa) { + this.in = in; + this.rsa = rsa; } + public void run() { - while (true){ + while (true) { try { String s = in.readLine(); System.out.print(rsa.decrypt(s)); diff --git a/src/main/java/views/ClientSendThread.java b/src/main/java/views/ClientSendThread.java index 5417b8d..261c26b 100644 --- a/src/main/java/views/ClientSendThread.java +++ b/src/main/java/views/ClientSendThread.java @@ -8,60 +8,60 @@ import java.security.NoSuchAlgorithmException; import java.util.logging.Level; import java.util.logging.Logger; -public class ClientSendThread extends Thread{ +public class ClientSendThread extends Thread { private final BufferedWriter out; - private final BigInteger se,sn; + private final BigInteger se, sn; private final Rsa rsa; private final BufferedReader in; - public ClientSendThread(BufferedReader in,BufferedWriter out,BigInteger se, BigInteger sn,Rsa rsa){ - this.in=in; + public ClientSendThread(BufferedReader in, BufferedWriter out, BigInteger se, BigInteger sn, Rsa rsa) { + this.in = in; this.out = out; this.se = se; - this.sn=sn; - this.rsa=rsa; + this.sn = sn; + this.rsa = rsa; } - public void run(){ + public void run() { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); + while (true) { try { String s = in.readLine(); System.out.print(rsa.decrypt(s)); s = br.readLine(); send(rsa.encrypt(s, se, sn)); - if("REGISTER".equals(s)){ - sendCredentials(br); + if ("REGISTER".equals(s)) { + s = sendCredentialsGetPassword(br); try { if (Client.register(s)) break; } catch (NoSuchAlgorithmException e) { System.err.println("Errore durante la registrazione"); } - } - else if("LOGIN".equals(s)){ - sendCredentials(br); + } else if ("LOGIN".equals(s)) { + s = sendCredentialsGetPassword(br); try { if (Client.login(s)) break; } catch (NoSuchAlgorithmException e) { System.err.println("Errore durante l'accesso"); } + } else { + System.err.println("Inserisci REGISTER o LOGIN"); } } catch (IOException e) { throw new RuntimeException(e); } } - ClientReceiveThread crt = new ClientReceiveThread(in,rsa); - crt.start(); - String s; try { + System.out.println(rsa.decrypt(in.readLine())); s = br.readLine(); send(rsa.encrypt(s, se, sn)); } catch (IOException ex) { Logger.getLogger(ClientSendThread.class.getName()).log(Level.SEVERE, null, ex); } - + BigInteger recipientE, recipientN; try { recipientE = new BigInteger(rsa.decrypt(in.readLine())); @@ -70,7 +70,10 @@ public class ClientSendThread extends Thread{ Logger.getLogger(ClientSendThread.class.getName()).log(Level.SEVERE, null, ex); return; } - + + ClientReceiveThread crt = new ClientReceiveThread(in, rsa); + crt.start(); + while (true) { try { s = br.readLine(); @@ -81,7 +84,7 @@ public class ClientSendThread extends Thread{ } } - private void sendCredentials(BufferedReader br) throws IOException { + private String sendCredentialsGetPassword(BufferedReader br) throws IOException { String s; s = in.readLine(); System.out.print(rsa.decrypt(s)); @@ -91,6 +94,7 @@ public class ClientSendThread extends Thread{ System.out.print(rsa.decrypt(s)); s = br.readLine(); send(rsa.encrypt(s, se, sn)); + return s; } public void send(String message) throws IOException {