"""
セット球別の集計画面を定義したクラス

    「ロト6 当せん数字予測アプリ」
    Copyright (c) 2025 toshifumi tsutsui
    Released under the MIT license
    https://wpandora8.net/the_mit_license.html
"""

from typing import override

import polars as pl
import scipy.stats
from kivy.lang import Builder
from kivy.uix.label import Label
from kivy.uix.screenmanager import Screen
from kivy.uix.spinner import Spinner

from datagrid import DataGrid
from datagrid_co_occurrence_row import (
    DataGridCoOccurrenceHeader,
    DataGridCoOccurrenceRow,
)
from datagrid_frequency_count_row import (
    DataGridFrequencyCountHeader,
    DataGridFrequencyCountRow,
)
from shared_vars import SharedVars

Builder.load_string("""
<ByBallSetScreen>:
    datagrid_frequency_count: datagrid_frequency_count
    datagrid_co_occurrence: datagrid_co_occurrence
    spinner_ball_set: spinner_ball_set
    spinner_include_bonus: spinner_include_bonus
    label_stats: label_stats
    label_number: label_number

    AnchorLayout:
        size_hint: 0.9, 0.9
        pos_hint: {'center_x': 0.5, 'center_y': 0.5}
        BoxLayout:
            orientation: 'vertical'
            BoxLayout:
                size_hint_y: 0.1
                spacing: dp(8)
                orientation: 'horizontal'
                Label:
                    text: 'セット球:'
                    halign: 'right'
                    valign: 'middle'
                    width: dp(75)
                    size_hint_x: None
                    text_size: self.size
                AnchorLayout:
                    anchor_x: 'center'
                    anchor_y: 'center'
                    width: dp(54)
                    size_hint_x: None
                    Spinner:
                        id: spinner_ball_set
                        text: "A"
                        values: ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
                        size: dp(54), dp(36)
                        size_hint: None, None
                        on_text: root.on_spinner_ball_set_text_change()
                Label:
                    id: label_stats
                    text: '各数字の出現回数の 平均: 0.00,  標準偏差: 1.00'
                    halign: 'left'
                    valign: 'middle'
                    text_size: self.size
                    width: self.texture_size[0]
            AnchorLayout:
                size_hint_y: 0.1
                anchor_x: 'center'
                anchor_y: 'center'
                Spinner:
                    id: spinner_include_bonus
                    text: "集計にボーナス数字を含めない"
                    values: ["集計にボーナス数字を含めない", "集計にボーナス数字を含める"]
                    height: dp(36)
                    size_hint: 0.4, None
                    on_text: root.on_spinner_include_bonus_text_change()
            BoxLayout:
                orientation: 'horizontal'
                size_hint_y: 0.8
                spacing: dp(50)
                BoxLayout:
                    orientation: 'vertical'
                    size_hint_x: 0.55
                    Label:
                        text: '各数字の出現状況'
                        halign: 'center'
                        valign: 'middle'
                        size_hint_y: 0.1
                    DataGrid:
                        size_hint_y: 0.9
                        id: datagrid_frequency_count
                        viewclass: 'DataGridFrequencyCountRow'

                BoxLayout:
                    orientation: 'vertical'
                    size_hint_x: 0.45
                    Label:
                        id: label_number
                        text: "'01' と同時に出現している数字"
                        halign: 'center'
                        valign: 'middle'
                        size_hint: 1.0, 0.1
                    DataGrid:
                        size_hint_y: 0.9
                        id: datagrid_co_occurrence
                        viewclass: 'DataGridCoOccurrenceRow'
""")


