diff --git a/cmd/server/main.go b/cmd/server/main.go index 69d8e9e..7f3293e 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -9,6 +9,7 @@ import ( "git.tijl.dev/tijl/tijl.dev/internal/db" "git.tijl.dev/tijl/tijl.dev/internal/i18n" "git.tijl.dev/tijl/tijl.dev/internal/oidc" + "git.tijl.dev/tijl/tijl.dev/internal/sessions" "git.tijl.dev/tijl/tijl.dev/modules/logger" "git.tijl.dev/tijl/tijl.dev/static" "git.tijl.dev/tijl/tijl.dev/views" @@ -110,7 +111,13 @@ func main() { // Common functions for in templating func getCommon(c *fiber.Ctx) fiber.Map { + _, err := sessions.GetSession(c) + signedIn := false + if err == nil { + signedIn = true + } return fiber.Map{ + "SignedIn": signedIn, "Path": c.Path(), "Language": i18n.GetLanguage(c), "T": i18n.GetTranslations(i18n.GetLanguage(c)), diff --git a/internal/db/db.go b/internal/db/db.go index 287f57d..150c6ee 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -23,7 +23,7 @@ func Load() { if err != nil { log.Fatal().Err(err).Msg("failed to connect to database") } - defer DB.Close() + //defer DB.Close() Queries = dbmanager.New(DB) log.Debug().Msg("connected to database") diff --git a/internal/middleware/session.go b/internal/middleware/session.go new file mode 100644 index 0000000..c870d7c --- /dev/null +++ b/internal/middleware/session.go @@ -0,0 +1 @@ +package middleware diff --git a/internal/oidc/handler.go b/internal/oidc/handler.go index 58fdb9e..1438630 100644 --- a/internal/oidc/handler.go +++ b/internal/oidc/handler.go @@ -2,20 +2,20 @@ package oidc import ( "context" + "database/sql" "errors" "net/http" + "git.tijl.dev/tijl/tijl.dev/internal/db" + "git.tijl.dev/tijl/tijl.dev/internal/sessions" "git.tijl.dev/tijl/tijl.dev/internal/utils" + "git.tijl.dev/tijl/tijl.dev/modules/database" log "git.tijl.dev/tijl/tijl.dev/modules/logger" "github.com/gofiber/fiber/v2" - "golang.org/x/oauth2" ) func HandleRedirect(c *fiber.Ctx) error { - state, err := utils.RandString(16) - if err != nil { - return err - } + state := utils.RandString(16) setCallbackCookie(c, "state", state) return c.Redirect(Config.AuthCodeURL(state), http.StatusFound) } @@ -40,14 +40,43 @@ func HandleCallback(c *fiber.Ctx) error { return err } - userInfo, err := Provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)) + var claims struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Username string `json:"preferred_username"` + } + if err := idToken.Claims(&claims); err != nil { + log.Error().Err(err).Msg("error getting claims") + return err + } + + _, err = db.Queries.GetUser(ctx, idToken.Subject) + if err == nil { + db.Queries.UpdateUserData(ctx, database.UpdateUserDataParams{ + Uid: idToken.Subject, + Email: claims.Email, + EmailVerified: claims.EmailVerified, + Username: claims.Username, + FullName: claims.Name, + }) + } else if err == sql.ErrNoRows { + db.Queries.CreateUser(ctx, database.CreateUserParams{ + Uid: idToken.Subject, + Email: claims.Email, + EmailVerified: claims.EmailVerified, + Username: claims.Username, + FullName: claims.Name, + }) + } else { + log.Error().Err(err).Msg("error getting user") + return err + } + + _, err = sessions.NewSession(idToken.Subject, c) if err != nil { return err } - log.Debug().Interface("userInfo", userInfo).Interface("idToken", idToken).Msg("data") - - // now we can create a user account and session in the db - return c.Redirect("/") } diff --git a/internal/sessions/sessions.go b/internal/sessions/sessions.go new file mode 100644 index 0000000..535deaa --- /dev/null +++ b/internal/sessions/sessions.go @@ -0,0 +1,53 @@ +package sessions + +import ( + "context" + + "git.tijl.dev/tijl/tijl.dev/internal/db" + "git.tijl.dev/tijl/tijl.dev/internal/utils" + "git.tijl.dev/tijl/tijl.dev/modules/database" + "github.com/gofiber/fiber/v2" +) + +func NewSession(uid string, c *fiber.Ctx) (string, error) { + createSessionParams := database.CreateSessionParams{ + Uid: uid, + Token: utils.RandString(64), + } + err := db.Queries.CreateSession(context.TODO(), createSessionParams) + if err != nil { + return "", err + } + c.Cookie(&fiber.Cookie{ + Name: "session", + Value: createSessionParams.Token, + Secure: true, + }) + + err = db.Queries.QuickUpdateSession(context.TODO(), database.QuickUpdateSessionParams{ + Token: createSessionParams.Token, + IpAddress: c.IP(), + Agent: string(c.Context().UserAgent()), + }) + if err != nil { + return "", err + } + + return createSessionParams.Token, nil +} + +func GetSession(c *fiber.Ctx) (database.Session, error) { + err := db.Queries.QuickUpdateSession(context.TODO(), database.QuickUpdateSessionParams{ + Token: c.Cookies("session"), + IpAddress: c.IP(), + Agent: string(c.Context().UserAgent()), + }) + if err != nil { + return database.Session{}, err + } + session, err := db.Queries.GetSession(context.TODO(), c.Cookies("session")) + if err != nil { + return session, err + } + return session, nil +} diff --git a/internal/utils/random.go b/internal/utils/random.go index d7c7684..8fcab1f 100644 --- a/internal/utils/random.go +++ b/internal/utils/random.go @@ -4,12 +4,14 @@ import ( "crypto/rand" "encoding/base64" "io" + + log "git.tijl.dev/tijl/tijl.dev/modules/logger" ) -func RandString(nByte int) (string, error) { +func RandString(nByte int) string { b := make([]byte, nByte) if _, err := io.ReadFull(rand.Reader, b); err != nil { - return "", err + log.Fatal().Err(err) } - return base64.RawURLEncoding.EncodeToString(b), nil + return base64.RawURLEncoding.EncodeToString(b) } diff --git a/migrations/00000001_init.up.sql b/migrations/00000001_init.up.sql index 41b0eb3..cf5ad46 100644 --- a/migrations/00000001_init.up.sql +++ b/migrations/00000001_init.up.sql @@ -1,30 +1,31 @@ CREATE TABLE users ( id SERIAL PRIMARY KEY, uid VARCHAR UNIQUE NOT NULL, -- username as unique identifier - email VARCHAR UNIQUE, - full_name VARCHAR, - displayname VARCHAR, + email VARCHAR NOT NULL, + email_verified BOOLEAN NOT NULL, + full_name VARCHAR NOT NULL, + username VARCHAR NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ); CREATE TABLE sessions ( id SERIAL PRIMARY KEY, - user_id INTEGER NOT NULL, - title VARCHAR, - token VARCHAR NOT NULL UNIQUE, - password VARCHAR, - last_activity TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + uid VARCHAR NOT NULL, + token VARCHAR NOT NULL UNIQUE NOT NULL, expires TIMESTAMP, + last_activity TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, - FOREIGN KEY (user_id) REFERENCES users (id) + FOREIGN KEY (uid) REFERENCES users (uid) ); CREATE TABLE session_ips ( id SERIAL PRIMARY KEY, session_id INTEGER NOT NULL, - ip_address INET NOT NULL, - access_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) REFERENCES sessions (id) + ip_address VARCHAR NOT NULL, + agent VARCHAR NOT NULL, + access_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + FOREIGN KEY (session_id) REFERENCES sessions (id), + CONSTRAINT session_ips_unique UNIQUE (session_id, ip_address) ); diff --git a/migrations/00000002_analytics.up.sql b/migrations/00000002_analytics.up.sql deleted file mode 100644 index e69de29..0000000 diff --git a/modules/database/models.go b/modules/database/models.go index 18e81dd..adaf94d 100644 --- a/modules/database/models.go +++ b/modules/database/models.go @@ -7,34 +7,32 @@ package database import ( "database/sql" "time" - - "github.com/sqlc-dev/pqtype" ) type Session struct { ID int32 - UserID int32 - Title sql.NullString + Uid string Token string - Password sql.NullString - LastActivity sql.NullTime Expires sql.NullTime + LastActivity time.Time CreatedAt time.Time } type SessionIp struct { ID int32 SessionID int32 - IpAddress pqtype.Inet - AccessTime sql.NullTime + IpAddress string + Agent string + AccessTime time.Time } type User struct { - ID int32 - Uid string - Email sql.NullString - FullName sql.NullString - Displayname sql.NullString - CreatedAt time.Time - UpdatedAt sql.NullTime + ID int32 + Uid string + Email string + EmailVerified bool + FullName string + Username string + CreatedAt time.Time + UpdatedAt time.Time } diff --git a/modules/database/sessions.sql b/modules/database/sessions.sql index ca3c7f0..5151ce3 100644 --- a/modules/database/sessions.sql +++ b/modules/database/sessions.sql @@ -1,18 +1,34 @@ --- name: GetSesssion :one +-- name: GetSession :one SELECT * FROM sessions WHERE token = $1; -- name: GetSessions :many -SELECT * FROM sessions WHERE user_id = $1 ORDER BY $2; +SELECT * FROM sessions WHERE uid = $1 ORDER BY $2; -- name: GetActiveSessions :many -SELECT * FROM sessions WHERE user_id = $1 AND (expires > CURRENT_TIMESTAMP OR expires IS NULL) ORDER BY $2; +SELECT * FROM sessions WHERE uid = $1 AND (expires > CURRENT_TIMESTAMP OR expires IS NULL) ORDER BY $2; -- name: CreateSession :exec -INSERT INTO sessions (user_id, title, token) VALUES ($1, $2, $3); +INSERT INTO sessions (uid, token, last_activity) VALUES ($1, $2, CURRENT_TIMESTAMP); -- name: QuickUpdateSession :exec -UPDATE sessions SET last_activity = GETDATE() WHERE id = $1; +WITH updated_session AS ( + UPDATE sessions + SET last_activity = CURRENT_TIMESTAMP + WHERE token = $1 + RETURNING id +) +INSERT INTO session_ips (session_id, ip_address, agent, access_time) +VALUES ( + (SELECT id FROM updated_session), + $2, + $3, + CURRENT_TIMESTAMP +) +ON CONFLICT (session_id, ip_address) +DO UPDATE SET + agent = EXCLUDED.agent, + access_time = CURRENT_TIMESTAMP; -- name: ExpireSession :exec -UPDATE sessions SET expires = 1 WHERE id = $1; +UPDATE sessions SET expires = 1 WHERE token = $1; diff --git a/modules/database/sessions.sql.go b/modules/database/sessions.sql.go index 87b5b19..ae85e04 100644 --- a/modules/database/sessions.sql.go +++ b/modules/database/sessions.sql.go @@ -7,44 +7,42 @@ package database import ( "context" - "database/sql" ) const createSession = `-- name: CreateSession :exec -INSERT INTO sessions (user_id, title, token) VALUES ($1, $2, $3) +INSERT INTO sessions (uid, token, last_activity) VALUES ($1, $2, CURRENT_TIMESTAMP) ` type CreateSessionParams struct { - UserID int32 - Title sql.NullString - Token string + Uid string + Token string } func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) error { - _, err := q.db.ExecContext(ctx, createSession, arg.UserID, arg.Title, arg.Token) + _, err := q.db.ExecContext(ctx, createSession, arg.Uid, arg.Token) return err } const expireSession = `-- name: ExpireSession :exec -UPDATE sessions SET expires = 1 WHERE id = $1 +UPDATE sessions SET expires = 1 WHERE token = $1 ` -func (q *Queries) ExpireSession(ctx context.Context, id int32) error { - _, err := q.db.ExecContext(ctx, expireSession, id) +func (q *Queries) ExpireSession(ctx context.Context, token string) error { + _, err := q.db.ExecContext(ctx, expireSession, token) return err } const getActiveSessions = `-- name: GetActiveSessions :many -SELECT id, user_id, title, token, password, last_activity, expires, created_at FROM sessions WHERE user_id = $1 AND (expires > CURRENT_TIMESTAMP OR expires IS NULL) ORDER BY $2 +SELECT id, uid, token, expires, last_activity, created_at FROM sessions WHERE uid = $1 AND (expires > CURRENT_TIMESTAMP OR expires IS NULL) ORDER BY $2 ` type GetActiveSessionsParams struct { - UserID int32 + Uid string Column2 interface{} } func (q *Queries) GetActiveSessions(ctx context.Context, arg GetActiveSessionsParams) ([]Session, error) { - rows, err := q.db.QueryContext(ctx, getActiveSessions, arg.UserID, arg.Column2) + rows, err := q.db.QueryContext(ctx, getActiveSessions, arg.Uid, arg.Column2) if err != nil { return nil, err } @@ -54,12 +52,10 @@ func (q *Queries) GetActiveSessions(ctx context.Context, arg GetActiveSessionsPa var i Session if err := rows.Scan( &i.ID, - &i.UserID, - &i.Title, + &i.Uid, &i.Token, - &i.Password, - &i.LastActivity, &i.Expires, + &i.LastActivity, &i.CreatedAt, ); err != nil { return nil, err @@ -75,72 +71,90 @@ func (q *Queries) GetActiveSessions(ctx context.Context, arg GetActiveSessionsPa return items, nil } -const getSessions = `-- name: GetSessions :many -SELECT id, user_id, title, token, password, last_activity, expires, created_at FROM sessions WHERE user_id = $1 ORDER BY $2 +const getSession = `-- name: GetSession :one +SELECT id, uid, token, expires, last_activity, created_at FROM sessions WHERE token = $1 ` -type GetSessionsParams struct { - UserID int32 - Column2 interface{} -} - -func (q *Queries) GetSessions(ctx context.Context, arg GetSessionsParams) ([]Session, error) { - rows, err := q.db.QueryContext(ctx, getSessions, arg.UserID, arg.Column2) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Session - for rows.Next() { - var i Session - if err := rows.Scan( - &i.ID, - &i.UserID, - &i.Title, - &i.Token, - &i.Password, - &i.LastActivity, - &i.Expires, - &i.CreatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const getSesssion = `-- name: GetSesssion :one -SELECT id, user_id, title, token, password, last_activity, expires, created_at FROM sessions WHERE token = $1 -` - -func (q *Queries) GetSesssion(ctx context.Context, token string) (Session, error) { - row := q.db.QueryRowContext(ctx, getSesssion, token) +func (q *Queries) GetSession(ctx context.Context, token string) (Session, error) { + row := q.db.QueryRowContext(ctx, getSession, token) var i Session err := row.Scan( &i.ID, - &i.UserID, - &i.Title, + &i.Uid, &i.Token, - &i.Password, - &i.LastActivity, &i.Expires, + &i.LastActivity, &i.CreatedAt, ) return i, err } -const quickUpdateSession = `-- name: QuickUpdateSession :exec -UPDATE sessions SET last_activity = GETDATE() WHERE id = $1 +const getSessions = `-- name: GetSessions :many +SELECT id, uid, token, expires, last_activity, created_at FROM sessions WHERE uid = $1 ORDER BY $2 ` -func (q *Queries) QuickUpdateSession(ctx context.Context, id int32) error { - _, err := q.db.ExecContext(ctx, quickUpdateSession, id) +type GetSessionsParams struct { + Uid string + Column2 interface{} +} + +func (q *Queries) GetSessions(ctx context.Context, arg GetSessionsParams) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, getSessions, arg.Uid, arg.Column2) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Uid, + &i.Token, + &i.Expires, + &i.LastActivity, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const quickUpdateSession = `-- name: QuickUpdateSession :exec +WITH updated_session AS ( + UPDATE sessions + SET last_activity = CURRENT_TIMESTAMP + WHERE token = $1 + RETURNING id +) +INSERT INTO session_ips (session_id, ip_address, agent, access_time) +VALUES ( + (SELECT id FROM updated_session), + $2, + $3, + CURRENT_TIMESTAMP +) +ON CONFLICT (session_id, ip_address) +DO UPDATE SET + agent = EXCLUDED.agent, + access_time = CURRENT_TIMESTAMP +` + +type QuickUpdateSessionParams struct { + Token string + IpAddress string + Agent string +} + +func (q *Queries) QuickUpdateSession(ctx context.Context, arg QuickUpdateSessionParams) error { + _, err := q.db.ExecContext(ctx, quickUpdateSession, arg.Token, arg.IpAddress, arg.Agent) return err } diff --git a/modules/database/users.sql b/modules/database/users.sql index 19337e3..3c52a3a 100644 --- a/modules/database/users.sql +++ b/modules/database/users.sql @@ -11,15 +11,16 @@ SELECT * FROM users WHERE id = $1 LIMIT 1; DELETE FROM users WHERE uid = $1; -- name: CreateUser :exec -INSERT INTO users (uid, email, full_name, displayname) -VALUES ($1, $2, $3, $4) +INSERT INTO users (uid, email, email_verified, full_name, username) +VALUES ($1, $2, $3, $4, $5) RETURNING id; -- name: UpdateUserData :exec UPDATE users SET email = COALESCE($2, email), - full_name = COALESCE($3, full_name), - displayname = COALESCE($4, displayname), + email_verified = COALESCE($3, email_verified), + full_name = COALESCE($4, full_name), + username = COALESCE($5, username), updated_at = CURRENT_TIMESTAMP WHERE uid = $1; diff --git a/modules/database/users.sql.go b/modules/database/users.sql.go index 4f62e72..070d0c4 100644 --- a/modules/database/users.sql.go +++ b/modules/database/users.sql.go @@ -7,28 +7,29 @@ package database import ( "context" - "database/sql" ) const createUser = `-- name: CreateUser :exec -INSERT INTO users (uid, email, full_name, displayname) -VALUES ($1, $2, $3, $4) +INSERT INTO users (uid, email, email_verified, full_name, username) +VALUES ($1, $2, $3, $4, $5) RETURNING id ` type CreateUserParams struct { - Uid string - Email sql.NullString - FullName sql.NullString - Displayname sql.NullString + Uid string + Email string + EmailVerified bool + FullName string + Username string } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) error { _, err := q.db.ExecContext(ctx, createUser, arg.Uid, arg.Email, + arg.EmailVerified, arg.FullName, - arg.Displayname, + arg.Username, ) return err } @@ -43,7 +44,7 @@ func (q *Queries) DeleteUser(ctx context.Context, uid string) error { } const getUser = `-- name: GetUser :one -SELECT id, uid, email, full_name, displayname, created_at, updated_at FROM users WHERE uid = $1 LIMIT 1 +SELECT id, uid, email, email_verified, full_name, username, created_at, updated_at FROM users WHERE uid = $1 LIMIT 1 ` func (q *Queries) GetUser(ctx context.Context, uid string) (User, error) { @@ -53,8 +54,9 @@ func (q *Queries) GetUser(ctx context.Context, uid string) (User, error) { &i.ID, &i.Uid, &i.Email, + &i.EmailVerified, &i.FullName, - &i.Displayname, + &i.Username, &i.CreatedAt, &i.UpdatedAt, ) @@ -62,7 +64,7 @@ func (q *Queries) GetUser(ctx context.Context, uid string) (User, error) { } const getUserById = `-- name: GetUserById :one -SELECT id, uid, email, full_name, displayname, created_at, updated_at FROM users WHERE id = $1 LIMIT 1 +SELECT id, uid, email, email_verified, full_name, username, created_at, updated_at FROM users WHERE id = $1 LIMIT 1 ` func (q *Queries) GetUserById(ctx context.Context, id int32) (User, error) { @@ -72,8 +74,9 @@ func (q *Queries) GetUserById(ctx context.Context, id int32) (User, error) { &i.ID, &i.Uid, &i.Email, + &i.EmailVerified, &i.FullName, - &i.Displayname, + &i.Username, &i.CreatedAt, &i.UpdatedAt, ) @@ -94,25 +97,28 @@ func (q *Queries) GetUserUid(ctx context.Context, id int32) (string, error) { const updateUserData = `-- name: UpdateUserData :exec UPDATE users SET email = COALESCE($2, email), - full_name = COALESCE($3, full_name), - displayname = COALESCE($4, displayname), + email_verified = COALESCE($3, email_verified), + full_name = COALESCE($4, full_name), + username = COALESCE($5, username), updated_at = CURRENT_TIMESTAMP WHERE uid = $1 ` type UpdateUserDataParams struct { - Uid string - Email sql.NullString - FullName sql.NullString - Displayname sql.NullString + Uid string + Email string + EmailVerified bool + FullName string + Username string } func (q *Queries) UpdateUserData(ctx context.Context, arg UpdateUserDataParams) error { _, err := q.db.ExecContext(ctx, updateUserData, arg.Uid, arg.Email, + arg.EmailVerified, arg.FullName, - arg.Displayname, + arg.Username, ) return err } diff --git a/views/index.html b/views/index.html index 5e79ea2..4cec2ea 100644 --- a/views/index.html +++ b/views/index.html @@ -1,3 +1,3 @@

Welcome to My Go App

This is the homepage.

-{{.T.about}} +
Signed In: {{.SignedIn}}