package main

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"io"
	"io/fs"
	"os"
	"sort"
	"strconv"
	"strings"

	"github.com/pkg/errors"
	"golang.org/x/text/cases"
	"golang.org/x/text/language"
)

const keystr = "determined:stream-gen"

type streamType string

const (
	json           streamType = "JSONB"
	text           streamType = "string"
	textArr        streamType = "[]string"
	integer        streamType = "int"
	integer64      streamType = "int64"
	intArr         streamType = "[]int"
	boolean        streamType = "bool"
	time           streamType = "time.Time"
	timePtr        streamType = "*time.Time"
	taskID         streamType = "model.TaskID"
	requestID      streamType = "model.RequestID"
	requestIDPtr   streamType = "*model.RequestID"
	workspaceState streamType = "model.WorkspaceState"
)

const (
	server     = "server"
	client     = "client"
	python     = "python"
	typescript = "typescript"
)

// Streamable represents the struct under a determined:stream-gen comment.
type Streamable struct {
	Name     string
	Fields   []Field
	Args     map[string]string
	Position token.Position
}

// Field is a member of a Streamable.
type Field struct {
	Name    string
	Type    streamType
	JSONTag string
}

// RootVisitor is the Visitor for the top-level go document.
type RootVisitor struct {
	src []byte
	out *[]Streamable
	fs  *token.FileSet
}

// Visit implements the ast.Visitor interface.
func (x RootVisitor) Visit(node ast.Node) ast.Visitor {
	if node == nil {
		return nil
	}
	return DeclFinder{x.src, x.out, x.fs} //nolint: gosimple // linter wants "clever" code
}

// DeclFinder discards any top-level definitions which can't be a type declaration.
type DeclFinder struct {
	src []byte
	out *[]Streamable
	fs  *token.FileSet
}

// Visit implements the ast.Visitor interface.
func (x DeclFinder) Visit(node ast.Node) ast.Visitor {
	if node == nil {
		return nil
	}
	_, ok := node.(*ast.GenDecl)
	if !ok {
		return nil
	}
	return &StreamableFinder{src: x.src, fs: x.fs, out: x.out}
}

// StreamableFinder seeks `type Thing struct` definitions with `determined:stream-gen` comments,
// builds an associated Streamable object, and adds it to the out slice.
type StreamableFinder struct {
	src              []byte
	fs               *token.FileSet
	out              *[]Streamable
	expectStreamable bool
	position         token.Position
	streamableArgs   map[string]string
}

