I am trying to traverse through a vault server concurrently. I am able to recursively search but I am having trouble parallelizing it.
Below is a sample of the code I came up with. I am looking on how to make this faster by using concurrency. Is there a way to traverse through the Vault paths concurrently?
package main
import (
"flag"
"fmt"
"github.com/hashicorp/vault/api"
"log"
"net/http"
"strings"
"sync"
"time"
)
var vault_path string
var vault_addr string
var dev_vault_addr = ""
var pie_vault_addr = ""
var prod_vault_addr = ""
var dev_vault_path string = ""
var pie_vault_path string = ""
var prod_vault_path string = ""
//flags
var env = flag.String("ENV","","vault environment - dev,pie,prod")
var okta_user = flag.String("USER","","okta username")
var okta_pw = flag.String("PW","","okta pw")
var searchValue= flag.String("VALUE","","value to search for")
var searchKey = flag.String("KEY","","key to search for")
func main() {
flag.Parse()
switch *env {
case "dev":
fmt.Println("dev vault ")
vault_path = dev_vault_path
vault_addr = dev_vault_addr
case "pie":
fmt.Println("pie")
vault_path = pie_vault_path
vault_addr = pie_vault_addr
case "prod":
fmt.Println("prod")
vault_path = prod_vault_path
vault_addr = prod_vault_addr
}
workerCount := 1
jobs := make(chan workerJob, workerCount)
results := make(chan workerResult)
readDone := make(chan bool)
wg := &sync.WaitGroup{}
// start N workers
for i := 0; i < workerCount; i++ {
go worker(jobs, results, wg)
}
// One initial job
wg.Add(1)
go func() {
jobs <- workerJob{
Path: vault_path,
}
}()
// When all jobs finished, shutdown the system.
go func() {
wg.Wait()
readDone <- true
}()
readloop:
for {
select {
case res := <-results:
log.Printf(`result=%#v`, res.secret)
case <-readDone:
log.Printf(`got stop`)
close(jobs)
break readloop
}
}
}
func setupClient(vault_addr string) *api.Client {
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
client, err := api.NewClient(&api.Config{Address: vault_addr, HttpClient: httpClient})
if err != nil {
fmt.Println(err)
}
// to pass the password
options := map[string]interface{}{
"password": okta_pw,
}
// the login path
// this is configurable, change userpass to ldap etc
path := fmt.Sprintf("auth/okta/login/%s", *okta_user)
// PUT call to get a token
secret, err := client.Logical().Write(path, options)
client.SetToken(secret.Auth.ClientToken)
return client
}
func walkDir(client *api.Client, path string) {
var value *api.Secret
var err error
if path != "" {
value, err = client.Logical().List(path)
} else {
path = vault_path
value, err = client.Logical().List(path)
}
if err != nil {
fmt.Println(err)
}
var datamap map[string]interface{}
datamap = value.Data
data := datamap["keys"].([]interface{})
for _, item := range data {
itemString := item.(string)
if strings.HasSuffix(itemString, "/") {
walkDir(client, path+itemString)
} else {
//its a secret
data := read(client, path+itemString)
if *searchKey!="" && searchForKey(data,*searchKey){
fmt.Println(path + itemString)
}
if *searchValue!="" && searchForValue(data,*searchValue){
fmt.Println(path + itemString)
}
}
}
}
func read(client *api.Client, path string) map[string]interface{} {
value, err := client.Logical().Read(path)
if err != nil {
fmt.Println(err)
}
values := value.Data
return values
}
func searchForValue(mapp map[string]interface{}, searchValue string) bool {
for _, value := range mapp {
if searchValue == value {
return true
}
}
return false
}
func searchForKey(mapp map[string]interface{}, searchKey string) bool {
for key := range mapp {
if searchKey == key {
return true
}
}
return false
}
// Job for worker
type workerJob struct {
Address string
Path string
}
// Result of a worker
type workerResult struct {
secret map[string]interface{}
}
func worker(jobs chan workerJob, results chan<- workerResult, wg *sync.WaitGroup) {
for j := range jobs {
client := setupClient(vault_addr)
log.Printf(`Vault Path: %#v`, j.Path)
var value *api.Secret
if j.Path != "" {
value, _ = client.Logical().List(j.Path)
} else {
j.Path = vault_path
value, _ = client.Logical().List(j.Path)
}
var datamap map[string]interface{}
datamap = value.Data
data := datamap["keys"].([]interface{})
for _, item := range data {
itemString := item.(string)
if strings.HasSuffix(itemString, "/") {
nj := workerJob{Path: itemString}
log.Printf(`sent new vault dir job: %#v`,nj.Path)
//one more job add to wg
wg.Add(1)
// Do not block when sending jobs
go func() {
jobs <- nj
}()
} else {
//its a secret
data := read(client, j.Path+itemString)
if *searchKey!="" && searchForKey(data,*searchKey){
log.Printf(vault_path+itemString)
r := workerResult{
secret: data,
}
results<-r
}
if *searchValue!="" && searchForValue(data,*searchValue){
log.Printf(vault_path+itemString)
r := workerResult{
secret: data,
}
results<-r
}
}
}
// Done one job, let wg know.
wg.Done()
}
}
Related
Why is this program so slow? I thought the code was fairly optimized, but it takes significantly long than the find command when use on my root filesystem.
It takes about 4 minutes, as opposed to the find command which takes about 40 seconds.
I tried removing the sorting algorithm, but doesn't speed up the program.
package main
import (
"fmt"
"io"
"io/fs"
"log"
"os"
"sort"
"sync"
"github.com/google/fscrypt/filesystem"
"github.com/sirupsen/logrus"
"gopkg.in/alecthomas/kingpin.v2"
)
var (
mountpoint = kingpin.Flag("mount", "The mount to find the largest file usages. Can be a subath of mount").Required().String()
limit = kingpin.Flag("limit", "The maximum number of files return to the display").Default("10").Short('l').Int()
)
var device string
type fileDisplay struct {
Size int64
Path string
}
type bySize []fileDisplay
func (a bySize) Len() int { return len(a) }
func (a bySize) Less(i, j int) bool { return a[i].Size < a[j].Size }
func (a bySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
var fileChan = make(chan fileDisplay)
var files []fileDisplay
func main() {
log.SetOutput(io.Discard)
kingpin.Version("0.0.1")
kingpin.Parse()
//Define limit after parsing
logrus.SetLevel(logrus.FatalLevel)
if (*mountpoint)[len(*mountpoint)-1:] != "/" {
*mountpoint = *mountpoint + "/"
}
fmt.Println("Finding the top", *limit, "largest files on filesystem", *mountpoint, "\n================================================")
mount, err := filesystem.FindMount(*mountpoint)
if err != nil {
logrus.Fatal(err)
}
device = mount.Device
entries, err := os.ReadDir(*mountpoint)
if err != nil {
logrus.Fatal(err)
}
var wg sync.WaitGroup
getFiles(*mountpoint, entries, &wg)
go func() {
defer close(fileChan)
wg.Wait()
}()
var last int64
for file := range fileChan {
if file.Size > last {
files = append(files, file)
} else {
files = append([]fileDisplay{file}, files...)
}
}
sort.Sort(bySize(files))
var shortFiles []fileDisplay
if len(files) > *limit {
shortFiles = files[len(files)-*limit:]
} else {
shortFiles = files
}
for _, file := range shortFiles {
fmt.Println(file.Path, file.DisplaySizeIEC())
}
}
func getFiles(start string, entries []fs.DirEntry, wg *sync.WaitGroup) {
for _, entry := range entries {
wg.Add(1)
go handleEntry(start, entry, wg)
}
}
func handleEntry(start string, entry fs.DirEntry, wg *sync.WaitGroup) {
defer wg.Done()
var file fileDisplay
mount, err := filesystem.FindMount(start + entry.Name())
if err != nil {
logrus.Fatalln(err, start+entry.Name())
return
}
if mount.Device == device {
if entry.Type().IsRegular() {
fileInfo, err := os.Stat(start + entry.Name())
if err != nil {
logrus.Fatalln(err, start+entry.Name())
return
}
file.Path = start + entry.Name()
file.Size = fileInfo.Size()
fileChan <- file
} else if entry.IsDir() {
entries, err := os.ReadDir(start + entry.Name())
if err != nil {
logrus.Fatalln(err, start+entry.Name())
return
}
logrus.Info("Searching ", start+entry.Name())
getFiles(start+entry.Name()+"/", entries, wg)
}
}
}
func (f *fileDisplay) DisplaySizeIEC() string {
const unit = 1024
b := f.Size
if b < unit {
return fmt.Sprintf("%dB", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.2f%ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
Edit: I tried removing the channel and just appending to the slice. This sped it up, but it's not safe because multiple routines could be accessing it.
My final draft involved dropping the channel and using sync.RWMutex to lock the list and a custom append function to append with the lock. This allowed me to drop the channel and use append without risking multiple routines editing the same slice.
I dropped the channel because this was causing routines to stay open until the for loop over the open channel could reach their message. My channek operations were blocking. So the routines caused it to slow to the speed of the for loop iterating over the channel.
You can see the differences here:
package main
import (
"fmt"
"io"
"io/fs"
"log"
"os"
"sort"
"sync"
"github.com/google/fscrypt/filesystem"
"github.com/sirupsen/logrus"
"gopkg.in/alecthomas/kingpin.v2"
)
var (
mountpoint = kingpin.Flag("mount", "The mount to find the largest file usages. Can be a subath of mount").Required().String()
limit = kingpin.Flag("limit", "The maximum number of files return to the display").Default("10").Short('l').Int()
)
var device string
type fileDisplays struct {
sync.RWMutex
Files []fileDisplay
}
var files fileDisplays
type fileDisplay struct {
Size int64
Path string
}
type bySize []fileDisplay
func (a bySize) Len() int { return len(a) }
func (a bySize) Less(i, j int) bool { return a[i].Size < a[j].Size }
func (a bySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func main() {
log.SetOutput(io.Discard)
kingpin.Version("0.0.1")
kingpin.Parse()
//Define limit after parsing
logrus.SetLevel(logrus.FatalLevel)
if (*mountpoint)[len(*mountpoint)-1:] != "/" {
*mountpoint = *mountpoint + "/"
}
fmt.Println("Finding the top", *limit, "largest files on filesystem", *mountpoint, "\n================================================")
mount, err := filesystem.FindMount(*mountpoint)
if err != nil {
logrus.Fatal(err)
}
device = mount.Device
entries, err := os.ReadDir(*mountpoint)
if err != nil {
logrus.Fatal(err)
}
var wg sync.WaitGroup
getFiles(*mountpoint, entries, &wg)
wg.Wait()
sort.Sort(bySize(files.Files))
var shortFiles []fileDisplay
if len(files.Files) > *limit {
shortFiles = files.Files[len(files.Files)-*limit:]
} else {
shortFiles = files.Files
}
for _, file := range shortFiles {
fmt.Println(file.Path, file.DisplaySizeIEC())
}
}
func getFiles(start string, entries []fs.DirEntry, wg *sync.WaitGroup) {
for _, entry := range entries {
wg.Add(1)
go handleEntry(start, entry, wg)
}
}
func handleEntry(start string, entry fs.DirEntry, wg *sync.WaitGroup) {
defer wg.Done()
var file fileDisplay
mount, err := filesystem.FindMount(start + entry.Name())
if err != nil {
logrus.Errorln(err, start+entry.Name())
return
}
if mount.Device == device {
if entry.Type().IsRegular() {
fileInfo, err := os.Stat(start + entry.Name())
if err != nil {
logrus.Errorln(err, start+entry.Name())
return
}
file.Path = start + entry.Name()
file.Size = fileInfo.Size()
files.Append(file)
} else if entry.IsDir() {
entries, err := os.ReadDir(start + entry.Name())
if err != nil {
logrus.Errorln(err, start+entry.Name())
return
}
logrus.Info("Searching ", start+entry.Name())
getFiles(start+entry.Name()+"/", entries, wg)
}
}
}
func (f *fileDisplay) DisplaySizeIEC() string {
const unit = 1024
b := f.Size
if b < unit {
return fmt.Sprintf("%dB", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.2f%ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func (fd *fileDisplays) Append(item fileDisplay) {
fd.Lock()
defer fd.Unlock()
fd.Files = append(fd.Files, item)
}
I want to loop through the menu's options. However, it stops at the first option, since the select without "default:" is blocking and it does not know more options will appear dynamically.
Bellow is the broken code:
package main
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"os/exec"
"strings"
"time"
"github.com/getlantern/systray"
"gopkg.in/yaml.v3"
)
var menuItensPtr []*systray.MenuItem
var config map[string]string
var commands []string
func main() {
config = readconfig()
systray.Run(onReady, onExit)
}
func onReady() {
systray.SetIcon(getIcon("assets/menu.ico"))
menuItensPtr = make([]*systray.MenuItem,0)
commands = make([]string,0)
for k, v := range config {
menuItemPtr := systray.AddMenuItem(k, k)
menuItensPtr = append(menuItensPtr, menuItemPtr)
commands = append(commands, v)
}
systray.AddSeparator()
// mQuit := systray.AddMenuItem("Quit", "Quits this app")
go func() {
for {
systray.SetTitle("My tray menu")
systray.SetTooltip("https://github.com/evandrojr/my-tray-menu")
time.Sleep(1 * time.Second)
}
}()
go func() {
for{
for i, menuItenPtr := range menuItensPtr {
select {
/// EXECUTION GETS STUCK HERE!!!!!!!
case<-menuItenPtr.ClickedCh:
execute(commands[i])
}
}
// select {
// case <-mQuit.ClickedCh:
// systray.Quit()
// return
// // default:
// }
}
}()
}
func onExit() {
// Cleaning stuff will go here.
}
func getIcon(s string) []byte {
b, err := ioutil.ReadFile(s)
if err != nil {
fmt.Print(err)
}
return b
}
func execute(commands string){
command_array:= strings.Split(commands, " ")
command:=""
command, command_array = command_array[0], command_array[1:]
cmd := exec.Command(command, command_array ...)
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
log.Fatal(err)
}
// fmt.Printf("Output %s\n", out.String())
}
func readconfig() map[string]string{
yfile, err := ioutil.ReadFile("my-tray-menu.yaml")
if err != nil {
log.Fatal(err)
}
data := make(map[string]string)
err2 := yaml.Unmarshal(yfile, &data)
if err2 != nil {
log.Fatal(err2)
}
for k, v := range data {
fmt.Printf("%s -> %s\n", k, v)
}
return data
}
Bellow is the ugly workaround that works:
package main
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/getlantern/systray"
"gopkg.in/yaml.v3"
)
var menuItensPtr []*systray.MenuItem
var config map[string]string
var commands []string
var labels []string
var programPath string
func main() {
setProgramPath()
config = readconfig()
time.Sleep(1 * time.Second)
systray.Run(onReady, onExit)
}
func onReady() {
systray.SetIcon(getIcon(filepath.Join(programPath,"assets/menu.ico")))
menuItensPtr = make([]*systray.MenuItem, 0)
i := 0
op0 := systray.AddMenuItem(labels[i], commands[i])
i++
op1 := systray.AddMenuItem(labels[i], commands[i])
i++
op2 := systray.AddMenuItem(labels[i], commands[i])
i++
op3 := systray.AddMenuItem(labels[i], commands[i])
i++
systray.AddSeparator()
mQuit := systray.AddMenuItem("Quit", "Quits this app")
go func() {
for {
systray.SetTitle("My tray menu")
systray.SetTooltip("https://github.com/evandrojr/my-tray-menu")
time.Sleep(1 * time.Second)
}
}()
go func() {
for {
select {
// HERE DOES NOT GET STUCK!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
case <-op0.ClickedCh:
execute(commands[0])
case <-op1.ClickedCh:
execute(commands[1])
case <-op2.ClickedCh:
execute(commands[2])
case <-op3.ClickedCh:
execute(commands[3])
case <-mQuit.ClickedCh:
systray.Quit()
return
}
}
}()
}
func onExit() {
// Cleaning stuff will go here.
}
func getIcon(s string) []byte {
b, err := ioutil.ReadFile(s)
if err != nil {
fmt.Print(err)
}
return b
}
func setProgramPath(){
ex, err := os.Executable()
if err != nil {
panic(err)
}
programPath = filepath.Dir(ex)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func execute(commands string) {
command_array := strings.Split(commands, " ")
command := ""
command, command_array = command_array[0], command_array[1:]
cmd := exec.Command(command, command_array...)
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
log.Fatal(err)
}
fmt.Printf("Output %s\n", out.String())
}
func readconfig() map[string]string {
yfile, err := ioutil.ReadFile(filepath.Join(programPath,"my-tray-menu.yaml"))
if err != nil {
log.Fatal(err)
}
data := make(map[string]string)
err2 := yaml.Unmarshal(yfile, &data)
if err2 != nil {
log.Fatal(err2)
}
labels = make([]string, 0)
commands = make([]string, 0)
for k, v := range data {
labels = append(labels, k)
commands = append(commands, v)
fmt.Printf("%s -> %s\n", k, v)
}
fmt.Print(len(labels))
return data
}
Full source code here:
https://github.com/evandrojr/my-tray-menu
select "chooses which of a set of possible send or receive operations will proceed". The spec sets out how this choice is made:
If one or more of the communications can proceed, a single one that can proceed is chosen via a uniform pseudo-random selection. Otherwise, if there is a default case, that case is chosen. If there is no default case, the "select" statement blocks until at least one of the communications can proceed.
Your working example:
select {
case <-op0.ClickedCh:
execute(commands[0])
case <-op1.ClickedCh:
execute(commands[1])
// ...
}
uses select successfully to choose between one of the offered options. However if you pass a single option e.g.
select {
case<-menuItenPtr.ClickedCh:
execute(commands[i])
}
}
The select will block until <-menuItenPtr.ClickedCh is ready to proceed (e.g. something is received). This is effectively the same as not using a select:
<-menuItenPtr.ClickedCh:
execute(commands[i])
The result you were expecting can be achieved by providing a default option:
select {
case<-menuItenPtr.ClickedCh:
execute(commands[i])
}
default:
}
As per the quote from the spec above the default option will be chosen if none of the other options can proceed. While this may work it's not a very good solution because you effectively end up with:
for {
// Check if event happened (not blocking)
}
This will tie up CPU time unnecessarily as it continually loops checking for events. A better solution would be to start a goroutine to monitor each channel:
for i, menuItenPtr := range menuItensPtr {
go func(c chan struct{}, cmd string) {
for range c { execute(cmd) }
}(menuItenPtr.ClickedCh, commands[i])
}
// Start another goroutine to handle quit
The above will probably work but does lead to the possibility that execute will be called concurrently (which might cause issues if your code is not threadsafe). One way around this is to use the "fan in" pattern (as suggested by #kostix and in the Rob Pike video suggested by #John); something like:
cmdChan := make(chan int)
for i, menuItenPtr := range menuItensPtr {
go func(c chan struct{}, cmd string) {
for range c { cmdChan <- cmd }
}(menuItenPtr.ClickedCh, commands[i])
}
go func() {
for {
select {
case cmd := <- cmdChan:
execute(cmd) // Handle command
case <-mQuit.ClickedCh:
systray.Quit()
return
}
}
}()
note: all code above entered directly into the question so please treat as pseudo code!
Try to write a directory traversing program by goroutine and channel, but unable to get the needed results. Expect to get the number of total sub-directory, files count. But when I run the code below, it will stuck in "dirCount <-1". PS: is that possible to write such a program with infinite depth traversing
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"github.com/gorilla/mux"
)
type DirectoryItem struct {
Name string `json:"name,omitemty"`
IsDir bool `json:"isDir,omitempty"`
Size int64 `json:"size,omitempty"`
}
type DirectoryInfo struct {
Path string `json:"path,omitemty"`
Dirs []DirectoryItem `json:"dirs,omitempty"`
}
var dirItems []DirectoryItem
var dirInfo DirectoryInfo
func GetOneDirItems(w http.ResponseWriter, req *http.Request) {
fpath := "E:\\"
query := req.URL.Query()
path := query["path"][0]
fpath = fpath + path
dirInfo, _ := CheckEachItem(fpath)
json.NewEncoder(w).Encode(dirInfo)
}
func CheckEachItem(dirPath string) (directory DirectoryInfo, err error) {
var items []DirectoryItem
dir, err := ioutil.ReadDir(dirPath)
if err != nil {
return directory, err
}
for _, fi := range dir {
if fi.IsDir() {
items = append(items, DirectoryItem{Name: fi.Name(), IsDir: true, Size: 0})
} else {
items = append(items, DirectoryItem{Name: fi.Name(), IsDir: false, Size: fi.Size()})
}
}
directory = DirectoryInfo{Path: dirPath, Dirs: items}
return directory, nil
}
func CalcDirInfo(w http.ResponseWriter, req *http.Request) {
query := req.URL.Query()
path := query["path"][0]
url := "http://localhost:8090/api/GetOneDirItems?path="
url += path
dirCount := make(chan int)
fileCount := make(chan int)
go Recr(url, dirCount, fileCount)
//
dirTotalCount := 0
for i := range dirCount {
dirTotalCount += i
}
fmt.Println(dirTotalCount)
}
func Recr(url string, dirCount chan int, fileCount chan int) {
fmt.Println(url)
resp, _ := http.Get(url)
dirInfo = DirectoryInfo{}
body, _ := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
json.Unmarshal([]byte(body), &dirInfo)
for _, itm := range dirInfo.Dirs {
fmt.Println("--")
if itm.IsDir {
newUrl := url + "/" + itm.Name
//// looks like stuck in here
dirCount <- 1
go Recr(newUrl, dirCount, fileCount)
} else {
fileCount <- 1
}
}
}
func main() {
router := mux.NewRouter()
//#1 func one:
//result sample:
//{"path":"E:\\code","dirs":[{"name":"A","isDir":true},{"name":"B","isDir":false}]}
router.HandleFunc("/api/GetOneDirItems", GetOneDirItems).Methods("GET")
//#2 2nd api to invoke 1st api recursively
//expected result
//{"path":"E:\\code","dirCount":2, "fileCount":3]}
router.HandleFunc("/api/GetDirInfo", CalcDirInfo).Methods("GET")
log.Fatal(http.ListenAndServe(":8090", router))
}
find some code example but not feedback the right number...
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"sync"
"github.com/gorilla/mux"
)
//!+1
var done = make(chan struct{})
func cancelled() bool {
select {
case <-done:
return true
default:
return false
}
}
//!-1
type DirectoryItem struct {
Name string `json:"name,omitemty"`
IsDir bool `json:"isDir,omitempty"`
Size int64 `json:"size,omitempty"`
}
type DirectoryInfo struct {
Path string `json:"path,omitemty"`
Dirs []DirectoryItem `json:"dirs,omitempty"`
}
var dirItems []DirectoryItem
var dirInfo DirectoryInfo
func GetOneDirItems(w http.ResponseWriter, req *http.Request) {
fpath := "E:\\"
query := req.URL.Query()
path := query["path"][0]
fpath = fpath + path
dirInfo, _ := CheckEachItem(fpath)
json.NewEncoder(w).Encode(dirInfo)
}
func CheckEachItem(dirPath string) (directory DirectoryInfo, err error) {
var items []DirectoryItem
dir, err := ioutil.ReadDir(dirPath)
if err != nil {
return directory, err
}
for _, fi := range dir {
if fi.IsDir() {
items = append(items, DirectoryItem{Name: fi.Name(), IsDir: true, Size: 0})
} else {
items = append(items, DirectoryItem{Name: fi.Name(), IsDir: false, Size: fi.Size()})
}
}
directory = DirectoryInfo{Path: dirPath, Dirs: items}
return directory, nil
}
func CalcDirInfo(w http.ResponseWriter, req *http.Request) {
query := req.URL.Query()
path := query["path"][0]
url := "http://localhost:8090/api/GetOneDirItems?path="
url += path
fpath := "E:\\"
fpath = fpath + path
dirInfo, _ := CheckEachItem(fpath)
fileSizes := make(chan int64)
dirCount := make(chan int, 100)
var n sync.WaitGroup
for _, item := range dirInfo.Dirs {
n.Add(1)
url = url + "/" + item.Name
go Recr(url, &n, dirCount, fileSizes)
}
go func() {
n.Wait()
close(fileSizes)
close(dirCount)
}()
// Print the results periodically.
// tick := time.Tick(500 * time.Millisecond)
var nfiles, ndirs, nbytes int64
loop:
//!+3
for {
select {
case <-done:
// Drain fileSizes to allow existing goroutines to finish.
for range fileSizes {
// Do nothing.
}
return
case size, ok := <-fileSizes:
// ...
//!-3
if !ok {
break loop // fileSizes was closed
}
nfiles++
nbytes += size
case _, ok := <-dirCount:
// ...
//!-3
if !ok {
break loop // dirCount was closed
}
ndirs++
// case <-tick:
// printDiskUsage(nfiles, ndirs, nbytes)
}
}
printDiskUsage(nfiles, ndirs, nbytes) // final totals
}
func Recr(url string, n *sync.WaitGroup, dirCount chan<- int, fileSizes chan<- int64) {
defer n.Done()
resp, _ := http.Get(url)
dirInfo = DirectoryInfo{}
body, _ := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
json.Unmarshal([]byte(body), &dirInfo)
for _, itm := range dirInfo.Dirs {
if itm.IsDir {
dirCount <- 1
n.Add(1)
newUrl := url + "/" + itm.Name
go Recr(newUrl, n, dirCount, fileSizes)
} else {
fileSizes <- itm.Size
}
}
}
func main() {
// Determine the initial directories.
roots := os.Args[1:]
if len(roots) == 0 {
roots = []string{"."}
}
// API Services
router := mux.NewRouter()
router.HandleFunc("/api/GetOneDirItems", GetOneDirItems).Methods("GET")
router.HandleFunc("/api/GetDirInfo", CalcDirInfo).Methods("GET")
log.Fatal(http.ListenAndServe(":8090", router))
}
func printDiskUsage(nfiles, ndirs, nbytes int64) {
fmt.Printf("%d files %.1f GB %d dirs\n", nfiles, float64(nbytes)/1e9, ndirs)
}
// walkDir recursively walks the file tree rooted at dir
// and sends the size of each found file on fileSizes.
//!+4
func walkDir(dir string, n *sync.WaitGroup, fileSizes chan<- int64, dirCount chan<- int) {
defer n.Done()
if cancelled() {
return
}
for _, entry := range dirents(dir) {
// ...
//!-4
if entry.IsDir() {
dirCount <- 1
n.Add(1)
subdir := filepath.Join(dir, entry.Name())
go walkDir(subdir, n, fileSizes, dirCount)
} else {
fileSizes <- entry.Size()
}
//!+4
}
}
//!-4
var sema = make(chan struct{}, 20) // concurrency-limiting counting semaphore
// dirents returns the entries of directory dir.
//!+5
func dirents(dir string) []os.FileInfo {
select {
case sema <- struct{}{}: // acquire token
case <-done:
return nil // cancelled
}
defer func() { <-sema }() // release token
// ...read directory...
//!-5
f, err := os.Open(dir)
if err != nil {
fmt.Fprintf(os.Stderr, "du: %v\n", err)
return nil
}
defer f.Close()
entries, err := f.Readdir(0) // 0 => no limit; read all entries
if err != nil {
fmt.Fprintf(os.Stderr, "du: %v\n", err)
// Don't return: Readdir may return partial results.
}
return entries
}
The problem here is the your program has no way of ending. Basically whenever the code recurses into another directory, you need to count that, and then when it finishes processing the directory, you need to push 1 to a done channel. When the count of directories recursed into == the number done, you can exit the channel select (that's the other missing part):
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"github.com/gorilla/mux"
)
type DirectoryItem struct {
Name string `json:"name,omitemty"`
IsDir bool `json:"isDir,omitempty"`
Size int64 `json:"size,omitempty"`
}
type DirectoryInfo struct {
Path string `json:"path,omitemty"`
Dirs []DirectoryItem `json:"dirs,omitempty"`
}
var dirItems []DirectoryItem
var dirInfo DirectoryInfo
func GetOneDirItems(w http.ResponseWriter, req *http.Request) {
fpath := "E:\\"
query := req.URL.Query()
path := query["path"][0]
fpath = fpath + path
dirInfo, err := CheckEachItem(fpath)
if err != nil {
panic(err)
}
json.NewEncoder(w).Encode(dirInfo)
}
func CheckEachItem(dirPath string) (directory DirectoryInfo, err error) {
var items []DirectoryItem
dir, err := ioutil.ReadDir(dirPath)
if err != nil {
return directory, err
}
for _, fi := range dir {
if fi.IsDir() {
items = append(items, DirectoryItem{Name: fi.Name(), IsDir: true, Size: 0})
} else {
items = append(items, DirectoryItem{Name: fi.Name(), IsDir: false, Size: fi.Size()})
}
}
directory = DirectoryInfo{Path: dirPath, Dirs: items}
return directory, nil
}
func CalcDirInfo(w http.ResponseWriter, req *http.Request) {
query := req.URL.Query()
path := query["path"][0]
url := "http://localhost:8090/api/GetOneDirItems?path="
url += path
dirCount := make(chan int, 10)
fileCount := make(chan int, 10)
doneCount := make(chan int, 10)
go Recr(url, dirCount, fileCount, doneCount)
//
dirTotalCount := 0
doneTotalCount := 0
out:
for {
select {
case dir := <-dirCount:
dirTotalCount += dir
fmt.Printf("dirTotalCount=%d\n", dirTotalCount)
case <-fileCount:
case done := <-doneCount:
doneTotalCount += done
fmt.Printf("doneTotalCount=%d dirTotalCount=%d\n", doneTotalCount, dirTotalCount)
if doneTotalCount == dirTotalCount+1 { // need -1 because of the root dir, which is not counted as a subdirectory
break out
}
}
}
fmt.Println("ALL DONE")
fmt.Printf("TOTAL=%d\n", dirTotalCount)
}
func Recr(url string, dirCount chan int, fileCount chan int, doneCount chan int) {
// fmt.Printf("url=%s\n", url)
resp, _ := http.Get(url)
dirInfo = DirectoryInfo{}
body, _ := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
json.Unmarshal([]byte(body), &dirInfo)
// fmt.Printf("dirInfo=%+v body=%s", dirInfo, string(body))
for _, itm := range dirInfo.Dirs {
if itm.IsDir {
newUrl := url + "/" + itm.Name
//// looks like stuck in here
// fmt.Printf("pushing one dir from %s\n", url)
dirCount <- 1
go Recr(newUrl, dirCount, fileCount, doneCount)
} else {
// fmt.Println("pushing one file")
fileCount <- 1
}
}
doneCount <- 1
}
func main() {
router := mux.NewRouter()
//#1 func one:
//result sample:
//{"path":"E:\\code","dirs":[{"name":"A","isDir":true},{"name":"B","isDir":false}]}
router.HandleFunc("/api/GetOneDirItems", GetOneDirItems).Methods("GET")
//#2 2nd api to invoke 1st api recursively
//expected result
//{"path":"E:\\code","dirCount":2, "fileCount":3]}
router.HandleFunc("/api/GetDirInfo", CalcDirInfo).Methods("GET")
log.Fatal(http.ListenAndServe(":8090", router))
}
I am writing a go project which is a simple web crawler to crawl links on the website. I want to experiment the concurrent features such as goroutines and channels. But when I run it it didn't go through. Nothing is showed as if there is nothing happening. I have no idea what went wrong. Can somebody point it out for me?
It works and shows all the crawled links if I remove the channels logic but I want it to send the links into a buffered channel and then display the links before ending the program. The program is supposed to be able to go to any depth as specified in the program. Currently the depth is 1.
package main
import (
"fmt"
"log"
"net/http"
"os"
"strings"
"time"
"golang.org/x/net/html"
)
// Link type to be sent over channel
type Link struct {
URL string
ok bool
}
func main() {
if len(os.Args) != 2 {
fmt.Println("Usage: crawl [URL].")
}
url := os.Args[1]
if !strings.HasPrefix(url, "http://") {
url = "http://" + url
}
ch := make(chan *Link, 5)
crawl(url, 1, ch)
visited := make(map[string]bool)
time.Sleep(2 * time.Second)
for link := range ch {
if _, ok := visited[link.URL]; !ok {
visited[link.URL] = true
}
}
close(ch)
for l := range visited {
fmt.Println(l)
}
}
func crawl(url string, n int, ch chan *Link) {
if n < 1 {
return
}
resp, err := http.Get(url)
if err != nil {
log.Fatalf("Can not reach the site. Error = %v\n", err)
os.Exit(1)
}
b := resp.Body
defer b.Close()
z := html.NewTokenizer(b)
nextN := n - 1
for {
token := z.Next()
switch token {
case html.ErrorToken:
return
case html.StartTagToken:
current := z.Token()
if current.Data != "a" {
continue
}
result, ok := getHrefTag(current)
if !ok {
continue
}
hasProto := strings.HasPrefix(result, "http")
if hasProto {
go crawl(result, nextN, ch)
ch <- &Link{result, true}
}
}
}
}
func getHrefTag(token html.Token) (result string, ok bool) {
for _, a := range token.Attr {
if a.Key == "href" {
result = a.Val
ok = true
break
}
}
return
}
UPDATED:
After some fiddling I figured out to change the code to remove the data races, however I still don't know how to avoid crawling urls that were visited previously (maybe I should start another question?):
package main
import (
"fmt"
"log"
"net/http"
"os"
"strings"
"golang.org/x/net/html"
)
func main() {
if len(os.Args) != 2 {
fmt.Println("Usage: crawl [URL].")
}
url := os.Args[1]
if !strings.HasPrefix(url, "http://") {
url = "http://" + url
}
for link := range newCrawl(url, 1) {
fmt.Println(link)
}
}
func newCrawl(url string, num int) chan string {
ch := make(chan string, 20)
go func() {
crawl(url, 1, ch)
close(ch)
}()
return ch
}
func crawl(url string, n int, ch chan string) {
if n < 1 {
return
}
resp, err := http.Get(url)
if err != nil {
log.Fatalf("Can not reach the site. Error = %v\n", err)
os.Exit(1)
}
b := resp.Body
defer b.Close()
z := html.NewTokenizer(b)
nextN := n - 1
for {
token := z.Next()
switch token {
case html.ErrorToken:
return
case html.StartTagToken:
current := z.Token()
if current.Data != "a" {
continue
}
result, ok := getHrefTag(current)
if !ok {
continue
}
hasProto := strings.HasPrefix(result, "http")
if hasProto {
done := make(chan struct{})
go func() {
crawl(result, nextN, ch)
close(done)
}()
<-done
ch <- result
}
}
}
}
func getHrefTag(token html.Token) (result string, ok bool) {
for _, a := range token.Attr {
if a.Key == "href" {
result = a.Val
ok = true
break
}
}
return
}
I think that recursive calling of goroutines is not good idea. It can simply goes out of control.. I would prefer more flat model like this:
package main
import (
"fmt"
"log"
"net/http"
"os"
"strings"
"sync"
"golang.org/x/net/html"
)
func main() {
if len(os.Args) != 2 {
fmt.Println("Usage: crawl [URL].")
}
url := os.Args[1]
if !strings.HasPrefix(url, "http://") {
url = "http://" + url
}
wg := NewWorkGroup(1)
wg.Crawl(url)
for k, v := range wg.urlMap {
fmt.Printf("%s: %d\n", k, v)
}
}
// represents single link and its deph
type Link struct {
url string
deph uint32
}
// wraps all around to group
type WorkGroup struct {
*sync.WaitGroup
maxDeph uint32
numW int
pool chan *Worker
linkQ chan Link
urlMap map[string]uint32
}
type Worker struct {
result chan []Link
}
func newWorker() *Worker {
return &Worker{
result: make(chan []Link),
}
}
func NewWorkGroup(maxDeph uint32) *WorkGroup {
numW := int(maxDeph)
if maxDeph > 10 {
numW = 10
}
return &WorkGroup{
WaitGroup: new(sync.WaitGroup),
maxDeph: maxDeph,
numW: numW,
pool: make(chan *Worker, numW),
linkQ: make(chan Link, 100),
urlMap: make(map[string]uint32),
}
}
// dispatch workers -> filter visited -> send not visited to channel
// pool + dispatcher keep order so workers go level by level
func (wg *WorkGroup) spawnDispatcher() {
wg.Add(1)
go func() {
defer wg.Done()
defer close(wg.linkQ)
for w := range wg.pool {
links := <-w.result
for i := 0; i < len(links); i++ {
if _, ok := wg.urlMap[links[i].url]; !ok {
wg.urlMap[links[i].url] = links[i].deph
// dont process links that reach max deph
if links[i].deph < wg.maxDeph {
select {
case wg.linkQ <- links[i]:
// goes well
continue
default:
// channel is too short, protecting possible deadlock
}
// drop rest of links
break
}
}
}
// empty link channel + nothing in process = end
if len(wg.linkQ) == 0 && len(wg.pool) == 0 {
return
}
}
}()
}
//initialize goroutines and crawl url
func (wg *WorkGroup) Crawl(url string) {
defer close(wg.pool)
wg.spawnCrawlers()
wg.spawnDispatcher()
wg.linkQ <- Link{url: url, deph: 0}
wg.Wait()
}
func (wg *WorkGroup) spawnCrawlers() {
// custom num of workers, used maxDeph
for i := 0; i < wg.numW; i++ {
wg.newCrawler()
}
}
func (wg *WorkGroup) newCrawler() {
wg.Add(1)
go func(w *Worker) {
defer wg.Done()
defer close(w.result)
for link := range wg.linkQ {
wg.pool <- w
w.result <- getExternalUrls(link)
}
}(newWorker())
}
// default sligtly modified crawl function
func getExternalUrls(source Link) []Link {
resp, err := http.Get(source.url)
if err != nil {
log.Printf("Can not reach the site. Error = %v\n", err)
return nil
}
b := resp.Body
defer b.Close()
z := html.NewTokenizer(b)
links := []Link{}
for {
token := z.Next()
switch token {
case html.ErrorToken:
return links
case html.StartTagToken:
current := z.Token()
if current.Data != "a" {
continue
}
url, ok := getHrefTag(current)
if ok && strings.HasPrefix(url, "http") {
links = append(links, Link{url: url, deph: source.deph + 1})
}
}
}
return links
}
//default function
func getHrefTag(token html.Token) (result string, ok bool) {
for _, a := range token.Attr {
if a.Key == "href" {
result = a.Val
ok = true
break
}
}
return
}
I found a good web invalid links checker. But how to change it for a complete sample by using goroutine? The web page is: How To Crawl A Website In Golang. The codes dynamically add the url that will be searched to the pending slice. but I have some difficulties to use goroutine to do it.
package main
import (
"crypto/tls"
"errors"
"fmt"
"golang.org/x/net/html"
"io"
"net/http"
"net/url"
"strings"
"time"
)
var alreadyCrawledList []string
var pending []string
var brokenLinks []string
const localHostWithPort = "localhost:8080"
func IsLinkInPendingQueue(link string) bool {
for _, x := range pending {
if x == link {
return true
}
}
return false
}
func IsLinkAlreadyCrawled(link string) bool {
for _, x := range alreadyCrawledList {
if x == link {
return true
}
}
return false
}
func AddLinkInAlreadyCrawledList(link string) {
alreadyCrawledList = append(alreadyCrawledList, link)
}
func AddLinkInPendingQueue(link string) {
pending = append(pending, link)
}
func AddLinkInBrokenLinksQueue(link string) {
brokenLinks = append(brokenLinks, link)
}
func main() {
start := time.Now()
AddLinkInPendingQueue("http://" + localHostWithPort)
for count := 0; len(pending) > 0; count++ {
x := pending[0]
pending = pending[1:] //it dynamicly change the search url
if err := crawlPage(x); err != nil { //how to use it by using goroutine?
t.Errorf(err.Error())
}
}
duration := time.Since(start)
fmt.Println("________________")
count = 0
for _, l := range brokenLinks {
count++
fmt.Println(count, "Broken. | ", l)
}
fmt.Println("Time taken:", duration)
}
func crawlPage(uri string) error {
if IsLinkAlreadyCrawled(uri) {
fmt.Println("Already visited: Ignoring uri | ", uri)
return nil
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client := http.Client{Transport: transport}
resp, err := client.Get(uri)
if err != nil {
fmt.Println("Got error: ", err.Error())
return err
}
if resp.StatusCode != http.StatusOK {
AddLinkInBrokenLinksQueue(uri)
return errors.New(fmt.Sprintf("Got %v instead of 200", resp.StatusCode))
}
defer resp.Body.Close()
links := ParseLinks(resp.Body)
links = ConvertLinksToLocalHost(links)
for _, link := range links {
if !InOurDomain(link) {
continue
}
absolute := FixURL(link, uri)
if !IsLinkAlreadyCrawled(absolute) && !IsLinkInPendingQueue(absolute) && absolute != uri { // Don't enqueue a page twice!
AddLinkInPendingQueue(absolute)
}
}
AddLinkInAlreadyCrawledList(uri)
return nil
}
func InOurDomain(link string) bool {
uri, err := url.Parse(link)
if err != nil {
return false
}
if uri.Scheme == "http" || uri.Scheme == "https" {
if uri.Host == localHostWithPort {
return true
}
return false
}
return true
}
func ConvertLinksToLocalHost(links []string) []string {
var convertedLinks []string
for _, link := range links {
convertedLinks = append(convertedLinks, strings.Replace(link, "leantricks.com", localHostWithPort, 1))
}
return convertedLinks
}
func FixURL(href, base string) string {
uri, err := url.Parse(href)
if err != nil {
return ""
}
baseURL, err := url.Parse(base)
if err != nil {
return ""
}
uri = baseURL.ResolveReference(uri)
return uri.String()
}
func ParseLinks(httpBody io.Reader) []string {
var links []string
page := html.NewTokenizer(httpBody)
for {
tokenType := page.Next()
if tokenType == html.ErrorToken {
return links
}
token := page.Token()
switch tokenType {
case html.StartTagToken:
fallthrough
case html.SelfClosingTagToken:
switch token.DataAtom.String() {
case "a":
fallthrough
case "link":
fallthrough
case "script":
for _, attr := range token.Attr {
if attr.Key == "href" {
links = append(links, attr.Val)
}
}
}
}
}
}
You could invoke the crawlPage() concurrently and handle alreadyCrawledList, pending and brokenLinks variables with mutexes (not so performant though). On the other hand, the code needs to be modified a lot to get it more performant.
I did a quick check with 4 links and seems to half the duration. I did a sample code with a simple http server and its here
Thanks,
- Anoop