Adding basic database persistence in go

It has been a wonderful week and a half, I became father of a cute little boy on Monday the 21st. The baby arrived a tad sooner than we expected, catching us a little unprepared, but then who minds such surprises. The baby and mom are doing good and life is slowly coming back on track. Having a baby has been a less of a life changing and more of a sleep shattering experience so far!

Thankfully, I have now got some time to continue my learnings on the go language. Continuing from the previous post of Posting on twitter, we will try to integrate database persistence to our application. The idea being that we are able to insert, query and delete some data from a database. We will use mysql for this example.

We will store our twitter token and secret key in database, earlier we were storing those values in a map.

We first need to create a database in mysql by going into mysql console. Lets name it “gotest”

Now lets create a table in the database by issuing a create table command (We are just concentrating on CRUD operations in go.)

The table I created looks like this

+--------+--------------+------+-----+---------+-------+
| Field  | Type         | Null | Key | Default | Extra |
+--------+--------------+------+-----+---------+-------+
| id     | int(11)      | YES  |     | NULL    |       |
| token  | varchar(100) | YES  |     | NULL    |       |
| secret | varchar(100) | YES  |     | NULL    |       |
+--------+--------------+------+-----+---------+-------+

Now lets get down to some code.

To work with mysql we need mysql drivers and packages for sql support provided by go. We will use following 2 packages

database/sql
github.com/go-sql-driver/mysql

We have most of the code same from the previous article, so lets concentrate on important changes. First we define some variables we will need

var (
	stmtIns *sql.Stmt
	stmtOut *sql.Stmt
	stmtDel *sql.Stmt
	secret string
)

We updated our method to save credentials to database

 func putCredentials(cred *oauth.Credentials) {
	// insert token and secret in database
	_, err := stmtIns.Exec(1, cred.Token, cred.Secret) //todo: get a random id here, its hardcoded to 1 for all calls
	if err != nil {
		panic("Unable to add credentials to database")
	}
}

In the above code stmtIns is a prepared Statement which inserts values to the database. Here is how we created this prepared statement in the main() function

stmtIns, err = dbConnection.Prepare("INSERT INTO credentials VALUES(?,?,?)")
	if err != nil {
		panic("Unable to get prepared statement for insert")
	}
	defer stmtIns.Close()

We also updated the function for deleteCredentials()

func deleteCredentials(token string) {
	_, err := stmtDel.Exec(token)
	if err != nil {
		panic("Unable to delete the token from database")
	}
}

Here is stmtDel is a prepared statement defined in main() method as

stmtDel, err = dbConnection.Prepare("DELETE FROM credentials WHERE token = ?")
	if err != nil {
		panic("Unable to get prepared statement for delete")
	}
	defer stmtDel.Close()

For querying the database we updated the getCredentials() method

func getCredentials(token string) *oauth.Credentials {

	err := stmtOut.QueryRow(token).Scan(&secret)
	if err != nil {
		panic("Unable to retrieve value from database")
	}
	return &oauth.Credentials{Token: token, Secret: secret}
}

Here stmtOut is a prepared statement defined as

stmtOut, err = dbConnection.Prepare("SELECT secret FROM credentials WHERE token = ?")
	if err != nil {
		panic("Unable to get prepared statement for select")
	}
	defer stmtOut.Close()

Here is how the entire code looks like, Note that the code is just a hint of how we can add database persistence and is not production ready.

package main

import ("fmt"
	"net/http"
	"log"
	"text/template"
	"github.com/garyburd/go-oauth/oauth"
	"net/url"
	"io/ioutil"
	"encoding/json"
	"time"
	"database/sql"
	_ "github.com/go-sql-driver/mysql"
)

var oauthClient = oauth.Client{
	Credentials : oauth.Credentials{Token: "your_App_token", Secret: "your_app_secret"},
	TemporaryCredentialRequestURI: "https://api.twitter.com/oauth/request_token",
	ResourceOwnerAuthorizationURI: "https://api.twitter.com/oauth/authorize",
	TokenRequestURI:               "https://api.twitter.com/oauth/access_token",
}

var (
	stmtIns *sql.Stmt
	stmtOut *sql.Stmt
	stmtDel *sql.Stmt
	secret string
)

type myHandler struct {
	handler  func(w http.ResponseWriter, r *http.Request, c *oauth.Credentials)
}

func serveHomePage(w http.ResponseWriter, r *http.Request, cred *oauth.Credentials) {
	if r.URL.Path != "/" {
		http.NotFound(w, r)
		return
	}
	if cred == nil {
		respond(w, homeLoggedOutTmpl, nil)
	} else {
		respond(w, homeTmpl, nil)
	}
}

// serveOAuthCallback handles callbacks from the OAuth server.
func serveOAuthCallback(w http.ResponseWriter, r *http.Request) {
	tempCred := getCredentials(r.FormValue("oauth_token"))
	if tempCred == nil {
		http.Error(w, "Unknown oauth_token.", 500)
		return
	}
	deleteCredentials(tempCred.Token)
	tokenCred, _, err := oauthClient.RequestToken(http.DefaultClient, tempCred, r.FormValue("oauth_verifier"))
	if err != nil {
		http.Error(w, "Error getting request token, "+err.Error(), 500)
		return
	}
	putCredentials(tokenCred)
	http.SetCookie(w, &http.Cookie{
		Name:     "auth",
		Path:     "/",
		HttpOnly: true,
		Value:    tokenCred.Token,
	})
	http.Redirect(w, r, "/", 302)
}

