๐ ์ธ๊ณผ์ถ๋ก ๊ฐ์ธ ๊ณต๋ถ์ฉ ํฌ์คํธ ๊ธ์ ๋๋ค. ์ถ์ฒ๋ ์ฒจ๋ถํ ๋งํฌ๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์!
โป ์ ๋ฆฌ
๐ ์ ๋ฆฌ
• ํ๊ท : ๋ฐ์ดํฐ๋ฅผ ์ ๋ก ๋ถํ ํ๊ณ , ๊ฐ ์ ์์ ATE ๋ฅผ ๊ณ์ฐํ ๋ค์, ์ ์ ATE ๋ฅผ ์ ์ฒด ๋ฐ์ดํฐ์ ์ ๋ํ ๋จ์ผ ATE ๋ก ๊ฒฐํฉํ๋ ๊ฒ
• ๋งค์นญ estimator
โ What is Regression Doing After All?
โฏ ํ๊ท๋ถ์
• ํ๊ท๋ถ์์ ์ ์ฉํ๋ฉด Treatment group ๊ณผ Control group ์ ๋น๊ตํ ๋, ์ถ๊ฐ์ ์ธ ๋ณ์๋ค์ ์ ์ดํ ์ ์๋ค. ์ฆ, X๋ฅผ ํต์ ํจ์ผ๋ก์จ ATE ๋ฅผ ์๋ณํ ์ ์๋ค : (Y0, Y1) ⊥ T | X โจ ์กฐ๊ฑด๋ถ ๋ ๋ฆฝ์ฑ ๊ฐ์
• ํ๊ท๋ถ์๊ณผ Matching ์ functional form ์ ๊ฐ์ ํ๋๋ ์ ํ๋๋์ ์ฐจ์ด๋ง ์กด์ฌํ๋ค.
โก The Subclassification Estimator
• ์ถ์ ํ๊ณ ์ ํ๋ ์ธ๊ณผํจ๊ณผ๊ฐ ์์ผ๋ ๊ต๋์์ธ X ๋๋ฌธ์ ์ถ์ ์ด ์ด๋ ค์ด ๊ฒฝ์ฐ, ๊ต๋์์ธ์ ํจ๊ณผ๊ฐ ๋์ผํ ์๊ทธ๋ฃน ๋ด์์ Treatment ์ control group ์ ๋น๊ตํด์ผ ํ๋ค. ์กฐ๊ฑด๋ถ ๋ ๋ฆฝ ๊ฐ์ ์ผ๋ก ๋ง์กฑํ๋ค๋ฉด ATE ๋ ์๋์ ๊ฐ์ด ๊ณ์ฐ๋ ์ ์๋ค.
โช ๋ณ์ X๊ฐ K ๊ฐ์ ์๋ก ๋ค๋ฅธ ์ {X1, X2, ... , Xk} ์ ์ทจํ๋ค๊ณ ๋งํ ์ ์์ผ๋ฉฐ, ๊ฐ ์ ์ treatment ํจ๊ณผ๋ฅผ ๊ณ์ฐํ๊ณ ์ด๋ฅผ ATE ๋ก ๊ฒฐํฉํ๋ค.
โข Matching Estimator
โฏ ์์ 1
• ex. ์ฐ์ ํ๋ก๊ทธ๋จ์ด ์์ ์ ๋ฏธ์น๋ ์ํฅ ์ถ์
...
• ํ๊ท ์ ์ผ๋ก ๋จ์ ๋น๊ตํด๋ณด๋ฉด ์ฐ์์์ด ์ฐ์๋ฅผ ๋ฐ์ง ์์ ์ฌ๋๋ณด๋ค ๋์ ๋ ๋ฒ๋ค๋ ๊ฒ์ ์ ์ ์๋ค.
trainee.query("trainees==1")["earnings"].mean() - trainee.query("trainees==0")["earnings"].mean()
# -4297.49373433584
• ๊ทธ๋ฌ๋ ํ๋ฅผ ๋ณด๋ฉด, ์ฐ์์์ด ์ฐ์์์ด ์๋ ์ฌ๋๋ณด๋ค ํจ์ฌ ์ด๋ฆฌ๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ ์ ์๋ค. ๋ฐ๋ผ์ ์ด๋ฅผ ๋ฐ์ํ๊ธฐ ์ํด ๋์ด๋ฅผ ๋์ผํ๊ฒ ๋ง์ถฐ์ค๋ค. ๊ฐ๋ น 28์ธ๋ก ๋์ด๊ฐ ๋์ผํ unit 1 ๊ณผ unit 27 ์ ๋งค์นญํ๋ ์์ผ๋ก ์งํํ ์ ์๋ค.
• 1๊ฐ ์ด์์ unit ์ด ์ผ์นํ๋ ๊ฒฝ์ฐ, ํด๋น ๊ทธ๋ฃน ์ค์์ ๋ฌด์์๋ก ์ ํํ ์ ์๋ค.
# make dataset where no one has the same age
unique_on_age = (trainee
.query("trainees==0")
.drop_duplicates("age"))
matches = (trainee
.query("trainees==1")
.merge(unique_on_age, on="age", how="left", suffixes=("_t_1", "_t_0"))
.assign(t1_minuts_t0 = lambda d: d["earnings_t_1"] - d["earnings_t_0"]))
matches.head(7)
• ๋ง์ง๋ง ์ด (t1-t0) ์ ํ๊ท ์ ์ทจํ๋ฉด ์ฐ๋ น์ ํต์ ํ๋ฉด์ ATET ์ถ์ ์น๋ฅผ ์ป์ ์ ์๋ค.
matches["t1_minuts_t0"].mean()
# 2457.8947368421054
• ํ์ง๋ง ์ค์ ๋ก ๋งค์นญ์์ ์ผ๋ฐ์ ์ผ๋ก ํ๋ ์ด์์ ๋ณ์๋ฅผ ๊ฐ์ง๊ณ ์๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ผ๋ฉฐ unit ์ ๊ฐ์ด ์๋ฒฝํ๊ฒ ์ผ์นํ๋ ๊ฒฝ์ฐ๋ ๋๋ฌผ๋ค. ์ด๋ฌํ ๊ฒฝ์ฐ unit ์ด ์๋ก ์ผ๋ง๋ ๊ฐ๊น์ด์ง ๋น๊ตํ๊ธฐ ์ํด ์ธก์ ํ๋ ๋ฐฉ์์ด ํ์ํ๊ณ ์ผ๋ฐ์ ์ผ๋ก ์ ํด๋ฆฌ๋์ ๊ฑฐ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ค. ์ด๋ ์ ํด๋ฆฌ๋์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ๋ ค๋ฉด ๋ณ์๊ฐ ๋๋ต ๊ฐ์ ์ค์ผ์ผ์ด ๋๋๋ก ์กฐ์ ํด์ผ ํ๋ ๊ณผ์ ์ด ํ์ํ๋ค.
โฏ ์์ 2
• ex. ํ์๊ฐ ํ๋ณต๊น์ง ๋ฉฐ์น ์ด ๊ฑธ๋ฆฌ๋์ง๋ฅผ ํตํด ์ฝ๋ฌผ์ ํจ๊ณผ๋ฅผ ๊ณ์ฐ
med = pd.read_csv("./data/medicine_impact_recovery.csv")
med.head()
• ๋จ์ํ ํ๊ท ์ฐจ์ด๋ฅผ ๊ณ์ฐํ๋ค๋ฉด E(Y|T=1) - E(Y|T=0) ์ด๊ณ , ์ด๋ฅผ ๊ณ์ฐํ๋ฉด ์ฝ๋ฌผ์ ํฌ์ฝํ ํ์๊ฐ ํ๋ณตํ๋๋ฐ ํ๊ท 16.9์ผ์ด ๋ ๊ฑธ๋ฆฐ๋ค๋ ๊ฒฐ๊ณผ๋ฅผ ๋์ถํ๋ค. ๊ต๋์์ธ์ผ๋ก ์ธํด ์ด๋ฌํ (์ง๊ด๊ณผ ๋ฐ๋๋) ๊ฒฐ๊ณผ๊ฐ ๋์จ ๊ฒ์ผ๋ก ์ถ์ธกํด ๋ณผ ์ ์๋ค.
med.query("medication==1")["recovery"].mean() - med.query("medication==0")["recovery"].mean()
# 16.895799546498726
• ์ด๋ฌํ ํธํฅ์ ์กฐ์ ํ๊ธฐ ์ํด X๋ฅผ ํต์ ํ๋ค. ์ผ๋จ ๋ณ์๋ฅผ ์ค์ผ์ผ๋งํด์ค๋ค.
# scale features
X = ["severity", "age", "sex"]
y = "recovery"
med = med.assign(**{f: (med[f] - med[f].mean())/med[f].std() for f in X})
med.head()
• KNN ์๊ณ ๋ฆฌ์ฆ์ ์ด์ฉํด ๋งค์นญ์ ์งํํ๋ค. mt0 ๋ ์ฒ์น๋์ง ์์ ๊ด์ธก์น๋ฅผ ์ ์ฅํ๊ณ , mt1 ์ ์ฒ์น๋ ๊ด์ธก์น๋ฅผ ์ ์ฅํ๋ค.
from sklearn.neighbors import KNeighborsRegressor
treated = med.query("medication==1")
untreated = med.query("medication==0")
mt0 = KNeighborsRegressor(n_neighbors=1).fit(untreated[X], untreated[y])
mt1 = KNeighborsRegressor(n_neighbors=1).fit(treated[X], treated[y])
predicted = pd.concat([
# find matches for the treated looking at the untreated knn model
treated.assign(match=mt0.predict(treated[X])),
# find matches for the untreated looking at the treated knn model
untreated.assign(match=mt1.predict(untreated[X]))
])
predicted.head()
• ๋งค์นญ์ ํตํด ATE ๋ฅผ ๊ณ์ฐํด๋ผ ์ ์๋ค.
np.mean((2*predicted["medication"] - 1)*(predicted["recovery"] - predicted["match"]))
# -0.9954
๋งค์นญ์ ํตํด, X๋ฅผ ํต์ ํ ๋, ์ฝ์ด ํ๊ท ์ ์ผ๋ก ์ฝ 1์ผ์ ๋ ํ๋ณต์๊ฐ์ ๋จ์ถํ๋ค๋ ๊ฒ์ ํ์ธํด๋ณผ ์ ์๋ค.
โฃ Matching bias
โฏ ols ๋ฅผ ํ์ฉํ bias ๋ณด์
• ATET ๋ ํธํฅ๋ estimator ์ด๋ค. (์ฆ๋ช ๊ณผ์ ์๋ต)
• Bias ๋ ๋งค์นญ ๋ถ์ผ์น์ ์ ๋๊ฐ ํด ๋ ๋ฐ์ํ๋ค. bias ๋ฅผ ์ค์ด๊ธฐ ์ํด ATET ๋ฅผ ์๋์ ์์๊ณผ ๊ฐ์ด ์ ์ํ ์ ์๋ค. μ^0(x) ๋ E[Y|X,T=0] ์ผ๋ก ์ฒ์น๋์ง ์์ ์ํ์ ํผํ ํ ์ ํํ๊ท์ ๊ฐ๋ค.
• OLS ๋ ์ถ์ ๊ธฐ์ ๋ถ์ฐจ์ ์ธ ์์์ด๋ค.
from sklearn.linear_model import LinearRegression
# fit the linear regression model to estimate mu_0(x)
ols0 = LinearRegression().fit(untreated[X], untreated[y])
ols1 = LinearRegression().fit(treated[X], treated[y])
# find the units that match to the treated
treated_match_index = mt0.kneighbors(treated[X], n_neighbors=1)[1].ravel()
# find the units that match to the untreatd
untreated_match_index = mt1.kneighbors(untreated[X], n_neighbors=1)[1].ravel()
predicted = pd.concat([
(treated
# find the Y match on the other group
.assign(match=mt0.predict(treated[X]))
# build the bias correction term
.assign(bias_correct=ols0.predict(treated[X]) - ols0.predict(untreated.iloc[treated_match_index][X]))),
(untreated
.assign(match=mt1.predict(untreated[X]))
.assign(bias_correct=ols1.predict(untreated[X]) - ols1.predict(treated.iloc[untreated_match_index][X])))
])
predicted.head()
• ๋ํ ๋งค์นญ์ ๋น๋ชจ์์ ์ถ์ ๋์ด๋ค. ์ ํ์ฑ์ด๋ ์ด๋ค ์ข ๋ฅ์ ๋งค๊ฐ๋ณ์ ๋ชจ๋ธ๋ ๊ฐ์ ํ์ง ์๋๋ค. ๋ฐ๋ผ์ ์ ํํ๊ท๋ณด๋ค ์ ์ฐํ ๋ฐฉ์์ด๋ฉฐ ๋น์ ํ์ฑ์ด ๋งค์ฐ ๊ฐํ ์ํฉ์์ ์๋ํ ์ ์๋ค.
np.mean((2*predicted["medication"] - 1)*((predicted["recovery"] - predicted["match"])-predicted["bias_correct"]))
## -7.36266090614141
โฏ CausalModel ์ ํ์ฉํ ์ถ์
from causalinference import CausalModel
cm = CausalModel(
Y=med["recovery"].values,
D=med["medication"].values,
X=med[["severity", "age", "sex"]].values
)
cm.est_via_matching(matches=1, bias_adj=True)
print(cm.estimates)
โจ ์ฝ๋ฌผ์ด ์ค์ ๋ก ํ์์ ๋ณ์ ์ ์ ๊ธฐ๊ฐ์ ์ค์ฌ์ค๋ค๊ณ ์์ ์๊ฒ ๋งํ ์ ์๋ค.
โค The Curse of Dimensionality
• ๋งค์นญ๋ ๊ด์ธก์น๋ค์ด ์ ์ฌํ์ง ์์ ๋ ํธํฅ์ด ๋ฐ์ํ๋ค. ๋ ๋ง์ ๋ณ์๊ฐ ์กด์ฌํ ์๋ก ํด๋น ๊ด์ธก์น์ ๋งค์นญ๋๋ ๊ด์ธก์น ์ฌ์ด์ ๊ฑฐ๋ฆฌ๊ฐ ๋ ๋ฉ์ด์ง๋ค. ์ฆ, ์ฐจ์์ ์ ์ฃผ ํ์์ด ๋ฐ์ํ๋ค.
• ์ ํํ๊ท์ ๊ฒฝ์ฐ ์ด ๋ฌธ์ ๋ฅผ ์ ์ฒ๋ฆฌํ๋ค. ๋ชจ๋ ๋ณ์ X๋ฅผ ๋จ์ผ ์ฐจ์ Y๋ก ํฌ์ํ๊ธฐ ๋๋ฌธ์ด๋ค. ํด๋น ํฌ์์ ๋ํ ์ฒ์น ๋ฐ ํต์ ์ ๋ํ ๋น๊ต๋ฅผ ์งํํ๋ค.
'1๏ธโฃ AIโขDS > ๐ฅ Casual inference' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[The Brave and True] 12. Doubly Robust Estimation (0) | 2023.07.14 |
---|---|
[The Brave and True] 11. Propensity score (0) | 2023.07.13 |
[The Brave and True] 9. Non Compliance and LATE (0) | 2023.07.04 |
[The Brave and True] 8. Instrumental variables (0) | 2023.07.03 |
์ธ๊ณผ์ถ๋ก ์ ๋ฐ์ดํฐ ๊ณผํ - ๋จธ์ ๋ฌ๋์ ํด์ ๊ฐ๋ฅ์ฑ๊ณผ ์ธ๊ณผ์ถ๋ก (0) | 2023.06.29 |
๋๊ธ