// Visit implements the ast.Visitor interface.
func (x *StreamableFinder) Visit(node ast.Node) ast.Visitor {
	if node == nil {
		return nil
	}
	if !x.expectStreamable {
		// is this a comment group containing "determined:stream-gen"?
		cmntgrp, ok := node.(*ast.CommentGroup)
		if !ok {
			// not a comment group, don't care
			return nil
		}
		// check each comment
		var cmntstr string
		for _, cmnt := range cmntgrp.List {
			if !strings.Contains(cmnt.Text, keystr) {
				continue
			}
			cmntstr = cmnt.Text
			// remember the location, in case we have to know where an error originates from
			offset := strings.Index(cmntstr, keystr)
			x.position = x.fs.Position(cmnt.Pos() + token.Pos(offset))
			break
		}
		if cmntstr == "" {
			// not a determined:stream-gen comment, don't care
			return nil
		}
		// We found one! The next node should be our StructType.
		x.expectStreamable = true
		// Get the text ocurring after the special determined:stremable.
		textAfter := strings.SplitN(cmntstr, keystr, 2)[1]
		// Get just the rest of the line containing the special determined:stream-gen
		lineAfter := strings.SplitN(textAfter, "\n", 2)[0]
		// Parse out "key=value" pairs.
		pairs := strings.Split(lineAfter, " ")
		x.streamableArgs = make(map[string]string)
		for _, p := range pairs {
			pair := strings.Trim(p, " ")
			if pair == "" {
				continue
			}
			fields := strings.Split(pair, "=")
			if len(fields) != 2 {
				fmt.Fprintf(os.Stderr, "found invalid key=value pair %q\n", pair)
				os.Exit(1)
			}
			x.streamableArgs[fields[0]] = fields[1]
		}
		return nil
	}

	// expectstreamable is only valid once.
	x.expectStreamable = false

	// This should be a TypeSpec with .Type that is a StructType.
	typ, ok := node.(*ast.TypeSpec)
	if !ok {
		fmt.Fprintf(os.Stderr, "found special 'determined:stream-gen' comment on non-struct\n")
		os.Exit(1)
	}
	strct, ok := typ.Type.(*ast.StructType)
	if !ok {
		fmt.Fprintf(os.Stderr, "found special 'determined:stream-gen' comment on non-struct\n")
		os.Exit(1)
	}

	// Build our Streamable from this struct.
	result := Streamable{Name: typ.Name.String(), Position: x.position}

	// Grab the args we parsed for this streamable.
	result.Args = x.streamableArgs
	x.streamableArgs = nil

	for _, field := range strct.Fields.List {
		if len(field.Names) == 0 {
			continue
		}
		if field.Tag == nil {
			continue
		}
		// The field tag comes as a literal; so unquote it to get the string
		tags, err := strconv.Unquote(field.Tag.Value)
		if err != nil {
			fmt.Fprintf(os.Stderr, "failed to parse tag: %v\n", field.Tag.Value)
			os.Exit(7)
		}
		// Use strings.Fields to split tags by non-empty space-separated individual tags.
		for _, tag := range strings.Fields(tags) {
			// Let each individual tag be KEY:VALUE, where VALUE can be anything.
			splits := strings.SplitN(tag, ":", 2)
			if len(splits) != 2 {
				fmt.Fprintf(os.Stderr, "failed to parse tag: %v\n", field.Tag.Value)
				os.Exit(7)
			}
			// Now Unquote each VALUE as if it were another string literal.
			k := splits[0]
			v, err := strconv.Unquote(splits[1])
			if err != nil {
				fmt.Fprintf(os.Stderr, "failed to parse tag: %v\n", field.Tag.Value)
				os.Exit(7)
			}
			// Detect the json= tag to figure out the name of this field.
			if k != "json" {
				continue
			}
			// Pick out the first comma-separated value from tag values like "since,omit_empty".
			v = strings.SplitN(v, ",", 2)[0]
			// Get the string representing this type.  We use the string because the ast
			// representation of the type is a PITA to work with.
			typestr := string(x.src[field.Type.Pos()-1 : field.Type.End()-1])
			result.Fields = append(result.Fields, Field{field.Names[0].String(), streamType(typestr), v})
		}
	}

	// extend output
	*x.out = append(*x.out, result)

	return nil
}

func parseFiles(files []string) ([]Streamable, error) {
	var results []Streamable

	for _, f := range files {
		src, err := os.ReadFile(f) //nolint: gosec // of course the code generator reads a file
		if err != nil {
			return nil, fmt.Errorf("reading file (%v): %v", src, err)
		}
		fs := token.NewFileSet()
		opts := parser.ParseComments | parser.SkipObjectResolution
		file, err := parser.ParseFile(fs, f, src, opts)
		if err != nil {
			return nil, fmt.Errorf("in file (%v): %v", f, err)
		}

		ast.Walk(RootVisitor{src, &results, fs}, file)
	}

	return results, nil
}

// Builder wraps strings.Builder but doesn't return a nil error like strings.Builder.
type Builder struct {
	builder strings.Builder
}

// Writef writes a formatted string, discard errors (which strings.Builder does not return).
func (b *Builder) Writef(fstr string, args ...interface{}) {
	if len(args) == 0 {
		_, _ = b.builder.WriteString(fstr)
		return
	}
	_, _ = b.builder.WriteString(fmt.Sprintf(fstr, args...))
}

// String does not deserve a comment, but the linter wants it.
func (b *Builder) String() string {
	return b.builder.String()
}

