option(null: hard);

module user_identity {

integer SALT_SIZE option(constant) = 16; // 128 bits
integer MIN_PASSWORD_LEN option(constant) = 4;
integer MIN_USERNAME_LEN option(constant) = 2;
integer TOKEN_LENGTH option(constant) = 24;


// If a new (better) method for hashing is introduced, this constant must be
// increased. The method is stored in the DB hash, and when checking against
// hash stored in DB, correct method must obviously be used. Implement the
// method in the function hash_encode under a new case in the switch.
integer BEST_METHOD option(constant) = 0;


string TABLE option(constant) = "AspNetUsers";


blob hash_encode(integer method, string password, blob salt)
{
	/* New method?
	 * Read comment by BEST_METHOD declaration above.
	 */

	switch (method) {
	case 0: {
		blob salted_pwd = new blob(password, salt);
		return ssl.sha256(salted_pwd);
	}
	default:
		throw (E_UNSPECIFIC, "Invalid hashing method");
	}
}

blob mk_salt()
{
	return ssl.random_bytes(SALT_SIZE);
}


void split_db_hash(blob db_hash, out integer method, out blob salt, out blob pwd_hash)
{
	vector(integer) bytes = db_hash.get_bytes();
	method = bytes[0];
	salt = new blob(bytes[1:SALT_SIZE]);
	pwd_hash = new blob(bytes[SALT_SIZE+1:]);
}


blob mk_db_hash(string password)
{
	if (strlen(password) < MIN_PASSWORD_LEN) {
		throw (E_UNSPECIFIC, "Password too short");
	}

	integer method = BEST_METHOD;
	if (method < 0 || method >= 256) {
		throw (E_UNSPECIFIC, "Invalid BEST_METHOD (must be stored as byte)");
	}

	blob salt = mk_salt();
	blob hash = hash_encode(method, password, salt);
	blob method_blob = new blob([method]);

	return new blob([method_blob, salt, hash]);
}

logical verify_password(string provided_password, blob db_hash)
{
	integer method;
	blob salt;
	blob correct_hash;
	split_db_hash(db_hash, method, salt, correct_hash);

	blob test_hash = hash_encode(method, provided_password, salt);

	return ssl.secure_compare(test_hash, correct_hash);
}

blob mk_guid(string use_this)
{
	string t = str(now_ut());
	string u = get_user_name();
	string p = string(get_process_id());

	blob b = new blob(str_join([t, u, p, use_this], "+"));

	blob h = ssl.sha256(b);
	return new blob(h.get_bytes()[0:15]);
}

string hex(integer i)
{
	if (i < 0 || i >= 16)
		throw (E_UNSPECIFIC, "Arg to hex out of bounds");

	if (i < 10)
		return str(i);

	switch (i) {
	case 10: return "a";
	case 11: return "b";
	case 12: return "c";
	case 13: return "d";
	case 14: return "e";
	case 15: return "f";
	}
}


string guid_to_str(blob guid)
{
	vector(string) result[36];
	vector(integer) bytes = guid.get_bytes();
	integer n = v_size(bytes);
	integer r = 0;
	for (integer i = 0; i < n; ++i) {
		integer b = bytes[i];
		switch (i) {
		case 4:
		case 6:
		case 8:
		case 10:
			result[r++] = "-";
		}

		integer b2 = b % 16;
		integer b1 = integer((b - b2) / 16);

		result[r++] = hex(b1);
		result[r++] = hex(b2);
	}

	return strcat(result);
}

string normalize(string s)
{
	// In the future normalization may include removal of accents etc
	return str_to_upper(s);
}

void list_users(odbc.connection conn, out vector(string) users, out vector(string) emails)
{
	string sql = strcat(["SELECT UserName, Email FROM ", TABLE, " ORDER BY UserName"]);
	odbc.statement stmt = new odbc.statement(conn, sql);
	stmt.execute();
	stmt.bulk_fetch();
	users = stmt.bulk_get_str(0);
	emails = stmt.bulk_get_str(1);
}

string fetch_concurrency_stamp(odbc.connection conn, string username)
{
	string normalized_username = normalize(username);
	string sql = strcat(
		["SELECT ConcurrencyStamp FROM [", TABLE, "]"
		 , " WHERE NormalizedUserName = ?"
			]);
	odbc.statement stmt = new odbc.statement(conn, sql);
	integer k = 0;
	stmt.set_param(k++, normalized_username); // Take advantage of possible db index on NormalizedUserName
	stmt.execute([40]);

	if (stmt.fetch()) {
		return stmt.get_str(0);
	} else {
		throw (E_UNSPECIFIC, "No such user");
	}
}
map_str_str fetch_user_data(odbc.connection conn)
{
	map_str_str ret = map_str_str();

	string sql = strcat(
		["SELECT UserName, PasswordHash FROM [", TABLE, "]"]);
	odbc.statement stmt = new odbc.statement(conn, sql);

	stmt.execute([128, 128]);

	while (stmt.fetch())
	{
		string normalized_username = normalize(stmt.get_str(0));

		string db_hash_base64 = stmt.get_str(1);
		ret.add(normalized_username, db_hash_base64);
	}

	return ret;
}

logical verify_password(
	odbc.connection conn,
	string username,
	string provided_password,
	out string concurrency_stamp)
{
	concurrency_stamp = null;

	string normalized_username = normalize(username);
	string sql = strcat(
		["SELECT PasswordHash, ConcurrencyStamp FROM [", TABLE, "]"
		 , " WHERE NormalizedUserName = ?"
			]);
	odbc.statement stmt = new odbc.statement(conn, sql);
	integer k = 0;
	stmt.set_param(k++, normalized_username); // Take advantage of possible db index on NormalizedUserName
	stmt.execute([128, 40]);

	if (stmt.fetch()) {
		string db_hash_base64 = stmt.get_str(0);
		blob db_hash = base64_decode(db_hash_base64);
		logical correct_password = verify_password(provided_password, db_hash);
		if (correct_password) {
			concurrency_stamp = stmt.get_str(1);
		}
		return correct_password;
	} else {
		return false;
	}
}


void update_password(
	odbc.connection conn,
	string username,
	blob password_hash,
	string from_concurrency_stamp)
{
	string normalized_username = normalize(username);
	string password_base64 = base64_encode(password_hash, false);
	string new_concurrency_stamp = guid_to_str(mk_guid(strcat([username, ":concurrency:"])));
	string sql = strcat(
		["   IF EXISTS (SELECT 1 FROM ", TABLE, " WHERE NormalizedUserName = ?)"
		 , " BEGIN"
		 , "   IF EXISTS (SELECT 1 FROM ", TABLE, " WHERE NormalizedUserName = ? AND ConcurrencyStamp = ?)"
		 , "     UPDATE ", TABLE, " SET PasswordHash = ?, ConcurrencyStamp = ?"
		 , "     WHERE NormalizedUserName = ? "
		 , "   ELSE "
		 , "     THROW 53000, 'Wrong ConcurrencyStamp', 1"
		 , " END"
		 , " ELSE"
		 , "   THROW 52000, 'No such user', 1"
			]);

	odbc.statement stmt = new odbc.statement(conn, sql);
	integer k = 0;
	stmt.set_param(k++, normalized_username);
	stmt.set_param(k++, normalized_username);
	stmt.set_param(k++, from_concurrency_stamp);
	stmt.set_param(k++, password_base64);
	stmt.set_param(k++, new_concurrency_stamp);
	stmt.set_param(k++, normalized_username);

	try {
		stmt.execute();
	} catch {
		throw (E_UNSPECIFIC, strcat("Cannot update password: ", err.message()));
	}
}

void insert_user(
	odbc.connection conn,
	string username,
	string option(nullable) email,
	blob password_hash)
{
	if (strlen(username) < MIN_USERNAME_LEN) {
		throw (E_UNSPECIFIC, "Username too short");
	}

	string id_str = guid_to_str(mk_guid(strcat([username, ":id:", email])));
	string normalized_username = normalize(username);
	string normalized_email = null(email) ? null : normalize(email);
	string password_base64 = base64_encode(password_hash, false);
	string concurrency_stamp = guid_to_str(mk_guid(strcat([username, ":concurrency:", email])));
	string sql = strcat(
		["   IF NOT EXISTS (SELECT 1 FROM ", TABLE, " WHERE NormalizedUserName = ?"
		 , null(email) ? ")" : " OR NormalizedEmail = ?)"
		 , "   INSERT INTO ", TABLE
		 , "   (Id, UserName, NormalizedUserName, Email, NormalizedEmail, PasswordHash, ConcurrencyStamp"
		 , "    , EmailConfirmed, PhoneNumberConfirmed, TwoFactorEnabled, LockoutEnabled, AccessFailedCount) "
		 , "   VALUES (?, ?, ?, ?, ?, ?, ?, 0, 0, 0, 0, 0)"
		 , " ELSE "
		 , "   THROW 51000, 'Collision on UserName or Email', 1"
			]);

	odbc.statement stmt = new odbc.statement(conn, sql);
	integer k = 0;
	stmt.set_param(k++, normalized_username);
	if (!null(email))
		stmt.set_param(k++, normalized_email);
	stmt.set_param(k++, id_str);
	stmt.set_param(k++, username);
	stmt.set_param(k++, normalized_username);
	stmt.set_param(k++, email);
	stmt.set_param(k++, normalized_email);
	stmt.set_param(k++, password_base64);
	stmt.set_param(k++, concurrency_stamp);

	try {
		stmt.execute();
	} catch {
		throw (E_UNSPECIFIC, strcat("Cannot add user: ", err.message()));
	}
}

void delete_user_from_db(odbc.connection conn, string username)
{
	string normalized_username = normalize(username);
	string sql = strcat(
		["   IF EXISTS (SELECT 1 FROM ", TABLE, " WHERE NormalizedUserName = ?)"
		 , "   DELETE FROM ", TABLE, " WHERE NormalizedUserName = ?"
		 , " ELSE"
		 , "  THROW 52000, 'No such user', 1"
			]);
	odbc.statement stmt = new odbc.statement(conn, sql);
	integer k = 0;
	stmt.set_param(k++, normalized_username);
	stmt.set_param(k++, normalized_username);

	try {
		stmt.execute();
	} catch {
		throw (E_UNSPECIFIC, strcat("Cannot delete user: ", err.message()));
	}
}

void add_user_to_db(odbc.connection conn, string username, string option(nullable) email, string password)
{
	blob pwd_hash = mk_db_hash(password);

	insert_user(conn, username, email, pwd_hash);
}

void change_password(odbc.connection conn, string username, string password, string concurrency_stamp)
{
	blob pwd_hash = mk_db_hash(password);

	update_password(conn, username, pwd_hash, concurrency_stamp);
}

class session {
public:
	session(string uid, integer lifetime_minutes, integer allowed_idleness_minutes);
	string uid();
	logical expired();
	logical update_last_access();

