Operator Learning


By Prof. Seungchul Lee
http://iailab.kaist.ac.kr/
Industrial AI Lab at KAIST

Table of Contents


1. Lab 1: Euler Beam

1.1 Problem Setup

  • We will solve a Euler beam problem using DeepONet:
  • Only Data-driven
  • Data and Physics-guided (BC)

No description has been provided for this image


  • Problem properties

$$E = 1 \operatorname{pa}, \quad I = 1 \operatorname{kg\cdot m^2}, \quad l = 1 \operatorname{m}, \quad q(x) = wx + w_0 \quad \text{where} \quad w_0 \in [0,2]$$


  • The exact solution (Case of $w (x) = 1$) is

$$y(x) = -{1 \over 24}x^4 + {1 \over 6}x^3 - {1 \over 4}x^2$$


1.2. Solve the Euler Beam Problem with Data Driven DeepONet

  • Make a neural network and loss functions like below :

In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

1.2.1. Import Library

In [ ]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
%matplotlib inline

1.2.2. CUDA

In [ ]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
cuda:0

1.2.3. Random Seed

In [ ]:
def random_seed(seed = 2021):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

1.2.4 Define Parameter

In [ ]:
# Properties
E = 1
I = 1
l = 1

1.2.5. Define Collocation Points

In [ ]:
'Domain collocation points'
domain_points = np.linspace(0, 1, 25).reshape(-1, 1)

Visualization

In [ ]:
'Plot'

y_init = np.zeros(25)

plt.scatter(domain_points, y_init, s = 30)
plt.xlabel('x')
plt.title('Collocation points')
plt.show()
No description has been provided for this image

1.2.6 Load $q(x)$ and Ground Truth

we use 10 distributed loads with varying $w_0$ and $w_1$ for training, and one distributed load for testing.

In [ ]:
q_train = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/beam_bending/q_train.npy')
q_test = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/beam_bending/q_test.npy')
gt_train = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/beam_bending/gt_train.npy')
gt_test = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/beam_bending/gt_test.npy')

print('Loading Condtions for Train: {}'.format(q_train.shape))
print('Ground Truth for Train:      {}'.format(gt_train.shape))

print('Loading Condtions for Test:  {}'.format(q_test.shape))
print('Ground Truth for Test:       {}'.format(gt_test.shape))
Loading Condtions for Train: (10, 20)
Ground Truth for Train:      (10, 25)
Loading Condtions for Test:  (1, 20)
Ground Truth for Test:       (1, 25)

Visualize $P(x)$

In [ ]:
# Plot Code
x_cor = np.linspace(0, 1, 20)

for i in range(10):
    if i == 0 or i == 3 or i == 6:
        a = q_train[i]
        plt.plot(x_cor, a, linestyle = '-', color = 'blue', marker= 'o', markevery = (0, 2))
    elif i == 9:
        a = q_train[i]
        plt.plot(x_cor, a, linestyle = '-', color = 'blue', marker='o', label = 'train', markevery = (0, 2))

plt.plot(x_cor, q_test.reshape(-1), linestyle = '-', color = 'red', marker = 'o', label = 'test (q = 1)', markevery = (0, 2))
plt.title("Distributed Load")
plt.legend()
plt.grid('on')
plt.xlim([-0.1, 1.1])
plt.ylim([-0.1, 2.1])
plt.show()
No description has been provided for this image

Visualize Ground Truth

In [ ]:
# Plot Code
x_cor = np.linspace(0, 1, 25)

for i in range(10):
    if i == 0 or i == 3 or i == 6:
        a = gt_train[i]
        plt.plot(x_cor, a, linestyle = '-', color = 'blue', marker= 'o', markevery = (0, 2))
    elif i == 9:
        a = gt_train[i]
        plt.plot(x_cor, a, linestyle = '-', color = 'blue', marker='o', label = 'train', markevery = (0, 2))

plt.plot(x_cor, gt_test.reshape(-1), linestyle = '-', color = 'red', marker = 'o', label = 'test (q = 1)', markevery = (0, 2))
plt.title("Deflection of Euler Beam")
plt.legend()
plt.grid('on')
plt.show()
No description has been provided for this image

1.2.7. Generate Dataset

  • Generate a dataset by stacking data for all conditions and collocation points like below:

Domain

In [ ]:
# Load domain collocation points
num_domain = domain_points.shape[0]  # (25, _)
num_condition = q_train.shape[0]     # (10, _)

condition_list = []
XY_domain_list = []
gt_list = []

for i in range(num_condition):
    con = q_train[i]
    conditions = np.vstack([con] * num_domain)
    domains = domain_points
    gts = gt_train[i].reshape(-1, 1)

    condition_list.append(conditions)
    XY_domain_list.append(domains)
    gt_list.append(gts)

conditions_domain = np.vstack(condition_list)
XYs_domain = np.vstack(XY_domain_list)
gts_domain = np.vstack(gt_list)

print('Condtions Domain: {}'.format(conditions_domain.shape))
print('XYs Domain:       {}'.format(XYs_domain.shape))
print('GTs Domain:       {}'.format(gts_domain.shape))
Condtions Domain: (250, 20)
XYs Domain:       (250, 1)
GTs Domain:       (250, 1)

1.2.8. Data Generator

In [ ]:
# Data generator
class DataGenerator(data.Dataset):
    def __init__(self, c, XY, gt, batch_size=64, rng_seed=1234):
        'Initialization'
        self.c = torch.tensor(c, dtype = torch.float32)
        self.XY = torch.tensor(XY, dtype = torch.float32)
        self.gt = torch.tensor(gt, dtype = torch.float32)
        self.N = c.shape[0]
        self.batch_size = batch_size
        self.seed = rng_seed

    def __len__(self):
        return (self.N // self.batch_size) + 1

    def __getitem__(self, index):
        c, XY, gt = self.__data_generation(index)
        return c, XY, gt

    def __data_generation(self, index):
        torch.manual_seed(self.seed + index)

        idx = torch.randperm(self.N)[:self.batch_size]

        c = self.c[idx, :]
        XY = self.XY[idx, :]
        gt = self.gt[idx,:]

        return c, XY, gt
In [ ]:
batch_size = 30000
train_dataset = DataGenerator(conditions_domain, XYs_domain, gts_domain, batch_size)

1.2.9. Define Network and Hyper-parameter

In [ ]:
class MLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.linears = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 16)
        )

    def forward(self, x):
        a = x.float()
        return self.linears(a)
In [ ]:
random_seed(seed = 2021)

dim_q = q_train[0].shape[0]

# Initialize model
branch_net = MLP(input_dim = dim_q).to(device)
trunk_net = MLP(input_dim = 1).to(device)

optimizer = optim.Adam(list(branch_net.parameters()) + list(trunk_net.parameters()), lr = 5e-4)
Loss_func = nn.MSELoss(reduction='mean')

1.2.10 Define Utility Function

In [ ]:
def DeepOnet(model, con_, X_):
    '''
    model = [branch_net, trunk_net]
    '''
    branch_net = model[0]
    trunk_net = model[1]

    B = branch_net(con_)
    T = trunk_net(X_)

    u = torch.sum(B * T, dim=1).unsqueeze(1)
    return u

def derivative(y, t) :
    df = torch.autograd.grad(y, t, create_graph = True, retain_graph = True, grad_outputs = torch.ones(y.size()).to(device))[0]
    return df

def requires_grad(x):
    return torch.tensor(x, requires_grad = True).to(device)
In [ ]:
'PLOT'

def PLOT(branch_test, trunk_test):
    branch_test.eval()
    trunk_test.eval()

    def exact_solution(x):
        return -(x ** 4) / 24 + x ** 3 / 6 - x ** 2 / 4

    test_q = q_test[0].reshape(1, 20)
    test_qs = np.vstack([test_q] * 25)
    test_XYs_domain = np.linspace(0, 1, 25).reshape(-1, 1)

    test_XYs_domain = requires_grad(test_XYs_domain)
    test_qs = requires_grad(test_qs)

    'DeepONet'
    B_test = branch_test(test_qs)
    T_test = trunk_test(test_XYs_domain)
    test_u = torch.sum(B_test * T_test, dim=1).unsqueeze(1)

    'Plot'
    x_ = np.linspace(0, l, len(test_u))
    plt.figure(figsize = (6, 4))
    plt.plot(x_, test_u.detach().cpu().numpy(), c =  'r', label = 'Predict u', linestyle = 'dashed', linewidth = 2)
    plt.plot(x_, exact_solution(x_), c =  'k', label = 'Exact Solution')
    plt.legend(fontsize = 10)
    plt.xticks(fontsize = 10)
    plt.yticks(fontsize = 10)
    plt.xlabel('Time (s)', fontsize = 10)
    plt.ylabel('Displacement (m)', fontsize = 10)
    plt.title('Test Result', fontsize = 15)
    plt.show()
#     clear_output(wait=True)

1.2.11. Train

In [ ]:
num_epochs = 2000
epoch = 0
In [ ]:
train_iter = iter(train_dataset)

