"""
Tron USDT/TRX 余额查询工具
使用 TronGrid API 获取完整数据
"""
import os
import base58
import requests
from flask import Flask, render_template, request, jsonify

app = Flask(__name__)

# TronGrid API Key
TRONGRID_API_KEY = "6f467fc7-7216-4b65-afe1-a5522887623a"
HEADERS = {"TRON-PRO-API-KEY": TRONGRID_API_KEY}

# USDT TRC-20 合约地址
USDT_CONTRACT = "TR7NHqjeKQxGTCi8q8ZY4pL8otSzgjLj6t"


def to_hex_address(base58_address: str) -> str:
    """将 Base58 地址转换为 Hex 地址"""
    try:
        addr_bytes = base58.b58decode(base58_address)
        hex_addr = "41" + addr_bytes[1:-4].hex()
        return hex_addr
    except Exception as e:
        return base58_address


def get_usdt_balance_and_txs(address: str) -> dict:
    """使用 TronGrid API 获取 USDT 余额和交易"""
    hex_addr = to_hex_address(address)
    address_lower = address.lower()
    
    incoming = 0
    outgoing = 0
    all_txs = []
    
    url = f"https://api.trongrid.io/v1/accounts/{hex_addr}/transactions/trc20"
    params = {
        "token": USDT_CONTRACT,
        "limit": 200
    }
    
    # 获取所有交易（分页）
    max_pages = 10  # 最多获取10页 = 2000条
    for page in range(max_pages):
        try:
            if page > 0 and min_ts:
                params["min_timestamp"] = min_ts - 1
            
            resp = requests.get(url, params=params, headers=HEADERS, timeout=30)
            data = resp.json()
            
            txs = data.get("data", [])
            if not txs:
                break
            
            for tx in txs:
                quant = int(tx.get("value", 0))
                from_addr = tx.get("from", "").lower()
                to_addr = tx.get("to", "").lower()
                
                if to_addr == address_lower:
                    incoming += quant
                if from_addr == address_lower:
                    outgoing += quant
                
                # 解析交易
                block_ts = tx.get("block_timestamp", 0)
                import datetime
                ts = datetime.datetime.fromtimestamp(block_ts / 1000)
                
                all_txs.append({
                    "tx_id": tx.get("transaction_id", ""),
                    "tx_id_short": tx.get("transaction_id", "")[:16] + "...",
                    "timestamp": ts.strftime("%Y-%m-%d %H:%M"),
                    "type": "USDT",
                    "amount": quant / 1_000_000,
                    "token": "USDT",
                    "from": tx.get("from", ""),
                    "to": tx.get("to", ""),
                    "direction": "in" if to_addr == address_lower else "out",
                    "status": "SUCCESS"
                })
            
            # 获取最后时间戳用于分页
            min_ts = txs[-1].get("block_timestamp", 0)
            
            # 如果不够一页，说明已经获取完
            if len(txs) < 200:
                break
                
        except Exception as e:
            print(f"Page {page} error: {e}")
            break
    
    balance = incoming - outgoing
    if balance < 0:
        balance = 0
    
    return {
        "balance": balance / 1_000_000,
        "transactions": all_txs
    }


def get_account_info(address: str) -> dict:
    """获取 TRX 账户信息"""
    hex_addr = to_hex_address(address)
    url = "https://api.trongrid.io/wallet/getaccount"
    payload = {"address": hex_addr}
    
    try:
        resp = requests.post(url, json=payload, headers=HEADERS, timeout=10)
        if resp.status_code == 200:
            data = resp.json()
            if data and data.get("address"):
                return data
    except Exception as e:
        print(f"Wallet API Error: {e}")
    return None


def get_trx_balance(address: str) -> float:
    """查询 TRX 余额"""
    account = get_account_info(address)
    if account:
        balance = account.get("balance", 0)
        return balance / 1_000_000
    return 0