func genTypescript(streamables []Streamable) ([]byte, error) {
	b := Builder{}
	typeAnno := func(f Field) ([2]string, error) {
		x := map[streamType]([2]string){
			json:           {"any", "{}"},
			text:           {"string", ""},
			textArr:        {"Array<string>", "[]"},
			boolean:        {"bool", "false"},
			integer:        {"number", "0"},
			integer64:      {"number", "0"},
			intArr:         {"Array<number>", "[]"},
			time:           {"string", ""},
			timePtr:        {"string | undefined", "undefined"},
			taskID:         {"string", ""},
			requestID:      {"number", "0"},
			requestIDPtr:   {"number | undefined", "undefined"},
			workspaceState: {"types.WorkspaceState", "types.WorkspaceState.Unspecified"},
		}
		out, ok := x[f.Type]
		if !ok {
			return [2]string{"", ""}, fmt.Errorf("no type annotation matches %q", f.Type)
		}
		return out, nil
	}
	b.Writef("// Code generated by stream-gen. DO NOT EDIT.\n")
	b.Writef("\n")
	b.Writef("import { isEqual } from 'lodash';\n")
	b.Writef("\n")
	b.Writef("import { Streamable, StreamSpec } from '.';\n")
	b.Writef("\n")
	typesImported := false
	for _, s := range streamables {
		source := s.Args["source"]
		entity := strings.ToLower(strings.TrimSuffix(s.Name, "SubscriptionSpec"))
		caser := cases.Title(language.English)

		switch source {
		case server:
			continue
		case client:
			for _, f := range s.Fields {
				anno, err := typeAnno(f)
				if err != nil {
					return nil, fmt.Errorf("struct %v, field %v: %v", s.Name, f.Name, err)
				}
				if strings.Contains(anno[0], "types") && !typesImported {
					typesImported = true
					b.Writef("import * as types from 'types';\n\n")
				}
			}
			b.Writef("export class %vSpec extends StreamSpec {\n", caser.String(entity))
			b.Writef("  readonly #id: Streamable = '%vs';\n", entity)
			for _, f := range s.Fields {
				anno, _ := typeAnno(f)
				b.Writef("  #%v: %v;\n", f.JSONTag, anno[0])
			}
			b.Writef("\n")
			b.Writef("  constructor(\n")
			for _, f := range s.Fields {
				anno, _ := typeAnno(f)
				b.Writef("    %v?: %v,\n", f.JSONTag, anno[0])
			}
			b.Writef("  ) {\n")
			b.Writef("    super();\n")
			for _, f := range s.Fields {
				anno, _ := typeAnno(f)
				b.Writef("    this.#%v = %v || %v;\n", f.JSONTag, f.JSONTag, anno[1])
			}
			b.Writef("  }\n")
			b.Writef("\n")
			b.Writef("  public equals = (sp?: StreamSpec): boolean => {\n")
			b.Writef("    if (!sp) return false;\n")
			b.Writef("    if (sp instanceof %vSpec) {\n", caser.String(entity))
			b.Writef("      return (\n")
			for i, f := range s.Fields {
				if i > 0 {
					b.Writef("        &&\n")
				}
				b.Writef("        isEqual(sp.#%v, this.#%v)\n", f.JSONTag, f.JSONTag)
			}
			b.Writef("      );\n")
			b.Writef("    }\n")
			b.Writef("    return false;\n")
			b.Writef("  };\n")
			b.Writef("\n")
			b.Writef("  public id = (): Streamable => {\n")
			b.Writef("    return this.#id;\n")
			b.Writef("  };\n")
			b.Writef("\n")
			b.Writef("  public toWire = (): Record<string, unknown> => {\n")
			b.Writef("    return {\n")
			for _, f := range s.Fields {
				b.Writef("      %v: this.#%v,\n", f.JSONTag, f.JSONTag)
			}
			b.Writef("    };\n")
			b.Writef("  };\n")
			b.Writef("}\n\n")
		}
	}
	return []byte(b.String()), nil
}

