diff --git a/.github/workflows/go-check.yml b/.github/workflows/go-check.yml new file mode 100644 index 0000000..0a4f17f --- /dev/null +++ b/.github/workflows/go-check.yml @@ -0,0 +1,45 @@ +name: Go Test cmd/checks + +on: + push: + branches: [ "*" ] + paths: + - cmd/checks/** + - internal/** + pull_request: + branches: [ "main" ] + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + - name: Test + run: go test -v ./cmd/check/... + + go_sec: + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v3 + - name: Run Gosec Security Scanner + uses: securego/gosec@master + with: + args: ./cmd/check/... + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + + - name: Build + run: go build -v ./cmd/check/... diff --git a/.github/workflows/go-dns.yml b/.github/workflows/go-dns.yml new file mode 100644 index 0000000..63ee950 --- /dev/null +++ b/.github/workflows/go-dns.yml @@ -0,0 +1,45 @@ +name: Go Test cmd/dns-sniffer + +on: + push: + branches: [ "*" ] + paths: + - cmd/dns-sniffer/** + - internal/** + pull_request: + branches: [ "main" ] + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + - name: Test + run: go test -v ./cmd/dns-sniffer/... + + go_sec: + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v3 + - name: Run Gosec Security Scanner + uses: securego/gosec@master + with: + args: ./cmd/dns-sniffer/... + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + + - name: Build + run: go build -v ./cmd/dns-sniffer/... diff --git a/.github/workflows/go-dpi.yml b/.github/workflows/go-dpi.yml new file mode 100644 index 0000000..5279585 --- /dev/null +++ b/.github/workflows/go-dpi.yml @@ -0,0 +1,45 @@ +name: Go Test cmd/dpi-sniffer + +on: + push: + branches: [ "*" ] + paths: + - cmd/dpi-sniffer/** + - internal/** + pull_request: + branches: [ "main" ] + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + - name: Test + run: go test -v ./cmd/dpi-sniffer/... + + go_sec: + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v3 + - name: Run Gosec Security Scanner + uses: securego/gosec@master + with: + args: ./cmd/dpi-sniffer/... + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + + - name: Build + run: go build -v ./cmd/dpi-sniffer/... diff --git a/.github/workflows/go-get-rkn.yml b/.github/workflows/go-get-rkn.yml new file mode 100644 index 0000000..46e2f32 --- /dev/null +++ b/.github/workflows/go-get-rkn.yml @@ -0,0 +1,45 @@ +name: Go Test cmd/get_rkn + +on: + push: + branches: [ "*" ] + paths: + - cmd/get_rkn/** + - internal/** + pull_request: + branches: [ "main" ] + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + - name: Test + run: go test -v ./cmd/get_rkn/... + + go_sec: + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v3 + - name: Run Gosec Security Scanner + uses: securego/gosec@master + with: + args: ./cmd/get_rkn/... + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + + - name: Build + run: go build -v ./cmd/get_rkn/... diff --git a/README.md b/README.md index ae52070..4673a90 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,15 @@ -## DNS/DPI sniffers and NFT-tables rules +# DNS/DPI sniffers and NFT-tables rules +[![Go Test cmd/checks](https://github.com/sir-go/rkn-rejects/actions/workflows/go-check.yml/badge.svg)](https://github.com/sir-go/rkn-rejects/actions/workflows/go-check.yml) +[![Go Test cmd/dns](https://github.com/sir-go/rkn-rejects/actions/workflows/go-dns.yml/badge.svg)](https://github.com/sir-go/rkn-rejects/actions/workflows/go-dns.yml) +[![Go Test cmd/dpi](https://github.com/sir-go/rkn-rejects/actions/workflows/go-dpi.yml/badge.svg)](https://github.com/sir-go/rkn-rejects/actions/workflows/go-dpi.yml) +[![Go Test cmd/get-rkn](https://github.com/sir-go/rkn-rejects/actions/workflows/go-get-rkn.yml/badge.svg)](https://github.com/sir-go/rkn-rejects/actions/workflows/go-get-rkn.yml) + The parental control project contains four utilities to get white and black lists from RKN service and completely isolate one certain host from denied resources. Utilities installed at the router between the host and uplink. -### Utilities - +## Utilities - [get_rkn](cmd/get_rkn) SOAP-client for [RKN service](https://vigruzki.rkn.gov.ru/services/OperatorRequest/?wsdl), @@ -28,8 +32,7 @@ the host and uplink. It can be run from the docker container for routing all traffic through the router's firewall. -### NF-tables - +## NF-tables Traffic to sniffers redirects by the `nf_queue` kernel module. All traffic rejects by default except DNS requests and answers. NF-tables rules have a list of allowed IP addresses. Every record in the list has a TTL and deletes when this time is expired. diff --git a/cmd/check/README.md b/cmd/check/README.md index d9c10f6..a42f5b4 100644 --- a/cmd/check/README.md +++ b/cmd/check/README.md @@ -1,7 +1,7 @@ -## Check - -### What it does +# Check +[![Go](https://github.com/sir-go/rkn-rejects/actions/workflows/go-check.yml/badge.svg)](https://github.com/sir-go/rkn-rejects/actions/workflows/go-check.yml) +## What it does - get targets for checking from the redis set - run bunch of workers - wait for all workers are done @@ -9,28 +9,16 @@ Each worker is an HTTP client that tries to get data from the target resource and log the result to the logfile if the target is accessible. -### Build +## Tests ```bash -go mod download -go build -o check ./cmd/check; +go test -v ./cmd/check/... +gosec ./cmd/check/... ``` -If the check will run on the same host that the firewall does, -it should run from the docker container. - +## Docker ```bash docker build . -t check -``` - -### Run - -Standalone -```bash -check -w 25 t 20s -lt 10s -d /tmp/checks -``` -Docker -```bash docker run -it --rm \ -v /tmp/checks:/var/log/checks \ --dns 195.208.4.1 \ @@ -39,6 +27,19 @@ docker run -it --rm \ -d /var/log/checks ``` +## Build +```bash +go mod download +go build -o check ./cmd/check; +``` +If the check will run on the same host that the firewall does, +it should run from the docker container. + +## Run +```bash +check -w 25 t 20s -lt 10s -d /tmp/checks +``` + ### Flags | key | default | description | @@ -56,5 +57,4 @@ docker run -it --rm \ | -m | -1(inf) | checks amount limit | | -t | 3s | check TCP timeout | | -lt | 10s | log polling interval | -| -o | stdout | buffered log output | | -d | /tmp | path to the logs for each check | diff --git a/cmd/check/checker.go b/cmd/check/checker.go index c6417ec..0800860 100644 --- a/cmd/check/checker.go +++ b/cmd/check/checker.go @@ -1,5 +1,7 @@ package main +// Checking target worker. Gets target from the channel, makes a request and issues a verdict + import ( "io/ioutil" "net/http" @@ -12,6 +14,7 @@ import ( ) type ( + // is resource accessible verdict struct { opened bool hash string @@ -20,15 +23,18 @@ type ( } ) +// uncommented rows in the list var reTarget = regexp.MustCompile(`^[^#]((.*)\|)?(.*://)?(.*)`) +// dump saves a verdict to a file in the specified directory func (v *verdict) dump(vDir string) { - err := ioutil.WriteFile(path.Join(vDir, v.hash), v.raw, 0666) + err := ioutil.WriteFile(path.Join(vDir, v.hash), v.raw, 0600) if err != nil { log.Panicln("dump verdict", err) } } +// check does check the target address accessibility and returns a verdict struct func check(target string, timeout time.Duration) (v verdict) { var ( err error @@ -60,8 +66,12 @@ func check(target string, timeout time.Duration) (v verdict) { return } -func Checker(wg *sync.WaitGroup, timeout time.Duration, targets <-chan string, - verdicts chan<- verdict) { +// Checker starts a checking process, reads a target from the targets channel +//and stores verdicts to the verdicts channel +func Checker(wg *sync.WaitGroup, timeout time.Duration, targets <-chan string, verdicts chan<- verdict) { + if wg == nil { + return + } for t := range targets { verdicts <- check(t, timeout) time.Sleep(CFG.Sleeps) diff --git a/cmd/check/checker_test.go b/cmd/check/checker_test.go new file mode 100644 index 0000000..515ed0e --- /dev/null +++ b/cmd/check/checker_test.go @@ -0,0 +1,93 @@ +package main + +import ( + "bytes" + "io/ioutil" + "os" + "path/filepath" + "testing" + "time" +) + +// checks if dump file is created and dumped data equals for verdict's raw data +func Test_verdict_dump(t *testing.T) { + type args struct { + vDir string + } + tests := []struct { + name string + args args + verdict verdict + }{ + {"e2e", + args{"_verdicts_dump_test"}, + verdict{ + true, + "some-verdict-hash", + "some-target", + []byte("some response content")}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ( + tmpDirName string + err error + ) + if tmpDirName, err = os.MkdirTemp("", tt.args.vDir); err != nil { + t.Errorf("can't make a directory %s. %v", tt.args.vDir, err) + } + tt.verdict.dump(tmpDirName) + dumpFile := filepath.Clean(filepath.Join(tmpDirName, tt.verdict.hash)) + verdictBytes, err := ioutil.ReadFile(dumpFile) + if err != nil { + t.Errorf("can't read the dump file %s, %v", dumpFile, err) + } + if !bytes.Equal(verdictBytes, tt.verdict.raw) { + t.Errorf("dumped data %s, want %s", verdictBytes, tt.verdict.raw) + } + }) + } +} + +func Test_check(t *testing.T) { + type args struct { + target string + timeout time.Duration + } + tests := []struct { + name string + args args + wantV verdict + }{ + {"google.com", args{"google.com", 10 * time.Second}, verdict{ + opened: true, + target: "google.com", + raw: []byte("200 OK\n"), + }}, + {"example.com", args{"example.com", 10 * time.Second}, verdict{ + opened: true, + target: "example.com", + raw: []byte("200 OK\n"), + }}, + {"non-exist", args{"some-non-exist-target-url.es", 10 * time.Second}, verdict{ + opened: false, + target: "", + raw: nil, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotV := check(tt.args.target, tt.args.timeout) + if gotV.target != tt.wantV.target { + t.Errorf("check(); verdict.target = %v, want %v", gotV.target, tt.wantV.target) + } + if gotV.opened != tt.wantV.opened { + t.Errorf("check(); verdict.opened = %v, want %v", gotV.opened, tt.wantV.opened) + } + if !bytes.HasPrefix(gotV.raw, tt.wantV.raw) { + t.Errorf("check(); verdict.raw begins with %s, want %s", gotV.raw, tt.wantV.raw) + } + }) + } +} diff --git a/cmd/check/config.go b/cmd/check/config.go index ab0d590..adbde8c 100644 --- a/cmd/check/config.go +++ b/cmd/check/config.go @@ -1,5 +1,7 @@ package main +// Running configuration. Parses running flags to the config struct. + import ( "encoding/json" "flag" @@ -9,16 +11,29 @@ import ( log "github.com/sirupsen/logrus" ) +// WorkersMax - hardcoded maximum amount of checkers const WorkersMax = 75 type ( + // Cfg stores the whole running configuration Cfg struct { - Workers int `json:"workers,omitempty"` - Sleeps time.Duration `json:"sleeps,omitempty"` - Max int `json:"limit,omitempty"` - Timeout time.Duration `json:"timeout,omitempty"` + // checking workers amount + Workers int `json:"workers,omitempty"` + + // pause between checks for the worker + Sleeps time.Duration `json:"sleeps,omitempty"` + + // maximum amount of targets to check (limits the target list) + Max int `json:"limit,omitempty"` + + // timeout for the response for each target + Timeout time.Duration `json:"timeout,omitempty"` + + // how often to store the buffered log LogInterval time.Duration `json:"log_interval,omitempty"` - Redis struct { + + // redis b connection parameters + Redis struct { Host string `json:"host,omitempty"` Port int `json:"port,omitempty"` Password string `json:"-,omitempty"` @@ -27,9 +42,12 @@ type ( TimeoutConn time.Duration `json:"timeout_conn,omitempty"` TimeoutRead time.Duration `json:"timeout_read,omitempty"` } `json:"redis"` - Out string `json:"out,omitempty"` + + // path of the directory to store verdicts VerdictsOutDir string `json:"verdicts_out_dir,omitempty"` - LogLevel string `json:"log_level,omitempty"` + + // logging level + LogLevel string `json:"log_level,omitempty"` } ) @@ -64,10 +82,10 @@ func initConfig() *Cfg { "redis set key for checks") flag.DurationVar(&cfg.Redis.TimeoutConn, "rtc", time.Second*15, - "radis connection timeout") + "redis connection timeout") flag.DurationVar(&cfg.Redis.TimeoutRead, "rtr", time.Second*15, - "radis read timeout") + "redis read timeout") flag.StringVar(&cfg.LogLevel, "log", "info", "log level [panic < fatal < error < warn < info < debug < trace]") @@ -87,9 +105,6 @@ func initConfig() *Cfg { flag.DurationVar(&cfg.LogInterval, "lt", time.Second*10, "log progress interval") - flag.StringVar(&cfg.Out, "o", "stdout", - "buffered log file path") - flag.StringVar(&cfg.VerdictsOutDir, "d", "/tmp/", "directory for per-record verdict logs") diff --git a/cmd/check/delayLog.go b/cmd/check/delayLog.go index 2e4f2b0..881c769 100644 --- a/cmd/check/delayLog.go +++ b/cmd/check/delayLog.go @@ -1,16 +1,20 @@ package main +// Buffered logs storage. Collects the log records and flushes them to the writer. + import ( "io" log "github.com/sirupsen/logrus" ) +// LogBuff stores logging records and periodically flushes them to the writer type LogBuff struct { records []string w io.Writer } +// add pushes a log message to the buffer records array func (l *LogBuff) add(msg string) { for _, r := range l.records { if r == msg { @@ -20,6 +24,7 @@ func (l *LogBuff) add(msg string) { l.records = append(l.records, msg) } +// flush writes all of the stored records to the writer func (l *LogBuff) flush() { for _, r := range l.records { if _, err := io.WriteString(l.w, r+"\n"); err != nil { diff --git a/cmd/check/delayLog_test.go b/cmd/check/delayLog_test.go new file mode 100644 index 0000000..8de0f4f --- /dev/null +++ b/cmd/check/delayLog_test.go @@ -0,0 +1,53 @@ +package main + +import ( + "bytes" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestLogBuff_add(t *testing.T) { + type args struct { + msg string + } + tests := []struct { + name string + args args + wantRecords []string + }{ + {"empty", args{""}, []string{""}}, + {"msg", args{"some message"}, []string{"some message"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lB := &LogBuff{} + lB.add(tt.args.msg) + if !cmp.Equal(lB.records, tt.wantRecords) { + t.Errorf("after LogBuff.add() records are %v, want %v", lB.records, tt.wantRecords) + } + }) + } +} + +func TestLogBuff_flush(t *testing.T) { + tests := []struct { + name string + content []string + }{ + {"empty", []string{""}}, + {"ok", []string{"some record", "another one record"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stringsBuf := bytes.NewBufferString("") + lB := &LogBuff{tt.content, stringsBuf} + lB.flush() + wantFlushed := strings.Join(tt.content, "\n") + "\n" + if stringsBuf.String() != wantFlushed { + t.Errorf("after LogBuff.flush() flushed data %s, want %s", stringsBuf.String(), wantFlushed) + } + }) + } +} diff --git a/cmd/check/init.go b/cmd/check/init.go index 8d38d8a..7b96c98 100644 --- a/cmd/check/init.go +++ b/cmd/check/init.go @@ -1,5 +1,7 @@ package main +// Initialize logging and interruptions. + import ( "fmt" "os" @@ -22,8 +24,8 @@ func (l *LogFormat) Format(entry *log.Entry) ([]byte, error) { } var ( - CFG *Cfg - lBuff LogBuff + CFG *Cfg // global config + lBuff LogBuff // buffered logging ) func InitInterrupt(tearDown func()) { @@ -52,13 +54,4 @@ func init() { log.SetFormatter(&LogFormat{}) log.SetLevel(log.DebugLevel) InitInterrupt(Stop) - CFG = initConfig() - - if CFG.LogLevel != "debug" { - logLevel, err := log.ParseLevel(CFG.LogLevel) - if err != nil { - log.Panicln("parsing LogLevel error", err) - } - log.SetLevel(logLevel) - } } diff --git a/cmd/check/main.go b/cmd/check/main.go index 03f9efb..ca1ab52 100644 --- a/cmd/check/main.go +++ b/cmd/check/main.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "path/filepath" "sync" "time" @@ -14,15 +15,10 @@ import ( var lBuffPath string -func main() { - var ( - err error - wg sync.WaitGroup - ) - - logWriter := log.StandardLogger().Writer() +func prepareLogBuffer() (logWriter *io.PipeWriter) { + logWriter = log.StandardLogger().Writer() if lBuffPath != "" && lBuffPath != "stdout" { - fd, err := os.Create(lBuffPath) + fd, err := os.Create(filepath.Clean(lBuffPath)) if err != nil { log.Panicln("can't create log buffer out file", err) } @@ -35,17 +31,11 @@ func main() { } else { lBuff.w = logWriter } + return +} - wg.Add(CFG.Workers) - - targets := make(chan string) - verdicts := make(chan verdict) - waitCh := make(chan struct{}) - - for i := CFG.Workers; i > 0; i-- { - go Checker(&wg, CFG.Timeout, targets, verdicts) - } - +func getTargets() []string { + // init redis db connection rdb := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", CFG.Redis.Host, CFG.Redis.Port), Password: CFG.Redis.Password, @@ -68,7 +58,40 @@ func main() { if CFG.Max > 0 && len(checkRecords) > CFG.Max { checkRecords = checkRecords[:CFG.Max] } + return checkRecords +} +func main() { + CFG = initConfig() + + if CFG.LogLevel != "debug" { + logLevel, err := log.ParseLevel(CFG.LogLevel) + if err != nil { + log.Panicln("parsing LogLevel error", err) + } + log.SetLevel(logLevel) + } + + var ( + err error + wg sync.WaitGroup + ) + logWriter := prepareLogBuffer() + + // init channels and start workers + wg.Add(CFG.Workers) + + targets := make(chan string) + verdicts := make(chan verdict) + waitCh := make(chan struct{}) + + for i := CFG.Workers; i > 0; i-- { + go Checker(&wg, CFG.Timeout, targets, verdicts) + } + + checkRecords := getTargets() + + // push records to targets channel go func() { for _, chR := range checkRecords { targets <- chR @@ -83,13 +106,18 @@ func main() { progressOpened := 0 tick := time.NewTicker(CFG.LogInterval) + // main loop for { select { + + // processing verdicts case v := <-verdicts: progress++ if !v.opened { continue } + + // push the verdict to the logging buffer if a target has been opened if lBuff.w != logWriter { _, err = io.WriteString(lBuff.w, v.target+"\n") if err != nil { @@ -99,12 +127,18 @@ func main() { lBuff.add(v.target) } progressOpened++ + + // dump verdict to its file for the farther analyzing v.dump(CFG.VerdictsOutDir) + + // all workers are done case <-waitCh: log.Infof("%d (! %d)", progress, progressOpened) lBuff.flush() log.Infof("-- done --") return + + // update the progress case <-tick.C: log.Infof("%d (! %d)", progress, progressOpened) } diff --git a/cmd/dns-sniffer/README.md b/cmd/dns-sniffer/README.md index 3b3f743..1d484b8 100644 --- a/cmd/dns-sniffer/README.md +++ b/cmd/dns-sniffer/README.md @@ -1,21 +1,27 @@ -## DNS sniffer - -### What it does +# DNS sniffer +[![Go Test cmd/dns](https://github.com/sir-go/rkn-rejects/actions/workflows/go-dns.yml/badge.svg)](https://github.com/sir-go/rkn-rejects/actions/workflows/go-dns.yml) +## What it does - read all DNS answer packets from `nf_qeueue` - check a hostname in the answer - add IPs from A-record to the allowed nftables set with TTL if all IP addresses in the answer and the hostname are not found in the denied lists -### Build +## Tests +```bash +go test -v ./cmd/dns-sniffer/... +gosec ./cmd/dns-sniffer/... +``` + +## Build ```bash go mod download go build -o dns ./cmd/dns-sniffer; ``` +## Run ### Flags - | key | default | description | |--------|---------------|--------------------------------------| | -nfq | 100-103 | nf queue num range | diff --git a/cmd/dns-sniffer/cache.go b/cmd/dns-sniffer/cache.go index 5f7261b..0a11e52 100644 --- a/cmd/dns-sniffer/cache.go +++ b/cmd/dns-sniffer/cache.go @@ -1,5 +1,8 @@ package main +// Redis-based DNS answers cache, +// stores answers (resolved IP strings arrays) in a ZLists where score is an record expiration time + import ( "context" "strconv" @@ -9,20 +12,24 @@ import ( log "github.com/sirupsen/logrus" ) +// AllowsCache stores a db connection pointer and a db key type AllowsCache struct { rdb *redis.Client key string } +// NewAllows creates a new cache keeper structure func NewAllows(rdb *redis.Client, key string) *AllowsCache { return &AllowsCache{rdb, key} } +// sanitize removes all zero-score records func (ac *AllowsCache) sanitize(score int64) error { return ac.rdb.ZRemRangeByScore(context.Background(), ac.key, "0", strconv.FormatInt(score, 10)).Err() } +// Add stores a record (IP addr strings with score - expiration time) func (ac *AllowsCache) Add(val string, score int64) error { z := redis.Z{Score: float64(score), Member: val} return ac.rdb.ZAddArgs(context.Background(), @@ -32,14 +39,17 @@ func (ac *AllowsCache) Add(val string, score int64) error { }).Err() } +// Has checks if the cache stores the given IP string func (ac *AllowsCache) Has(val string) bool { return ac.rdb.ZScore(context.Background(), ac.key, val).Val() != 0 } +// Del removes the IP address from the cache func (ac *AllowsCache) Del(val string) error { return ac.rdb.ZRem(context.Background(), ac.key, val).Err() } +// RunSanitizer gets time ticks and cleanup cache (remove all records with expired time) func (ac *AllowsCache) RunSanitizer(timeout time.Duration) { tick := time.NewTicker(timeout + time.Second) for { diff --git a/cmd/dns-sniffer/config.go b/cmd/dns-sniffer/config.go index 44c199c..74811ad 100644 --- a/cmd/dns-sniffer/config.go +++ b/cmd/dns-sniffer/config.go @@ -4,19 +4,26 @@ import ( "encoding/json" "flag" "os" - "strconv" - "strings" "time" log "github.com/sirupsen/logrus" + + "rkn-rejects/internal/tools" ) type ( Cfg struct { - Queues []uint16 `json:"queues,omitempty"` - QMaxLen uint32 `json:"q_max_len,omitempty"` - MarkDone int `json:"mark,omitempty"` - Redis struct { + // nf queue IDs + Queues []uint16 `json:"queues,omitempty"` + + // maximum queue capacity + QMaxLen uint32 `json:"q_max_len,omitempty"` + + // 'done' packet marker ID + MarkDone int `json:"mark,omitempty"` + + // redis db connection parameters + Redis struct { Host string `json:"host,omitempty"` Port int `json:"port,omitempty"` Password string `json:"-,omitempty"` @@ -25,15 +32,22 @@ type ( TimeoutConn time.Duration `json:"timeout_conn,omitempty"` TimeoutRead time.Duration `json:"timeout_read,omitempty"` } `json:"redis"` + + // logging level LogLevel string `json:"log_level,omitempty"` - Dry bool `json:"dry,omitempty"` - Nf struct { + + // just configure and exit + Dry bool `json:"dry,omitempty"` + + // netfilter configuration + Nf struct { Table string `json:"table,omitempty"` SetName string `json:"set_name,omitempty"` } `json:"nf,omitempty"` } ) +// Config stringer func (c *Cfg) String() string { var ( b []byte @@ -51,34 +65,6 @@ func (c *Cfg) String() string { return string(b) } -// 100-103 -> [100 101 102 103] -func parseRange(s string, a *[]uint16) { - var ( - v0, v1 uint64 - err error - ) - if !strings.ContainsRune(s, '-') { - v0, err = strconv.ParseUint(s, 10, 16) - if err != nil { - log.Panicln("range parsing error:", err.Error()) - } - *a = []uint16{uint16(v0)} - return - } - - p := strings.Split(s, "-") - if v0, err = strconv.ParseUint(p[0], 10, 16); err != nil { - log.Panicln("range begin parsing error:", err.Error()) - } - if v1, err = strconv.ParseUint(p[1], 10, 16); err != nil { - log.Panicln("range end parsing error:", err.Error()) - } - for v0 <= v1 { - *a = append(*a, uint16(v0)) - v0++ - } -} - func initConfig() *Cfg { var ( queuesStr string @@ -129,7 +115,9 @@ func initConfig() *Cfg { flag.Parse() - parseRange(queuesStr, &cfg.Queues) + if err := tools.ParseRange(queuesStr, &cfg.Queues); err != nil { + panic(err) + } cfg.QMaxLen = uint32(qMaxLen) log.Info(cfg) if cfg.Dry { diff --git a/cmd/dns-sniffer/init.go b/cmd/dns-sniffer/init.go index 3f787f6..f0104b9 100644 --- a/cmd/dns-sniffer/init.go +++ b/cmd/dns-sniffer/init.go @@ -1,5 +1,7 @@ package main +// Initialize logging and interruptions. + import ( "fmt" "net" @@ -23,9 +25,9 @@ func (l *LogFormat) Format(entry *log.Entry) ([]byte, error) { } var ( - CFG *Cfg - BogusSubnets []*net.IPNet - td []func() + CFG *Cfg // global config pointer + BogusSubnets []*net.IPNet // pointers to addresses of bogus networks + td []func() // teardown callbacks array ) func InitInterrupt() { @@ -49,16 +51,8 @@ func init() { log.SetFormatter(&LogFormat{}) log.SetLevel(log.DebugLevel) InitInterrupt() - CFG = initConfig() - - if CFG.LogLevel != "debug" { - logLevel, err := log.ParseLevel(CFG.LogLevel) - if err != nil { - log.Panic("parsing LogLevel error") - } - log.SetLevel(logLevel) - } + // fill the bogus subnets array for _, ipS := range []string{ "0.0.0.0/8", // 0.0.0.0 - 0.255.255.255 "10.0.0.0/8", // 10.0.0.0 - 10.255.255.255 diff --git a/cmd/dns-sniffer/main.go b/cmd/dns-sniffer/main.go index 3583966..914b1fe 100644 --- a/cmd/dns-sniffer/main.go +++ b/cmd/dns-sniffer/main.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" ) +// regQ register queue processing function func regQ(q uint16, rdb *redis.Client, ac *AllowsCache) { config := nfqueue.Config{ NfQueue: q, @@ -27,12 +28,14 @@ func regQ(q uint16, rdb *redis.Client, ac *AllowsCache) { log.Panicln("nfqueue opening", err) } + // add queue closing to teardowns td = append(td, func() { if e := nf.Close(); e != nil { log.Errorln("nfqueue closing", err) } }) + // create nf_queue processing function fn := func(a nfqueue.Attribute) int { packetsHook(a, rdb, ac) err = nf.SetVerdictWithMark(*a.PacketID, nfqueue.NfRepeat, CFG.MarkDone) @@ -43,6 +46,8 @@ func regQ(q uint16, rdb *redis.Client, ac *AllowsCache) { } return 0 } + + // register queue processing function log.Debugln("register hook func...") err = nf.RegisterWithErrorFunc( context.Background(), fn, func(e error) int { return 1 }) @@ -52,8 +57,19 @@ func regQ(q uint16, rdb *redis.Client, ac *AllowsCache) { } func main() { + CFG = initConfig() + + // set logging params + if CFG.LogLevel != "debug" { + logLevel, err := log.ParseLevel(CFG.LogLevel) + if err != nil { + log.Panic("parsing LogLevel error") + } + log.SetLevel(logLevel) + } defer log.Warn("-- done --") + // setup Redis connection log.Debugln("make redis client...") rdb := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", CFG.Redis.Host, CFG.Redis.Port), @@ -68,9 +84,13 @@ func main() { } }() + // create captured answers caching allows := NewAllows(rdb, "allows") + + // recheck records live times and cleanup go allows.RunSanitizer(time.Second * 5) + // run nf queue processors for _, qn := range CFG.Queues { go regQ(qn, rdb, allows) } diff --git a/cmd/dns-sniffer/traffic.go b/cmd/dns-sniffer/traffic.go index 245cbc7..2113e47 100644 --- a/cmd/dns-sniffer/traffic.go +++ b/cmd/dns-sniffer/traffic.go @@ -1,8 +1,6 @@ package main import ( - "context" - "net" "time" "github.com/florianl/go-nfqueue" @@ -12,47 +10,15 @@ import ( log "github.com/sirupsen/logrus" "rkn-rejects/internal/fw" + "rkn-rejects/internal/tools" ) -func ipInSubnets(ip net.IP, subnets []*net.IPNet) (bool, string) { - for _, sn := range subnets { - if sn.Contains(ip) { - return true, sn.String() - } - } - return false, "" -} - -func getUpperDomains(d string) (res []string) { - var dotIndexes []int - for idx, r := range d { - if r == '.' { - dotIndexes = append(dotIndexes, idx) - } - } - for i := len(dotIndexes) - 2; i > -1; i-- { - res = append(res, d[dotIndexes[i]+1:]) - } - return append(res, d) -} - -func isHostDenied(h string, rdb *redis.Client) bool { - ctx := context.Background() - for _, ud := range getUpperDomains(h) { - rRes := rdb.SIsMember(ctx, CFG.Redis.SetKey, ud) - if err := rRes.Err(); err != nil { - log.Errorln("redis sismember", CFG.Redis.SetKey, err) - return false - } - if rRes.Val() { - return true - } - } - return false -} - +// queue processing function +// packetsHook gets DNS answer packet from NF, parses it, checks if the hostname denied, +// then checks if resolved IP addresses in the allowed IP addresses cache, +// if the host is denied and IP in the allowed cache - remove IP fromm the cache, +// if the host isn't denied and the cache hasn't it's IP - add IP to the cache func packetsHook(a nfqueue.Attribute, rdb *redis.Client, ac *AllowsCache) { - //log.Debugln("packetID: ", *a.PacketID) p := gopacket.NewPacket(*a.Payload, layers.LayerTypeIPv4, gopacket.DecodeOptions{ Lazy: true, @@ -61,6 +27,7 @@ func packetsHook(a nfqueue.Attribute, rdb *redis.Client, ac *AllowsCache) { DecodeStreamsAsDatagrams: true, }) + // parse the packet as DNS answer l7 := p.ApplicationLayer() if l7 == nil { return @@ -76,14 +43,18 @@ func packetsHook(a nfqueue.Attribute, rdb *redis.Client, ac *AllowsCache) { } questedName := string(dns.Questions[0].Name) - isDenied := isHostDenied(questedName, rdb) + // checks if the hostname is denied + isDenied := tools.IsHostDenied(questedName, rdb, CFG.Redis.SetKey) + + // check if IP in bogus networks, add to allowed cache if host isn't denied, and remove if it is for _, answer := range dns.Answers { if answer.Type != layers.DNSTypeA { continue } - if bogus, bNet := ipInSubnets(answer.IP, BogusSubnets); bogus { + // if the IP in bogus networks - skip it + if bogus, bNet := tools.IpInSubnets(answer.IP, BogusSubnets); bogus { log.Debugf("DNS Q: %s A: bogus IP: %s [in %s]", questedName, answer.IP.String(), bNet) continue @@ -95,9 +66,13 @@ func packetsHook(a nfqueue.Attribute, rdb *redis.Client, ac *AllowsCache) { Comment: questedName, } + // check if cache already has the IP inCache := ac.Has(el.Ip) if isDenied { if inCache { + + // host is denied and the IP is in the allowed addresses cache - remove it + log.Infoln("bad : ", questedName) if err := fw.Del(CFG.Nf.Table, CFG.Nf.SetName, el); err != nil { log.Error(err) @@ -108,6 +83,9 @@ func packetsHook(a nfqueue.Attribute, rdb *redis.Client, ac *AllowsCache) { } } else { if !inCache { + + // host is not denied and the IP is not in the allowed addresses cache - add it + log.Infoln("good: ", questedName) if err := fw.Add(CFG.Nf.Table, CFG.Nf.SetName, el); err != nil { log.Error(err) diff --git a/cmd/dpi-sniffer/README.md b/cmd/dpi-sniffer/README.md index 5390e36..7952517 100644 --- a/cmd/dpi-sniffer/README.md +++ b/cmd/dpi-sniffer/README.md @@ -1,12 +1,19 @@ -## DPI sniffer +# DPI sniffer +[![Go Test cmd/dpi](https://github.com/sir-go/rkn-rejects/actions/workflows/go-dpi.yml/badge.svg)](https://github.com/sir-go/rkn-rejects/actions/workflows/go-dpi.yml) -### What it does +## What it does - read all packets from `nf_qeueue` - check TLS SNI extension, HTTP headers, and the payload of packet - if IP address is denied or found a denied hostname then the packet marks as "bad" end returns to the firewall (it will be rejected) +## Tests +```bash +go test -v ./cmd/dpi-sniffer/... +gosec ./cmd/dpi-sniffer/... +``` + ### Build ```bash go mod download diff --git a/cmd/dpi-sniffer/config.go b/cmd/dpi-sniffer/config.go index 1762af7..6c8bb71 100644 --- a/cmd/dpi-sniffer/config.go +++ b/cmd/dpi-sniffer/config.go @@ -4,11 +4,11 @@ import ( "encoding/json" "flag" "os" - "strconv" - "strings" "time" log "github.com/sirupsen/logrus" + + "rkn-rejects/internal/tools" ) type ( @@ -49,34 +49,6 @@ func (c *Cfg) String() string { return string(b) } -// 100-103 -> [100 101 102 103] -func parseRange(s string, a *[]uint16) { - var ( - v0, v1 uint64 - err error - ) - if !strings.ContainsRune(s, '-') { - v0, err = strconv.ParseUint(s, 10, 16) - if err != nil { - log.Panicln("range parsing error:", err.Error()) - } - *a = []uint16{uint16(v0)} - return - } - - p := strings.Split(s, "-") - if v0, err = strconv.ParseUint(p[0], 10, 16); err != nil { - log.Panicln("range begin parsing error:", err.Error()) - } - if v1, err = strconv.ParseUint(p[1], 10, 16); err != nil { - log.Panicln("range end parsing error:", err.Error()) - } - for v0 <= v1 { - *a = append(*a, uint16(v0)) - v0++ - } -} - func initConfig() *Cfg { var ( queuesStr string @@ -114,7 +86,9 @@ func initConfig() *Cfg { "just pretty print config") flag.Parse() - parseRange(queuesStr, &cfg.Queues) + if err := tools.ParseRange(queuesStr, &cfg.Queues); err != nil { + panic(err) + } cfg.QMaxLen = uint32(qMaxLen) log.Info(cfg) if cfg.Dry { diff --git a/cmd/dpi-sniffer/init.go b/cmd/dpi-sniffer/init.go index 9453843..0bfa1bc 100644 --- a/cmd/dpi-sniffer/init.go +++ b/cmd/dpi-sniffer/init.go @@ -22,8 +22,8 @@ func (l *LogFormat) Format(entry *log.Entry) ([]byte, error) { } var ( - CFG *Cfg - td []func() + CFG *Cfg // global config pointer + td []func() // teardown callbacks array ) func InitInterrupt() { @@ -47,13 +47,4 @@ func init() { log.SetFormatter(&LogFormat{}) log.SetLevel(log.DebugLevel) InitInterrupt() - CFG = initConfig() - - if CFG.LogLevel != "debug" { - logLevel, err := log.ParseLevel(CFG.LogLevel) - if err != nil { - log.Panic("parsing LogLevel error") - } - log.SetLevel(logLevel) - } } diff --git a/cmd/dpi-sniffer/main.go b/cmd/dpi-sniffer/main.go index 209f914..1cb9f91 100644 --- a/cmd/dpi-sniffer/main.go +++ b/cmd/dpi-sniffer/main.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" ) +// regQ register queue processing function func regQ(q uint16, rc *redis.Client) { config := nfqueue.Config{ NfQueue: q, @@ -26,12 +27,14 @@ func regQ(q uint16, rc *redis.Client) { log.Panicln("nfqueue opening", err) } + // add queue closing to teardowns td = append(td, func() { if e := nf.Close(); e != nil { log.Errorln("nfqueue closing", err) } }) + // create nf_queue processing function fn := func(a nfqueue.Attribute) int { m := CFG.MarkDone if packetsHook(a, rc) { @@ -44,6 +47,9 @@ func regQ(q uint16, rc *redis.Client) { } return 0 } + + // register queue processing function + log.Debugln("register hook func...") err = nf.RegisterWithErrorFunc( context.Background(), fn, func(e error) int { return 1 }) if err != nil { @@ -52,8 +58,21 @@ func regQ(q uint16, rc *redis.Client) { } func main() { + + // init running flags + + CFG = initConfig() + if CFG.LogLevel != "debug" { + logLevel, err := log.ParseLevel(CFG.LogLevel) + if err != nil { + log.Panic("parsing LogLevel error") + } + log.SetLevel(logLevel) + } defer log.Warn("-- done --") + // setup Redis connection + rdb := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", CFG.Redis.Host, CFG.Redis.Port), Password: CFG.Redis.Password, @@ -67,6 +86,8 @@ func main() { } }() + // run nf queue processors + for _, qn := range CFG.Queues { go regQ(qn, rdb) } diff --git a/cmd/dpi-sniffer/sni.go b/cmd/dpi-sniffer/sni.go index c056e6f..9c8d4f2 100644 --- a/cmd/dpi-sniffer/sni.go +++ b/cmd/dpi-sniffer/sni.go @@ -10,6 +10,7 @@ type TLSPayload struct { pos int } +// GetLenW returns the value of the len-word func (tp *TLSPayload) GetLenW() (int, error) { if tp.len < tp.pos+2 { return 0, fmt.Errorf("too small payload len (%d)", tp.len) @@ -20,6 +21,7 @@ func (tp *TLSPayload) GetLenW() (int, error) { return (b1 << 8) + b2, nil } +// GetLenB returns the value of the len-byte func (tp *TLSPayload) GetLenB() (int, error) { if tp.len < tp.pos+1 { return 0, fmt.Errorf("too small payload len (%d)", tp.len) @@ -28,6 +30,7 @@ func (tp *TLSPayload) GetLenB() (int, error) { return int(tp.raw[tp.pos-1]), nil } +// Skip moves the pos pointer to n bytes forward func (tp *TLSPayload) Skip(n int) error { if tp.len < tp.pos+n { return fmt.Errorf("too small payload len (%d)", tp.len) @@ -36,6 +39,7 @@ func (tp *TLSPayload) Skip(n int) error { return nil } +// GetString returns the data from current pos to n of payload as a string func (tp *TLSPayload) GetString(n int) (res string, err error) { if tp.len < tp.pos+n { return "", fmt.Errorf("too small payload len (%d)", tp.len) @@ -45,6 +49,8 @@ func (tp *TLSPayload) GetString(n int) (res string, err error) { return } +// GetSNIForced parses the packet payload as TLS handshake packet +// and grabs the SNI server_name values if it presents func GetSNIForced(d []byte) (sni string, err error) { // SessionIdLength offset = 43 pl := TLSPayload{len: len(d), raw: d, pos: 43} diff --git a/cmd/dpi-sniffer/sni_test.go b/cmd/dpi-sniffer/sni_test.go new file mode 100644 index 0000000..0456b1c --- /dev/null +++ b/cmd/dpi-sniffer/sni_test.go @@ -0,0 +1,150 @@ +package main + +import ( + "io/ioutil" + "testing" +) + +func TestTLSPayload_GetLenW(t *testing.T) { + tests := []struct { + name string + pl TLSPayload + want int + wantErr bool + }{ + {"empty", TLSPayload{}, 0, true}, + {"tooSmall", TLSPayload{10, []byte{}, 16}, 0, true}, + {"ok", TLSPayload{6, []byte{10, 12, 2, 1, 0, 3}, 2}, 513, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.pl.GetLenW() + if (err != nil) != tt.wantErr { + t.Errorf("GetLenW() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetLenW() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTLSPayload_GetLenB(t *testing.T) { + tests := []struct { + name string + pl TLSPayload + want int + wantErr bool + }{ + {"empty", TLSPayload{}, 0, true}, + {"tooSmall", TLSPayload{10, []byte{}, 16}, 0, true}, + {"ok", TLSPayload{6, []byte{10, 12, 2, 1, 0, 3}, 2}, 2, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.pl.GetLenB() + if (err != nil) != tt.wantErr { + t.Errorf("GetLenB() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetLenB() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTLSPayload_Skip(t *testing.T) { + type args struct { + pl *TLSPayload + n int + } + tests := []struct { + name string + args args + wantErr bool + wantPos int + }{ + {"empty", args{&TLSPayload{}, 0}, false, 0}, + {"tooSmall", args{&TLSPayload{}, 5}, true, 0}, + {"ok", args{&TLSPayload{20, nil, 3}, 5}, false, 8}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.args.pl.Skip(tt.args.n); (err != nil) != tt.wantErr { + t.Errorf("Skip() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.args.pl.pos != tt.wantPos { + t.Errorf("after Skip() pos = %v, want %v", tt.args.pl.pos, tt.wantPos) + } + }) + } +} + +func TestTLSPayload_GetString(t *testing.T) { + type args struct { + pl *TLSPayload + n int + } + tests := []struct { + name string + args args + wantRes string + wantErr bool + }{ + {"empty", args{&TLSPayload{}, 0}, "", false}, + {"tooSmall", args{&TLSPayload{}, 12}, "", true}, + {"ok", args{&TLSPayload{20, []byte("some string content"), 4}, 15}, + " string content", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotRes, err := tt.args.pl.GetString(tt.args.n) + if (err != nil) != tt.wantErr { + t.Errorf("GetString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotRes != tt.wantRes { + t.Errorf("GetString() gotRes = %v, want %v", gotRes, tt.wantRes) + } + }) + } +} + +func TestGetSNIForced(t *testing.T) { + sniDumpVk, err := ioutil.ReadFile("../../testdata/sni_dump_vk.pcap") + if err != nil { + panic(err) + } + + sniDumpTtnet, err := ioutil.ReadFile("../../testdata/sni_dump_ttnet.pcap") + if err != nil { + panic(err) + } + type args struct { + d []byte + } + tests := []struct { + name string + args args + wantSni string + wantErr bool + }{ + {"empty", args{nil}, "", true}, + {"vk", args{sniDumpVk}, "vk.com", false}, + {"ttnet", args{sniDumpTtnet}, "ttnet.ru", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotSni, err := GetSNIForced(tt.args.d) + if (err != nil) != tt.wantErr { + t.Errorf("GetSNIForced() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotSni != tt.wantSni { + t.Errorf("GetSNIForced() gotSni = %v, want %v", gotSni, tt.wantSni) + } + }) + } +} diff --git a/cmd/dpi-sniffer/traffic.go b/cmd/dpi-sniffer/traffic.go index d1c12ec..f8eea4b 100644 --- a/cmd/dpi-sniffer/traffic.go +++ b/cmd/dpi-sniffer/traffic.go @@ -1,7 +1,6 @@ package main import ( - "context" "regexp" "strings" @@ -10,43 +9,24 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" log "github.com/sirupsen/logrus" + + "rkn-rejects/internal/tools" ) //goland:noinspection SpellCheckingInspection var ( + // ReLooksLikeDomain matches strings tooks like domain name ReLooksLikeDomain = regexp.MustCompile(`([a-zA-Z]+[a-zA-Z0-9._-]*\.)+([a-zA-Z]+[a-zA-Z0-9._-]*)`) - ReHostname = regexp.MustCompile(`(?i)host:\s([^:/?#\s]+).*\s`) - ReHTTP = regexp.MustCompile(`^(GET|POST|PUT|PATCH|DELETE|TRACE|CONNECT|HEAD|OPTIONS)\s`) -) -func getUpperDomains(d string) (res []string) { - var dotIndexes []int - for idx, r := range d { - if r == '.' { - dotIndexes = append(dotIndexes, idx) - } - } - for i := len(dotIndexes) - 2; i > -1; i-- { - res = append(res, d[dotIndexes[i]+1:]) - } - return append(res, d) -} + // ReHostname matches strings tooks like host name + ReHostname = regexp.MustCompile(`(?i)host:\s([^:/?#\s]+).*\s`) -func isHostDenied(h string, rdb *redis.Client) bool { - ctx := context.Background() - for _, ud := range getUpperDomains(h) { - rRes := rdb.SIsMember(ctx, CFG.Redis.SetKey, ud) - if err := rRes.Err(); err != nil { - log.Errorln("redis sismember", CFG.Redis.SetKey, err) - return false - } - if rRes.Val() { - return true - } - } - return false -} + // ReHTTP matches strings tooks like HTTP header + ReHTTP = regexp.MustCompile(`^(GET|POST|PUT|PATCH|DELETE|TRACE|CONNECT|HEAD|OPTIONS)\s`) +) +// snatchRegexp gets a host or domain name from the given slice +// pos argument sets a position in the matched results func snatchRegexp(b []byte, re *regexp.Regexp, pos int) (s string) { if reResult := re.FindSubmatch(b); reResult != nil && len(reResult) > 1 { return strings.TrimRightFunc(string(reResult[pos]), func(r rune) bool { return r == '.' }) @@ -54,17 +34,22 @@ func snatchRegexp(b []byte, re *regexp.Regexp, pos int) (s string) { return "" } +// packetsHook - nf queue packet processor func packetsHook(a nfqueue.Attribute, rdb *redis.Client) bool { + + // nf verdicts mapping const ( accept = false reject = true ) + // if -x (rejects all) flag is set - reject the packet without any checks if CFG.RejectAll { log.Warn("R-ALL: any =>X any") return reject } + // parse the packet as IPv4 p := gopacket.NewPacket(*a.Payload, layers.LayerTypeIPv4, gopacket.DecodeOptions{ Lazy: true, NoCopy: true, @@ -90,7 +75,7 @@ func packetsHook(a nfqueue.Attribute, rdb *redis.Client) bool { //log.Debugln("check TLS") var hostname string if hostname, _ = GetSNIForced(p.TransportLayer().LayerPayload()); len(hostname) > 1 { - if !isHostDenied(hostname, rdb) { + if !tools.IsHostDenied(hostname, rdb, CFG.Redis.SetKey) { log.Debugf("TLS : %-15s ==> %s", srcIP, hostname) return accept } @@ -101,7 +86,7 @@ func packetsHook(a nfqueue.Attribute, rdb *redis.Client) bool { // check HTTP hostname //log.Debugln("check hostname") if hostname = snatchRegexp(*a.Payload, ReHostname, 1); hostname != "" { - if !isHostDenied(hostname, rdb) { + if !tools.IsHostDenied(hostname, rdb, CFG.Redis.SetKey) { log.Debugf("HTTP: %-15s ==> %s", srcIP, hostname) return accept } @@ -119,7 +104,7 @@ func packetsHook(a nfqueue.Attribute, rdb *redis.Client) bool { // DPI seek anything looks like domain name //log.Debugln("check DPI") if hostname = snatchRegexp(*a.Payload, ReLooksLikeDomain, 0); hostname != "" { - if !isHostDenied(hostname, rdb) { + if !tools.IsHostDenied(hostname, rdb, CFG.Redis.SetKey) { log.Debugf("DPI : %-15s ==> %s", srcIP, hostname) return accept } diff --git a/cmd/dpi-sniffer/traffic_test.go b/cmd/dpi-sniffer/traffic_test.go new file mode 100644 index 0000000..0a81cfa --- /dev/null +++ b/cmd/dpi-sniffer/traffic_test.go @@ -0,0 +1,44 @@ +package main + +import ( + "regexp" + "testing" +) + +func Test_snatchRegexp(t *testing.T) { + type args struct { + b []byte + pos int + } + tests := []struct { + name string + args args + want map[*regexp.Regexp]string + }{ + {"empty", args{[]byte(""), 0}, map[*regexp.Regexp]string{ + ReHTTP: "", + ReHostname: "", + ReLooksLikeDomain: "", + }}, + {"domain", args{[]byte("host: example.com "), 0}, map[*regexp.Regexp]string{ + ReHTTP: "", + ReHostname: "host: example.com ", + ReLooksLikeDomain: "example.com", + }}, + {"host", args{[]byte("host: example.com "), 1}, map[*regexp.Regexp]string{ + ReHTTP: "", + ReHostname: "example.com", + ReLooksLikeDomain: "example", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for rxp, wantString := range tt.want { + if got := snatchRegexp(tt.args.b, rxp, tt.args.pos); got != wantString { + t.Errorf("snatchRegexp(%s, %v, %v) = %v, want %v", + tt.args.b, rxp, tt.args.pos, got, wantString) + } + } + }) + } +} diff --git a/cmd/get_rkn/0_version.go b/cmd/get_rkn/0_version.go index d7aeda0..0db02c4 100644 --- a/cmd/get_rkn/0_version.go +++ b/cmd/get_rkn/0_version.go @@ -1,5 +1,7 @@ package main +// API and a documentation versions checking + import ( "encoding/xml" "net/http" @@ -15,6 +17,7 @@ type ( time.Time } + // SResVersion - `getLastDumpDateEx` response structure SResVersion struct { LastDumpDate *TimestampMs `xml:"lastDumpDate"` LastDumpDateUrgently *TimestampMs `xml:"lastDumpDateUrgently"` @@ -24,6 +27,7 @@ type ( } ) +// UnmarshalXML parses an XML element to a time.Time contained struct func (p *TimestampMs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { var v int64 if err := d.DecodeElement(&v, &start); err != nil { @@ -33,20 +37,23 @@ func (p *TimestampMs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error return nil } +// GetDumpVersion fetches an XML with API version info and returns structured data func GetDumpVersion() *SResVersion { log.Println("get actual RKN service versions...") - httpClient := &http.Client{Timeout: CFG.Web.TcpTimeout.Duration} + httpClient := &http.Client{Timeout: CFG.Web.TcpTimeout} defer func() { httpClient.CloseIdleConnections() }() + // create soap client soap, err := gosoap.SoapClient(CFG.Web.SoapUrl, httpClient) if err != nil { log.Panicln("can't make soap client", err) } + // make request & get response soapResp := new(gosoap.Response) err = retry( CFG.Web.Attempts, - time.Second*CFG.Web.TcpTimeout.Duration, + time.Second*CFG.Web.TcpTimeout, func() error { soapResp, err = soap.Call("getLastDumpDateEx", nil) return err @@ -55,6 +62,7 @@ func GetDumpVersion() *SResVersion { log.Panicln("can't call soap method getLastDumpDateEx", err) } + // decode the answer v := new(SResVersion) if err = soapResp.Unmarshal(v); err != nil { log.Panicln("can't unmarshal soap response", err) @@ -63,6 +71,8 @@ func GetDumpVersion() *SResVersion { return v } +// CheckVersions compares fetched versions and given in the config. +// If versions dont match - throw an error func CheckVersions() { log.Info("check versions") version := GetDumpVersion() diff --git a/cmd/get_rkn/1_request.go b/cmd/get_rkn/1_request.go index f6f42b8..539317c 100644 --- a/cmd/get_rkn/1_request.go +++ b/cmd/get_rkn/1_request.go @@ -1,5 +1,7 @@ package main +// Soap request building and send + import ( "encoding/base64" "fmt" @@ -16,6 +18,7 @@ import ( ) type ( + // SResReq - `sendRequest` response structure SResReq struct { Result bool `xml:"result"` Comment string `xml:"resultComment"` @@ -23,8 +26,9 @@ type ( } ) +// genRequest builds an XML request and stores it to the file func genRequest() { - log.Info("gen rewuest") + log.Info("gen request") var err error req := fmt.Sprintf( ` @@ -46,10 +50,10 @@ func genRequest() { dumpTo(&CFG.Req.File, req, "dump request XML to:") } +// sign calls a request file signing external process func sign() { log.Info("sign request") - //goland:noinspection SpellCheckingInspection - cmd := exec.Command(CFG.Sign.Script) + cmd := exec.Command(CFG.Sign.Script) //#nosec cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr log.Debug(cmd) @@ -58,9 +62,10 @@ func sign() { } } +// sendRequest makes a request to the API and gets a scheduled task ID func sendRequest() (taskId string) { log.Info("send request") - httpClient := &http.Client{Timeout: CFG.Web.TcpTimeout.Duration} + httpClient := &http.Client{Timeout: CFG.Web.TcpTimeout} defer func() { httpClient.CloseIdleConnections() }() soap, err := gosoap.SoapClient(CFG.Web.SoapUrl, httpClient) @@ -81,7 +86,7 @@ func sendRequest() (taskId string) { soapResp := new(gosoap.Response) err = retry( CFG.Web.Attempts, - time.Second*CFG.Web.TcpTimeout.Duration, + time.Second*CFG.Web.TcpTimeout, func() error { soapResp, err = soap.Call("sendRequest", gosoap.Params{ "requestFile": base64.StdEncoding.EncodeToString(req), diff --git a/cmd/get_rkn/2_result.go b/cmd/get_rkn/2_result.go index 7a744e4..e5cd9d5 100644 --- a/cmd/get_rkn/2_result.go +++ b/cmd/get_rkn/2_result.go @@ -1,5 +1,7 @@ package main +// Fetching response as a zip-file bytes + import ( "encoding/base64" "errors" @@ -11,6 +13,7 @@ import ( ) type ( + // SResResult - `getResult` response structure SResResult struct { Result bool `xml:"result"` Code int `xml:"resultCode"` @@ -22,9 +25,11 @@ type ( } ) +// getResult continuously tries to fetch the result of the given task ID +// stores the result to a file (if path is presented in the config) and returns a zip dump bytes func getResult(code string) (zipDump []byte) { log.Info("get result") - httpClient := &http.Client{Timeout: CFG.Res.GetTimeout.Duration} + httpClient := &http.Client{Timeout: CFG.Res.GetTimeout} defer func() { httpClient.CloseIdleConnections() }() soap, err := gosoap.SoapClient(CFG.Web.SoapUrl, httpClient) @@ -36,7 +41,7 @@ func getResult(code string) (zipDump []byte) { soapResp := new(gosoap.Response) err = retry( CFG.Res.Attempts, - CFG.Res.RetryTimeout.Duration, + CFG.Res.RetryTimeout, func() error { soapResp, err = soap.Call("getResult", gosoap.Params{"code": code}) diff --git a/cmd/get_rkn/3_parse.go b/cmd/get_rkn/3_parse.go index a3c0e35..5e3ec85 100644 --- a/cmd/get_rkn/3_parse.go +++ b/cmd/get_rkn/3_parse.go @@ -1,5 +1,7 @@ package main +// Unpack a zip dump and parses contained XML files + import ( "bufio" "context" @@ -21,6 +23,7 @@ var ( reHost *regexp.Regexp ) +// domainIsOk checks if a domain itself is in the white list or of any of it's children is in the whitelist func domainIsOk(d string) bool { if b, s := isItUpDomainOf(d, CFG.WListDomains); b { if s == d { @@ -35,43 +38,65 @@ func domainIsOk(d string) bool { return true } +// processRegElement parses an XML element of dump and toss it to the certain db func processRegElement( elChan chan *xmlparser.XMLElement, rdb *redis.Client, wg *sync.WaitGroup) { defer wg.Done() ctx := context.Background() + + // create new buffer of requests for IPs, domains and check URLs buff := NewRBuff("ip_", "domains_", "check_") for el := range elChan { hsh := el.Attrs["hash"] switch el.Attrs["blockType"] { + // el has blockType=ip attr case "ip": + + // parse all of the ipSubnet child elements for _, ch := range el.Childs["ipSubnet"] { for _, h := range CIDRHosts(ch.InnerText) { + // make the hash and toss all of IPs of the parsed subnet to the checks buff.add("check_", fmt.Sprintf("%s|%s", hsh, h)) + // toss all IPs of the parsed subnet to the IP list buff.add("ip_", h) } } + + // parse all of the ip child elements for _, ch := range el.Childs["ip"] { for _, h := range CIDRHosts(ch.InnerText) { + // make the hash and toss all of IPs of the parsed subnet to the checks buff.add("check_", fmt.Sprintf("%s|%s", hsh, h)) + // toss all IPs of the parsed subnet to the IP list buff.add("ip_", h) } } + + // el has blockType=domain or blockType=domain-mask attr case "domain": fallthrough case "domain-mask": + // parse all of the domain child elements for _, ch := range el.Childs["domain"] { + // clean the domain names domain := sanitizeDomain(ch.InnerText) + // check if the domain is allowed if domainIsOk(domain) { + // make the hash and toss the domain to the checks buff.add("check_", fmt.Sprintf("%s|%s", hsh, domain)) + //toss the domain to the domains list buff.add("domains_", domain) } } + // el has blockType=url or anything else attr default: + // parse all of the url child elements for _, ch := range el.Childs["url"] { + // get the hostname from the element content subStrings := reHost.FindStringSubmatch(ch.InnerText) if len(subStrings) < 2 { log.Debugf("[%s] can't get host from url `%s`", @@ -79,7 +104,11 @@ func processRegElement( continue } + // got the hostname host := subStrings[1] + + // if it an IP address - parse it as a subnet + // and toss all of parsed addresses to the check and ip lists if ip := net.ParseIP(host); ip != nil { for _, h := range CIDRHosts(host) { buff.add("check_", fmt.Sprintf("%s|%s", @@ -89,6 +118,7 @@ func processRegElement( continue } + // if it's not an IP address then clean the hostname and toss it to the checks and domains lists if domain := sanitizeDomain(host); domainIsOk(domain) { buff.add("check_", fmt.Sprintf("%s|%s", hsh, ch.InnerText)) @@ -97,19 +127,27 @@ func processRegElement( } } + // if request buffer is full - send it to the redis if buff.count > CFG.Parse.Redis.ChunkSize { buff.send(rdb, ctx) } } + + // send the rest of requests to the redis if buff.count > 0 { buff.send(rdb, ctx) } } +// Parse reads data from an unpacked dump with chunks sized by `size`, decodes it as an XML, +//and parses all of the elements with the progress updating func Parse(r io.ReadCloser, size uint64) { log.Info("start parsing") + + // regexp for get a hostname from the url reHost = regexp.MustCompile(`^[a-zA-Z\d]+://([^/\n:\\]+)`) + // create a new redis connection rdb := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", CFG.Parse.Redis.Host, @@ -117,9 +155,9 @@ func Parse(r io.ReadCloser, size uint64) { Password: CFG.Parse.Redis.Password, DB: CFG.Parse.Redis.Db, MaxRetries: 99, - DialTimeout: CFG.Parse.Redis.TimeoutConn.Duration, - ReadTimeout: CFG.Parse.Redis.TimeoutRead.Duration, - WriteTimeout: CFG.Parse.Redis.TimeoutRead.Duration, + DialTimeout: CFG.Parse.Redis.TimeoutConn, + ReadTimeout: CFG.Parse.Redis.TimeoutRead, + WriteTimeout: CFG.Parse.Redis.TimeoutRead, }) defer func() { if err := rdb.Close(); err != nil { @@ -127,19 +165,26 @@ func Parse(r io.ReadCloser, size uint64) { } }() + // dumps saved in the Windows-1251 encoding dec := charmap.Windows1251.NewDecoder() + // prepare the XML parser parser := xmlparser.NewXMLParser( bufio.NewReader(dec.Reader(r)), "content") elChan := parser.Stream() + // prepare a waiting group for goroutines var wg sync.WaitGroup waitCh := make(chan struct{}) + // set capacity of the waiting group by the workers amount wg.Add(CFG.Parse.Redis.Workers) - tick := time.NewTicker(CFG.Parse.ProgressPollTimeout.Duration) + // create a ticker for the progressbar updating + tick := time.NewTicker(CFG.Parse.ProgressPollTimeout) + + // run by-element parsing workers go func() { for i := CFG.Parse.Redis.Workers; i > 0; i-- { go processRegElement(elChan, rdb, &wg) @@ -152,14 +197,20 @@ func Parse(r io.ReadCloser, size uint64) { progress uint64 err error ) + + // a loop for goroutines processing for { select { + + // all workers are done case <-waitCh: if progress < 100 { log.Info("100%") } ctx := context.Background() + + // rename the temporary list on the redis (remove a _ prefix) rKeys := rdb.Keys(ctx, "*_") if err = rKeys.Err(); err != nil { log.Panicln("redis get keys", err) @@ -171,6 +222,8 @@ func Parse(r io.ReadCloser, size uint64) { } } return + + // update the progress by the ticker case <-tick.C: progress = parser.TotalReadSize / (size / 100.0) if progress > 100 { diff --git a/cmd/get_rkn/4_add_blacklists.go b/cmd/get_rkn/4_add_blacklists.go index a96df65..c4058db 100644 --- a/cmd/get_rkn/4_add_blacklists.go +++ b/cmd/get_rkn/4_add_blacklists.go @@ -1,5 +1,7 @@ package main +// Add predefined in the config blacklist records to the domains list on the redis + import ( "context" "fmt" @@ -10,6 +12,8 @@ import ( func InjectBlacklist() { log.Info("inject domains from blacklist") + + // create a new redis connection rdb := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", CFG.Parse.Redis.Host, @@ -19,7 +23,7 @@ func InjectBlacklist() { }) defer func() { if err := rdb.Close(); err != nil { - log.Panicln("clode redis conn", err) + log.Panicln("close redis conn", err) } }() @@ -27,6 +31,7 @@ func InjectBlacklist() { rb := NewRBuff("domains") defer func() { rb.send(rdb, ctx) }() + // toss the domains from the config to the redis list for _, d := range CFG.BListDomains { if domain := sanitizeDomain(d); domainIsOk(domain) { rb.add("domains", domain) diff --git a/cmd/get_rkn/5_ip_to_nftables.go b/cmd/get_rkn/5_ip_to_nftables.go index e6939c3..82f9266 100644 --- a/cmd/get_rkn/5_ip_to_nftables.go +++ b/cmd/get_rkn/5_ip_to_nftables.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "path/filepath" "github.com/go-redis/redis/v8" log "github.com/sirupsen/logrus" @@ -12,13 +13,17 @@ import ( "rkn-rejects/internal/fw" ) +// fPuts formats a string with the format `f` and puts it to the writer `fd` func fPuts(fd io.Writer, f string, s ...interface{}) { if _, err := fmt.Fprintf(fd, f, s...); err != nil { log.Panicln("can't write to file ", err) } } +// dumpIpsToFile gets IP addresses from the redis, generates a netfilter script, and stores it to the file func dumpIpsToFile(fileName, tableName, setName string) { + + // create new redis connection rdb := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", CFG.Parse.Redis.Host, @@ -32,13 +37,15 @@ func dumpIpsToFile(fileName, tableName, setName string) { } }() + // fetch IP addresses ctx := context.Background() res, err := rdb.SMembers(ctx, "ip").Result() if err != nil { log.Panicln("redis get ip members", err) } - fd, err := os.Create(fileName) + // open a script file + fd, err := os.Create(filepath.Clean(fileName)) if err != nil { log.Panicln("create nft file", err) } @@ -48,6 +55,7 @@ func dumpIpsToFile(fileName, tableName, setName string) { } }() + // generate and write a script to the file fPuts(fd, "#!/usr/sbin/nft -f \n\n") fPuts(fd, "# this file generated by the get-rkn tool \n\n") fPuts(fd, "flush set %s %s \n\n", tableName, setName) @@ -58,6 +66,7 @@ func dumpIpsToFile(fileName, tableName, setName string) { fPuts(fd, "}") } +// ip2nft generates a nftables script with denied IP addresses, flushes the rules table, and applies the script func ip2nft() { log.Info("gen deny by ip nftables rules") dumpIpsToFile(CFG.Fw.IpDenyFile, CFG.Fw.IpDenyTable, CFG.Fw.IpDenySet) diff --git a/cmd/get_rkn/README.md b/cmd/get_rkn/README.md index f52650a..025d688 100644 --- a/cmd/get_rkn/README.md +++ b/cmd/get_rkn/README.md @@ -1,12 +1,13 @@ -## get_rkn - +# get_rkn +[![Go Test cmd/get-rkn](https://github.com/sir-go/rkn-rejects/actions/workflows/go-get-rkn.yml/badge.svg)](https://github.com/sir-go/rkn-rejects/actions/workflows/go-get-rkn.yml) + [CryptoPRO](https://www.cryptopro.ru/products/csp/downloads) for sign requests is required. SSH tuning `.ssh/config` ``` KexAlgorithms +diffie-hellman-group-exchange-sha1,diffie-hellman-group14-sha1 ``` -### What it does +## What it does - check versions of the service, dump and documentation - generate XML request, sign it and send to the service - get the request UUID and wait the result @@ -15,17 +16,7 @@ KexAlgorithms +diffie-hellman-group-exchange-sha1,diffie-hellman-group14-sha1 - fill the redis DB - generate and run script for configuring the firewall -### Build -```bash -go mod download -go build -o get_rkn rkn-rejects/cmd/get_rkn; -``` - -### Flags - -`-c ` - path to `*.toml` config file - -### Config +## Config ```toml log_level = "debug" @@ -99,7 +90,14 @@ log_level = "debug" ip_deny_set = "deny_rkn" # nftables denied addresses set name ``` -### Run +## Build +```bash +go mod download +go build -o get_rkn rkn-rejects/cmd/get_rkn; +``` +## Run ```bash ./get_rkn -c rkn.toml ``` +## Flags +`-c ` - path to `*.toml` config file (default is `rkn.toml`) diff --git a/cmd/get_rkn/config.go b/cmd/get_rkn/config.go index 35a1a51..06ad041 100644 --- a/cmd/get_rkn/config.go +++ b/cmd/get_rkn/config.go @@ -12,8 +12,8 @@ import ( ) const ( - ConfFileEnvVar = "RKN_CONF" - DefaultConfFile = "rkn.toml" + ConfFileEnvVar = "RKN_CONF" // env variable stores a path to the config file + DefaultConfFile = "rkn.toml" // default path to the config file ) type ( @@ -21,9 +21,7 @@ type ( *net.IPNet } - Duration struct { - time.Duration - } + // Cfg is the global config Cfg struct { ActualVersions struct { @@ -32,10 +30,10 @@ type ( Doc string `toml:"doc"` } `toml:"actual_versions"` Web struct { - SoapUrl string `toml:"soap_url"` - DocUrl string `toml:"doc_url"` - TcpTimeout *Duration `toml:"tcp_timeout"` - Attempts int `toml:"attempts"` + SoapUrl string `toml:"soap_url"` + DocUrl string `toml:"doc_url"` + TcpTimeout time.Duration `toml:"tcp_timeout"` + Attempts int `toml:"attempts"` } `toml:"web"` Req struct { File string `toml:"file"` @@ -51,27 +49,27 @@ type ( Script string `toml:"script"` } `toml:"sign"` Res struct { - DumpTo *string `toml:"dump_to"` - Attempts int `toml:"attempts"` - GetTimeout *Duration `toml:"download_timeout"` - RetryTimeout *Duration `toml:"retry_timeout"` + DumpTo *string `toml:"dump_to"` + Attempts int `toml:"attempts"` + GetTimeout time.Duration `toml:"download_timeout"` + RetryTimeout time.Duration `toml:"retry_timeout"` } `toml:"res"` Parse struct { - FromDump *string `toml:"from_dump"` - ProgressPollTimeout *Duration `toml:"progress_poll_timeout"` + FromDump *string `toml:"from_dump"` + ProgressPollTimeout time.Duration `toml:"progress_poll_timeout"` BogusIp struct { Subnets []Net `toml:"subnets"` MinMask int `toml:"min_mask"` } `toml:"bogus_ip"` Redis struct { - Host string `toml:"host"` - Port int `toml:"port"` - Password string `toml:"password"` - Db int `toml:"db"` - ChunkSize int `toml:"chunk_size"` - Workers int `toml:"workers"` - TimeoutConn *Duration `toml:"timeout_conn,omitempty"` - TimeoutRead *Duration `toml:"timeout_read,omitempty"` + Host string `toml:"host"` + Port int `toml:"port"` + Password string `toml:"password"` + Db int `toml:"db"` + ChunkSize int `toml:"chunk_size"` + Workers int `toml:"workers"` + TimeoutConn time.Duration `toml:"timeout_conn,omitempty"` + TimeoutRead time.Duration `toml:"timeout_read,omitempty"` } `toml:"redis"` } `toml:"parse"` Lists struct { @@ -89,12 +87,6 @@ type ( } ) -func (d *Duration) UnmarshalText(text []byte) error { - var err error - d.Duration, err = time.ParseDuration(string(text)) - return err -} - func (n *Net) UnmarshalText(text []byte) error { return n.parse(string(text)) } diff --git a/cmd/get_rkn/config_test.go b/cmd/get_rkn/config_test.go index f90f5d3..df30c1e 100644 --- a/cmd/get_rkn/config_test.go +++ b/cmd/get_rkn/config_test.go @@ -1,48 +1,37 @@ package main import ( + "net" + "reflect" "testing" ) -func Test_sanitizeConfLine(t *testing.T) { +func TestNet_parse(t *testing.T) { type args struct { - s string + text string } tests := []struct { - name string - args args - wantS_ string + name string + args args + wantErr bool + want *net.IPNet }{ - {"", args{s: ""}, ""}, - {"", args{s: " "}, ""}, - {"", args{s: " "}, ""}, - {"", args{s: " #"}, ""}, - {"", args{s: " # "}, ""}, - {"", args{s: "# "}, ""}, - {"", args{s: "#"}, ""}, - {"", args{s: "abc"}, "abc"}, - {"", args{s: " abc"}, "abc"}, - {"", args{s: "abc "}, "abc"}, - {"", args{s: " abc "}, "abc"}, - {"", args{s: "ab cd"}, "ab cd"}, - {"", args{s: " ab cd"}, "ab cd"}, - {"", args{s: " ab cd "}, "ab cd"}, - {"", args{s: "# ab cd "}, ""}, - {"", args{s: " ab # cd "}, "ab"}, - {"", args{s: " ab cd # "}, "ab cd"}, - {"", args{s: " ab cd # asd "}, "ab cd"}, - {"", args{s: " # asd "}, ""}, - {"", args{s: "abc#asd"}, "abc"}, - {"", args{s: " abc#asd"}, "abc"}, - {"", args{s: " abc #asd"}, "abc"}, - {"", args{s: "abc #asd"}, "abc"}, - {"", args{s: "abc# asd"}, "abc"}, - {"", args{s: "abc# asd # adf"}, "abc"}, + {"empty", args{""}, true, nil}, + {"ok-Zero", args{"0.0.0.0/0"}, false, + &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}}, + {"ok-Net", args{"192.168.5.0/30"}, false, + &net.IPNet{IP: net.IP{192, 168, 5, 0}, Mask: net.IPMask{255, 255, 255, 252}}}, + {"ok-Host", args{"10.10.6.36/32"}, false, + &net.IPNet{IP: net.IP{10, 10, 6, 36}, Mask: net.IPMask{255, 255, 255, 255}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if gotS_ := sanitizeConfLine(tt.args.s); gotS_ != tt.wantS_ { - t.Errorf("sanitizeConfLine() = %v, want %v", gotS_, tt.wantS_) + ipN := &Net{} + if err := ipN.parse(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("parse() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(ipN.IPNet.String(), tt.want.String()) { + t.Errorf("after parse() ipNet = %v, want %v", ipN.IPNet, tt.want) } }) } diff --git a/cmd/get_rkn/domain_math.go b/cmd/get_rkn/domain_math.go index 8fddee3..73ae588 100644 --- a/cmd/get_rkn/domain_math.go +++ b/cmd/get_rkn/domain_math.go @@ -1,5 +1,7 @@ package main +// Some tools for domain names working with + import ( "strings" @@ -7,6 +9,10 @@ import ( "golang.org/x/net/idna" ) +// isASCII checks if the whole string is ASCII encoded +// example -> true +// пример -> false +// some-пример -> false func isASCII(s string) bool { for _, c := range s { if c > 127 { @@ -16,6 +22,9 @@ func isASCII(s string) bool { return true } +// forcePUnicode encodes a domain as punycode if it has certain signature +// сайт.рф -> xn--80aswg.xn--p1ai +// example.com -> example.com func forcePUnicode(d string) (s string) { var err error if !isASCII(d) { @@ -29,6 +38,9 @@ func forcePUnicode(d string) (s string) { return } +// forcePDecode decodes a domain as punycode if it has certain signature +// xn--80aswg.xn--p1ai -> сайт.рф +// example.com -> example.com func forcePDecode(d string) (s string) { if !strings.HasPrefix(d, "xn--") { return d @@ -41,6 +53,10 @@ func forcePDecode(d string) (s string) { return } +// sanitizeDomain clears a domain name +//www.domain.com -> domain.com +//*.sub.domain.net -> sub.domain.net +//domain.com/?params=444 -> domain.com func sanitizeDomain(d string) (s string) { s = forcePUnicode(d) s = strings.TrimPrefix(s, "www.") @@ -52,6 +68,8 @@ func sanitizeDomain(d string) (s string) { return } +// isItUpDomainOf checks if any of the given subdomains is a child of the given domain. +// Returns a boolean (is the domain is the parent) and subdomain string func isItUpDomainOf(domain string, subDomains []string) (bool, string) { for _, sd := range subDomains { if strings.HasSuffix(sd, "."+domain) || sd == domain { diff --git a/cmd/get_rkn/domain_math_test.go b/cmd/get_rkn/domain_math_test.go new file mode 100644 index 0000000..5b4de65 --- /dev/null +++ b/cmd/get_rkn/domain_math_test.go @@ -0,0 +1,136 @@ +package main + +import ( + "testing" +) + +func Test_isItUpDomainOf(t *testing.T) { + type args struct { + domain string + subDomains []string + } + tests := []struct { + name string + args args + wantResult bool + wantSubdomain string + }{ + {"empty", args{"", nil}, + false, ""}, + {"self", args{"example.com", []string{"example.com"}}, + true, "example.com"}, + {"child", args{"example.com", []string{"subdomain.example.com"}}, + true, "subdomain.example.com"}, + {"3rd", args{"subdomain.example.com", []string{"sub.subdomain.example.com"}}, + true, "sub.subdomain.example.com"}, + {"not", args{"subdomain.example.com", []string{"example.com"}}, + false, ""}, + {"many", args{"example.com", []string{ + "sub.another.com", + "ech1.example.com", + "bx.example.com", + }}, + true, "ech1.example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotResult, gotSubdomain := isItUpDomainOf(tt.args.domain, tt.args.subDomains) + if gotResult != tt.wantResult || gotSubdomain != tt.wantSubdomain { + t.Errorf("isItUpDomainOf() got = (%v, %v), want (%v, %v)", + gotResult, gotSubdomain, tt.wantResult, tt.wantSubdomain) + } + }) + } +} + +func Test_sanitizeDomain(t *testing.T) { + type args struct { + d string + } + tests := []struct { + name string + args args + wantS string + }{ + {"empty", args{""}, ""}, + {"www", args{"www.ya.ru"}, "ya.ru"}, + {"wildcard", args{"*.maps.ya.ru"}, "maps.ya.ru"}, + {"params", args{"www.maps.ya.ru/page?params=4&l=3"}, "maps.ya.ru"}, + {"anchors", args{"*.maps.ya.ru/#:page"}, "maps.ya.ru"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotS := sanitizeDomain(tt.args.d); gotS != tt.wantS { + t.Errorf("sanitizeDomain() = %v, want %v", gotS, tt.wantS) + } + }) + } +} + +func Test_forcePDecode(t *testing.T) { + type args struct { + d string + } + //goland:noinspection SpellCheckingInspection + tests := []struct { + name string + args args + wantS string + }{ + {"empty", args{""}, ""}, + {"national", args{"xn--e1afmkfd.xn--p1ai"}, "пример.рф"}, + {"eng", args{"example.com"}, "example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotS := forcePDecode(tt.args.d); gotS != tt.wantS { + t.Errorf("forcePDecode() = %v, want %v", gotS, tt.wantS) + } + }) + } +} + +func Test_forcePUnicode(t *testing.T) { + type args struct { + d string + } + //goland:noinspection SpellCheckingInspection + tests := []struct { + name string + args args + wantS string + }{ + {"empty", args{""}, ""}, + {"national", args{"пример.рф"}, "xn--e1afmkfd.xn--p1ai"}, + {"eng", args{"example.com"}, "example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotS := forcePUnicode(tt.args.d); gotS != tt.wantS { + t.Errorf("forcePUnicode() = %v, want %v", gotS, tt.wantS) + } + }) + } +} + +func Test_isASCII(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want bool + }{ + {"empty", args{""}, true}, + {"yes", args{"latin string"}, true}, + {"no", args{"latin строка"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isASCII(tt.args.s); got != tt.want { + t.Errorf("isASCII() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/get_rkn/dump_to.go b/cmd/get_rkn/dump_to.go index 394941e..6f9d3df 100644 --- a/cmd/get_rkn/dump_to.go +++ b/cmd/get_rkn/dump_to.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" ) +// dumpTo saves the given data to a file if the path is presented func dumpTo(path *string, data interface{}, decr string) { if path == nil || *path == "" { return diff --git a/cmd/get_rkn/dump_to_test.go b/cmd/get_rkn/dump_to_test.go new file mode 100644 index 0000000..3f9bc18 --- /dev/null +++ b/cmd/get_rkn/dump_to_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "bytes" + "io/ioutil" + "path/filepath" + "testing" +) + +func Test_dumpTo(t *testing.T) { + type args struct { + path string + data interface{} + decr string + } + tests := []struct { + name string + args args + wantContent []byte + }{ + {"empty", args{"", nil, ""}, nil}, + {"ok", args{"/tmp/tmp-file-name", "some content", "description"}, + []byte("some content")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dumpTo(&tt.args.path, tt.args.data, tt.args.decr) + if tt.args.path != "" { + content, err := ioutil.ReadFile(filepath.Clean(tt.args.path)) + if err != nil { + t.Errorf("read tmp file error %v", err) + } + if !bytes.Equal(content, tt.wantContent) { + t.Errorf("dumpTo() wrote an unexpected contant %v, want %v", content, tt.wantContent) + } + } + }) + } +} diff --git a/cmd/get_rkn/includes.go b/cmd/get_rkn/includes.go index c2cd206..9df747e 100644 --- a/cmd/get_rkn/includes.go +++ b/cmd/get_rkn/includes.go @@ -3,11 +3,13 @@ package main import ( "bufio" "os" + "path/filepath" "strings" log "github.com/sirupsen/logrus" ) +// sanitizeConfLine removes from the given string spaces and comments func sanitizeConfLine(s string) string { if s == "" { return "" @@ -23,8 +25,9 @@ func sanitizeConfLine(s string) string { } } +// includesReadDomains reads a file with domains, recode them and pushes to a string slice func includesReadDomains(path string) (d []string) { - fd, err := os.Open(path) + fd, err := os.Open(filepath.Clean(path)) if err != nil { log.Panicln("can't open domains file", err) } diff --git a/cmd/get_rkn/includes_test.go b/cmd/get_rkn/includes_test.go new file mode 100644 index 0000000..0e224cc --- /dev/null +++ b/cmd/get_rkn/includes_test.go @@ -0,0 +1,45 @@ +package main + +import ( + "testing" +) + +func Test_sanitizeConfLine(t *testing.T) { + tests := []struct { + arg string + want string + }{ + {"", ""}, + {" ", ""}, + {" ", ""}, + {" #", ""}, + {" # ", ""}, + {"# ", ""}, + {"#", ""}, + {"abc", "abc"}, + {" abc", "abc"}, + {"abc ", "abc"}, + {" abc ", "abc"}, + {"ab cd", "ab cd"}, + {" ab cd", "ab cd"}, + {" ab cd ", "ab cd"}, + {"# ab cd ", ""}, + {" ab # cd ", "ab"}, + {" ab cd # ", "ab cd"}, + {" ab cd # asd ", "ab cd"}, + {" # asd ", ""}, + {"abc#asd", "abc"}, + {" abc#asd", "abc"}, + {" abc #asd", "abc"}, + {"abc #asd", "abc"}, + {"abc# asd", "abc"}, + {"abc# asd # adf", "abc"}, + } + for _, tt := range tests { + t.Run("", func(t *testing.T) { + if got := sanitizeConfLine(tt.arg); got != tt.want { + t.Errorf("sanitizeConfLine() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/get_rkn/ip_math.go b/cmd/get_rkn/ip_math.go index cdaa41d..7060e42 100644 --- a/cmd/get_rkn/ip_math.go +++ b/cmd/get_rkn/ip_math.go @@ -8,6 +8,8 @@ import ( log "github.com/sirupsen/logrus" ) +// ipInSubnets checks if the IP address is in the subnets slice and returns a boolean (found or not) +//and the subnet as a string where this IP address has been found func ipInSubnets(ip net.IP, subnets []Net) (bool, string) { for _, sn := range subnets { if sn.Contains(ip) { @@ -17,14 +19,17 @@ func ipInSubnets(ip net.IP, subnets []Net) (bool, string) { return false, "" } +// ipInBogusSubnet check if the defined in the configuration bogus networks contain the given IP address func ipInBogusSubnet(ip net.IP) bool { if b, sn := ipInSubnets(ip, CFG.Parse.BogusIp.Subnets); b { - log.Warnln("%s in bogus subnet %s -> skip", ip, sn) + log.Warnf("%s in bogus subnet %s -> skip", ip, sn) return b } return false } +// CIDRHosts composites the given CIDR to a slice of hosts (IP strings) +// "10.0.0.12/30 -> 10.0.0.12/32, 10.0.0.13/32, 10.0.0.14/32, 10.0.0.15/32, func CIDRHosts(cidr string) (hosts []string) { // it's IP address if !strings.ContainsRune(cidr, '/') { diff --git a/cmd/get_rkn/ip_math_test.go b/cmd/get_rkn/ip_math_test.go new file mode 100644 index 0000000..729dc4c --- /dev/null +++ b/cmd/get_rkn/ip_math_test.go @@ -0,0 +1,64 @@ +package main + +import ( + "net" + "testing" +) + +func Test_ipInSubnets(t *testing.T) { + type args struct { + ip net.IP + subnets []Net + } + tests := []struct { + name string + args args + wantFound bool + wantSubnet string + }{ + { + name: "empty", + args: args{ + ip: nil, + subnets: nil, + }, + wantFound: false, + wantSubnet: "", + }, + { + name: "found", + args: args{ + ip: net.IP{10, 10, 0, 15}, + subnets: []Net{ + {IPNet: &net.IPNet{IP: net.IP{10, 10, 10, 0}, Mask: net.IPMask{255, 255, 255, 0}}}, + {IPNet: &net.IPNet{IP: net.IP{10, 10, 0, 0}, Mask: net.IPMask{255, 255, 255, 0}}}, + {IPNet: &net.IPNet{IP: net.IP{192, 168, 201, 0}, Mask: net.IPMask{255, 255, 255, 252}}}, + }, + }, + wantFound: true, + wantSubnet: "10.10.0.0/24", + }, + { + name: "not-found", + args: args{ + ip: net.IP{10, 10, 0, 15}, + subnets: []Net{ + {IPNet: &net.IPNet{IP: net.IP{10, 10, 10, 0}, Mask: net.IPMask{255, 255, 255, 0}}}, + {IPNet: &net.IPNet{IP: net.IP{172, 16, 0, 0}, Mask: net.IPMask{255, 255, 0, 0}}}, + {IPNet: &net.IPNet{IP: net.IP{192, 168, 201, 0}, Mask: net.IPMask{255, 255, 255, 252}}}, + }, + }, + wantFound: false, + wantSubnet: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotFound, gotSubnet := ipInSubnets(tt.args.ip, tt.args.subnets) + if gotFound != tt.wantFound || gotSubnet != tt.wantSubnet { + t.Errorf("ipInSubnets() got = (%v, %v), want (%v, %v)", + gotFound, gotSubnet, tt.wantFound, tt.wantSubnet) + } + }) + } +} diff --git a/cmd/get_rkn/main.go b/cmd/get_rkn/main.go index 8bd03b8..0661823 100644 --- a/cmd/get_rkn/main.go +++ b/cmd/get_rkn/main.go @@ -22,12 +22,10 @@ func (l *LogFormat) Format(entry *log.Entry) ([]byte, error) { ), nil } -func init() { +func initLogging() { log.SetReportCaller(true) log.SetFormatter(&LogFormat{}) log.SetLevel(log.DebugLevel) - CFG = initConfig() - if CFG.LogLevel != "debug" { logLevel, err := log.ParseLevel(CFG.LogLevel) if err != nil { @@ -38,6 +36,8 @@ func init() { } func main() { + CFG = initConfig() + initLogging() defer log.Println("--done--") var dumpData []byte diff --git a/cmd/get_rkn/read_zip.go b/cmd/get_rkn/read_zip.go index 4f33da8..2acd793 100644 --- a/cmd/get_rkn/read_zip.go +++ b/cmd/get_rkn/read_zip.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" ) +// ReadZip opens a zipDump bytes as a zip-file and returns a reader for the content func ReadZip(zipDump []byte) (r io.ReadCloser, size uint64) { zipReader, err := zip.NewReader( bytes.NewReader(zipDump), int64(len(zipDump))) diff --git a/cmd/get_rkn/redis_buffer.go b/cmd/get_rkn/redis_buffer.go index 53b9460..c5a2605 100644 --- a/cmd/get_rkn/redis_buffer.go +++ b/cmd/get_rkn/redis_buffer.go @@ -7,13 +7,19 @@ import ( log "github.com/sirupsen/logrus" ) +// A buffer for requests to the Redis server to reduce the load of the parsing processes + type ( + // RBuff contains a mapped requests (list name as a key and a slice of requests as a value) + // and a total count of stored requests. + // When a buffer is full it's flushes all of stored requests to the Redis and resets itself RBuff struct { data map[string][]interface{} count int } ) +// NewRBuff creates a buffer mapped with given lists func NewRBuff(lists ...string) *RBuff { rb := &RBuff{data: make(map[string][]interface{})} for _, ln := range lists { @@ -22,11 +28,13 @@ func NewRBuff(lists ...string) *RBuff { return rb } +// add a `val` request to the `list` func (r *RBuff) add(list string, val interface{}) { r.data[list] = append(r.data[list], val) r.count++ } +// send flushes all stored requests to the redis via `rdb` connection and recreates func (r *RBuff) send(rdb *redis.Client, ctx context.Context) { keys := make([]string, 0) for key, values := range r.data { diff --git a/cmd/get_rkn/retry.go b/cmd/get_rkn/retry.go index 13ea374..2d344f2 100644 --- a/cmd/get_rkn/retry.go +++ b/cmd/get_rkn/retry.go @@ -7,6 +7,8 @@ import ( log "github.com/sirupsen/logrus" ) +// retry tries to call the given function with `sleep` interval while the function returns a non-nil error, +// maximum `attempts` times func retry(attempts int, sleep time.Duration, f func() error) (err error) { for i := attempts - 1; i > 0; i-- { if err = f(); err == nil { diff --git a/cmd/get_rkn/sign.sh b/cmd/get_rkn/sign.sh index e9ac308..1663e1d 100644 --- a/cmd/get_rkn/sign.sh +++ b/cmd/get_rkn/sign.sh @@ -1,3 +1,6 @@ #!/bin/bash + +# this is a requests signing command line (using a CryptoPRO CLI tool) + /opt/cprocsp/bin/amd64/csptest -sfsign -sign -in /tmp/req.xml \ -out /tmp/req.xml.signed -my "${COMPANY_NAME}" -detached -add diff --git a/internal/tools/domains.go b/internal/tools/domains.go new file mode 100644 index 0000000..4462f89 --- /dev/null +++ b/internal/tools/domains.go @@ -0,0 +1,15 @@ +package tools + +// GetUpperDomains returns a list of all upper domains of the given hostname +func GetUpperDomains(d string) (res []string) { + var dotIndexes []int + for idx, r := range d { + if r == '.' { + dotIndexes = append(dotIndexes, idx) + } + } + for i := len(dotIndexes) - 2; i > -1; i-- { + res = append(res, d[dotIndexes[i]+1:]) + } + return append(res, d) +} diff --git a/internal/tools/domains_test.go b/internal/tools/domains_test.go new file mode 100644 index 0000000..7a8ca5c --- /dev/null +++ b/internal/tools/domains_test.go @@ -0,0 +1,30 @@ +package tools + +import ( + "reflect" + "testing" +) + +func Test_getUpperDomains(t *testing.T) { + type args struct { + d string + } + tests := []struct { + name string + args args + wantRes []string + }{ + {"empty", args{""}, []string{""}}, + {"1lvl", args{"com"}, []string{"com"}}, + {"2lvl", args{"uk.com"}, []string{"uk.com"}}, + {"3lvl", args{"gov.uk.com"}, []string{"uk.com", "gov.uk.com"}}, + {"4lvl", args{"main.gov.uk.com"}, []string{"uk.com", "gov.uk.com", "main.gov.uk.com"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotRes := GetUpperDomains(tt.args.d); !reflect.DeepEqual(gotRes, tt.wantRes) { + t.Errorf("getUpperDomains() = %v, want %v", gotRes, tt.wantRes) + } + }) + } +} diff --git a/internal/tools/hostCheck.go b/internal/tools/hostCheck.go new file mode 100644 index 0000000..2c0ccf5 --- /dev/null +++ b/internal/tools/hostCheck.go @@ -0,0 +1,24 @@ +package tools + +import ( + "context" + + "github.com/go-redis/redis/v8" + log "github.com/sirupsen/logrus" +) + +// IsHostDenied checks the hostname and all of the upper domains if any of them is in the denied list +func IsHostDenied(h string, rdb *redis.Client, hostsKey string) bool { + ctx := context.Background() + for _, ud := range GetUpperDomains(h) { + rRes := rdb.SIsMember(ctx, hostsKey, ud) + if err := rRes.Err(); err != nil { + log.Errorln("redis sismember", hostsKey, err) + return false + } + if rRes.Val() { + return true + } + } + return false +} diff --git a/internal/tools/ranges.go b/internal/tools/ranges.go new file mode 100644 index 0000000..1334fc9 --- /dev/null +++ b/internal/tools/ranges.go @@ -0,0 +1,36 @@ +package tools + +import ( + "strconv" + "strings" +) + +// ParseRange unwraps numbers range contained string to the given slice +// 100-103 -> [100 101 102 103] +func ParseRange(s string, a *[]uint16) error { + var ( + v0, v1 uint64 + err error + ) + if !strings.ContainsRune(s, '-') { + v0, err = strconv.ParseUint(s, 10, 16) + if err != nil { + return err + } + *a = []uint16{uint16(v0)} + return nil + } + + p := strings.Split(s, "-") + if v0, err = strconv.ParseUint(p[0], 10, 16); err != nil { + return err + } + if v1, err = strconv.ParseUint(p[1], 10, 16); err != nil { + return err + } + for v0 <= v1 { + *a = append(*a, uint16(v0)) + v0++ + } + return nil +} diff --git a/internal/tools/ranges_test.go b/internal/tools/ranges_test.go new file mode 100644 index 0000000..693b2fd --- /dev/null +++ b/internal/tools/ranges_test.go @@ -0,0 +1,35 @@ +package tools + +import ( + "reflect" + "testing" +) + +func Test_parseRange(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + wantSlice []uint16 + wantErr bool + }{ + {"single", args{"34"}, []uint16{34}, false}, + {"singleErr", args{"3f4"}, []uint16{}, true}, + {"range", args{"34-36"}, []uint16{34, 35, 36}, false}, + {"rangeErr", args{"34-3-6"}, []uint16{}, false}, + } + for _, tt := range tests { + a := []uint16{} + t.Run(tt.name, func(t *testing.T) { + err := ParseRange(tt.args.s, &a) + if (err != nil) != tt.wantErr { + t.Errorf("ParseRange() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(a, tt.wantSlice) { + t.Errorf("after parseRange() given slice contains %v, want %v ", a, tt.wantSlice) + } + }) + } +} diff --git a/internal/tools/subnets.go b/internal/tools/subnets.go new file mode 100644 index 0000000..7805424 --- /dev/null +++ b/internal/tools/subnets.go @@ -0,0 +1,16 @@ +package tools + +import ( + "net" +) + +// IpInSubnets check if the IP address is found in the subnets +// return the boolean (found IP or not) and the subnet that contains the given IP as a string +func IpInSubnets(ip net.IP, subnets []*net.IPNet) (bool, string) { + for _, sn := range subnets { + if sn.Contains(ip) { + return true, sn.String() + } + } + return false, "" +} diff --git a/internal/tools/subnets_test.go b/internal/tools/subnets_test.go new file mode 100644 index 0000000..dff488e --- /dev/null +++ b/internal/tools/subnets_test.go @@ -0,0 +1,42 @@ +package tools + +import ( + "net" + "testing" +) + +func Test_ipInSubnets(t *testing.T) { + type args struct { + ip net.IP + subnets []*net.IPNet + } + tests := []struct { + name string + args args + wantContains bool + wantSubnet string + }{ + {"empty", args{net.IP{}, []*net.IPNet{}}, false, ""}, + {"yes", args{net.IP{10, 10, 0, 15}, []*net.IPNet{ + {net.IP{10, 10, 1, 0}, net.IPMask{255, 255, 255, 0}}, + {net.IP{10, 10, 0, 0}, net.IPMask{255, 255, 255, 0}}, + {net.IP{192, 168, 22, 0}, net.IPMask{255, 255, 0, 0}}, + }}, true, "10.10.0.0/24"}, + {"no", args{net.IP{10, 10, 0, 15}, []*net.IPNet{ + {net.IP{10, 10, 1, 0}, net.IPMask{255, 255, 255, 0}}, + {net.IP{10, 10, 2, 0}, net.IPMask{255, 255, 255, 0}}, + {net.IP{192, 168, 22, 0}, net.IPMask{255, 255, 0, 0}}, + }}, false, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotContains, gotSubnet := IpInSubnets(tt.args.ip, tt.args.subnets) + if gotContains != tt.wantContains { + t.Errorf("ipInSubnets() gotContains = %v, want %v", gotContains, tt.wantContains) + } + if gotSubnet != tt.wantSubnet { + t.Errorf("ipInSubnets() gotSubnet = %v, want %v", gotSubnet, tt.wantSubnet) + } + }) + } +} diff --git a/testdata/sni_dump_ttnet.pcap b/testdata/sni_dump_ttnet.pcap new file mode 100644 index 0000000..507c896 Binary files /dev/null and b/testdata/sni_dump_ttnet.pcap differ diff --git a/testdata/sni_dump_vk.pcap b/testdata/sni_dump_vk.pcap new file mode 100644 index 0000000..3fc3300 Binary files /dev/null and b/testdata/sni_dump_vk.pcap differ