	timestamp created_at();
	timestamp expires_at();
	timestamp last_access();

private:
	string uid_;
	timestamp created_at_;
	timestamp expiry_;
	integer allowed_idleness_;
	timestamp last_access_;

	timestamp save_req_resp_until_;

	void __dbg_print(__dbg_label l);
};

session.session(string uid, integer lifetime_minutes, integer allowed_idleness_minutes)
: uid_(uid), created_at_(now_ut())
{
	expiry_ = created_at_ + lifetime_minutes * 60000;
	last_access_ = timestamp(date(created_at_), hour(created_at_), minute(created_at_), 0); // Only store whole minutes
	if (allowed_idleness_minutes > 0)
		allowed_idleness_ = allowed_idleness_minutes * 60000;
	else
		allowed_idleness_ = 0; // 0 means idle logout not in use
}

string session.uid() = uid_;
logical session.expired()
{
	if (now_ut() > expiry_)
		return true;

	// Allow 60 sec leeway, since last_access can be up to one minute behind.
	if (allowed_idleness_ != 0 && now_ut() > last_access_ + allowed_idleness_ + 60000)
		return true;

	return false;
}

logical session.update_last_access()
{
	timestamp t = now_ut();
	if (t - last_access_ < 60000)
		return false; // No update needed, still in the same minute

	last_access_ = timestamp(date(t), hour(t), minute(t), 0);
	return true;
}

timestamp session.created_at() = created_at_;
timestamp session.expires_at() = expiry_;
timestamp session.last_access() = last_access_;

void session.__dbg_print(__dbg_label l)
{
	string last_access;
	if (date(last_access_) == today())
		last_access = timestamp_to_str(last_access_, "%H:%M");
	else
		last_access = timestamp_to_str(last_access_, "%Y%m%d %H%M");

	l.set_text(strcat(["session { \"", uid_, "\", Last access: ", last_access, ", Expires: ", str(expiry_), ", ", pretty_duration(now_ut(), expiry_), " }"]));
}

class user_set {
public:
	user_set();

