Files
qddns/cmd/qddns-server/main.go
2022-08-06 20:58:35 +02:00

126 lines
3.2 KiB
Go

package main
import (
"flag"
"fmt"
"github.com/coreos/go-systemd/activation"
"github.com/gin-gonic/gin"
"github.com/thequux/qddns/common"
"github.com/thequux/qddns/db"
"github.com/thequux/qddns/multilistener"
"net"
"net/http"
"os"
"strings"
)
func Update(c *gin.Context) {
authHdr := strings.Fields(c.GetHeader("Authorization"))
domain := c.Param("domain")
clientIP := c.GetHeader("X-Real-IP")
ip := net.ParseIP(clientIP)
var token string
if len(authHdr) == 2 && authHdr[0] == "bearer" {
// we have a token; check if it's valid for domain
token = authHdr[1]
} else {
c.JSON(http.StatusUnauthorized, common.Response{Status: "error", Message: "Missing bearer token", Code: "QDA0001"})
return
}
if ip == nil {
c.JSON(http.StatusBadRequest, common.Response{Status: "error", Message: "Unable to determine client IP (got " + clientIP + ")", Code: "QDE0003"})
return
}
// get a connection
tx, err := db.Db.Begin(c)
defer tx.Rollback(c)
if err != nil {
c.JSON(http.StatusInternalServerError, common.Response{Status: "error", Message: "Failed to access database", Code: "QDE0001"})
return
}
// check token validity
row := tx.QueryRow(c, "SELECT COUNT(*) from qddns_auth WHERE token = $1 AND domain = $2", token, domain)
var rcount int
if err := row.Scan(&rcount); err != nil {
c.JSON(http.StatusInternalServerError, common.Response{Status: "error", Message: "Failed to check token " + err.Error(), Code: "QDE0002"})
return
} else if rcount < 1 {
c.JSON(http.StatusUnauthorized, common.Response{Status: "error", Message: "Not authorized for domain " + domain, Code: "QDA0002"})
return
}
// Identify the type of address
var recordType string
if ip.To4() == nil {
recordType = "AAAA"
clientIP = ip.To16().String()
} else {
recordType = "A"
clientIP = ip.To4().String()
}
// Do the update
tag, err := tx.Exec(c, "UPDATE records SET content = $1 WHERE name = $2 AND type = $3", clientIP, domain, recordType)
if err != nil {
c.JSON(http.StatusInternalServerError, common.Response{
Status: "error",
Message: "Failed to update record: " + err.Error(),
Code: "QDD0001",
})
return
}
if tag.RowsAffected() != 1 {
c.JSON(http.StatusInternalServerError, common.Response{
Status: "error",
Message: fmt.Sprintf("Wrong number of rows affected: %v", tag.RowsAffected()),
})
return
}
c.JSON(http.StatusOK, common.Response{
Status: "OK",
Message: "Success!",
})
tx.Commit(c)
}
var listen = flag.String("listen", ":8081", "Address or path on which to listen")
func main() {
flag.Parse()
if err := db.Connect(""); err != nil {
fmt.Fprintf(os.Stderr, "Unable to connect to database: %v", err)
os.Exit(1)
}
// Set up server
r := gin.Default()
r.POST("/update/:domain", Update)
var err error
if listeners, err := activation.Listeners(); err == nil && len(listeners) > 0 {
// Socket activation
var listener net.Listener
if len(listeners) > 1 {
listener, _ = multilistener.New(listeners...)
} else {
listener = listeners[0]
}
err = r.RunListener(listener)
} else if _, err = net.ResolveTCPAddr("tcp", *listen); err == nil {
err = r.Run(*listen)
} else {
// Probably a UNIX address
err = r.RunUnix(*listen)
}
if err != nil {
println("Failed to listen: " + err.Error())
os.Exit(1)
}
}