while epoch < num_epochs + 1:

    train_batch = next(train_iter)
    con_domain, XY_domain, gt_domain = train_batch

    ############### Reguires grad #################
    con_domain, XY_domain, gt_domain = requires_grad(con_domain), requires_grad(XY_domain), requires_grad(gt_domain)


    model = [branch_net, trunk_net]
    T_domain = DeepOnet(model, con_domain, XY_domain)


    ################# Data loss ###################
    loss_data = Loss_func(T_domain.float(), gt_domain.float().to(device))

    loss = loss_data

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # if epoch % 100 == 0:
    #     with torch.no_grad():
    #         print('Epoch: {} Loss: {:.6f} '.format(epoch, loss.item()))

    if epoch % 500 == 0:
        with torch.no_grad():
            print('Epoch: {} Loss: {:.6f} '.format(epoch, loss.item()))
            PLOT(branch_net, trunk_net)

    epoch += 1
Epoch: 0 Loss: 0.010332 
No description has been provided for this image
Epoch: 500 Loss: 0.000073 
No description has been provided for this image
Epoch: 1000 Loss: 0.000017 
No description has been provided for this image
Epoch: 1500 Loss: 0.000006 
No description has been provided for this image
Epoch: 2000 Loss: 0.000005 
No description has been provided for this image

1.3 Solve the Euler Beam Problem with Data and Physics-guide DeepONet

  • Problem properties

$$E = 1 \operatorname{pa}, \quad I = 1 \operatorname{kg\cdot m^2}, \quad l = 1 \operatorname{m}, \quad q(x) = wx + w_0$$

$$\quad \text{where} \quad w \in [-2,2] \quad \text{and} \quad w_0 \in [0,2] $$


  • One Dirichlet boundary condition on the left boundary:

$$y(0) = 0$$


  • One Neumann boundary condition on the left boundary:

$$y'(0) = 0$$


  • Two boundary conditions on the right boundary:

$$y''(1) = 0, \quad y'''(1) = 0$$


  • The exact solution (Case of $w (x) = 1$) is

$$y(x) = -{1 \over 24}x^4 + {1 \over 6}x^3 - {1 \over 4}x^2$$


  • Make a neural network and loss functions like below :

1.3.1. Define collocation Points

In [ ]:
'Domain collocation points'
domain_points = np.linspace(0, 1, 25).reshape(-1, 1)

'BC collocation points'
bc_points_x0 = (np.zeros(1)).reshape(-1, 1)
bc_points_x1 = (np.ones(1)).reshape(-1, 1)
In [ ]:
'Plot'

y_init = np.zeros(25)

plt.scatter(domain_points, y_init, s = 30)
plt.scatter(bc_points_x0, 0, color = 'r', s = 30)
plt.scatter(bc_points_x1, 0, color = 'r', s = 30)
plt.xlabel('x')
plt.title('Collocation points')
plt.show()
No description has been provided for this image

1.3.2 Generate Dataset

Boundary Condition $x_0$

In [ ]:
num_BC = bc_points_x0.shape[0]     # (1, _)
num_condition = q_train.shape[0]   # (10, _)

condition_list = []
XY_bc0_list = []

for i in range(num_condition):
    con = q_train[i]
    conditions = np.vstack([con] * num_BC)
    bc_x0 = bc_points_x0

    condition_list.append(conditions)
    XY_bc0_list.append(bc_x0)

conditions_bc_x0 = np.vstack(condition_list)
XYs_bc_x0 = np.vstack(XY_bc0_list)

print('Conditions BC0: {}'.format(conditions_bc_x0.shape))
print('XYs BC0:        {}'.format(XYs_bc_x0.shape))
Conditions BC0: (10, 20)
XYs BC0:        (10, 1)

Boundary Condition $x_1$

In [ ]:
num_BC = bc_points_x1.shape[0]     # (1, _)
num_condition = q_train.shape[0]   # (10, _)

condition_list = []
XY_bc1_list = []

for i in range(num_condition):
    con = q_train[i]
    conditions = np.vstack([con] * num_BC)
    bc_x1 = bc_points_x1

    condition_list.append(conditions)
    XY_bc1_list.append(bc_x1)

conditions_bc_x1 = np.vstack(condition_list)
XYs_bc_x1 = np.vstack(XY_bc1_list)

print('Conditions BC1: {}'.format(conditions_bc_x1.shape))
print('Xs BC1:         {}'.format(XYs_bc_x1.shape))
Conditions BC1: (10, 20)
Xs BC1:         (10, 1)

1.3.3. Data Generator

In [ ]:
'BC generator'
class BCGenerator(data.Dataset):
    def __init__(self, c, XY, batch_size=64, rng_seed=1234):
        'Initialization'
        self.c = torch.tensor(c, dtype = torch.float32)
        self.XY = torch.tensor(XY, dtype = torch.float32)
        self.N = c.shape[0]
        self.batch_size = batch_size
        self.seed = rng_seed

    def __len__(self):
        return (self.N // self.batch_size) + 1

    def __getitem__(self, index):
        c, XY = self.__data_generation(index)
        return c, XY

    def __data_generation(self, index):
        torch.manual_seed(self.seed + index)

        idx = torch.randperm(self.N)[:self.batch_size]

        c = self.c[idx, :]
        XY = self.XY[idx, :]

        return c, XY
In [ ]:
batch_size = 30000
train_bc_x0_dataset = BCGenerator(conditions_bc_x0, XYs_bc_x0, batch_size)
train_bc_x1_dataset = BCGenerator(conditions_bc_x1, XYs_bc_x1, batch_size)

1.3.4. Define Networks and Hyper-parameter

In [ ]:
class MLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.linears = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 16)
        )

    def forward(self, x):
        a = x.float()
        return self.linears(a)
In [ ]:
random_seed(seed = 5)

dim_q = q_train[0].shape[0]

# Initialize model
branch_net = MLP(input_dim = dim_q).to(device)
trunk_net = MLP(input_dim = 1).to(device)

optimizer = optim.Adam(list(branch_net.parameters()) + list(trunk_net.parameters()), lr = 5e-4)
Loss_func = nn.MSELoss(reduction='mean')

1.3.5. Define Usable Function

In [ ]:
def PIDeepOnet(model, con_, X_):
    '''
    model = [branch_net, trunk_net]
    '''
    branch_net = model[0]
    trunk_net = model[1]

    B = branch_net(con_)
    T = trunk_net(X_)

    u = torch.sum(B * T, dim=1).unsqueeze(1)
    return u

def derivative(y, t) :
    df = torch.autograd.grad(y, t, create_graph = True,retain_graph = True, grad_outputs = torch.ones(y.size()).to(device))[0]
    return df

def requires_grad(x):
    return torch.tensor(x, requires_grad = True).to(device)
In [ ]:
'PLOT'
def PLOT(branch_test, trunk_test):
    branch_test.eval()
    trunk_test.eval()

    def exact_solution(x):
        return -(x ** 4) / 24 + x ** 3 / 6 - x ** 2 / 4

    test_q = q_test[0].reshape(1, 20)
    test_qs = np.vstack([test_q] * 25)
    test_XYs_domain = np.linspace(0, 1, 25).reshape(-1, 1)

    test_XYs_domain = requires_grad(test_XYs_domain)
    test_qs = requires_grad(test_qs)

    'DeepONet'
    B_test = branch_test(test_qs)
    T_test = trunk_test(test_XYs_domain)
    test_u = torch.sum(B_test * T_test, dim=1).unsqueeze(1)

    'Plot'
    x_ = np.linspace(0, l, len(test_u))
    plt.figure(figsize = (6, 4))
    plt.plot(x_, test_u.detach().cpu().numpy(), c = 'r', label = 'Predict u', linestyle = 'dashed', linewidth = 2)
    plt.plot(x_, exact_solution(x_), c = 'k', label = 'Exact Solution')
    plt.legend(fontsize = 10)
    plt.xticks(fontsize = 10)
    plt.yticks(fontsize = 10)
    plt.xlabel('Time (s)', fontsize = 10)
    plt.ylabel('Displacement (m)', fontsize = 10)
    plt.title('Test Result', fontsize = 15)
    plt.show()
#     clear_output(wait=True)

1.3.6. Define Boundary Condition

  • One Dirichlet boundary condition on the left boundary:

$$y(0) = 0$$

  • One Neumann boundary condition on the left boundary:

$$y'(0) = 0$$

  • Two boundary conditions on the right boundary:

$$y''(1) = 0, \quad y'''(1) = 0$$

In [ ]:
def BC_x0(model, c_x0, XY_x0):

    u_0 = PIDeepOnet(model, c_x0, XY_x0)
    u_x_0 = derivative(u_0, XY_x0)

    return u_0.float(), u_x_0.float()


def BC_x1(model, c_x1, XY_x1):

    u_1 = PIDeepOnet(model, c_x1, XY_x1)
    u_x_1 = derivative(u_1, XY_x1)
    u_xx_1 = derivative(u_x_1, XY_x1)
    u_xxx_1 = derivative(u_xx_1, XY_x1)

    return  u_xx_1.float(), u_xxx_1.float()

1.3.6. Train

In [ ]:
num_epochs = 2000

epoch = 0
In [ ]:
train_iter = iter(train_dataset)
train_bc_x0_iter = iter(train_bc_x0_dataset)
train_bc_x1_iter = iter(train_bc_x1_dataset)

