From d584d40a5278aeb412ad84201a56d60209ca0e4f Mon Sep 17 00:00:00 2001 From: Tovi Jaeschke-Rogers Date: Mon, 21 Mar 2022 05:03:05 +1030 Subject: [PATCH] Add authentication to all required endpoints --- Api/Auth/ChangePassword.go | 53 ++++++++++++++ Api/Auth/Login.go | 5 +- Api/Auth/Session.go | 32 ++++++++- Api/JsonSerialization/VerifyJson.go | 50 ++++++++----- Api/PostImages.go | 22 +++--- Api/Posts.go | 61 ++++++++++------ Api/Posts_test.go | 37 +++++++++- Api/Users.go | 74 ++++++++++++-------- Api/Users_test.go | 104 ++++++++++++++++++++++++---- Database/Users.go | 2 +- {Api => Util}/PostHelper.go | 8 +-- {Api => Util}/PostImageHelper.go | 8 +-- {Api => Util}/ReturnJson.go | 2 +- {Api => Util}/UserHelper.go | 8 +-- 14 files changed, 353 insertions(+), 113 deletions(-) create mode 100644 Api/Auth/ChangePassword.go rename {Api => Util}/PostHelper.go (84%) rename {Api => Util}/PostImageHelper.go (84%) rename {Api => Util}/ReturnJson.go (97%) rename {Api => Util}/UserHelper.go (84%) diff --git a/Api/Auth/ChangePassword.go b/Api/Auth/ChangePassword.go new file mode 100644 index 0000000..9187665 --- /dev/null +++ b/Api/Auth/ChangePassword.go @@ -0,0 +1,53 @@ +package Auth + +import ( + "encoding/json" + "net/http" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" +) + +type ChangePassword struct { + Password string `json:"password"` + ConfirmPassword string `json:"confirm_password"` +} + +func UpdatePassword(w http.ResponseWriter, r *http.Request) { + var ( + changePasswd ChangePassword + userData Models.User + err error + ) + + userData, err = CheckCookieCurrentUser(w, r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + err = json.NewDecoder(r.Body).Decode(&changePasswd) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + if changePasswd.Password != changePasswd.ConfirmPassword { + w.WriteHeader(http.StatusBadRequest) + return + } + + userData.Password, err = HashPassword(changePasswd.Password) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + err = Database.UpdateUser(userData.ID.String(), &userData) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/Api/Auth/Login.go b/Api/Auth/Login.go index 5d42e5c..0ae4a0a 100644 --- a/Api/Auth/Login.go +++ b/Api/Auth/Login.go @@ -51,8 +51,9 @@ func Login(w http.ResponseWriter, r *http.Request) { expiresAt = time.Now().Add(1 * time.Hour) Sessions[sessionToken.String()] = Session{ - Username: userData.Email, - Expiry: expiresAt, + UserID: userData.ID.String(), + Email: userData.Email, + Expiry: expiresAt, } http.SetCookie(w, &http.Cookie{ diff --git a/Api/Auth/Session.go b/Api/Auth/Session.go index 3e2c23f..b647376 100644 --- a/Api/Auth/Session.go +++ b/Api/Auth/Session.go @@ -4,6 +4,9 @@ import ( "errors" "net/http" "time" + + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util" ) var ( @@ -11,8 +14,9 @@ var ( ) type Session struct { - Username string - Expiry time.Time + UserID string + Email string + Expiry time.Time } func (s Session) IsExpired() bool { @@ -49,3 +53,27 @@ func CheckCookie(r *http.Request) (Session, error) { return userSession, nil } + +func CheckCookieCurrentUser(w http.ResponseWriter, r *http.Request) (Models.User, error) { + var ( + userSession Session + userData Models.User + err error + ) + + userSession, err = CheckCookie(r) + if err != nil { + return userData, err + } + + userData, err = Util.GetUserById(w, r) + if err != nil { + return userData, err + } + + if userData.ID.String() != userSession.UserID { + return userData, errors.New("Is not current user") + } + + return userData, nil +} diff --git a/Api/JsonSerialization/VerifyJson.go b/Api/JsonSerialization/VerifyJson.go index 2ad1554..862e7c8 100644 --- a/Api/JsonSerialization/VerifyJson.go +++ b/Api/JsonSerialization/VerifyJson.go @@ -7,7 +7,11 @@ import ( // isIntegerType returns whether the type is an integer and if it's unsigned. // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L328 -func isIntegerType(t reflect.Type) (yes bool, unsigned bool) { +func isIntegerType(t reflect.Type) (bool, bool) { + var ( + yes bool + unsigned bool + ) switch t.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: yes = true @@ -16,19 +20,22 @@ func isIntegerType(t reflect.Type) (yes bool, unsigned bool) { unsigned = true } - return + return yes, unsigned } // isFloatType returns true if the type is a floating point. Note that this doesn't // care about the value -- unmarshaling the number "0" gives a float, not an int. // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L319 -func isFloatType(t reflect.Type) (yes bool) { +func isFloatType(t reflect.Type) bool { + var ( + yes bool + ) switch t.Kind() { case reflect.Float32, reflect.Float64: yes = true } - return + return yes } // CanConvert returns whether value v is convertible to type t. @@ -38,10 +45,20 @@ func isFloatType(t reflect.Type) (yes bool) { // Modified due to not handling slices (DefaultCanConvert fails on PhotoUrls and Tags) // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L191 func CanConvert(t reflect.Type, v reflect.Value) bool { - isPtr := t.Kind() == reflect.Ptr - isStruct := t.Kind() == reflect.Struct - isArray := t.Kind() == reflect.Array - dstType := t + var ( + isPtr bool + isStruct bool + isArray bool + dstType reflect.Type + dstInt bool + unsigned bool + f float64 + srcInt bool + ) + isPtr = t.Kind() == reflect.Ptr + isStruct = t.Kind() == reflect.Struct + isArray = t.Kind() == reflect.Array + dstType = t // Check if v is a nil value. if !v.IsValid() || (v.CanAddr() && v.IsNil()) { @@ -72,20 +89,19 @@ func CanConvert(t reflect.Type, v reflect.Value) bool { } // Handle converting to an integer type. - if dstInt, unsigned := isIntegerType(dstType); dstInt { + dstInt, unsigned = isIntegerType(dstType) + if dstInt { if isFloatType(v.Type()) { - f := v.Float() + f = v.Float() - if math.Trunc(f) != f { - return false - } else if unsigned && f < 0 { - return false - } - } else if srcInt, _ := isIntegerType(v.Type()); srcInt { - if unsigned && v.Int() < 0 { + if math.Trunc(f) != f || unsigned && f < 0 { return false } } + srcInt, _ = isIntegerType(v.Type()) + if srcInt && unsigned && v.Int() < 0 { + return false + } } return true diff --git a/Api/PostImages.go b/Api/PostImages.go index 62cc710..dea0442 100644 --- a/Api/PostImages.go +++ b/Api/PostImages.go @@ -30,10 +30,10 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { err error ) - postID, err = getPostId(r) + postID, err = Util.GetPostId(r) if err != nil { log.Printf("Error encountered getting id\n") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -42,7 +42,7 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { err = r.ParseMultipartForm(20 << 20) if err != nil { log.Printf("Error encountered parsing multipart form: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -54,7 +54,7 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { file, err = fileHeader.Open() if err != nil { log.Printf("Error encountered while post image upload: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } defer file.Close() @@ -62,14 +62,14 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { fileBytes, err = ioutil.ReadAll(file) if err != nil { log.Printf("Error encountered while post image upload: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } fileObject, err = Util.WriteFile(fileBytes, "image") if err != nil { log.Printf("Error encountered while post image upload: %s\n", err.Error()) - JsonReturn(w, 415, "Invalid filetype") + Util.JsonReturn(w, 415, "Invalid filetype") return } @@ -83,19 +83,19 @@ func createPostImage(w http.ResponseWriter, r *http.Request) { err = Database.CreatePostImage(&postImage) if err != nil { log.Printf("Error encountered while creating post_image record: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } } - postData, err = getPostById(w, r) + postData, err = Util.GetPostById(w, r) if err != nil { return } returnJson, err = json.MarshalIndent(postData, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -110,7 +110,7 @@ func deletePostImage(w http.ResponseWriter, r *http.Request) { err error ) - postImageData, err = getPostImageById(w, r) + postImageData, err = Util.GetPostImageById(w, r) if err != nil { return } @@ -118,7 +118,7 @@ func deletePostImage(w http.ResponseWriter, r *http.Request) { err = Database.DeletePostImage(&postImageData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } diff --git a/Api/Posts.go b/Api/Posts.go index c39adb5..f5aed9c 100644 --- a/Api/Posts.go +++ b/Api/Posts.go @@ -8,9 +8,11 @@ import ( "net/url" "strconv" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/Auth" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util" ) func getPosts(w http.ResponseWriter, r *http.Request) { @@ -27,27 +29,27 @@ func getPosts(w http.ResponseWriter, r *http.Request) { page, err = strconv.Atoi(values.Get("page")) if err != nil { log.Println("Could not parse page url argument") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } page, err = strconv.Atoi(values.Get("pageSize")) if err != nil { log.Println("Could not parse pageSize url argument") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } posts, err = Database.GetPosts(page, pageSize) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(posts, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -66,13 +68,13 @@ func getFrontPagePosts(w http.ResponseWriter, r *http.Request) { posts, err = Database.GetFrontPagePosts() if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(posts, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -88,14 +90,14 @@ func getPost(w http.ResponseWriter, r *http.Request) { err error ) - postData, err = getPostById(w, r) + postData, err = Util.GetPostById(w, r) if err != nil { return } returnJson, err = json.MarshalIndent(postData, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -112,12 +114,16 @@ func createPost(w http.ResponseWriter, r *http.Request) { err error ) - // TODO: Add auth + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -129,21 +135,20 @@ func createPost(w http.ResponseWriter, r *http.Request) { "audios", }, false) if err != nil { - panic(err) log.Printf("Invalid data provided to posts API: %s\n", err.Error()) - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") return } err = Database.CreatePost(&postData) if err != nil { - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") } returnJson, err = json.MarshalIndent(postData, "", " ") if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -161,38 +166,44 @@ func updatePost(w http.ResponseWriter, r *http.Request) { err error ) - id, err = getPostId(r) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + id, err = Util.GetPostId(r) if err != nil { log.Printf("Error encountered getting id\n") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } postData, err = JsonSerialization.DeserializePost(requestBody, []string{}, true) if err != nil { log.Printf("Invalid data provided to posts API: %s\n", err.Error()) - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") return } postData, err = Database.UpdatePost(id, &postData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(postData, "", " ") if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -207,7 +218,13 @@ func deletePost(w http.ResponseWriter, r *http.Request) { err error ) - postData, err = getPostById(w, r) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + postData, err = Util.GetPostById(w, r) if err != nil { return } @@ -215,7 +232,7 @@ func deletePost(w http.ResponseWriter, r *http.Request) { err = Database.DeletePost(&postData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } diff --git a/Api/Posts_test.go b/Api/Posts_test.go index 3f381a0..6084d4a 100644 --- a/Api/Posts_test.go +++ b/Api/Posts_test.go @@ -155,7 +155,11 @@ func Test_createPost(t *testing.T) { defer ts.Close() - userData, err := createTestUser(true) + c, u, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } postJson := ` { @@ -171,14 +175,25 @@ func Test_createPost(t *testing.T) { } ` - postJson = fmt.Sprintf(postJson, userData.ID.String()) + postJson = fmt.Sprintf(postJson, u.ID.String()) + + req, err := http.NewRequest("POST", ts.URL+"/post", strings.NewReader(postJson)) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + + req.AddCookie(c) - res, err := http.Post(ts.URL+"/post", "application/json", strings.NewReader(postJson)) + res, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + return } + if res.StatusCode != http.StatusOK { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) + return } postData := new(Models.Post) @@ -204,6 +219,12 @@ func Test_deletePost(t *testing.T) { defer ts.Close() + c, _, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + postData, err := createTestPost() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) @@ -220,6 +241,8 @@ func Test_deletePost(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } + req.AddCookie(c) + // Fetch Request res, err := http.DefaultClient.Do(req) if err != nil { @@ -242,6 +265,12 @@ func Test_updatePost(t *testing.T) { defer ts.Close() + c, _, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + return + } + postData, err := createTestPost() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) @@ -265,6 +294,8 @@ func Test_updatePost(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } + req.AddCookie(c) + // Fetch Request res, err := http.DefaultClient.Do(req) if err != nil { diff --git a/Api/Users.go b/Api/Users.go index 33f035d..9ce2e79 100644 --- a/Api/Users.go +++ b/Api/Users.go @@ -12,6 +12,7 @@ import ( "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" + "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util" ) func getUsers(w http.ResponseWriter, r *http.Request) { @@ -23,32 +24,37 @@ func getUsers(w http.ResponseWriter, r *http.Request) { err error ) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + values = r.URL.Query() page, err = strconv.Atoi(values.Get("page")) if err != nil { log.Println("Could not parse page url argument") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } page, err = strconv.Atoi(values.Get("pageSize")) if err != nil { log.Println("Could not parse pageSize url argument") - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } - users, err = Database.GetUsers(page, pageSize) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(users, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -64,14 +70,20 @@ func getUser(w http.ResponseWriter, r *http.Request) { err error ) - userData, err = getUserById(w, r) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + userData, err = Util.GetUserById(w, r) if err != nil { return } returnJson, err = json.MarshalIndent(userData, "", " ") if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -90,7 +102,7 @@ func createUser(w http.ResponseWriter, r *http.Request) { requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -100,30 +112,30 @@ func createUser(w http.ResponseWriter, r *http.Request) { }, false) if err != nil { log.Printf("Invalid data provided to user API: %s\n", err.Error()) - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") return } err = Database.CheckUniqueEmail(userData.Email) if err != nil { - JsonReturn(w, 405, "invalid_email") + Util.JsonReturn(w, 405, "invalid_email") return } if userData.Password != userData.ConfirmPassword { - JsonReturn(w, 405, "invalid_password") + Util.JsonReturn(w, 405, "invalid_password") return } userData.Password, err = Auth.HashPassword(userData.Password) if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } err = Database.CreateUser(&userData) if err != nil { - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -133,45 +145,44 @@ func createUser(w http.ResponseWriter, r *http.Request) { func updateUser(w http.ResponseWriter, r *http.Request) { var ( - userData Models.User - requestBody []byte - returnJson []byte - id string - err error + currentUserData Models.User + userData Models.User + requestBody []byte + returnJson []byte + err error ) - id, err = getUserId(r) + currentUserData, err = Auth.CheckCookieCurrentUser(w, r) if err != nil { - log.Printf("Error encountered getting id\n") - JsonReturn(w, 500, "An error occured") + w.WriteHeader(http.StatusUnauthorized) return } requestBody, err = ioutil.ReadAll(r.Body) if err != nil { log.Printf("Error encountered reading POST body: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } userData, err = JsonSerialization.DeserializeUser(requestBody, []string{}, true) if err != nil { log.Printf("Invalid data provided to users API: %s\n", err.Error()) - JsonReturn(w, 405, "Invalid data") + Util.JsonReturn(w, 405, "Invalid data") return } - err = Database.UpdateUser(id, &userData) + err = Database.UpdateUser(currentUserData.ID.String(), &userData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } returnJson, err = json.MarshalIndent(userData, "", " ") if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } @@ -186,15 +197,22 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { err error ) - userData, err = getUserById(w, r) + _, err = Auth.CheckCookie(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + userData, err = Util.GetUserById(w, r) if err != nil { + w.WriteHeader(http.StatusNotFound) return } err = Database.DeleteUser(&userData) if err != nil { log.Printf("An error occured: %s\n", err.Error()) - JsonReturn(w, 500, "An error occured") + Util.JsonReturn(w, 500, "An error occured") return } diff --git a/Api/Users_test.go b/Api/Users_test.go index 64cb6e7..1125999 100644 --- a/Api/Users_test.go +++ b/Api/Users_test.go @@ -2,6 +2,7 @@ package Api import ( "encoding/json" + "errors" "fmt" "math/rand" "net/http" @@ -68,6 +69,47 @@ func createTestUser(random bool) (Models.User, error) { return userData, err } +func login() (*http.Cookie, Models.User, error) { + var ( + c *http.Cookie + u Models.User + ) + + r.HandleFunc("/admin/login", Auth.Login).Methods("POST") + + ts := httptest.NewServer(r) + + defer ts.Close() + + u, err := createTestUser(true) + if err != nil { + return c, u, err + } + + postJson := ` +{ + "email": "%s", + "password": "password" +} +` + postJson = fmt.Sprintf(postJson, u.Email) + + res, err := http.Post(ts.URL+"/admin/login", "application/json", strings.NewReader(postJson)) + if err != nil { + return c, u, err + } + + if res.StatusCode != http.StatusOK { + return c, u, errors.New("Invalid res.StatusCode") + } + + if len(res.Cookies()) != 1 { + return c, u, errors.New("Invalid cookies length") + } + + return res.Cookies()[0], u, nil +} + func Test_getUser(t *testing.T) { t.Log("Testing getUser...") @@ -77,22 +119,31 @@ func Test_getUser(t *testing.T) { defer ts.Close() - userData, err := createTestUser(false) + c, u, err := login() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) t.FailNow() } - res, err := http.Get(fmt.Sprintf( + req, err := http.NewRequest("GET", fmt.Sprintf( "%s/user/%s", ts.URL, - userData.ID, - )) + u.ID, + ), nil) + + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + + req.AddCookie(c) + res, err := http.DefaultClient.Do(req) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) t.FailNow() } + if res.StatusCode != http.StatusOK { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) t.FailNow() @@ -105,18 +156,18 @@ func Test_getUser(t *testing.T) { t.FailNow() } - if getUserData.Email != "email@email.com" { - t.Errorf("Expected email \"email@email.com\", recieved %s", getUserData.Email) + if getUserData.Email != u.Email { + t.Errorf("Expected email \"%s\", recieved %s", u.Email, getUserData.Email) t.FailNow() } - if getUserData.FirstName != "Hugh" { - t.Errorf("Expected email \"Hugh\", recieved %s", getUserData.FirstName) + if getUserData.FirstName != u.FirstName { + t.Errorf("Expected email \"%s\", recieved %s", u.FirstName, getUserData.FirstName) t.FailNow() } - if getUserData.LastName != "Mann" { - t.Errorf("Expected email \"Mann\", recieved %s", getUserData.LastName) + if getUserData.LastName != u.LastName { + t.Errorf("Expected email \"%s\", recieved %s", u.LastName, getUserData.LastName) t.FailNow() } } @@ -129,17 +180,31 @@ func Test_getUsers(t *testing.T) { ts := httptest.NewServer(r) defer ts.Close() - var err error + c, _, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + for i := 0; i < 20; i++ { createTestUser(true) } - res, err := http.Get(ts.URL + "/user?page=1&pageSize=10") + req, err := http.NewRequest("GET", ts.URL+"/user?page=1&pageSize=10", nil) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) t.FailNow() } + + req.AddCookie(c) + + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + if res.StatusCode != http.StatusOK { t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) t.FailNow() @@ -201,9 +266,10 @@ func Test_updateUser(t *testing.T) { defer ts.Close() - userData, err := createTestUser(true) + c, u, err := login() if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() } email := fmt.Sprintf("%s@email.com", randString(16)) @@ -220,13 +286,15 @@ func Test_updateUser(t *testing.T) { req, err := http.NewRequest("PUT", fmt.Sprintf( "%s/user/%s", ts.URL, - userData.ID, + u.ID, ), strings.NewReader(postJson)) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) } + req.AddCookie(c) + // Fetch Request res, err := http.DefaultClient.Do(req) if err != nil { @@ -266,6 +334,12 @@ func Test_deleteUser(t *testing.T) { defer ts.Close() + c, _, err := login() + if err != nil { + t.Errorf("Expected nil, recieved %s", err.Error()) + t.FailNow() + } + userData, err := createTestUser(true) if err != nil { t.Errorf("Expected nil, recieved %s", err.Error()) @@ -282,6 +356,8 @@ func Test_deleteUser(t *testing.T) { t.Errorf("Expected nil, recieved %s", err.Error()) } + req.AddCookie(c) + // Fetch Request res, err := http.DefaultClient.Do(req) if err != nil { diff --git a/Database/Users.go b/Database/Users.go index 0289823..7dc8b03 100644 --- a/Database/Users.go +++ b/Database/Users.go @@ -108,7 +108,7 @@ func UpdateUser(id string, userData *Models.User) error { var ( err error ) - err = DB.Model(&Models.Post{}). + err = DB.Model(&Models.User{}). Select("*"). Omit("id", "created_at", "updated_at", "deleted_at"). Where("id = ?", id). diff --git a/Api/PostHelper.go b/Util/PostHelper.go similarity index 84% rename from Api/PostHelper.go rename to Util/PostHelper.go index b12d3fe..697ff44 100644 --- a/Api/PostHelper.go +++ b/Util/PostHelper.go @@ -1,4 +1,4 @@ -package Api +package Util import ( "errors" @@ -11,7 +11,7 @@ import ( "github.com/gorilla/mux" ) -func getPostId(r *http.Request) (string, error) { +func GetPostId(r *http.Request) (string, error) { var ( urlVars map[string]string id string @@ -26,14 +26,14 @@ func getPostId(r *http.Request) (string, error) { return id, nil } -func getPostById(w http.ResponseWriter, r *http.Request) (Models.Post, error) { +func GetPostById(w http.ResponseWriter, r *http.Request) (Models.Post, error) { var ( postData Models.Post id string err error ) - id, err = getPostId(r) + id, err = GetPostId(r) if err != nil { log.Printf("Error encountered getting id\n") JsonReturn(w, 500, "An error occured") diff --git a/Api/PostImageHelper.go b/Util/PostImageHelper.go similarity index 84% rename from Api/PostImageHelper.go rename to Util/PostImageHelper.go index 06142c4..9ee16b3 100644 --- a/Api/PostImageHelper.go +++ b/Util/PostImageHelper.go @@ -1,4 +1,4 @@ -package Api +package Util import ( "errors" @@ -10,7 +10,7 @@ import ( "github.com/gorilla/mux" ) -func getPostImageId(r *http.Request) (string, error) { +func GetPostImageId(r *http.Request) (string, error) { var ( urlVars map[string]string id string @@ -25,14 +25,14 @@ func getPostImageId(r *http.Request) (string, error) { return id, nil } -func getPostImageById(w http.ResponseWriter, r *http.Request) (Models.PostImage, error) { +func GetPostImageById(w http.ResponseWriter, r *http.Request) (Models.PostImage, error) { var ( postImageData Models.PostImage id string err error ) - id, err = getPostImageId(r) + id, err = GetPostImageId(r) if err != nil { log.Printf("Error encountered getting id\n") JsonReturn(w, 500, "An error occured") diff --git a/Api/ReturnJson.go b/Util/ReturnJson.go similarity index 97% rename from Api/ReturnJson.go rename to Util/ReturnJson.go index 002c091..747fcf2 100644 --- a/Api/ReturnJson.go +++ b/Util/ReturnJson.go @@ -1,4 +1,4 @@ -package Api +package Util import ( "encoding/json" diff --git a/Api/UserHelper.go b/Util/UserHelper.go similarity index 84% rename from Api/UserHelper.go rename to Util/UserHelper.go index 7a658ee..85df401 100644 --- a/Api/UserHelper.go +++ b/Util/UserHelper.go @@ -1,4 +1,4 @@ -package Api +package Util import ( "errors" @@ -11,7 +11,7 @@ import ( "github.com/gorilla/mux" ) -func getUserId(r *http.Request) (string, error) { +func GetUserId(r *http.Request) (string, error) { var ( urlVars map[string]string id string @@ -26,14 +26,14 @@ func getUserId(r *http.Request) (string, error) { return id, nil } -func getUserById(w http.ResponseWriter, r *http.Request) (Models.User, error) { +func GetUserById(w http.ResponseWriter, r *http.Request) (Models.User, error) { var ( postData Models.User id string err error ) - id, err = getUserId(r) + id, err = GetUserId(r) if err != nil { log.Printf("Error encountered getting id\n") JsonReturn(w, 500, "An error occured")