func genPython(streamables []Streamable) ([]byte, error) {
	b := Builder{}
	typeAnno := func(f Field) (string, error) {
		x := map[streamType]string{
			json:           "typing.Any",
			text:           "str",
			textArr:        "typing.List[str]",
			boolean:        "bool",
			integer:        "int",
			integer64:      "int",
			intArr:         "typing.List[int]",
			time:           "float",
			timePtr:        "typing.Optional[float]",
			taskID:         "str",
			requestID:      "int",
			requestIDPtr:   "typing.Optional[int]",
			workspaceState: "str",
		}
		out, ok := x[f.Type]
		if !ok {
			return "", fmt.Errorf("no type annotation matches %q", f.Type)
		}
		return out, nil
	}
	optional := func(anno string) string {
		if strings.HasPrefix(anno, "typing.Optional") {
			return anno
		}
		return fmt.Sprintf("typing.Optional[%v]", anno)
	}
	b.Writef("# Code generated by stream-gen. DO NOT EDIT.\n")
	b.Writef("\n")
	b.Writef("\"\"\"Wire formats for the determined streaming updates subsystem\"\"\"\n")
	b.Writef("\n")
	b.Writef("import typing\n")
	b.Writef("\n")
	b.Writef("\n")
	b.Writef("class ServerMsg:\n")
	b.Writef("    @classmethod\n")
	b.Writef("    def from_json(cls, obj: typing.Any) -> \"ServerMsg\":\n")
	b.Writef("        return cls(**obj)  # type: ignore\n")
	b.Writef("\n")
	b.Writef("    def to_json(self) -> typing.Dict[str, typing.Any]:\n")
	b.Writef("        return dict(vars(self))\n")
	b.Writef("\n")
	b.Writef("    def __repr__(self) -> str:\n")
	b.Writef("        body = \", \".join(f\"{k}={v}\" for k, v in vars(self).items())\n")
	b.Writef("        return f\"{type(self).__name__}({body})\"\n")
	b.Writef("\n")
	b.Writef("    def __eq__(self, other: object) -> bool:\n")
	b.Writef("        return isinstance(other, type(self)) and vars(self) == vars(other)\n")
	b.Writef("\n")
	b.Writef("\n")
	b.Writef("class DeleteMsg:\n")
	b.Writef("    def __init__(self, keys: str) -> None:\n")
	b.Writef("        self.keys = keys\n")
	b.Writef("\n")
	b.Writef("    @classmethod\n")
	b.Writef("    def from_json(cls, keys: str) -> \"DeleteMsg\":\n")
	b.Writef("        return cls(keys)\n")
	b.Writef("\n")
	b.Writef("    def to_json(self) -> str:\n")
	b.Writef("        return self.keys\n")
	b.Writef("\n")
	b.Writef("    def __repr__(self) -> str:\n")
	b.Writef("        return f\"{type(self).__name__}({self.keys})\"\n")
	b.Writef("\n")
	b.Writef("    def __eq__(self, other: object) -> bool:\n")
	b.Writef("        return isinstance(other, type(self)) and self.keys == other.keys\n")
	b.Writef("\n")
	b.Writef("\n")
	b.Writef("class ClientMsg:\n")
	b.Writef("    def to_json(self) -> typing.Dict[str, typing.Any]:\n")
	b.Writef("        return {k: v for k, v in vars(self).items() if v is not None}\n")
	b.Writef("\n")
	b.Writef("    def __repr__(self) -> str:\n")
	b.Writef("        body = \", \".join(f\"{k}={v}\" for k, v in vars(self).items())\n")
	b.Writef("        return f\"{type(self).__name__}({body})\"\n")
	b.Writef("\n")
	b.Writef("    def __eq__(self, other: object) -> bool:\n")
	b.Writef("        return isinstance(other, type(self)) and self.to_json() == other.to_json()\n")

	for _, s := range streamables {
		source := s.Args["source"]

		switch source {
		case server:
			// Generate a subclass of a ServerMsg, all fields are always filled.
			b.Writef("\n\n")
			b.Writef("class %v(ServerMsg):\n", s.Name)
			b.Writef("    def __init__(\n")
			b.Writef("        self,\n")
			for _, f := range s.Fields {
				anno, err := typeAnno(f)
				if err != nil {
					return nil, fmt.Errorf("struct %v, field %v: %v", s.Name, f.Name, err)
				}
				b.Writef("        %v: %q,\n", f.JSONTag, anno)
			}
			b.Writef("    ) -> None:\n")
			for _, f := range s.Fields {
				b.Writef("        self.%v = %v\n", f.JSONTag, f.JSONTag)
			}
			if deleter := s.Args["delete_msg"]; deleter != "" {
				// Also generate a Delete message
				b.Writef("\n\n")
				b.Writef("class %v(DeleteMsg):\n", deleter)
				b.Writef("    pass\n")
			}
		case client:
			// Generate a subclass of a ClientMsg, all fields are always optional.
			b.Writef("\n\n")
			b.Writef("class %v(ClientMsg):\n", s.Name)
			b.Writef("    def __init__(\n")
			b.Writef("        self,\n")
			for _, f := range s.Fields {
				anno, err := typeAnno(f)
				if err != nil {
					return nil, fmt.Errorf("struct %v, field %v: %v", s.Name, f.Name, err)
				}
				b.Writef("        %v: %q = None,\n", f.JSONTag, optional(anno))
			}
			b.Writef("    ) -> None:\n")
			for _, f := range s.Fields {
				b.Writef("        self.%v = %v\n", f.JSONTag, f.JSONTag)
			}
		default:
			fmt.Fprintf(os.Stderr, "invalid 'source' value %q @ %v\n", source, s.Position)
			os.Exit(1)
		}
	}
	return []byte(b.String()), nil
}