while epoch < num_epochs + 1:

    train_batch = next(train_iter)
    train_bc_x0_batch = next(train_bc_x0_iter)
    train_bc_x1_batch = next(train_bc_x1_iter)

    con_domain, XY_domain, gt_domain = train_batch
    con_x0, XY_x0 = train_bc_x0_batch
    con_x1, XY_x1 = train_bc_x1_batch

    ############### Reguires grad #################
    'Domain'
    con_domain, XY_domain, gt_domain = requires_grad(con_domain), requires_grad(XY_domain), requires_grad(gt_domain)
    'BC'
    con_x0, XY_x0 = requires_grad(con_x0), requires_grad(XY_x0)
    con_x1, XY_x1 = requires_grad(con_x1), requires_grad(XY_x1)


    model = [branch_net, trunk_net]
    T_domain = PIDeepOnet(model, con_domain, XY_domain)


    ################# Data Loss ####################
    loss_data = Loss_func(T_domain.float(), gt_domain.float().to(device))

    ################## BC Loss #####################
    u_x0, du_x0 = BC_x0(model, con_x0, XY_x0)
    ddu_x1, dddu_x1 = BC_x1(model, con_x1, XY_x1)

    loss_BC_x0_1 = Loss_func(u_x0, torch.zeros_like(u_x0).to(device))
    loss_BC_x0_2 = Loss_func(du_x0, torch.zeros_like(du_x0).to(device))
    loss_BC_x1_1 = Loss_func(ddu_x1, torch.zeros_like(ddu_x1).to(device))
    loss_BC_x1_2 = Loss_func(dddu_x1, torch.zeros_like(dddu_x1).to(device))
    loss_bc = loss_BC_x0_1 + loss_BC_x0_2 + loss_BC_x1_1 + loss_BC_x1_2

    loss = loss_data + loss_bc

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 500 == 0:
        with torch.no_grad():
            print('Epoch: {} Loss: {:.6f} '.format(epoch, loss.item()))
            PLOT(branch_net, trunk_net)

    epoch += 1
Epoch: 0 Loss: 0.003302 
No description has been provided for this image
Epoch: 500 Loss: 0.000150 
No description has been provided for this image
Epoch: 1000 Loss: 0.000065 
No description has been provided for this image
Epoch: 1500 Loss: 0.000007 
No description has been provided for this image
Epoch: 2000 Loss: 0.000001 
No description has been provided for this image

2. Lab 2: Elastic Deformation for Thin Plate (PI-DeepONet)

2.1. Problem Setup

  • We will solve thin plate equations to find displacement and stress distribution of thin plate using PI-DeepONet
  • Only Physics-informed (PDE and BC)
No description has been provided for this image
  • Based on Kirchhoff-Love plate theory, three hypotheses were used

    • straight lines normal to the mid-surface remain straight after deformation
    • straight lines normal to the mid-surface remain normal to the mid-surface after deformation
    • the thickness of the plate does not change during a deformation
  • A non-uniform stretching force is applied to square elastic plate

  • Only one quarter of the plate is considered since the geometry and in-plane forces are symmetric (yellow domain)



  • Problem properties

$$ E = 50 \operatorname{Mpa}, \quad \nu = 0.3, \quad \omega = 20 \operatorname{mm}, \quad h = 1 \operatorname{mm}, \quad f \operatorname{Mpa} \in [0.5, 1.5] $$
  • Governing equations (Föppl–von Kármán equations) for the isotropic elastic plate:

$$ \begin{align*} &{E \over 1 - \nu^2}\left({\partial^2 u \over \partial x^2} + {1 - \nu \over 2}{\partial^2 u \over \partial y^2} + {1 + \nu \over 2}{\partial^2 v \over \partial x \partial y} \right) = 0\\\\ &{E \over 1 - \nu^2}\left({\partial^2 v \over \partial y^2} + {1 - \nu \over 2}{\partial^2 v \over \partial x^2} + {1 + \nu \over 2}{\partial^2 x \over \partial x \partial y} \right) = 0 \end{align*} $$
  • Two Dirichlet boundary conditions at $x = 0,\, y = 0\; (B.C.①, B.C.②)$:

$$ v(x,y) = 0 \qquad \text{at} \quad y = 0\\\\ u(x,y) = 0 \qquad \text{at} \quad x = 0 $$
  • Two free boundary conditions at $y = \omega / 2\; (B.C.③)$:

$$ \sigma_{yy} = 0,\quad \sigma_{yx} = 0 $$
  • Free boundary condition and in-plane force boundary condition at $x = \omega / 2\; (B.C.④)$:

$$ \sigma_{xx} = P \centerdot h,\quad \sigma_{xy} = 0 $$
  • Make a neural network and loss funcitons like below:

2.1.1 Numerical Solution

  • Numerical solution of test case is one is illustrated in below figures:

  • The case of $f$ = 1

  • $x, y$ direction displacement and stress $u$, $v$, $\sigma_{xx}$, $\sigma_{yy}$, respectively

  • Solve this problem using PI-DeepONet and then compare with a numerical solution


No description has been provided for this image
No description has been provided for this image

In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

2.2 Solve the Thin Plate Problem with PI-DeepONet

2.2.1 Load Library

In [ ]:
import numpy as np
import math
import matplotlib.pyplot as plt
import random
import os
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
%matplotlib inline

2.2.2. CUDA

In [ ]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
cuda:0

2.2.3. Random Seed

In [ ]:
def random_seed(seed = 5):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

2.2.4. Define Properties

In [ ]:
# Properties
E = 50
nu = 0.3
a = 10
w = 10
h = 1

2.2.5. Generate Collocation Points

In [ ]:
Nx = 50                                           # Number of samples
Ny = 50                                           # Number of samples
x = np.linspace(0, w, Nx)                         # Input data for x (Nx x 1)
y = np.linspace(0, w, Ny)                         # Input data for y (Ny x 1)

xy = np.meshgrid(x, y)
xy_domain = np.concatenate((xy[0].reshape(-1, 1), xy[1].reshape(-1, 1)), 1)

xy_left = xy_domain[xy_domain[:, 0] == 0]
xy_right = xy_domain[xy_domain[: ,0] == w]
xy_top = xy_domain[xy_domain[:, 1] == w]
xy_bottom = xy_domain[xy_domain[:, 1] == 0]

plt.figure(figsize=(5, 5))
plt.scatter(xy_domain[:, 0], xy_domain[:, 1], s=10, color='gray', label='Interior')
plt.scatter(xy_bottom[:, 0], xy_bottom[:, 1], s=10, color='red', label='Bottom')
plt.scatter(xy_top[:, 0], xy_top[:, 1], s=10, color='blue', label='Top')
plt.scatter(xy_left[:, 0], xy_left[:, 1], s=10, color='green', label='Left')
plt.scatter(xy_right[:, 0], xy_right[:, 1], s=10, color='purple', label='Right')

plt.legend()
plt.title('Collocation Points')
plt.xlabel('X')
plt.ylabel('Y')
plt.grid(True)
plt.show()
No description has been provided for this image

2.2.6. Load $P(x)$

In [ ]:
# Define Condition parameter
full_f = [0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5]
f_train= [0.5, 0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5]
f_test = [1]

P_train = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/thin_plate/tp_P_train.npy')
P_test = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/thin_plate/tp_P_test.npy')

print('Loading Condtions for Train: {}'.format(P_train.shape))
print('Loading Condtions for Test:  {}'.format(P_test.shape))
Loading Condtions for Train: (10, 250)
Loading Condtions for Test:  (1, 250)

Plot

In [ ]:
x_cor = np.linspace(0, 10, 250)
plt.plot(x_cor, P_test[0], color = 'red', label = 'test (f=1)', linewidth = 2, marker= 'o', markevery = (0, 10))
plt.plot(x_cor, P_train[3], color = 'blue', label = 'train' ,linewidth = 2, marker= 'o', markevery = (0, 10))
plt.plot(x_cor, P_train[6], color = 'blue', linewidth = 2, marker= 'o', markevery = (0, 10))
plt.plot(x_cor, P_train[9], color = 'blue', linewidth = 2, marker= 'o', markevery = (0, 10))
plt.xlabel('y')
plt.ylabel('P')
plt.ylim([0,1.7])
plt.title('P Plot')
plt.grid('on')
plt.legend()
plt.show()
No description has been provided for this image

2.2.7. Generate Dataset

  • Generate a dataset by stacking data for all conditions and collocation points like below:

Domain

In [ ]:
num_domain = xy_domain.shape[0]  # (2500,_)
num_condition = P_train.shape[0]  #(10,_)

condition_list = []
XY_domain_list = []

for i in range(num_condition):
    con = P_train[i]
    conditions = np.vstack([con] * num_domain)
    domains = xy_domain

    condition_list.append(conditions)
    XY_domain_list.append(domains)

conditions_domain = np.vstack(condition_list)
XYs_domain = np.vstack(XY_domain_list)

print('Condtions Domain: {}'.format(conditions_domain.shape))
print('XYs Domain:       {}'.format(XYs_domain.shape))
Condtions Domain: (25000, 250)
XYs Domain:       (25000, 2)

