from vertexai.preview.generative_models import FunctionDeclaration
from vertexai.preview.generative_models import GenerationResponse
from src.patterns.web_access.serp import run as google_search
from src.patterns.web_access.tasks import SearchTask
from vertexai.preview.generative_models import Tool
from src.llm.generate import ResponseGenerator
from src.prompt.manage import TemplateManager
from src.config.logging import logger
from typing import Optional
from typing import Dict
from typing import Any 


class WebSearchAgent(SearchTask):
    """
    WebSearchAgent orchestrates search operations using a language model, generating search instructions,
    and performing the search via the SERP API.

    Attributes:
        TEMPLATE_PATH (str): Path to the template configuration file used for search.
        response_generator (ResponseGenerator): Instance to generate responses from the language model.
        template_manager (TemplateManager): Manages and fills templates for search instructions.
    """
    TEMPLATE_PATH = './config/patterns/web_access.yml'

    def __init__(self) -> None:
        """
        Initializes WebSearchAgent with the response generator and template manager.
        """
        self.response_generator = ResponseGenerator()
        self.template_manager = TemplateManager(self.TEMPLATE_PATH)

    def create_search_function_declaration(self) -> FunctionDeclaration:
        """
        Creates a function declaration for the web search tool, specifying parameters and usage.

        Returns:
            FunctionDeclaration: Describes the function declaration for web search.
        """
        return FunctionDeclaration(
            name="web_search",
            description="Perform Google Search using SERP API",
            parameters={
                "type": "object",
                "properties": {
                    "query": {"type": "string", "description": "Search query"},
                    "location": {"type": "string", "description": "Geographic location for localized results", "default": ""},
                },
                "required": ["query"]
            },
        )

    def function_call(self, model_name: str, search_query: str, search_tool: Tool) -> GenerationResponse:
        """
        Generates a response for the search query using the specified model and tool.

        Args:
            model_name (str): Name of the model to use for generating the response.
            search_query (str): The search query string.
            search_tool (Tool): Tool containing the function declaration for search.

        Returns:
            GenerationResponse: The response generated by the language model.

        Raises:
            Exception: If there is an error during response generation.
        """
        try:
            template = self.template_manager.create_template('tools', 'search')
            system_instruction = template['system']
            user_instruction = self.template_manager.fill_template(template['user'], query=search_query)
            
            logger.info(f"Generating response for search query: {search_query}")
            return self.response_generator.generate_response(
                model_name, 
                system_instruction, 
                [user_instruction], 
                tools=[search_tool]
            )
        except Exception as e:
            logger.error(f"Error generating search data: {e}")
            raise

    def extract_function_args(self, response: GenerationResponse) -> Optional[Dict[str, Any]]:
        """
        Extracts function arguments from the language model response.

        Args:
            response (GenerationResponse): The response from the language model.

        Returns:
            Optional[Dict[str, Any]]: A dictionary of function arguments, or None if extraction fails.

        Raises:
            Exception: If there is an error during argument extraction.
        """
        try:
            first_candidate = response.candidates[0]
            first_part = first_candidate.content.parts[0]
            function_call = first_part.function_call
            logger.info("Extracting function arguments from the response.")
            return dict(function_call.args) if function_call else None
        except (IndexError, KeyError) as e:
            logger.error(f"Failed to extract function arguments: {e}")
            return None

    def run(self, model_name: str, query: str, location: str = '') -> None:
        """
        Runs the web search process, generating instructions, extracting arguments, and initiating the search.

        Args:
            model_name (str): Name of the language model for generating the search response.
            query (str): Search query string.
            location (str, optional): Geographic location for search (default is '').
        """
        try:
            # Create search function tool and generate response
            search_tool = Tool(function_declarations=[self.create_search_function_declaration()])
            response = self.function_call(model_name, query, search_tool)
            function_args = self.extract_function_args(response)

            # Set search terms and location based on function arguments if available
            search_terms = function_args.get('query', query) if function_args else query
            search_location = location or function_args.get('location', '') if function_args else location

            logger.info(f"Running web search for query: '{search_terms}', location: '{search_location}'")
            google_search(query, search_terms, search_location)

        except Exception as e:
            logger.error(f"Error during search execution: {e}")
            raise
