diff --git a/.gitignore b/.gitignore index ffa1e97..fcb50cc 100644 --- a/.gitignore +++ b/.gitignore @@ -4,9 +4,14 @@ __pycache__/ *.pyo .env +# Config (contains token) +config.txt + # Data files (large) data/*.parquet data/*.csv +data/*.db +data/temp/ !A股股票列表.csv # Logs diff --git a/fetch_history_v2.py b/fetch_history_v2.py new file mode 100644 index 0000000..cba6c00 --- /dev/null +++ b/fetch_history_v2.py @@ -0,0 +1,393 @@ +""" +A股历史数据获取系统 V2 - 性能优化版 +优化点: +1. 分文件存储 - 每只股票单独存小文件,避免每次读写整个大文件 +2. 批量合并 - 每100只股票合并一次,减少IO次数 +3. SQLite进度记录 - 更可靠的断点续传 +""" + +import tushare as ts +import pandas as pd +import os +import time +import sqlite3 +from datetime import datetime +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +import threading + +# 配置 +BASE_DIR = Path(__file__).parent +DATA_DIR = BASE_DIR / 'data' +TEMP_DIR = DATA_DIR / 'temp' # 临时分片目录 +LOGS_DIR = BASE_DIR / 'logs' +STOCK_LIST_FILE = BASE_DIR / 'A股股票列表.csv' +DB_FILE = DATA_DIR / 'progress.db' # SQLite进度数据库 + +# 创建目录 +DATA_DIR.mkdir(exist_ok=True) +TEMP_DIR.mkdir(exist_ok=True) +LOGS_DIR.mkdir(exist_ok=True) + +# 时间范围 +START_DATE = '20100101' +END_DATE = datetime.now().strftime('%Y%m%d') + +# 请求间隔(秒)- tushare积分限制 +REQUEST_INTERVAL = 2 # 减少间隔,用轻量存储补偿 + +# 批量合并阈值 +MERGE_BATCH_SIZE = 50 # 每50只股票合并一次 + + +def setup_tushare(token=None): + """初始化tushare""" + if not token: + token = os.environ.get('TUSHARE_TOKEN', '') + + if not token: + config_file = BASE_DIR / 'config.txt' + if config_file.exists(): + token = config_file.read_text().strip() + + if not token: + raise ValueError("缺少 Tushare Token") + + ts.set_token(token) + return ts.pro_api() + + +def init_progress_db(): + """初始化SQLite进度数据库""" + conn = sqlite3.connect(DB_FILE) + cursor = conn.cursor() + + # 创建进度表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS progress ( + ts_code TEXT PRIMARY KEY, + status TEXT DEFAULT 'pending', + record_count INTEGER DEFAULT 0, + updated_at TEXT, + error_msg TEXT + ) + ''') + + # 创建合并记录表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS merge_log ( + batch_id INTEGER PRIMARY KEY AUTOINCREMENT, + stock_count INTEGER, + merged_at TEXT, + file_size INTEGER + ) + ''') + + conn.commit() + conn.close() + print(f"进度数据库: {DB_FILE}") + + +def load_stock_list(): + """加载股票列表""" + df = pd.read_csv(STOCK_LIST_FILE) + df.columns = df.columns.str.strip() + print(f"加载股票列表: {len(df)} 只股票") + return df + + +def get_stock_codes_with_suffix(df): + """将股票代码转换为tushare格式""" + codes = [] + for code in df['code']: + code = str(code).zfill(6) + first_digit = code[0] + + if first_digit == '6': + ts_code = f"{code}.SH" + elif first_digit in ('0', '3'): + ts_code = f"{code}.SZ" + elif first_digit in ('4', '8'): + ts_code = f"{code}.BJ" + else: + ts_code = f"{code}.SZ" + + codes.append(ts_code) + return codes + + +def init_stock_progress(codes): + """初始化所有股票的进度状态""" + conn = sqlite3.connect(DB_FILE) + cursor = conn.cursor() + + # 批量插入(不存在则插入) + for ts_code in codes: + cursor.execute(''' + INSERT OR IGNORE INTO progress (ts_code, status, updated_at) + VALUES (?, 'pending', ?) + ''', (ts_code, datetime.now().isoformat())) + + conn.commit() + + # 统计状态 + cursor.execute('SELECT status, COUNT(*) FROM progress GROUP BY status') + stats = cursor.fetchall() + conn.close() + + print("\n当前进度状态:") + for status, count in stats: + print(f" {status}: {count} 只") + + return stats + + +def get_pending_stocks(): + """获取待处理的股票列表""" + conn = sqlite3.connect(DB_FILE) + cursor = conn.cursor() + + cursor.execute(''' + SELECT ts_code FROM progress + WHERE status = 'pending' OR status = 'error' + ORDER BY ts_code + ''') + + pending = [row[0] for row in cursor.fetchall()] + conn.close() + + return pending + + +def save_stock_temp(df, ts_code): + """保存单只股票到临时文件(极快)""" + temp_file = TEMP_DIR / f"{ts_code.replace('.', '_')}.parquet" + df.to_parquet(temp_file, index=False, compression='snappy') + + # 更新进度 + conn = sqlite3.connect(DB_FILE) + cursor = conn.cursor() + cursor.execute(''' + UPDATE progress + SET status = 'completed', record_count = ?, updated_at = ? + WHERE ts_code = ? + ''', (len(df), datetime.now().isoformat(), ts_code)) + conn.commit() + conn.close() + + return temp_file.stat().st_size + + +def merge_batch_to_main(): + """将临时文件批量合并到主文件""" + temp_files = list(TEMP_DIR.glob('*.parquet')) + + if not temp_files: + return 0 + + print(f"\n正在合并 {len(temp_files)} 个临时文件...") + + # 读取所有临时文件 + batch_data = [] + for tf in temp_files: + try: + df = pd.read_parquet(tf) + batch_data.append(df) + except Exception as e: + print(f" 警告: 读取 {tf.name} 失败: {e}") + + if not batch_data: + return 0 + + # 合并 + new_data = pd.concat(batch_data, ignore_index=True) + + # 读取主文件并合并 + main_file = DATA_DIR / 'stock_daily_data.parquet' + + if main_file.exists(): + existing = pd.read_parquet(main_file) + # 获取已合并的股票代码 + existing_codes = set(existing['ts_code'].unique()) + new_codes = set(new_data['ts_code'].unique()) + + # 只合并新股票的数据 + truly_new = new_data[~new_data['ts_code'].isin(existing_codes)] + + if len(truly_new) > 0: + combined = pd.concat([existing, truly_new], ignore_index=True) + else: + combined = existing + else: + combined = new_data + + # 排序并保存 + combined = combined.sort_values(['ts_code', 'trade_date']).reset_index(drop=True) + combined.to_parquet(main_file, index=False, compression='snappy') + + # 记录合并日志 + conn = sqlite3.connect(DB_FILE) + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO merge_log (stock_count, merged_at, file_size) + VALUES (?, ?, ?) + ''', (len(temp_files), datetime.now().isoformat(), main_file.stat().st_size)) + conn.commit() + conn.close() + + # 删除临时文件 + for tf in temp_files: + tf.unlink() + + print(f" 合并完成: {len(new_data)} 条新记录") + print(f" 主文件大小: {main_file.stat().st_size / 1024 / 1024:.2f} MB") + + return len(temp_files) + + +def fetch_stock_data(pro, ts_code): + """获取单只股票数据""" + try: + df = pro.daily(ts_code=ts_code, start_date=START_DATE, end_date=END_DATE) + + if df is not None and len(df) > 0: + # 保存到临时文件 + file_size = save_stock_temp(df, ts_code) + return True, len(df), file_size + else: + # 无数据,标记完成 + conn = sqlite3.connect(DB_FILE) + cursor = conn.cursor() + cursor.execute(''' + UPDATE progress + SET status = 'no_data', updated_at = ? + WHERE ts_code = ? + ''', (datetime.now().isoformat(), ts_code)) + conn.commit() + conn.close() + return True, 0, 0 + + except Exception as e: + # 记录错误 + conn = sqlite3.connect(DB_FILE) + cursor = conn.cursor() + cursor.execute(''' + UPDATE progress + SET status = 'error', error_msg = ?, updated_at = ? + WHERE ts_code = ? + ''', (str(e)[:200], datetime.now().isoformat(), ts_code)) + conn.commit() + conn.close() + return False, 0, 0 + + +def fetch_all_stocks(pro, codes): + """获取所有股票数据""" + total = len(codes) + batch_count = 0 + + print(f"\n开始获取数据...") + print(f"共 {total} 只股票") + print(f"每 {MERGE_BATCH_SIZE} 只合并一次") + print("-" * 50) + + for i, ts_code in enumerate(codes): + success, records, size = fetch_stock_data(pro, ts_code) + + status = "✓" if success else "✗" + print(f"[{i+1}/{total}] {ts_code} {status} {records}条 {size/1024:.1f}KB") + + # 批量合并检查 + batch_count += 1 + if batch_count >= MERGE_BATCH_SIZE: + merge_batch_to_main() + batch_count = 0 + + # 请求间隔 + if i < total - 1: + time.sleep(REQUEST_INTERVAL) + + # 最后合并剩余的 + if batch_count > 0: + merge_batch_to_main() + + print("\n" + "=" * 50) + print("数据获取完成!") + + +def show_final_stats(): + """显示最终统计""" + conn = sqlite3.connect(DB_FILE) + cursor = conn.cursor() + + # 状态统计 + cursor.execute('SELECT status, COUNT(*) FROM progress GROUP BY status') + stats = cursor.fetchall() + + # 记录数统计 + cursor.execute('SELECT SUM(record_count) FROM progress WHERE status = "completed"') + total_records = cursor.fetchone()[0] or 0 + + # 合并历史 + cursor.execute('SELECT COUNT(*), SUM(stock_count) FROM merge_log') + merge_stats = cursor.fetchone() + + conn.close() + + print("\n最终统计:") + print("-" * 30) + for status, count in stats: + print(f" {status}: {count} 只") + print(f"\n 总记录数: {total_records}") + print(f" 合并批次: {merge_stats[0]} 次") + + # 文件大小 + main_file = DATA_DIR / 'stock_daily_data.parquet' + if main_file.exists(): + print(f" 主文件大小: {main_file.stat().st_size / 1024 / 1024:.2f} MB") + + +def main(): + """主函数""" + print("=" * 60) + print("A股历史数据获取系统 V2 - 性能优化版") + print("=" * 60) + print(f"数据时间范围: {START_DATE} ~ {END_DATE}") + print(f"数据保存目录: {DATA_DIR}") + print("=" * 60) + + # 初始化 + print("\n初始化 Tushare...") + pro = setup_tushare() + + print("\n初始化进度数据库...") + init_progress_db() + + # 加载股票列表 + print("\n加载股票列表...") + stock_df = load_stock_list() + codes = get_stock_codes_with_suffix(stock_df) + + # 初始化进度 + init_stock_progress(codes) + + # 获取待处理股票 + pending = get_pending_stocks() + + if not pending: + print("\n所有股票已完成!") + show_final_stats() + return + + print(f"\n待处理: {len(pending)} 只股票") + print(f"预计耗时: {len(pending) * REQUEST_INTERVAL / 60:.1f} 分钟") + + # 开始获取 + fetch_all_stocks(pro, pending) + + # 显示统计 + show_final_stats() + + +if __name__ == '__main__': + main() diff --git a/run_v2.sh b/run_v2.sh new file mode 100644 index 0000000..d84dc8a --- /dev/null +++ b/run_v2.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# A股历史数据获取 - V2优化版 + +cd "$(dirname "$0")" + +# 检查配置 +if [ ! -f "config.txt" ]; then + echo "请先配置 Tushare Token:" + echo " echo 'your_token' > config.txt" + exit 1 +fi + +# 检查股票列表 +if [ ! -f "A股股票列表.csv" ]; then + echo "缺少股票列表文件: A股股票列表.csv" + exit 1 +fi + +# 运行 +echo "启动 A股历史数据获取 V2..." +python3 fetch_history_v2.py + +# 完成后生成汇总文件 +if [ -f "data/stock_daily_data.parquet" ]; then + echo "生成CSV汇总文件..." + python3 -c " +import pandas as pd +from pathlib import Path +from datetime import datetime + +df = pd.read_parquet('data/stock_daily_data.parquet') +timestamp = datetime.now().strftime('%Y%m%d') +df.to_csv(f'data/A股日线数据_{timestamp}.csv', index=False) +print(f'CSV文件: data/A股日线数据_{timestamp}.csv ({len(df)}条记录)') +" +fi \ No newline at end of file