Boundary Condition

Boundary Condition $x_0$

In [ ]:
num_BC = xy_left.shape[0]    # (50,_)
num_condition = P_train.shape[0]  # (10,_)

condition_list = []
XY_left_list = []

for i in range(num_condition):
    con = P_train[i]
    conditions = np.vstack([con] * num_BC)
    XYs_bc_left = xy_left

    condition_list.append(conditions)
    XY_left_list.append(XYs_bc_left)

conditions_left = np.vstack(condition_list)
XYs_left = np.vstack(XY_left_list)

print('Conditions BC x0: {}'.format(conditions_left.shape))
print('XYs BC x0:        {}'.format(XYs_left.shape))
Conditions BC x0: (500, 250)
XYs BC x0:        (500, 2)

Boundary Condition $x_1$

In [ ]:
num_BC = xy_right.shape[0]    # (50,_)
num_condition = P_train.shape[0]  # (10,_)

condition_list = []
XY_right_list = []

for i in range(num_condition):
    con = P_train[i]
    conditions = np.vstack([con] * num_BC)
    XYs_bc_right = xy_right

    condition_list.append(conditions)
    XY_right_list.append(XYs_bc_right)

conditions_right = np.vstack(condition_list)
XYs_right = np.vstack(XY_right_list)

print('Conditions BC x1: {}'.format(conditions_right.shape))
print('XYs BC x1:        {}'.format(XYs_right.shape))
Conditions BC x1: (500, 250)
XYs BC x1:        (500, 2)

Boundary Condition $y_0$

In [ ]:
num_BC = xy_bottom.shape[0]    # (50,_)
num_condition = P_train.shape[0]  # (10,_)

condition_list = []
XY_bottom_list = []

for i in range(num_condition):
    con = P_train[i]
    conditions = np.vstack([con] * num_BC)
    XYs_bc_bottom = xy_bottom

    condition_list.append(conditions)
    XY_bottom_list.append(XYs_bc_bottom)

conditions_bottom = np.vstack(condition_list)
XYs_bottom = np.vstack(XY_bottom_list)

print('Conditions BC y0: {}'.format(conditions_bottom.shape))
print('XYs BC y0:        {}'.format(XYs_bottom.shape))
Conditions BC y0: (500, 250)
XYs BC y0:        (500, 2)

Boundary Condition $y_1$

In [ ]:
num_BC = xy_top.shape[0]    # (50,_)
num_condition = P_train.shape[0]  # (10,_)

condition_list = []
XY_top_list = []

for i in range(num_condition):
    con = P_train[i]
    conditions = np.vstack([con] * num_BC)
    XYs_bc_top = xy_top

    condition_list.append(conditions)
    XY_top_list.append(XYs_bc_top)

conditions_top = np.vstack(condition_list)
XYs_top = np.vstack(XY_top_list)

