package main import ( "flag" "fmt" "github.com/coreos/go-systemd/activation" "github.com/gin-gonic/gin" "github.com/thequux/qddns/common" "github.com/thequux/qddns/db" "github.com/thequux/qddns/multilistener" "net" "net/http" "os" "strings" ) func Update(c *gin.Context) { authHdr := strings.Fields(c.GetHeader("Authorization")) domain := c.Param("domain") clientIP := c.GetHeader("X-Real-IP") ip := net.ParseIP(clientIP) var token string if len(authHdr) == 2 && authHdr[0] == "bearer" { // we have a token; check if it's valid for domain token = authHdr[1] } else { c.JSON(http.StatusUnauthorized, common.Response{Status: "error", Message: "Missing bearer token", Code: "QDA0001"}) return } if ip == nil { c.JSON(http.StatusBadRequest, common.Response{Status: "error", Message: "Unable to determine client IP (got " + clientIP + ")", Code: "QDE0003"}) return } // get a connection tx, err := db.Db.Begin(c) defer tx.Rollback(c) if err != nil { c.JSON(http.StatusInternalServerError, common.Response{Status: "error", Message: "Failed to access database", Code: "QDE0001"}) return } // check token validity row := tx.QueryRow(c, "SELECT COUNT(*) from qddns_auth WHERE token = $1 AND domain = $2", token, domain) var rcount int if err := row.Scan(&rcount); err != nil { c.JSON(http.StatusInternalServerError, common.Response{Status: "error", Message: "Failed to check token " + err.Error(), Code: "QDE0002"}) return } else if rcount < 1 { c.JSON(http.StatusUnauthorized, common.Response{Status: "error", Message: "Not authorized for domain " + domain, Code: "QDA0002"}) return } // Identify the type of address var recordType string if ip.To4() == nil { recordType = "AAAA" clientIP = ip.To16().String() } else { recordType = "A" clientIP = ip.To4().String() } // Do the update tag, err := tx.Exec(c, "UPDATE records SET content = $1 WHERE name = $2 AND type = $3", clientIP, domain, recordType) if err != nil { c.JSON(http.StatusInternalServerError, common.Response{ Status: "error", Message: "Failed to update record: " + err.Error(), Code: "QDD0001", }) return } if tag.RowsAffected() != 1 { c.JSON(http.StatusInternalServerError, common.Response{ Status: "error", Message: fmt.Sprintf("Wrong number of rows affected: %v", tag.RowsAffected()), }) return } tx.Commit(c) } var listen = flag.String("listen", ":8081", "Address or path on which to listen") func main() { flag.Parse() if err := db.Connect(""); err != nil { fmt.Fprintf(os.Stderr, "Unable to connect to database: %v", err) os.Exit(1) } // Set up server r := gin.Default() r.POST("/update/:domain", Update) var err error if listeners, err := activation.Listeners(); err == nil && len(listeners) > 0 { // Socket activation var listener net.Listener if len(listeners) > 1 { listener, _ = multilistener.New(listeners...) } else { listener = listeners[0] } err = r.RunListener(listener) } else if _, err = net.ResolveTCPAddr("tcp", *listen); err == nil { err = r.Run(*listen) } else { // Probably a UNIX address err = r.RunUnix(*listen) } if err != nil { println("Failed to listen: " + err.Error()) os.Exit(1) } }