from decimal import Decimal, ROUND_HALF_UP

from secondary.models import (
    ALevelCompetencyScale,
    ALevelComponentWeight,
    ALevelModuleAssessment,
    ALevelSubjectModule,
    SecondaryComputationPolicy,
)


def _to_decimal(value):
    return Decimal(str(value))


def get_a_level_policy(section=None, on_date=None):
    policy = SecondaryComputationPolicy.get_active_policy(
        section=section,
        level=SecondaryComputationPolicy.Level.UPPER_SECONDARY,
        on_date=on_date,
    )
    if policy:
        return policy
    return (
        SecondaryComputationPolicy.objects.filter(
            is_active=True,
            level=SecondaryComputationPolicy.Level.UPPER_SECONDARY,
        )
        .order_by("-effective_from", "-id")
        .first()
    )


def _latest_attempt(records):
    if not records:
        return None
    return sorted(records, key=lambda record: (record.attempt_no, record.assessed_on, record.id))[-1]


def _get_component_weights(policy, subject):
    weights = list(
        ALevelComponentWeight.objects.filter(policy=policy, subject=subject)
        .order_by("component_type")
    )
    if not weights:
        raise ValueError("No A-Level component weights configured for this subject.")

    total_weight = sum(_to_decimal(weight.weight) for weight in weights)
    if total_weight.quantize(Decimal("0.01")) != Decimal("100.00"):
        raise ValueError("A-Level component weights for the subject must total 100%.")
    return weights


def _resolve_points(record, scale_map):
    if record is None:
        return Decimal("0.00")
    if record.points_awarded is not None:
        return _to_decimal(record.points_awarded)
    scale = scale_map.get(record.competency_code)
    return _to_decimal(scale.point_value) if scale else Decimal("0.00")


def _final_scale(policy, weighted_point):
    return (
        ALevelCompetencyScale.objects.filter(
            policy=policy,
            min_weighted_point__lte=weighted_point,
            max_weighted_point__gte=weighted_point,
        )
        .order_by("-min_weighted_point", "display_order")
        .first()
    )


def compute_a_level_subject_result(student, subject, policy=None):
    policy = policy or get_a_level_policy(section=subject.section)
    if not policy:
        raise ValueError("No active A-Level computation policy found.")
    if policy.level != SecondaryComputationPolicy.Level.UPPER_SECONDARY:
        raise ValueError("Selected policy is not an A-Level (Upper Secondary) policy.")

    scale_map = {
        scale.code: scale
        for scale in ALevelCompetencyScale.objects.filter(policy=policy)
    }
    if not scale_map:
        raise ValueError("No A-Level competency scale configured for the selected policy.")

    weights = _get_component_weights(policy=policy, subject=subject)
    records = list(
        ALevelModuleAssessment.objects.filter(
            policy=policy,
            student=student,
            subject=subject,
        ).select_related("module")
    )

    component_points = {}
    selected_records = []
    modular_breakdown = []

    for component in weights:
        component_type = component.component_type
        weight = _to_decimal(component.weight)

        if component_type == ALevelComponentWeight.ComponentType.MODULAR_EXAM:
            module_points = []
            modules = ALevelSubjectModule.objects.filter(subject=subject, is_active=True).order_by("module_order", "id")
            if modules.exists():
                for module in modules:
                    module_records = [
                        record
                        for record in records
                        if record.component_type == component_type and record.module_id == module.id
                    ]
                    latest = _latest_attempt(module_records)
                    if latest:
                        points = _resolve_points(latest, scale_map)
                        module_points.append(points)
                        selected_records.append(latest)
                        modular_breakdown.append(
                            {
                                "module_code": module.code,
                                "module_name": module.name,
                                "attempt_no": latest.attempt_no,
                                "competency_code": latest.competency_code,
                                "points": points,
                            }
                        )
            else:
                latest = _latest_attempt([record for record in records if record.component_type == component_type])
                if latest:
                    points = _resolve_points(latest, scale_map)
                    module_points.append(points)
                    selected_records.append(latest)

            if module_points:
                component_points[component_type] = (sum(module_points) / Decimal(str(len(module_points)))).quantize(
                    Decimal("0.01"),
                    rounding=ROUND_HALF_UP,
                )
            else:
                component_points[component_type] = Decimal("0.00")
        else:
            latest = _latest_attempt(
                [
                    record
                    for record in records
                    if record.component_type == component_type and record.module_id is None
                ]
            )
            if latest:
                selected_records.append(latest)
            component_points[component_type] = _resolve_points(latest, scale_map).quantize(
                Decimal("0.01"),
                rounding=ROUND_HALF_UP,
            )

        component_points[f"{component_type}_weight"] = weight
        component_points[f"{component_type}_contribution"] = (
            component_points[component_type] * weight / Decimal("100")
        ).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)

    weighted_point = sum(
        component_points[f"{component.component_type}_contribution"]
        for component in weights
    ).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)

    final_scale = _final_scale(policy=policy, weighted_point=weighted_point)
    return {
        "policy_id": policy.id,
        "weighted_point": weighted_point,
        "final_competency": final_scale.code if final_scale else "N/A",
        "final_descriptor": final_scale.descriptor if final_scale else "",
        "component_points": component_points,
        "selected_records": selected_records,
        "modular_breakdown": modular_breakdown,
    }
