5,337
edits
Line 1,608: | Line 1,608: | ||
* Necessity of small <math>d_{H}(P,Q)</math> for DA. | * Necessity of small <math>d_{H}(P,Q)</math> for DA. | ||
* Necessity of small joint training error. | * Necessity of small joint training error. | ||
===Domain Generalization=== | |||
Also known as out-of-dist (OOD) generalization. | |||
Training: <math>|E|</math> training domains (envs) | |||
<math>P^{(e)} \sim \{(x_i^e, y_i^e)\}_{i=1}^{m_e}</math> with <math>1 \leq e \leq |E|</math>. | |||
Goal: Find <math>h \in H</math> that performs well in an unseen domain (domain <math>|E|+1</math>). | |||
At test time <math>P^{(K+1)} \sim \{(x_i^{(k+1)}, y_i^{(k+1)})\}_{i=1}^{m_{(k+1)}} = E[\ell(h(x),y)]</math>. | |||
<math>R^{(k+1)}(h) = E_{(x,y) \sim P^{(k+1)}}[\ell(h(x), y)]</math>. | |||
;Example datasets | |||
* [http://ai.bu.edu/M3SDA/ DomainNet] | |||
** 6 domains: painting, photo, sketches, drawings, clipart, infograph | |||
* PACS | |||
** 4 domains: art, cartoon, photo, sketch | |||
* Some Simpler datasets | |||
** Rotated MNIST | |||
** Color MNIST | |||
One generalization method is to do nothing, just training normally with ERM. | |||
===Domain-adversarial neural networks (DANN)=== | |||
* Train a feature extractor <math>\phi</math> and a classifier <math>w</math> to yield <math>f=w \circ \phi</math>. | |||
* Domain classifier <math>c</math>. | |||
* <math>loss = \frac{1}{k} \sum_{j=1}^{k} E[\ell(w \circ \phi(x), y)] + \lambda L(domain\_classification)</math>. | |||
* We optimize <math>\min_{\phi, w} loss</math> | |||
* We can use Wasserstein distance: | |||
** <math>loss = \frac{1}{k} \sum_{j=1}^{k} E[\ell(w \circ \phi(x), y)] + \lambda \sum w(P_{X|Y}^{j_1}, P_{X|Y}^{j_2})</math> | |||
** This is solved using alternating gradient descent. | |||
===Meta Learning for Domain Generalization=== | |||
[Li ''et al.'' 2017] | |||
Idea: Build meta-test domains | |||
;Meta-train | |||
* Loss: <math>L_{meta\_train(\theta) = \frac{1}{K-K_1} \sum_{j} E[\ell(f_{\theta'}(x), y)]</math> | |||
* Take one-gradient step to update the model: | |||
** <math>\phi' = \phi - \eta \nabla L_{metatrain}(\theta)</math> | |||
Overall objective: | |||
* <math>\min_{\theta} L_{\text{meta-train}}(\theta) + \beta L_{\text{meta-test}}(\theta')</math> | |||
To update <math>L_{meta}(\theta)</math>, we need to compute <math>\nabla L_{meta}(\theta)</math> which depends on the Hessian wrt <math>\theta</math>. This can be solved using a ''hessian-vector product'' without computing out the hessian which could be very large. | |||
===Invariant Risk Minimization (IRM)=== | |||
[Arjovsky ''et al.'' (2019)] | |||
Idea: Find a feature extractor <math>\phi()</math> such that optimal classifier is the same for every domain. | |||
Define: <math>R^e(\phi, w) = E[\ell (w_0 \phi(x), y)]</math>. | |||
Objective: <math>\min_{\phi, \hat{w}} \frac{1}{k} \sum R^e (\phi, w)</math> s.t. <math>\forall e</math>, <math>\hat{w} \in \operatorname{argmin}_{\beta} R^e(\phi, \beta)</math> | |||
This is a bi-level optimization which is difficult to solve. The constraint depends on another optimization. | |||
The paper uses a lagrangian relaxation: | |||
<math>\min_{\phi, \hat{w}} \frac{1}{k} \sum R^e(\phi, \hat{w}) + \lambda \Vert \nabla_{\hat{w}} R^e(\phi, \hat{w}) \Vert^2_2</math>. | |||
Argument: If we can solve the optimization, such a function will use only invariant features since non-invariant features will have different conditional distribution with the label. | |||
[Rosenfeld ''et al.'' Oct 2020] | |||
Not a valid argument in general as IRM fails to recover invariant predictors. | |||
===Which method for generalization works the best?=== | |||
[Gultajani & Lopez-Poz] offer the following empirical observations: | |||
* Model selection is critical in domain generalization. | |||
* Training-domain validation set | |||
* Leave-one-domain out validation set | |||
* Oracle selection | |||
;Data augmentation is important in domain generalization | |||
* random crops, random horizontal flips, random color jitter | |||
* image-to-image neural translation | |||
;ERM (doing nothing) outperforms other DG methods!! | |||
* Possible with careful model selection | |||
* Larger models help with domain generalization. | |||
* See DomainBED for code and data. | |||
==Misc== | ==Misc== |