func getCredentials(token string) *oauth.Credentials {

	err := stmtOut.QueryRow(token).Scan(&secret)
	if err != nil {
		panic("Unable to retrieve value from database")
	}
	return &oauth.Credentials{Token: token, Secret: secret}
}

func deleteCredentials(token string) {
	_, err := stmtDel.Exec(token)
	if err != nil {
		panic("Unable to delete the token from database")
	}
}

func serveAuthorizationPage(w http.ResponseWriter, r *http.Request) {
	callback := "http://" + r.Host + "/callback"
	tempCred, err := oauthClient.RequestTemporaryCredentials(http.DefaultClient, callback, nil)
	if err != nil {
		http.Error(w, "Error getting temp cred, "+err.Error(), 500)
		return
	}
	putCredentials(tempCred)
	http.Redirect(w, r, oauthClient.AuthorizationURL(tempCred, nil), 302)
}


func postTweet(w http.ResponseWriter, r *http.Request, cred *oauth.Credentials) {
	var profile map[string]interface{}
	if err := apiPost(
		cred,
		"https://api.twitter.com/1.1/statuses/update.json",
		url.Values{"status": {"This is another test tweet sent from #golang application "}},
		&profile); err != nil {
		http.Error(w, "Error following, "+err.Error(), 500)
		return
	}
	respond(w, tweetedTmpl, profile)
}

// apiPost issues a POST request to the Twitter API and decodes the response JSON to data.
func apiPost(cred *oauth.Credentials, urlStr string, form url.Values, data interface{}) error {
	resp, err := oauthClient.Post(http.DefaultClient, cred, urlStr, form)
	if err != nil {
		return err
	}
	defer resp.Body.Close()
	return decodeResponse(resp, data)
}

// decodeResponse decodes the JSON response from the Twitter API.
func decodeResponse(resp *http.Response, data interface{}) error {
	if resp.StatusCode != 200 {
		p, _ := ioutil.ReadAll(resp.Body)
		return fmt.Errorf("get %s returned status %d, %s", resp.Request.URL, resp.StatusCode, p)
	}
	return json.NewDecoder(resp.Body).Decode(data)
}

func putCredentials(cred *oauth.Credentials) {
	// insert token and secret in database
	_, err := stmtIns.Exec(1, cred.Token, cred.Secret) //todo: get a random id here
	if err != nil {
		panic("Unable to add credentials to database")
	}
}


func respond(w http.ResponseWriter, t *template.Template, data interface{}) {
	w.Header().Set("Content-Type", "text/html; charset=utf-8")
	if err := t.Execute(w, data); err != nil {
		log.Print(err)
	}
}

func (h *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	var cred *oauth.Credentials
	if c, _ := r.Cookie("auth"); c != nil {
		cred = getCredentials(c.Value)
	}
	h.handler(w, r, cred)
}

// serveLogout clears the authentication cookie.
func serveLogout(w http.ResponseWriter, r *http.Request) {
	http.SetCookie(w, &http.Cookie{
		Name:     "auth",
		Path:     "/",
		HttpOnly: true,
		MaxAge:   -1,
		Expires:  time.Now().Add(-1*time.Hour),
	})
	http.Redirect(w, r, "/", 302)
}

var homeTmpl = template.Must(template.New("home").Parse(
`<html>
<head>
</head>
<body>
<a href="/authorize"> Authorize for Twitter</a>
<a href="/post">Post the tweet</a>
<a href="/logout">Logout</a>
</body>
</html>`))

var tweetedTmpl = template.Must(template.New("tweetedTmpl").Parse(
`<html>
<head>
</head>
<body>
 The post has been tweeted
<a href="/"> Home </a>
<br/>
<a href="/logout"> Logout </a>
</body>
</html>`))

var homeLoggedOutTmpl = template.Must(template.New("loggedOut").Parse(
`<html>
<head>
</head>
<body>
You are logged out
<a href="/authorize">Authorize</a>
</body>
</html>`))

func main() {
	fmt.Println("Starting Server")
	dbConnection, err := sql.Open("mysql", "root:igdefault@/gotest")
	defer dbConnection.Close()
	err = dbConnection.Ping()
	if err != nil {
		fmt.Println("There is an error in connecting to database")
	}else {
		fmt.Println("Connected to database")
	}

	stmtIns, err = dbConnection.Prepare("INSERT INTO credentials VALUES(?,?,?)")
	if err != nil {
		panic("Unable to get prepared statement for insert")
	}
	defer stmtIns.Close()

	stmtOut, err = dbConnection.Prepare("SELECT secret FROM credentials WHERE token = ?")
	if err != nil {
		panic("Unable to get prepared statement for select")
	}
	defer stmtOut.Close()

	stmtDel, err = dbConnection.Prepare("DELETE FROM credentials WHERE token = ?")
	if err != nil {
		panic("Unable to get prepared statement for delete")
	}
	defer stmtDel.Close()


	http.Handle("/", &myHandler{handler: serveHomePage})
	http.HandleFunc("/authorize", serveAuthorizationPage)
	http.HandleFunc("/callback", serveOAuthCallback)
	http.Handle("/post", &myHandler{handler: postTweet})
	http.HandleFunc("/logout", serveLogout)
	err = http.ListenAndServe(":8080", nil)
	if err != nil {
		log.Fatalf("Error listening, %v", err)
	}
}

~~ Whizdumb ~~
Email : sachin.xpert@gmail.com