def get_energy(address: str) -> dict:
    """查询能量/Bandwidth"""
    account = get_account_info(address)
    if account:
        account_resource = account.get("account_resource", {})
        bandwidth = account_resource.get("bandwidth", {})
        
        frozen = account.get("frozenV2", [])
        energy_frozen = 0
        for f in frozen:
            if f.get("type") == "ENERGY":
                energy_frozen = int(f.get("frozen_balance", 0))
        
        return {
            "bandwidth": {
                "free_limit": bandwidth.get("free_limit", 0),
                "free_used": bandwidth.get("free_used", 0),
                "net_limit": bandwidth.get("net_limit", 0),
                "net_used": bandwidth.get("net_used", 0),
            },
            "energy": {
                "limit": account.get("energy_limit", 0),
                "used": account.get("energy_usage", 0),
                "frozen": energy_frozen / 1_000_000
            }
        }
    return {"bandwidth": {}, "energy": {}}


def get_transactions(address: str) -> list:
    """查询 TRX 转账记录"""
    hex_addr = to_hex_address(address)
    transactions = []
    
    try:
        url = f"https://api.trongrid.io/v1/accounts/{hex_addr}/transactions"
        params = {"limit": 20}
        resp = requests.get(url, params=params, headers=HEADERS, timeout=10)
        data = resp.json()
        
        if data.get("data"):
            for tx in data["data"]:
                tx_id = tx.get("txID", "")
                block_timestamp = tx.get("block_timestamp", 0)
                
                import datetime
                timestamp = datetime.datetime.fromtimestamp(block_timestamp / 1000)
                ts_str = timestamp.strftime("%Y-%m-%d %H:%M")
                
                raw_data = tx.get("raw_data", {})
                contract = raw_data.get("contract", [])
                
                result = tx.get("ret", [{}])[0] if tx.get("ret") else {}
                amount = 0
                from_addr = ""
                to_addr = ""
                
                if contract:
                    c = contract[0]
                    parameter = c.get("parameter", {})
                    value = parameter.get("value", {})
                    
                    if c.get("type") == "TransferContract":
                        amount = int(value.get("amount", 0)) / 1_000_000
                        from_hex = value.get("owner_address", "")
                        to_hex = value.get("to_address", "")
                        
                        try:
                            if from_hex.startswith("41"):
                                from_bytes = bytes.fromhex(from_hex[2:])
                                from_addr = base58.b58encode(b'\x41' + from_bytes).decode()
                            if to_hex.startswith("41"):
                                to_bytes = bytes.fromhex(to_hex[2:])
                                to_addr = base58.b58encode(b'\x41' + to_bytes).decode()
                        except:
                            pass
                
                is_incoming = to_addr.lower() == address.lower()
                
                transactions.append({
                    "tx_id": tx_id,
                    "tx_id_short": tx_id[:16] + "...",
                    "timestamp": ts_str,
                    "type": "TRX",
                    "amount": amount,
                    "token": "TRX",
                    "from": from_addr,
                    "to": to_addr,
                    "direction": "in" if is_incoming else "out",
                    "status": "SUCCESS" if result.get("contractRet") == "SUCCESS" else "FAILED"
                })
                
    except Exception as e:
        print(f"TRX Tx Error: {e}")
    
    return transactions


@app.route("/")
def index():
    return render_template("index.html")


@app.route("/query")
def query():
    address = request.args.get("address", "").strip()
    
    if not address:
        return jsonify({"error": "请输入地址"})
    
    if not address.startswith("T"):
        return jsonify({"error": "地址格式不正确，应以 T 开头"})
    
    # 获取 USDT 数据
    usdt_data = get_usdt_balance_and_txs(address)
    usdt_balance = usdt_data.get("balance", 0)
    usdt_txs = usdt_data.get("transactions", [])
    
    # 获取 TRX 数据
    trx_balance = get_trx_balance(address)
    energy = get_energy(address)
    trx_txs = get_transactions(address)
    
    # 合并交易记录
    all_txs = usdt_txs + trx_txs
    all_txs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
    all_txs = all_txs[:30]
    
    return jsonify({
        "address": address,
        "trx_balance": trx_balance,
        "usdt_balance": usdt_balance,
        "energy": energy,
        "transactions": all_txs
    })


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000, debug=True)
