diff --git a/attestation/ras/clientapi/clientapi.go b/attestation/ras/clientapi/clientapi.go index 7539badc9c2ef961f88ad297adcc9d990af08b8f..3251b41c43c2377875ed01dabc8cede4f7e3dba5 100644 --- a/attestation/ras/clientapi/clientapi.go +++ b/attestation/ras/clientapi/clientapi.go @@ -82,13 +82,11 @@ func (s *service) RegisterClient(ctx context.Context, in *RegisterClientRequest) s.Unlock() // get client config - c, err := config.CreateConfig("") - if err != nil { - return nil, err - } + c := config.GetDefault() hd := c.GetHBDuration() td := c.GetTrustDuration() + config.Save() return &RegisterClientReply{ ClientId: clientID, ClientConfig: &ClientConfig{ diff --git a/attestation/ras/config/config.go b/attestation/ras/config/config.go index 1ee15150a430cc36958189d57ba2e3edf09b7b9d..993d3f00449dd93d66ad9173e6a1d5b549063132 100644 --- a/attestation/ras/config/config.go +++ b/attestation/ras/config/config.go @@ -1,5 +1,5 @@ /* -Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2021. kunpengsecl licensed under the Mulan PSL v2. You can use this software according to the terms and conditions of the Mulan PSL v2. You may obtain a copy of Mulan PSL v2 at: @@ -17,148 +17,188 @@ Description: Store RAS and RAC configurations. package config import ( - "errors" - "fmt" - "github.com/spf13/viper" "time" + + "github.com/spf13/viper" ) var defaultConfigPath = []string{ - "./", + ".", + "./ras/config", + "$HOME/.config/attestation", + "/usr/lib/attestation", + "/etc/attestation", } -type Config struct { - path string - rasConfig RASConfig - racConfig RACConfig -} type ( - RASConfig struct { + dbConfig struct { + host string + dbName string + user string + password string + port int + } + rasConfig struct { mgrStrategy string changeTime time.Time } - - RACConfig struct { + racConfig struct { hbDuration time.Duration // heartbeat duration trustDuration time.Duration // trust state duration } + config struct { + dbConfig + rasConfig + racConfig + } ) -var config *Config +var cfg *config /* - CreateConfig creates Config object if it was not initialized. - The path can't be empty when CreateConfig is called first time. Normally it is called when app initializes. - The path can't be modified when app running. - If the path is empty, CreateConfig return current Config object. So after called for the first time, - the parameter is always "". + GetDefault returns the global default config object. + It searches the defaultConfigPath to find the first matched config.yaml. + if it doesn't find any one, it returns the default values by code. */ -func CreateConfig(path string) (*Config, error) { - if config != nil && path == "" { - return config, nil - } - if config != nil && path != "" { - return config, errors.New("the parameter must be empty because config has been initialized") - } - if config == nil && path == "" { - return config, errors.New("the parameter can't be empty for initializing config") +func GetDefault() *config { + if cfg != nil { + return cfg } + viper.SetConfigName("config") viper.SetConfigType("yaml") - viper.AddConfigPath(path) - err := viper.ReadInConfig() - if err != nil { - fmt.Println("failed to read config file") - return nil, err - } - hbDuration := viper.GetDuration("racConfig.hbDuration") - trustDuration := viper.GetDuration("racConfig.trustDuration") - mgrStrategy := viper.GetString("rasConfig.mgrStrategy") - changeTime := viper.GetTime("rasConfig.changeTime") - c := &Config{ - path: path, - rasConfig: RASConfig{ - mgrStrategy: mgrStrategy, - changeTime: changeTime, - }, - racConfig: RACConfig{ - hbDuration: hbDuration, - trustDuration: trustDuration, - }, + for _, s := range defaultConfigPath { + viper.AddConfigPath(s) } - return c, nil -} -// GetDefault returns the global default config object. -func GetDefault() *Config { - var config *Config - for _, p := range defaultConfigPath { - config, _ = CreateConfig(p) + err := viper.ReadInConfig() + if err == nil { + cfg = &config{ + dbConfig: dbConfig{ + host: viper.GetString("database.host"), + dbName: viper.GetString("database.dbname"), + user: viper.GetString("database.user"), + password: viper.GetString("database.password"), + port: viper.GetInt("database.port"), + }, + rasConfig: rasConfig{ + mgrStrategy: viper.GetString("rasConfig.mgrStrategy"), + changeTime: viper.GetTime("rasConfig.changeTime"), + }, + racConfig: racConfig{ + hbDuration: viper.GetDuration("racConfig.hbDuration"), + trustDuration: viper.GetDuration("racConfig.trustDuration"), + }, + } + return cfg } - if config == nil { - config = &Config{ - rasConfig: RASConfig{ + + if cfg == nil { + cfg = &config{ + dbConfig: dbConfig{ + host: "localhost", + dbName: "test", + user: "", + password: "", + port: 5432, + }, + rasConfig: rasConfig{ mgrStrategy: "auto", changeTime: time.Now(), }, - racConfig: RACConfig{ + racConfig: racConfig{ hbDuration: 10 * time.Second, trustDuration: 120 * time.Second, }, } } - return config + return cfg } -func (c *Config) GetHBDuration() time.Duration { - return c.racConfig.hbDuration +// Save saves all config variables to the config.yaml file. +func Save() { + if cfg != nil { + viper.Set("database.host", cfg.host) + viper.Set("database.dbname", cfg.dbName) + viper.Set("database.user", cfg.user) + viper.Set("database.password", cfg.password) + viper.Set("database.port", cfg.port) + viper.Set("racConfig.hbDuration", cfg.hbDuration) + viper.Set("racConfig.trustDuration", cfg.trustDuration) + viper.Set("rasConfig.mgrStrategy", cfg.mgrStrategy) + viper.Set("rasConfig.changeTime", cfg.changeTime) + err := viper.WriteConfig() + if err != nil { + _ = viper.SafeWriteConfig() + } + } } -/* - SetHBDuration just set hbDuration for now, it can't change the config file. - If you want to change the config file, please use ChangeConfig. -*/ -func (c *Config) SetHBDuration(d time.Duration) { - c.racConfig.hbDuration = d +func (c *config) GetHost() string { + return c.host } -/* - ChangeConfig just change the config file, it can't set config value for now. - If you want to set current config value, please use SetXXX function. -*/ -func (c *Config) ChangeConfig(hbDuration time.Duration, trustDuration time.Duration, mgrStrategy string) error { - viper.Set("racConfig.hbDuration", hbDuration) - viper.Set("racConfig.trustDuration", trustDuration) - viper.Set("rasConfig.mgrStrategy", mgrStrategy) - viper.Set("rasConfig.changeTime", time.Now()) - err := viper.WriteConfig() - if err != nil { - return err - } - return nil +func (c *config) SetHost(host string) { + c.host = host +} + +func (c *config) GetDBName() string { + return c.dbName +} + +func (c *config) SetDBName(dbName string) { + c.dbName = dbName +} + +func (c *config) GetUser() string { + return c.user +} + +func (c *config) SetUser(user string) { + c.user = user +} + +func (c *config) GetPassword() string { + return c.password +} + +func (c *config) SetPassword(password string) { + c.password = password +} + +func (c *config) GetPort() int { + return c.port +} + +func (c *config) SetPort(port int) { + c.port = port +} + +func (c *config) GetHBDuration() time.Duration { + return c.hbDuration } -func (c *Config) GetTrustDuration() time.Duration { - return c.racConfig.trustDuration +func (c *config) SetHBDuration(d time.Duration) { + c.hbDuration = d } -func (c *Config) SetTrustDuration(d time.Duration) { - c.racConfig.trustDuration = d +func (c *config) GetTrustDuration() time.Duration { + return c.trustDuration } -func (c *Config) GetMgrStrategy() string { - return c.rasConfig.mgrStrategy +func (c *config) SetTrustDuration(d time.Duration) { + c.trustDuration = d } -func (c *Config) SetMgrStrategy(s string) { - c.rasConfig.mgrStrategy = s - c.rasConfig.changeTime = time.Now() +func (c *config) GetMgrStrategy() string { + return c.mgrStrategy } -func (c *Config) GetChangeTime() time.Time { - return c.rasConfig.changeTime +func (c *config) SetMgrStrategy(s string) { + c.mgrStrategy = s + c.changeTime = time.Now() } -func (c *Config) SetChangeTime(n time.Time) { - c.rasConfig.changeTime = n +func (c *config) GetChangeTime() time.Time { + return c.changeTime } diff --git a/attestation/ras/config/config.yaml b/attestation/ras/config/config.yaml index b97ff302d779bd9a9f50ed44b0e5ba737ea7a52d..58b0908c62c50b7af62c664ab539c8b7da1a92f0 100644 --- a/attestation/ras/config/config.yaml +++ b/attestation/ras/config/config.yaml @@ -1,11 +1,11 @@ database: dbname: kunpengsecl host: localhost - password: + password: "" port: 5432 - user: + user: "" racconfig: - hbduration: 2s + hbduration: 3s trustduration: 2m0s rasconfig: changetime: 2021-09-30T11:53:24.0581136+08:00 diff --git a/attestation/ras/config/config_test.go b/attestation/ras/config/config_test.go index f24d224adf6de196c56f6517bab60d2e04333ded..d223e3e4c303113506979b839bb6b62670a130ef 100644 --- a/attestation/ras/config/config_test.go +++ b/attestation/ras/config/config_test.go @@ -22,22 +22,6 @@ func TestRASConfig(t *testing.T) { t.Errorf("test mgrStrategy error at case %d\n", i) } } - - now := time.Now() - hLate := now.Add(time.Hour * 12) - testCases2 := []struct { - input time.Time - result time.Time - }{ - {now, now}, - {hLate, hLate}, - } - for i := 0; i < len(testCases2); i++ { - config.SetChangeTime(testCases2[i].input) - if config.GetChangeTime() != testCases2[i].result { - t.Errorf("test changeTime error at case %d\n", i) - } - } } func TestRACConfig(t *testing.T) { diff --git a/attestation/ras/dao/postgresqldao.go b/attestation/ras/dao/postgresqldao.go index cd50b97a1e5bf911936de9677abe9f39ba3d34df..a697ce63dee7dbd91649e099de9184071b8dc111 100644 --- a/attestation/ras/dao/postgresqldao.go +++ b/attestation/ras/dao/postgresqldao.go @@ -3,23 +3,24 @@ package dao import ( "context" "fmt" + "time" + + "gitee.com/openeuler/kunpengsecl/attestation/ras/config" "gitee.com/openeuler/kunpengsecl/attestation/ras/entity" "github.com/jackc/pgx/v4" - "github.com/spf13/viper" - "os" - "time" ) /* PostgreSqlDAO implements dao. conn is a connection initialized in CreatePostgreSqlDAO() and destroyed in Destroy() - */ +*/ type PostgreSqlDAO struct { conn *pgx.Conn } + /* SaveReport use conn to execute a transaction for insert data into tables - */ +*/ func (psd *PostgreSqlDAO) SaveReport(report *entity.Report) error { var reportId int64 tx, err := psd.conn.Begin(context.Background()) @@ -98,7 +99,7 @@ func (psd *PostgreSqlDAO) SaveReport(report *entity.Report) error { /* RegisterClient start a transaction, first insert clientInfo into table client_info, get the client_info_id and then insert data into table register_client. - */ +*/ func (psd *PostgreSqlDAO) RegisterClient(clientInfo *entity.ClientInfo, ic string) (int64, error) { var clientInfoId int64 var clientId int64 @@ -112,7 +113,7 @@ func (psd *PostgreSqlDAO) RegisterClient(clientInfo *entity.ClientInfo, ic strin Using serial to generate client_info_id is impossible because clientInfo is a map, several rows have the same client_info_id. Here we use table client_info_id to record client_info_id value. - */ + */ err = tx.QueryRow(context.Background(), "INSERT INTO client_info_id(online) VALUES ($1) RETURNING id", true).Scan(&clientInfoId) if err != nil { @@ -133,8 +134,8 @@ func (psd *PostgreSqlDAO) RegisterClient(clientInfo *entity.ClientInfo, ic strin // Insert data into register_client err = tx.QueryRow(context.Background(), - "INSERT INTO register_client(client_info_id, register_time, ak_certificate) " + - "VALUES ($1, $2, $3) RETURNING client_id", clientInfoId, time.Now(), ic).Scan(&clientId) + "INSERT INTO register_client(client_info_id, register_time, ak_certificate) "+ + "VALUES ($1, $2, $3) RETURNING client_id", clientInfoId, time.Now(), ic).Scan(&clientId) if err != nil { tx.Rollback(context.Background()) return 0, err @@ -148,7 +149,7 @@ func (psd *PostgreSqlDAO) RegisterClient(clientInfo *entity.ClientInfo, ic strin /* UnRegisterClient delete client data from table client_info, client_info_id and register_client. - */ +*/ func (psd *PostgreSqlDAO) UnRegisterClient(clientId int64) error { var clientInfoId int64 tx, err := psd.conn.Begin(context.Background()) @@ -184,7 +185,7 @@ func (psd *PostgreSqlDAO) UnRegisterClient(clientId int64) error { return nil } -func (psd *PostgreSqlDAO) SelectAllClientIds() ([]int64, error){ +func (psd *PostgreSqlDAO) SelectAllClientIds() ([]int64, error) { var clientIds []int64 rows, err := psd.conn.Query(context.Background(), "SELECT client_id FROM register_client") @@ -202,42 +203,20 @@ func (psd *PostgreSqlDAO) SelectAllClientIds() ([]int64, error){ return clientIds, nil } -func (psd *PostgreSqlDAO) init() error { - // use viper to read config for connecting database - path, err:= os.Getwd() - if err != nil { - return err - } - path = path + "\\..\\config" - viper.SetConfigName("config") - viper.SetConfigType("yaml") - viper.AddConfigPath(path) - err = viper.ReadInConfig() - if err != nil { - fmt.Println("读取配置文件失败") - return err - } - host := viper.GetString("database.host") - port := viper.GetString("database.port") - dbname := viper.GetString("database.dbname") - user := viper.GetString("database.user") - password := viper.GetString("database.password") - url := "postgres://"+user+":"+password+"@"+host+":"+port+"/"+dbname - psd.conn, err = pgx.Connect(context.Background(), url) + +// CreatePostgreSqlDAO creates a postgre database connection to read and store data. +func CreatePostgreSqlDAO() (*PostgreSqlDAO, error) { + host := config.GetDefault().GetHost() + port := config.GetDefault().GetPort() + dbname := config.GetDefault().GetDBName() + user := config.GetDefault().GetUser() + password := config.GetDefault().GetPassword() + url := fmt.Sprintf("postgres://%s:%s@%s:%d/%s", user, password, host, port, dbname) + c, err := pgx.Connect(context.Background(), url) if err != nil { - fmt.Println("数据库连接失败") - return err + return nil, err } - return nil -} - -/* - use factory mode to get a pointer of PostgreSqlDAO - */ -func CreatePostgreSqlDAO() *PostgreSqlDAO { - psd := new(PostgreSqlDAO) - psd.init() - return psd + return &PostgreSqlDAO{conn: c}, nil } func (psd *PostgreSqlDAO) Destroy() { diff --git a/attestation/ras/dao/postgresqldao_test.go b/attestation/ras/dao/postgresqldao_test.go index c36eb8f6168b0ef2f9adb14ba285fd0b7e50b2e9..62f3623d6c20b11d54c867aba1f7b404f02ffb4c 100644 --- a/attestation/ras/dao/postgresqldao_test.go +++ b/attestation/ras/dao/postgresqldao_test.go @@ -2,17 +2,21 @@ package dao import ( "fmt" + "testing" + "gitee.com/openeuler/kunpengsecl/attestation/ras/entity" "github.com/stretchr/testify/assert" - "testing" ) func TestPostgreSqlDAOSaveReport(t *testing.T) { - psd := CreatePostgreSqlDAO() - + psd, err := CreatePostgreSqlDAO() + if err != nil { + t.Fatalf("%v", err) + return + } pcrInfo := entity.PcrInfo{ Algorithm: 1, - Values: []entity.PcrValue{ + Values: []entity.PcrValue{ 1: { Id: 1, Value: "pcr value 1", @@ -22,7 +26,7 @@ func TestPostgreSqlDAOSaveReport(t *testing.T) { Value: "pcr value 2", }, }, - Quote: []byte("test quote"), + Quote: []byte("test quote"), } biosItem1 := entity.ManifestItem{ @@ -44,7 +48,7 @@ func TestPostgreSqlDAOSaveReport(t *testing.T) { } biosManifest := entity.Manifest{ - Type: "bios", + Type: "bios", Items: []entity.ManifestItem{ 1: biosItem1, 2: biosItem2, @@ -52,23 +56,23 @@ func TestPostgreSqlDAOSaveReport(t *testing.T) { } imaManifest := entity.Manifest{ - Type: "ima", + Type: "ima", Items: []entity.ManifestItem{ 1: imaItem1, }, } testReport := &entity.Report{ - PcrInfo: pcrInfo, - Manifest: []entity.Manifest{ + PcrInfo: pcrInfo, + Manifest: []entity.Manifest{ 1: biosManifest, 2: imaManifest, }, - ClientId: 1, + ClientId: 1, ClientInfo: entity.ClientInfo{ Info: map[string]string{ - "client_name": "test_client", - "client_type": "test_type", + "client_name": "test_client", + "client_type": "test_type", "client_description": "test description", }, }, @@ -83,22 +87,30 @@ func TestPostgreSqlDAOSaveReport(t *testing.T) { } func TestRegisterClient(t *testing.T) { - psd := CreatePostgreSqlDAO() + psd, err := CreatePostgreSqlDAO() + if err != nil { + t.Fatalf("%v", err) + return + } ci := &entity.ClientInfo{ Info: map[string]string{ - "info name1" : "info value1", - "info name2" : "info value2", + "info name1": "info value1", + "info name2": "info value2", }, } ic := "test ic" - _, err := psd.RegisterClient(ci, ic) - if err != nil { + _, err2 := psd.RegisterClient(ci, ic) + if err2 != nil { t.FailNow() } } func TestUnRegisterClient(t *testing.T) { - psd := CreatePostgreSqlDAO() + psd, err := CreatePostgreSqlDAO() + if err != nil { + t.Fatalf("%v", err) + return + } clientIds, err := psd.SelectAllClientIds() if err != nil { fmt.Println(err) @@ -115,4 +127,4 @@ func TestUnRegisterClient(t *testing.T) { t.FailNow() } assert.NotEqual(t, clientIds[0], newClientIds[0]) -} \ No newline at end of file +}