From fe4b31be06a704858929d3413fcb0dace9d63eb2 Mon Sep 17 00:00:00 2001 From: Tovi Jaeschke-Rogers Date: Tue, 15 Mar 2022 21:10:20 +1030 Subject: [PATCH 1/5] Add createUser and getUsers API endpoints --- Api/JsonSerialization/DeserializeUserJson.go | 76 +++++++++++ Api/Posts.go | 2 - Api/Routes.go | 3 + Api/Users.go | 102 ++++++++++++++ Api/Users_test.go | 132 +++++++++++++++++++ Database/Init.go | 3 + Database/Users.go | 63 +++++++++ Models/Users.go | 15 +++ 8 files changed, 394 insertions(+), 2 deletions(-) create mode 100644 Api/JsonSerialization/DeserializeUserJson.go create mode 100644 Api/Users.go create mode 100644 Api/Users_test.go create mode 100644 Database/Users.go create mode 100644 Models/Users.go diff --git a/Api/JsonSerialization/DeserializeUserJson.go b/Api/JsonSerialization/DeserializeUserJson.go new file mode 100644 index 0000000..01ad7d9 --- /dev/null +++ b/Api/JsonSerialization/DeserializeUserJson.go @@ -0,0 +1,76 @@ +package JsonSerialization + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + + schema "github.com/Kangaroux/go-map-schema" +) + +func DeserializeUser(data []byte, allowMissing []string, allowAllMissing bool) (Models.User, error) { + var ( + postData Models.User = Models.User{} + jsonStructureTest map[string]interface{} = make(map[string]interface{}) + jsonStructureTestResults *schema.CompareResults + field schema.FieldMissing + allowed string + missingFields []string + i int + err error + ) + + // Verify the JSON has the correct structure + json.Unmarshal(data, &jsonStructureTest) + jsonStructureTestResults, err = schema.CompareMapToStruct( + &postData, + jsonStructureTest, + &schema.CompareOpts{ + ConvertibleFunc: CanConvert, + TypeNameFunc: schema.DetailedTypeName, + }) + if err != nil { + return postData, err + } + + if len(jsonStructureTestResults.MismatchedFields) > 0 { + return postData, errors.New(fmt.Sprintf( + "MismatchedFields found when deserializing data: %s", + jsonStructureTestResults.Errors().Error(), + )) + } + + // Remove allowed missing fields from MissingFields + for _, allowed = range allowMissing { + for i, field = range jsonStructureTestResults.MissingFields { + if allowed == field.String() { + jsonStructureTestResults.MissingFields = append( + jsonStructureTestResults.MissingFields[:i], + jsonStructureTestResults.MissingFields[i+1:]..., + ) + } + } + } + + if !allowAllMissing && len(jsonStructureTestResults.MissingFields) > 0 { + for _, field = range jsonStructureTestResults.MissingFields { + missingFields = append(missingFields, field.String()) + } + + return postData, errors.New(fmt.Sprintf( + "MissingFields found when deserializing data: %s", + strings.Join(missingFields, ", "), + )) + } + + // Deserialize the JSON into the struct + err = json.Unmarshal(data, &postData) + if err != nil { + return postData, err + } + + return postData, err +} diff --git a/Api/Posts.go b/Api/Posts.go index 60f50ec..e54af45 100644 --- a/Api/Posts.go +++ b/Api/Posts.go @@ -116,8 +116,6 @@ func createPost(w http.ResponseWriter, r *http.Request) { // TODO: Add auth - log.Printf("Posts handler recieved %s request", r.Method) - requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) diff --git a/Api/Routes.go b/Api/Routes.go index c023cc4..4ce2674 100644 --- a/Api/Routes.go +++ b/Api/Routes.go @@ -26,6 +26,9 @@ func InitApiEndpoints() *mux.Router { router.HandleFunc("/post/{postID}/image", createPostImage).Methods("POST") router.HandleFunc("/post/{postID}/image/{imageID}", deletePostImage).Methods("DELETE") + // Define routes for users api + router.HandleFunc("/user", createUser).Methods("POST") + //router.PathPrefix("/").Handler(http.StripPrefix("/images/", http.FileServer(http.Dir("./uploads")))) return router diff --git a/Api/Users.go b/Api/Users.go new file mode 100644 index 0000000..c343b69 --- /dev/null +++ b/Api/Users.go @@ -0,0 +1,102 @@ +package Api + +import ( + "encoding/json" + "io/ioutil" + "log" + "net/http" + "net/url" + "strconv" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" +) + +func getUsers(w http.ResponseWriter, r *http.Request) { + var ( + users []Models.User + returnJson []byte + values url.Values + page, pageSize int + err error + ) + + values = r.URL.Query() + + page, err = strconv.Atoi(values.Get("page")) + if err != nil { + log.Println("Could not parse page url argument") + JsonReturn(w, 500, "An error occured") + return + } + + page, err = strconv.Atoi(values.Get("pageSize")) + if err != nil { + log.Println("Could not parse pageSize url argument") + JsonReturn(w, 500, "An error occured") + return + } + + users, err = Database.GetUsers(page, pageSize) + if err != nil { + log.Printf("An error occured: %s\n", err.Error()) + JsonReturn(w, 500, "An error occured") + return + } + + returnJson, err = json.MarshalIndent(users, "", " ") + if err != nil { + JsonReturn(w, 500, "An error occured") + return + } + + // Return updated json + w.WriteHeader(http.StatusOK) + w.Write(returnJson) +} + +func createUser(w http.ResponseWriter, r *http.Request) { + var ( + userData Models.User + requestBody []byte + err error + ) + + requestBody, err = ioutil.ReadAll(r.Body) + if err != nil { + log.Printf("Error encountered reading POST body: %s\n", err.Error()) + JsonReturn(w, 500, "An error occured") + return + } + + userData, err = JsonSerialization.DeserializeUser(requestBody, []string{ + "id", + "last_login", + }, false) + if err != nil { + log.Printf("Invalid data provided to user API: %s\n", err.Error()) + JsonReturn(w, 405, "Invalid data") + return + } + + err = Database.CheckUniqueEmail(userData.Email) + if err != nil { + JsonReturn(w, 405, "invalid_email") + return + } + + if userData.Password != userData.ConfirmPassword { + JsonReturn(w, 500, "invalid_password") + return + } + + err = Database.CreateUser(&userData) + if err != nil { + JsonReturn(w, 405, "Invalid data") + return + } + + // Return updated json + w.WriteHeader(http.StatusOK) +} diff --git a/Api/Users_test.go b/Api/Users_test.go new file mode 100644 index 0000000..8007c12 --- /dev/null +++ b/Api/Users_test.go @@ -0,0 +1,132 @@ +package Api + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net/http" + "net/http/httptest" + "os" + "path" + "runtime" + "strings" + "testing" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + "github.com/gorilla/mux" + "gorm.io/gorm" +) + +func init() { + // Fix working directory for tests + _, filename, _, _ := runtime.Caller(0) + dir := path.Join(path.Dir(filename), "..") + err := os.Chdir(dir) + if err != nil { + panic(err) + } + + log.SetOutput(ioutil.Discard) + Database.Init() + + r = mux.NewRouter() +} + +var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func RandStringRunes(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + +func Test_getUsers(t *testing.T) { + t.Log("Testing getUsers...") + + r.HandleFunc("/user", getUsers).Methods("GET") + + ts := httptest.NewServer(r) + defer ts.Close() + + var err error + for i := 0; i < 20; i++ { + userData := Models.User{ + Email: fmt.Sprintf( + "%s@email.com", + RandStringRunes(16), + ), + Password: "password", + ConfirmPassword: "password", + } + + err = Database.CreateUser(&userData) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + } + + defer Database.DB. + Session(&gorm.Session{FullSaveAssociations: true}). + Unscoped(). + Delete(&userData) + } + + res, err := http.Get(ts.URL + "/user?page=1&pageSize=10") + + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + } + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + } + + getUsersData := new([]Models.User) + err = json.NewDecoder(res.Body).Decode(getUsersData) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + } + + if len(*getUsersData) != 10 { + t.Errorf("Expected 10, recieved %d", len(*getUsersData)) + } +} + +func Test_createUser(t *testing.T) { + t.Log("Testing createUser...") + + r.HandleFunc("/user", createUser).Methods("POST") + + ts := httptest.NewServer(r) + + defer ts.Close() + + postJson := ` +{ + "email": "email@email.com", + "password": "password", + "confirm_password": "password", + "first_name": "Hugh", + "last_name": "Mann" +} +` + + res, err := http.Post(ts.URL+"/user", "application/json", strings.NewReader(postJson)) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + return + } + + Database.DB.Model(Models.User{}). + Select("count(*) > 0"). + Where("email = ?", "email@email.com"). + Delete(Models.User{}) +} diff --git a/Database/Init.go b/Database/Init.go index a467b64..44c1930 100644 --- a/Database/Init.go +++ b/Database/Init.go @@ -43,4 +43,7 @@ func Init() { DB.AutoMigrate(&Models.SubscriptionEmailAttachment{}) DB.AutoMigrate(&Models.SubscriptionEmail{}) DB.AutoMigrate(&Models.Subscription{}) + + log.Println("Running AutoMigrate on User tables...") + DB.AutoMigrate(&Models.User{}) } diff --git a/Database/Users.go b/Database/Users.go new file mode 100644 index 0000000..e28b898 --- /dev/null +++ b/Database/Users.go @@ -0,0 +1,63 @@ +package Database + +import ( + "errors" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + + "gorm.io/gorm" +) + +func GetUsers(page, pageSize int) ([]Models.User, error) { + var ( + users []Models.User + err error + ) + + if page == 0 { + page = 1 + } + + switch { + case pageSize > 100: + pageSize = 100 + case pageSize <= 0: + pageSize = 10 + } + + err = DB.Offset(page). + Limit(pageSize). + Find(&users). + Error + + return users, err +} + +func CheckUniqueEmail(email string) error { + var ( + exists bool + err error + ) + + err = DB.Model(Models.User{}). + Select("count(*) > 0"). + Where("email = ?", email). + Find(&exists). + Error + + if err != nil { + return err + } + + if exists { + return errors.New("Invalid email") + } + + return nil +} + +func CreateUser(userData *Models.User) error { + return DB.Session(&gorm.Session{FullSaveAssociations: true}). + Create(userData). + Error +} diff --git a/Models/Users.go b/Models/Users.go new file mode 100644 index 0000000..27401e4 --- /dev/null +++ b/Models/Users.go @@ -0,0 +1,15 @@ +package Models + +import ( + "time" +) + +type User struct { + Base + Email string `gorm:"not null;unique" json:"email"` + Password string `gorm:"not null" json:"password"` + ConfirmPassword string `gorm:"-" json:"confirm_password"` + LastLogin *time.Time `json:"last_login"` + FirstName string `gorm:"not null" json:"first_name"` + LastName string `gorm:"not null" json:"last_name"` +} -- 2.17.1 From 4cb80bbb3a4c38024099ccbdb2b290aab2aec218 Mon Sep 17 00:00:00 2001 From: Tovi Jaeschke-Rogers Date: Wed, 16 Mar 2022 20:37:10 +1030 Subject: [PATCH 2/5] Add user_id to post model Add update and delete cruds to user API --- Api/Auth/Passwords.go | 22 +++ Api/JsonSerialization/VerifyJson.go | 5 + Api/PostImages_test.go | 17 +-- Api/Posts.go | 10 +- Api/Posts_test.go | 149 +++++++------------ Api/Routes.go | 9 +- Api/UserHelper.go | 51 +++++++ Api/Users.go | 103 ++++++++++++- Api/Users_test.go | 222 ++++++++++++++++++++++++---- Database/Init.go | 51 +++++-- Database/Users.go | 53 ++++++- Models/Posts.go | 9 +- Models/Users.go | 2 +- 13 files changed, 533 insertions(+), 170 deletions(-) create mode 100644 Api/Auth/Passwords.go create mode 100644 Api/UserHelper.go diff --git a/Api/Auth/Passwords.go b/Api/Auth/Passwords.go new file mode 100644 index 0000000..779c48e --- /dev/null +++ b/Api/Auth/Passwords.go @@ -0,0 +1,22 @@ +package Auth + +import ( + "golang.org/x/crypto/bcrypt" +) + +func HashPassword(password string) (string, error) { + var ( + bytes []byte + err error + ) + bytes, err = bcrypt.GenerateFromPassword([]byte(password), 14) + return string(bytes), err +} + +func CheckPasswordHash(password, hash string) bool { + var ( + err error + ) + err = bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} diff --git a/Api/JsonSerialization/VerifyJson.go b/Api/JsonSerialization/VerifyJson.go index c0e1b17..2ad1554 100644 --- a/Api/JsonSerialization/VerifyJson.go +++ b/Api/JsonSerialization/VerifyJson.go @@ -40,6 +40,7 @@ func isFloatType(t reflect.Type) (yes bool) { func CanConvert(t reflect.Type, v reflect.Value) bool { isPtr := t.Kind() == reflect.Ptr isStruct := t.Kind() == reflect.Struct + isArray := t.Kind() == reflect.Array dstType := t // Check if v is a nil value. @@ -58,6 +59,10 @@ func CanConvert(t reflect.Type, v reflect.Value) bool { return v.Kind() == reflect.Map } + if isArray { + return v.Kind() == reflect.String + } + if t.Kind() == reflect.Slice { return v.Kind() == reflect.Slice } diff --git a/Api/PostImages_test.go b/Api/PostImages_test.go index 0bf0761..9ab1947 100644 --- a/Api/PostImages_test.go +++ b/Api/PostImages_test.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "mime/multipart" "net/http" "net/http/httptest" @@ -22,7 +21,6 @@ import ( "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" "github.com/gorilla/mux" - "gorm.io/gorm" ) func init() { @@ -34,8 +32,7 @@ func init() { panic(err) } - log.SetOutput(ioutil.Discard) - Database.Init() + Database.InitTest() r = mux.NewRouter() } @@ -181,12 +178,6 @@ func Test_createPostImages(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } - defer Database.DB. - Session(&gorm.Session{FullSaveAssociations: true}). - Unscoped(). - Select("PostLinks", "PostImages"). - Delete(&postData) - if len(updatePostData.PostImages) != 1 { t.Errorf("Expected len(updatePostData.PostImages) == 1, recieved %d", len(updatePostData.PostImages)) } @@ -243,12 +234,6 @@ func Test_deletePostImages(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } - defer Database.DB. - Session(&gorm.Session{FullSaveAssociations: true}). - Unscoped(). - Select("PostLinks", "PostImages"). - Delete(&postData) - req, err := http.NewRequest("DELETE", fmt.Sprintf( "%s/post/%s/image/%s", ts.URL, diff --git a/Api/Posts.go b/Api/Posts.go index e54af45..c39adb5 100644 --- a/Api/Posts.go +++ b/Api/Posts.go @@ -11,8 +11,6 @@ import ( "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" - - "github.com/gorilla/mux" ) func getPosts(w http.ResponseWriter, r *http.Request) { @@ -131,6 +129,7 @@ func createPost(w http.ResponseWriter, r *http.Request) { "audios", }, false) if err != nil { + panic(err) log.Printf("Invalid data provided to posts API: %s\n", err.Error()) JsonReturn(w, 405, "Invalid data") return @@ -158,15 +157,12 @@ func updatePost(w http.ResponseWriter, r *http.Request) { postData Models.Post requestBody []byte returnJson []byte - urlVars map[string]string id string - ok bool err error ) - urlVars = mux.Vars(r) - id, ok = urlVars["postID"] - if !ok { + id, err = getPostId(r) + if err != nil { log.Printf("Error encountered getting id\n") JsonReturn(w, 500, "An error occured") return diff --git a/Api/Posts_test.go b/Api/Posts_test.go index d0ce2d2..3f381a0 100644 --- a/Api/Posts_test.go +++ b/Api/Posts_test.go @@ -3,8 +3,6 @@ package Api import ( "encoding/json" "fmt" - "io/ioutil" - "log" "net/http" "net/http/httptest" "os" @@ -17,7 +15,6 @@ import ( "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" "github.com/gorilla/mux" - "gorm.io/gorm" ) var ( @@ -33,12 +30,34 @@ func init() { panic(err) } - log.SetOutput(ioutil.Discard) - Database.Init() + Database.InitTest() r = mux.NewRouter() } +func createTestPost() (Models.Post, error) { + + userData, err := createTestUser(true) + + postData := Models.Post{ + UserID: userData.ID, + Title: "Test post", + Content: "Test content", + FrontPage: true, + Order: 1, + PostLinks: []Models.PostLink{ + { + Type: "Facebook", + Link: "http://facebook.com/", + }, + }, + } + + err = Database.CreatePost(&postData) + return postData, err + +} + func Test_getPosts(t *testing.T) { t.Log("Testing getPosts...") @@ -49,29 +68,7 @@ func Test_getPosts(t *testing.T) { var err error for i := 0; i < 20; i++ { - postData := Models.Post{ - Title: "Test post", - Content: "Test content", - FrontPage: true, - Order: i, - PostLinks: []Models.PostLink{ - { - Type: "Facebook", - Link: "http://facebook.com/", - }, - }, - } - - err = Database.CreatePost(&postData) - if err != nil { - t.Errorf("Expected nil, recieved %s", err.Error()) - } - - defer Database.DB. - Session(&gorm.Session{FullSaveAssociations: true}). - Unscoped(). - Select("PostLinks"). - Delete(&postData) + createTestPost() } res, err := http.Get(ts.URL + "/post?page=1&pageSize=10") @@ -103,30 +100,12 @@ func Test_getPost(t *testing.T) { defer ts.Close() - postData := Models.Post{ - Title: "Test post", - Content: "Test content", - FrontPage: true, - Order: 1, - PostLinks: []Models.PostLink{ - { - Type: "Facebook", - Link: "http://facebook.com/", - }, - }, - } - - err := Database.CreatePost(&postData) + postData, err := createTestPost() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() } - defer Database.DB. - Session(&gorm.Session{FullSaveAssociations: true}). - Unscoped(). - Select("PostLinks"). - Delete(&postData) - res, err := http.Get(fmt.Sprintf( "%s/post/%s", ts.URL, @@ -135,16 +114,36 @@ func Test_getPost(t *testing.T) { if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() } + if res.StatusCode != http.StatusOK { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + t.FailNow() } getPostData := new(Models.Post) err = json.NewDecoder(res.Body).Decode(getPostData) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + if getPostData.Title != "Test post" { + t.Errorf("Expected title \"Test post\", recieved %s", getPostData.Title) + t.FailNow() } + + if getPostData.Content != "Test content" { + t.Errorf("Expected content \"Test content\", recieved %s", getPostData.Content) + t.FailNow() + } + + if len(getPostData.PostLinks) != 1 { + t.Errorf("Expected len(PostLinks) == 1, recieved %d", len(getPostData.PostLinks)) + t.FailNow() + } + } func Test_createPost(t *testing.T) { @@ -156,8 +155,11 @@ func Test_createPost(t *testing.T) { defer ts.Close() + userData, err := createTestUser(true) + postJson := ` { + "user_id": "%s", "title": "Test post", "content": "Test content", "front_page": true, @@ -169,6 +171,8 @@ func Test_createPost(t *testing.T) { } ` + postJson = fmt.Sprintf(postJson, userData.ID.String()) + res, err := http.Post(ts.URL+"/post", "application/json", strings.NewReader(postJson)) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) @@ -183,12 +187,6 @@ func Test_createPost(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } - defer Database.DB. - Session(&gorm.Session{FullSaveAssociations: true}). - Unscoped(). - Select("PostLinks"). - Delete(&postData) - if postData.Title != "Test post" { t.Errorf("Expected title \"Test post\", recieved \"%s\"", postData.Title) } @@ -206,30 +204,12 @@ func Test_deletePost(t *testing.T) { defer ts.Close() - postData := Models.Post{ - Title: "Test post", - Content: "Test content", - FrontPage: true, - Order: 1, - PostLinks: []Models.PostLink{ - { - Type: "Facebook", - Link: "http://facebook.com/", - }, - }, - } - - err := Database.CreatePost(&postData) + postData, err := createTestPost() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() } - defer Database.DB. - Session(&gorm.Session{FullSaveAssociations: true}). - Unscoped(). - Select("PostLinks"). - Delete(&postData) - req, err := http.NewRequest("DELETE", fmt.Sprintf( "%s/post/%s", ts.URL, @@ -262,20 +242,7 @@ func Test_updatePost(t *testing.T) { defer ts.Close() - postData := Models.Post{ - Title: "Test post", - Content: "Test content", - FrontPage: true, - Order: 1, - PostLinks: []Models.PostLink{ - { - Type: "Facebook", - Link: "http://facebook.com/", - }, - }, - } - - err := Database.CreatePost(&postData) + postData, err := createTestPost() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) } @@ -315,12 +282,6 @@ func Test_updatePost(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } - defer Database.DB. - Session(&gorm.Session{FullSaveAssociations: true}). - Unscoped(). - Select("PostLinks"). - Delete(&postData) - if updatePostData.Content != "New test content" { t.Errorf("Expected \"New test content\", recieved %s", updatePostData.Content) } diff --git a/Api/Routes.go b/Api/Routes.go index 4ce2674..366d8cd 100644 --- a/Api/Routes.go +++ b/Api/Routes.go @@ -17,17 +17,22 @@ func InitApiEndpoints() *mux.Router { // Define routes for posts api router.HandleFunc("/post", getPosts).Methods("GET") - router.HandleFunc("/frontPagePosts", getFrontPagePosts).Methods("GET") router.HandleFunc("/post", createPost).Methods("POST") - router.HandleFunc("/post/{postID}", createPost).Methods("GET") + router.HandleFunc("/post/{postID}", getPost).Methods("GET") router.HandleFunc("/post/{postID}", updatePost).Methods("PUT") router.HandleFunc("/post/{postID}", deletePost).Methods("DELETE") + router.HandleFunc("/frontPagePosts", getFrontPagePosts).Methods("GET") + router.HandleFunc("/post/{postID}/image", createPostImage).Methods("POST") router.HandleFunc("/post/{postID}/image/{imageID}", deletePostImage).Methods("DELETE") // Define routes for users api + router.HandleFunc("/user", getUsers).Methods("GET") router.HandleFunc("/user", createUser).Methods("POST") + router.HandleFunc("/user/{userID}", getUser).Methods("GET") + router.HandleFunc("/user/{userID}", updatePost).Methods("PUT") + router.HandleFunc("/user/{userID}", deletePost).Methods("DELETE") //router.PathPrefix("/").Handler(http.StripPrefix("/images/", http.FileServer(http.Dir("./uploads")))) diff --git a/Api/UserHelper.go b/Api/UserHelper.go new file mode 100644 index 0000000..7a658ee --- /dev/null +++ b/Api/UserHelper.go @@ -0,0 +1,51 @@ +package Api + +import ( + "errors" + "log" + "net/http" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + + "github.com/gorilla/mux" +) + +func getUserId(r *http.Request) (string, error) { + var ( + urlVars map[string]string + id string + ok bool + ) + + urlVars = mux.Vars(r) + id, ok = urlVars["userID"] + if !ok { + return id, errors.New("Could not get id") + } + return id, nil +} + +func getUserById(w http.ResponseWriter, r *http.Request) (Models.User, error) { + var ( + postData Models.User + id string + err error + ) + + id, err = getUserId(r) + if err != nil { + log.Printf("Error encountered getting id\n") + JsonReturn(w, 500, "An error occured") + return postData, err + } + + postData, err = Database.GetUserById(id) + if err != nil { + log.Printf("Could not find pet with id %s\n", id) + JsonReturn(w, 404, "Not found") + return postData, err + } + + return postData, nil +} diff --git a/Api/Users.go b/Api/Users.go index c343b69..33f035d 100644 --- a/Api/Users.go +++ b/Api/Users.go @@ -8,6 +8,7 @@ import ( "net/url" "strconv" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/Auth" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" @@ -56,6 +57,29 @@ func getUsers(w http.ResponseWriter, r *http.Request) { w.Write(returnJson) } +func getUser(w http.ResponseWriter, r *http.Request) { + var ( + userData Models.User + returnJson []byte + err error + ) + + userData, err = getUserById(w, r) + if err != nil { + return + } + + returnJson, err = json.MarshalIndent(userData, "", " ") + if err != nil { + JsonReturn(w, 500, "An error occured") + return + } + + // Return updated json + w.WriteHeader(http.StatusOK) + w.Write(returnJson) +} + func createUser(w http.ResponseWriter, r *http.Request) { var ( userData Models.User @@ -87,16 +111,93 @@ func createUser(w http.ResponseWriter, r *http.Request) { } if userData.Password != userData.ConfirmPassword { - JsonReturn(w, 500, "invalid_password") + JsonReturn(w, 405, "invalid_password") + return + } + + userData.Password, err = Auth.HashPassword(userData.Password) + if err != nil { + JsonReturn(w, 500, "An error occured") return } err = Database.CreateUser(&userData) if err != nil { + JsonReturn(w, 500, "An error occured") + return + } + + // Return updated json + w.WriteHeader(http.StatusOK) +} + +func updateUser(w http.ResponseWriter, r *http.Request) { + var ( + userData Models.User + requestBody []byte + returnJson []byte + id string + err error + ) + + id, err = getUserId(r) + if err != nil { + log.Printf("Error encountered getting id\n") + JsonReturn(w, 500, "An error occured") + return + } + + requestBody, err = ioutil.ReadAll(r.Body) + if err != nil { + log.Printf("Error encountered reading POST body: %s\n", err.Error()) + JsonReturn(w, 500, "An error occured") + return + } + + userData, err = JsonSerialization.DeserializeUser(requestBody, []string{}, true) + if err != nil { + log.Printf("Invalid data provided to users API: %s\n", err.Error()) JsonReturn(w, 405, "Invalid data") return } + err = Database.UpdateUser(id, &userData) + if err != nil { + log.Printf("An error occured: %s\n", err.Error()) + JsonReturn(w, 500, "An error occured") + return + } + + returnJson, err = json.MarshalIndent(userData, "", " ") + if err != nil { + log.Printf("An error occured: %s\n", err.Error()) + JsonReturn(w, 500, "An error occured") + return + } + + // Return updated json + w.WriteHeader(http.StatusOK) + w.Write(returnJson) +} + +func deleteUser(w http.ResponseWriter, r *http.Request) { + var ( + userData Models.User + err error + ) + + userData, err = getUserById(w, r) + if err != nil { + return + } + + err = Database.DeleteUser(&userData) + if err != nil { + log.Printf("An error occured: %s\n", err.Error()) + JsonReturn(w, 500, "An error occured") + return + } + // Return updated json w.WriteHeader(http.StatusOK) } diff --git a/Api/Users_test.go b/Api/Users_test.go index 8007c12..162e83f 100644 --- a/Api/Users_test.go +++ b/Api/Users_test.go @@ -3,8 +3,6 @@ package Api import ( "encoding/json" "fmt" - "io/ioutil" - "log" "math/rand" "net/http" "net/http/httptest" @@ -13,11 +11,12 @@ import ( "runtime" "strings" "testing" + "time" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/Auth" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" "github.com/gorilla/mux" - "gorm.io/gorm" ) func init() { @@ -29,15 +28,14 @@ func init() { panic(err) } - log.SetOutput(ioutil.Discard) - Database.Init() + Database.InitTest() r = mux.NewRouter() } var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") -func RandStringRunes(n int) string { +func randString(n int) string { b := make([]rune, n) for i := range b { b[i] = letterRunes[rand.Intn(len(letterRunes))] @@ -45,6 +43,84 @@ func RandStringRunes(n int) string { return string(b) } +func createTestUser(random bool) (Models.User, error) { + now := time.Now() + + email := "email@email.com" + if random { + email = fmt.Sprintf("%s@email.com", randString(16)) + } + + password, err := Auth.HashPassword("password") + if err != nil { + return Models.User{}, err + } + + userData := Models.User{ + Email: email, + Password: password, + LastLogin: &now, + FirstName: "Hugh", + LastName: "Mann", + } + + err = Database.CreateUser(&userData) + return userData, err +} + +func Test_getUser(t *testing.T) { + t.Log("Testing getPost...") + + r.HandleFunc("/user/{userID}", getUser).Methods("GET") + + ts := httptest.NewServer(r) + + defer ts.Close() + + userData, err := createTestUser(false) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + res, err := http.Get(fmt.Sprintf( + "%s/user/%s", + ts.URL, + userData.ID, + )) + + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + t.FailNow() + } + + getUserData := new(Models.User) + err = json.NewDecoder(res.Body).Decode(getUserData) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + if getUserData.Email != "email@email.com" { + t.Errorf("Expected email \"email@email.com\", recieved %s", getUserData.Email) + t.FailNow() + } + + if getUserData.FirstName != "Hugh" { + t.Errorf("Expected email \"Hugh\", recieved %s", getUserData.FirstName) + t.FailNow() + } + + if getUserData.LastName != "Mann" { + t.Errorf("Expected email \"Mann\", recieved %s", getUserData.LastName) + t.FailNow() + } +} + func Test_getUsers(t *testing.T) { t.Log("Testing getUsers...") @@ -55,43 +131,30 @@ func Test_getUsers(t *testing.T) { var err error for i := 0; i < 20; i++ { - userData := Models.User{ - Email: fmt.Sprintf( - "%s@email.com", - RandStringRunes(16), - ), - Password: "password", - ConfirmPassword: "password", - } - - err = Database.CreateUser(&userData) - if err != nil { - t.Errorf("Expected nil, recieved %s", err.Error()) - } - - defer Database.DB. - Session(&gorm.Session{FullSaveAssociations: true}). - Unscoped(). - Delete(&userData) + createTestUser(true) } res, err := http.Get(ts.URL + "/user?page=1&pageSize=10") if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() } if res.StatusCode != http.StatusOK { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + t.FailNow() } getUsersData := new([]Models.User) err = json.NewDecoder(res.Body).Decode(getUsersData) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() } if len(*getUsersData) != 10 { t.Errorf("Expected 10, recieved %d", len(*getUsersData)) + t.FailNow() } } @@ -104,15 +167,18 @@ func Test_createUser(t *testing.T) { defer ts.Close() + email := fmt.Sprintf("%s@email.com", randString(16)) + postJson := ` { - "email": "email@email.com", + "email": "%s", "password": "password", "confirm_password": "password", "first_name": "Hugh", "last_name": "Mann" } ` + postJson = fmt.Sprintf(postJson, email) res, err := http.Post(ts.URL+"/user", "application/json", strings.NewReader(postJson)) if err != nil { @@ -124,9 +190,107 @@ func Test_createUser(t *testing.T) { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) return } +} + +func Test_updateUser(t *testing.T) { + t.Log("Testing updateUser...") + + r.HandleFunc("/user/{userID}", updateUser).Methods("PUT") + + ts := httptest.NewServer(r) + + defer ts.Close() + + userData, err := createTestUser(true) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + } + + email := fmt.Sprintf("%s@email.com", randString(16)) + + postJson := ` +{ + "email": "%s", + "first_name": "first", + "last_name": "last" +} +` + postJson = fmt.Sprintf(postJson, email) + + req, err := http.NewRequest("PUT", fmt.Sprintf( + "%s/user/%s", + ts.URL, + userData.ID, + ), strings.NewReader(postJson)) + + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + } + + // Fetch Request + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + } + + updateUserData := new(Models.User) + err = json.NewDecoder(res.Body).Decode(updateUserData) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + } + + if updateUserData.Email != email { + t.Errorf("Expected email \"%s\", recieved %s", email, updateUserData.Email) + } + + if updateUserData.FirstName != "first" { + t.Errorf("Expected FirstName \"first\", recieved %s", updateUserData.FirstName) + } + + if updateUserData.LastName != "last" { + t.Errorf("Expected LastName \"last\", recieved %s", updateUserData.LastName) + } +} + +func Test_deleteUser(t *testing.T) { + t.Log("Testing deleteUser...") + + r.HandleFunc("/user/{userID}", deleteUser).Methods("DELETE") + + ts := httptest.NewServer(r) + + defer ts.Close() + + userData, err := createTestUser(true) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + req, err := http.NewRequest("DELETE", fmt.Sprintf( + "%s/user/%s", + ts.URL, + userData.ID, + ), nil) + + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + } + + // Fetch Request + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + defer res.Body.Close() - Database.DB.Model(Models.User{}). - Select("count(*) > 0"). - Where("email = ?", "email@email.com"). - Delete(Models.User{}) + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + } } diff --git a/Database/Init.go b/Database/Init.go index 44c1930..20f5f92 100644 --- a/Database/Init.go +++ b/Database/Init.go @@ -10,14 +10,30 @@ import ( ) const dbUrl = "postgres://postgres:@localhost:5432/sudden_impact_records" +const dbTestUrl = "postgres://postgres:@localhost:5432/sudden_impact_records_test" var ( DB *gorm.DB ) +func GetModels() []interface{} { + return []interface{}{ + &Models.User{}, + &Models.PostImage{}, + &Models.PostVideo{}, + &Models.PostAudio{}, + &Models.PostLink{}, + &Models.Post{}, + &Models.SubscriptionEmailAttachment{}, + &Models.SubscriptionEmail{}, + &Models.Subscription{}, + } +} + func Init() { var ( - err error + model interface{} + err error ) log.Println("Initializing database...") @@ -28,22 +44,27 @@ func Init() { log.Fatalln(err) } - log.Println("Running AutoMigrate on Post tables...") + log.Println("Running AutoMigrate...") + + for _, model = range GetModels() { + DB.AutoMigrate(model) + } +} - // Post tables - DB.AutoMigrate(&Models.PostImage{}) - DB.AutoMigrate(&Models.PostVideo{}) - DB.AutoMigrate(&Models.PostAudio{}) - DB.AutoMigrate(&Models.PostLink{}) - DB.AutoMigrate(&Models.Post{}) +func InitTest() { + var ( + model interface{} + err error + ) - log.Println("Running AutoMigrate on Subscription tables...") + DB, err = gorm.Open(postgres.Open(dbTestUrl), &gorm.Config{}) - // Email subscription tables - DB.AutoMigrate(&Models.SubscriptionEmailAttachment{}) - DB.AutoMigrate(&Models.SubscriptionEmail{}) - DB.AutoMigrate(&Models.Subscription{}) + if err != nil { + log.Fatalln(err) + } - log.Println("Running AutoMigrate on User tables...") - DB.AutoMigrate(&Models.User{}) + for _, model = range GetModels() { + DB.Migrator().DropTable(model) + DB.AutoMigrate(model) + } } diff --git a/Database/Users.go b/Database/Users.go index e28b898..3f9d0a9 100644 --- a/Database/Users.go +++ b/Database/Users.go @@ -6,11 +6,28 @@ import ( "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" "gorm.io/gorm" + "gorm.io/gorm/clause" ) +func GetUserById(id string) (Models.User, error) { + var ( + userData Models.User + err error + ) + + err = DB.Preload(clause.Associations). + First(&userData, "id = ?", id). + Error + + userData.Password = "" + + return userData, err +} + func GetUsers(page, pageSize int) ([]Models.User, error) { var ( users []Models.User + i int err error ) @@ -30,6 +47,10 @@ func GetUsers(page, pageSize int) ([]Models.User, error) { Find(&users). Error + for i, _ = range users { + users[i].Password = "" + } + return users, err } @@ -57,7 +78,37 @@ func CheckUniqueEmail(email string) error { } func CreateUser(userData *Models.User) error { - return DB.Session(&gorm.Session{FullSaveAssociations: true}). + var ( + err error + ) + + err = DB.Session(&gorm.Session{FullSaveAssociations: true}). Create(userData). Error + + userData.Password = "" + + return err +} + +func UpdateUser(id string, userData *Models.User) error { + var ( + err error + ) + err = DB.Model(&Models.Post{}). + Select("*"). + Omit("id", "created_at", "updated_at", "deleted_at"). + Where("id = ?", id). + Updates(userData). + Error + + userData.Password = "" + + return err +} + +func DeleteUser(userData *Models.User) error { + return DB.Session(&gorm.Session{FullSaveAssociations: true}). + Delete(userData). + Error } diff --git a/Models/Posts.go b/Models/Posts.go index a7c5e0d..85760e7 100644 --- a/Models/Posts.go +++ b/Models/Posts.go @@ -6,10 +6,11 @@ import ( type Post struct { Base - Title string `gorm:"not null" json:"title"` - Content string `gorm:"not null" json:"content"` - FrontPage bool `gorm:"not null;type:boolean" json:"front_page"` - Order int `gorm:"not null" json:"order"` + UserID uuid.UUID `gorm:"type:uuid;column:user_id;not null;" json:"user_id"` + Title string `gorm:"not null" json:"title"` + Content string `gorm:"not null" json:"content"` + FrontPage bool `gorm:"not null;type:boolean" json:"front_page"` + Order int `gorm:"not null" json:"order"` PostLinks []PostLink `json:"links"` PostImages []PostImage `json:"images"` diff --git a/Models/Users.go b/Models/Users.go index 27401e4..ff3bc83 100644 --- a/Models/Users.go +++ b/Models/Users.go @@ -7,7 +7,7 @@ import ( type User struct { Base Email string `gorm:"not null;unique" json:"email"` - Password string `gorm:"not null" json:"password"` + Password string `gorm:"not null" json:"password,omitempty"` ConfirmPassword string `gorm:"-" json:"confirm_password"` LastLogin *time.Time `json:"last_login"` FirstName string `gorm:"not null" json:"first_name"` -- 2.17.1 From 946c4913fa72cb0be6a2ee7bdf405d5c67348b2b Mon Sep 17 00:00:00 2001 From: Tovi Jaeschke-Rogers Date: Mon, 21 Mar 2022 04:03:20 +1030 Subject: [PATCH 3/5] Add Login and Logout routes --- Api/Auth/Login.go | 65 ++++++++++++++++++++++++ Api/Auth/Login_test.go | 111 +++++++++++++++++++++++++++++++++++++++++ Api/Auth/Logout.go | 34 +++++++++++++ Api/Auth/Session.go | 51 +++++++++++++++++++ Api/Routes.go | 6 +++ Api/Users_test.go | 2 +- Database/Users.go | 13 +++++ 7 files changed, 281 insertions(+), 1 deletion(-) create mode 100644 Api/Auth/Login.go create mode 100644 Api/Auth/Login_test.go create mode 100644 Api/Auth/Logout.go create mode 100644 Api/Auth/Session.go diff --git a/Api/Auth/Login.go b/Api/Auth/Login.go new file mode 100644 index 0000000..5d42e5c --- /dev/null +++ b/Api/Auth/Login.go @@ -0,0 +1,65 @@ +package Auth + +import ( + "encoding/json" + "net/http" + "time" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + + "github.com/gofrs/uuid" +) + +type Credentials struct { + Email string `json:"email"` + Password string `json:"password"` +} + +func Login(w http.ResponseWriter, r *http.Request) { + var ( + creds Credentials + userData Models.User + sessionToken uuid.UUID + expiresAt time.Time + err error + ) + + err = json.NewDecoder(r.Body).Decode(&creds) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + userData, err = Database.GetUserByEmail(creds.Email) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + if !CheckPasswordHash(creds.Password, userData.Password) { + w.WriteHeader(http.StatusUnauthorized) + return + } + + sessionToken, err = uuid.NewV4() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + expiresAt = time.Now().Add(1 * time.Hour) + + Sessions[sessionToken.String()] = Session{ + Username: userData.Email, + Expiry: expiresAt, + } + + http.SetCookie(w, &http.Cookie{ + Name: "session_token", + Value: sessionToken.String(), + Expires: expiresAt, + }) + + w.WriteHeader(http.StatusOK) +} diff --git a/Api/Auth/Login_test.go b/Api/Auth/Login_test.go new file mode 100644 index 0000000..5861107 --- /dev/null +++ b/Api/Auth/Login_test.go @@ -0,0 +1,111 @@ +package Auth + +import ( + "fmt" + "math/rand" + "net/http" + "net/http/httptest" + "os" + "path" + "runtime" + "strings" + "testing" + "time" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + + "github.com/gorilla/mux" +) + +var ( + r *mux.Router + letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +) + +func init() { + // Fix working directory for tests + _, filename, _, _ := runtime.Caller(0) + dir := path.Join(path.Dir(filename), "..") + err := os.Chdir(dir) + if err != nil { + panic(err) + } + + Database.InitTest() + + r = mux.NewRouter() +} + +func randString(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + +func createTestUser(random bool) (Models.User, error) { + now := time.Now() + + email := "email@email.com" + if random { + email = fmt.Sprintf("%s@email.com", randString(16)) + } + + password, err := HashPassword("password") + if err != nil { + return Models.User{}, err + } + + userData := Models.User{ + Email: email, + Password: password, + LastLogin: &now, + FirstName: "Hugh", + LastName: "Mann", + } + + err = Database.CreateUser(&userData) + return userData, err +} + +func Test_Login(t *testing.T) { + t.Log("Testing Login...") + + r.HandleFunc("/admin/login", Login).Methods("POST") + + ts := httptest.NewServer(r) + + defer ts.Close() + + userData, err := createTestUser(true) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + postJson := ` +{ + "email": "%s", + "password": "password" +} +` + postJson = fmt.Sprintf(postJson, userData.Email) + + res, err := http.Post(ts.URL+"/admin/login", "application/json", strings.NewReader(postJson)) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + return + } + + if len(res.Cookies()) != 1 { + t.Errorf("Expected cookies len 1, recieved %d", len(res.Cookies())) + return + } +} diff --git a/Api/Auth/Logout.go b/Api/Auth/Logout.go new file mode 100644 index 0000000..822b21d --- /dev/null +++ b/Api/Auth/Logout.go @@ -0,0 +1,34 @@ +package Auth + +import ( + "net/http" + "time" +) + +func Logout(w http.ResponseWriter, r *http.Request) { + var ( + c *http.Cookie + sessionToken string + err error + ) + + c, err = r.Cookie("session_token") + if err != nil { + if err == http.ErrNoCookie { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.WriteHeader(http.StatusBadRequest) + return + } + + sessionToken = c.Value + + delete(Sessions, sessionToken) + + http.SetCookie(w, &http.Cookie{ + Name: "session_token", + Value: "", + Expires: time.Now(), + }) +} diff --git a/Api/Auth/Session.go b/Api/Auth/Session.go new file mode 100644 index 0000000..3e2c23f --- /dev/null +++ b/Api/Auth/Session.go @@ -0,0 +1,51 @@ +package Auth + +import ( + "errors" + "net/http" + "time" +) + +var ( + Sessions = map[string]Session{} +) + +type Session struct { + Username string + Expiry time.Time +} + +func (s Session) IsExpired() bool { + return s.Expiry.Before(time.Now()) +} + +func CheckCookie(r *http.Request) (Session, error) { + var ( + c *http.Cookie + sessionToken string + userSession Session + exists bool + err error + ) + + c, err = r.Cookie("session_token") + if err != nil { + return userSession, err + } + sessionToken = c.Value + + // We then get the session from our session map + userSession, exists = Sessions[sessionToken] + if !exists { + return userSession, errors.New("Cookie not found") + } + + // If the session is present, but has expired, we can delete the session, and return + // an unauthorized status + if userSession.IsExpired() { + delete(Sessions, sessionToken) + return userSession, errors.New("Cookie expired") + } + + return userSession, nil +} diff --git a/Api/Routes.go b/Api/Routes.go index 366d8cd..e29633a 100644 --- a/Api/Routes.go +++ b/Api/Routes.go @@ -3,6 +3,8 @@ package Api import ( "log" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/Auth" + "github.com/gorilla/mux" ) @@ -34,6 +36,10 @@ func InitApiEndpoints() *mux.Router { router.HandleFunc("/user/{userID}", updatePost).Methods("PUT") router.HandleFunc("/user/{userID}", deletePost).Methods("DELETE") + // Define routes for authentication + router.HandleFunc("/admin/login", Auth.Login).Methods("POST") + router.HandleFunc("/admin/logout", Auth.Logout).Methods("GET") + //router.PathPrefix("/").Handler(http.StripPrefix("/images/", http.FileServer(http.Dir("./uploads")))) return router diff --git a/Api/Users_test.go b/Api/Users_test.go index 162e83f..64cb6e7 100644 --- a/Api/Users_test.go +++ b/Api/Users_test.go @@ -69,7 +69,7 @@ func createTestUser(random bool) (Models.User, error) { } func Test_getUser(t *testing.T) { - t.Log("Testing getPost...") + t.Log("Testing getUser...") r.HandleFunc("/user/{userID}", getUser).Methods("GET") diff --git a/Database/Users.go b/Database/Users.go index 3f9d0a9..0289823 100644 --- a/Database/Users.go +++ b/Database/Users.go @@ -24,6 +24,19 @@ func GetUserById(id string) (Models.User, error) { return userData, err } +func GetUserByEmail(email string) (Models.User, error) { + var ( + userData Models.User + err error + ) + + err = DB.Preload(clause.Associations). + First(&userData, "email = ?", email). + Error + + return userData, err +} + func GetUsers(page, pageSize int) ([]Models.User, error) { var ( users []Models.User -- 2.17.1 From d584d40a5278aeb412ad84201a56d60209ca0e4f Mon Sep 17 00:00:00 2001 From: Tovi Jaeschke-Rogers Date: Mon, 21 Mar 2022 05:03:05 +1030 Subject: [PATCH 4/5] Add authentication to all required endpoints --- Api/Auth/ChangePassword.go | 53 ++++++++++++++ Api/Auth/Login.go | 5 +- Api/Auth/Session.go | 32 ++++++++- Api/JsonSerialization/VerifyJson.go | 50 ++++++++----- Api/PostImages.go | 22 +++--- Api/Posts.go | 61 ++++++++++------ Api/Posts_test.go | 37 +++++++++- Api/Users.go | 74 ++++++++++++-------- Api/Users_test.go | 104 ++++++++++++++++++++++++---- Database/Users.go | 2 +- {Api => Util}/PostHelper.go | 8 +-- {Api => Util}/PostImageHelper.go | 8 +-- {Api => Util}/ReturnJson.go | 2 +- {Api => Util}/UserHelper.go | 8 +-- 14 files changed, 353 insertions(+), 113 deletions(-) create mode 100644 Api/Auth/ChangePassword.go rename {Api => Util}/PostHelper.go (84%) rename {Api => Util}/PostImageHelper.go (84%) rename {Api => Util}/ReturnJson.go (97%) rename {Api => Util}/UserHelper.go (84%) diff --git a/Api/Auth/ChangePassword.go b/Api/Auth/ChangePassword.go new file mode 100644 index 0000000..9187665 --- /dev/null +++ b/Api/Auth/ChangePassword.go @@ -0,0 +1,53 @@ +package Auth + +import ( + "encoding/json" + "net/http" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" +) + +type ChangePassword struct { + Password string `json:"password"` + ConfirmPassword string `json:"confirm_password"` +} + +func UpdatePassword(w http.ResponseWriter, r *http.Request) { + var ( + changePasswd ChangePassword + userData Models.User + err error + ) + + userData, err = CheckCookieCurrentUser(w, r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + err = json.NewDecoder(r.Body).Decode(&changePasswd) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + if changePasswd.Password != changePasswd.ConfirmPassword { + w.WriteHeader(http.StatusBadRequest) + return + } + + userData.Password, err = HashPassword(changePasswd.Password) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + err = Database.UpdateUser(userData.ID.String(), &userData) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/Api/Auth/Login.go b/Api/Auth/Login.go index 5d42e5c..0ae4a0a 100644 --- a/Api/Auth/Login.go +++ b/Api/Auth/Login.go @@ -51,8 +51,9 @@ func Login(w http.ResponseWriter, r *http.Request) { expiresAt = time.Now().Add(1 * time.Hour) Sessions[sessionToken.String()] = Session{ - Username: userData.Email, - Expiry: expiresAt, + UserID: userData.ID.String(), + Email: userData.Email, + Expiry: expiresAt, } http.SetCookie(w, &http.Cookie{ diff --git a/Api/Auth/Session.go b/Api/Auth/Session.go index 3e2c23f..b647376 100644 --- a/Api/Auth/Session.go +++ b/Api/Auth/Session.go @@ -4,6 +4,9 @@ import ( "errors" "net/http" "time" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util" ) var ( @@ -11,8 +14,9 @@ var ( ) type Session struct { - Username string - Expiry time.Time + UserID string + Email string + Expiry time.Time } func (s Session) IsExpired() bool { @@ -49,3 +53,27 @@ func CheckCookie(r *http.Request) (Session, error) { return userSession, nil } + +func CheckCookieCurrentUser(w http.ResponseWriter, r *http.Request) (Models.User, error) { + var ( + userSession Session + userData Models.User + err error + ) + + userSession, err = CheckCookie(r) + if err != nil { + return userData, err + } + + userData, err = Util.GetUserById(w, r) + if err != nil { + return userData, err + } + + if userData.ID.String() != userSession.UserID { + return userData, errors.New("Is not current user") + } + + return userData, nil +} diff --git a/Api/JsonSerialization/VerifyJson.go b/Api/JsonSerialization/VerifyJson.go index 2ad1554..862e7c8 100644 --- a/Api/JsonSerialization/VerifyJson.go +++ b/Api/JsonSerialization/VerifyJson.go @@ -7,7 +7,11 @@ import ( // isIntegerType returns whether the type is an integer and if it's unsigned. // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L328 -func isIntegerType(t reflect.Type) (yes bool, unsigned bool) { +func isIntegerType(t reflect.Type) (bool, bool) { + var ( + yes bool + unsigned bool + ) switch t.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: yes = true @@ -16,19 +20,22 @@ func isIntegerType(t reflect.Type) (yes bool, unsigned bool) { unsigned = true } - return + return yes, unsigned } // isFloatType returns true if the type is a floating point. Note that this doesn't // care about the value -- unmarshaling the number "0" gives a float, not an int. // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L319 -func isFloatType(t reflect.Type) (yes bool) { +func isFloatType(t reflect.Type) bool { + var ( + yes bool + ) switch t.Kind() { case reflect.Float32, reflect.Float64: yes = true } - return + return yes } // CanConvert returns whether value v is convertible to type t. @@ -38,10 +45,20 @@ func isFloatType(t reflect.Type) (yes bool) { // Modified due to not handling slices (DefaultCanConvert fails on PhotoUrls and Tags) // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L191 func CanConvert(t reflect.Type, v reflect.Value) bool { - isPtr := t.Kind() == reflect.Ptr - isStruct := t.Kind() == reflect.Struct - isArray := t.Kind() == reflect.Array - dstType := t + var ( + isPtr bool + isStruct bool + isArray bool + dstType reflect.Type + dstInt bool + unsigned bool + f float64 + srcInt bool + ) + isPtr = t.Kind() == reflect.Ptr + isStruct = t.Kind() == reflect.Struct + isArray = t.Kind() == reflect.Array + dstType = t // Check if v is a nil value. if !v.IsValid() || (v.CanAddr() && v.IsNil()) { @@ -72,20 +89,19 @@ func CanConvert(t reflect.Type, v reflect.Value) bool { } // Handle converting to an integer type. - if dstInt, unsigned := isIntegerType(dstType); dstInt { + dstInt, unsigned = isIntegerType(dstType) + if dstInt { if isFloatType(v.Type()) { - f := v.Float() + f = v.Float() - if math.Trunc(f) != f { - return false - } else if unsigned && f < 0 { - return false - } - } else if srcInt, _ := isIntegerType(v.Type()); srcInt { - if unsigned && v.Int() < 0 { + if math.Trunc(f) != f || unsigned && f < 0 { return false } } + srcInt, _ = isIntegerType(v.Type()) + if srcInt && unsigned && v.Int() < 0 { + return false + } } return true diff --git a/Api/PostImages.go b/Api/PostImages.go index 62cc710..dea0442 100644 --- a/Api/PostImages.go +++ b/Api/PostImages.go @@ -30,10 +30,10 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { err error ) - postID, err = getPostId(r) + postID, err = Util.GetPostId(r) if err != nil { log.Printf("Error encountered getting id\n") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -42,7 +42,7 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { err = r.ParseMultipartForm(20 << 20) if err != nil { log.Printf("Error encountered parsing multipart form: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -54,7 +54,7 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { file, err = fileHeader.Open() if err != nil { log.Printf("Error encountered while post image upload: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } defer file.Close() @@ -62,14 +62,14 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { fileBytes, err = ioutil.ReadAll(file) if err != nil { log.Printf("Error encountered while post image upload: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } fileObject, err = Util.WriteFile(fileBytes, "image") if err != nil { log.Printf("Error encountered while post image upload: %s\n", err.Error()) - JsonReturn(w, 415, "Invalid filetype") + Util.JsonReturn(w, 415, "Invalid filetype") return } @@ -83,19 +83,19 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { err = Database.CreatePostImage(&postImage) if err != nil { log.Printf("Error encountered while creating post_image record: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } } - postData, err = getPostById(w, r) + postData, err = Util.GetPostById(w, r) if err != nil { return } returnJson, err = json.MarshalIndent(postData, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -110,7 +110,7 @@ func deletePostImage(w http.ResponseWriter, r *http.Request) { err error ) - postImageData, err = getPostImageById(w, r) + postImageData, err = Util.GetPostImageById(w, r) if err != nil { return } @@ -118,7 +118,7 @@ func deletePostImage(w http.ResponseWriter, r *http.Request) { err = Database.DeletePostImage(&postImageData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } diff --git a/Api/Posts.go b/Api/Posts.go index c39adb5..f5aed9c 100644 --- a/Api/Posts.go +++ b/Api/Posts.go @@ -8,9 +8,11 @@ import ( "net/url" "strconv" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/Auth" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util" ) func getPosts(w http.ResponseWriter, r *http.Request) { @@ -27,27 +29,27 @@ func getPosts(w http.ResponseWriter, r *http.Request) { page, err = strconv.Atoi(values.Get("page")) if err != nil { log.Println("Could not parse page url argument") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } page, err = strconv.Atoi(values.Get("pageSize")) if err != nil { log.Println("Could not parse pageSize url argument") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } posts, err = Database.GetPosts(page, pageSize) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(posts, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -66,13 +68,13 @@ func getFrontPagePosts(w http.ResponseWriter, r *http.Request) { posts, err = Database.GetFrontPagePosts() if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(posts, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -88,14 +90,14 @@ func getPost(w http.ResponseWriter, r *http.Request) { err error ) - postData, err = getPostById(w, r) + postData, err = Util.GetPostById(w, r) if err != nil { return } returnJson, err = json.MarshalIndent(postData, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -112,12 +114,16 @@ func createPost(w http.ResponseWriter, r *http.Request) { err error ) - // TODO: Add auth + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -129,21 +135,20 @@ func createPost(w http.ResponseWriter, r *http.Request) { "audios", }, false) if err != nil { - panic(err) log.Printf("Invalid data provided to posts API: %s\n", err.Error()) - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") return } err = Database.CreatePost(&postData) if err != nil { - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") } returnJson, err = json.MarshalIndent(postData, "", " ") if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -161,38 +166,44 @@ func updatePost(w http.ResponseWriter, r *http.Request) { err error ) - id, err = getPostId(r) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + id, err = Util.GetPostId(r) if err != nil { log.Printf("Error encountered getting id\n") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } postData, err = JsonSerialization.DeserializePost(requestBody, []string{}, true) if err != nil { log.Printf("Invalid data provided to posts API: %s\n", err.Error()) - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") return } postData, err = Database.UpdatePost(id, &postData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(postData, "", " ") if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -207,7 +218,13 @@ func deletePost(w http.ResponseWriter, r *http.Request) { err error ) - postData, err = getPostById(w, r) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + postData, err = Util.GetPostById(w, r) if err != nil { return } @@ -215,7 +232,7 @@ func deletePost(w http.ResponseWriter, r *http.Request) { err = Database.DeletePost(&postData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } diff --git a/Api/Posts_test.go b/Api/Posts_test.go index 3f381a0..6084d4a 100644 --- a/Api/Posts_test.go +++ b/Api/Posts_test.go @@ -155,7 +155,11 @@ func Test_createPost(t *testing.T) { defer ts.Close() - userData, err := createTestUser(true) + c, u, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } postJson := ` { @@ -171,14 +175,25 @@ func Test_createPost(t *testing.T) { } ` - postJson = fmt.Sprintf(postJson, userData.ID.String()) + postJson = fmt.Sprintf(postJson, u.ID.String()) + + req, err := http.NewRequest("POST", ts.URL+"/post", strings.NewReader(postJson)) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + req.AddCookie(c) - res, err := http.Post(ts.URL+"/post", "application/json", strings.NewReader(postJson)) + res, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + return } + if res.StatusCode != http.StatusOK { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + return } postData := new(Models.Post) @@ -204,6 +219,12 @@ func Test_deletePost(t *testing.T) { defer ts.Close() + c, _, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + postData, err := createTestPost() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) @@ -220,6 +241,8 @@ func Test_deletePost(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } + req.AddCookie(c) + // Fetch Request res, err := http.DefaultClient.Do(req) if err != nil { @@ -242,6 +265,12 @@ func Test_updatePost(t *testing.T) { defer ts.Close() + c, _, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + postData, err := createTestPost() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) @@ -265,6 +294,8 @@ func Test_updatePost(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } + req.AddCookie(c) + // Fetch Request res, err := http.DefaultClient.Do(req) if err != nil { diff --git a/Api/Users.go b/Api/Users.go index 33f035d..9ce2e79 100644 --- a/Api/Users.go +++ b/Api/Users.go @@ -12,6 +12,7 @@ import ( "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util" ) func getUsers(w http.ResponseWriter, r *http.Request) { @@ -23,32 +24,37 @@ func getUsers(w http.ResponseWriter, r *http.Request) { err error ) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + values = r.URL.Query() page, err = strconv.Atoi(values.Get("page")) if err != nil { log.Println("Could not parse page url argument") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } page, err = strconv.Atoi(values.Get("pageSize")) if err != nil { log.Println("Could not parse pageSize url argument") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } - users, err = Database.GetUsers(page, pageSize) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(users, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -64,14 +70,20 @@ func getUser(w http.ResponseWriter, r *http.Request) { err error ) - userData, err = getUserById(w, r) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + userData, err = Util.GetUserById(w, r) if err != nil { return } returnJson, err = json.MarshalIndent(userData, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -90,7 +102,7 @@ func createUser(w http.ResponseWriter, r *http.Request) { requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -100,30 +112,30 @@ func createUser(w http.ResponseWriter, r *http.Request) { }, false) if err != nil { log.Printf("Invalid data provided to user API: %s\n", err.Error()) - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") return } err = Database.CheckUniqueEmail(userData.Email) if err != nil { - JsonReturn(w, 405, "invalid_email") + Util.JsonReturn(w, 405, "invalid_email") return } if userData.Password != userData.ConfirmPassword { - JsonReturn(w, 405, "invalid_password") + Util.JsonReturn(w, 405, "invalid_password") return } userData.Password, err = Auth.HashPassword(userData.Password) if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } err = Database.CreateUser(&userData) if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -133,45 +145,44 @@ func createUser(w http.ResponseWriter, r *http.Request) { func updateUser(w http.ResponseWriter, r *http.Request) { var ( - userData Models.User - requestBody []byte - returnJson []byte - id string - err error + currentUserData Models.User + userData Models.User + requestBody []byte + returnJson []byte + err error ) - id, err = getUserId(r) + currentUserData, err = Auth.CheckCookieCurrentUser(w, r) if err != nil { - log.Printf("Error encountered getting id\n") - JsonReturn(w, 500, "An error occured") + w.WriteHeader(http.StatusUnauthorized) return } requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } userData, err = JsonSerialization.DeserializeUser(requestBody, []string{}, true) if err != nil { log.Printf("Invalid data provided to users API: %s\n", err.Error()) - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") return } - err = Database.UpdateUser(id, &userData) + err = Database.UpdateUser(currentUserData.ID.String(), &userData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(userData, "", " ") if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -186,15 +197,22 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { err error ) - userData, err = getUserById(w, r) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + userData, err = Util.GetUserById(w, r) if err != nil { + w.WriteHeader(http.StatusNotFound) return } err = Database.DeleteUser(&userData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } diff --git a/Api/Users_test.go b/Api/Users_test.go index 64cb6e7..1125999 100644 --- a/Api/Users_test.go +++ b/Api/Users_test.go @@ -2,6 +2,7 @@ package Api import ( "encoding/json" + "errors" "fmt" "math/rand" "net/http" @@ -68,6 +69,47 @@ func createTestUser(random bool) (Models.User, error) { return userData, err } +func login() (*http.Cookie, Models.User, error) { + var ( + c *http.Cookie + u Models.User + ) + + r.HandleFunc("/admin/login", Auth.Login).Methods("POST") + + ts := httptest.NewServer(r) + + defer ts.Close() + + u, err := createTestUser(true) + if err != nil { + return c, u, err + } + + postJson := ` +{ + "email": "%s", + "password": "password" +} +` + postJson = fmt.Sprintf(postJson, u.Email) + + res, err := http.Post(ts.URL+"/admin/login", "application/json", strings.NewReader(postJson)) + if err != nil { + return c, u, err + } + + if res.StatusCode != http.StatusOK { + return c, u, errors.New("Invalid res.StatusCode") + } + + if len(res.Cookies()) != 1 { + return c, u, errors.New("Invalid cookies length") + } + + return res.Cookies()[0], u, nil +} + func Test_getUser(t *testing.T) { t.Log("Testing getUser...") @@ -77,22 +119,31 @@ func Test_getUser(t *testing.T) { defer ts.Close() - userData, err := createTestUser(false) + c, u, err := login() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) t.FailNow() } - res, err := http.Get(fmt.Sprintf( + req, err := http.NewRequest("GET", fmt.Sprintf( "%s/user/%s", ts.URL, - userData.ID, - )) + u.ID, + ), nil) + + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + req.AddCookie(c) + res, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) t.FailNow() } + if res.StatusCode != http.StatusOK { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) t.FailNow() @@ -105,18 +156,18 @@ func Test_getUser(t *testing.T) { t.FailNow() } - if getUserData.Email != "email@email.com" { - t.Errorf("Expected email \"email@email.com\", recieved %s", getUserData.Email) + if getUserData.Email != u.Email { + t.Errorf("Expected email \"%s\", recieved %s", u.Email, getUserData.Email) t.FailNow() } - if getUserData.FirstName != "Hugh" { - t.Errorf("Expected email \"Hugh\", recieved %s", getUserData.FirstName) + if getUserData.FirstName != u.FirstName { + t.Errorf("Expected email \"%s\", recieved %s", u.FirstName, getUserData.FirstName) t.FailNow() } - if getUserData.LastName != "Mann" { - t.Errorf("Expected email \"Mann\", recieved %s", getUserData.LastName) + if getUserData.LastName != u.LastName { + t.Errorf("Expected email \"%s\", recieved %s", u.LastName, getUserData.LastName) t.FailNow() } } @@ -129,17 +180,31 @@ func Test_getUsers(t *testing.T) { ts := httptest.NewServer(r) defer ts.Close() - var err error + c, _, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + for i := 0; i < 20; i++ { createTestUser(true) } - res, err := http.Get(ts.URL + "/user?page=1&pageSize=10") + req, err := http.NewRequest("GET", ts.URL+"/user?page=1&pageSize=10", nil) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) t.FailNow() } + + req.AddCookie(c) + + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + if res.StatusCode != http.StatusOK { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) t.FailNow() @@ -201,9 +266,10 @@ func Test_updateUser(t *testing.T) { defer ts.Close() - userData, err := createTestUser(true) + c, u, err := login() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() } email := fmt.Sprintf("%s@email.com", randString(16)) @@ -220,13 +286,15 @@ func Test_updateUser(t *testing.T) { req, err := http.NewRequest("PUT", fmt.Sprintf( "%s/user/%s", ts.URL, - userData.ID, + u.ID, ), strings.NewReader(postJson)) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) } + req.AddCookie(c) + // Fetch Request res, err := http.DefaultClient.Do(req) if err != nil { @@ -266,6 +334,12 @@ func Test_deleteUser(t *testing.T) { defer ts.Close() + c, _, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + userData, err := createTestUser(true) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) @@ -282,6 +356,8 @@ func Test_deleteUser(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } + req.AddCookie(c) + // Fetch Request res, err := http.DefaultClient.Do(req) if err != nil { diff --git a/Database/Users.go b/Database/Users.go index 0289823..7dc8b03 100644 --- a/Database/Users.go +++ b/Database/Users.go @@ -108,7 +108,7 @@ func UpdateUser(id string, userData *Models.User) error { var ( err error ) - err = DB.Model(&Models.Post{}). + err = DB.Model(&Models.User{}). Select("*"). Omit("id", "created_at", "updated_at", "deleted_at"). Where("id = ?", id). diff --git a/Api/PostHelper.go b/Util/PostHelper.go similarity index 84% rename from Api/PostHelper.go rename to Util/PostHelper.go index b12d3fe..697ff44 100644 --- a/Api/PostHelper.go +++ b/Util/PostHelper.go @@ -1,4 +1,4 @@ -package Api +package Util import ( "errors" @@ -11,7 +11,7 @@ import ( "github.com/gorilla/mux" ) -func getPostId(r *http.Request) (string, error) { +func GetPostId(r *http.Request) (string, error) { var ( urlVars map[string]string id string @@ -26,14 +26,14 @@ func getPostId(r *http.Request) (string, error) { return id, nil } -func getPostById(w http.ResponseWriter, r *http.Request) (Models.Post, error) { +func GetPostById(w http.ResponseWriter, r *http.Request) (Models.Post, error) { var ( postData Models.Post id string err error ) - id, err = getPostId(r) + id, err = GetPostId(r) if err != nil { log.Printf("Error encountered getting id\n") JsonReturn(w, 500, "An error occured") diff --git a/Api/PostImageHelper.go b/Util/PostImageHelper.go similarity index 84% rename from Api/PostImageHelper.go rename to Util/PostImageHelper.go index 06142c4..9ee16b3 100644 --- a/Api/PostImageHelper.go +++ b/Util/PostImageHelper.go @@ -1,4 +1,4 @@ -package Api +package Util import ( "errors" @@ -10,7 +10,7 @@ import ( "github.com/gorilla/mux" ) -func getPostImageId(r *http.Request) (string, error) { +func GetPostImageId(r *http.Request) (string, error) { var ( urlVars map[string]string id string @@ -25,14 +25,14 @@ func getPostImageId(r *http.Request) (string, error) { return id, nil } -func getPostImageById(w http.ResponseWriter, r *http.Request) (Models.PostImage, error) { +func GetPostImageById(w http.ResponseWriter, r *http.Request) (Models.PostImage, error) { var ( postImageData Models.PostImage id string err error ) - id, err = getPostImageId(r) + id, err = GetPostImageId(r) if err != nil { log.Printf("Error encountered getting id\n") JsonReturn(w, 500, "An error occured") diff --git a/Api/ReturnJson.go b/Util/ReturnJson.go similarity index 97% rename from Api/ReturnJson.go rename to Util/ReturnJson.go index 002c091..747fcf2 100644 --- a/Api/ReturnJson.go +++ b/Util/ReturnJson.go @@ -1,4 +1,4 @@ -package Api +package Util import ( "encoding/json" diff --git a/Api/UserHelper.go b/Util/UserHelper.go similarity index 84% rename from Api/UserHelper.go rename to Util/UserHelper.go index 7a658ee..85df401 100644 --- a/Api/UserHelper.go +++ b/Util/UserHelper.go @@ -1,4 +1,4 @@ -package Api +package Util import ( "errors" @@ -11,7 +11,7 @@ import ( "github.com/gorilla/mux" ) -func getUserId(r *http.Request) (string, error) { +func GetUserId(r *http.Request) (string, error) { var ( urlVars map[string]string id string @@ -26,14 +26,14 @@ func getUserId(r *http.Request) (string, error) { return id, nil } -func getUserById(w http.ResponseWriter, r *http.Request) (Models.User, error) { +func GetUserById(w http.ResponseWriter, r *http.Request) (Models.User, error) { var ( postData Models.User id string err error ) - id, err = getUserId(r) + id, err = GetUserId(r) if err != nil { log.Printf("Error encountered getting id\n") JsonReturn(w, 500, "An error occured") -- 2.17.1 From d2eb1c218cd25032d49c98a4faae47e5009c8bd7 Mon Sep 17 00:00:00 2001 From: Tovi Jaeschke-Rogers Date: Mon, 21 Mar 2022 05:15:23 +1030 Subject: [PATCH 5/5] Add tests for logout and UpdatePassword --- Api/Auth/Logout_test.go | 90 ++++++++++++++++ .../{ChangePassword.go => UpdatePassword.go} | 0 Api/Auth/UpdatePassword_test.go | 100 ++++++++++++++++++ Api/Routes.go | 11 +- 4 files changed, 196 insertions(+), 5 deletions(-) create mode 100644 Api/Auth/Logout_test.go rename Api/Auth/{ChangePassword.go => UpdatePassword.go} (100%) create mode 100644 Api/Auth/UpdatePassword_test.go diff --git a/Api/Auth/Logout_test.go b/Api/Auth/Logout_test.go new file mode 100644 index 0000000..56cc7f9 --- /dev/null +++ b/Api/Auth/Logout_test.go @@ -0,0 +1,90 @@ +package Auth + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "path" + "runtime" + "strings" + "testing" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + + "github.com/gorilla/mux" +) + +func init() { + // Fix working directory for tests + _, filename, _, _ := runtime.Caller(0) + dir := path.Join(path.Dir(filename), "..") + err := os.Chdir(dir) + if err != nil { + panic(err) + } + + Database.InitTest() + + r = mux.NewRouter() +} + +func Test_Logout(t *testing.T) { + t.Log("Testing Logout...") + + r.HandleFunc("/admin/login", Logout).Methods("POST") + r.HandleFunc("/admin/logout", Logout).Methods("GET") + + ts := httptest.NewServer(r) + + defer ts.Close() + + userData, err := createTestUser(true) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + postJson := ` +{ + "email": "%s", + "password": "password" +} +` + postJson = fmt.Sprintf(postJson, userData.Email) + + res, err := http.Post(ts.URL+"/admin/login", "application/json", strings.NewReader(postJson)) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + return + } + + if len(res.Cookies()) != 1 { + t.Errorf("Expected cookies len 1, recieved %d", len(res.Cookies())) + return + } + + req, err := http.NewRequest("GET", ts.URL+"/admin/logout", nil) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + req.AddCookie(res.Cookies()[0]) + + res, err = http.DefaultClient.Do(req) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + return + } +} diff --git a/Api/Auth/ChangePassword.go b/Api/Auth/UpdatePassword.go similarity index 100% rename from Api/Auth/ChangePassword.go rename to Api/Auth/UpdatePassword.go diff --git a/Api/Auth/UpdatePassword_test.go b/Api/Auth/UpdatePassword_test.go new file mode 100644 index 0000000..1347495 --- /dev/null +++ b/Api/Auth/UpdatePassword_test.go @@ -0,0 +1,100 @@ +package Auth + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "path" + "runtime" + "strings" + "testing" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + + "github.com/gorilla/mux" +) + +func init() { + // Fix working directory for tests + _, filename, _, _ := runtime.Caller(0) + dir := path.Join(path.Dir(filename), "..") + err := os.Chdir(dir) + if err != nil { + panic(err) + } + + Database.InitTest() + + r = mux.NewRouter() +} + +func Test_UpdatePassword(t *testing.T) { + t.Log("Testing UpdatePassword...") + + r.HandleFunc("/admin/login", Logout).Methods("POST") + r.HandleFunc("/admin/user/{userID}/update-password", UpdatePassword).Methods("PUT") + + ts := httptest.NewServer(r) + + defer ts.Close() + + userData, err := createTestUser(true) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + postJson := ` +{ + "email": "%s", + "password": "password" +} +` + postJson = fmt.Sprintf(postJson, userData.Email) + + res, err := http.Post(ts.URL+"/admin/login", "application/json", strings.NewReader(postJson)) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + return + } + + if len(res.Cookies()) != 1 { + t.Errorf("Expected cookies len 1, recieved %d", len(res.Cookies())) + return + } + + postJson = ` +{ + "password": "new_password", + "confirm_password": "new_password" +} +` + req, err := http.NewRequest("PUT", fmt.Sprintf( + "%s/admin/user/%s/update-password", + ts.URL, + userData.ID, + ), strings.NewReader(postJson)) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + req.AddCookie(res.Cookies()[0]) + + res, err = http.DefaultClient.Do(req) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + if res.StatusCode != http.StatusOK { + t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + return + } +} diff --git a/Api/Routes.go b/Api/Routes.go index e29633a..e71544f 100644 --- a/Api/Routes.go +++ b/Api/Routes.go @@ -30,11 +30,12 @@ func InitApiEndpoints() *mux.Router { router.HandleFunc("/post/{postID}/image/{imageID}", deletePostImage).Methods("DELETE") // Define routes for users api - router.HandleFunc("/user", getUsers).Methods("GET") - router.HandleFunc("/user", createUser).Methods("POST") - router.HandleFunc("/user/{userID}", getUser).Methods("GET") - router.HandleFunc("/user/{userID}", updatePost).Methods("PUT") - router.HandleFunc("/user/{userID}", deletePost).Methods("DELETE") + router.HandleFunc("/admin/user", getUsers).Methods("GET") + router.HandleFunc("/admin/user", createUser).Methods("POST") + router.HandleFunc("/admin/user/{userID}", getUser).Methods("GET") + router.HandleFunc("/admin/user/{userID}", updatePost).Methods("PUT") + router.HandleFunc("/admin/user/{userID}", deletePost).Methods("DELETE") + router.HandleFunc("/admin/user/{userID}/update-password", Auth.UpdatePassword).Methods("PUT") // Define routes for authentication router.HandleFunc("/admin/login", Auth.Login).Methods("POST") -- 2.17.1