print('Conditions BC y1: {}'.format(conditions_top.shape))
print('XYs BC y1:        {}'.format(XYs_top.shape))
Conditions BC y1: (500, 250)
XYs BC y1:        (500, 2)
In [ ]:
# Data generator
class DataGenerator(data.Dataset):
    def __init__(self, c, XY, batch_size=64, rng_seed=1234):
        'Initialization'
        self.c = torch.tensor(c, dtype = torch.float32)
        self.XY = torch.tensor(XY, dtype = torch.float32)
        self.N = c.shape[0]
        self.batch_size = batch_size
        self.seed = rng_seed

    def __len__(self):
        return (self.N // self.batch_size) + 1

    def __getitem__(self, index):
        c, XY = self.__data_generation(index)
        return c, XY

    def __data_generation(self, index):
        torch.manual_seed(self.seed + index)

        idx = torch.randperm(self.N)[:self.batch_size]

        c = self.c[idx, :]
        XY = self.XY[idx, :]

        return c, XY
In [ ]:
batch_size = 30000
train_domain = DataGenerator(conditions_domain, XYs_domain, batch_size)
train_left = DataGenerator(conditions_left, XYs_left, batch_size)
train_right = DataGenerator(conditions_right, XYs_right, batch_size)
train_bottom = DataGenerator(conditions_bottom, XYs_bottom, batch_size)
train_top = DataGenerator(conditions_top, XYs_top, batch_size)

2.2.8. Define Neural Networks and Hyper-parameter

In [ ]:
class MLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.linears = nn.Sequential(
            nn.Linear(input_dim, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        a = x.float()
        return self.linears(a)
In [ ]:
random_seed()

# Initialize model
dim_P = P_train[0].shape[0]

branch_net_u = MLP(input_dim = dim_P).to(device)
branch_net_v = MLP(input_dim = dim_P).to(device)
trunk_net_u = MLP(input_dim = 2).to(device)
trunk_net_v = MLP(input_dim = 2).to(device)

optimizer = optim.Adam(
    list(branch_net_u.parameters()) + list(branch_net_v.parameters()) +
    list(trunk_net_u.parameters()) + list(trunk_net_v.parameters()) ,lr=5e-4)
Loss_func = nn.MSELoss(reduction='mean')

2.2.9. Define Utility Function

In [ ]:
def PIDeepOnet(model, con_, X_):
    '''
    model = [branch_net_u, branch_net_v, trunk_net_u, trunk_net_v]
    '''

    branch_net_u = model[0]
    branch_net_v = model[1]
    trunk_net_u = model[2]
    trunk_net_v = model[3]

    B1 = branch_net_u(con_)
    B2 = branch_net_v(con_)
    T1 = trunk_net_u(X_)
    T2 = trunk_net_v(X_)

    u = torch.sum(B1 * T1, dim=1).unsqueeze(1)
    v = torch.sum(B2 * T2, dim=1).unsqueeze(1)

    return u, v

def derivative(y, t) :
    df = torch.autograd.grad(y, t, create_graph=True,retain_graph = True, grad_outputs=torch.ones(y.size()).to(device))[0]
    df_x = df[:, 0:1]
    df_y = df[:, 1:2]
    return df_x, df_y

def requires_grad(x):
    return torch.tensor(x,requires_grad=True).to(device)
In [ ]:
'PLOT'
from IPython.display import clear_output

def check_stress(model, test_XYs, test_Ps):

    branch_net_u, branch_net_v, trunk_net_u, trunk_net_v = model[0], model[1], model[2], model[3]
    B1_test = branch_net_u(test_Ps)
    B2_test = branch_net_v(test_Ps)
    T1_test = trunk_net_u(test_XYs)
    T2_test = trunk_net_v(test_XYs)

    u = torch.sum(B1_test * T1_test, dim=1).unsqueeze(1)
    v = torch.sum(B2_test * T2_test, dim=1).unsqueeze(1)

    du_x, _ = derivative(u, test_XYs)
    _, dv_y = derivative(v, test_XYs)


    sig_xx = (du_x + nu * dv_y) * E / (1 - nu**2)
    sig_yy = (dv_y + nu * du_x) * E / (1 - nu**2)

    return sig_xx, sig_yy


def PLOT(branch_net_u, branch_net_v, trunk_net_u, trunk_net_v):
    branch_net_u.eval()
    branch_net_v.eval()
    trunk_net_u.eval()
    trunk_net_v.eval()

    test_points = xy_domain.copy()
    num_points = len(test_points)
    test_Ps = np.vstack([P_test]*num_points)
    test_XYs = test_points

    test_Ps = requires_grad(test_Ps)
    test_XYs = requires_grad(test_XYs)

    color_legend = [[0, 0.182], [-0.06, 0.011], [-0.0022,1.0], [-0.15, 0.45]]
    title = ['x-displacement ($u$)', 'y-displacement ($v$)', '$\sigma_{xx}$', '$\sigma_{yy}$']


    B1_test = branch_net_u(torch.tensor(test_Ps).to(device))
    B2_test = branch_net_v(torch.tensor(test_Ps).to(device))
    T1_test = trunk_net_u(torch.tensor(test_XYs).to(device))
    T2_test = trunk_net_v(torch.tensor(test_XYs).to(device))

    test_u = torch.sum(B1_test*T1_test, dim=1).unsqueeze(1)
    test_v = torch.sum(B2_test*T2_test, dim=1).unsqueeze(1)

    test_sig_x, test_sig_y = check_stress([branch_net_u, branch_net_v, trunk_net_u,trunk_net_v], test_XYs, test_Ps)

    test_results = torch.cat((test_u, test_v, test_sig_x, test_sig_y), dim=1)

    plt.figure(figsize=(6, 5))
    for idx in range(4):
        plt.subplot(2,2,idx+1)
        plt.scatter(test_points[:,0],
                    test_points[:,1],
                    c = test_results.detach().cpu().numpy()[:, idx],
                    cmap = 'rainbow',
                    s = 7)
        plt.clim(color_legend[idx])
        plt.title(title[idx], fontsize = 13)
        plt.xlabel('x (mm)', fontsize = 12)
        plt.ylabel('y (mm)', fontsize = 12)
        plt.axis('square')
        plt.xlim(0, 10)
        plt.ylim(0, 10)
        plt.colorbar()
    plt.tight_layout()
#     plt.savefig('./images/{}.png'.format(epoch))
    plt.show()
    clear_output(wait=True)

    return test_u, test_v, test_sig_x, test_sig_y

2.2.10. Define PDE with Boundary Conditions

  • Governing equations (Föppl–von Kármán equations) for the isotropic elastic plate:

$$ \begin{align*} &{E \over 1 - \nu^2}\left({\partial^2 u \over \partial x^2} + {1 - \nu \over 2}{\partial^2 u \over \partial y^2} + {1 + \nu \over 2}{\partial^2 v \over \partial x \partial y} \right) = 0\\\\ &{E \over 1 - \nu^2}\left({\partial^2 v \over \partial y^2} + {1 - \nu \over 2}{\partial^2 v \over \partial x^2} + {1 + \nu \over 2}{\partial^2 x \over \partial x \partial y} \right) = 0 \end{align*} $$
In [ ]:
def PDE(model, c_domain, XY_domain):

    u, v = PIDeepOnet(model, c_domain, XY_domain)

    du_x, du_y = derivative(u, XY_domain)
    du_xx, du_xy = derivative(du_x, XY_domain)
    _, du_yy =  derivative(du_y, XY_domain)

    dv_x, dv_y = derivative(v, XY_domain)
    dv_xx, dv_xy = derivative(dv_x, XY_domain)
    _, dv_yy = derivative(dv_y, XY_domain)

    force_eq_x = (du_xx + 0.5 * (1 - nu) * du_yy + 0.5 * (1 + nu) * dv_xy) * E / (1 - nu**2)
    force_eq_y = (dv_yy + 0.5 * (1 - nu) * dv_xx + 0.5 * (1 + nu) * du_xy) * E / (1 - nu**2)

    return force_eq_x.float(), force_eq_y.float()
  • Two Dirichlet boundary conditions at $x = 0,\, y = 0\; (B.C.①, B.C.②)$:

$$ v(x,y) = 0 \qquad \text{at} \quad y = 0\\\\ u(x,y) = 0 \qquad \text{at} \quad x = 0 $$
  • Two free boundary conditions at $y = \omega / 2\; (B.C.③)$:

$$ \sigma_{yy} = 0,\quad \sigma_{yx} = 0 $$
  • Free boundary condition and in-plane force boundary condition at $x = \omega / 2\; (B.C.④)$:

$$ \sigma_{xx} = P \centerdot h,\quad \sigma_{xy} = 0 $$
In [ ]:
def BC_bottom(model, con_bottom, XY_bottom):
    _, v = PIDeepOnet(model, con_bottom, XY_bottom)
    return v.float()

def BC_left(model, con_left, XY_left):
    u, _ = PIDeepOnet(model, con_left, XY_left)
    return u.float()

def BC_top(model, con_top, XY_top):
    u, v = PIDeepOnet(model, con_top, XY_top)

    du_x, du_y = derivative(u, XY_top)
    dv_x, dv_y = derivative(v, XY_top)

    sig_xy = (dv_x + du_y) * E / (1 + nu) / 2
    sig_yy = (dv_y + nu * du_x) * E / (1 - nu**2)

    return sig_xy.float(), sig_yy.float()

def BC_right(model, con_right, XY_right):
    u, v = PIDeepOnet(model, con_right, XY_right)

    du_x, du_y = derivative(u, XY_right)
    dv_x, dv_y = derivative(v, XY_right)

    sig_xx = (du_x + nu * dv_y) * E / (1 - nu**2)
    sig_xy = (dv_x + du_y) * E / (1 + nu) / 2
    sig_ex = con_right[:, 0:1] * h * torch.cos(math.pi / (2 * w) * XY_right[:, 1:2]).reshape(-1, 1)

    sig_xx_ex = sig_xx.float() - sig_ex.float()

    return sig_xx_ex.float(), sig_xy.float()
In [ ]:
epoch = 0
num_epochs = 50000
min_loss = np.inf

2.2.12. Train

In [ ]:
train_domain_iter= iter(train_domain)
train_left_iter= iter(train_left)
train_right_iter= iter(train_right)
train_bottom_iter= iter(train_bottom)
train_top_iter= iter(train_top)

while epoch < num_epochs + 1:

    train_domain_batch = next(train_domain_iter)
    train_bottom_batch = next(train_bottom_iter)
    train_left_batch = next(train_left_iter)
    train_top_batch = next(train_top_iter)
    train_right_batch = next(train_right_iter)

    con_domain, XY_domain = train_domain_batch
    con_bottom, XY_bottom = train_bottom_batch
    con_left, XY_left = train_left_batch
    con_top, XY_top = train_top_batch
    con_right, XY_right = train_right_batch

    ############### Reguires grad #################
    'Domain'
    con_domain, XY_domain = requires_grad(con_domain), requires_grad(XY_domain)
    'BC'
    con_left, XY_left = requires_grad(con_left), requires_grad(XY_left)
    con_right, XY_right = requires_grad(con_right), requires_grad(XY_right)
    con_bottom, XY_bottom = requires_grad(con_bottom), requires_grad(XY_bottom)
    con_top, XY_top = requires_grad(con_top), requires_grad(XY_top)

    model = [branch_net_u, branch_net_v, trunk_net_u, trunk_net_v]

    ################# PDE Loss #####################
    force_eq_x, force_eq_y = PDE(model, con_domain, XY_domain)
    loss_PDE_x = Loss_func(force_eq_x, torch.zeros_like(force_eq_x).to(device))
    loss_PDE_y = Loss_func(force_eq_y, torch.zeros_like(force_eq_y).to(device))
    loss_PDE = loss_PDE_x + loss_PDE_y

    ################## BC Loss #####################
    bottom_v = BC_bottom(model, con_bottom, XY_bottom)
    left_u = BC_left(model, con_left, XY_left)
    top_sig_xy, top_sig_yy = BC_top(model, con_top, XY_top)
    right_sig_xx_ex, right_sig_xy = BC_right(model, con_right, XY_right)

    loss_BC_bottom = Loss_func(bottom_v,torch.zeros_like(bottom_v)).to(device)

    loss_BC_left = Loss_func(left_u, torch.zeros_like(left_u)).to(device)

    loss_BC_top_1 = Loss_func(top_sig_xy, torch.zeros_like(top_sig_xy)).to(device)
    loss_BC_top_2 = Loss_func(top_sig_yy, torch.zeros_like(top_sig_yy)).to(device)
    loss_BC_top = loss_BC_top_1 + loss_BC_top_2

    loss_BC_right_1 = Loss_func(right_sig_xx_ex, torch.zeros_like(right_sig_xx_ex)).to(device)
    loss_BC_right_2 = Loss_func(right_sig_xy, torch.zeros_like(right_sig_xy)).to(device)
    loss_BC_right = loss_BC_right_1 + loss_BC_right_2

    loss_BC = loss_BC_left + loss_BC_bottom + loss_BC_right + loss_BC_top

    loss = loss_PDE + loss_BC

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print('Epoch: {} Loss: {:.6f} PDELoss: {:.6f} BCLoss: {:.6f}'.format(epoch, loss.item(), loss_PDE.item(), loss_BC.item()))
        PLOT(branch_net_u, branch_net_v, trunk_net_u, trunk_net_v)

    epoch += 1

2.2.13. Test Pretrained Model

In [ ]:
branch_u_trained = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/thin_plate/tp_branch_net_u.pt')
branch_v_trained = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/thin_plate/tp_branch_net_v.pt')
trunk_u_trained = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/thin_plate/tp_trunk_net_u.pt')
trunk_v_trained = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/thin_plate/tp_trunk_net_v.pt')
test_u, test_v, test_sig_x, test_sig_y = PLOT(branch_u_trained, branch_v_trained, trunk_u_trained, trunk_v_trained)
No description has been provided for this image
In [ ]:
Plate_data = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/thin_plate/Plate_data.npy')
loc = Plate_data[:, 0:2]
u = Plate_data[:, 2:3]
v = Plate_data[:, 3:4]
stress = Plate_data[:, 4:6]
In [ ]:
def RESULT(model_test, con_test, XY_test):

    pde_u, pde_v = PIDeepOnet(model_test, con_test, XY_test)
    pde_disp = np.hstack([pde_u.cpu().detach().numpy(), pde_v.cpu().detach().numpy()])
    sig_x, sig_y = check_stress(model_test, XY_test, con_test)
    pde_sig = np.hstack([sig_x.cpu().detach().numpy(), sig_y.cpu().detach().numpy()])
    return pde_disp, pde_sig

diag_ = [i for i in range(u.shape[0]) if loc[i, 0] + loc[i, 1] == 10]
diag_x = np.linspace(0, 10, 101).reshape(-1, 1)
diag_y = -diag_x + 10
diag = np.concatenate((diag_x, diag_y), 1)
XY_test = diag
con_test = np.vstack([P_test] * XY_test.shape[0])

con_test, XY_test = requires_grad(con_test), requires_grad(XY_test)

model_test = [branch_u_trained, branch_v_trained, trunk_u_trained, trunk_v_trained]

results = {
    "FEM": [u[diag_], v[diag_], stress[diag_, 0], stress[diag_, 1]],
    "PI-DeepONet": RESULT(model_test, con_test, XY_test),
}

for key in ["PI-DeepONet"]:
    disp, sig = results[key]
    results[key] = [disp[:, 0], disp[:, 1], sig[:, 0], sig[:, 1]]

titles = ['x-displacement ($u$)', 'y-displacement ($v$)', '$\sigma_{xx}$', '$\sigma_{yy}$']
line_styles = {'FEM': 'k', 'PI-DeepONet': '--'}

plt.figure(figsize=(10, 9))

for idx, title in enumerate(titles):
    plt.subplot(2, 2, idx + 1)
    for label, result in results.items():
        plt.plot(diag[:, 0], result[idx], line_styles[label], linewidth=3, label=label)
    plt.xlabel('x (mm)', fontsize=15)
    plt.ylabel(title, fontsize=15)
    plt.xlim((0, 10))
    plt.legend(fontsize=11)

plt.tight_layout()
plt.show()
No description has been provided for this image

3. Lab 3: Flow Around a Cylinder

3.1. Problem Setup

  • We will solve the Naiver-Stokes equation with varying viscosity term with PI-DeepONet
  • Data and Physics-informed (PDE and BC)

No description has been provided for this image


  • Problem properties

$$\rho = 1\operatorname{kg/m^3}, \quad \mu \operatorname{N\cdot s/m^2} \in [0.002, 0.02] , \quad D = 2h = 1\operatorname{m}, \quad L = 2\operatorname{m}, \quad u_{in} = 1\operatorname{m/s}, \quad \nu = \frac{\mu}{\rho}$$


  • 2D Navier-Stokes Equations & boundary conditions (for steady state)

$$ \begin{align*} \rho \left(u{\partial u \over \partial x} + v{\partial u \over \partial y} + {1 \over \rho}{\partial p \over \partial x}\right) - \mu \ \left({\partial^2 u \over {\partial x^2}} + {\partial^2 u \over {\partial y^2}}\right) &= 0\\\\ \rho \left(u{\partial v \over \partial x} + v{\partial v \over \partial y} + {1 \over \rho}{\partial p \over \partial y}\right) - \mu \ \left({\partial^2 v \over {\partial x^2}} + {\partial^2 v \over {\partial y^2}}\right) &= 0\\\\ {\partial u \over \partial x} + {\partial v \over \partial y} &= 0 \end{align*} $$


  • Two Dirichlet boundary conditions on the plate boundary (no-slip condition)

$$u(x,y) = 0, \quad v(x,y) = 0 \qquad \text{at} \quad y = \frac{D}{2} \ \; \text{or} \; -\frac{D}{2}$$


  • Two Dirichlet boundary conditions at the inlet boundary (no-slip condition)

$$u(-1,y) = u_{\text{in}}, \quad v(-1,y) = 0$$


  • Two Dirichlet boundary conditions at the outlet boundary

$$p(1,y) = 0, \quad v(1,y) = 0$$


  • Two Dirichlet boundary conditions at the cylinder boundary

$$u(\text{cylinder}) = 0, \quad v(\text{cylinder}) = 0$$


  • Make a neural network and loss functions like below :

In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

3.2. Solve the Navier-Stokes Equations with PI-DeepONet

3.2.1.Load Library

In [ ]:
import numpy as np
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data

%matplotlib inline

import os
import warnings
warnings.filterwarnings("ignore")

3.2.2. CUDA

In [ ]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
cuda:0

3.2.3. Random Seed

In [ ]:
def random_seed(seed = 2021):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

3.2.4. Define Conditions and Load Data

In [ ]:
vis_list = np.array([0.002, 0.004, 0.006, 0.008, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.2]).reshape(-1, 1) # 11
vis_train = np.array([0.002, 0.004, 0.008, 0.01, 0.02, 0.06, 0.08, 0.1, 0.2]).reshape(-1, 1) # 9
vis_test = np.array([0.006, 0.04]).reshape(-1, 1) # 2


print("Number of train condition: {}".format(len(vis_train)))
print("Number of test condition:  {}".format(len(vis_test)))
Number of train condition: 9
Number of test condition:  2
In [ ]:
train_xy_domain = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/viscosity/viscosity_train_xy_domain.npy')
train_gt_domain = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/viscosity/viscosity_train_gt_domain.npy')
test_xy_domain = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/viscosity/viscosity_test_xy_domain.npy')
test_gt_domain = np.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/data/viscosity/viscosity_test_gt_domain.npy')
In [ ]:
print("train XY shape: {}".format(train_xy_domain.shape))
print("train gt shape: {}".format(train_gt_domain.shape))
print("test XY shape:  {}".format(test_xy_domain.shape))
print("test gt shape:  {}".format(test_gt_domain.shape))
train XY shape: (9, 19692, 2)
train gt shape: (9, 19692, 3)
test XY shape:  (2, 19692, 2)
test gt shape:  (2, 19692, 3)

3.2.5. Generate collocation points

Domain

In [ ]:
domain_points = train_xy_domain[0]

Boundary Condition

In [ ]:
bc_top_x = np.linspace(-0.5, 1.5, 200).reshape(-1, 1)
bc_top_y = 0.5 * np.ones_like(bc_top_x).reshape(-1, 1)

bc_bottom_x = np.linspace(-0.5, 1.5, 200).reshape(-1, 1)
bc_bottom_y = -0.5 * np.ones_like(bc_bottom_x).reshape(-1, 1)

bc_inlet_y = np.linspace(-0.5, 0.5, 100).reshape(-1, 1)
bc_inlet_x = -0.5 * np.ones_like(bc_inlet_y).reshape(-1, 1)

bc_outlet_y = np.linspace(-0.5, 0.5, 100).reshape(-1, 1)
bc_outlet_x = 1.5 * np.ones_like(bc_outlet_y).reshape(-1, 1)

radius = 0.05
theta = np.linspace(0, 2 * np.pi, 200)
bc_cylinder_x = (0 + radius * np.cos(theta)).reshape(-1, 1)
bc_cylinder_y = (0 + radius * np.sin(theta)).reshape(-1, 1)

bc_top_points = np.hstack((bc_top_x, bc_top_y))
bc_bottom_points = np.hstack((bc_bottom_x, bc_bottom_y))
bc_inlet_points = np.hstack((bc_inlet_x, bc_inlet_y))
bc_outlet_points = np.hstack((bc_outlet_x, bc_outlet_y))
bc_cylinder_points = np.hstack((bc_cylinder_x, bc_cylinder_y))

bc_wall_points = np.concatenate((bc_top_points, bc_bottom_points, bc_cylinder_points), 0)

PLOT

In [ ]:
plt.scatter(domain_points[:, 0], domain_points[:, 1], s = 1)
plt.scatter(bc_wall_points[:, 0], bc_wall_points[:, 1], s = 5)
plt.scatter(bc_inlet_points[:, 0], bc_inlet_points[:, 1], s = 5)
plt.scatter(bc_outlet_points[:, 0], bc_outlet_points[:, 1], s = 5)

plt.title('Collocation Points')
plt.axis('scaled')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
No description has been provided for this image

3.2.6. Generate Dataset

  • Generate a dataset by stacking data for all conditions and collocation points like below:

Domain

In [ ]:
# Domain dataset
num_domain = domain_points.shape[0]      # (19692, _)
num_condition = vis_train.shape[0]  #(9, _)

condition_list = []
XY_domain_list = []
gt_list = []

for i in range(num_condition):
    con = vis_train[i]
    conditions = np.vstack([con] * num_domain)
    domains = domain_points
    gts = train_gt_domain[i]

    condition_list.append(conditions)
    XY_domain_list.append(domains)
    gt_list.append(gts)

conditions_domain = np.vstack(condition_list)
XYs_domain = np.vstack(XY_domain_list)
gts_domain = np.vstack(gt_list)

print('Condtions Domain: {}'.format(conditions_domain.shape))
print('XYs Domain:       {}'.format(XYs_domain.shape))
print('GTs Domain:       {}'.format(gts_domain.shape))
Condtions Domain: (177228, 1)
XYs Domain:       (177228, 2)
GTs Domain:       (177228, 3)

Boundary Condition

BC Inlet

In [ ]:
num_BC = bc_inlet_points.shape[0]    # (100,_)
num_condition = vis_train.shape[0]  # (9,_)

condition_list = []
bc_list = []

for i in range(num_condition):
    con = vis_train[i]
    conditions = np.vstack([con] * num_BC)
    XYs_bc_inlet = bc_inlet_points

    condition_list.append(conditions)
    bc_list.append(XYs_bc_inlet)

conditions_bc_inlet = np.vstack(condition_list)
XYs_bc_inlet = np.vstack(bc_list)

print('Condtions BC: {}'.format(conditions_bc_inlet.shape))
print('XYs BC:       {}'.format(XYs_bc_inlet.shape))
Condtions BC: (900, 1)
XYs BC:       (900, 2)

BC Wall

In [ ]:
num_BC = bc_wall_points.shape[0]    # (200,_)
num_condition = vis_train.shape[0]  # (9,_)

condition_list = []
bc_list = []

for i in range(num_condition):
    con = vis_train[i]
    conditions = np.vstack([con] * num_BC)
    XYs_bc_wall = bc_wall_points

    condition_list.append(conditions)
    bc_list.append(XYs_bc_wall)

conditions_bc_wall = np.vstack(condition_list)
XYs_bc_wall = np.vstack(bc_list)

print('Condtions BC: {}'.format(conditions_bc_wall.shape))
print('XYs BC:       {}'.format(XYs_bc_wall.shape))
Condtions BC: (5400, 1)
XYs BC:       (5400, 2)

BC Outlet

In [ ]:
num_BC = bc_outlet_points.shape[0]    # (250,_)
num_condition = vis_train.shape[0]  # (10,_)

condition_list = []
bc_list = []

for i in range(num_condition):
    con = vis_train[i]
    conditions = np.vstack([con] * num_BC)
    XYs_bc_outlet = bc_outlet_points

    condition_list.append(conditions)
    bc_list.append(XYs_bc_outlet)

conditions_bc_outlet = np.vstack(condition_list)
XYs_bc_outlet = np.vstack(bc_list)

print('Condtions BC: {}'.format(conditions_bc_outlet.shape))
print('XYs BC:       {}'.format(XYs_bc_outlet.shape))
Condtions BC: (900, 1)
XYs BC:       (900, 2)

3.2.7. Data Generator

In [ ]:
# Data generator
class DataGenerator(data.Dataset):
    def __init__(self, c, XY, gt, batch_size=64, rng_seed=1234):
        'Initialization'
        self.c = torch.tensor(c, dtype = torch.float32)
        self.XY = torch.tensor(XY, dtype = torch.float32)
        self.gt = torch.tensor(gt, dtype = torch.float32)
        self.N = c.shape[0]
        self.batch_size = batch_size
        self.seed = rng_seed

    def __len__(self):
        return (self.N // self.batch_size) + 1

    def __getitem__(self, index):
        c, XY, gt = self.__data_generation(index)
        return c, XY, gt

    def __data_generation(self, index):
        torch.manual_seed(self.seed + index)

        idx = torch.randperm(self.N)[:self.batch_size]

        c = self.c[idx, :]
        XY = self.XY[idx, :]
        gt = self.gt[idx,:]

        return c, XY, gt


# BC generator
class BCGenerator(data.Dataset):
    def __init__(self, c, XY, batch_size=64, rng_seed=1234):
        'Initialization'
        self.c = torch.tensor(c, dtype = torch.float32)
        self.XY = torch.tensor(XY, dtype = torch.float32)
        self.N = c.shape[0]
        self.batch_size = batch_size
        self.seed = rng_seed

    def __len__(self):
        return (self.N // self.batch_size) + 1

    def __getitem__(self, index):
        c, XY= self.__data_generation(index)
        return c, XY

    def __data_generation(self, index):
        torch.manual_seed(self.seed + index)

        idx = torch.randperm(self.N)[:self.batch_size]

        c = self.c[idx, :]
        XY = self.XY[idx, :]

        return c, XY
In [ ]:
batch_size = 5000
train_domain = DataGenerator(conditions_domain, XYs_domain, gts_domain, batch_size)
train_inlet = BCGenerator(conditions_bc_inlet, XYs_bc_inlet, batch_size)
train_wall = BCGenerator(conditions_bc_wall, XYs_bc_wall, batch_size)
train_outlet = BCGenerator(conditions_bc_outlet, XYs_bc_outlet, batch_size)

3.2.8. Define Networks and Hyper-parameter

In [ ]:
class MLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.linears = nn.Sequential(
            nn.Linear(input_dim, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100)
        )

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        a = x.float()
        return self.linears(a)
In [ ]:
random_seed()
# Initialize model
dim_vis = 1

branch_net_u = MLP(input_dim = dim_vis).to(device)
branch_net_v = MLP(input_dim = dim_vis).to(device)
branch_net_p = MLP(input_dim = dim_vis).to(device)
trunk_net_u = MLP(input_dim = 2).to(device)
trunk_net_v = MLP(input_dim = 2).to(device)
trunk_net_p = MLP(input_dim = 2).to(device)

optimizer = optim.Adam(
    list(branch_net_u.parameters()) + list(branch_net_v.parameters()) + list(branch_net_p.parameters()) +
    list(trunk_net_u.parameters()) + list(trunk_net_v.parameters()) + list(trunk_net_p.parameters()), lr=5e-4)
Loss_func = nn.MSELoss(reduction='mean')

3.2.9. Define Utility Function

In [ ]:
def PIDeepOnet(model, con_, XY_):
    '''
    model = [branch_net_u, branch_net_v,branch_net_p, trunk_net_u, trunk_net_v, trunk_net_p]
    '''

    branch_net_u = model[0]
    branch_net_v = model[1]
    branch_net_p = model[2]
    trunk_net_u = model[3]
    trunk_net_v = model[4]
    trunk_net_p = model[5]

    B1 = branch_net_u(con_)
    B2 = branch_net_v(con_)
    B3 = branch_net_p(con_)

    T1 = trunk_net_u(XY_)
    T2 = trunk_net_v(XY_)
    T3 = trunk_net_p(XY_)

    u = torch.sum(B1 * T1, dim=1).unsqueeze(1)
    v = torch.sum(B2 * T2, dim=1).unsqueeze(1)
    p = torch.sum(B3 * T3, dim=1).unsqueeze(1)

    return u, v, p

def derivative(y, t) :
    df = torch.autograd.grad(y, t, create_graph = True, retain_graph = True, grad_outputs = torch.ones(y.size()).to(device))[0]
    df_x = df[:, 0:1]
    df_y = df[:, 1:2]
    return df_x, df_y

def requires_grad(x):
    return torch.tensor(x, requires_grad = True).to(device)
In [ ]:
'PLOT'
from scipy.interpolate import griddata
import scipy.interpolate
import matplotlib.patches as patches
from IPython.display import clear_output

def PLOT(branch_net_u, branch_net_v, branch_net_p, trunk_net_u, trunk_net_v, trunk_net_p):
    branch_net_u.eval()
    branch_net_v.eval()
    branch_net_p.eval()
    trunk_net_u.eval()
    trunk_net_v.eval()
    trunk_net_p.eval()

    def interpolate_output(x, y, us,extent):
        "Interpolates irregular points onto a mesh"

        # define mesh to interpolate onto
        xyi = np.meshgrid(
        np.linspace(extent[0], extent[1], 100),
        np.linspace(extent[2], extent[3], 100),
            indexing="ij"
        )

        # linearly interpolate points onto mesh
        us = [scipy.interpolate.griddata(
            (x, y), u, tuple(xyi)
            )
            for u in us]

        return us

    with torch.no_grad():
        count = 0
        plt.figure(figsize=(14, 4))

        for i, gt in enumerate(gts):
            if i != 2 and i != 6:
                continue

            XY = XY_[i]

            x = XY[:,0]
            y = XY[:,1]
            u = gt[:,0]
            v = gt[:,1]
            # p = gt[:,2]
            gt_vel = np.sqrt(u**2 + v**2)

            num_col = len(gt)
            c_tmp = np.full((num_col, 1), vis_list[i])
            c_test = torch.tensor(c_tmp).to(device)
            XY_test = torch.tensor(XY).to(device)

            B_u = branch_net_u(c_test)
            B_v = branch_net_v(c_test)
            T_u = trunk_net_u(XY_test)
            T_v = trunk_net_v(XY_test)

            u_pred = torch.sum(B_u * T_u, dim=1)
            v_pred = torch.sum(B_v * T_v, dim=1)
            vel_pred = np.sqrt(u_pred.detach().cpu().numpy()**2 + v_pred.detach().cpu().numpy()**2)

            gt_inter, vel_inter, e_inter = interpolate_output(x, y,
                                                             [gt_vel, vel_pred, np.abs(gt_vel - vel_pred)],
                                                             [-0.5, 1.5, -0.5, 0.5])

            plt.subplot(2, 3, (count * 3) +1)
            if count == 0:
                plt.title('Ground Truth', fontsize=15)
            plt.imshow(gt_inter.T, origin = 'lower', extent=[-0.5, 1.5, -0.5, 0.5], vmin = np.min(gt_vel), vmax = np.max(gt_vel), cmap = 'jet')
            shp = patches.Circle((0, 0), radius=0.04, color='white')
            plt.gca().add_patch(shp)
            plt.colorbar()
            plt.axis('scaled')
            plt.ylabel('Vis: {}'.format(vis_list[i][0]), fontsize=13)

            plt.subplot(2, 3, (count * 3) + 2)
            if count == 0:
                plt.title('Prediction', fontsize=15)
            plt.imshow(vel_inter.T, origin = 'lower', extent = [-0.5, 1.5, -0.5, 0.5], vmin = np.min(gt_vel), vmax = np.max(gt_vel), cmap = 'jet')
            shp = patches.Circle((0, 0), radius=0.04, color='white')
            plt.gca().add_patch(shp)
            plt.colorbar()
            plt.axis('scaled')

            plt.subplot(2, 3, (count * 3) + 3)
            if count == 0:
                plt.title('Error', fontsize=15)
            plt.imshow(e_inter.T, origin = 'lower', extent = [-0.5, 1.5, -0.5, 0.5], vmin = 0, vmax = 0.3, cmap = 'jet')
            shp = patches.Circle((0, 0), radius = 0.04, color = 'white')
            plt.gca().add_patch(shp)
            plt.colorbar()
            plt.axis('scaled')

            count += 1
        plt.tight_layout()
        plt.show()
        clear_output(wait=True)

XY_ = np.concatenate((train_xy_domain[0:2], test_xy_domain[0:1], train_xy_domain[2:5], test_xy_domain[1:2], train_xy_domain[5:]), 0)
gts = np.concatenate((train_gt_domain[0:2], test_gt_domain[0:1], train_gt_domain[2:5], test_gt_domain[1:2], train_gt_domain[5:]), 0)

3.2.10. Define PDE with Boundary Conditions

  • 2D Navier-Stokes Equations & boundary conditions (for steady state)

$$ \begin{align*} \rho \left(u{\partial u \over \partial x} + v{\partial u \over \partial y} + {1 \over \rho}{\partial p \over \partial x}\right) - \mu \ \left({\partial^2 u \over {\partial x^2}} + {\partial^2 u \over {\partial y^2}}\right) &= 0\\\\ \rho \left(u{\partial v \over \partial x} + v{\partial v \over \partial y} + {1 \over \rho}{\partial p \over \partial y}\right) - \mu \ \left({\partial^2 v \over {\partial x^2}} + {\partial^2 v \over {\partial y^2}}\right) &= 0\\\\ {\partial u \over \partial x} + {\partial v \over \partial y} &= 0 \end{align*} $$

PDE

In [ ]:
def PDE(model, c_domain, XY_domain):

    u, v, p = PIDeepOnet(model, c_domain, XY_domain)

    du_x, du_y = derivative(u, XY_domain)
    du_xx, _ = derivative(du_x, XY_domain)
    _, du_yy =  derivative(du_y, XY_domain)

    dv_x, dv_y = derivative(v, XY_domain)
    dv_xx, _ = derivative(dv_x, XY_domain)
    _, dv_yy = derivative(dv_y, XY_domain)

    dp_x, dp_y = derivative(p, XY_domain)

    vis = c_domain
    pde_u = 1 * (u * du_x + v * du_y) + dp_x - vis * (du_xx + du_yy)
    pde_v = 1 * (u * dv_x + v * dv_y) + dp_y - vis * (dv_xx + dv_yy)
    pde_cont = du_x + dv_y

    return pde_u.float(), pde_v.float(), pde_cont.float()
  • Two Dirichlet boundary conditions on the plate boundary (no-slip condition),

$$u(x,y) = 0, \quad v(x,y) = 0 \qquad \text{at} \quad y = \frac{D}{2} \ \; \text{or} \; -\frac{D}{2}$$


  • Two Dirichlet boundary conditions at the cylinder boundary

$$u(\text{cylinder}) = 0, \quad v(\text{cylinder}) = 0$$


  • Two Dirichlet boundary conditions at the inlet boundary

$$u(-1,y) = u_{\text{in}}, \quad v(-1,y) = 0$$


  • Two Dirichlet boundary conditions at the outlet boundary

$$p(1,y) = 0, \quad v(1,y) = 0$$


Boundary Conditions

In [ ]:
def BC_wall(model, con_wall, XY_wall):
    wall_u, wall_v, _ = PIDeepOnet(model, con_wall, XY_wall)
    return wall_u.float(), wall_v.float()

def BC_inlet(model, con_inlet, XY_inlet):
    inlet_u, inlet_v, _ = PIDeepOnet(model, con_inlet, XY_inlet)
    inlet_u = inlet_u - torch.ones_like(inlet_u).to(device)

    return inlet_u.float(), inlet_v.float()

def BC_outlet(model, con_outlet, XY_outlet):
    _, outlet_v, outlet_p = PIDeepOnet(model, con_outlet, XY_outlet)
    return outlet_v.float(), outlet_p.float()

3.2.11. Train

In [ ]:
epoch = 0
num_epochs = 150000
min_loss = np.inf
In [ ]:
train_domain_iter = iter(train_domain)
train_wall_iter = iter(train_wall)
train_inlet_iter = iter(train_inlet)
train_outlet_iter = iter(train_outlet)

while epoch < num_epochs + 1:

    train_domain_batch = next(train_domain_iter)
    train_wall_batch = next(train_wall_iter)
    train_inlet_batch = next(train_inlet_iter)
    train_outlet_batch = next(train_outlet_iter)

    con_domain, XY_domain, gt_domain = train_domain_batch
    con_wall, XY_wall = train_wall_batch
    con_inlet, XY_inlet = train_inlet_batch
    con_outlet, XY_outlet = train_outlet_batch

    ############### Reguires grad #################
    'Domain'
    con_domain, XY_domain, gt_domain = requires_grad(con_domain), requires_grad(XY_domain), requires_grad(gt_domain)
    'Boundary Conditions'
    con_wall, XY_wall = requires_grad(con_wall), requires_grad(XY_wall) # BC wall
    con_inlet, XY_inlet = requires_grad(con_inlet), requires_grad(XY_inlet) # BC inlet
    con_outlet, XY_outlet = requires_grad(con_outlet), requires_grad(XY_outlet) # BC outlet


    model = [branch_net_u, branch_net_v, branch_net_p, trunk_net_u, trunk_net_v, trunk_net_p]
    u, v, p = PIDeepOnet(model, con_domain, XY_domain)


    ################# Data loss ###################
    loss_data_u = Loss_func(u.float(), gt_domain[:, 0:1].float())
    loss_data_v = Loss_func(v.float(), gt_domain[:, 1:2].float())
    loss_data_p = Loss_func(p.float(), gt_domain[:, 2:3].float())

    loss_data = loss_data_u + loss_data_v + loss_data_p

    ################## PDE loss ###################
    PDE_u, PDE_v, PDE_cont = PDE(model, con_domain, XY_domain)

    loss_PDE_u = Loss_func(PDE_u, torch.zeros_like(PDE_u).to(device))
    loss_PDE_v = Loss_func(PDE_v, torch.zeros_like(PDE_v).to(device))
    loss_PDE_cont = Loss_func(PDE_cont, torch.zeros_like(PDE_cont).to(device))

    loss_pde = loss_PDE_u + loss_PDE_v + loss_PDE_cont

    ################## BC loss ####################
    wall_u, wall_v = BC_wall(model, con_wall, XY_wall)
    inlet_u, inlet_v = BC_inlet(model, con_inlet, XY_inlet)
    outlet_v, outlet_p = BC_outlet(model, con_outlet, XY_outlet)

    loss_BC_wall_u = Loss_func(wall_u, torch.zeros_like(wall_u).to(device))
    loss_BC_wall_v = Loss_func(wall_v, torch.zeros_like(wall_v).to(device))
    loss_BC_inlet_u = Loss_func(inlet_u, torch.zeros_like(inlet_u).to(device))
    loss_BC_inlet_v = Loss_func(inlet_v, torch.zeros_like(inlet_v).to(device))
    loss_BC_outlet_v = Loss_func(outlet_v, torch.zeros_like(outlet_v).to(device))
    loss_BC_outlet_p = Loss_func(outlet_p, torch.zeros_like(outlet_p).to(device))

    loss_BC_wall = loss_BC_wall_u + loss_BC_wall_v
    loss_BC_inlet = loss_BC_inlet_u + loss_BC_inlet_v
    loss_BC_outlet  =loss_BC_outlet_v + loss_BC_outlet_p
    loss_bc = loss_BC_wall + loss_BC_inlet + loss_BC_outlet

    loss = 10 * loss_data + 0.1 * loss_bc + 0.01 * loss_pde

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        with torch.no_grad():
            print('Epoch: {} Loss: {:.4f} DATALoss: {:.4f} PDELoss: {:.4f} BCLoss: {:.4f}'.format(epoch, loss.item(), 10 * loss_data.item(), 0.01 * loss_pde.item(), 0.1 * loss_bc.item()))
            PLOT(branch_net_u, branch_net_v, branch_net_p, trunk_net_u, trunk_net_v, trunk_net_p)
    epoch += 1

3.3. Test Pretrained Model

In [ ]:
with torch.no_grad():
    branch1_test = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/flow_around_a_cylinder/fac_branch_net_u.pt')
    branch2_test = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/flow_around_a_cylinder/fac_branch_net_v.pt')
    branch3_test = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/flow_around_a_cylinder/fac_branch_net_p.pt')
    trunk1_test = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/flow_around_a_cylinder/fac_trunk_net_u.pt')
    trunk2_test = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/flow_around_a_cylinder/fac_trunk_net_v.pt')
    trunk3_test = torch.load('/content/drive/MyDrive/tutorials/기계인공지능연구회/본부학술대회_241106/DeepONet/weight/flow_around_a_cylinder/fac_trunk_net_p.pt')
    PLOT(branch1_test, branch2_test, branch3_test, trunk1_test, trunk2_test, trunk3_test)
No description has been provided for this image

3.4. Interpolation Visualization

GIF 설명

In [ ]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')