class ByBallSetScreen(Screen):
    """セット球別の集計画面を定義したクラス"""

    def __init__(self, shared_vars: SharedVars, **kwargs):
        """セット球別の集計画面のインスタンスを作成して返す。

        Args:
            shared_vars (SharedVars): アプリ内で共有する変数を管理するクラスのインスタンス。
        """

        super().__init__(**kwargs)

        self._shared_vars = shared_vars
        """アプリ内で共有する変数を管理するクラスのインスタンス"""

        self._include_bonus: bool = False
        """集計にボーナス数字を含める場合は True"""

        self._setup_widget_references()

        self._datagrid_frequency_count.header_item = DataGridFrequencyCountHeader()
        self._datagrid_co_occurrence.header_item = DataGridCoOccurrenceHeader()

        self._datagrid_frequency_count.bind(selected_row=self._on_select_change)  # type: ignore

        self._df_by_number: pl.DataFrame = self._aggregate_by_ball_set()
        """数字別の出現回数を集計した DataFrame"""

    def _setup_widget_references(self):
        """各ウィジェットの参照を設定する。"""

        self._datagrid_frequency_count: DataGrid = self.ids.datagrid_frequency_count
        """数字別の出現回数のデータグリッド"""

        self._datagrid_co_occurrence: DataGrid = self.ids.datagrid_co_occurrence
        """特定の数字と同時に出現している数字の出現回数のデータグリッド"""

        self._spinner_ball_set: Spinner = self.ids.spinner_ball_set
        """セット球を選択するスピナー"""

        self._spinner_include_bonus: Spinner = self.ids.spinner_include_bonus
        """ボーナス数字を集計に含めるかどうかを選択するスピナー"""

        self._label_number: Label = self.ids.label_number
        """特定の数字を表示するラベル"""

        self._label_stats: Label = self.ids.label_stats
        """出現回数の平均と標準偏差を表示するラベル"""

    @override
    def on_pre_enter(self, *args):
        """画面が表示されるとき。"""

        self._calculation()

    def on_spinner_ball_set_text_change(self):
        """セット球のスピナーの選択肢が変更されたとき。"""

        self._calculation()

    def on_spinner_include_bonus_text_change(self):
        """集計にボーナス数字を含めるかどうかのスピナーの選択肢が変更されたとき。"""

        if self._spinner_include_bonus.text == "集計にボーナス数字を含めない":
            self._include_bonus = False
        else:
            self._include_bonus = True

        self._calculation()

    def _calculation(self) -> None:
        """データを集計して表示する。"""

        self._df_by_number = self._aggregate_by_ball_set()

        if not self._df_by_number.is_empty():
            col_name: str = "numbers_with_bonus" if self._include_bonus else "numbers"
            number: int = self._df_by_number.get_column(col_name)[0]

            self._aggregate_co_occurrence_numbers(
                self._shared_vars.df,
                self._df_by_number,
                self._spinner_ball_set.text,
                number,
            )

    def _on_select_change(
        self, instance: DataGrid, selected_row: DataGridFrequencyCountRow
    ) -> None:
        """特定の数字が選択されたとき。

        Args:
            selected_row (DataGridFrequencyCountRow): 選択された行のアイテム。
        """

        if isinstance(selected_row, DataGridFrequencyCountRow):
            number: int = selected_row.numbers
            self._label_number.text = f"'{number:0>2}' と同時に出現している数字"

            self._aggregate_co_occurrence_numbers(
                self._shared_vars.df,
                self._df_by_number,
                self._spinner_ball_set.text,
                number,
            )

    def _aggregate_by_ball_set(self) -> pl.DataFrame:
        """セット球別に各数字の出現回数を集計する。"""

        df: pl.DataFrame = self._shared_vars.df

        if df.is_empty():
            return pl.DataFrame()

        else:
            ball_set: str = self._spinner_ball_set.text
            df_by_number: pl.DataFrame = self._get_frequency_count_dataframe(
                df, ball_set
            )

            mean: float = df_by_number.get_column("count").mean()  # type: ignore
            std: float = df_by_number.get_column("count").std()  # type: ignore
            text: str = f"（各数字の出現回数の 平均: {mean:.2f}, 標準偏差: {std:.2f}）"
            self._label_stats.text = text

            df_by_number = self._add_cdf_values(df_by_number, mean, std)

            self._datagrid_frequency_count.data = df_by_number.to_dicts()
            self._datagrid_frequency_count.scroll_to_top()

            return df_by_number

    def _aggregate_co_occurrence_numbers(
        self, df: pl.DataFrame, df_by_number: pl.DataFrame, ball_set: str, number: int
    ) -> None:
        """特定の数字と同時に出現している数字の出現回数を集計する。

        Args:
            df (pl.DataFrame): 元の DataFrame
            df_by_number (pl.DataFrame): 数字別の出現回数を集計した DataFrame
            ball_set (str): セット球の名前
            number (int): 特定の数字
        """

        self._label_number.text = f"'{number:0>2}' と同時に出現している数字"

        df_co_occurrence: pl.DataFrame = self._get_co_occurrence(
            df, df_by_number, ball_set, number
        )

        self._datagrid_co_occurrence.data = df_co_occurrence.to_dicts()
        self._datagrid_co_occurrence.scroll_to_top()

    ###############################################################################
    # Controllers

    def _get_frequency_count_dataframe(
        self, df: pl.DataFrame, ball_set: str
    ) -> pl.DataFrame:
        """数字別の出現回数を集計した DataFrame を返す。

        Args:
            df (pl.DataFrame): 元の DataFrame
            ball_set (str): セット球の名前

        Returns:
            pl.DataFrame: 数字別の出現回数を集計した DataFrame。
        """

        col_name: str = "numbers_with_bonus" if self._include_bonus else "numbers"

        df_new: pl.DataFrame = (
            df.select(["ball_set", col_name])
            .filter(pl.col("ball_set") == ball_set)
            .explode(col_name)
            .group_by(col_name)
            .agg(pl.len().alias("count"))
            .with_columns(
                (
                    pl.col("count") / df.filter(pl.col("ball_set") == ball_set).height
                ).alias("rate")
            )
            .select(pl.all().sort_by(["count", col_name], descending=True))
        )

        if self._include_bonus:
            df_new = df_new.with_columns(pl.col(col_name).alias("numbers"))

        return df_new

    def _get_co_occurrence(
        self, df: pl.DataFrame, df_by_number: pl.DataFrame, ball_set: str, number: int
    ) -> pl.DataFrame:
        """特定の数字と同時に出現した数字の出現回数を集計した DataFrame を返す。

        Args:
            df (pl.DataFrame): 元の DataFrame
            df_by_number (pl.DataFrame): 数字別の出現回数を集計した DataFrame
            ball_set (str): セット球の名前
            number (int): 特定の数字

        Returns:
            pl.DataFrame: 特定の数字と同時に出現した数字の出現回数を集計した DataFrame
        """

        col_name: str = "numbers_with_bonus" if self._include_bonus else "numbers"

        df_new: pl.DataFrame = (
            (
                df.select(["ball_set", col_name])
                .filter(
                    (pl.col("ball_set") == ball_set)
                    & (pl.col(col_name).list.contains(number))
                )
                .explode(col_name)
                .group_by(col_name)
                .agg(pl.len().alias("co_occurrence"))
            )
            .join(df_by_number, on=col_name, how="inner")
            .select(pl.all().sort_by(["co_occurrence", "count"], descending=True))
        ).filter(pl.col(col_name) != number)

        if self._include_bonus:
            df_new = df_new.with_columns(pl.col(col_name).alias("numbers"))

        return df_new

    def _add_cdf_values(
        self, df: pl.DataFrame, mean: float, std: float
    ) -> pl.DataFrame:
        """DataFrame に CDF（累積分布関数）の値を追加する。

        Args:
            df (pl.DataFrame): CDF の値を追加する DataFrame
            mean (float): 平均値
            std (float): 標準偏差

        Returns:
            pl.DataFrame: CDF の値を追加した DataFrame
        """

        return df.with_columns(
            (
                pl.col("count").map_elements(
                    lambda x: scipy.stats.norm.cdf(x, loc=mean, scale=std),
                    return_dtype=pl.Float64,
                )
            ).alias("cdf_value")
        )