func printHelp(output io.Writer) {
	fmt.Fprintf(
		output,
		`stream-gen generates bindings for determined streaming updates.

usage: stream-gen IN.GO... --python/ts [--output OUTPUT]

All structs in the input files IN.GO... which contain special 'determined:stream-gen' comments will
be included in the generated output.

Presently the only output languages are python and typescript.

Output will be written to stdout, or a location specified by --output.  The OUTPUT will only be
overwritten if it would be modified.
`)
}

func verifyArgs(streamables []Streamable) {
	allowedArgs := map[string]bool{
		"delete_msg": true,
		"source":     true,
	}
	requiredArgs := []string{"source"}

	for _, s := range streamables {
		// verify args
		for k, v := range s.Args {
			if !allowedArgs[k] {
				fmt.Fprintf(os.Stderr, "unrecognized arg %q (%v=%v) @ %v\n", k, k, v, s.Position)
				os.Exit(1)
			}
		}
		for _, k := range requiredArgs {
			if _, ok := s.Args[k]; !ok {
				fmt.Fprintf(os.Stderr, "missing required arg %q @ %v\n", k, s.Position)
				os.Exit(1)
			}
		}
	}
}

func main() {
	// Parse commandline options manually because built-in flag library is junk.
	if len(os.Args) == 1 {
		// no args provided
		printHelp(os.Stdout)
		os.Exit(0)
	}
	output := ""
	lang := ""
	gofiles := []string{}
	for i := 1; i < len(os.Args); i++ {
		arg := os.Args[i]
		if arg == "-h" || arg == "--help" {
			printHelp(os.Stdout)
			os.Exit(0)
		}
		if arg == "--python" {
			lang = python
			continue
		}
		if arg == "--ts" {
			lang = typescript
			continue
		}
		if arg == "-o" || arg == "--output" {
			if i+1 >= len(os.Args) {
				fmt.Fprintf(os.Stderr, "Missing --output parameter.\nTry --help.\n")
				os.Exit(2)
			}
			i++
			output = os.Args[i]
			continue
		}
		if strings.HasPrefix(arg, "-") {
			fmt.Fprintf(os.Stderr, "Unrecognized option: %v.\nTry --help.\n", arg)
			os.Exit(2)
		}
		gofiles = append(gofiles, arg)
	}
	if len(gofiles) == 0 {
		fmt.Fprintf(os.Stderr, "No input files.\nTry --help.\n")
		os.Exit(2)
	}
	if lang == "" {
		fmt.Fprintf(os.Stderr, "No language specifier.\nTry --help.\n")
		os.Exit(2)
	}

	// read input files
	results, err := parseFiles(gofiles)
	if err != nil {
		fmt.Fprintf(os.Stderr, "%v\n", err)
		os.Exit(1)
	}

	// verify args will exit with code 1 in case of error.
	verifyArgs(results)

	// generate the language bindings
	sort.Slice(results, func(i, j int) bool {
		return results[i].Name < results[j].Name
	})
	var content []byte
	switch lang {
	case python:
		content, err = genPython(results)
		if err != nil {
			fmt.Fprintf(os.Stderr, "%v\n", err)
			os.Exit(1)
		}
	case typescript:
		content, err = genTypescript(results)
		if err != nil {
			fmt.Fprintf(os.Stderr, "%v\n", err)
			os.Exit(1)
		}
	}

	// write to output
	if output == "" {
		// write to stdout
		_, err := os.Stdout.Write(content)
		if err != nil {
			fmt.Fprintf(os.Stderr, "failed writing to stdout: %v\n", err.Error())
			os.Exit(1)
		}
	} else {
		old, err := os.ReadFile(output)
		if err != nil && !errors.Is(err, fs.ErrNotExist) {
			fmt.Fprintf(os.Stderr, "failed reading old content of %v: %v\n", output, err.Error())
			os.Exit(1)
		}
		if bytes.Equal(old, content) {
			// old output is already up-to-date
			fmt.Fprintf(os.Stderr, "output is up-to-date\n")
			os.Exit(0)
		}
		// write a new output
		err = os.WriteFile(output, content, 0o666) //nolint: gosec // let umask do its thing
		if err != nil {
			fmt.Fprintf(os.Stderr, "failed writing to %v: %v\n", output, err.Error())
			os.Exit(1)
		}
	}
}