	string login_user(string uid, integer session_length, integer idleness_length);
	void logout_user(string uid);

	string get_user_info(string token, out logical last_access_updated); // Throws if session is expired
	session get_session_by_token(string token); // Does not check session at all
	vector(session) get_sessions_by_user_id(string user_id); // Does not check session

	void enable_save_req_resp(string uid, integer timeout_sec);
	logical disable_save_req_resp(string uid);
	logical should_save_req_resp(string uid);

	vector(session) get_sessions();

private:
	// Session stuff
	map_str_obj<session> token2session_;
	map_str_str uid2token_;
	map_str_timestamp uid2save_until_;

	// Private funcs
	session _get_session(string token, out logical last_access_updated);
	void __dbg_print(__dbg_label l);
};

user_set.user_set()
:
token2session_(new map_str_obj<session>()),
uid2token_(map_str_str()),
uid2save_until_(map_str_timestamp())
{}

session user_set._get_session(string token, out logical last_access_updated)
{
	session ssn = token2session_.find(token);

	if (null(ssn)) {
		throw (E_UNSPECIFIC, "Invalid token");
	}

	if (ssn.expired()) {
		throw (E_UNSPECIFIC, "Session expired");
	}

	last_access_updated = ssn.update_last_access();

	return ssn;
}

string user_set.login_user(string uid, integer session_length, integer idleness_length)
{
	if (uid == "")
		throw (E_UNSPECIFIC, "uid may not be empty string");

	// Delete existing session if there is one
	string old_token = uid2token_.find(uid);
	if (!null(old_token)) {
		uid2token_.remove(uid);
		token2session_.remove(old_token);
	}

	// Create a new session with new token
	string token = base64_encode(ssl.random_bytes(TOKEN_LENGTH), true);
	session ssn = new session(uid, session_length, idleness_length);
	token2session_.add(token, ssn);
	uid2token_.add(uid, token);
	return token;
}

void user_set.logout_user(string uid)
{
	string token = uid2token_.find(uid);
	if (null(token)) {
		throw (E_UNSPECIFIC, "No active session");
	}

	token2session_.remove(token);
	uid2token_.remove(uid);
}

string user_set.get_user_info(string token, out logical last_access_updated)
{
	session ssn = _get_session(token, last_access_updated);
	return ssn.uid();
}

session user_set.get_session_by_token(string token) = token2session_.find(token);

vector(session) user_set.get_sessions_by_user_id(string user_id)
{
	// Return a vector as a future compatibility. We may allow for a user to have
	// multiple active sessions, and this function will still work. It should
	// return the sessions in order of creation time probably. Or no order at all.
	string token = uid2token_.find(user_id);
	if (null(token))
		return vector(i:0; null<session>);

	return [token2session_.find(token)];
}

void user_set.enable_save_req_resp(string uid, integer timeout_sec)
{
	timestamp save_req_resp_until = now() + 1000*timeout_sec;
	uid2save_until_.add(uid, save_req_resp_until);
}

logical user_set.disable_save_req_resp(string uid)
{
	if (null(uid2save_until_.find(uid)))
		return false;

	uid2save_until_.remove(uid);
	return true;
}

logical user_set.should_save_req_resp(string uid)
{
	timestamp save_req_resp_until = uid2save_until_.find(uid);
	if (null(save_req_resp_until))
		return false;

	if (now() > save_req_resp_until) {
		uid2save_until_.remove(uid);
		return false;
	}

	return true;
}

vector(session) user_set.get_sessions()
{
	vector(string) tokens;
	vector(session) sessions;
	token2session_.get_content(tokens, sessions);
	return sessions;
}

void user_set.__dbg_print(__dbg_label l)
{
	l.set_text(strcat(["user_set { sessions=", str(token2session_.size()), " }"]));
}

}
