diff --git a/src/main.rs b/src/main.rs index 344e7b4..b15675f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,25 @@ mod tsar; +mod security; use std::{env, net::{IpAddr, SocketAddr}}; use axum::{ - routing::get, + routing::{get, post}, Router, - extract::{Path, State}, http::StatusCode, + extract::{Path, State}, http::StatusCode, Json, }; -use rand::{ thread_rng, seq::IteratorRandom }; +use rand::{thread_rng, seq::IteratorRandom}; +use security::salt_and_hash; +use serde::Deserialize; use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; +use base64::prelude::*; + + +#[derive(Deserialize)] +struct SetPronounsRequest { + password: String, + new_pronouns: String, +} #[tokio::main] @@ -31,6 +42,7 @@ async fn main() { let app = Router::new() .route("/api/thethirdcan/hello", get(|| async { "hello world!" })) .route("/api/thethirdcan/pronouns/:user", get(user_pronouns)) + .route("/api/thethirdcan/pronouns/:user", post(set_pronouns)) .with_state(pool); axum::Server::bind(&SocketAddr::new(IpAddr::from([0;4]), port)) @@ -46,7 +58,7 @@ async fn user_pronouns( ) -> Result { sqlx::query!("SELECT pronouns FROM users - WHERE username == ?", + WHERE username = ?", user) .fetch_one(pool) .await @@ -61,3 +73,29 @@ async fn user_pronouns( .map(ToOwned::to_owned) .ok_or(StatusCode::INTERNAL_SERVER_ERROR) } + +async fn set_pronouns( + State(pool): State<&SqlitePool>, + Path(user): Path, + Json(payload): Json, + ) -> Result<(), StatusCode> { + + let password_data = BASE64_URL_SAFE.decode(payload.password).map_err(|_| StatusCode::BAD_REQUEST)?; + let salted_hashed_password = salt_and_hash(&password_data); + let auth_slice = salted_hashed_password.as_ref(); + + let rows_affected = sqlx::query!("UPDATE users + SET pronouns = ? + WHERE username = ? AND auth = ? + ", payload.new_pronouns, user, auth_slice) + .execute(pool) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .rows_affected(); + + if rows_affected == 0 { + return Err(StatusCode::NOT_FOUND); + } + + Ok(()) +} diff --git a/src/security.rs b/src/security.rs new file mode 100644 index 0000000..40a8a8d --- /dev/null +++ b/src/security.rs @@ -0,0 +1,8 @@ +use sha2::{Sha256, Digest}; + +pub fn salt_and_hash(input: &[u8]) -> Box<[u8]> { + let mut hasher = Sha256::new(); + hasher.update(input); + hasher.update("get salted on lmao"); + hasher.finalize().as_slice().into() +} diff --git a/src/tsar.rs b/src/tsar.rs index 9d0ccf8..d66ab7e 100644 --- a/src/tsar.rs +++ b/src/tsar.rs @@ -5,9 +5,10 @@ use base64::prelude::*; use reqwest::Client; use serde::{Serialize, Deserialize}; -use sha2::{Sha256, Digest}; use sqlx::SqlitePool; +use crate::security::salt_and_hash; + #[derive(Serialize, Deserialize, Debug)] struct LoginInfo { #[serde(rename = "user")] @@ -160,7 +161,7 @@ async fn handle_event(pool: &SqlitePool, login: &LoginInfo, thread: &MessageThre let sid = thread.sid; let mid = event.mid; let Some(uid) = thread.members.iter().find(|member| member.mid == mid).and_then(|member| member.uid) else { return Err(()) }; - let Some(username) = res.users.iter().find(|user| user.id == uid).map(|user| &user.name) else { return Err(()) }; + let Some(username) = res.users.iter().find(|user| user.id == uid).map(|user| &user.key) else { return Err(()) }; debug!("handling '{text}' from {username} in thread {sid}"); @@ -172,16 +173,10 @@ Note: this will remove your previous authentication code, if you had one."#.into } else if text.starts_with("get_authentication") { let key_data: [u8; 16] = thread_rng().gen(); let b64_key = BASE64_URL_SAFE.encode(key_data); - let hashed_key = { - let mut hasher = Sha256::new(); - hasher.update(key_data); - hasher.update("get salted on lmao"); - hasher.finalize() - }; - let hashed_key_slice = hashed_key.as_slice(); - + let hashed_key = salt_and_hash(&key_data); + let hashed_key_slice = hashed_key.as_ref(); - let user_in_db = sqlx::query!("SELECT username FROM users WHERE username == ?;", username) + let user_in_db = sqlx::query!("SELECT username FROM users WHERE username = ?;", username) .fetch_optional(pool) .await .map_err(|e| error!("failed to search database for username: {e}"))? @@ -190,7 +185,7 @@ Note: this will remove your previous authentication code, if you had one."#.into if user_in_db { sqlx::query!("UPDATE users SET auth = ? - WHERE username == ?; + WHERE username = ?; ", hashed_key_slice, username) .execute(pool) .await