1 Star 0 Fork 1

winie/sq

forked from unsafe-rust/sq 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
6_db.go 11.46 KB
一键复制 编辑 原始数据 按行查看 历史
unsafe-rust 提交于 2021-03-28 17:36 . update
package sq
import (
"bufio"
"bytes"
"context"
"database/sql"
"gitee.com/gopher2011/sqlx"
"log"
"os"
"reflect"
"strings"
"time"
)
type ISql interface {
// Query 参数<querySql>是查询类的SQL语句,<args>是SQL语句的参数。
QueryX(querySql string, args ...interface{}) (*sqlx.Rows, error)
// QueryRow 参数<querySql>是查询类的SQL语句,<args>是SQL语句的参数。
QueryRowX(querySql string, args ...interface{}) *sqlx.Row
// Take 查询一条记录,并将结果扫描进 <pointer>
// 参数<pointer>可以是struct/*struct。
// 参数<querySql>是查询类的SQL语句,<args>是SQL语句的参数。
Take(pointer interface{}, querySql string, args ...interface{}) error
// Select 查询多条记录,并将查询结果扫描进 <pointer>
// 参数<pointer>可以是[]struct/*[]struct。
// 参数<querySql>是查询类的SQL语句,<args>是SQL语句的参数。
Select(pointer interface{}, querySql string, args ...interface{}) error
// Exec 增、删、改操作的接口。
// 参数<querySql>是查询类的SQL语句,<args>是SQL语句的参数。
Exec(querySql string, args ...interface{}) (sql.Result, error)
PrepareX(query string) (*sqlx.Stmt, error)
// Rebind 绑定 querySql 语句。
Rebind(querySql string) string
//
ExecN(query string, arg interface{}) (sql.Result, error)
// DriverName 获取当前处于连接状态的 *sql.DB 的驱动名称。
DriverName() string
}
type BuildFunc func(b *Builder)
type DB struct {
database *sqlx.DB // *sqlx.DB 对象,该对象具备操作数据库的能力
tx *sqlx.Tx // *sqlx.Tx 对象,该对象是以事务的方式操作数据库。
logging bool // 为true:表示打印日志。
RelationMap map[string]BuildFunc // // 用来构造一个 Builder 对象的map容器。
}
// db 内部方法,返回数据库实例,如果是事务,则事务优先级较高
func (w *DB) db() ISql {
if w.tx != nil {
return w.tx.Unsafe()
}
return w.database.Unsafe()
}
// ShowSql single show sql log
func ShowSql() *DB {
w := Use(defaultLink)
w.logging = true
return w
}
func (w *DB) argsIn(query string, args []interface{}) (string, []interface{}, error) {
newArgs := make([]interface{}, 0)
newQuery, newArgs, err := sqlx.In(query, args...)
if err != nil {
return query, args, err
}
return newQuery, newArgs, nil
}
//DriverName wrapper sqlx.DriverName
func (w *DB) DriverName() string {
if w.tx != nil {
return w.tx.DriverName()
}
return w.database.DriverName()
}
//Begin begins a transaction and returns an *gosql.DB instead of an *sql.Tx.
func (w *DB) Begin() (*DB, error) {
tx, err := w.database.BeginX()
if err != nil {
return nil, err
}
return &DB{tx: tx}, nil
}
// Commit commits the transaction.
func (w *DB) Commit() error {
return w.tx.Commit()
}
// Rollback aborts the transaction.
func (w *DB) Rollback() error {
return w.tx.Rollback()
}
//Rebind wrapper sqlx.Rebind
func (w *DB) Rebind(query string) string {
return w.db().Rebind(query)
}
func (w *DB) PrepareX(query string) (*sqlx.Stmt, error) {
return w.db().PrepareX(query)
}
//Exec 执行: 增、删、改操作。
// 参数 <sqlStr> 是具体的SQL语句;参数 <args> 是SQL语句需要的参数。
func (w *DB) Exec(sqlStr string, args ...interface{}) (result sql.Result, err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Sql: sqlStr,
Args: args,
Err: err,
Start: start,
End: time.Now(),
}, w.logging)
}(time.Now())
return w.db().Exec(sqlStr, args...)
}
// ExecN wrapper *sqlx.ExecN
func (w *DB) ExecN(query string, args interface{}) (result sql.Result, err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Sql: query,
Args: args,
Err: err,
Start: start,
End: time.Now(),
}, w.logging)
}(time.Now())
return w.db().ExecN(query, args)
}
//QueryX 查询多条记录。
// 参数 <querySql> 是具体的SQL语句;参数 <args> 是SQL语句需要的参数。
// 注意: 返回的是 *sqlx.Rows 而不是 *sql.Rows
func (w *DB) QueryX(querySql string, args ...interface{}) (rows *sqlx.Rows, err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Sql: querySql,
Args: args,
Err: err,
Start: start,
End: time.Now(),
}, w.logging)
}(time.Now())
query, newArgs, err := w.argsIn(querySql, args)
if err != nil {
return nil, err
}
return w.db().QueryX(query, newArgs...)
}
//QueryRowX 查询一条记录。
// 参数 <querySql> 是具体的SQL语句;参数 <args> 是SQL语句需要的参数。
// 注意: 返回的是 *sqlx.Row 而不是 *sql.Row
func (w *DB) QueryRowX(querySql string, args ...interface{}) (rows *sqlx.Row) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Sql: querySql,
Args: args,
Err: rows.Err(),
Start: start,
End: time.Now(),
}, w.logging)
}(time.Now())
query, newArgs, _ := w.argsIn(querySql, args)
return w.db().QueryRowX(query, newArgs...)
}
//Take 快捷方式,查询一条记录,并将查询结果扫描进 <pointer>。
// 参数 <querySql> 是具体的SQL语句;参数 <args> 是SQL语句需要的参数。
func (w *DB) Take(pointer interface{}, querySql string, args ...interface{}) (err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Sql: querySql,
Args: args,
Err: err,
Start: start,
End: time.Now(),
}, w.logging)
}(time.Now())
wrapper, ok := pointer.(*ModelWrapper)
if ok {
pointer = wrapper.model
}
hook := NewHook(nil, w)
refVal := reflect.ValueOf(pointer)
hook.callMethod("BeforeFind", refVal)
query, newArgs, err := w.argsIn(querySql, args)
if err != nil {
return err
}
err = w.db().Take(pointer, query, newArgs...)
if err != nil {
return err
}
if reflect.Indirect(refVal).Kind() == reflect.Struct {
// relation data fill
err = RelationOne(wrapper, w, pointer)
}
if err != nil {
return err
}
hook.callMethod("AfterFind", refVal)
if hook.HasError() {
return hook.Error()
}
return nil
}
//Select 快捷方式,查询多条记录,并将查询结果扫描进 <pointer>。
// 参数 <querySql> 是具体的SQL语句;参数 <args> 是SQL语句需要的参数。
func (w *DB) Select(pointer interface{}, querySql string, args ...interface{}) (err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Sql: querySql,
Args: args,
Err: err,
Start: start,
End: time.Now(),
}, w.logging)
}(time.Now())
query, newArgs, err := w.argsIn(querySql, args)
if err != nil {
return err
}
wrapper, ok := pointer.(*ModelWrapper)
if ok {
pointer = wrapper.model
}
err = w.db().Select(pointer, query, newArgs...)
if err != nil {
return err
}
t := indirectType(reflect.TypeOf(pointer))
if t.Kind() == reflect.Slice {
if indirectType(t.Elem()).Kind() == reflect.Struct {
// relation data fill
err = RelationAll(wrapper, w, pointer)
}
}
if err != nil {
return err
}
return nil
}
// TxCtx the transaction with context
func (w *DB) TxCtx(ctx context.Context, fn func(ctx context.Context, tx *DB) error) (err error) {
tx, err := w.database.BeginTxx(ctx, nil)
if err != nil {
return err
}
defer func() {
if err != nil {
err1 := tx.Rollback()
if err1 != nil {
log.Printf("gosql rollback error:%s", err1)
}
}
}()
err = fn(ctx, &DB{tx: tx})
if err == nil {
err = tx.Commit()
}
return
}
// Tx the transaction
func (w *DB) Tx(fn func(w *DB) error) (err error) {
tx, err := w.database.BeginX()
if err != nil {
return err
}
defer func() {
if err != nil {
err1 := tx.Rollback()
if err1 != nil {
log.Printf("gosql rollback error:%s", err1)
}
}
}()
err = fn(&DB{tx: tx})
if err == nil {
err = tx.Commit()
}
return
}
// Table database handler from to table name
// for example:
// sq.Use("db2").Table("users")
func (w *DB) Table(tableName string) *Mapper {
return &Mapper{db: w, SqlBuilder: SqlBuilder{table: tableName, dialect: newDialect(w.DriverName())}}
}
//Model database handler from to struct
// for example:
// sq.Use("db2").Model(&users{})
func (w *DB) Model(pointer interface{}) *Builder {
if v1, ok := pointer.(*ModelWrapper); ok {
return &Builder{wrapper: v1, model: v1.model, db: w, SqlBuilder: SqlBuilder{dialect: newDialect(w.DriverName())}}
} else {
return &Builder{model: pointer, db: w, SqlBuilder: SqlBuilder{dialect: newDialect(w.DriverName())}}
}
}
// Model database handler from to struct with context
// for example:
// sq.Use("db2").Ctx(ctx).Model(&users{})
func (w *DB) Ctx(ctx context.Context) *Builder {
return &Builder{db: w, SqlBuilder: SqlBuilder{dialect: newDialect(w.DriverName())}, ctx: ctx}
}
//Import sql文件中的SQL DDL
func (w *DB) Import(fileName string) ([]sql.Result, error) {
file, err := os.Open(fileName)
if err != nil {
return nil, err
}
defer file.Close()
var results []sql.Result
scanner := bufio.NewScanner(file)
semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, ';'); i >= 0 {
return i + 1, data[0:i], nil
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
scanner.Split(semiColSpliter)
for scanner.Scan() {
query := strings.Trim(scanner.Text(), " \t\n\r")
if len(query) > 0 {
result, err1 := w.db().Exec(query)
results = append(results, result)
if err1 != nil {
return nil, err1
}
}
}
return results, nil
}
// Relation 关联表构建器句柄。参数 name 与 BuildFunc 一一对应。
func (w *DB) Relation(name string, fn BuildFunc) *DB {
if w.RelationMap == nil {
w.RelationMap = make(map[string]BuildFunc)
}
w.RelationMap[name] = fn
return w
}
//Begin 开始默认数据库的事务,并返回 *sq.DB而不是 *sql.Tx。
func Begin() (*DB, error) {
return Use(defaultLink).Begin()
}
//Use 变更数据库。
func Use(db string) *DB {
return &DB{database: GetDB(db)}
}
//Exec default database
func Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
return Use(defaultLink).Exec(sqlStr, args...)
}
//Exec default database
func ExecN(sqlStr string, args ...interface{}) (sql.Result, error) {
return Use(defaultLink).ExecN(sqlStr, args)
}
//Query default database
func QueryX(querySql string, args ...interface{}) (*sqlx.Rows, error) {
return Use(defaultLink).QueryX(querySql, args...)
}
//QueryRow default database
func QueryRowX(querySql string, args ...interface{}) *sqlx.Row {
return Use(defaultLink).QueryRowX(querySql, args...)
}
//TxCtx default database the transaction with context
func TxCtx(ctx context.Context, fn func(ctx context.Context, tx *DB) error) error {
return Use(defaultLink).TxCtx(ctx, fn)
}
//Tx default database the transaction
func Tx(fn func(tx *DB) error) error {
return Use(defaultLink).Tx(fn)
}
//Get default database
func Take(pointer interface{}, querySql string, args ...interface{}) error {
return Use(defaultLink).Take(pointer, querySql, args...)
}
//Select default database
func Select(pointer interface{}, querySql string, args ...interface{}) error {
return Use(defaultLink).Select(pointer, querySql, args...)
}
// Import SQL DDL from io.Reader
func Import(f string) ([]sql.Result, error) {
return Use(defaultLink).Import(f)
}
// Relation association table builder handle
func Relation(name string, fn BuildFunc) *DB {
w := Use(defaultLink)
w.RelationMap = make(map[string]BuildFunc)
w.RelationMap[name] = fn
return w
}
// SetDefaultLink set default link name
func SetDefaultLink(db string) {
defaultLink = db
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/winie_admin/sq.git
git@gitee.com:winie_admin/sq.git
winie_admin
sq
sq
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385