// Copyright 2020 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package main

import (
	"bytes"
	"fmt"
	"html/template"
	"os"
	"regexp"
	"strings"

	"github.com/cockroachdb/cockroach/pkg/cli/exit"
	"github.com/cockroachdb/errors"
	"github.com/cockroachdb/gostdlib/go/format"
)

func main() {
	if err := run(); err != nil {
		fmt.Fprintln(os.Stderr, "ERROR:", err)
		exit.WithCode(exit.UnspecifiedError())
	}
}

func run() error {
	if len(os.Args) < 5 {
		return errors.Newf("usage: %s <package> <type> <input> <output>\n", os.Args[0])
	}
	pkg, opType, in, out := os.Args[1], os.Args[2], os.Args[3], os.Args[4]

	source, err := os.ReadFile(in)
	if err != nil {
		return err
	}

	opPattern := regexp.MustCompile(`type (\w+) struct {`)
	var ops []string
	for _, line := range strings.Split(string(source), "\n") {
		line = strings.TrimSpace(line)
		if matches := opPattern.FindStringSubmatch(line); len(matches) > 0 {
			ops = append(ops, matches[1])
		}
	}

	tmpl, err := template.New("visitor").Parse(visitorTemplate)
	if err != nil {
		return err
	}

	// Render the template.
	var gen bytes.Buffer
	if err := tmpl.Execute(&gen, info{
		Pkg:  pkg,
		Type: opType,
		Ops:  ops,
	}); err != nil {
		return err
	}

	// Run gofmt on the generated source.
	formatted, err := format.Source(gen.Bytes())
	if err != nil {
		return errors.Wrap(err, "gofmt")
	}

	// Write the output file.
	f, err := os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
	if err != nil {
		return err
	}
	defer func() { _ = f.Close() }()
	if _, err := f.Write(formatted); err != nil {
		return err
	}

	return nil
}

type info struct {
	Pkg  string
	Type string
	Ops  []string
}

const visitorTemplate = `// Copyright 2020 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

// Code generated by generate_visitor.go. DO NOT EDIT.

package {{.Pkg}}

import "context"

{{$type := .Type}}
// {{$type}}Op is an operation which can be visited by {{$type}}Visitor.
type {{$type}}Op interface {
	Op
	Visit(context.Context, {{$type}}Visitor) error
}

// {{$type}}Visitor is a visitor for {{$type}}Op operations.
type {{$type}}Visitor interface {
{{range .Ops -}}
	{{.}}(context.Context, {{.}}) error
{{end}}
}

{{range .Ops}}
// Visit is part of the {{$type}}Op interface.
func (op {{.}}) Visit(ctx context.Context, v {{$type}}Visitor) error {
	return v.{{.}}(ctx, op)
}
{{end}}`
