diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..d6d2e92 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,265 @@ +# AGENTS.md + +Guidelines for AI agents working with the qfeval-data repository. + +## Repository Overview + +**qfeval-data** is a Python library for handling financial time series data. It provides the `Data` class—a specialized data structure built on PyTorch tensors for efficient manipulation of timestamped, symbol-indexed financial data (OHLCV). + +For detailed specifications, see the `docs/` directory: +- `docs/README.md` - Documentation index (Japanese: `docs/README.ja.md`) +- `docs/data.md` - Complete Data class API reference (Japanese: `docs/data.ja.md`) +- `docs/flattener.md` - Flattener class reference (Japanese: `docs/flattener.ja.md`) +- `docs/util.md` - Utility functions reference (Japanese: `docs/util.ja.md`) +- `docs/examples.md` - Practical examples and recipes (Japanese: `docs/examples.ja.md`) + +## Codebase Structure + +``` +qfeval-data/ +├── qfeval_data/ # Main package +│ ├── __init__.py # Exports: Data, Flattener, __version__ +│ ├── data.py # Core Data class +│ ├── flattener.py # Tensor flattening utilities +│ ├── util.py # Helper functions +│ ├── plot.py # Visualization (requires matplotlib) +│ └── version.py # Version string +├── tests/ # pytest test suite +└── pyproject.toml # Project configuration +``` + +## Development Commands + + +```bash +# Install dependencies +pip install -e ".[dev]" + +# Run tests +pytest + +# Run tests with coverage +pytest --cov=qfeval_data + +# Linting and formatting +black qfeval_data tests +isort qfeval_data tests +flake8 qfeval_data tests + +# Type checking +mypy qfeval_data +``` + +## Code Style + +- Formatter: `black` (line-length: 80) +- Import sorting: `isort` +- Linting: `flake8` +- Type checking: `mypy` (strict mode) + +## Key Design Patterns + +1. **Lazy slicing**: Data slicing creates views without copying tensors +2. **Sorted indexes**: Timestamps and symbols are always sorted internally +3. **Method chaining**: Most methods return `Data` for fluent API +4. **PyTorch backend**: Full GPU support via tensor operations + +--- + +## Using qfeval_data.Data (PyPI Package) + +This section is for agents that consume qfeval-data as a dependency. + + + +### Installation + + +```bash +pip install qfeval-data + +# With plotting support +pip install qfeval-data[plot] +``` + +### Core Concepts + +The `Data` class wraps a dictionary of PyTorch tensors indexed by: +- **timestamps**: `np.ndarray[datetime64]` (sorted) +- **symbols**: `np.ndarray[str]` (sorted) +- **columns**: Named tensors with shape `(num_timestamps, num_symbols, *extra_dims)` + +### Creating Data Objects + +```python +from qfeval_data import Data +import pandas as pd +import numpy as np +import torch + +# From pandas DataFrame (requires "timestamp" and "symbol" columns) +df = pd.DataFrame({ + "timestamp": ["2024-01-01", "2024-01-01", "2024-01-02", "2024-01-02"], + "symbol": ["AAPL", "GOOG", "AAPL", "GOOG"], + "open": [100.0, 200.0, 101.0, 201.0], + "close": [105.0, 205.0, 106.0, 206.0], +}) +data = Data.from_dataframe(df) + +# From tensors directly +tensors = { + "open": torch.tensor([[100.0, 200.0], [101.0, 201.0]]), + "close": torch.tensor([[105.0, 205.0], [106.0, 206.0]]), +} +timestamps = np.array(["2024-01-01", "2024-01-02"], dtype="datetime64[D]") +symbols = np.array(["AAPL", "GOOG"]) +data = Data.from_tensors(tensors, timestamps, symbols) +``` + +### Accessing Data + +```python +# Column access +opens = data.get("open") # Single column +data.open # Attribute access shortcut + +# Slicing (lazy - no data copy) +subset = data[:10, :] # First 10 timestamps +subset = data["2024-01-01", :] # By timestamp value +subset = data[:, "AAPL"] # Single symbol +subset = data[:, ["AAPL", "GOOG"]] # Multiple symbols + +# Properties +data.timestamps # np.ndarray of timestamps +data.symbols # np.ndarray of symbols +data.columns # List of column names +data.shape # (num_timestamps, num_symbols) +data.tensors # Dict[str, Tensor] after slicing +``` + +### Arithmetic Operations + +All arithmetic is element-wise on tensors: + +```python +returns = (data.close / data.open) - 1 +mask = data.close > data.open # Boolean Data +``` + +### Time Series Operations + +```python +data.shift(1) # Shift forward by 1 timestamp +data.pct_change() # Percent change +data.diff() # Difference +data.cumsum() # Cumulative sum +data.moving_average(2) # 2-period moving average +``` + +### Aggregation (axis: 0=timestamp, 1=symbol, None=both) + +```python +data.mean(axis=0) # Mean across timestamps +data.sum(axis=1) # Sum across symbols +data.std() # Std dev across all +data.min(axis=0) +data.max(axis=0) +data.count() # Count non-NaN +``` + +### Missing Value Handling + +```python +data.dropna(axis=0, how="any") # Drop timestamps with any NaN +data.fillna(0.0) # Fill NaN with value +data.fillna(method="ffill") # Forward fill +``` + +### Financial Metrics + +```python +data.close.annualized_return() +data.close.annualized_volatility() +data.close.annualized_sharpe_ratio() +data.close.maximum_drawdown() +data.close.metrics() # All metrics combined +``` + +### Resampling + +```python +data.daily() +data.weekly() +data.monthly() +data.yearly() +``` + +### Conversion + +```python +df = data.to_dataframe() # pandas DataFrame (long format) +csv = data.to_csv() # CSV string +``` + +### Device and Dtype + +```python +data.to(torch.float64) # Change dtype +data.device # Current device +data.dtype # Current dtype +``` + +### Method Chaining Example + +```python +result = ( + data + .get(["open", "close"]) + .dropna() + .pct_change() + .fillna(0.0) + .mean(axis=1) +) +``` + +### Flattener Utility + +Convert between `Data` (timestamp/symbol indexed) and flat `Tensor` (batch indexed): + +```python +from qfeval_data import Flattener + +flattener = Flattener(data.close) +flat_tensor = flattener.flatten(data.close) # Data -> Tensor +restored = flattener.unflatten(flat_tensor, "prices") # Tensor -> Data +``` + +### Further Documentation + +For complete API documentation and more examples, see: +- `docs/data.md` - Full Data class reference with all methods and parameters +- `docs/flattener.md` - Flattener class details +- `docs/examples.md` - Practical recipes for common tasks diff --git a/README.ja.md b/README.ja.md new file mode 100644 index 0000000..4bce791 --- /dev/null +++ b/README.ja.md @@ -0,0 +1,33 @@ +# qfeval-data + +[[English](README.md)] + +[![python](https://img.shields.io/badge/python-%3E=3.9-blue.svg)](https://pypi.org/project/qfeval_data/) +[![pypi](https://img.shields.io/pypi/v/qfeval_data.svg)](https://pypi.org/project/qfeval_data/) +[![CI](https://github.com/pfnet-research/qfeval-data/actions/workflows/ci-python.yaml/badge.svg)](https://github.com/pfnet-research/qfeval-data/actions/workflows/ci-python.yaml) +[![codecov](https://codecov.io/gh/pfnet-research/qfeval-data/graph/badge.svg?token=5A02B1JV7V)](https://codecov.io/gh/pfnet-research/qfeval-data) +[![downloads](https://img.shields.io/pypi/dm/qfeval_data)](https://pypi.org/project/qfeval_data) +[![code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + +qfevalは、Preferred Networks 金融チームが開発している、金融時系列処理のためのフレームワークです。 +データ形式の仕様定義、金融時系列データを効率的に扱うためのクラス/関数群、および金融時系列モデルの評価フレームワークが含まれます。 + +qfeval-dataは、qfevalの中でも、金融時系列データを効率的に扱うためのデータフレームを提供します。 + +## インストール + +```bash +pip install qfeval_data +``` + +## 使い方 + +詳細なドキュメントは [docs/README.ja.md](docs/README.ja.md) を参照してください。 + +## リリース手順 + +1. `release/X.X.X` のブランチを作成する。 +2. version.yaml (Bump) のワークフローが実行され、`Bumping version from Z.Z.Z to X.X.X` というタイトルのプルリクエストが作成されるので、これをマージする。 +3. `release/X.X.X` ブランチを `master` にマージするプルリクエスト(タイトルは `Release/X.X.X` のままで OK)を作成する。 +4. 他の人から Approval を得て、`Release/X.X.X` のプルリクエストのマージをする。 +5. [Release ワークフロー](https://github.com/pfnet-research/qfeval-data/actions/workflows/release.yaml) が走るのでこれの完了を待ち、 PyPI の [qfeval/data](https://pypi.org/project/qfeval_data/#history) で新しいバージョンが追加されたことを確認する。 diff --git a/README.md b/README.md index 0aa5fc4..6c0edd8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ # qfeval-data + +[[日本語](README.ja.md)] + [![python](https://img.shields.io/badge/python-%3E=3.9-blue.svg)](https://pypi.org/project/qfeval_data/) [![pypi](https://img.shields.io/pypi/v/qfeval_data.svg)](https://pypi.org/project/qfeval_data/) [![CI](https://github.com/pfnet-research/qfeval-data/actions/workflows/ci-python.yaml/badge.svg)](https://github.com/pfnet-research/qfeval-data/actions/workflows/ci-python.yaml) @@ -6,13 +9,6 @@ [![downloads](https://img.shields.io/pypi/dm/qfeval_data)](https://pypi.org/project/qfeval_data) [![code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -qfevalは、Preferred Networks 金融チームが開発している、金融時系列処理のためのフレームワークです。 -データ形式の仕様定義、金融時系列データを効率的に扱うためのクラス/関数群、および金融時系列モデルの評価フレームワークが含まれます。 - -qfeval-dataは、qfevalの中でも、金融時系列データを効率的に扱うためのデータフレームを提供します。 - ---- - qfeval is a framework developed by Preferred Networks' Financial Solutions team for processing financial time series data. It includes: data format specification definitions, a set of classes/functions for efficiently handling financial time series data, and a framework for evaluating financial time series models. @@ -25,12 +21,13 @@ pip install qfeval_data ``` ## Usage -TBD -## リリース手順 +See [docs/README.md](docs/README.md) for detailed documentation. + +## Release Process -1. `release/X.X.X` のブランチを作成する。 -2. version.yaml (Bump) のワークフローが実行され、`Bumping version from Z.Z.Z to X.X.X` というタイトルのプルリクエストが作成されるので、これをマージする。 -3. `release/X.X.X` ブランチを `master` にマージするプルリクエスト(タイトルは `Release/X.X.X` のままで OK)を作成する。 -4. 他の人から Approval を得て、`Release/X.X.X` のプルリクエストのマージをする。 -5. [Release ワークフロー](https://github.com/pfnet-research/qfeval-data/actions/workflows/release.yaml) が走るのでこれの完了を待ち、 PyPI の [qfeval/data](https://pypi.org/project/qfeval_data/#history) で新しいバージョンが追加されたことを確認する。 +1. Create a `release/X.X.X` branch. +2. The version.yaml (Bump) workflow will run and create a pull request titled `Bumping version from Z.Z.Z to X.X.X`. Merge this PR. +3. Create a pull request to merge the `release/X.X.X` branch into `master` (the title `Release/X.X.X` is fine). +4. Get approval from another team member and merge the `Release/X.X.X` pull request. +5. Wait for the [Release workflow](https://github.com/pfnet-research/qfeval-data/actions/workflows/release.yaml) to complete, then verify the new version appears on PyPI at [qfeval/data](https://pypi.org/project/qfeval_data/#history). diff --git a/docs/README.ja.md b/docs/README.ja.md new file mode 100644 index 0000000..270670c --- /dev/null +++ b/docs/README.ja.md @@ -0,0 +1,67 @@ +# qfeval-data ドキュメント + +[[English](README.md)] + +**qfeval-data** は金融時系列データを効率的に操作するための Python ライブラリです。タイムスタンプとシンボルでインデックス付けされた金融データを扱うための、PyTorch テンソルをベースにした特殊なデータ構造である `Data` クラスを提供します。 + +## 主な機能 + +- **PyTorch バックエンド**: PyTorch テンソルによる GPU アクセラレーションの完全サポート +- **遅延スライシング**: 不要なコピーを行わない効率的なデータアクセス +- **金融特化**: OHLCV データ、メトリクス、テクニカル指標の組み込みサポート +- **柔軟な I/O**: CSV、DataFrame からの読み込み、またはテンソルからの直接構築 +- **可視化**: matplotlib による統合されたローソク足チャートとプロット機能 + +## インストール + +```bash +pip install qfeval-data + +# プロット機能付き +pip install qfeval-data[plot] +``` + +## クイックスタート + +```python +from qfeval_data import Data +import pandas as pd + +# CSV から読み込み +data = Data.from_csv("prices.csv") + +# または DataFrame から +df = pd.DataFrame({ + "timestamp": ["2024-01-01", "2024-01-01", "2024-01-02", "2024-01-02"], + "symbol": ["AAPL", "GOOG", "AAPL", "GOOG"], + "close": [150.0, 140.0, 152.0, 142.0], +}) +data = Data.from_dataframe(df) + +# アクセスと操作 +returns = data.pct_change() +avg_return = returns.mean(axis=0) + +# メトリクスの計算 +metrics = data.metrics() +print(metrics.to_dataframe()) +``` + +## ドキュメント目次 + +- [Data クラスリファレンス](data.ja.md) - `Data` クラスの完全な API リファレンス +- [Flattener リファレンス](flattener.ja.md) - `Data` とフラットテンソル間の変換 +- [ユーティリティ関数](util.ja.md) - 配列と時間操作のヘルパー関数 +- [使用例](examples.ja.md) - 一般的な使用パターンとレシピ + +## 必要条件 + +- Python >= 3.9 +- PyTorch +- NumPy +- pandas +- qfeval-functions + +## ライセンス + +詳細は [LICENSE](../LICENSE) ファイルを参照してください。 diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..9fa52ac --- /dev/null +++ b/docs/README.md @@ -0,0 +1,72 @@ +# qfeval-data Documentation + +[[日本語](README.ja.md)] + +**qfeval-data** is a Python library for efficient manipulation of financial time series data. It provides the `Data` class—a specialized data structure built on PyTorch tensors for working with timestamped, symbol-indexed financial data. + +## Key Features + +- **PyTorch Backend**: Full GPU acceleration support via PyTorch tensors +- **Lazy Slicing**: Efficient data access without unnecessary copying +- **Financial Focus**: Built-in support for OHLCV data, metrics, and technical indicators +- **Flexible I/O**: Load from CSV, DataFrame, or construct from tensors +- **Visualization**: Integrated candlestick charts and plotting via matplotlib + +## Installation + + +```bash +pip install qfeval-data + +# With plotting support +pip install qfeval-data[plot] +``` + +## Quick Start + + + +```python +from qfeval_data import Data +import pandas as pd + +# Create from DataFrame +df = pd.DataFrame({ + "timestamp": ["2024-01-01", "2024-01-01", "2024-01-02", "2024-01-02"], + "symbol": ["AAPL", "GOOG", "AAPL", "GOOG"], + "close": [150.0, 140.0, 152.0, 142.0], +}) +data = Data.from_dataframe(df) + +# Access and manipulate +returns = data.pct_change() +avg_return = returns.mean(axis=0) + +# Calculate metrics +metrics = data.metrics() +print(metrics.to_dataframe()) +``` + +## Documentation Contents + +- [Data Class Reference](data.md) - Complete API reference for the `Data` class +- [Flattener Reference](flattener.md) - Converting between `Data` and flat tensors +- [Utility Functions](util.md) - Helper functions for arrays and time operations +- [Examples](examples.md) - Common usage patterns and recipes + +## Requirements + +- Python >= 3.9 +- PyTorch +- NumPy +- pandas +- qfeval-functions + +## License + +See the [LICENSE](../LICENSE) file for details. diff --git a/docs/data.ja.md b/docs/data.ja.md new file mode 100644 index 0000000..a1c6a37 --- /dev/null +++ b/docs/data.ja.md @@ -0,0 +1,924 @@ +# Data クラスリファレンス + +`Data` クラスは qfeval-data の中核コンポーネントです。タイムスタンプとシンボルでインデックス付けされた数値テンソルを管理し、効率的な金融時系列データの操作のために設計されています。 + + + +## 概要 + +```python +from qfeval_data import Data +``` + +### データ構造 + +- **テンソル**: カラム名(文字列)から PyTorch テンソルへの辞書 +- **形状**: 各テンソルは `(num_timestamps, num_symbols, *extra_dimensions)` の形状を持つ +- **タイムスタンプ**: `np.ndarray[datetime64]` - 常にソート済み +- **シンボル**: `np.ndarray[str]` - 常にソート済み + +### 設計原則 + +1. **遅延スライシング**: スライス操作はデータをコピーせずビューを作成 +2. **ソート済みインデックス**: タイムスタンプとシンボルは構築時に自動的にソート +3. **メソッドチェーン**: ほとんどのメソッドは `Data` オブジェクトを返し、流暢な API を実現 +4. **GPU サポート**: PyTorch テンソルバックエンドによる完全なデバイス柔軟性 + +--- + +## 構築メソッド + +### `Data.from_dataframe(df, dtype=None, device=None)` + +pandas DataFrame から `Data` オブジェクトを作成します。 + +**パラメータ:** +- `df` (`pd.DataFrame`): `timestamp` と `symbol` カラムが必須の DataFrame +- `dtype` (`torch.dtype`, 省略可): テンソルのデータ型 +- `device` (`str` または `torch.device`, 省略可): テンソルのデバイス + +**戻り値:** `Data` + +**例:** +```python +import pandas as pd +from qfeval_data import Data + +df = pd.DataFrame({ + "timestamp": ["2024-01-01", "2024-01-01", "2024-01-02", "2024-01-02"], + "symbol": ["AAPL", "GOOG", "AAPL", "GOOG"], + "open": [150.0, 140.0, 152.0, 142.0], + "close": [155.0, 145.0, 153.0, 143.0], +}) +data = Data.from_dataframe(df) +``` + +**多次元カラム:** + +多次元データにはカラム名にブラケット記法を使用: +```python +df = pd.DataFrame({ + "timestamp": ["2024-01-01"], + "symbol": ["AAPL"], + "embedding[0]": [0.1], + "embedding[1]": [0.2], + "embedding[2]": [0.3], +}) +data = Data.from_dataframe(df) +# data.embedding の形状は (1, 1, 3) +``` + +--- + +### `Data.from_csv(input, dtype=None, device=None)` + +CSV ファイルから `Data` オブジェクトを読み込みます。 + +**パラメータ:** +- `input` (`str` またはファイルライクオブジェクト): CSV ファイルへのパスまたはファイルオブジェクト +- `dtype` (`torch.dtype`, 省略可): テンソルのデータ型 +- `device` (`str` または `torch.device`, 省略可): テンソルのデバイス + +**戻り値:** `Data` + +**例:** +```python +data = Data.from_csv("prices.csv") +# data = Data.from_csv("prices.csv.xz") # 圧縮ファイルもサポート +``` + +**CSV フォーマット:** + +```csv +timestamp,symbol,open,high,low,close,volume +2024-01-01,AAPL,150.0,156.0,149.0,155.0,1000000 +2024-01-01,GOOG,140.0,146.0,139.0,145.0,800000 +``` + +--- + +### `Data.from_tensors(tensors, timestamps, symbols)` + +テンソルから直接 `Data` オブジェクトを作成します。最もプリミティブなコンストラクタです。 + +**パラメータ:** +- `tensors` (`Dict[str, torch.Tensor]`): カラム名からテンソルへの辞書 +- `timestamps` (`np.ndarray`): datetime64 値の1次元配列 +- `symbols` (`np.ndarray`): シンボル文字列の1次元配列 + +**戻り値:** `Data` + +**例:** +```python +import torch +import numpy as np +from qfeval_data import Data + +tensors = { + "open": torch.tensor([[150.0, 140.0], [152.0, 142.0]]), + "close": torch.tensor([[155.0, 145.0], [153.0, 143.0]]), +} +timestamps = np.array(["2024-01-01", "2024-01-02"], dtype="datetime64[D]") +symbols = np.array(["AAPL", "GOOG"]) + +data = Data.from_tensors(tensors, timestamps, symbols) +``` + +**注意:** +- タイムスタンプとシンボルは自動的にソートされ、テンソルもそれに応じて再インデックスされます +- すべてのテンソルは `(len(timestamps), len(symbols), ...)` の形状を持つ必要があります +- すべてのテンソルは同じデバイス上にある必要があります + +--- + +### `Data.from_preset(name="pfn-topix500", dtype=None, device=None, paths=[])` + +システムパスからプリセットデータファイルを読み込みます。 + +**パラメータ:** +- `name` (`str`): プリセット名(`data/{name}.csv` または `data/{name}.csv.xz` を検索) +- `dtype` (`torch.dtype`, 省略可): テンソルのデータ型 +- `device` (`str` または `torch.device`, 省略可): テンソルのデバイス +- `paths` (`List[str]`): 追加の検索パス + +**戻り値:** `Data` + +**例外:** プリセットが見つからない場合 `FileNotFoundError` + +--- + +## プロパティ + +### データアクセスプロパティ + +| プロパティ | 型 | 説明 | +|----------|------|-------------| +| `tensors` | `Dict[str, Tensor]` | スライシング適用後のテンソル | +| `tensor` | `Tensor` | 単一テンソル(カラムが1つの場合のみ) | +| `raw_tensors` | `Dict[str, Tensor]` | スライシングなしの直接テンソルアクセス | +| `raw_tensor` | `Tensor` | 単一の生テンソル | +| `arrays` | `Dict[str, np.ndarray]` | テンソルの NumPy 配列版 | +| `array` | `np.ndarray` | 単一配列版 | + +### メタデータプロパティ + +| プロパティ | 型 | 説明 | +|----------|------|-------------| +| `timestamps` | `np.ndarray` | ソート済み datetime64 配列 | +| `symbols` | `np.ndarray` | ソート済み文字列配列 | +| `columns` | `List[str]` | カラム名のリスト | +| `shape` | `Tuple[int, int]` | `(num_timestamps, num_symbols)` | +| `device` | `torch.device` | テンソルのデバイス | +| `dtype` | `torch.dtype` | テンソルのデータ型 | + +--- + +## インデックスとスライシング + +### `data[timestamp_idx, symbol_idx]` + +タイムスタンプとシンボルインデックスでデータにアクセス。複数のインデックススタイルをサポート: + +**整数インデックス:** +```python +data[0, :] # 最初のタイムスタンプ、全シンボル +data[:, 0] # 全タイムスタンプ、最初のシンボル +data[0, 0] # 単一要素 +data[-1, :] # 最後のタイムスタンプ +``` + +**スライスインデックス:** +```python +data[:10, :] # 最初の10タイムスタンプ +data[5:15, :] # タイムスタンプ 5-14 +data[:, :3] # 最初の3シンボル +``` + +**値ベースインデックス:** +```python +data["2024-01-01", :] # タイムスタンプ値で指定 +data["2024-01-01":"2024-01-31", :] # タイムスタンプ範囲 +data[:, "AAPL"] # シンボル値で指定 +data[:, ["AAPL", "GOOG"]] # 複数シンボル +``` + +**ブールマスクインデックス:** +```python +mask = data.close > data.open # ブール Data +filtered = data[mask] # マスク適用(マッチしない箇所は NaN に) +``` + +--- + +## カラムアクセス + +### `data.get(*columns)` / `data.get(columns)` / `data.get(pattern=...)` + +カラムのサブセットを抽出します。 + +**シグネチャ:** +```python +def get(self, *columns: str) -> Data: ... +def get(self, columns: Iterable[str]) -> Data: ... +def get(self, filter_func: Callable[[str], bool]) -> Data: ... +def get(self, *, pattern: str) -> Data: ... +``` + +**例:** +```python +# 単一カラム +opens = data.get("open") + +# 複数カラム +ohlc = data.get("open", "high", "low", "close") +ohlc = data.get(["open", "high", "low", "close"]) + +# フィルタ関数 +prices = data.get(lambda c: c in ["open", "close"]) + +# Glob パターン +prices = data.get(pattern="*price*") +``` + +### 属性アクセス + +カラムは属性としてアクセス可能: +```python +data.close # data.get("close") と同等 +data.volume # data.get("volume") と同等 +``` + +--- + +### `data.set(key, value)` + +カラムを追加または更新します。 + +**パラメータ:** +- `key` (`str`): カラム名 +- `value` (`torch.Tensor` または `Data`): カラム値 + +**例:** +```python +data.set("returns", data.close.pct_change().tensor) +data.set("spread", data.high - data.low) +``` + +--- + +### `data.rename(columns)` + +カラム名を変更します。 + +**パラメータ:** +- `columns` (`str`, `List[str]`, または `Dict[str, str]`): 新しいカラム名 + +**戻り値:** `Data` + +**例:** +```python +# 単一カラムの名前変更(Data が1カラムの場合) +renamed = data.get("close").rename("price") + +# リストで名前変更(カラム数と一致する必要あり) +renamed = data.rename(["o", "h", "l", "c", "v"]) + +# 辞書で選択的に名前変更 +renamed = data.rename({"open": "o", "close": "c"}) +``` + +--- + +## 算術演算 + +すべての算術演算はテンソルに対して要素単位で行われます: + +### 二項演算子 + +| 演算子 | 説明 | +|----------|-------------| +| `+`, `-`, `*`, `/` | 基本算術 | +| `//` | 切り捨て除算 | +| `%` | 剰余 | +| `**` | べき乗 | +| `@` | 行列乗算 | +| `&`, `\|`, `^` | ビット演算 | + +### 比較演算子 + +| 演算子 | 説明 | +|----------|-------------| +| `==`, `!=` | 等価(ブール Data を返す) | +| `<`, `>`, `<=`, `>=` | 比較(ブール Data を返す) | + +**注意:** Python の真偽値評価を避けるため `.eq()` と `.ne()` メソッドを使用してください。 + +### 単項演算子 + +| 演算子 | 説明 | +|----------|-------------| +| `-x` | 符号反転 | +| `+x` | 正 | +| `abs(x)` | 絶対値 | +| `~x` | ビット否定 | + +**例:** +```python +returns = (data.close / data.open) - 1 +spread = data.high - data.low +is_up = data.close > data.open +``` + +--- + +## 時系列演算 + +### `data.shift(shift=1, skipna=False)` + +タイムスタンプ軸に沿って値をシフトします。 + +**パラメータ:** +- `shift` (`int`): シフトする期間数(正=前方、負=後方) +- `skipna` (`bool`): True の場合、シフト時に NaN 値をスキップ + +**戻り値:** `Data` + +**例:** +```python +previous = data.shift(1) # 前日の値 +next_day = data.shift(-1) # 翌日の値 +``` + +--- + +### `data.pct_change(periods=1, skipna=False)` + +変化率を計算します。 + +**計算式:** `(current / previous) - 1` + +**パラメータ:** +- `periods` (`int`): 比較のためのシフト期間 +- `skipna` (`bool`): NaN 値をスキップ + +**戻り値:** `Data` + +**例:** +```python +daily_returns = data.close.pct_change() +weekly_returns = data.close.pct_change(periods=5) +``` + +--- + +### `data.diff(periods=1, skipna=False)` + +現在値と前回値の差を計算します。 + +**計算式:** `current - previous` + +**パラメータ:** +- `periods` (`int`): 比較のためのシフト期間 +- `skipna` (`bool`): NaN 値をスキップ + +**戻り値:** `Data` + +--- + +### `data.cumsum(axis=0, skipna=True)` + +軸に沿った累積和。 + +**パラメータ:** +- `axis` (`int` または `str`): 軸(0/"timestamp" または 1/"symbol") +- `skipna` (`bool`): NaN 値をスキップ + +**戻り値:** `Data` + +--- + +### `data.cumprod(axis=0, skipna=True)` + +軸に沿った累積積。 + +**パラメータ:** +- `axis` (`int` または `str`): 軸(0/"timestamp" または 1/"symbol") +- `skipna` (`bool`): NaN 値をスキップ + +**戻り値:** `Data` + +--- + +## 集約メソッド + +すべての集約メソッドは `axis` パラメータをサポート: +- `axis=0` または `axis="timestamp"`: タイムスタンプ方向に集約 +- `axis=1` または `axis="symbol"`: シンボル方向に集約 +- `axis=None`: 両軸で集約 + +### 統計集約 + +| メソッド | 説明 | +|--------|-------------| +| `sum(axis=None)` | 合計 | +| `mean(axis=None)` | 算術平均 | +| `min(axis=None)` | 最小値 | +| `max(axis=None)` | 最大値 | +| `var(axis=None, ddof=1)` | 分散 | +| `std(axis=None, ddof=1)` | 標準偏差 | +| `skew(axis=None, ddof=1)` | 歪度 | +| `kurt(axis=None, ddof=1)` | 尖度 | +| `count(axis=None)` | 非 NaN 値のカウント | + +### 位置集約 + +| メソッド | 説明 | +|--------|-------------| +| `first(axis="timestamp", skipna=True)` | 最初の値 | +| `last(axis="timestamp", skipna=True)` | 最後の値 | + +**例:** +```python +# 全タイムスタンプの平均価格 +avg_price = data.close.mean(axis=0) + +# シンボルごとの合計出来高 +total_vol = data.volume.sum(axis=0) + +# 全体統計 +stats = data.close.mean() # スカラー(単一値) +``` + +--- + +## 欠損値処理 + +### `data.dropna(axis=0, how="any", thresh=None)` + +欠損値を含む行またはカラムを削除します。 + +**パラメータ:** +- `axis` (`int` または `str`): 削除する軸(0=タイムスタンプ、1=シンボル) +- `how` (`str`): "any"(いずれかが NaN なら削除)または "all"(すべてが NaN なら削除) +- `thresh` (`int`, 省略可): 必要な非 NaN 値の最小数 + +**戻り値:** `Data` + +**例:** +```python +# 欠損値があるタイムスタンプを削除 +clean = data.dropna(axis=0, how="any") + +# すべてが欠損のシンボルを削除 +clean = data.dropna(axis=1, how="all") +``` + +--- + +### `data.fillna(value=0.0, method=None, axis=0)` + +欠損値を補完します。 + +**パラメータ:** +- `value` (`float`): NaN を置き換える値(`method=None` の場合) +- `method` (`str`, 省略可): 補完方法 - `"ffill"`(前方補完)または `"bfill"`(後方補完) +- `axis` (`int` または `str`): 補完メソッドの軸 + +**戻り値:** `Data` + +**例:** +```python +# ゼロで補完 +filled = data.fillna(0.0) + +# 前方補完(前の値を使用) +filled = data.fillna(method="ffill") + +# 後方補完(次の値を使用) +filled = data.fillna(method="bfill") +``` + +--- + +## 金融メトリクス + +### `data.annualized_return()` + +年率リターンを計算します。 + +**計算式:** `(last / first) ^ (1 / years) - 1` + +**戻り値:** タイムスタンプ次元が折りたたまれた `Data` + +**エイリアス:** `ar()` + +--- + +### `data.annualized_volatility()` + +年率ボラティリティ(年率換算されたリターンの標準偏差)を計算します。 + +**戻り値:** タイムスタンプ次元が折りたたまれた `Data` + +**エイリアス:** `avol()` + +--- + +### `data.annualized_sharpe_ratio()` + +年率シャープレシオを計算します。 + +**計算式:** `annualized_return / annualized_volatility` + +**戻り値:** タイムスタンプ次元が折りたたまれた `Data` + +**エイリアス:** `asr()` + +--- + +### `data.maximum_drawdown()` + +最大ドローダウン(最大のピークからトラフへの下落)を計算します。 + +**戻り値:** タイムスタンプ次元が折りたたまれた `Data` + +**エイリアス:** `mdd()` + +--- + +### `data.metrics()` + +すべてのメトリクスを一度に計算します。 + +**戻り値:** 以下のカラムを持つ `Data`: +- `annualized_sharpe_ratio` +- `annualized_return` +- `annualized_volatility` +- `maximum_drawdown` + +**例:** +```python +metrics = data.close.metrics() +print(metrics.to_dataframe()) +``` + +--- + +## リサンプリングメソッド + +データをより低い頻度にダウンサンプリングします。OHLC カラムは特別に処理されます: +- `open`: ウィンドウ内の最初の有効値 +- `high`: ウィンドウ内の最大値 +- `low`: ウィンドウ内の最小値 +- `close`: ウィンドウ内の最後の有効値 +- その他のカラム: デフォルトで合計 + +### メソッド + +| メソッド | 頻度 | +|--------|-----------| +| `minutely()` | 1分 | +| `hourly()` | 1時間 | +| `daily()` | 1日 | +| `weekly()` | 7日 | +| `monthly()` | 1ヶ月 | +| `yearly()` | 1年 | + +**パラメータ(すべてのメソッド共通):** +- `origin` (`np.datetime64`, 省略可): バケット化の起点時刻 +- `offset` (`np.timedelta64`, 省略可): タイムゾーンオフセット調整 +- `aggregation_f` (callable): 非 OHLC カラムの集約関数 + +**例:** +```python +# ティックデータを日次 OHLCV に変換 +daily = tick_data.daily() + +# タイムゾーンオフセット付き週次データ +weekly = data.weekly(offset=np.timedelta64(9, "h")) +``` + +--- + +### `data.downsample(delta, origin=None, offset=None, aggregation_f=nansum)` + +任意の頻度への汎用ダウンサンプリング。 + +**パラメータ:** +- `delta` (`np.timedelta64`): バケット化の時間間隔 +- `origin` (`np.datetime64`, 省略可): 起点時刻 +- `offset` (`np.timedelta64`, 省略可): タイムゾーンオフセット +- `aggregation_f` (callable): 集約関数 + +**例:** +```python +# 15分足 +bars_15m = data.downsample(np.timedelta64(15, "m")) +``` + +--- + +## テクニカル指標 + +### `data.moving_average(window=25, skipna=True)` + +単純移動平均を計算します。 + +**パラメータ:** +- `window` (`int`): ウィンドウサイズ +- `skipna` (`bool`): NaN 値をスキップ + +**戻り値:** `Data` + +--- + +### `data.bollinger_band(window=20, sigma=2.0, skipna=True)` + +ボリンジャーバンドを計算します。 + +**パラメータ:** +- `window` (`int`): 移動平均のウィンドウサイズ +- `sigma` (`float`): バンドの標準偏差数 + +**戻り値:** `Tuple[Data, Data, Data]` - (上限、中央、下限) + +**例:** +```python +upper, middle, lower = data.close.bollinger_band(window=20, sigma=2.0) +``` + +--- + +## 可視化メソッド + +すべての可視化メソッドには matplotlib が必要です(`pip install qfeval-data[plot]`)。 + +### `data.plot(ax=None, **kwargs)` + +カラムに基づいてプロットタイプを自動検出。OHLC データにはローソク足、それ以外は折れ線グラフを使用。 + +**パラメータ:** +- `ax` (`matplotlib.axes.Axes`, 省略可): プロット先の Axes + +**戻り値:** `List[matplotlib.axes.Axes]` + +--- + +### `data.line(ax=None, even=False, **kwargs)` + +折れ線グラフ。 + +**パラメータ:** +- `ax` (`matplotlib.axes.Axes`, 省略可): プロット先の Axes +- `even` (`bool`): 等間隔の x 軸を使用(時間ギャップを無視) +- `**kwargs`: `matplotlib.plot()` に渡される + +--- + +### `data.bar(width=0.8, bottom=0.0, ax=None, **kwargs)` + +棒グラフ。 + +**パラメータ:** +- `width` (`float` または `Data`): 棒の幅 +- `bottom` (`float` または `Data`): 棒の底の位置 +- `ax` (`matplotlib.axes.Axes`, 省略可): プロット先の Axes + +--- + +### `data.candlestick(ax=None, **kwargs)` + +OHLC ローソク足チャート。`open`, `high`, `low`, `close` カラムが必要です。 + +**パラメータ:** +- `ax` (`matplotlib.axes.Axes`, 省略可): プロット先の Axes +- `upcolor` (`str`): 陽線の色(デフォルト: "#ee3333") +- `downcolor` (`str`): 陰線の色(デフォルト: "#118822") +- `neutralcolor` (`str`): 中立の色(デフォルト: "#444444") +- `width` (`float`): ローソク実体の幅(デフォルト: 0.6) +- `linewidth` (`float`): 髭の線幅(デフォルト: 0.5) + +**例:** +```python +import matplotlib.pyplot as plt +from qfeval_data import Data + +data = Data.from_csv("prices.csv") +data[:, "AAPL"].candlestick() # 単一シンボルをプロット +plt.close() +``` + +--- + +## 変換メソッド + +### `data.to_dataframe()` + +ロング形式の pandas DataFrame に変換します。 + +**戻り値:** `timestamp`, `symbol`, および全データカラムを持つ `pd.DataFrame` + +**例:** +```python +df = data.to_dataframe() +# timestamp symbol open close +# 0 2024-01-01 AAPL 150.0 155.0 +# 1 2024-01-01 GOOG 140.0 145.0 +``` + +--- + +### `data.to_table()` + +ワイド形式(2次元テーブル)の pandas DataFrame に変換します。 + +**戻り値:** `pd.DataFrame` + +**注意:** +- 単一カラム: タイムスタンプがインデックス、シンボルがカラム +- 複数カラム: 単一タイムスタンプまたは単一シンボルが必要 + +--- + +### `data.to_series()` + +pandas Series に変換します。単一カラムかつ単一シンボルが必要です。 + +**戻り値:** `pd.Series` + +--- + +### `data.to_csv(path=None)` + +CSV 形式でエクスポートします。 + +**パラメータ:** +- `path` (`str`, 省略可): ファイルパス。None の場合、CSV 文字列を返します。 + +**戻り値:** `str`(path が None の場合)または `None` + +--- + +## ユーティリティメソッド + +### `data.copy(deep=False)` + +コピーを作成します。 + +**パラメータ:** +- `deep` (`bool`): True の場合、テンソルをコピー;そうでなければテンソル参照を共有 + +**戻り値:** `Data` + +--- + +### `data.to(dtype_or_device)` + +dtype および/または device を変換します。 + +**シグネチャ:** +```python +def to(self, dtype: torch.dtype) -> Data: ... +def to(self, device: torch.device) -> Data: ... +def to(self, tensor: torch.Tensor) -> Data: ... +def to(self, data: Data) -> Data: ... +``` + +**例:** +```python +data_f64 = data.to(torch.float64) +other_data = data.get("close") +data_like = data.to(other_data) # dtype/device を合わせる +``` + +--- + +### `data.like(other)` + +別の Data のタイムスタンプとシンボルに合わせてリシェイプします。 + +**パラメータ:** +- `other` (`Data`): 形状の参照となる Data + +**戻り値:** `Data` + +**注意:** +- 欠落したタイムスタンプ/シンボルの組み合わせは NaN で補完 +- 余分な組み合わせは破棄 + +--- + +### `data.merge(*others)` + +複数の Data オブジェクトをマージ(タイムスタンプ/シンボルの和集合)。 + +**パラメータ:** +- `*others` (`Data`): マージする Data オブジェクト + +**戻り値:** `Data` + +**注意:** +- 重複するセルでは、最後の非 NaN 値が優先 +- 同名のカラムは互換性のある形状を持つ必要あり + +--- + +### `data.apply(f, *args, skipna=False)` + +テンソルに関数を適用します。 + +**パラメータ:** +- `f` (callable): テンソルを受け取りテンソルを返す関数 +- `*args`: 追加の引数(Data または値) +- `skipna` (`bool`): NaN 値をスキップ + +**戻り値:** `Data` + +**例:** +```python +# カスタム関数を適用 +result = data.close.apply(lambda x: torch.log(x + 1)) + +# 追加引数付き +other_data = data.close +result = data.close.apply(lambda x, y: x * y, other_data) +``` + +--- + +## 比較メソッド + +### `data.equals(other)` + +完全な等価性をチェック(NaN の位置を含む)。 + +**戻り値:** `bool` + +--- + +### `data.allclose(other, rtol=1e-5, atol=1e-8)` + +近似等価性をチェック。 + +**パラメータ:** +- `rtol` (`float`): 相対許容誤差 +- `atol` (`float`): 絶対許容誤差 + +**戻り値:** `bool` + +--- + +## シリアライゼーション + +`Data` クラスは Python の pickle プロトコルをサポート: + +```python +import pickle +import tempfile +import os + +# tempfile を使用して保存と読み込み +with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + pickle.dump(data, f) + temp_path = f.name + +with open(temp_path, "rb") as f: + loaded_data = pickle.load(f) + +os.unlink(temp_path) # クリーンアップ +``` diff --git a/docs/data.md b/docs/data.md new file mode 100644 index 0000000..617b9bf --- /dev/null +++ b/docs/data.md @@ -0,0 +1,1050 @@ +# Data Class Reference + +The `Data` class is the core component of qfeval-data. It manages numerical tensors indexed by timestamps and symbols, designed for efficient financial time series manipulation. + + + +## Overview + +```python +from qfeval_data import Data +``` + +### Data Structure + +- **Tensors**: Dictionary mapping column names (strings) to PyTorch tensors +- **Shape**: Each tensor has shape `(num_timestamps, num_symbols, *extra_dimensions)` +- **Timestamps**: `np.ndarray[datetime64]` - always sorted +- **Symbols**: `np.ndarray[str]` - always sorted + +### Design Principles + +1. **Lazy Slicing**: Slicing operations create views without copying data +2. **Sorted Indexes**: Timestamps and symbols are automatically sorted on construction +3. **Method Chaining**: Most methods return `Data` objects for fluent API +4. **GPU Support**: Full PyTorch tensor backend with device flexibility + +--- + +## Construction Methods + +### `Data.from_dataframe(df, dtype=None, device=None)` + +Create a `Data` object from a pandas DataFrame. + +**Parameters:** +- `df` (`pd.DataFrame`): DataFrame with required `timestamp` and `symbol` columns +- `dtype` (`torch.dtype`, optional): Data type for tensors +- `device` (`str` or `torch.device`, optional): Device for tensors + +**Returns:** `Data` + +**Example:** +```python +import pandas as pd +from qfeval_data import Data + +df = pd.DataFrame({ + "timestamp": ["2024-01-01", "2024-01-01", "2024-01-02", "2024-01-02"], + "symbol": ["AAPL", "GOOG", "AAPL", "GOOG"], + "open": [150.0, 140.0, 152.0, 142.0], + "close": [155.0, 145.0, 153.0, 143.0], +}) +data = Data.from_dataframe(df) +``` + +**Multi-dimensional columns:** + +Use bracket notation in column names for multi-dimensional data: +```python +df = pd.DataFrame({ + "timestamp": ["2024-01-01"], + "symbol": ["AAPL"], + "embedding[0]": [0.1], + "embedding[1]": [0.2], + "embedding[2]": [0.3], +}) +data = Data.from_dataframe(df) +# data.embedding has shape (1, 1, 3) +``` + +--- + +### `Data.from_csv(input, dtype=None, device=None)` + +Load a `Data` object from a CSV file. + +**Parameters:** +- `input` (`str` or file-like): Path to CSV file or file object +- `dtype` (`torch.dtype`, optional): Data type for tensors +- `device` (`str` or `torch.device`, optional): Device for tensors + +**Returns:** `Data` + +**Example:** +```python +data = Data.from_csv("prices.csv") +# data = Data.from_csv("prices.csv.xz") # Supports compressed files +``` + +**CSV Format:** + +```csv +timestamp,symbol,open,high,low,close,volume +2024-01-01,AAPL,150.0,156.0,149.0,155.0,1000000 +2024-01-01,GOOG,140.0,146.0,139.0,145.0,800000 +``` + +--- + +### `Data.from_tensors(tensors, timestamps, symbols)` + +Create a `Data` object directly from tensors. This is the most primitive constructor. + +**Parameters:** +- `tensors` (`Dict[str, torch.Tensor]`): Dictionary of column name to tensor +- `timestamps` (`np.ndarray`): 1D array of datetime64 values +- `symbols` (`np.ndarray`): 1D array of symbol strings + +**Returns:** `Data` + +**Example:** +```python +import torch +import numpy as np +from qfeval_data import Data + +tensors = { + "open": torch.tensor([[150.0, 140.0], [152.0, 142.0]]), + "close": torch.tensor([[155.0, 145.0], [153.0, 143.0]]), +} +timestamps = np.array(["2024-01-01", "2024-01-02"], dtype="datetime64[D]") +symbols = np.array(["AAPL", "GOOG"]) + +data = Data.from_tensors(tensors, timestamps, symbols) +``` + +**Notes:** +- Timestamps and symbols are automatically sorted; tensors are reindexed accordingly +- All tensors must have shape `(len(timestamps), len(symbols), ...)` +- All tensors must be on the same device + +--- + +### `Data.from_preset(name="pfn-topix500", dtype=None, device=None, paths=[])` + +Load a preset data file from the system path. + +**Parameters:** +- `name` (`str`): Preset name (searches for `data/{name}.csv` or `data/{name}.csv.xz`) +- `dtype` (`torch.dtype`, optional): Data type for tensors +- `device` (`str` or `torch.device`, optional): Device for tensors +- `paths` (`List[str]`): Additional paths to search + +**Returns:** `Data` + +**Raises:** `FileNotFoundError` if preset not found + +--- + +## Properties + +### Data Access Properties + +| Property | Type | Description | +|----------|------|-------------| +| `tensors` | `Dict[str, Tensor]` | Tensors after slicing applied | +| `tensor` | `Tensor` | Single tensor (requires exactly 1 column) | +| `raw_tensors` | `Dict[str, Tensor]` | Direct tensor access without slicing | +| `raw_tensor` | `Tensor` | Single raw tensor | +| `arrays` | `Dict[str, np.ndarray]` | NumPy array versions of tensors | +| `array` | `np.ndarray` | Single array version | + +### Metadata Properties + +| Property | Type | Description | +|----------|------|-------------| +| `timestamps` | `np.ndarray` | Sorted datetime64 array | +| `symbols` | `np.ndarray` | Sorted string array | +| `columns` | `List[str]` | List of column names | +| `shape` | `Tuple[int, int]` | `(num_timestamps, num_symbols)` | +| `device` | `torch.device` | Tensor device | +| `dtype` | `torch.dtype` | Tensor data type | + +--- + +## Indexing and Slicing + +### `data[timestamp_idx, symbol_idx]` + +Access data by timestamp and symbol indices. Supports multiple indexing styles: + +**Integer indexing:** +```python +data[0, :] # First timestamp, all symbols +data[:, 0] # All timestamps, first symbol +data[0, 0] # Single element +data[-1, :] # Last timestamp +``` + +**Slice indexing:** +```python +data[:10, :] # First 10 timestamps +data[5:15, :] # Timestamps 5-14 +data[:, :3] # First 3 symbols +``` + +**Value-based indexing:** +```python +data["2024-01-01", :] # By timestamp value +data["2024-01-01":"2024-01-31", :] # Timestamp range +data[:, "AAPL"] # By symbol value +data[:, ["AAPL", "GOOG"]] # Multiple symbols +``` + +**Boolean mask indexing:** +```python +mask = data.close > data.open # Boolean Data +filtered = data[mask] # Apply mask (non-matching become NaN) +``` + +--- + +## Column Access + +### `data.get(*columns)` / `data.get(columns)` / `data.get(pattern=...)` + +Extract a subset of columns. + +**Signatures:** +```python +def get(self, *columns: str) -> Data: ... +def get(self, columns: Iterable[str]) -> Data: ... +def get(self, filter_func: Callable[[str], bool]) -> Data: ... +def get(self, *, pattern: str) -> Data: ... +``` + +**Examples:** +```python +# Single column +opens = data.get("open") + +# Multiple columns +ohlc = data.get("open", "high", "low", "close") +ohlc = data.get(["open", "high", "low", "close"]) + +# Filter function +prices = data.get(lambda c: c in ["open", "close"]) + +# Glob pattern +prices = data.get(pattern="*price*") +``` + +### Attribute Access + +Columns can be accessed as attributes: +```python +data.close # Equivalent to data.get("close") +data.volume # Equivalent to data.get("volume") +``` + +--- + +### `data.set(key, value)` + +Add or update a column. + +**Parameters:** +- `key` (`str`): Column name +- `value` (`torch.Tensor` or `Data`): Column values + +**Example:** +```python +data.set("returns", data.close.pct_change().tensor) +data.set("spread", data.high - data.low) +``` + +--- + +### `data.rename(columns)` + +Rename columns. + +**Parameters:** +- `columns` (`str`, `List[str]`, or `Dict[str, str]`): New column names + +**Returns:** `Data` + +**Examples:** +```python +# Rename single column (when Data has one column) +renamed = data.get("close").rename("price") + +# Rename with list (must match column count) +renamed = data.rename(["o", "h", "l", "c", "v"]) + +# Rename with dict (selective) +renamed = data.rename({"open": "o", "close": "c"}) +``` + +--- + +## Arithmetic Operations + +All arithmetic operations are element-wise on tensors: + +### Binary Operators + +| Operator | Description | +|----------|-------------| +| `+`, `-`, `*`, `/` | Basic arithmetic | +| `//` | Floor division | +| `%` | Modulo | +| `**` | Power | +| `@` | Matrix multiplication | +| `&`, `\|`, `^` | Bitwise operations | + +### Comparison Operators + +| Operator | Description | +|----------|-------------| +| `==`, `!=` | Equality (returns boolean Data) | +| `<`, `>`, `<=`, `>=` | Comparison (returns boolean Data) | + +**Note:** Use `.eq()` and `.ne()` methods to avoid Python's truthiness evaluation. + +### Unary Operators + +| Operator | Description | +|----------|-------------| +| `-x` | Negation | +| `+x` | Positive | +| `abs(x)` | Absolute value | +| `~x` | Bitwise not | + +**Examples:** +```python +returns = (data.close / data.open) - 1 +spread = data.high - data.low +is_up = data.close > data.open +``` + +--- + +## Time Series Operations + +### `data.shift(shift=1, skipna=False)` + +Shift values along the timestamp axis. + +**Parameters:** +- `shift` (`int`): Number of periods to shift (positive = forward, negative = backward) +- `skipna` (`bool`): If True, skip NaN values when shifting + +**Returns:** `Data` + +**Example:** +```python +previous = data.shift(1) # Previous day's values +next_day = data.shift(-1) # Next day's values +``` + +--- + +### `data.pct_change(periods=1, skipna=False)` + +Calculate percentage change. + +**Formula:** `(current / previous) - 1` + +**Parameters:** +- `periods` (`int`): Periods to shift for comparison +- `skipna` (`bool`): Skip NaN values + +**Returns:** `Data` + +**Example:** +```python +daily_returns = data.close.pct_change() +weekly_returns = data.close.pct_change(periods=5) +``` + +--- + +### `data.diff(periods=1, skipna=False)` + +Calculate difference between current and previous values. + +**Formula:** `current - previous` + +**Parameters:** +- `periods` (`int`): Periods to shift for comparison +- `skipna` (`bool`): Skip NaN values + +**Returns:** `Data` + +--- + +### `data.cumsum(axis=0, skipna=True)` + +Cumulative sum along an axis. + +**Parameters:** +- `axis` (`int` or `str`): Axis (0/"timestamp" or 1/"symbol") +- `skipna` (`bool`): Skip NaN values + +**Returns:** `Data` + +--- + +### `data.cumprod(axis=0, skipna=True)` + +Cumulative product along an axis. + +**Parameters:** +- `axis` (`int` or `str`): Axis (0/"timestamp" or 1/"symbol") +- `skipna` (`bool`): Skip NaN values + +**Returns:** `Data` + +--- + +### `data.group_shift(shift=1, reference=None)` + +Shift values while skipping timestamps where any symbol has missing values. + +**Parameters:** +- `shift` (`int`): Number of periods to shift +- `reference` (`Data`, optional): Data to determine skip pattern from + +**Returns:** `Data` + +--- + +## Aggregation Methods + +All aggregation methods support the `axis` parameter: +- `axis=0` or `axis="timestamp"`: Aggregate across timestamps +- `axis=1` or `axis="symbol"`: Aggregate across symbols +- `axis=None`: Aggregate across both axes + +### Statistical Aggregations + +| Method | Description | +|--------|-------------| +| `sum(axis=None)` | Sum of values | +| `mean(axis=None)` | Arithmetic mean | +| `min(axis=None)` | Minimum value | +| `max(axis=None)` | Maximum value | +| `var(axis=None, ddof=1)` | Variance | +| `std(axis=None, ddof=1)` | Standard deviation | +| `skew(axis=None, ddof=1)` | Skewness | +| `kurt(axis=None, ddof=1)` | Kurtosis | +| `count(axis=None)` | Count of non-NaN values | + +### Position Aggregations + +| Method | Description | +|--------|-------------| +| `first(axis="timestamp", skipna=True)` | First value | +| `last(axis="timestamp", skipna=True)` | Last value | + +**Examples:** +```python +# Average price across all timestamps +avg_price = data.close.mean(axis=0) + +# Total volume per symbol +total_vol = data.volume.sum(axis=0) + +# Overall statistics +stats = data.close.mean() # Scalar (single value) +``` + +--- + +## Missing Value Handling + +### `data.dropna(axis=0, how="any", thresh=None)` + +Remove rows or columns with missing values. + +**Parameters:** +- `axis` (`int` or `str`): Axis along which to drop (0=timestamps, 1=symbols) +- `how` (`str`): "any" (drop if any NaN) or "all" (drop if all NaN) +- `thresh` (`int`, optional): Minimum number of non-NaN values required + +**Returns:** `Data` + +**Example:** +```python +# Drop timestamps with any missing values +clean = data.dropna(axis=0, how="any") + +# Drop symbols with all missing values +clean = data.dropna(axis=1, how="all") +``` + +--- + +### `data.fillna(value=0.0, method=None, axis=0)` + +Fill missing values. + +**Parameters:** +- `value` (`float`): Value to fill NaN with (when `method=None`) +- `method` (`str`, optional): Fill method - `"ffill"` (forward fill) or `"bfill"` (backward fill) +- `axis` (`int` or `str`): Axis for fill methods + +**Returns:** `Data` + +**Examples:** +```python +# Fill with zero +filled = data.fillna(0.0) + +# Forward fill (use previous value) +filled = data.fillna(method="ffill") + +# Backward fill (use next value) +filled = data.fillna(method="bfill") +``` + +--- + +## Financial Metrics + +### `data.annualized_return()` + +Calculate annualized return. + +**Formula:** `(last / first) ^ (1 / years) - 1` + +**Returns:** `Data` with single timestamp dimension collapsed + +**Alias:** `ar()` + +--- + +### `data.annualized_volatility()` + +Calculate annualized volatility (standard deviation of returns scaled to yearly). + +**Returns:** `Data` with single timestamp dimension collapsed + +**Alias:** `avol()` + +--- + +### `data.annualized_sharpe_ratio()` + +Calculate annualized Sharpe ratio. + +**Formula:** `annualized_return / annualized_volatility` + +**Returns:** `Data` with single timestamp dimension collapsed + +**Alias:** `asr()` + +--- + +### `data.maximum_drawdown()` + +Calculate maximum drawdown (largest peak-to-trough decline). + +**Returns:** `Data` with single timestamp dimension collapsed + +**Alias:** `mdd()` + +--- + +### `data.metrics()` + +Calculate all metrics at once. + +**Returns:** `Data` with columns: +- `annualized_sharpe_ratio` +- `annualized_return` +- `annualized_volatility` +- `maximum_drawdown` + +**Example:** +```python +metrics = data.close.metrics() +print(metrics.to_dataframe()) +``` + +--- + +## Resampling Methods + +Downsample data to lower frequency. OHLC columns are handled specially: +- `open`: First valid value in window +- `high`: Maximum value in window +- `low`: Minimum value in window +- `close`: Last valid value in window +- Other columns: Sum by default + +### Methods + +| Method | Frequency | +|--------|-----------| +| `minutely()` | 1 minute | +| `hourly()` | 1 hour | +| `daily()` | 1 day | +| `weekly()` | 7 days | +| `monthly()` | 1 month | +| `yearly()` | 1 year | + +**Parameters (all methods):** +- `origin` (`np.datetime64`, optional): Origin time for bucketing +- `offset` (`np.timedelta64`, optional): Timezone offset adjustment +- `aggregation_f` (callable): Aggregation function for non-OHLC columns + +**Example:** +```python +# Convert tick data to daily OHLCV (no-op if already daily) +daily = data.daily() + +# Weekly data with timezone offset +weekly = data.weekly(offset=np.timedelta64(9, "h")) +``` + +--- + +### `data.downsample(delta, origin=None, offset=None, aggregation_f=nansum)` + +Generic downsampling to arbitrary frequency. + +**Parameters:** +- `delta` (`np.timedelta64`): Time interval for bucketing +- `origin` (`np.datetime64`, optional): Origin time +- `offset` (`np.timedelta64`, optional): Timezone offset +- `aggregation_f` (callable): Aggregation function + +**Example:** +```python +# Downsample to 2-day bars +bars_2d = data.downsample(np.timedelta64(2, "D")) +``` + +--- + +## Technical Indicators + +### `data.moving_average(window=25, skipna=True)` + +Calculate simple moving average. + +**Parameters:** +- `window` (`int`): Window size +- `skipna` (`bool`): Skip NaN values + +**Returns:** `Data` + +--- + +### `data.bollinger_band(window=20, sigma=2.0, skipna=True)` + +Calculate Bollinger Bands. + +**Parameters:** +- `window` (`int`): Window size for moving average +- `sigma` (`float`): Number of standard deviations for bands + +**Returns:** `Tuple[Data, Data, Data]` - (upper, middle, lower) + +**Example:** +```python +upper, middle, lower = data.close.bollinger_band(window=20, sigma=2.0) +``` + +--- + +## Visualization Methods + +All visualization methods require matplotlib (`pip install qfeval-data[plot]`). + +### `data.plot(ax=None, **kwargs)` + +Auto-detect plot type based on columns. Uses candlestick for OHLC data, line plot otherwise. + +**Parameters:** +- `ax` (`matplotlib.axes.Axes`, optional): Axes to plot on + +**Returns:** `List[matplotlib.axes.Axes]` + +--- + +### `data.line(ax=None, even=False, **kwargs)` + +Line plot. + +**Parameters:** +- `ax` (`matplotlib.axes.Axes`, optional): Axes to plot on +- `even` (`bool`): Use even x-axis spacing (ignore time gaps) +- `**kwargs`: Passed to `matplotlib.plot()` + +--- + +### `data.bar(width=0.8, bottom=0.0, ax=None, **kwargs)` + +Bar plot. + +**Parameters:** +- `width` (`float` or `Data`): Bar width +- `bottom` (`float` or `Data`): Bar bottom position +- `ax` (`matplotlib.axes.Axes`, optional): Axes to plot on + +--- + +### `data.candlestick(ax=None, **kwargs)` + +OHLC candlestick chart. Requires `open`, `high`, `low`, `close` columns. + +**Parameters:** +- `ax` (`matplotlib.axes.Axes`, optional): Axes to plot on +- `upcolor` (`str`): Color for up candles (default: "#ee3333") +- `downcolor` (`str`): Color for down candles (default: "#118822") +- `neutralcolor` (`str`): Color for neutral candles (default: "#444444") +- `width` (`float`): Candle body width (default: 0.6) +- `linewidth` (`float`): Wick line width (default: 0.5) + +**Example:** +```python +import matplotlib.pyplot as plt +from qfeval_data import Data + +data = Data.from_csv("prices.csv") +data[:, "AAPL"].candlestick() # Plot single symbol +plt.close() +``` + +--- + +### `data.vlines(ymax=0.0, ax=None, **kwargs)` + +Vertical lines plot. + +--- + +### `data.fill_between(y2=0.0, ax=None, **kwargs)` + +Fill area between curves. + +--- + +### `data.plot_moving_average(window=25, ax=None, **kwargs)` + +Plot moving average line. + +--- + +### `data.plot_bollinger_band(window=20, sigma=2.0, ax=None, **kwargs)` + +Plot Bollinger Bands with fill. + +--- + +## Conversion Methods + +### `data.to_dataframe()` + +Convert to pandas DataFrame in long format. + +**Returns:** `pd.DataFrame` with columns: `timestamp`, `symbol`, and all data columns + +**Example:** +```python +df = data.to_dataframe() +# timestamp symbol open close +# 0 2024-01-01 AAPL 150.0 155.0 +# 1 2024-01-01 GOOG 140.0 145.0 +``` + +--- + +### `data.to_table()` + +Convert to pandas DataFrame in wide format (2D table). + +**Returns:** `pd.DataFrame` + +**Notes:** +- Single column: timestamps as index, symbols as columns +- Multiple columns: requires single timestamp or single symbol + +--- + +### `data.to_series()` + +Convert to pandas Series. Requires single column and single symbol. + +**Returns:** `pd.Series` + +--- + +### `data.to_csv(path=None)` + +Export to CSV format. + +**Parameters:** +- `path` (`str`, optional): File path. If None, returns CSV string. + +**Returns:** `str` (if path is None) or `None` + +--- + +### `data.to_matrix()` + +Convert to DataFrame with timestamps as index and symbols as columns. + +**Returns:** `pd.DataFrame` + +--- + +### `data.to_matrix_csv(path=None)` + +Export matrix format as CSV. + +--- + +## Utility Methods + +### `data.copy(deep=False)` + +Create a copy. + +**Parameters:** +- `deep` (`bool`): If True, copy tensors; otherwise, share tensor references + +**Returns:** `Data` + +--- + +### `data.to(dtype_or_device)` + +Convert dtype and/or device. + +**Signatures:** +```python +def to(self, dtype: torch.dtype) -> Data: ... +def to(self, device: torch.device) -> Data: ... +def to(self, tensor: torch.Tensor) -> Data: ... +def to(self, data: Data) -> Data: ... +``` + +**Examples:** +```python +data_f64 = data.to(torch.float64) +other_data = data.get("close") +data_like = data.to(other_data) # Match dtype/device +``` + +--- + +### `data.like(other)` + +Reshape to match another Data's timestamps and symbols. + +**Parameters:** +- `other` (`Data`): Reference Data for shape + +**Returns:** `Data` + +**Notes:** +- Missing timestamp/symbol combinations filled with NaN +- Extra combinations discarded + +--- + +### `data.merge(*others)` + +Merge multiple Data objects (union of timestamps/symbols). + +**Parameters:** +- `*others` (`Data`): Data objects to merge + +**Returns:** `Data` + +**Notes:** +- For overlapping cells, last non-NaN value wins +- Columns with same name must have compatible shapes + +--- + +### `data.merge_columns(other)` + +Merge columns from another Data (same timestamps/symbols required). + +**Parameters:** +- `other` (`Data`): Data with columns to add + +**Returns:** `Data` + +--- + +### `data.apply(f, *args, skipna=False)` + +Apply function to tensors. + +**Parameters:** +- `f` (callable): Function taking tensor(s) and returning tensor +- `*args`: Additional arguments (Data or values) +- `skipna` (`bool`): Skip NaN values + +**Returns:** `Data` + +**Example:** +```python +# Apply custom function +result = data.close.apply(lambda x: torch.log(x + 1)) + +# With additional argument +other_data = data.close +result = data.close.apply(lambda x, y: x * y, other_data) +``` + +--- + +### `data.subsequences(start, stop, indexes=None)` + +Extract time subsequences. + +**Parameters:** +- `start` (`int`): Start offset from each timestamp +- `stop` (`int`): Stop offset from each timestamp +- `indexes`: Specific timestamps to extract from + +**Returns:** `Data` with extra dimension for subsequence + +--- + +### `data.zeros()` + +Create Data with same shape filled with zeros. + +**Returns:** `Data` + +--- + +## Comparison Methods + +### `data.equals(other)` + +Check exact equality (including NaN positions). + +**Returns:** `bool` + +--- + +### `data.allclose(other, rtol=1e-5, atol=1e-8)` + +Check approximate equality. + +**Parameters:** +- `rtol` (`float`): Relative tolerance +- `atol` (`float`): Absolute tolerance + +**Returns:** `bool` + +--- + +## Index Conversion + +### `data.timestamp_index(v, side="equal")` + +Convert timestamp value(s) to integer indices. + +**Returns:** `int` or `np.ndarray` + +--- + +### `data.symbol_index(v, side="equal")` + +Convert symbol value(s) to integer indices. + +**Returns:** `int` or `np.ndarray` + +--- + +### `data.has_timestamps()` + +Check if Data has valid timestamps (not aggregated). + +**Returns:** `bool` + +--- + +### `data.has_symbols()` + +Check if Data has valid symbols (not aggregated). + +**Returns:** `bool` + +--- + +### `data.size(dim=None)` + +Get size of dimension(s). + +**Parameters:** +- `dim` (`int` or `str`, optional): Specific dimension + +**Returns:** `Tuple[int, int]` (if dim is None) or `int` + +--- + +## Serialization + +The `Data` class supports Python's pickle protocol: + +```python +import pickle +import tempfile +import os + +# Save and load using tempfile +with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + pickle.dump(data, f) + temp_path = f.name + +with open(temp_path, "rb") as f: + loaded_data = pickle.load(f) + +os.unlink(temp_path) # Clean up +``` diff --git a/docs/examples.ja.md b/docs/examples.ja.md new file mode 100644 index 0000000..e442a9f --- /dev/null +++ b/docs/examples.ja.md @@ -0,0 +1,620 @@ +# 使用例とレシピ + +このドキュメントでは qfeval-data の実用的な例と一般的な使用パターンを紹介します。 + + + +## 目次 + +1. [データの読み込み](#データの読み込み) +2. [基本操作](#基本操作) +3. [時系列分析](#時系列分析) +4. [ポートフォリオ分析](#ポートフォリオ分析) +5. [データ変換](#データ変換) +6. [可視化](#可視化) +7. [機械学習との統合](#機械学習との統合) +8. [複数シンボルの操作](#複数シンボルの操作) + +--- + +## データの読み込み + +### CSV ファイルから + +```python +from qfeval_data import Data + +# 基本的な読み込み +data = Data.from_csv("prices.csv") + +# dtype を指定 +data = Data.from_csv("prices.csv", dtype=torch.float32) +``` + +**期待される CSV フォーマット:** + +```csv +timestamp,symbol,open,high,low,close,volume +2024-01-02,AAPL,185.5,186.2,184.1,185.8,50000000 +2024-01-02,GOOG,140.0,141.5,139.5,141.0,20000000 +2024-01-03,AAPL,186.0,187.5,185.0,186.5,48000000 +2024-01-03,GOOG,141.0,142.0,140.0,141.5,19000000 +``` + +### pandas DataFrame から + +```python +import pandas as pd +from qfeval_data import Data + +# サンプルデータを作成 +df = pd.DataFrame({ + "timestamp": pd.date_range("2024-01-01", periods=10, freq="D").repeat(2), + "symbol": ["AAPL", "GOOG"] * 10, + "close": [150 + i * 0.5 + (0 if i % 2 == 0 else 10) for i in range(20)], +}) + +data = Data.from_dataframe(df) +print(data.shape) # (10, 2) +``` + +### 生テンソルから + +```python +import torch +import numpy as np +from qfeval_data import Data + +# テンソルを作成 +timestamps = np.array(["2024-01-01", "2024-01-02", "2024-01-03"], dtype="datetime64[D]") +symbols = np.array(["AAPL", "GOOG", "MSFT"]) +prices = torch.randn(3, 3) * 10 + 100 # 3 タイムスタンプ x 3 シンボル + +data = Data.from_tensors({"close": prices}, timestamps, symbols) +``` + +### 多次元データ + +```python +# タイムスタンプ/シンボルごとの埋め込みベクトル +df = pd.DataFrame({ + "timestamp": ["2024-01-01", "2024-01-01"], + "symbol": ["AAPL", "GOOG"], + "embedding[0]": [0.1, 0.2], + "embedding[1]": [0.3, 0.4], + "embedding[2]": [0.5, 0.6], +}) +data = Data.from_dataframe(df) +print(data.embedding.tensor.shape) # (1, 2, 3) +``` + +--- + +## 基本操作 + +### データへのアクセス + +```python +from qfeval_data import Data + +# data はセットアップで作成済み + +# 特定のカラムを取得 +closes = data.close # 属性アクセス +closes = data.get("close") # メソッドアクセス +ohlc = data.get("open", "high", "low", "close") + +# 時間でスライス +first_week = data[:5, :] # 最初の5タイムスタンプ +jan_data = data["2024-01-01":"2024-01-12", :] + +# シンボルでスライス +apple = data[:, "AAPL"] # 単一シンボル +tech = data[:, ["AAPL", "GOOG", "MSFT"]] # 複数シンボル + +# 組み合わせスライス +apple_jan = data["2024-01-01":"2024-01-12", "AAPL"] +``` + +### 算術演算 + +```python +# リターン +returns = data.close.pct_change() + +# 対数リターン +log_returns = (data.close / data.close.shift(1)).apply(torch.log) + +# スプレッド +spread = data.high - data.low + +# カスタム計算 +typical_price = (data.high + data.low + data.close) / 3 +``` + +### フィルタリング + +```python +# ブールフィルタリング +up_days = data[data.close > data.open] # マッチしない箇所は NaN に + +# 欠損値を削除 +clean = data.dropna() + +# 欠損値を補完 +filled = data.fillna(method="ffill") +``` + +--- + +## 時系列分析 + +### ローリング計算 + +```python +# 移動平均(ウィンドウサイズ <= データ長) +ma_5 = data.close.moving_average(window=5) + +# ボリンジャーバンド +upper, middle, lower = data.close.bollinger_band(window=5, sigma=2.0) +``` + +### ラグ特徴量 + +```python +# 過去の値 +prev_close = data.close.shift(1) +prev_5_close = data.close.shift(5) + +# 将来の値(ターゲット用) +next_close = data.close.shift(-1) +next_return = data.close.shift(-1).pct_change() +``` + +### リサンプリング + +```python +# ティックデータから日次データ +daily = tick_data.daily() + +# 週次 OHLCV +weekly = daily.weekly() + +# タイムゾーンオフセット付き月次 +monthly = daily.monthly(offset=np.timedelta64(9, "h")) + +# カスタム間隔 +bars_15m = data.downsample(np.timedelta64(15, "m")) +``` + +--- + +## ポートフォリオ分析 + +### 単一銘柄のメトリクス + +```python +# 単一銘柄のメトリクスを取得 +apple = data[:, "AAPL"] +metrics = apple.close.metrics() +print(metrics.to_dataframe()) +# annualized_sharpe_ratio annualized_return annualized_volatility maximum_drawdown +# symbol +# AAPL 1.25 0.15 0.12 0.08 +``` + +### クロスセクション分析 + +```python +# 全銘柄のメトリクスを比較 +all_metrics = data.close.metrics() + +# 最高シャープレシオを見つける +sharpe = all_metrics.get("annualized_sharpe_ratio") +best_idx = sharpe.tensor.argmax() +best_symbol = data.symbols[best_idx] +print(f"最高シャープレシオ: {best_symbol}") +``` + +### ポートフォリオリターン + +```python +import torch + +# 等ウェイトポートフォリオ +weights = torch.ones(data.shape[1]) / data.shape[1] +portfolio_returns = (data.close.pct_change() * weights).sum(axis=1) + +# カスタムウェイト +weights = torch.tensor([0.4, 0.3, 0.3]) # AAPL, GOOG, MSFT +portfolio_returns = (data.close.pct_change() * weights).sum(axis=1) + +# ポートフォリオ累積リターン +cumulative = (1 + portfolio_returns).cumprod() +``` + +### 相関分析 + +```python +# リターンを計算 +returns = data.close.pct_change() + +# 相関のために numpy に変換 +returns_array = returns.dropna().array +import numpy as np +corr_matrix = np.corrcoef(returns_array.T) +print(pd.DataFrame(corr_matrix, index=data.symbols, columns=data.symbols)) +``` + +--- + +## データ変換 + +### 正規化 + +```python +# 時間方向の Z スコア正規化 +mean = data.close.mean(axis=0) +std = data.close.std(axis=0) +normalized = (data.close - mean) / std + +# Min-Max 正規化 +min_val = data.close.min(axis=0) +max_val = data.close.max(axis=0) +scaled = (data.close - min_val) / (max_val - min_val) +``` + +### 特徴量作成 + +```python +def create_features(data): + """一般的なテクニカル特徴量を作成""" + features = [] + + # リターン + features.append(data.close.pct_change().rename("return_1d")) + features.append(data.close.pct_change(3).rename("return_3d")) + + # 移動平均(ウィンドウサイズ <= データ長) + ma_3 = data.close.moving_average(3) + ma_5 = data.close.moving_average(5) + features.append((data.close / ma_3 - 1).rename("close_ma3_ratio")) + features.append((data.close / ma_5 - 1).rename("close_ma5_ratio")) + features.append((ma_3 / ma_5 - 1).rename("ma3_ma5_ratio")) + + # 出来高比率 + if "volume" in data.columns: + vol_ma = data.volume.moving_average(5) + features.append((data.volume / vol_ma).rename("volume_ratio")) + + # すべての特徴量をマージ + result = features[0] + for f in features[1:]: + result = result.merge_columns(f) + return result + +features = create_features(data) +``` + +### データソースのマージ + +```python +# 複数のデータソースをマージ +prices = Data.from_csv("prices.csv") +fundamentals = Data.from_csv("fundamentals.csv") + +# 同じタイムスタンプ/シンボル - カラムをマージ +combined = prices.merge_columns(fundamentals) + +# 異なるタイムスタンプ/シンボル - 和集合マージ +combined = prices.merge(fundamentals) +``` + +--- + +## 可視化 + +### 基本プロット + +```python +import matplotlib.pyplot as plt +from qfeval_data import Data + +prices = Data.from_csv("prices.csv") +aapl = prices[:, "AAPL"] + +# OHLC データにはローソク足 +aapl.candlestick() +plt.title("AAPL") +plt.close() +``` + +### ローソク足チャート + +```python +# 明示的なローソク足(aapl は前の例で作成済み) +aapl.candlestick() +plt.title("AAPL ローソク足") +plt.close() + +# カスタムカラー +aapl.candlestick( + upcolor="#00ff00", + downcolor="#ff0000", + width=0.8 +) +plt.close() +``` + +### 折れ線グラフ + +```python +# 単一系列 +aapl.close.line() +plt.title("AAPL 終値") +plt.close() + +# 複数系列 +fig, ax = plt.subplots() +aapl.close.line(ax=ax, label="Close") +aapl.close.moving_average(5).line(ax=ax, label="MA5") +plt.legend() +plt.close() +``` + +### テクニカル指標 + +```python +# 移動平均オーバーレイ +fig, ax = plt.subplots() +aapl.candlestick(ax=ax) +aapl.close.plot_moving_average(window=5, ax=ax, color="blue") +plt.close() + +# ボリンジャーバンド +fig, ax = plt.subplots() +aapl.candlestick(ax=ax) +aapl.close.plot_bollinger_band(window=5, ax=ax) +plt.close() +``` + +### 複数サブプロット + +```python +fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) + +# ボリンジャーバンド付き価格 +aapl.candlestick(ax=axes[0]) +aapl.close.plot_bollinger_band(window=5, ax=axes[0]) +axes[0].set_title("価格") + +# 出来高 +aapl.volume.bar(ax=axes[1]) +axes[1].set_title("出来高") + +# リターン +aapl.close.pct_change().line(ax=axes[2]) +axes[2].set_title("日次リターン") + +plt.tight_layout() +plt.close() +``` + +--- + +## 機械学習との統合 + +### PyTorch 用データ準備 + +```python +import torch +from qfeval_data import Data, Flattener + +# data はセットアップで作成済み + +# 特徴量とターゲットを作成(Flattener 用に単一カラム) +feature = data.close +target = data.close.pct_change() # 日次リターン + +# アライメント用の Flattener を作成 +flattener = Flattener(feature, target) + +# テンソルに変換 +X = flattener.flatten(feature).unsqueeze(-1) # 形状: (B, 1) +y = flattener.flatten(target) # 形状: (B,) + +print(f"特徴量形状: {X.shape}") +print(f"ターゲット形状: {y.shape}") +``` + +### 訓練ループ + +```python +import torch.nn as nn +import torch.optim as optim + +# シンプルなモデル +class Model(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(input_dim, 64), + nn.ReLU(), + nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 1) + ) + + def forward(self, x): + return self.layers(x).squeeze(-1) + +model = Model(X.shape[1]) +optimizer = optim.Adam(model.parameters(), lr=0.001) +criterion = nn.MSELoss() + +# 訓練(例のため短いループ) +for epoch in range(10): + optimizer.zero_grad() + pred = model(X) + loss = criterion(pred, y) + loss.backward() + optimizer.step() +``` + +### 予測の作成 + +```python +# 予測を作成 +model.eval() +with torch.no_grad(): + predictions = model(X) + +# Data 形式に戻す +pred_data = flattener.unflatten(predictions, "prediction") + +# 予測形状を確認 +print(f"予測形状: {pred_data.shape}") +``` + +### 時系列分割 + +```python +# 時間で分割(サンプルデータの日付を使用) +split_date = "2024-01-08" +train_data = data[:split_date, :] +test_data = data[split_date:, :] + +print(f"訓練: {train_data.shape}, テスト: {test_data.shape}") +``` + +--- + +## 複数シンボルの操作 + +### クロスセクション操作 + +```python +# 各タイムスタンプ内でシンボル間のランク付け +def rank_cross_section(data): + """各タイムスタンプでシンボル間の値をランク付け""" + return data.apply( + lambda x: x.argsort(dim=1).argsort(dim=1).float() / (x.shape[1] - 1) + ) + +ranked = rank_cross_section(data.close.pct_change()) +``` + +### セクター分析 + +```python +# セクターマッピングがあると仮定 +sector_map = {"AAPL": "Tech", "GOOG": "Tech", "JPM": "Finance", "XOM": "Energy"} +sectors = [sector_map.get(s, "Other") for s in data.symbols] + +# セクターでグループ化 +tech_symbols = [s for s, sec in zip(data.symbols, sectors) if sec == "Tech"] +tech_data = data[:, tech_symbols] + +# セクター平均 +tech_avg = tech_data.close.mean(axis=1).rename("tech_avg") +``` + +### ユニバースフィルタリング + +```python +# 流動性でフィルタ(サンプルデータ用に閾値を調整) +avg_volume = data.volume.mean(axis=0) +liquid_mask = avg_volume.tensor > 900000 +liquid_symbols = data.symbols[liquid_mask.cpu().numpy()] +liquid_data = data[:, liquid_symbols.tolist()] + +# 価格でフィルタ +avg_price = data.close.mean(axis=0) +valid_mask = (avg_price.tensor > 5) & (avg_price.tensor < 1000) +valid_symbols = data.symbols[valid_mask.cpu().numpy()] +``` + +### ペアトレーディング + +```python +# 2銘柄間のスプレッドを計算 +spread = data[:, "AAPL"].close - data[:, "GOOG"].close + +# スプレッドを正規化 +spread_mean = spread.mean(axis=0) +spread_std = spread.std(axis=0) +zscore = (spread - spread_mean) / spread_std + +# シグナル生成 +long_signal = zscore < -2 # AAPL 買い、GOOG 売り +short_signal = zscore > 2 # AAPL 売り、GOOG 買い +``` diff --git a/docs/examples.md b/docs/examples.md new file mode 100644 index 0000000..27a2758 --- /dev/null +++ b/docs/examples.md @@ -0,0 +1,617 @@ +# Examples and Recipes + +This document provides practical examples and common usage patterns for qfeval-data. + + + +## Table of Contents + +1. [Loading Data](#loading-data) +2. [Basic Operations](#basic-operations) +3. [Time Series Analysis](#time-series-analysis) +4. [Portfolio Analysis](#portfolio-analysis) +5. [Data Transformation](#data-transformation) +6. [Visualization](#visualization) +7. [Machine Learning Integration](#machine-learning-integration) +8. [Working with Multiple Symbols](#working-with-multiple-symbols) + +--- + +## Loading Data + +### From CSV File + +```python +from qfeval_data import Data + +# Basic loading +data = Data.from_csv("prices.csv") + +# With specific dtype +data = Data.from_csv("prices.csv", dtype=torch.float32) +``` + +**Expected CSV format:** + +```csv +timestamp,symbol,open,high,low,close,volume +2024-01-02,AAPL,185.5,186.2,184.1,185.8,50000000 +2024-01-02,GOOG,140.0,141.5,139.5,141.0,20000000 +2024-01-03,AAPL,186.0,187.5,185.0,186.5,48000000 +2024-01-03,GOOG,141.0,142.0,140.0,141.5,19000000 +``` + +### From pandas DataFrame + +```python +import pandas as pd +from qfeval_data import Data + +# Create sample data +df = pd.DataFrame({ + "timestamp": pd.date_range("2024-01-01", periods=10, freq="D").repeat(2), + "symbol": ["AAPL", "GOOG"] * 10, + "close": [150 + i * 0.5 + (0 if i % 2 == 0 else 10) for i in range(20)], +}) + +data = Data.from_dataframe(df) +print(data.shape) # (10, 2) +``` + +### From Raw Tensors + +```python +import torch +import numpy as np +from qfeval_data import Data + +# Create tensors +timestamps = np.array(["2024-01-01", "2024-01-02", "2024-01-03"], dtype="datetime64[D]") +symbols = np.array(["AAPL", "GOOG", "MSFT"]) +prices = torch.randn(3, 3) * 10 + 100 # 3 timestamps x 3 symbols + +data = Data.from_tensors({"close": prices}, timestamps, symbols) +``` + +### Multi-dimensional Data + +```python +# Embedding vectors per timestamp/symbol +df = pd.DataFrame({ + "timestamp": ["2024-01-01", "2024-01-01"], + "symbol": ["AAPL", "GOOG"], + "embedding[0]": [0.1, 0.2], + "embedding[1]": [0.3, 0.4], + "embedding[2]": [0.5, 0.6], +}) +data = Data.from_dataframe(df) +print(data.embedding.tensor.shape) # (1, 2, 3) +``` + +--- + +## Basic Operations + +### Accessing Data + +```python +from qfeval_data import Data + +# data is created in setup + +# Get specific columns +closes = data.close # Attribute access +closes = data.get("close") # Method access +ohlc = data.get("open", "high", "low", "close") + +# Slice by time +first_week = data[:5, :] # First 5 timestamps +jan_data = data["2024-01-01":"2024-01-12", :] + +# Slice by symbol +apple = data[:, "AAPL"] # Single symbol +tech = data[:, ["AAPL", "GOOG", "MSFT"]] # Multiple symbols + +# Combined slicing +apple_jan = data["2024-01-01":"2024-01-12", "AAPL"] +``` + +### Arithmetic + +```python +# Returns +returns = data.close.pct_change() + +# Log returns +log_returns = (data.close / data.close.shift(1)).apply(torch.log) + +# Spread +spread = data.high - data.low + +# Custom calculations +typical_price = (data.high + data.low + data.close) / 3 +``` + +### Filtering + +```python +# Boolean filtering +up_days = data[data.close > data.open] # Non-matching become NaN + +# Drop missing values +clean = data.dropna() + +# Fill missing values +filled = data.fillna(method="ffill") +``` + +--- + +## Time Series Analysis + +### Rolling Calculations + +```python +# Moving average (window size <= data length) +ma_5 = data.close.moving_average(window=5) + +# Bollinger Bands +upper, middle, lower = data.close.bollinger_band(window=5, sigma=2.0) +``` + +### Lagged Features + +```python +# Previous values +prev_close = data.close.shift(1) +prev_5_close = data.close.shift(5) + +# Future values (for targets) +next_close = data.close.shift(-1) +next_return = data.close.shift(-1).pct_change() +``` + +### Resampling + +```python +# Daily data from tick data (no-op if already daily) +daily = tick_data.daily() + +# Weekly OHLCV +weekly = daily.weekly() + +# Monthly with timezone offset +monthly = daily.monthly(offset=np.timedelta64(9, "h")) + +# Custom interval (2-day bars for daily data) +bars_2d = data.downsample(np.timedelta64(2, "D")) +``` + +--- + +## Portfolio Analysis + +### Single Stock Metrics + +```python +# Get metrics for a single stock +apple = data[:, "AAPL"] +metrics = apple.close.metrics() +print(metrics.to_dataframe()) +``` + +### Cross-sectional Analysis + +```python +# Compare metrics across all stocks +all_metrics = data.close.metrics() + +# Find best Sharpe ratio +sharpe = all_metrics.get("annualized_sharpe_ratio") +best_idx = sharpe.tensor.argmax() +best_symbol = data.symbols[best_idx] +print(f"Best Sharpe: {best_symbol}") +``` + +### Portfolio Returns + +```python +import torch + +# Equal-weighted portfolio +weights = torch.ones(data.shape[1]) / data.shape[1] +portfolio_returns = (data.close.pct_change() * weights).sum(axis=1) + +# Custom weights +weights = torch.tensor([0.4, 0.3, 0.3]) # AAPL, GOOG, MSFT +portfolio_returns = (data.close.pct_change() * weights).sum(axis=1) + +# Portfolio cumulative return +cumulative = (1 + portfolio_returns).cumprod() +``` + +### Correlation Analysis + +```python +# Calculate returns +returns = data.close.pct_change() + +# Convert to numpy for correlation +returns_array = returns.dropna().array +import numpy as np +corr_matrix = np.corrcoef(returns_array.T) +print(pd.DataFrame(corr_matrix, index=data.symbols, columns=data.symbols)) +``` + +--- + +## Data Transformation + +### Normalization + +```python +# Z-score normalization across time +mean = data.close.mean(axis=0) +std = data.close.std(axis=0) +normalized = (data.close - mean) / std + +# Min-max normalization +min_val = data.close.min(axis=0) +max_val = data.close.max(axis=0) +scaled = (data.close - min_val) / (max_val - min_val) +``` + +### Creating Features + +```python +def create_features(data): + """Create common technical features.""" + features = [] + + # Returns + features.append(data.close.pct_change().rename("return_1d")) + features.append(data.close.pct_change(3).rename("return_3d")) + + # Moving averages (window size <= data length) + ma_3 = data.close.moving_average(3) + ma_5 = data.close.moving_average(5) + features.append((data.close / ma_3 - 1).rename("close_ma3_ratio")) + features.append((data.close / ma_5 - 1).rename("close_ma5_ratio")) + features.append((ma_3 / ma_5 - 1).rename("ma3_ma5_ratio")) + + # Volume ratio + if "volume" in data.columns: + vol_ma = data.volume.moving_average(5) + features.append((data.volume / vol_ma).rename("volume_ratio")) + + # Merge all features + result = features[0] + for f in features[1:]: + result = result.merge_columns(f) + return result + +features = create_features(data) +``` + +### Merging Data Sources + +```python +# Merge multiple data sources +prices = Data.from_csv("prices.csv") +fundamentals = Data.from_csv("fundamentals.csv") + +# Same timestamps/symbols - merge columns +combined = prices.merge_columns(fundamentals) + +# Different timestamps/symbols - union merge +combined = prices.merge(fundamentals) +``` + +--- + +## Visualization + +### Basic Plots + +```python +import matplotlib.pyplot as plt +from qfeval_data import Data + +prices = Data.from_csv("prices.csv") +aapl = prices[:, "AAPL"] + +# Candlestick for OHLC data +aapl.candlestick() +plt.title("AAPL") +plt.close() +``` + +### Candlestick Chart + +```python +# Explicit candlestick (using aapl from previous example) +aapl.candlestick() +plt.title("AAPL Candlestick") +plt.close() + +# Custom colors +aapl.candlestick( + upcolor="#00ff00", + downcolor="#ff0000", + width=0.8 +) +plt.close() +``` + +### Line Plots + +```python +# Single series +aapl.close.line() +plt.title("AAPL Close Price") +plt.close() + +# Multiple series +fig, ax = plt.subplots() +aapl.close.line(ax=ax, label="Close") +aapl.close.moving_average(5).line(ax=ax, label="MA5") +plt.legend() +plt.close() +``` + +### Technical Indicators + +```python +# Moving average overlay +fig, ax = plt.subplots() +aapl.candlestick(ax=ax) +aapl.close.plot_moving_average(window=5, ax=ax, color="blue") +plt.close() + +# Bollinger Bands +fig, ax = plt.subplots() +aapl.candlestick(ax=ax) +aapl.close.plot_bollinger_band(window=5, ax=ax) +plt.close() +``` + +### Multiple Subplots + +```python +fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) + +# Price with Bollinger Bands +aapl.candlestick(ax=axes[0]) +aapl.close.plot_bollinger_band(window=5, ax=axes[0]) +axes[0].set_title("Price") + +# Volume +aapl.volume.bar(ax=axes[1]) +axes[1].set_title("Volume") + +# Returns +aapl.close.pct_change().line(ax=axes[2]) +axes[2].set_title("Daily Returns") + +plt.tight_layout() +plt.close() +``` + +--- + +## Machine Learning Integration + +### Preparing Data for PyTorch + +```python +import torch +from qfeval_data import Data, Flattener + +# data is created in setup + +# Create feature and target (single column for Flattener) +feature = data.close +target = data.close.pct_change() # Daily return + +# Create flattener for alignment +flattener = Flattener(feature, target) + +# Convert to tensors +X = flattener.flatten(feature).unsqueeze(-1) # shape: (B, 1) +y = flattener.flatten(target) # shape: (B,) + +print(f"Features shape: {X.shape}") +print(f"Target shape: {y.shape}") +``` + +### Training Loop + +```python +import torch.nn as nn +import torch.optim as optim + +# Simple model +class Model(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(input_dim, 64), + nn.ReLU(), + nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 1) + ) + + def forward(self, x): + return self.layers(x).squeeze(-1) + +model = Model(X.shape[1]) +optimizer = optim.Adam(model.parameters(), lr=0.001) +criterion = nn.MSELoss() + +# Training (short loop for example) +for epoch in range(10): + optimizer.zero_grad() + pred = model(X) + loss = criterion(pred, y) + loss.backward() + optimizer.step() +``` + +### Making Predictions + +```python +# Make predictions +model.eval() +with torch.no_grad(): + predictions = model(X) + +# Convert back to Data format +pred_data = flattener.unflatten(predictions, "prediction") + +# View predictions shape +print(f"Predictions shape: {pred_data.shape}") +``` + +### Time Series Split + +```python +# Split by time (using dates in the sample data) +split_date = "2024-01-08" +train_data = data[:split_date, :] +test_data = data[split_date:, :] + +print(f"Train: {train_data.shape}, Test: {test_data.shape}") +``` + +--- + +## Working with Multiple Symbols + +### Cross-sectional Operations + +```python +# Rank within each timestamp +def rank_cross_section(d): + """Rank values across symbols for each timestamp.""" + return d.apply( + lambda x: x.argsort(dim=1).argsort(dim=1).float() / (x.shape[1] - 1) + ) + +ranked = rank_cross_section(data.close.pct_change()) +``` + +### Sector Analysis + +```python +# Assuming you have sector mapping +sector_map = {"AAPL": "Tech", "GOOG": "Tech", "MSFT": "Tech"} +sectors = [sector_map.get(s, "Other") for s in data.symbols] + +# Group by sector +tech_symbols = [s for s, sec in zip(data.symbols, sectors) if sec == "Tech"] +tech_data = data[:, tech_symbols] + +# Sector average +tech_avg = tech_data.close.mean(axis=1).rename("tech_avg") +``` + +### Universe Filtering + +```python +# Filter by liquidity (threshold adjusted for sample data) +avg_volume = data.volume.mean(axis=0) +liquid_mask = avg_volume.tensor > 900000 +liquid_symbols = data.symbols[liquid_mask.cpu().numpy()] +liquid_data = data[:, liquid_symbols.tolist()] + +# Filter by price +avg_price = data.close.mean(axis=0) +valid_mask = (avg_price.tensor > 5) & (avg_price.tensor < 1000) +valid_symbols = data.symbols[valid_mask.cpu().numpy()] +``` + +### Pair Trading + +```python +# Calculate spread between two stocks +spread = data[:, "AAPL"].close - data[:, "GOOG"].close + +# Normalize spread +spread_mean = spread.mean(axis=0) +spread_std = spread.std(axis=0) +zscore = (spread - spread_mean) / spread_std + +# Generate signals +long_signal = zscore < -2 # Buy AAPL, sell GOOG +short_signal = zscore > 2 # Sell AAPL, buy GOOG +``` diff --git a/docs/flattener.ja.md b/docs/flattener.ja.md new file mode 100644 index 0000000..f8c0545 --- /dev/null +++ b/docs/flattener.ja.md @@ -0,0 +1,245 @@ +# Flattener クラスリファレンス + +`Flattener` クラスは、`Data` オブジェクト(タイムスタンプ/シンボルインデックス付き)とフラットな `torch.Tensor` オブジェクト(単一のバッチインデックス付き)間の変換を支援します。 + + + +## 概要 + +```python +from qfeval_data import Flattener +``` + +Flattener は以下のような場合に便利です: +- 金融データを機械学習モデル用のバッチ形式に変換 +- フラットなテンソル表現での作業 +- モデル出力をタイムスタンプ/シンボルインデックス形式に戻す変換 + +## コンストラクタ + +### `Flattener(*data)` + +1つ以上の Data オブジェクトから Flattener を作成します。 + +**パラメータ:** +- `*data` (`Data`): フラット化マスクを定義する1つ以上の Data オブジェクト + +**動作:** +- 有効な(非 NaN の)タイムスタンプ/シンボルペアのマスクを作成 +- すべての入力 Data オブジェクトは同じタイムスタンプとシンボルを持つ必要あり +- すべての入力 Data において NaN 値がないペアのみが有効と見なされる + +**例:** +```python +from qfeval_data import Data, Flattener + +# data はセットアップで作成済み +flattener = Flattener(data) +``` + +**複数の Data オブジェクトの場合:** +```python +# Flattener は両方のデータセットで有効なペアのみを含む +flattener = Flattener(prices, features) +``` + +--- + +## メソッド + +### `flattener.flatten(data)` + +Data オブジェクトをフラットなテンソルに変換します。 + +**パラメータ:** +- `data` (`Data`): フラット化する Data オブジェクト(コンストラクタ入力と同じタイムスタンプ/シンボルを持つ必要あり) + +**戻り値:** 形状 `(batch_size, *extra_dims)` の `torch.Tensor` + +**形状変換:** +- 入力 Data 形状: `(T, S, *extra_dims)` ここで T=タイムスタンプ数、S=シンボル数 +- 出力テンソル形状: `(B, *extra_dims)` ここで B=有効ペア数 + +**例:** +```python +# data と flattener はセットアップで作成済み +flat_tensor = flattener.flatten(data.close) +print(flat_tensor.shape) # (B,) ここで B = 有効なタイムスタンプ/シンボルペアの数 +``` + +**注意:** +- 有効な(非 NaN の)ペアのみが出力に含まれる +- 要素の順序は行優先(タイムスタンプが最も遅く変化) + +--- + +### `flattener.unflatten(tensor, name="")` + +フラットなテンソルを Data オブジェクトに戻します。 + +**パラメータ:** +- `tensor` (`torch.Tensor`): 形状 `(batch_size, *extra_dims)` のフラットテンソル +- `name` (`str`): 返される Data オブジェクトのカラム名 + +**戻り値:** 元のタイムスタンプ/シンボルに一致する形状の `Data` + +**形状変換:** +- 入力テンソル形状: `(B, *extra_dims)` +- 出力 Data 形状: `(T, S, *extra_dims)` + +**例:** +```python +# 処理後(flat_tensor を処理してシミュレート) +output_tensor = flat_tensor * 2 # 形状: (B,) + +# Data に戻す +predictions = flattener.unflatten(output_tensor, name="prediction") +print(predictions.shape) # (T, S) +``` + +**注意:** +- 無効なペア(フラット化時にマスクされたもの)は NaN で補完 +- テンソルのバッチサイズは構築時の有効ペア数と一致する必要あり + +--- + +### `flattener.timestamp_indexes()` + +フラット化された表現の各要素のタイムスタンプインデックスを取得します。 + +**戻り値:** 形状 `(batch_size,)` の `torch.Tensor` + +**例:** +```python +ts_idx = flattener.timestamp_indexes() +# ts_idx[i] = フラット化されたテンソルの i 番目の要素のタイムスタンプインデックス +``` + +--- + +### `flattener.symbol_indexes()` + +フラット化された表現の各要素のシンボルインデックスを取得します。 + +**戻り値:** 形状 `(batch_size,)` の `torch.Tensor` + +**例:** +```python +sym_idx = flattener.symbol_indexes() +# sym_idx[i] = フラット化されたテンソルの i 番目の要素のシンボルインデックス +``` + +--- + +## 完全な例 + +```python +import torch +from qfeval_data import Data, Flattener + +# data はセットアップで作成済み +print(f"元の形状: {data.shape}") # (4, 2) + +# Flattener 作成 +flattener = Flattener(data) + +# 終値をフラット化 +prices = flattener.flatten(data.close) +print(f"フラット化後の形状: {prices.shape}") # (8,) + +# 何らかの処理 +log_prices = torch.log(prices) + +# Data に戻す +result = flattener.unflatten(log_prices, "log_price") +print(f"結果の形状: {result.shape}") # (4, 2) + +# インデックスマッピングを取得 +ts_idx = flattener.timestamp_indexes() +sym_idx = flattener.symbol_indexes() +print(f"最初の要素: timestamp={ts_idx[0].item()}, symbol={sym_idx[0].item()}") +``` + +--- + +## 機械学習ワークフロー + +```python +import torch +import torch.nn as nn +from qfeval_data import Data, Flattener + +# data はセットアップで作成済み +feature = data.close # シンプルのため単一特徴量 +target = data.close.pct_change() # 日次リターン + +# 両方から Flattener を作成(アライメントを保証) +flattener = Flattener(feature, target) + +# 訓練用にフラット化 +X = flattener.flatten(feature).unsqueeze(-1) # 形状: (B, 1) +y = flattener.flatten(target) # 形状: (B,) + +# モデル訓練 +model = nn.Linear(1, 1) +optimizer = torch.optim.Adam(model.parameters()) + +for epoch in range(10): # 例のため短い訓練 + pred = model(X).squeeze() + loss = ((pred - y) ** 2).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + +# 予測を作成してアンフラット化 +with torch.no_grad(): + predictions = model(X).squeeze() + pred_data = flattener.unflatten(predictions, "prediction") + +# pred_data は元データと同じタイムスタンプ/シンボル構造を持つ +print(pred_data.shape) +``` + +--- + +## 注意事項 + +1. **メモリ効率**: フラット化は有効な要素をコピーして連続テンソルを作成します(ビューではない) +2. **NaN 処理**: 非 NaN の要素のみがフラット化された表現に含まれる +3. **デバイス一貫性**: Flattener は入力 Data と同じデバイスで動作 +4. **バッチ次元**: フラット化されたテンソルはタイムスタンプとシンボルの次元を単一のバッチ次元に結合 diff --git a/docs/flattener.md b/docs/flattener.md new file mode 100644 index 0000000..70cbef7 --- /dev/null +++ b/docs/flattener.md @@ -0,0 +1,237 @@ +# Flattener Class Reference + +The `Flattener` class assists conversion between `Data` objects (with timestamp/symbol indices) and flat `torch.Tensor` objects (with a single batch index). + + + +## Overview + +```python +from qfeval_data import Flattener +``` + +The Flattener is useful when you need to: +- Convert financial data to batch format for machine learning models +- Work with flat tensor representations +- Convert model outputs back to timestamp/symbol indexed format + +## Constructor + +### `Flattener(*data)` + +Create a Flattener from one or more Data objects. + +**Parameters:** +- `*data` (`Data`): One or more Data objects that define the flattening mask + +**Behavior:** +- Creates a mask of valid (non-NaN) timestamp/symbol pairs +- All input Data objects must have the same timestamps and symbols +- A pair is considered valid if it has no NaN values across all input Data + +**Example:** +```python +from qfeval_data import Data, Flattener + +# data is created in setup +flattener = Flattener(data) +``` + +**With multiple Data objects:** +```python +# Flattener will only include pairs valid in BOTH datasets +flattener = Flattener(prices, features) +``` + +--- + +## Methods + +### `flattener.flatten(data)` + +Convert a Data object to a flat tensor. + +**Parameters:** +- `data` (`Data`): Data object to flatten (must have same timestamps/symbols as constructor input) + +**Returns:** `torch.Tensor` with shape `(batch_size, *extra_dims)` + +**Shape transformation:** +- Input Data shape: `(T, S, *extra_dims)` where T=timestamps, S=symbols +- Output tensor shape: `(B, *extra_dims)` where B=number of valid pairs + +**Example:** +```python +# data and flattener are created in setup +flat_tensor = flattener.flatten(data.close) +print(flat_tensor.shape) # (B,) where B = number of valid timestamp/symbol pairs +``` + +**Notes:** +- Only valid (non-NaN) pairs are included in the output +- The order of elements follows row-major order (timestamp varies slowest) + +--- + +### `flattener.unflatten(tensor, name="")` + +Convert a flat tensor back to a Data object. + +**Parameters:** +- `tensor` (`torch.Tensor`): Flat tensor with shape `(batch_size, *extra_dims)` +- `name` (`str`): Column name for the returned Data object + +**Returns:** `Data` with shape matching the original timestamps/symbols + +**Shape transformation:** +- Input tensor shape: `(B, *extra_dims)` +- Output Data shape: `(T, S, *extra_dims)` + +**Example:** +```python +# After processing (simulate model output) +output_tensor = flat_tensor * 2 # shape: (B,) + +# Convert back to Data +predictions = flattener.unflatten(output_tensor, name="prediction") +print(predictions.shape) # (T, S) +``` + +**Notes:** +- Invalid pairs (those that were masked during flattening) are filled with NaN +- Tensor batch size must match the number of valid pairs from construction + +--- + +### `flattener.timestamp_indexes()` + +Get the timestamp index for each element in the flattened representation. + +**Returns:** `torch.Tensor` with shape `(batch_size,)` + +**Example:** +```python +ts_idx = flattener.timestamp_indexes() +# ts_idx[i] = timestamp index of the i-th element in flattened tensor +``` + +--- + +### `flattener.symbol_indexes()` + +Get the symbol index for each element in the flattened representation. + +**Returns:** `torch.Tensor` with shape `(batch_size,)` + +**Example:** +```python +sym_idx = flattener.symbol_indexes() +# sym_idx[i] = symbol index of the i-th element in flattened tensor +``` + +--- + +## Complete Example + +```python +import torch +from qfeval_data import Data, Flattener + +# data is created in setup +print(f"Original shape: {data.shape}") # (4, 2) + +# Create flattener +flattener = Flattener(data) + +# Flatten closing prices +prices = flattener.flatten(data.close) +print(f"Flattened shape: {prices.shape}") # (8,) + +# Do some processing +log_prices = torch.log(prices) + +# Unflatten back to Data +result = flattener.unflatten(log_prices, "log_price") +print(f"Result shape: {result.shape}") # (4, 2) + +# Get index mapping +ts_idx = flattener.timestamp_indexes() +sym_idx = flattener.symbol_indexes() +print(f"First element: timestamp={ts_idx[0].item()}, symbol={sym_idx[0].item()}") +``` + +--- + +## Machine Learning Workflow + +```python +import torch +import torch.nn as nn +from qfeval_data import Data, Flattener + +# data is created in setup +feature = data.close # Single feature for simplicity +target = data.close.pct_change() # Daily return + +# Create flattener from both (ensures alignment) +flattener = Flattener(feature, target) + +# Flatten for training +X = flattener.flatten(feature).unsqueeze(-1) # shape: (B, 1) +y = flattener.flatten(target) # shape: (B,) + +# Train model +model = nn.Linear(1, 1) +optimizer = torch.optim.Adam(model.parameters()) + +for epoch in range(10): # Short training for example + pred = model(X).squeeze() + loss = ((pred - y) ** 2).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + +# Make predictions and unflatten +with torch.no_grad(): + predictions = model(X).squeeze() + pred_data = flattener.unflatten(predictions, "prediction") + +# Now pred_data has the same timestamp/symbol structure as original data +print(pred_data.shape) +``` + +--- + +## Notes + +1. **Memory Efficiency**: Flattening creates a contiguous tensor by copying valid elements, not a view +2. **NaN Handling**: Only non-NaN elements are included in the flattened representation +3. **Device Consistency**: The flattener operates on the same device as the input Data +4. **Batch Dimension**: The flattened tensor combines timestamp and symbol dimensions into a single batch dimension diff --git a/docs/util.ja.md b/docs/util.ja.md new file mode 100644 index 0000000..53d69ea --- /dev/null +++ b/docs/util.ja.md @@ -0,0 +1,303 @@ +# ユーティリティ関数リファレンス + +`qfeval_data.util` モジュールは、配列操作、時間計算、その他のユーティリティのためのヘルパー関数を提供します。 + + + +## 概要 + +```python +from qfeval_data import util +``` + +--- + +## 配列操作 + +### `util.to_numpy(tensor)` + +PyTorch テンソルを NumPy 配列に変換します。 + +**パラメータ:** +- `tensor` (`torch.Tensor`): PyTorch テンソル(GPU 上でも勾配があっても可) + +**戻り値:** `np.ndarray` + +**例:** +```python +import torch +from qfeval_data import util + +tensor = torch.tensor([1.0, 2.0, 3.0]) +array = util.to_numpy(tensor) +print(type(array)) # +``` + +**注意:** +- 計算グラフから自動的にデタッチ +- 必要に応じて GPU から CPU に自動的に移動 + +--- + +### `util.nans(shape=None, like=None)` + +NaN 値で埋められたテンソルを作成します。 + +**パラメータ:** +- `shape` (`Tuple[int, ...]`, 省略可): 出力テンソルの形状 +- `like` (`torch.Tensor`): dtype と device の参照テンソル + +**戻り値:** `torch.Tensor` + +**例:** +```python +import torch +from qfeval_data import util + +ref = torch.tensor([1.0, 2.0]) +nans = util.nans((3, 4), like=ref) +print(nans.shape) # torch.Size([3, 4]) +print(nans.device) # cpu +``` + +**注意:** +- `like` パラメータは必須 +- `shape` が None の場合、`like` の形状を使用 + +--- + +### `util.make_array_mapping(ref, like)` + +2つのソート済み配列間のインデックスマッピングを作成します。 + +**パラメータ:** +- `ref` (`np.ndarray`): 参照配列(ソート済み) +- `like` (`np.ndarray`): マッピング元の配列(ソート済み) + +**戻り値:** `Tuple[np.ndarray, np.ndarray]` +- 最初の配列: `ref[indices[i]] == like[i]` となるインデックス +- 2番目の配列: マッピングが無効な場合に True となるブールマスク + +**例:** +```python +import numpy as np +from qfeval_data import util + +ref = np.array(["A", "B", "C", "D"]) +like = np.array(["B", "D", "E"]) + +indexes, mask = util.make_array_mapping(ref, like) +print(indexes) # [1, 3, 0] (E のインデックスは 0 だがマスクされる) +print(mask) # [False, False, True] (E は ref にない) +``` + +**ユースケース:** 異なるシンボルを持つ異なるソースからのデータを揃える。 + +--- + +### `util.are_broadcastable_shapes(*shapes)` + +形状が NumPy ライクな演算でブロードキャスト可能かチェックします。 + +**パラメータ:** +- `*shapes` (`Tuple[int, ...]` または `torch.Size`): チェックする形状 + +**戻り値:** `bool` + +**例:** +```python +from qfeval_data import util + +print(util.are_broadcastable_shapes((3, 4), (4,))) # True +print(util.are_broadcastable_shapes((3, 4), (3, 1))) # True +print(util.are_broadcastable_shapes((3, 4), (2, 4))) # False +``` + +--- + +## 時間関数 + +### `util.floor_time(t, d, origin=None, offset=None)` + +日時を時間間隔で切り捨てます。 + +**パラメータ:** +- `t` (`np.datetime64` または `np.ndarray`): 切り捨てるタイムスタンプ +- `d` (`np.timedelta64`): 時間間隔 +- `origin` (`np.datetime64`, 省略可): 間隔計算の起点 +- `offset` (`np.timedelta64`, 省略可): 切り捨て前に適用するオフセット + +**戻り値:** `np.datetime64` または `np.ndarray` + +**例:** +```python +import numpy as np +from qfeval_data import util + +t = np.datetime64("2024-01-15T14:35:00") +d = np.timedelta64(1, "h") + +floored = util.floor_time(t, d) +print(floored) # 2024-01-15T14:00:00 +``` + +**オフセット付き(タイムゾーン調整):** +```python +# 9時間オフセットで日境界に切り捨て(JST タイムゾーン) +t = np.datetime64("2024-01-15T08:00:00") # UTC +offset = np.timedelta64(9, "h") +floored = util.floor_time(t, np.timedelta64(1, "D"), offset=offset) +# 結果は JST の日境界を考慮 +``` + +--- + +### `util.ceil_time(t, d, origin=None, offset=None)` + +日時を時間間隔で切り上げます。 + +**パラメータ:** +- `t` (`np.datetime64` または `np.ndarray`): 切り上げるタイムスタンプ +- `d` (`np.timedelta64`): 時間間隔 +- `origin` (`np.datetime64`, 省略可): 間隔計算の起点 +- `offset` (`np.timedelta64`, 省略可): 切り上げ前に適用するオフセット + +**戻り値:** `np.datetime64` または `np.ndarray` + +**例:** +```python +import numpy as np +from qfeval_data import util + +t = np.datetime64("2024-01-15T14:35:00") +d = np.timedelta64(1, "h") + +ceiled = util.ceil_time(t, d) +print(ceiled) # 2024-01-15T15:00:00 + +# すでに境界上にある場合は同じ値を返す +t2 = np.datetime64("2024-01-15T14:00:00") +print(util.ceil_time(t2, d)) # 2024-01-15T14:00:00 +``` + +--- + +### `util.time_origin(d)` + +指定された間隔のデフォルト時間起点を取得します。 + +**パラメータ:** +- `d` (`np.timedelta64`): 時間間隔 + +**戻り値:** `np.datetime64` + +**動作:** +- 月/年間隔の場合: `1000-01-01` を返す +- その他の間隔の場合: `1893-01-01` を返す(日曜日、ダウ・ジョーンズより前) + +**例:** +```python +import numpy as np +from qfeval_data import util + +print(util.time_origin(np.timedelta64(1, "D"))) # 1893-01-01 +print(util.time_origin(np.timedelta64(1, "M"))) # 1000-01-01 +``` + +**注意:** +- `1893-01-01` が選ばれた理由: + - 日曜日である(週計算に有用) + - ダウ・ジョーンズ工業株価平均(1896年)より前 + - ナノ秒精度で十分な範囲を確保 +- 週間隔は日曜日から始まる7日間を使用 + +--- + +## その他のユーティリティ + +### `util.sha1(x)` + +様々なデータ型の SHA1 ハッシュを計算します。 + +**パラメータ:** +- `x` (`bytes`, `str`, `np.ndarray`, または `torch.Tensor`): ハッシュするデータ + +**戻り値:** `str`(16進数ハッシュ) + +**例:** +```python +import numpy as np +import torch +from qfeval_data import util + +print(util.sha1("hello")) # aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d +print(util.sha1(np.array([1, 2, 3]))) # ハッシュには形状情報が含まれる +print(util.sha1(torch.tensor([1.0, 2.0]))) # テンソルでも動作 +``` + +**注意:** +- 配列/テンソルの場合、ハッシュには形状情報が含まれる +- テンソルはハッシュ前に自動的に NumPy に変換 + +--- + +### `util.gc()` + +ガベージコレクションを実行し、GPU メモリをクリアします。 + +**例:** +```python +from qfeval_data import util + +# 処理後にメモリを解放 +util.gc() +``` + +**動作:** +- Python ガベージコレクション(世代2)を実行 +- GPU が利用可能な場合、CUDA キャッシュをクリア + +--- + +### `util.torch_device(device)` + +デバイス指定を `torch.device` にパースします。 + +**パラメータ:** +- `device` (`str`, `torch.device`, または `None`): デバイス指定 + +**戻り値:** `torch.device` + +**特別な値:** +- `None`: CPU デバイスを返す +- `"auto"`: CUDA が利用可能なら CUDA、そうでなければ CPU を返す +- `"cpu"`, `"cuda"`, `"cuda:0"` など: 標準 PyTorch デバイス文字列 + +**例:** +```python +from qfeval_data import util + +print(util.torch_device(None)) # cpu +print(util.torch_device("auto")) # cuda または cpu +print(util.torch_device("cuda:0")) # cuda:0 +``` + +--- + +## 型変数 + +モジュールはジェネリック型付けのための型変数を定義しています: + +```python +import typing + +# ジェネリック型 +T = typing.TypeVar("T") + +# 配列ライク型(torch.Tensor または np.ndarray) +Array = typing.TypeVar("Array", torch.Tensor, np.ndarray) +``` diff --git a/docs/util.md b/docs/util.md new file mode 100644 index 0000000..f6c0f8d --- /dev/null +++ b/docs/util.md @@ -0,0 +1,303 @@ +# Utility Functions Reference + +The `qfeval_data.util` module provides helper functions for array operations, time calculations, and other utilities. + + + +## Overview + +```python +from qfeval_data import util +``` + +--- + +## Array Operations + +### `util.to_numpy(tensor)` + +Convert a PyTorch tensor to a NumPy array. + +**Parameters:** +- `tensor` (`torch.Tensor`): PyTorch tensor (can be on GPU or have gradients) + +**Returns:** `np.ndarray` + +**Example:** +```python +import torch +from qfeval_data import util + +tensor = torch.tensor([1.0, 2.0, 3.0]) +array = util.to_numpy(tensor) +print(type(array)) # +``` + +**Notes:** +- Automatically detaches from computation graph +- Automatically moves from GPU to CPU if necessary + +--- + +### `util.nans(shape=None, like=None)` + +Create a tensor filled with NaN values. + +**Parameters:** +- `shape` (`Tuple[int, ...]`, optional): Shape of the output tensor +- `like` (`torch.Tensor`): Reference tensor for dtype and device + +**Returns:** `torch.Tensor` + +**Example:** +```python +import torch +from qfeval_data import util + +ref = torch.tensor([1.0, 2.0]) +nans = util.nans((3, 4), like=ref) +print(nans.shape) # torch.Size([3, 4]) +print(nans.device) # cpu +``` + +**Notes:** +- `like` parameter is required +- If `shape` is None, uses the shape of `like` + +--- + +### `util.make_array_mapping(ref, like)` + +Create index mapping between two sorted arrays. + +**Parameters:** +- `ref` (`np.ndarray`): Reference array (sorted) +- `like` (`np.ndarray`): Array to map from (sorted) + +**Returns:** `Tuple[np.ndarray, np.ndarray]` +- First array: indices where `ref[indices[i]] == like[i]` +- Second array: boolean mask where True indicates the mapping is invalid + +**Example:** +```python +import numpy as np +from qfeval_data import util + +ref = np.array(["A", "B", "C", "D"]) +like = np.array(["B", "D", "E"]) + +indexes, mask = util.make_array_mapping(ref, like) +print(indexes) # [1, 3, 0] (index for E is 0, but masked) +print(mask) # [False, False, True] (E not in ref) +``` + +**Use case:** Aligning data from different sources with different symbols. + +--- + +### `util.are_broadcastable_shapes(*shapes)` + +Check if shapes are broadcastable for NumPy-like operations. + +**Parameters:** +- `*shapes` (`Tuple[int, ...]` or `torch.Size`): Shapes to check + +**Returns:** `bool` + +**Example:** +```python +from qfeval_data import util + +print(util.are_broadcastable_shapes((3, 4), (4,))) # True +print(util.are_broadcastable_shapes((3, 4), (3, 1))) # True +print(util.are_broadcastable_shapes((3, 4), (2, 4))) # False +``` + +--- + +## Time Functions + +### `util.floor_time(t, d, origin=None, offset=None)` + +Floor datetime to a time interval. + +**Parameters:** +- `t` (`np.datetime64` or `np.ndarray`): Timestamp(s) to floor +- `d` (`np.timedelta64`): Time interval +- `origin` (`np.datetime64`, optional): Origin for interval calculation +- `offset` (`np.timedelta64`, optional): Offset to apply before flooring + +**Returns:** `np.datetime64` or `np.ndarray` + +**Example:** +```python +import numpy as np +from qfeval_data import util + +t = np.datetime64("2024-01-15T14:35:00") +d = np.timedelta64(1, "h") + +floored = util.floor_time(t, d) +print(floored) # 2024-01-15T14:00:00 +``` + +**With offset (timezone adjustment):** +```python +# Floor to day boundary with 9-hour offset (JST timezone) +t = np.datetime64("2024-01-15T08:00:00") # UTC +offset = np.timedelta64(9, "h") +floored = util.floor_time(t, np.timedelta64(1, "D"), offset=offset) +# Result considers JST day boundary +``` + +--- + +### `util.ceil_time(t, d, origin=None, offset=None)` + +Ceil datetime to a time interval. + +**Parameters:** +- `t` (`np.datetime64` or `np.ndarray`): Timestamp(s) to ceil +- `d` (`np.timedelta64`): Time interval +- `origin` (`np.datetime64`, optional): Origin for interval calculation +- `offset` (`np.timedelta64`, optional): Offset to apply before ceiling + +**Returns:** `np.datetime64` or `np.ndarray` + +**Example:** +```python +import numpy as np +from qfeval_data import util + +t = np.datetime64("2024-01-15T14:35:00") +d = np.timedelta64(1, "h") + +ceiled = util.ceil_time(t, d) +print(ceiled) # 2024-01-15T15:00:00 + +# If already on boundary, returns same value +t2 = np.datetime64("2024-01-15T14:00:00") +print(util.ceil_time(t2, d)) # 2024-01-15T14:00:00 +``` + +--- + +### `util.time_origin(d)` + +Get the default time origin for a given interval. + +**Parameters:** +- `d` (`np.timedelta64`): Time interval + +**Returns:** `np.datetime64` + +**Behavior:** +- For monthly/yearly intervals: returns `1000-01-01` +- For other intervals: returns `1893-01-01` (a Sunday, predating Dow Jones) + +**Example:** +```python +import numpy as np +from qfeval_data import util + +print(util.time_origin(np.timedelta64(1, "D"))) # 1893-01-01 +print(util.time_origin(np.timedelta64(1, "M"))) # 1000-01-01 +``` + +**Notes:** +- `1893-01-01` is chosen because: + - It's a Sunday (useful for week calculations) + - Predates Dow Jones Industrial Average (1896) + - Allows sufficient range for nanosecond precision +- Weekly intervals use 7-day periods starting from Sunday + +--- + +## Other Utilities + +### `util.sha1(x)` + +Compute SHA1 hash of various data types. + +**Parameters:** +- `x` (`bytes`, `str`, `np.ndarray`, or `torch.Tensor`): Data to hash + +**Returns:** `str` (hexadecimal hash) + +**Example:** +```python +import numpy as np +import torch +from qfeval_data import util + +print(util.sha1("hello")) # aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d +print(util.sha1(np.array([1, 2, 3]))) # hash includes shape info +print(util.sha1(torch.tensor([1.0, 2.0]))) # works with tensors +``` + +**Notes:** +- For arrays/tensors, the hash includes shape information +- Tensors are automatically converted to NumPy before hashing + +--- + +### `util.gc()` + +Run garbage collection and clear GPU memory. + +**Example:** +```python +from qfeval_data import util + +# Free memory after processing +util.gc() +``` + +**Behavior:** +- Runs Python garbage collection (generation 2) +- Clears CUDA cache if GPUs are available + +--- + +### `util.torch_device(device)` + +Parse device specification to `torch.device`. + +**Parameters:** +- `device` (`str`, `torch.device`, or `None`): Device specification + +**Returns:** `torch.device` + +**Special values:** +- `None`: Returns CPU device +- `"auto"`: Returns CUDA if available, otherwise CPU +- `"cpu"`, `"cuda"`, `"cuda:0"`, etc.: Standard PyTorch device strings + +**Example:** +```python +from qfeval_data import util + +print(util.torch_device(None)) # cpu +print(util.torch_device("auto")) # cuda or cpu +print(util.torch_device("cuda:0")) # cuda:0 +``` + +--- + +## Type Variables + +The module defines type variables for generic typing: + +```python +import typing + +# Generic type +T = typing.TypeVar("T") + +# Array-like type (torch.Tensor or np.ndarray) +Array = typing.TypeVar("Array", torch.Tensor, np.ndarray) +``` diff --git a/tests/data/fundamentals.csv b/tests/data/fundamentals.csv new file mode 100644 index 0000000..ac75907 --- /dev/null +++ b/tests/data/fundamentals.csv @@ -0,0 +1,22 @@ +timestamp,symbol,pe_ratio,market_cap +2024-01-02,AAPL,28.5,2900000000000 +2024-01-02,GOOG,25.0,1800000000000 +2024-01-02,MSFT,35.0,2800000000000 +2024-01-03,AAPL,28.6,2910000000000 +2024-01-03,GOOG,25.1,1810000000000 +2024-01-03,MSFT,35.1,2810000000000 +2024-01-04,AAPL,28.7,2920000000000 +2024-01-04,GOOG,25.2,1820000000000 +2024-01-04,MSFT,35.2,2820000000000 +2024-01-05,AAPL,28.8,2930000000000 +2024-01-05,GOOG,25.3,1830000000000 +2024-01-05,MSFT,35.3,2830000000000 +2024-01-08,AAPL,28.9,2940000000000 +2024-01-08,GOOG,25.4,1840000000000 +2024-01-08,MSFT,35.4,2840000000000 +2024-01-09,AAPL,29.0,2950000000000 +2024-01-09,GOOG,25.5,1850000000000 +2024-01-09,MSFT,35.5,2850000000000 +2024-01-10,AAPL,29.1,2960000000000 +2024-01-10,GOOG,25.6,1860000000000 +2024-01-10,MSFT,35.6,2860000000000 diff --git a/tests/data/prices.csv b/tests/data/prices.csv new file mode 100644 index 0000000..effd930 --- /dev/null +++ b/tests/data/prices.csv @@ -0,0 +1,22 @@ +timestamp,symbol,open,high,low,close,volume +2024-01-02,AAPL,185.5,186.2,184.1,185.8,50000000 +2024-01-02,GOOG,140.0,141.5,139.5,141.0,20000000 +2024-01-02,MSFT,370.0,372.0,368.0,371.0,25000000 +2024-01-03,AAPL,186.0,187.5,185.0,186.5,48000000 +2024-01-03,GOOG,141.0,142.0,140.0,141.5,19000000 +2024-01-03,MSFT,371.0,373.0,369.0,372.0,24000000 +2024-01-04,AAPL,186.5,188.0,185.5,187.0,52000000 +2024-01-04,GOOG,141.5,143.0,141.0,142.5,21000000 +2024-01-04,MSFT,372.0,374.0,370.0,373.0,26000000 +2024-01-05,AAPL,187.0,189.0,186.0,188.5,55000000 +2024-01-05,GOOG,142.5,144.0,142.0,143.5,22000000 +2024-01-05,MSFT,373.0,375.0,371.0,374.0,27000000 +2024-01-08,AAPL,188.5,190.0,187.5,189.0,53000000 +2024-01-08,GOOG,143.5,145.0,143.0,144.0,20000000 +2024-01-08,MSFT,374.0,376.0,372.0,375.0,25000000 +2024-01-09,AAPL,189.0,191.0,188.0,190.5,58000000 +2024-01-09,GOOG,144.0,146.0,143.5,145.5,23000000 +2024-01-09,MSFT,375.0,377.0,373.0,376.0,28000000 +2024-01-10,AAPL,190.5,192.0,189.5,191.0,54000000 +2024-01-10,GOOG,145.5,147.0,145.0,146.0,21000000 +2024-01-10,MSFT,376.0,378.0,374.0,377.0,26000000 diff --git a/tests/test_docs.py b/tests/test_docs.py new file mode 100644 index 0000000..12d6b09 --- /dev/null +++ b/tests/test_docs.py @@ -0,0 +1,95 @@ +"""Tests for documentation code examples. + +This module extracts and tests code examples from markdown files in AGENTS.md +and docs/*.md. + +Markdown format: +- Code blocks with ```python are extracted and executed +- HTML comments before a code block skip that block +- HTML comments contain hidden setup code +- HTML comments contain hidden teardown code +""" + +import re +from pathlib import Path +from typing import List, Tuple + +import pytest + + +def extract_code_blocks(markdown_content: str) -> List[Tuple[str, bool]]: + """Extract Python code blocks from markdown content. + + Returns: + List of (code, should_skip) tuples. + """ + blocks: List[Tuple[str, bool]] = [] + + # Extract hidden setup code from HTML comments + setup_pattern = r"" + setup_matches = re.findall(setup_pattern, markdown_content, re.DOTALL) + setup_code = "\n".join(setup_matches) + + # Find all code blocks with skip markers + # Pattern matches optional skip comment followed by python code block + pattern = r"(?:\s*)?(```python\n(.*?)\n```)" + matches = list(re.finditer(pattern, markdown_content, re.DOTALL)) + + for match in matches: + full_match = match.group(0) + code = match.group(2) + should_skip = full_match.strip().startswith("